Efficient Distributed Training: From DP to ZeRO and FlashAttention

Training billion-parameter models on a single GPU is no longer feasible—both memory and compute demands exceed what any single device can provide. The solution is parallelism, but choosing the right parallelism strategy (and combining them) is itself a non-trivial engineering problem. Data parallelism replicates the model but partitions data; tensor parallelism slices individual operators across devices; pipeline parallelism assigns different layers to different GPUs; and sequence parallelism handles the non-tensor-parallel regions. In practice, large-scale training combines all of these into a multi-dimensional parallel strategy.

This post covers the core parallelism methods—Data Parallelism, Megatron-style Tensor Parallelism, GPipe/PipeDream Pipeline Parallelism, and Sequence Parallelism—along with memory optimization techniques including ZeRO, PyTorch FSDP, mixed precision training, activation recomputation, and FlashAttention. For each method, we provide the key formulations, trade-offs, and reference implementations that practitioners need to make informed decisions when scaling training.

Training Parallelism

Data Parallelism

Data Parallelism (DP) replicates the full model on each device and partitions the mini-batch across them. Each device executes forward and backward passes on its data shard independently, then gradients are averaged via all-reduce before the synchronized weight update. DP is simple to implement and scales well when the model fits in a single GPU’s memory.

Model Parallelism (Tensor Parallelism)

Tensor-level Model Parallelism (TP)1 distributes individual operators across devices, enabling models too large for a single GPU. Megatron-LM pioneered this approach for transformers.

MLP Block

Consider the standard MLP within a transformer:

\[Y = \text{GeLU}(XA), \qquad Z = \text{Dropout}(YB)\]

Megatron-LM tensor parallelism for MLP

TP splits $A$ along columns: \(A = [A_1, A_2]\). Since GeLU is element-wise, it can be applied independently to each partition:

\[[Y_1, Y_2] = [\text{GeLU}(XA_1), \text{GeLU}(XA_2)]\]

Matrix $B$ is then split along rows, allowing direct consumption of the GeLU outputs without inter-device communication until the final all-reduce.

Self-Attention Block

For self-attention:

\[\text{Attention}(X,Q,K,V) = \text{softmax}\!\left(\frac{(XQ)(XK)^\top}{\sqrt{d_k}}\right) XV\]

TP partitions the $Q$, $K$, $V$ projection matrices along columns so that each attention head resides on a single GPU. The output projection is split along rows. This yields an efficient design requiring only two all-reduce operations per transformer layer (one in forward, one in backward).

Communication operations in tensor parallelism

Non-parallel components: Dropout, layer normalization, and residual connections are replicated (not split) across GPUs. Each GPU maintains a duplicate of the LayerNorm parameters, allowing these operations to proceed locally without communication.

Pipeline Parallelism

Pipeline Parallelism (PP)234 assigns consecutive groups of layers to different devices, with data flowing through them sequentially as micro-batches.

GPipe

GPipe2 splits the mini-batch into $m$ micro-batches that flow through $p$ pipeline stages simultaneously.

GPipe pipeline schedule

Pipeline bubble: The unavoidable idle time at the start and end of each batch. With $p$ stages and $m$ micro-batches:

\[\text{Bubble fraction} = \frac{p - 1}{m + p - 1}\]

GPipe micro-batch schedule

When $m \gg p$, the bubble overhead becomes negligible.

PipeDream (1F1B)

PipeDream3 introduces the one-forward-one-backward (1F1B) schedule: each device immediately starts the backward pass for a micro-batch as soon as its forward pass completes. This interleaving ensures no GPU sits idle during steady state.

PipeDream 1F1B schedule

With interleaved scheduling, each device handles $v$ model chunks (multiple non-contiguous stages), reducing per-micro-batch latency by $1/v$ and shrinking the bubble to:

\[\text{Bubble fraction} = \frac{1}{v} \cdot \frac{p - 1}{m + p - 1}\]

PipeDream interleaved pipeline

Sequence Parallelism

Sequence Parallelism (SP)5 targets the non-tensor-parallel regions of a transformer (LayerNorm, dropout, residual connections), which are independent along the sequence dimension.

Sequence Parallelism (Megatron-LM v3)

The standard non-parallel block:

\[\begin{align} Y &= \text{LayerNorm}(X), \\ Z &= \text{GeLU}(YA), \\ W &= ZB, \\ V &= \text{Dropout}(W) \end{align}\]

SP splits $X$ along the sequence dimension: \(X = [X_1^s, X_2^s]\). After LayerNorm (sequence-parallel), an all-gather provides the full $Y$ for the TP region. After the TP GEMM, reduce-scatter returns to sequence-parallel sharding:

\[\begin{align} [Y_1^s, Y_2^s] &= \text{LayerNorm}([X_1^s, X_2^s]), \\ Y &= g(Y_1^s, Y_2^s), \\ [Z_1^h, Z_2^h] &= [\text{GeLU}(YA_1^c),\; \text{GeLU}(YA_2^c)], \\ W_1 &= Z_1^h B_1^r, \quad W_2 = Z_2^h B_2^r, \\ [W_1^s, W_2^s] &= \bar{g}(W_1, W_2), \\ [V_1^s, V_2^s] &= [\text{Dropout}(W_1^s),\; \text{Dropout}(W_2^s)] \end{align}\]

Tensor + Sequence Parallelism

SP replaces the all-reduce in TP with all-gather (forward) and reduce-scatter (backward), keeping total communication volume constant while eliminating activation memory redundancy in non-TP regions.

ZeRO: Memory-Efficient Optimizer

Zero Redundancy Optimizer (ZeRO)6 eliminates memory redundancy across DP ranks by partitioning (instead of replicating) model states. ZeRO has three progressive stages:

  1. \(P_\text{os}\): Partition optimizer states only.
  2. \(P_\text{os+g}\): Additionally partition gradients.
  3. \(P_\text{os+g+p}\): Additionally partition parameters.

ZeRO memory optimization stages

See DeepSpeed ZeRO tutorial for implementation details.

PyTorch FSDP

Fully Sharded Data Parallel (FSDP)7 shards model parameters, gradients, and optimizer states across devices. During forward/backward, only one shard is materialized (unsharded) at a time, keeping peak memory proportional to a single shard.

PyTorch FSDP overview

FSDP uses a sharding factor $F$:

  • $F = 1$: equivalent to standard DP (all-reduce on gradients).
  • $F = W$ (world size): full sharding—each device holds $1/W$ of the model.
  • $1 < F < W$: hybrid sharding.

FSDP full sharding

Sharding strategy (flatten-concat-chunk): All parameters within an FSDP unit are flattened into a single contiguous FlatParameter, padded if necessary for divisibility by $F$, then chunked evenly across ranks.

FSDP hybrid sharding

Mixed Precision Training

Mixed precision89 executes most operations in FP16 (or BF16) for speed, while maintaining an FP32 master copy of weights for numerical stability.

Precision format evolution

Mixed Precision Training workflow

Training procedure:

  1. Maintain FP32 master weights.
  2. Each iteration: cast to FP16 → forward → scale loss by $S$ → backward → unscale gradients by $1/S$ → update FP32 master weights.

Dynamic Loss Scaling

Start with a large scaling factor $S$. If no overflow occurs for $N$ iterations, increase $S$. If overflow is detected, skip the update and decrease $S$.

1
2
3
4
5
6
7
8
9
10
11
12
13
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

AMP Optimization Levels (NVIDIA Apex)

Level Description Weights Loss Scaling
O0 Pure FP32 FP32 None
O1 Mixed precision (recommended) FP32, ops auto-cast Dynamic
O2 Almost FP16 FP16 + FP32 master copy Dynamic
O3 Pure FP16 (speed baseline) FP16 None

Floating-Point Formats

Format Exponent Mantissa Range Use Case
FP32 8 bits 23 bits Wide Master weights, accumulation
FP16 5 bits 10 bits Limited Forward/backward compute
BF16 8 bits 7 bits Same as FP32 Preferred for training stability
FP8 4-5 bits 2-3 bits Narrow Emerging for inference/training

Floating-point format comparison

Memory-Efficient Techniques

CPU Offload

When GPU memory is exhausted, model states (optimizer states, parameters) can be offloaded to CPU memory and fetched back on demand. ZeRO-Offload and ZeRO-Infinity implement this strategy.

Activation Recomputation

Instead of storing all intermediate activations for backpropagation, selective activation recomputation stores only a subset (e.g., activations at layer boundaries) and recomputes the rest during the backward pass. This trades compute for memory, typically with minimal wall-clock overhead.

FlashAttention

FlashAttention10 reduces HBM (High Bandwidth Memory) accesses by tiling the attention computation into SRAM-sized blocks, avoiding materialization of the full $N \times N$ attention matrix.

FlashAttention tiling illustration

FlashAttention algorithm

FlashAttention v1 tiles K/V in the outer loop and Q in the inner loop; v2 reverses this to reduce SRAM visits. The result is exact attention (not an approximation) with $O(N)$ memory instead of $O(N^2)$.

References

  1. Shoeybi, M., et al. Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053, 2019. 

  2. Huang, Y., et al. GPipe: Easy Scaling with Micro-Batch Pipeline Parallelism. NeurIPS 2019.  2

  3. Harlap, A., et al. PipeDream: Fast and Efficient Pipeline Parallel DNN Training. SOSP 2019.  2

  4. Narayanan, D., et al. Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM. SC 2021. 

  5. Korthikanti, V.A., et al. Reducing Activation Recomputation in Large Transformer Models. MLSys 2023. 

  6. Rajbhandari, S., et al. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC 2020. 

  7. Zhao, Y., et al. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel. arXiv:2304.11277, 2023. 

  8. NVIDIA. Training with Mixed Precision

  9. Micikevicius, P., et al. Mixed Precision Training. ICLR 2018. 

  10. Dao, T., et al. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. 




    Related Posts

  • Multimodal Tokenization with Vector Quantization: A Review
  • Memory-Efficient Attention: MHA vs. MQA vs. GQA vs. MLA
  • Positional Encoding in Transformers: From Sinusoidal to RoPE
  • Diffusion Models: A Mathematical Guide from Scratch
  • Masking Strategies for Pre-trained Language Models: From MLM to T5