Paper ID: 2412.14374

Scaling Deep Learning Training with MPMD Pipeline Parallelism

Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover

We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.11\times$ with respect to the best performing SPMD configuration.

Submitted: Dec 18, 2024