Skip to main content

Command Palette

Search for a command to run...

The MDS Shim — Zero-Conversion Data Loading for 800+ Datasets

How we trained on 800+ MDS datasets in Megatron without converting a single file

Updated
12 min read
The MDS Shim — Zero-Conversion Data Loading for 800+ Datasets

We have about 800 datasets in Mosaic MDS format, with tens of millions of multimodal samples — each one an audio clip, an instruction, and a target response — spread across thousands of compressed shards on a shared filesystem.

Because Megatron expects Energon format (WebDataset plus metadata), converting 800+ MDS datasets to Energon would mean decoding every sample, re-encoding into WDS tar archives, generating .nv-meta configs, and temporarily doubling our ~30 TB storage footprint while both copies coexisted. On our filesystem, that's a non-trivial amount of I/O which we would love to avoid.

Instead of converting all our data from MDS to Energon format, we created a new file (mds_chatml_dataset.py, 815 lines), taking about two days end-to-end including debugging. It reads MDS shards directly, converts samples on the fly, and presents them to Megatron's training loop as if they came from Energon. With the MDS shim, we avoided the need for format conversion, intermediate copies, and changes to the training script. All we needed to do was swap --train-data-path from a directory to a YAML file.

This article was written with the assistance of Claude Opus 4.6 and Gemini 3.1 Pro. The banner is generated with Nano Banana.

What MDS looks like on disk

Each MDS dataset is a directory of numbered shard files (.mds or .mds.zstd) plus an index.json that describes the schema and shard layout. A shard is a flat binary blob: an offset table at the front, then packed sample data. Zstd-compressed shards decompress to the same format. Mosaic's streaming library handles all of this through its StreamingDataset class — in theory.

The five columns we care about per sample:

Column Type What it is
context_audio bytes OGG/Vorbis audio (the input)
instruction_text str Task instruction
answer_text str Target output
task str Task type: ASR, AQA, ST, AC, ...
language str Language code: en, zh, ta, ms, ...

(The schema has three more columns — answer_audio, instruction_audio, context_text — but they're empty in our data.)

The training mix is controlled by a ~5,000-line YAML config. Each entry has a name, path, sampling weight, and an optional choose field capping how many samples are drawn per epoch:

train:
  - name: <dataset_name>
    path: <dataset_path>
    task: ASR
    choose: 50000

This YAML was already how our Mosaic Composer pipeline controlled the data mix. We reused it as-is.

The adapter pattern

Megatron's training loop calls train_valid_test_dataloaders_provider() at startup, which must return train and validation dataloaders. Our adapter hooks in by checking whether --train-data-path ends in .yaml:

def train_valid_test_dataloaders_provider(train_val_test_num_samples):
    args = get_args()

    if _is_yaml_path(args.train_data_path[0]):
        # MDS path — build our custom dataloaders
        train_ds = build_mds_train_dataloader(yaml_path=..., ...)
        val_name_ds_pairs = build_mds_val_dataloaders(yaml_path=..., ...)
        return EnergonDataloader(train_ds), val_name_ds_pairs, None

    # Otherwise, fall back to original Energon/WDS pipeline
    ...

The EnergonDataloader wrapper wraps any iterable in cyclic_iter and exposes a save_state() method that we treated as a no-op. The training loop calls next() on an iterator that yields batch dictionaries. It does not know the data came from MDS.

An issue with treating save_state() as a no-op: we can't resume from a data checkpoint. If training crashes, the shuffle plan re-seeds and the model sees a different sample order on restart. This differs from the original Mosaic library's behaviour (which supports deterministic resumability), but with tens of millions of samples, the probability of repeating a significant fraction of samples in one epoch is negligible. We decided that the tradeoff was acceptable.

The conversion from MDS sample to Megatron batch happens in three stages.

Stage 1: Raw MDS sample to ChatML

Each raw MDS sample is a dict. We convert it to the internal ChatMLSample format that the task encoder expects:

def _mds_to_chatml_sample(mds_sample, sample_idx):
    task = mds_sample['task']
    raw_instruction = mds_sample['instruction_text']
    lang_code = mds_sample['language']
    answer_text = mds_sample['answer_text']
    waveform, sr = torchaudio.load(io.BytesIO(mds_sample['context_audio']))
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, sr, 16000)

    # ASR and AC tasks get randomised task-specific prompts;
    # everything else keeps the raw instruction_text
    instruction = _get_task_instruction(task, raw_instruction, lang_code)

    conversation = [
        {'role': 'user', 'content': f'<audio>{instruction}'},
        {'role': 'assistant', 'content': answer_text},
    ]

    return ChatMLSample(audios=[(waveform, sr)], conversation=...)

This is the only function that knows about MDS. Everything downstream — mel spectrogram extraction, tokenization, <SpeechHere> placeholder expansion, loss mask construction — goes through TaskEncoder, shared with the Energon pipeline.

Stage 2: The shuffle plan

Naively loading 800+ datasets: opening all of them as StreamingDataset objects, sample proportionally to weights. We tried this first, but it blew up the memory immediately.

On 8 GPUs, that's at least 800 x 8 = 6,400 StreamingDataset instances, each allocating shared memory for internal shard tracking. The process hit /dev/shm's 64 GB limit within seconds, then died with OSError: [Errno 28] No space left on device. We didn't get to iteration 1.

The second attempt was a pool: keep 8 datasets open at a time, rotate through them weighted by sampling probability. This avoided the memory blowup but introduced loss spikes. Every time the pool swapped a dataset, the model suddenly shifted from all-ASR batches to all-speech-translation batches. The loss would jump 15-20% and take ~50 steps to recover. The training instability was unacceptable because momentary distribution shifts could corrupt the optimizer state in ways that wouldn't average out cleanly.

The solution: bypass StreamingDataset entirely and read shards with direct file I/O, using a pre-generated shuffle plan instead of a live shuffle buffer.

  1. At init, scan all 800+ datasets' index.json files to build a global index: which dataset has which shards, how many samples per shard, cumulative offsets. Compute normalised sampling weights from each dataset's choose field (choose_i / sum(choose)), matching how the Mosaic Composer pipeline allocates data. Takes about 3 seconds.

  2. Generate a shuffle plan (default 1.5M entries, configurable via MDS_SHUFFLE_PLAN_SIZE). Each dataset is allocated a share of the plan proportional to its normalised weight. A _DatasetEpochSampler fills each dataset's allocation with without-replacement indices — if the allocation exceeds the dataset size (e.g. a 10k-sample Tamil dataset with choose=50000), the sampler tiles full shuffled passes so every sample is seen an equal number of times before any is repeated. All datasets' allocations are concatenated, then globally shuffled with rng.permutation().

  3. Shard the plan across ranks and workers. Each (dp_rank, worker_id) pair gets a deterministic RNG seed and a unique non-overlapping stride-slice of the shuffled plan: plan[my_shard::total_shards]. No rank sees another rank's data.

  4. Iterate through the plan: for each (dataset_idx, sample_idx) entry, resolve to a (shard, offset) pair, load the shard through an LRU cache (decompressed shards, 4 GB budget, configurable via MDS_SHARD_CACHE_GB), convert to ChatML, encode, batch, yield. When the plan is exhausted, generate a fresh one for the next epoch.

class _ShardReader:
    """Reads samples from an MDS shard by direct file I/O.
    Decompresses zstd on first access. Cached by _ShardLRUCache."""

    def get_sample(self, idx):
        self._ensure_loaded()
        begin = int(self._offset_table[idx])
        end = int(self._offset_table[idx + 1]) if idx + 1 <= self.n_samples \
            else len(self._raw_data) # boundary check for last sample
        return self._mds_reader.decode_sample(self._raw_data[begin:end])

class _ShardLRUCache:
    """LRU cache of decompressed shards, bounded at 4 GB."""
    # Evicts oldest shard when total memory exceeds budget.

No shared memory allocation, no background threads leaking file descriptors, and every micro-batch contains samples from many datasets mixed together. After switching to direct shard I/O, loss spikes vanished completely.

Plan size tuning. The default plan size is 1.5M entries. Because the plan is just two integer arrays (dataset index + sample index), memory overhead is negligible compared to a buffer of decoded samples. The plan provides thorough cross-dataset mixing: after global permutation, consecutive samples in any rank's slice come from different datasets with high probability.

Sampling is without replacement. The _DatasetEpochSampler generates shuffled passes over each dataset's full sample set. When choose > total (upsampling), the sampler tiles complete shuffled passes — a dataset with 10k samples and choose=50000 produces 5 full passes where every sample appears exactly 5 times before the epoch ends. This matches the epoch semantics of Mosaic's Stream(choose=N). When the plan is exhausted and a new one is generated, each dataset gets a fresh set of shuffled passes.

Stage 3: Validation — cache and replay

Validation needs to evaluate each of the 85 datasets independently and log per-dataset metrics to wandb under task-specific keys.

The first eval reads all samples through the shard-reader path, encodes them, and caches batches in memory. Subsequent evals replay the cache — avoiding 85 shared-memory allocations per eval interval. While this yields bit-identical results, it does freeze the randomized ASR/AC prompts to whatever was chosen during the first evaluation pass.

class MdsSingleValDataset(IterableDataset):
    def __iter__(self):
        if self._cached_batches is None:
            self._cached_batches = self._load_and_cache()
        yield from self._cached_batches

With 80+ datasets at ~512 samples each (capped by --eval-samples-per-dataset), the total cache footprint is roughly 2-3 GB of encoded tensors — well within our 80 GB GPU memory budget.

Per-dataset evaluation required monkey-patching Megatron's evaluate_and_print_results to iterate over each dataset independently and log to wandb under dataset-specific keys. This replaced Megatron's default of evaluating a single merged validation set — which would've been useless. An average loss across 85 datasets spanning ASR, speech translation, audio QA, and spoken dialogue tells you nothing about which tasks are degrading.

The three bugs that took the longest

Shared memory exhaustion (2 hours to diagnose, 30 minutes to fix). Opening all 800+ datasets as StreamingDataset objects burned through /dev/shm before loading a single sample. The misleading part: the error says No space left on device, which sent us on a detour checking filesystem quotas before someone ran df -h /dev/shm and saw it was full. Fix: stop using StreamingDataset for training entirely.

Read-only data directories (1 hour). Some dataset directories sat on a read-only NFS mount. StreamingDataset writes .tmp decompression files next to the source data — silently failing with PermissionError. Fix: auto-detect read-only directories with os.access(path, os.W_OK) and redirect decompression to a writable cache. The cache path is configurable via MDS_LOCAL_CACHE env var (defaults to /dev/shm/mds_cache) — and yes, we were aware of the irony of using /dev/shm after bug #1 killed us on shared memory. The difference is that the decompression cache uses bounded, predictable storage (one shard at a time), not the unbounded allocation pattern of 5,696 StreamingDataset instances.

Loss spikes from pool-based rotation (half a day). This was subtle. The spikes were small enough to look like normal training variance for the first few hundred steps. It wasn't until we plotted the loss colored by which dataset-pool-slot was active that the pattern jumped out: every spike aligned exactly with a pool rotation event. The shuffle buffer killed it.

Corrupted shards and malformed index.json files don't crash training — the shard reader catches exceptions and skips bad samples, logging a warning. We saw about 0.01% of samples fail to decode across the full dataset.

What we'd do differently

Skip the StreamingDataset path entirely and go straight to direct shard I/O. Mosaic's streaming library is designed for distributed training with its own dataloader — using StreamingDataset as a sample reader is fighting the abstraction. Direct shard reading was simpler, faster, and side-stepped every resource management issue.

Profile the data loading pipeline more carefully from day one. Currently, our switch to MDS has introduced a momentary 0.3-second GPU power drop on every single iteration. This is a classic symptom of GPU starvation: because heavy operations like OGG audio decoding (torchaudio.load), resampling, and mel spectrogram extraction are happening synchronously, the GPUs are forced to sit idle while the CPU prepares the next batch. We are still actively working to eliminate this bottleneck. If we were to build this again, we would design the pipeline from the start to aggressively decouple data preparation—likely by wrapping the dataset in a standard PyTorch DataLoader with a high num_workers count and prefetch_factor—to ensure the GPUs are never waiting on the CPU.

The numbers

On 8xH100 (single node, TP=1, PP=1) with the MDS adapter:

  • Step time: slightly less than 1 second per iteration (GBS=64, MBS=4, seq_len=2048)

  • Throughput: ~19k padded tokens/s/GPU (GBS x seq_len / time / GPUs, including padding) — matching the Energon baseline within measurement noise

  • Loss at step 5k: Mosaic baseline 0.58, Megatron with MDS adapter 0.56 — within 5% relative. This is an early-training snapshot; we confirmed the gap stayed under 3% through 50k steps before committing to the full run

The adapter adds no measurable overhead to step time compared to Energon on the same data. The cost is all at startup (3 seconds to scan index files) and in the first validation pass (building the 85-dataset cache takes ~10 minutes).

If you're doing a similar migration

The pattern generalises to any source format: write an IterableDataset yielding batch dicts, wrap it in EnergonDataloader, branch in train_valid_test_dataloaders_provider(). The default Energon path stays intact — anyone on the team without MDS data keeps using Energon, unchanged. Both formats run through the same training script.

The total diff: one new file (815 lines) and about 330 lines of MDS-related code in pretrain_llm.py for the dataloader branching, per-dataset validation monkey-patch, and wandb logging (plus another ~400 lines for generation-based evaluation, which is arguably a separate feature). We added 4 lines to Megatron core for a --wandb-entity argument — everything else is runtime monkey-patches that don't touch upstream files. The model successfully trained on the 800+ datasets on the first real run after the shuffle buffer fix.

Maintenance-wise, the adapter depends on Mosaic's streaming library only for zstd decompression and MDS sample decoding — both stable internal APIs. The Megatron interface point (train_valid_test_dataloaders_provider) has been stable across the versions we've used. We've been running this without touching the adapter code.