From Loss=36 to Convergence: Integrating Whisper+Gemma2 into Megatron's TransformerEngine
Four bugs we had to fix to get our AudioLLM training stably inside Megatron's TransformerEngine

From Loss=36 to Convergence: Integrating Whisper+Gemma2 into Megatron's TransformerEngine
When we started debugging our AudioLLM on the Megatron trainer, our loss started at 36. This did not make sense because even with a random prediction over a 262k-token vocabulary, we expect the loss to be no more than 12.5. [For the math nerds: Random uniform predictions over vocabulary size V gives cross-entropy = ln(V). ln(262,144) = ln(2^18) = 18 × ln(2) = 18 × 0.6931 = 12.48 ≈ 12.5] This suggested the forward pass was broken.
A day later we traced it to a one-line mismatch in how Gemma2 stores its LayerNorm weights. But that was only one of four things we had to fix.
The model at a glance
audio waveform
→ mel spectrogram (128 bins × T frames)
→ Whisper encoder (32 layers, hidden_dim=1280)
→ learned weighted sum over 32 layer outputs
→ LayerNorm(1280)
→ MLP adapter: concat 5 frames → Linear(6400→6400) → SiLU → Dropout(0.01) → Linear(6400→3584)
→ masked_scatter into token embeddings
→ Gemma2 9B decoder (42 layers, hidden_dim=3584)
The speech encoder and adapter live on the first pipeline stage (pre_process=True). The Gemma2 decoder spans all stages. Audio meets text at exactly one point: a masked_scatter in the embedding layer. Everything upstream of that scatter is plain PyTorch — no TransformerEngine, no tensor parallelism. The Whisper encoder loads from HF and runs in float32; the adapter is two linear layers, a SiLU, and a dropout.
What took the time was getting Gemma2 to produce correct outputs inside TransformerEngine.
The embedding fusion
The input sequence contains <SpeechHere> placeholder tokens where audio should go. After looking up text embeddings, we scatter the adapted audio features into those positions:
class LanguageModelEmbedding(MegatronModule):
def forward(self, input_ids, position_ids=None, speech_embeds=None):
embeddings = self.word_embeddings(input_ids)
if speech_embeds is not None:
speech_mask = (input_ids == self.config.speech_token_id)
speech_mask_expanded = speech_mask.unsqueeze(-1).expand_as(embeddings)
embeddings = embeddings.masked_scatter(
speech_mask_expanded,
speech_embeds.to(embeddings.device, embeddings.dtype),
)
# Gemma2 convention: scale by sqrt(hidden_size) AFTER fusion
embeddings = embeddings * self.embedding_multiplier # sqrt(3584) ≈ 59.87
embeddings = embeddings.transpose(0, 1).contiguous() # [b,s,h] → [s,b,h]
return embeddings
The scaling order matters. Gemma2 multiplies all embeddings by sqrt(hidden_size), and this has to happen after the scatter — otherwise the speech embeddings bypass the scaling entirely while the text embeddings at those positions get scaled then overwritten. We caught it by comparing embedding norms between HF and Megatron on the first forward pass. With wrong ordering, loss diverged within 100 steps.
Note the .to(embeddings.device, embeddings.dtype) cast. The adapter output may be in a different precision than the embedding table (e.g. fp32 adapter vs bf16 embeddings). Without the cast, masked_scatter silently produces garbage.
The speech_embeds tensor arriving here is already flat — all valid audio tokens concatenated into [total_tokens, hidden], trimmed per-sample to exclude Whisper padding frames. That trimming happens in the model's forward pass (model.py), where we loop over each sample and slice by its audio_length before concatenating. Without it, padding frames from shorter clips leak into the text embeddings.
What we had to fix for Gemma2
Megatron's TransformerLayer assumes a fairly standard pre-norm architecture. Gemma2 deviates in four places.
Peri-layer-norm
Standard Megatron uses pre-norm: one LayerNorm before attention, one before MLP. Gemma2 adds post-norms — four LayerNorms per layer, 168 total. The post-norm sits between the sub-layer output and the residual connection:
# Gemma2 peri-LN: post-norm between sub-layer output and residual add
attn_output = self.post_self_attn_layernorm(attn_output) # post-norm here
hidden_states = residual + dropout(attn_output) # then residual add
mlp_output = self.post_mlp_layernorm(mlp_output) # same pattern
hidden_states = residual + dropout(mlp_output)
We got this right on the first attempt but ran a variant with post-norm after the residual add to verify it mattered by monitoring differences in the loss curve. By step 2.5k: 2.13 vs the correct implementation's 1.23. Put the norm before the add and each sub-layer's contribution stays bounded — the residual stream doesn't blow up with depth. Move it after and you're back to Post-LN, which at 42 layers crushes gradients in the bottom layers while the top layers see disproportionately large ones.
Alternating sliding window attention
Even-indexed layers use a 4096-token sliding window; odd-indexed layers use full causal attention. TransformerEngine's TEDotProductAttention reads window_size at init time, so we set it post-construction:
for layer_idx, layer in enumerate(lm.decoder.layers):
if layer_idx % 2 == 0:
core_attn = layer.self_attention.core_attention
core_attn.window_size = (4096, 0)
At our training context of 2048, this is a no-op. It matters at 8k+ with CP=2, where each rank sees 4k tokens — exactly the window size on even layers, while odd layers exchange KV across ranks for global attention.
Logit soft-capping
Gemma2 bounds attention scores to [-50, 50] via 50 * tanh(scores / 50) before softmax. TE 2.7.0 calls flash_attn internally but never passes the softcap kwarg. We monkey-patch two flash_attn entry points in TE's backends module plus flash_attn_with_kvcache in Megatron's attention module:
_GEMMA2_ATTN_SOFTCAP = 50.0
def _flash_attn_func_with_softcap(*args, **kwargs):
if kwargs.get('softcap', 0.0) == 0.0:
kwargs['softcap'] = _GEMMA2_ATTN_SOFTCAP
return _orig_flash_attn_func(*args, **kwargs)
_te_fa_backends.flash_attn_func = _flash_attn_func_with_softcap
# Same wrapper for flash_attn_varlen_func and flash_attn_with_kvcache
We print a confirmation at startup for each patch group (TE backends and Megatron kvcache), so a failed patch shows up immediately rather than as mysterious loss divergence. If TE renames its internal module paths, the try/except falls through, the warning prints, and our launch script refuses to proceed without the confirmation string.
There's also a final logit soft-cap of 30.0 on output logits — a tanh in gpt_model.py after the output projection.
The RMSNorm offset convention (loss = 36)
Back to the opening problem of loss of 36. HuggingFace's Gemma2 RMSNorm and TransformerEngine's RMSNorm use different weight conventions:
| Weight init | Forward computation | |
|---|---|---|
| HF Gemma2RMSNorm | zeros |
(1 + weight) * normalised |
| TE RMSNorm | ones |
weight * normalised |
HF stores weights as offsets from 1. A weight of 0.2 means an actual scale of 1.2. Our checkpoint converter copied these directly into TransformerEngine, which interprets 0.2 as a literal scale of 0.2. All 168 layernorms scaling activations down by 5–6x, hidden states crushed, logits pinned to the softcap ceiling, cross-entropy at ~36.
The diagnosis trail: disable final logit softcap → loss jumps to ~99 (softcap was masking the severity). Logit stats: mean=22.6, std=7.0, max=30.0 (all hitting ceiling). First layer's input_layernorm weight: mean=0.198 — should be ~1.2. That pointed us to the HF source code.
The debugging hotfix that got us unblocked:
for name, param in language_model.named_parameters():
if 'layernorm' in name and name.endswith('.weight'):
if param.data.min().item() < 0.5:
param.data.add_(1.0)
Tut the proper fix is a single config flag — TransformerEngine natively supports the zero-centered gamma convention:
config.layernorm_zero_centered_gamma = True # tells TE to compute (1 + weight) * norm(x)
This is what we use in production. It eliminates the heuristic < 0.5 guard entirely — TE handles the +1 internally during the forward pass, so checkpoint weights stay in HF's offset format without conversion. If you're porting a model that uses this convention (Gemma2, PaLM), reach for this flag first.
One more catch: Gemma2 uses head_dim=256 independently of hidden_size/num_heads, but Megatron computes 3584/16=224 by default. Fix: --kv-channels 256. The larger head_dim gives 128 RoPE frequency pairs instead of 112, which provides finer positional resolution — relevant if you're planning to extend context length beyond the 8192-token training window with RoPE scaling. We created a config file (transformer_config.py) to document this mismatch because the resulting shape error is unhelpful.
Activation function: GeGLU, not SwiGLU
Megatron's --swiglu flag correctly enables gated linear units but sets the gate activation to SiLU (SwiGLU). Gemma2 uses GELU gating (GeGLU). Without an override, TE selects the wrong fused kernel and the MLP computes a different function:
# Gemma2 uses GeGLU (gelu_pytorch_tanh), not SwiGLU.
# --swiglu correctly sets gated_linear_unit=True but uses F.silu as gate.
# Override to F.gelu so TE selects GEGLU kernel.
config.activation_func = torch.nn.functional.gelu
This won't crash — the shapes are identical — but loss will plateau higher than the reference because every MLP layer applies the wrong nonlinearity. Easy to miss if you're only checking dimensions.
What we validated
With all fixes in, iteration-1 loss landed at 3.7 — close enough to HF's 3.6. The real question was whether the curves would track:
Loss curves: Gap under 2% in loss between HF vs Megatron at step 5k
Weight comparison: QKV, layernorm (after +1 fixup), and embedding weights matched to float32 precision
Downstream: WER on our internal ASR benchmark matched HF within 1% relative (N=12k utterances, 3 test sets)
Risks, mitigations, and what we'd do differently
Monkey-patches. Three softcap patches across two libraries (TE attention backends and Megatron's kvcache), plus patches to _get_param_groups for per-module learning rates (article 5), GPTModel._preprocess for batched generation, train for optimizer reference stashing, evaluate_and_print_results for per-dataset validation (article 6), and save_checkpoint for DDP param buffer re-sync — eight runtime patches total against Megatron internals. Zero source-level modifications to Megatron core, but the runtime patches carry equivalent maintenance burden.
If the softcap patch breaks, loss diverges within ~500 steps (we ablated this). We run a 200-step smoke test on every TE version bump, checking attention score max < 55.0. Detection latency: under 10 minutes.
What we'd do differently. The weight-comparison script. The RMSNorm bug consumed a full day of chasing attention masks and checkpoint loading order before someone compared weight values directly. The script took 20 minutes to write after the fact — iterate over both models' named_parameters(), flag any mean that differs by more than 0.1. It would have caught the +1 offset in seconds.
If you're porting a model to Megatron, check these things first:
Weight conventions — run the comparison script before the first forward pass
LayerNorm variant — verify pre/post/peri placement matches the reference, not just the count
Attention modifications — check if the model uses softcapping, sliding window, or non-standard head_dim; any of these may require TE monkey-patches
Activation function — verify SwiGLU vs GeGLU vs plain GeLU; Megatron's
--swigludefaults to SiLU gatingEmbedding scaling — check whether scaling happens before or after modality fusion
Position encoding — verify head_dim, RoPE base, and whether
kv_channelsneeds an explicit override
A note on context scaling: at 2048 tokens, the ~300 adapted audio tokens from a 30-second clip occupy about 15% of the sequence. At 8192, that drops to ~4%. Multi-turn audio conversations will need either longer clips, multiple audio segments, or higher-fidelity adapters (less compression than 5x). We haven't hit this wall yet, but it's the natural next question after sequence packing.
The integration is ~2,000 lines across 8 files (excluding LoRA, RoPE re-export, and __init__ covered in later articles), all using Megatron's public spec API. Three months stable.





