diff --git a/timm/data/loader.py b/timm/data/loader.py index 4d201e4c6e..83ebfc4f2e 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -334,7 +334,7 @@ def create_loader( assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use" if balance_classes: assert samples_csv_path, "Provide csv with labels to use balance_classes." - samples_csv = pd.read_csv(samples_csv_path) + samples_csv = pd.read_csv(samples_csv_path).astype(str) all_labels = samples_csv["label"].values unique, counts = np.unique(all_labels, return_counts=True) unique_counts = {v: c for v, c in zip(unique, counts)}