Skip to main content

Command Palette

Search for a command to run...

Why We Moved an AudioLLM to Megatron

Scaling a 10B multimodal model beyond FSDP with tensor, pipeline, and context parallelism

Updated
11 min read
Why We Moved an AudioLLM to Megatron

We trained our 10B-parameter AudioLLM — a Whisper speech encoder fused with a Gemma2 9B text decoder — using Megatron with Mosaic Streaming to handle training data.

The wall

The architecture is a Whisper-large-v3 speech encoder feeding into a Gemma2 9B text decoder through a small MLP adapter. At 10B parameters with 30-second audio clips, a single H100 GPU with 80GB VRAM can just barely hold the full model and its activations at small micro-batch sizes.

Our multimodal trainer only supported FSDP. FSDP shards optimizer states and parameters across GPUs, then all-gathers them for each forward/backward pass. Read here for a primer of data parallelism.

While FSDP allows us to scale the global batch size across multiple nodes, FSDP alone is not enough to prevent long sequences blowing up memory. A 30-second audio clip produces ~1500 encoder frames, which after 5× adapter compression become ~300 tokens spliced into a 1040-token context. On our H100 nodes (80 GB VRAM per GPU), we could train at ~1k context but couldn't generate more than ~500 output tokens. Scaling to 8k+ contexts for multi-turn conversations with audio was out of the question: activation memory scales roughly linearly with sequence length, and FSDP shards parameters but not activations. Every GPU still holds the full activation tensor for its batch. Activation checkpointing bought us ~60% back, but not enough. We were also stuck at 10B ~ 27B parameters — anything larger wouldn't feasibly fit with FSDP sharding, while maintaining a sufficiently large global batch size and reasonable walltime, across the full cluster.

What Megatron gives you

Megatron provides native support for four parallelism strategies, each mapped to the level of the hardware where it's cheapest.

1. Use Tensor Parallelism (TP) for Intra-Node Scaling: When you have a group of GPUs (e.g., the 8 A100s in a DGX node) connected with a high-speed backplane, tensor parallelism is the ideal way to scale up to handle massive layers. TP splits weight matrices within a node, accomplishing two goals simultaneously:

(a) Distributed memory load: It enables the combining of the memory of several GPUs to store a single, massive layer that would be too large for any one device.

(b) Parallelizes computation: Each GPU performs a smaller portion of the matrix multiplication using its local shard of the weight matrix. This allows the GPUs to work on the same forward or backward pass of a single layer at the same time, speeding up the computation.

You may refer to this article for a more detailed writeup of how TP works.

Tensor parallelism

Source: Stanford CS 336 Language Modeling from Scratch, Lecture 7: Parallelism 1, 58 min 12 sec

To enable TP in Megatron, you don't manually shard weights — you declare which linear layers should be column-parallel (output sharded across TP ranks) and which should be row-parallel (input sharded across TP ranks) in a layer spec. Megatron and TransformerEngine handle the sharding and communication internally.

The imports:

from megatron.core.extensions.transformer_engine import (
    TEColumnParallelLinear,  # shards output dim: each GPU holds W[:, out/tp]
    TERowParallelLinear,     # shards input dim: each GPU holds W[in/tp, :]
    TEDotProductAttention,   # handles attention with TP, CP, and packed sequences
    TENorm,                  # fused RMSNorm / LayerNorm
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules

The layer spec wires these into each transformer layer. For Gemma2:

layer_spec = ModuleSpec(
    module=Gemma2TransformerLayer,
    submodules=Gemma2TransformerLayerSubmodules(
        input_layernorm=TENorm,
        self_attention=ModuleSpec(
            module=SelfAttention,
            submodules=SelfAttentionSubmodules(
                linear_qkv=TEColumnParallelLinear,   # QKV projection: column-parallel
                core_attention=TEDotProductAttention,
                linear_proj=TERowParallelLinear,      # output projection: row-parallel
            ),
        ),
        mlp=ModuleSpec(
            module=MLP,
            submodules=MLPSubmodules(
                linear_fc1=TEColumnParallelLinear,    # up-projection: column-parallel
                linear_fc2=TERowParallelLinear,       # down-projection: row-parallel
            ),
        ),
        # ... layernorms and bias-dropout-add omitted
    ),
)

The pattern is: column-parallel for layers that fan out (QKV, MLP up-projection — output dimension gets sharded), row-parallel for layers that fan back in (attention output projection, MLP down-projection — input dimension gets sharded). A column-parallel layer followed by a row-parallel layer forms a pair: the column layer's sharded output feeds directly into the row layer's sharded input with no communication between them. The only all-reduce (or reduce-scatter with sequence parallelism) happens after the row-parallel layer, once per attention block and once per MLP block.

2. Use Pipeline and Data Parallelism for Inter-Node Scaling: To scale beyond a single node, you typically combine tensor parallelism with pipeline parallelism, where each stage of the pipeline is a tensor-parallel group of GPUs. This minimizes the slow communication across nodes to only what is necessary between pipeline stages.

Pipeline Parallelism (PP) splits layers across nodes. For example, with PP=2, node 1 holds layers 0–20 and node 2 holds layers 21–41. Cross-node communication is a single point-to-point (P2P) activation transfer per microbatch per forward pass — not an all-reduce. For training runs, there is also a second P2P transfer per microbatch passing the gradients backwards from node 2 to node 1.

3. Data Parallelism (DP) and Context Parallelism (CP) handle the remaining dimensions. DP is standard (replicate and shard optimizer states across ranks). CP is the interesting one for us — more on this below.

Here's a sample of how you can write the config of a production launch:

# Production: 4 nodes × 8 GPUs = 32 GPUs
TP=4    # split weights within node (NVLink)
PP=2    # split layers across 2 nodes (IB, point-to-point)
CP=1    # context parallelism (increase for 8k+ sequences)
# DP = 32 / (4×2×1) = 4 data-parallel ranks

INFRA_ARGS=(
    --tensor-model-parallel-size ${TP}
    --pipeline-model-parallel-size ${PP}
    --context-parallel-size ${CP}
    --use-distributed-optimizer
    --sequence-parallel
    --recompute-granularity full
    --recompute-method uniform
    --recompute-num-layers 1
)

For single-node development we drop to TP=1, PP=1 for fast iteration — the same code, just different flags.

Context parallelism and audio

CP is interesting because it unblocks long-context audio training.

Standard CP splits the sequence across GPUs and exchanges KV pairs in a ring. With CP=4, an 8k sequence means each GPU processes 2k tokens and stores activations for only its segment — solving the memory problem that FSDP couldn't.

But there's a subtlety with multimodal sequences. In our model, a typical input looks like [system tokens] [audio placeholder tokens] [instruction tokens] [response tokens]. Because CP splits the sequence at the token level before embedding, the multimodal integration logic must be CP-aware. The placeholder IDs are replaced with Whisper encoder outputs via masked_scatter, but this happens locally on each GPU. The system has to slice the Whisper encoder outputs and route the correct chunks to whichever specific CP ranks hold the audio placeholders. Each rank performs the replacement for its local chunk, and from that point on, they enter the transformer decoder as regular embeddings, allowing standard CP ring-attention to take over. Below is example using a simple tensor to explain what we mean.

Imagine an input sequence of length 8 (\(L=8\)), a hidden dimension of 2 (\(D=2\)), and we are using CP=2 (2 GPUs). Let token 99 be the audio placeholder.

  • Input Tokens: [10, 99, 99, 99, 99, 20, 30, 40]

  • Whisper Encoder Output (\(4 \times 2\) tensor): [[A1, A1], [A2, A2],[A3, A3], [A4, A4]]

1. The sequence is split at the token level (Dataloader):

  • GPU 0 gets Tokens [10, 99, 99, 99]

  • GPU 1 gets Tokens [99, 20, 30, 40]

2. Local Embeddings: Each GPU passes its 4 tokens through the standard text embedding layer, resulting in local \(4 \times 2\) tensors.

3. Localized Replacement (The Multimodal Subtlety): To replace the placeholders, CP must know where the audio tokens landed. The Whisper encoder outputs must be sliced so each GPU gets the exact audio features it needs for its local chunk.

  • GPU 0 searches its chunk and sees three placeholders. It receives the first 3 Whisper outputs and applies masked_scatter locally:

    • GPU 0 Tensor becomes: [[Emb_10], [A1, A1], [A2, A2], [A3, A3]]
  • GPU 1 searches its chunk and sees one placeholder. It receives the last Whisper output and applies masked_scatter locally:

    • GPU 1 Tensor becomes: [[A4, A4],[Emb_20], [Emb_30], [Emb_40]]

What Megatron doesn't give you

Megatron-LM is primarily a text decoder framework. As of March 2026, while there is an example of a Vision LLM, it does not yet officially support an audio encoder, nor does Energon supports audio data. Specifically, we needed to add:

  • A HuggingFace Whisper encoder that produces frame-level features, plus an adapter that compresses and projects them into the decoder's embedding space

  • A data pipeline that reads 700+ multimodal datasets (audio bytes + text) from Mosaic MDS format without converting to Energon/WebDataset

  • Loss masking that correctly excludes audio prompt tokens from the training objective

Megatron has Energon for multimodal data loading, but our 30+ TB sat in MDS format. In addition, Energon as of March 2026 does not yet natively support audio. Even if we are able to customize Energon to accept audio files, converting all our data from MDS to Energon format would involve significant work.

Instead we wrote an MDS adapter that slot into Megatron's extension points (model_provider, forward_step, train_valid_test_dataloaders_provider).

Megatron-core also does not provide a drop-in solution for LoRA. LoRA is readily available in Megatron Bridge, but thus far we have not had success getting both Megatron-core and Megatron Bridge to work nicely in the same python environment. Thankfully, it is possible to build your own LoRA with just Megatron-core itself.

What we built (to be covered in later articles)

  1. A data shim reading MDS datasets in-place — 712 train + 85 val datasets, zero format conversion.

  2. Whisper encoder integration with learned layer weighting, MLP adapter, and masked_scatter fusion into the token stream.

  3. Sequence packing via cu_seqlens for variable-length audio, eliminating ~40% padding waste.

  4. TP-aware LoRA with per-module learning rates (encoder 0.1×, adapter 0.5×, decoder 1.0×).

None of these required modifying Megatron's core. We monkey-patched TransformerEngine where needed (Gemma2's attention soft-capping) and used Megatron's standard extension points for everything else.

Was it worth it?

Three weeks of engineering. Most of the time went to things that had nothing to do with parallelism: shared memory exhaustion from opening 700 MDS datasets at once (fix: LRU pool of 8 open datasets), and a weight convention mismatch between HuggingFace and TransformerEngine that silently wrecked the model.

That last one is worth dwelling on. Gemma2's RMSNorm uses an offset convention: output = (1 + weight) * normalize(x), where weights initialize to zero. TransformerEngine's RMSNorm uses the standard convention: output = weight * normalize(x), with weights initializing to one. When we loaded the HF checkpoint into Megatron, the converter copied the weights verbatim. Every LayerNorm in the model was applying the wrong normalization scale. The loss at iteration 1 was ~36 — worse than the ~12.5 you'd expect from random predictions, because the softcapped logits were confidently wrong rather than uniformly distributed. It took a day of debugging to trace this to the norms. The fix was one line: param.data.add_(1.0) for every LayerNorm weight below 0.5.

The point: framework migrations could break in subtle numerical ways that look like "the model just isn't learning." You need to validate loss curves against your baseline from the very first step.

We did that validation and after fixing the LayerNorm issue and the attention soft-capping (covered in article 3), Megatron's loss was ~2.2 by step 20 and continued declining normally through the first 1k steps of our dev run.

On performance: on 8×H100 (single node, TP=1, PP=1), Megatron's built-in throughput counter reports ~1100 TFLOP/s/GPU at steady state, with ~855ms per iteration — that works out to ~19k tokens/s/GPU at our GBS=64 and seq_len=2048.

A note on MFU: the reported TFLOP/s exceeds H100's 989 TFLOPS BF16 dense peak because Megatron counts activation recomputation FLOPs (we run --recompute-granularity full). Backing out the recompute overhead and adjusting for the mixed architecture — a Whisper encoder, LoRA on the decoder, audio preprocessing not counted — we estimate effective MFU in the range of 35–40%.

The real payoff is that we can now push to 8k contexts by setting CP=2, train on 64+ GPUs by adding nodes, and pick up NVIDIA's latest optimizations (dynamic context parallelism, FP8 training) without re-architecting. Three weeks of upfront cost bought us a training infrastructure that scales with the hardware.

Next up: how we avoided converting 30 TB of training data by writing an MDS-to-Megatron adapter.