diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index c81dc9d0d..b3cd12cc3 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -13,6 +13,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.utils.data import ConcatDataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader as TorchDataLoader +from tqdm import tqdm from typing_extensions import TypedDict from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -112,6 +113,9 @@ def build_datasets( if tokenizer_hash is None: tokenizer_hash = tokenizer_xxhash(tokenizer)[:16] + # Expand each dataset config into its concrete jsonl paths while keeping the + # owning config/tokenize_fn so multi-dataset setups don't share the wrong builder. + build_tasks: list[tuple[DatasetConfig, BaseTokenizeFnConfig, str | Path]] = [] for config in dataset_config: _dataset_config = config["dataset"] assert isinstance(_dataset_config, DatasetConfig) @@ -125,17 +129,24 @@ def build_datasets( for f in list_dir_or_file(anno_path, suffix=".jsonl", list_dir=False, recursive=True) ] all_anno_path.sort() - for anno_path in all_anno_path: - _dataset_config = copy.deepcopy(_dataset_config) - _dataset_config.anno_path = anno_path - anno_name = os.path.basename(anno_path) # for debug - _tokenize_fn = _tokenize_fn_name.build(tokenizer, tokenizer_hash=tokenizer_hash, anno_name=anno_name) - _dataset = _dataset_config.build(_tokenize_fn) - if get_rank() == 0: - logger.debug( - f"[Dataset] (Original) {_dataset_config.name}/{os.path.basename(anno_path)}: {len(_dataset)} samples." - ) - datasets.append(_dataset) + build_tasks.extend((_dataset_config, _tokenize_fn_name, path) for path in all_anno_path) + + bar = None + if get_rank() == 0: + bar = tqdm(build_tasks, desc="Building datasets", unit="dataset") + + for _dataset_config, _tokenize_fn_name, anno_path in build_tasks: + _dataset_config = copy.deepcopy(_dataset_config) + _dataset_config.anno_path = anno_path + anno_name = os.path.basename(anno_path) + _tokenize_fn = _tokenize_fn_name.build(tokenizer, tokenizer_hash=tokenizer_hash, anno_name=anno_name) + _dataset = _dataset_config.build(_tokenize_fn) + datasets.append(_dataset) + + if get_rank() == 0: + assert bar is not None + bar.update(1) + bar.set_postfix_str(f"{Path(anno_path).stem} (cached: {_dataset.cached})") return datasets @@ -154,7 +165,7 @@ def build_dataloader( dp_size = dp_mesh.size() if dp_mesh is not None else 1 assert global_batch_size % dp_size == 0, "global_batch_size must be divisible by dp_size." - if dataloader_config.pack_level != "none" and get_rank == 0: + if dataloader_config.pack_level != "none" and get_rank() == 0: num_tokens = sum(dset.num_tokens.sum() for dset in datasets if dset.num_tokens is not None) logger.debug(f"[Dataset] {num_tokens} tokens.") diff --git a/xtuner/v1/datasets/jsonl.py b/xtuner/v1/datasets/jsonl.py index 96856e2fb..f6a5404a0 100644 --- a/xtuner/v1/datasets/jsonl.py +++ b/xtuner/v1/datasets/jsonl.py @@ -287,9 +287,11 @@ def __init__( logger.debug(f"[Dataset] Start loading [{self.name}]{self.path} with sample_ratio={sample_ratio}.") self._has_chunk = isinstance(tokenize_fn, LongTextPretrainTokenizeFunction) + self._cached = False tok_cache_dir: str | None = None # set inside cache_dir branch when tokenize_fn is CachableTokenizeFunction if cache_tag is not None and (cached := self._get_cached_tag(cache_tag, tokenize_fn)) is not None: + self._cached = True logger.debug(f"[Dataset] Load cached [{self.name}]{self.path} of cache tags {cache_tag}.") offset_path = cached["offsets"] meta_path = cached.get("jsonl_meta") @@ -377,6 +379,7 @@ def __init__( _meta_file = os.path.join(tok_cache_dir, "jsonl_meta") if os.path.exists(_meta_file): logger.debug(f"Loading tokenize meta from cache: {_meta_file}") + self._cached = True _meta = load_dict_from_npy_dir(_meta_file, mmap=enable_mmap_shared) else: _meta = self.count_tokens(offsets, tok_cache_dir) @@ -835,3 +838,7 @@ def load_state_dict(self, state_dict: dict): ... def get_state_dict(self): return {} + + @property + def cached(self) -> bool: + return self._cached