DistJax is a mini-library and collection of examples designed to simplify the implementation of common distributed training paradigms in JAX and Flax. While JAX provides powerful low-level primitives like pmap and shard_map for parallelism, orchestrating them into cohesive, large-scale training strategies can be complex. This repository provides high-level, reusable building blocks for data parallelism, tensor parallelism (including asynchronous variants), and pipeline parallelism, allowing researchers and engineers to scale their models with clarity and confidence.
This library provides modular components and end-to-end examples for the following paradigms, which can be mixed and matched to suit your specific hardware and model architecture.
The foundational technique of replicating a model's weights across multiple devices and processing different shards of a data batch on each one. It's highly effective for scaling but requires each device to hold a full copy of the model, which can be a memory bottleneck.
A model-parallel strategy that shards individual layers (like the weight matrices in Dense or Attention layers) across devices. This allows for training models that are too large to fit on a single device. This library includes:
-
Standard synchronous communication using collective operations like
all_gatherandpsum_scatter, which are easy to reason about but introduce explicit synchronization points. -
Advanced asynchronous communication primitives that leverage JAX's
ppermuteto overlap communication with computation. By passing activations between devices in a staggered, ring-like fashion, this approach can hide communication latency and significantly improve GPU/TPU utilization.
A model-parallel strategy that stages sequential model layers or blocks across different devices. To keep all devices active, the input batch is split into smaller "micro-batches" that are fed into the pipeline in a staggered manner. This minimizes the "pipeline bubble"—the idle time when devices are waiting for data—making it an efficient strategy for very deep models.
The components are designed to be composable. For example, you can combine Data Parallelism and Tensor Parallelism: within a group of 8 GPUs, you might use 4-way tensor parallelism to shard a large model, and then replicate this 4-GPU setup twice for 2-way data parallelism. This allows for flexible scaling across both model size and data throughput.
To quickly run the examples, follow these steps:
# 1. Clone the repository
git clone https://github.com/your-username/DistJax.git
cd DistJax
# 2. Install dependencies
pip3 install -r requirements.txt
# 3. Run the data-parallel example
python3 -m examples.data_parallelism
# 4. Run the tensor-parallel example
python3 -m examples.tensor_parallelismThe repository is organized to separate reusable logic from specific model implementations and training scripts. This clean separation of concerns makes the library easy to navigate and extend.
DistJax/
├── core/ # Generic training utilities (TrainState, attention ops)
│ ├── attention.py
│ ├── module_utils.py
│ ├── training.py
│ └── utils.py
├── parallelism/ # The core parallelism primitives and modules
│ ├── data_parallel.py
│ ├── pipeline_parallel.py
│ ├── sharding.py
│ ├── tensor_parallel_async.py
│ └── tensor_parallel.py
├── models/ # Example model architectures built with the library
│ ├── mlp.py
│ ├── pp_classifier.py
│ ├── simple_classifier.py
│ ├── tp_classifier.py
│ └── transformer.py
├── configs/ # Configuration files for the models
│ ├── default_config.py
│ └── tp_config.py
├── examples/ # Standalone scripts to run training for each paradigm
│ ├── data_parallelism.py
│ ├── tensor_parallelism.py
│ └── utils.py
├── README.md
└── requirements.txt
The DistJax/parallelism directory contains the core building blocks for various parallelism strategies:
data_parallel.py: Implements data parallelism with synchronized gradients.pipeline_parallel.py: Provides tools for pipeline parallelism, including micro-batching and model wrappers.sharding.py: Contains utilities for sharding parameters, including Fully Sharded Data Parallelism (FSDP).tensor_parallel.py: Implements synchronous tensor parallelism.tensor_parallel_async.py: Implements asynchronous tensor parallelism.
The DistJax/models directory contains several example models that demonstrate how to use the parallelism strategies:
simple_classifier.py: A basic classifier for demonstrating data parallelism.tp_classifier.py: A classifier that uses tensor parallelism.pp_classifier.py: A classifier that uses pipeline parallelism.transformer.py: A Transformer model with tensor parallelism.mlp.py: Contains various MLP blocks, including tensor-parallel and asynchronous versions.
Follow these steps to set up the environment and run one of the examples.
git clone https://github.com/your-username/DistJax.git
cd DistJaxIt's recommended to use a virtual environment to manage dependencies.
python3 -m venv venv
source venv/bin/activate
pip3 install -r requirements.txtNote: For GPU or TPU support, ensure you have installed the appropriate version of JAX by following the official JAX installation guide.
The scripts in the examples/ directory are designed to be run directly. They will simulate a multi-device environment on your CPU for demonstration purposes.
To run the data-parallel training example:
python3 -m examples.data_parallelismThis will run a few training steps and print the final metrics.
The library is built around a few key ideas to promote modularity and ease of use.
The DistJax.parallelism module contains the core building blocks. These are often implemented as flax.linen.Module wrappers that inject parallelism logic into standard layers. For example, tensor_parallel.py provides TPDense, which looks like a normal Dense layer but automatically handles the sharding of its weights and the necessary communication (all_gather or psum_scatter) of its inputs and outputs across a device mesh.
The DistJax.models directory demonstrates the design philosophy: model architecture should be decoupled from parallelism logic. The models are constructed by composing the parallel primitives from the library. For instance, transformer.py builds a fully tensor-parallel Transformer block by using TPAsyncDense for its MLP and attention layers, without cluttering the model definition with low-level communication code.
The DistJax.examples scripts tie everything together and provide a blueprint for your own training runs. They handle the essential boilerplate for distributed training:
-
Setting up the JAX Mesh: A
Meshdefines the logical topology of your devices (e.g., an 8-device array with a 'data' axis and a 'model' axis). This abstraction is crucial for telling JAX how to distribute data and computations. -
Loading Model Configurations: Using
ml_collections.ConfigDictfor clean and hierarchical management of hyperparameters. -
Initializing the Model State: Using
shard_mapto correctly initialize parameters across all devices according to the specified parallelism strategy. This ensures each device gets only its designated shard of the model. -
Defining the Parallel train_step: The core training function is written once and then parallelized using
shard_map, withPartitionSpecannotations to define how the state, metrics, and data are sharded. -
Running the Main Training Loop: The loop executes the JIT-compiled parallel
train_step, passing sharded data and updating the distributed model state.
Contributions are welcome! If you have ideas for improvements, new features, or find any bugs, please feel free to open an issue or submit a pull request. Potential areas for future work include:
- More model examples (e.g., Mixture-of-Experts, Vision Transformers)
- Support for more advanced optimizers tailored for distributed settings (e.g., ZeRO)
- Enhanced documentation with more in-depth tutorials and conceptual guides
- Integration with more advanced JAX features as they become available
This project is licensed under the MIT License. See the LICENSE file for details.