Skip to content

fix(checkpointing): save/restore SeedableRandomSampler for map-style datasets#4019

Open
Anai-Guo wants to merge 1 commit into
huggingface:mainfrom
Anai-Guo:fix-checkpointing-seedable-sampler-map-style
Open

fix(checkpointing): save/restore SeedableRandomSampler for map-style datasets#4019
Anai-Guo wants to merge 1 commit into
huggingface:mainfrom
Anai-Guo:fix-checkpointing-seedable-sampler-map-style

Conversation

@Anai-Guo
Copy link
Copy Markdown

Problem

Fixes #3996.

When using use_seedable_sampler=True with a map-style (non-iterable) dataset, save_state() / load_state() silently drops the SeedableRandomSampler.epoch counter. On resume the counter resets to 0, so the sampler re-derives the same seed sequence (initial_seed + 0, initial_seed + 1, …) that was already used before the checkpoint — the model trains on the exact same shuffle order it already saw.

Root cause

save_state() and load_state() both gate sampler persistence on:

if isinstance(dataloader.dataset, IterableDatasetShard):
    ...

SeedableRandomSampler is attached to both iterable and map-style dataloaders (when use_seedable_sampler=True), but the guard prevents it from being saved for the map-style case.

Fix

Remove the IterableDatasetShard guard and persist any SeedableRandomSampler, regardless of dataset type. On the load side, add an input_sampler_file.exists() guard for backwards-compat with older checkpoints that do not contain a sampler file.

-        from .data_loader import IterableDatasetShard, SeedableRandomSampler
-
-        if isinstance(dataloader.dataset, IterableDatasetShard):
-            sampler = dataloader.get_sampler()
-            if isinstance(sampler, SeedableRandomSampler):
-                save(sampler, output_sampler_file, ...)
+        from .data_loader import SeedableRandomSampler
+
+        sampler = dataloader.get_sampler()
+        if isinstance(sampler, SeedableRandomSampler):
+            save(sampler, output_sampler_file, ...)

The fix is identical in shape for both save_state and load_state; the load path additionally checks input_sampler_file.exists() to stay compatible with checkpoints produced before this fix.

🤖 Generated with Claude Code

…datasets

DataLoader shuffle sequence replays from epoch 0 after resuming from a
checkpoint when using `use_seedable_sampler=True` with map-style datasets.

Root cause: the `IterableDatasetShard` guard in `save_state()` and
`load_state()` prevents saving/restoring `SeedableRandomSampler.epoch`
for non-iterable (map-style) datasets. On resume the epoch counter resets
to 0, so the sampler seeds with `initial_seed + 0`, `initial_seed + 1`,
... replaying the exact same shuffle sequence seen before the checkpoint.

Fix: remove the `IterableDatasetShard` guard and save/load any
`SeedableRandomSampler` regardless of dataset type. On the load side add
an `input_sampler_file.exists()` check for backwards-compat with
checkpoints written before this change.

Fixes huggingface#3996
@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Anai-Guo
Copy link
Copy Markdown
Author

Still relevant — happy to address any review feedback. Please keep open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DataLoader shuffle sequence replays from epoch 0 after resuming from a checkpoint

1 participant