torchtitan is designed to work seamlessly with most HuggingFace datasets. It supports three training flavours — pre-training (plain text), instruction-tuning / SFT (chat), and multimodal (vision) — each with its own dataloader. Both text flavours support single-source and multi-source interleaved configurations.
torchtitan/hf_datasets/text_datasets.py # pre-training and SFT
torchtitan/hf_datasets/multimodal/mm_datasets.py # vision
You need three components: a loader function, a sample processor, and a registry entry.
def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
return load_dataset(
dataset_path,
name="20220301.en",
split="train",
streaming=True,
trust_remote_code=True,
)def process_wikipedia_text(sample: dict[str, Any]) -> str:
"""Process Wikipedia dataset sample text."""
return f"{sample['title']}\n\n{sample['text']}"DATASETS = {
# ... existing datasets ...
"wikipedia": DatasetConfig(
path="wikipedia",
loader=load_wikipedia_dataset,
sample_processor=process_wikipedia_text,
),
}dataloader=HuggingFaceTextDataLoader.Config(
dataset="wikipedia",
infinite=True,
),The ChatDataLoader handles single-turn [user, assistant] message pairs. It tokenizes samples using the model's chat template, masks prompt tokens in labels so loss is computed on the assistant response only, and packs multiple short samples into each sequence.
from torchtitan.hf_datasets.text_datasets import ChatDataLoader
def process_gsm8k(sample: dict) -> list[dict]:
return [
{"role": "user", "content": sample["question"]},
{"role": "assistant", "content": sample["answer"]},
]
dataloader=ChatDataLoader.Config(
dataset_path="openai/gsm8k",
load_dataset_kwargs={"name": "main", "split": "train"},
sample_processor=process_gsm8k,
infinite=True,
),Both text flavours support interleaving multiple sources with configurable sampling weights. At each step a source is drawn proportionally to its weight. When a source is drawn, it returns a packed sample, potentially consisting multiple data points from the source. Iteration stops depending on stopping strategy (on_first_exhausted / all_exhausted), defining an epoch boundary — re-looping and shuffling are handled per source exactly as in the single-source case.
All sources must share the same infinite setting.
from torchtitan.hf_datasets.text_datasets import (
HFDataSource,
InterleavedHuggingFaceTextDataLoader,
)
dataloader=InterleavedHuggingFaceTextDataLoader.Config(
sources=[
HFDataSource(dataset="c4", weight=7.0, infinite=True),
HFDataSource(dataset="wikipedia", weight=2.0, infinite=True),
HFDataSource(dataset="my_dataset", weight=1.0, infinite=True),
],
seed=42,
),from torchtitan.hf_datasets.text_datasets import (
ChatDataSource,
InterleavedChatDataLoader,
)
def process_gsm8k(sample):
return [
{"role": "user", "content": sample["question"]},
{"role": "assistant", "content": sample["answer"]},
]
def process_alpaca(sample):
return [
{"role": "user", "content": sample["instruction"]},
{"role": "assistant", "content": sample["output"]},
]
dataloader=InterleavedChatDataLoader.Config(
sources=[
ChatDataSource(
dataset_path="openai/gsm8k",
load_dataset_kwargs={"name": "main", "split": "train"},
sample_processor=process_gsm8k,
weight=3.0,
infinite=True,
),
ChatDataSource(
dataset_path="tatsu-lab/alpaca",
load_dataset_kwargs={"split": "train"},
sample_processor=process_alpaca,
weight=1.0,
infinite=True,
),
],
seed=42,
),Weights are sampling probabilities, normalised internally. A weight of 3.0 alongside 1.0 means the first source is drawn three times as often on average — it does not mean the source is iterated three times per epoch. The epoch boundary is defined by whichever source exhausts first.
This makes weights easy to reason about as a token mixture ratio: if source A has weight 3 and source B has weight 1, roughly 75 % of training tokens will come from A and 25 % from B, regardless of the absolute dataset sizes.
Interleaved dataloaders are fully stateful. The interleaver RNG and the state of every source are saved together, so resuming from a checkpoint produces byte-identical continuations.
| Use case | Dataloader |
|---|---|
| Single pre-training source | HuggingFaceTextDataLoader |
| Multiple pre-training sources | InterleavedHuggingFaceTextDataLoader |
| Single SFT source | ChatDataLoader |
| Multiple SFT sources | InterleavedChatDataLoader |
| Multimodal (vision + text) | MMDataLoader |