-
Notifications
You must be signed in to change notification settings - Fork 422
[Enhance] Show dataset build progress with cache status #1843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| from torch.distributed.device_mesh import DeviceMesh | ||
| from torch.utils.data import ConcatDataset, RandomSampler, SequentialSampler | ||
| from torch.utils.data import DataLoader as TorchDataLoader | ||
| from tqdm import tqdm | ||
| from typing_extensions import TypedDict | ||
|
|
||
| from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | ||
|
|
@@ -112,6 +113,9 @@ def build_datasets( | |
| if tokenizer_hash is None: | ||
| tokenizer_hash = tokenizer_xxhash(tokenizer)[:16] | ||
|
|
||
| # Expand each dataset config into its concrete jsonl paths while keeping the | ||
| # owning config/tokenize_fn so multi-dataset setups don't share the wrong builder. | ||
| build_tasks: list[tuple[DatasetConfig, BaseTokenizeFnConfig, str | Path]] = [] | ||
| for config in dataset_config: | ||
| _dataset_config = config["dataset"] | ||
| assert isinstance(_dataset_config, DatasetConfig) | ||
|
|
@@ -125,17 +129,24 @@ def build_datasets( | |
| for f in list_dir_or_file(anno_path, suffix=".jsonl", list_dir=False, recursive=True) | ||
| ] | ||
| all_anno_path.sort() | ||
| for anno_path in all_anno_path: | ||
| _dataset_config = copy.deepcopy(_dataset_config) | ||
| _dataset_config.anno_path = anno_path | ||
| anno_name = os.path.basename(anno_path) # for debug | ||
| _tokenize_fn = _tokenize_fn_name.build(tokenizer, tokenizer_hash=tokenizer_hash, anno_name=anno_name) | ||
| _dataset = _dataset_config.build(_tokenize_fn) | ||
| if get_rank() == 0: | ||
| logger.debug( | ||
| f"[Dataset] (Original) {_dataset_config.name}/{os.path.basename(anno_path)}: {len(_dataset)} samples." | ||
| ) | ||
| datasets.append(_dataset) | ||
| build_tasks.extend((_dataset_config, _tokenize_fn_name, path) for path in all_anno_path) | ||
|
|
||
| bar = None | ||
| if get_rank() == 0: | ||
| bar = tqdm(build_tasks, desc="Building datasets", unit="dataset") | ||
|
|
||
| for _dataset_config, _tokenize_fn_name, anno_path in build_tasks: | ||
| _dataset_config = copy.deepcopy(_dataset_config) | ||
| _dataset_config.anno_path = anno_path | ||
| anno_name = os.path.basename(anno_path) | ||
| _tokenize_fn = _tokenize_fn_name.build(tokenizer, tokenizer_hash=tokenizer_hash, anno_name=anno_name) | ||
| _dataset = _dataset_config.build(_tokenize_fn) | ||
| datasets.append(_dataset) | ||
|
|
||
| if get_rank() == 0: | ||
| assert bar is not None | ||
| bar.update(1) | ||
| bar.set_postfix_str(f"{Path(anno_path).stem} (cached: {_dataset.cached})") | ||
|
|
||
| return datasets | ||
|
|
||
|
|
@@ -154,7 +165,7 @@ def build_dataloader( | |
| dp_size = dp_mesh.size() if dp_mesh is not None else 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Good catch fixing if self.pack_level != "none" and get_rank == 0:Since this PR already fixes the deprecated copy, it would be good to fix the active one in |
||
| assert global_batch_size % dp_size == 0, "global_batch_size must be divisible by dp_size." | ||
|
|
||
| if dataloader_config.pack_level != "none" and get_rank == 0: | ||
| if dataloader_config.pack_level != "none" and get_rank() == 0: | ||
| num_tokens = sum(dset.num_tokens.sum() for dset in datasets if dset.num_tokens is not None) | ||
| logger.debug(f"[Dataset] {num_tokens} tokens.") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: The tqdm progress bar is created but never closed. When using manual
bar.update(1)instead of iterating over the wrapper, you need an explicitbar.close()after the loop — otherwise the final state may not render correctly (e.g., the cursor stays on the bar line).