Paper ID: 2412.09734

MPAX: Mathematical Programming in JAX

Haihao Lu, Zedong Peng, Jinwen Yang

We introduce MPAX (Mathematical Programming in JAX), a versatile and efficient toolbox for integrating mathematical programming into machine learning workflows. MPAX implemented firstorder methods in JAX, providing native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism. Currently in beta version, MPAX supports linear programming and will be extended to solve more general mathematical programming problems and specialized modules for common machine learning tasks. The solver is available at this https URL

Submitted: Dec 12, 2024