Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions xtuner/v1/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import ConcatDataset, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader as TorchDataLoader
from tqdm import tqdm
from typing_extensions import TypedDict

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -112,6 +113,9 @@ def build_datasets(
if tokenizer_hash is None:
tokenizer_hash = tokenizer_xxhash(tokenizer)[:16]

# Expand each dataset config into its concrete jsonl paths while keeping the
# owning config/tokenize_fn so multi-dataset setups don't share the wrong builder.
build_tasks: list[tuple[DatasetConfig, BaseTokenizeFnConfig, str | Path]] = []
for config in dataset_config:
_dataset_config = config["dataset"]
assert isinstance(_dataset_config, DatasetConfig)
Expand All @@ -125,17 +129,24 @@ def build_datasets(
for f in list_dir_or_file(anno_path, suffix=".jsonl", list_dir=False, recursive=True)
]
all_anno_path.sort()
for anno_path in all_anno_path:
_dataset_config = copy.deepcopy(_dataset_config)
_dataset_config.anno_path = anno_path
anno_name = os.path.basename(anno_path) # for debug
_tokenize_fn = _tokenize_fn_name.build(tokenizer, tokenizer_hash=tokenizer_hash, anno_name=anno_name)
_dataset = _dataset_config.build(_tokenize_fn)
if get_rank() == 0:
logger.debug(
f"[Dataset] (Original) {_dataset_config.name}/{os.path.basename(anno_path)}: {len(_dataset)} samples."
)
datasets.append(_dataset)
build_tasks.extend((_dataset_config, _tokenize_fn_name, path) for path in all_anno_path)

bar = None
if get_rank() == 0:
bar = tqdm(build_tasks, desc="Building datasets", unit="dataset")

for _dataset_config, _tokenize_fn_name, anno_path in build_tasks:
_dataset_config = copy.deepcopy(_dataset_config)
_dataset_config.anno_path = anno_path
anno_name = os.path.basename(anno_path)
_tokenize_fn = _tokenize_fn_name.build(tokenizer, tokenizer_hash=tokenizer_hash, anno_name=anno_name)
_dataset = _dataset_config.build(_tokenize_fn)
datasets.append(_dataset)

if get_rank() == 0:
assert bar is not None
bar.update(1)
bar.set_postfix_str(f"{Path(anno_path).stem} (cached: {_dataset.cached})")

return datasets
Comment on lines 150 to 151
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: The tqdm progress bar is created but never closed. When using manual bar.update(1) instead of iterating over the wrapper, you need an explicit bar.close() after the loop — otherwise the final state may not render correctly (e.g., the cursor stays on the bar line).

Suggested change
return datasets
if bar is not None:
bar.close()
return datasets


Expand All @@ -154,7 +165,7 @@ def build_dataloader(
dp_size = dp_mesh.size() if dp_mesh is not None else 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Good catch fixing get_rank == 0get_rank() == 0 here. However, the same bug exists at line 402 in DataloaderConfig.build:

if self.pack_level != "none" and get_rank == 0:

Since this PR already fixes the deprecated copy, it would be good to fix the active one in DataloaderConfig.build as well — that code path is the one users actually hit.

assert global_batch_size % dp_size == 0, "global_batch_size must be divisible by dp_size."

if dataloader_config.pack_level != "none" and get_rank == 0:
if dataloader_config.pack_level != "none" and get_rank() == 0:
num_tokens = sum(dset.num_tokens.sum() for dset in datasets if dset.num_tokens is not None)
logger.debug(f"[Dataset] {num_tokens} tokens.")

Expand Down
7 changes: 7 additions & 0 deletions xtuner/v1/datasets/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,11 @@ def __init__(
logger.debug(f"[Dataset] Start loading [{self.name}]{self.path} with sample_ratio={sample_ratio}.")

self._has_chunk = isinstance(tokenize_fn, LongTextPretrainTokenizeFunction)
self._cached = False

tok_cache_dir: str | None = None # set inside cache_dir branch when tokenize_fn is CachableTokenizeFunction
if cache_tag is not None and (cached := self._get_cached_tag(cache_tag, tokenize_fn)) is not None:
self._cached = True
logger.debug(f"[Dataset] Load cached [{self.name}]{self.path} of cache tags {cache_tag}.")
offset_path = cached["offsets"]
meta_path = cached.get("jsonl_meta")
Expand Down Expand Up @@ -377,6 +379,7 @@ def __init__(
_meta_file = os.path.join(tok_cache_dir, "jsonl_meta")
if os.path.exists(_meta_file):
logger.debug(f"Loading tokenize meta from cache: {_meta_file}")
self._cached = True
_meta = load_dict_from_npy_dir(_meta_file, mmap=enable_mmap_shared)
else:
_meta = self.count_tokens(offsets, tok_cache_dir)
Expand Down Expand Up @@ -835,3 +838,7 @@ def load_state_dict(self, state_dict: dict): ...

def get_state_dict(self):
return {}

@property
def cached(self) -> bool:
return self._cached
Loading