From acea20c2b87adca8a9dc1d88b4d99531fe88802e Mon Sep 17 00:00:00 2001 From: Krzysztof Czajkowski Date: Fri, 19 Dec 2025 12:06:22 +0100 Subject: [PATCH] rename csv path var name to match train function --- timm/data/loader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 245d36a995..4d201e4c6e 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -229,7 +229,7 @@ def create_loader( worker_seeding: str = 'all', tf_preprocessing: bool = False, balance_classes: bool = False, - dataset_csv_path: Optional[str] = None + samples_csv_path: Optional[str] = None ): """ @@ -274,7 +274,7 @@ def create_loader( worker_seeding: Control worker random seeding at init. tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports. balance_classes: Sample classes with uniform probability - dataset_csv_path: Path to dataset csv, used for class balancing + samples_csv_path: Path to dataset csv, used for class balancing Returns: DataLoader @@ -333,9 +333,9 @@ def create_loader( else: assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use" if balance_classes: - assert dataset_csv_path, "Provide csv with labels to use balance_classes." - dataset_csv = pd.read_csv(dataset_csv_path) - all_labels = dataset_csv["label"].values + assert samples_csv_path, "Provide csv with labels to use balance_classes." + samples_csv = pd.read_csv(samples_csv_path) + all_labels = samples_csv["label"].values unique, counts = np.unique(all_labels, return_counts=True) unique_counts = {v: c for v, c in zip(unique, counts)} label_weights = np.array([1 / unique_counts[num] for num in all_labels])