From e38553631d36db157a3e7683720aced7168686dd Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Mon, 18 May 2026 21:07:36 -0700 Subject: [PATCH 1/7] adding a new dataloader to load into huggingface format, and testing stratified splitting --- .gitignore | 5 +- agml/data/__init__.py | 1 + agml/data/hf_loader.py | 279 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 2 + 4 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 agml/data/hf_loader.py diff --git a/.gitignore b/.gitignore index cfd536f1..4db4f780 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,7 @@ image*.png # under development agml/data/augmentations/ -agml/data/exporters/pascal_voc.py \ No newline at end of file +agml/data/exporters/pascal_voc.py + +# vs code settings +.vscode/ \ No newline at end of file diff --git a/agml/data/__init__.py b/agml/data/__init__.py index a1c729c3..4c239a85 100644 --- a/agml/data/__init__.py +++ b/agml/data/__init__.py @@ -18,3 +18,4 @@ from .point_cloud import PointCloud from .public import download_public_dataset, public_data_sources, source from .tools import coco_to_bboxes, convert_bbox_format +from .hf_loader import HuggingFaceDataLoader diff --git a/agml/data/hf_loader.py b/agml/data/hf_loader.py new file mode 100644 index 00000000..63bb98d8 --- /dev/null +++ b/agml/data/hf_loader.py @@ -0,0 +1,279 @@ +import os +import shutil +import urllib.parse +from collections import Counter +import math +from typing import List, Union + +try: + from datasets import load_dataset, DatasetDict, Image, Sequence, Value, ClassLabel +except ImportError: + raise ImportError( + "The `datasets` library is required to use the HuggingFaceDataLoader. " + "Please install it using `pip install datasets`." + ) + +from agml.framework import AgMLSerializable +from agml.backend.config import data_save_path + +class HuggingFaceDataLoader(AgMLSerializable): + """A data loader designed for loading datasets directly into Hugging Face formats. + + This loader retrieves datasets from S3 (or local paths) and structures them + into Hugging Face `DatasetDict` objects natively compatible with `transformers` + or `diffusers` pipelines. It supports custom multi-column stratification. + + Parameters + ---------- + location : str + The location of the dataset. This can be an S3 URI (s3://...), a public URL, + or a local directory path containing the HuggingFace-formatted dataset. + task : str + The computer vision task. Must be one of 'classification', 'detection', or 'segmentation'. + local_dir : str, optional + The local directory to download the data to. Defaults to `~/.agml/datasets/`. + """ + + serializable = frozenset(("location", "task", "local_dir")) + + def __init__(self, location: str, task: str, local_dir: str = None, **kwargs): + if task not in ['classification', 'detection', 'segmentation']: + raise ValueError("Task must be 'classification', 'detection', or 'segmentation'.") + + self.location = location + self.task = task + + if local_dir is None: + # Extract a pseudo-name from the location + base_name = os.path.basename(location.rstrip('/')) + if not base_name: + base_name = "hf_dataset" + self.local_dir = os.path.join(data_save_path(), base_name) + else: + self.local_dir = local_dir + + self._hf_dataset = None + self._setup_loader() + + def _setup_loader(self): + """Downloads the dataset and loads it into a Hugging Face Dataset.""" + # 1. Handle downloading from S3 if needed + self._download_from_s3() + + # 2. Load into Hugging Face Dataset object + # Assume the dataset is provided in imagefolder format. + try: + self._hf_dataset = load_dataset('imagefolder', data_dir=self.local_dir) + except Exception as e: + raise RuntimeError( + f"Failed to load dataset as 'imagefolder' from {self.local_dir}: {e}" + ) + + # 3. Cast features based on the task + self._cast_features() + + def _download_from_s3(self): + """Downloads files from an S3 URI to `self.local_dir`.""" + if not self.location.startswith('s3://'): + if os.path.isdir(self.location): + self.local_dir = self.location + return # Assume it's a valid path or handle HTTP elsewhere + + if os.path.exists(self.local_dir) and len(os.listdir(self.local_dir)) > 0: + return # Already downloaded + + try: + import boto3 + except ImportError: + raise ImportError("`boto3` is required to download S3 datasets. Run `pip install boto3`.") + + print(f"Downloading dataset from {self.location} to {self.local_dir}...") + os.makedirs(self.local_dir, exist_ok=True) + + parsed_url = urllib.parse.urlparse(self.location) + bucket_name = parsed_url.netloc + prefix = parsed_url.path.lstrip('/') + + s3 = boto3.client('s3') + paginator = s3.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) + + for page in pages: + if 'Contents' not in page: + continue + for obj in page['Contents']: + key = obj['Key'] + if key.endswith('/'): + continue + # Compute relative path + rel_path = os.path.relpath(key, prefix) + local_file_path = os.path.join(self.local_dir, rel_path) + + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + s3.download_file(bucket_name, key, local_file_path) + + def _cast_features(self): + """Casts Hugging Face Dataset features according to the specified CV task.""" + # Convert DatasetDict to a processable format to cast columns accurately + for split_name in self._hf_dataset.keys(): + ds = self._hf_dataset[split_name] + features = ds.features + + # Ensure "image" exists and is of HF Image type + if "image" in features and not isinstance(features["image"], Image): + self._hf_dataset[split_name] = ds.cast_column("image", Image()) + + if self.task == "segmentation": + # Segmentation masks should also be casted to the Image type for automated decoding + label_col = "mask" if "mask" in features else "label" + if label_col in features and not isinstance(features[label_col], Image): + self._hf_dataset[split_name] = ds.cast_column(label_col, Image()) + + def split(self, + val_size: float = None, + test_size: float = None, + stratify_cols: Union[str, List[str]] = None, + seed: int = 42) -> DatasetDict: + """ + Splits the dataset into train, val, and test splits. + Supports stratified splitting across any number of columns. + + Parameters + ---------- + val_size : float + Proportion of the dataset to include in the validation split. + test_size : float + Proportion of the dataset to include in the test split. + stratify_cols : str or list of str + Column(s) to stratify the split on. + seed : int + Random seed for reproducibility. + + Returns + ------- + DatasetDict + A Hugging Face DatasetDict containing 'train', 'val', and 'test' splits. + """ + # If the dataset is already split (e.g., contains train/test/val), + # and we operate on 'train' as the base: + base_ds = self._hf_dataset['train'] if 'train' in self._hf_dataset else self._hf_dataset + + # If both sizes are None, the caller is requesting no split — + # return the dataset as-is (or wrapped as a DatasetDict with 'train'). + if val_size is None and test_size is None: + if isinstance(self._hf_dataset, DatasetDict): + return self._hf_dataset + # Wrap a single Dataset into a DatasetDict under 'train' + dataset_dict = DatasetDict({"train": base_ds}) + self._hf_dataset = dataset_dict + return dataset_dict + + # Treat None for an individual split as 0.0 (i.e., not requested) + if val_size is None: + val_size = 0.0 + if test_size is None: + test_size = 0.0 + + if stratify_cols is None: + stratify_by_column = None + ds_to_split = base_ds + else: + if isinstance(stratify_cols, str): + stratify_cols = [stratify_cols] + + # Create composite key for stratification + def add_stratify_key(example): + key = "_".join(str(example[col]) for col in stratify_cols) + example["_stratify_key"] = key + return example + + ds_to_split = base_ds.map(add_stratify_key) + stratify_by_column = "_stratify_key" + + # Inspect counts per composite key to ensure stratification is possible + try: + values = ds_to_split[stratify_by_column] + counts = Counter(values) + # Basic requirement: at least 2 samples per class for initial stratified split + too_small = {k: v for k, v in counts.items() if v < 2} + if too_small: + raise ValueError( + "Stratified splitting requires at least 2 samples per class. " + f"The following stratify groups have fewer than 2 samples: {too_small}" + ) + + # Additional requirement when doing a second split (val vs test): + # ensure the temporary partition (val+test) will likely contain + # at least 2 samples per class. Conservatively require + # `count >= ceil(2 / eval_size)` for each class when `test_size > 0`. + eval_size_check = val_size + test_size + if test_size > 0 and eval_size_check > 0: + min_required_in_temp = math.ceil(2.0 / eval_size_check) + too_small_temp = {k: v for k, v in counts.items() if v < min_required_in_temp} + if too_small_temp: + raise ValueError( + "Stratified splitting into train/val/test is not possible with the current " + f"class counts and requested `val_size+test_size={eval_size_check}`. " + f"Each class must have >= {min_required_in_temp} samples to guarantee at least 2 " + f"samples in the combined val+test partition. The problematic groups: {too_small_temp}" + ) + except Exception as e: + # If we can't compute counts or there's an issue, raise a clearer error + raise RuntimeError( + f"Failed to validate stratification column '{stratify_by_column}': {e}" + ) + + # Convert the composite string key into a ClassLabel feature so + # Hugging Face's `train_test_split` can perform stratification. + try: + unique_labels = sorted(ds_to_split.unique(stratify_by_column)) + ds_to_split = ds_to_split.cast_column(stratify_by_column, ClassLabel(names=unique_labels)) + except Exception: + # If casting fails for any reason, leave the dataset as-is + # and allow the underlying split to raise a clear error. + pass + + # Split 1: Train vs. Temp (Val + Test) + eval_size = val_size + test_size + if eval_size >= 1.0 or eval_size <= 0.0: + raise ValueError("val_size + test_size must be strictly between 0 and 1.") + + split_1 = ds_to_split.train_test_split( + test_size=eval_size, + stratify_by_column=stratify_by_column, + seed=seed + ) + + train_ds = split_1['train'] + temp_ds = split_1['test'] + + # Split 2: Val vs. Test + if test_size > 0: + relative_test_size = test_size / eval_size + split_2 = temp_ds.train_test_split( + test_size=relative_test_size, + stratify_by_column=stratify_by_column, + seed=seed + ) + val_ds = split_2['train'] + test_ds = split_2['test'] + else: + val_ds = temp_ds + test_ds = None + + # Cleanup composite key + dataset_dict = DatasetDict({"train": train_ds, "val": val_ds}) + if test_ds is not None: + dataset_dict["test"] = test_ds + + if stratify_by_column: + for split in dataset_dict.keys(): + dataset_dict[split] = dataset_dict[split].remove_columns("_stratify_key") + + self._hf_dataset = dataset_dict + return dataset_dict + + @property + def dataset(self) -> DatasetDict: + """Returns the underlying Hugging Face DatasetDict.""" + return self._hf_dataset diff --git a/requirements.txt b/requirements.txt index 8286b8da..8187ffe5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ opencv-python-headless; sys.platform == 'linux' pyyaml>=5.4.1 albumentations dict2xml +datasets +transformers \ No newline at end of file From f81ec9cb7f6261a0d25a3392bbd4e5375f914427 Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Fri, 29 May 2026 11:58:09 -0700 Subject: [PATCH 2/7] updating hf data loader --- agml/data/hf_loader.py | 83 ++++++------------------------------------ 1 file changed, 12 insertions(+), 71 deletions(-) diff --git a/agml/data/hf_loader.py b/agml/data/hf_loader.py index 63bb98d8..3c59c95c 100644 --- a/agml/data/hf_loader.py +++ b/agml/data/hf_loader.py @@ -1,6 +1,3 @@ -import os -import shutil -import urllib.parse from collections import Counter import math from typing import List, Union @@ -14,104 +11,48 @@ ) from agml.framework import AgMLSerializable -from agml.backend.config import data_save_path class HuggingFaceDataLoader(AgMLSerializable): """A data loader designed for loading datasets directly into Hugging Face formats. - This loader retrieves datasets from S3 (or local paths) and structures them + This loader retrieves datasets from the Hugging Face Hub and structures them into Hugging Face `DatasetDict` objects natively compatible with `transformers` or `diffusers` pipelines. It supports custom multi-column stratification. Parameters ---------- - location : str - The location of the dataset. This can be an S3 URI (s3://...), a public URL, - or a local directory path containing the HuggingFace-formatted dataset. + dataset_name : str + The name of a dataset on the Hugging Face Hub. task : str The computer vision task. Must be one of 'classification', 'detection', or 'segmentation'. - local_dir : str, optional - The local directory to download the data to. Defaults to `~/.agml/datasets/`. + cache_dir : str, optional + The local directory to cache the dataset in. """ - serializable = frozenset(("location", "task", "local_dir")) + serializable = frozenset(("dataset_name", "task", "cache_dir")) - def __init__(self, location: str, task: str, local_dir: str = None, **kwargs): + def __init__(self, dataset_name: str, task: str, cache_dir: str = None, **kwargs): if task not in ['classification', 'detection', 'segmentation']: raise ValueError("Task must be 'classification', 'detection', or 'segmentation'.") - self.location = location + self.dataset_name = dataset_name self.task = task - - if local_dir is None: - # Extract a pseudo-name from the location - base_name = os.path.basename(location.rstrip('/')) - if not base_name: - base_name = "hf_dataset" - self.local_dir = os.path.join(data_save_path(), base_name) - else: - self.local_dir = local_dir + self.cache_dir = cache_dir self._hf_dataset = None self._setup_loader() def _setup_loader(self): - """Downloads the dataset and loads it into a Hugging Face Dataset.""" - # 1. Handle downloading from S3 if needed - self._download_from_s3() - - # 2. Load into Hugging Face Dataset object - # Assume the dataset is provided in imagefolder format. + """Loads the dataset from the Hugging Face Hub.""" try: - self._hf_dataset = load_dataset('imagefolder', data_dir=self.local_dir) + self._hf_dataset = load_dataset(self.dataset_name, cache_dir=self.cache_dir) except Exception as e: raise RuntimeError( - f"Failed to load dataset as 'imagefolder' from {self.local_dir}: {e}" + f"Failed to load Hugging Face dataset '{self.dataset_name}': {e}" ) - # 3. Cast features based on the task self._cast_features() - def _download_from_s3(self): - """Downloads files from an S3 URI to `self.local_dir`.""" - if not self.location.startswith('s3://'): - if os.path.isdir(self.location): - self.local_dir = self.location - return # Assume it's a valid path or handle HTTP elsewhere - - if os.path.exists(self.local_dir) and len(os.listdir(self.local_dir)) > 0: - return # Already downloaded - - try: - import boto3 - except ImportError: - raise ImportError("`boto3` is required to download S3 datasets. Run `pip install boto3`.") - - print(f"Downloading dataset from {self.location} to {self.local_dir}...") - os.makedirs(self.local_dir, exist_ok=True) - - parsed_url = urllib.parse.urlparse(self.location) - bucket_name = parsed_url.netloc - prefix = parsed_url.path.lstrip('/') - - s3 = boto3.client('s3') - paginator = s3.get_paginator('list_objects_v2') - pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) - - for page in pages: - if 'Contents' not in page: - continue - for obj in page['Contents']: - key = obj['Key'] - if key.endswith('/'): - continue - # Compute relative path - rel_path = os.path.relpath(key, prefix) - local_file_path = os.path.join(self.local_dir, rel_path) - - os.makedirs(os.path.dirname(local_file_path), exist_ok=True) - s3.download_file(bucket_name, key, local_file_path) - def _cast_features(self): """Casts Hugging Face Dataset features according to the specified CV task.""" # Convert DatasetDict to a processable format to cast columns accurately From 17d359a36d47388018ee40734cbb60d7dcac26f4 Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Fri, 29 May 2026 12:03:29 -0700 Subject: [PATCH 3/7] adding deprecation warning to agml models module --- agml/models/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/agml/models/__init__.py b/agml/models/__init__.py index 6db5f722..a4041cc9 100644 --- a/agml/models/__init__.py +++ b/agml/models/__init__.py @@ -17,6 +17,14 @@ commonly used deep learning models on agricultural datasets within AgML. """ +import warnings + +warnings.warn( + "agml.models is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) + # Before anything can be imported, we need to run checks for PyTorch and # PyTorch Lightning, as these are not imported on their own. try: From 192bec339f1794c1d0832ce2b6d74d95d1e1b84c Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Fri, 29 May 2026 12:27:22 -0700 Subject: [PATCH 4/7] attempting to update package requirements --- environment.yml | 2 ++ pyproject.toml | 33 +++++++++++++++++++-------------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/environment.yml b/environment.yml index a7eae300..80f65c1e 100644 --- a/environment.yml +++ b/environment.yml @@ -41,3 +41,5 @@ dependencies: - pyyaml>=5.4.1 - albumentations - dict2xml + - datasets + - transformers diff --git a/pyproject.toml b/pyproject.toml index 128f0f25..b219307a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,8 @@ dependencies = [ "dict2xml>=1.7.6", "ipywidgets>=8.1.5", "rich>=14.0.0", + "datasets", + "transformers" ] [project.urls] @@ -60,20 +62,6 @@ dependencies = [ [tool.uv] -dev-dependencies = [ - "boto3>=1.35.66", - "scikit-image>=0.21.0", - "shapely>=2.0.6", - "botocore>=1.35.66", - "pandas>=2.0.3", - "pytest>=8.3.3", - "pytest-order>=1.3.0", - "pytest-cov>=5.0.0", - "ruff>=0.7.4", - "mypy>=1.13.0", - "coverage>=7.6.1", - "interrogate>=1.7.0", -] [dependency-groups] docs = [ @@ -89,3 +77,20 @@ docs = [ "mkdocs-git-revision-date-localized-plugin>=1.2.0", "mkdocs-minify-plugin>=0.8.0", ] +dev = [ + "boto3>=1.35.66", + "scikit-image>=0.21.0", + "shapely>=2.0.6", + "botocore>=1.35.66", + "pandas>=2.0.3", + "pytest>=8.3.3", + "pytest-order>=1.3.0", + "pytest-cov>=5.0.0", + "ruff>=0.7.4", + "mypy>=1.13.0", + "coverage>=7.6.1", + "interrogate>=1.7.0", + "datasets", + "transformers" +] + From 3b3be26f565a8706f4c04c07e99ec096ffad508a Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Fri, 29 May 2026 12:41:20 -0700 Subject: [PATCH 5/7] more package fixing attempts --- .github/release-drafter.yml | 3 +-- .github/workflows/docs.yml | 4 ++-- pyproject.toml | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index 45b762c5..148d8f1b 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -1,6 +1,5 @@ name-template: "v$RESOLVED_VERSION" tag-template: "v$RESOLVED_VERSION" -change-template: "- $TITLE #$NUMBER [@$AUTHOR]" change-template: '- $TITLE @$AUTHOR (#$NUMBER)' change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks. @@ -12,7 +11,7 @@ categories: - "enhancement" - "✨" - "⚡️" - - title: "📦 Dependencies Changes": + - title: "📦 Dependencies Changes" labels: - "dependency" - "deps" diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index e48bc8bf..200f625c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,5 +1,5 @@ name: docs -on: [push, pull_request, workflow_dispatch] +on: [workflow_dispatch] permissions: contents: read @@ -47,7 +47,7 @@ jobs: - run: uv sync --only-group docs - - run: uv run mkdocs build --config-file config/mkdocs.yml + - run: uv run properdocs build --config-file config/mkdocs.yml # - run: uv run mkdocs gh-deploy --force --config-file config/mkdocs.yml diff --git a/pyproject.toml b/pyproject.toml index b219307a..788a05d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ docs = [ "mkdocstrings-python>=1.8.0", "mkdocs-git-revision-date-localized-plugin>=1.2.0", "mkdocs-minify-plugin>=0.8.0", + "properdocs" ] dev = [ "boto3>=1.35.66", From 23dc8fda3c6ad9fd86bcb85eeda7dc3a29cb0930 Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Fri, 29 May 2026 12:43:44 -0700 Subject: [PATCH 6/7] require python >=3.9 --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 788a05d3..6a42b2e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ license = {text = "Apache 2.0"} readme = "README.md" packages = [{include = "agml", from = "agml", exclude ='agml/_internal'}] version = "0.7.4" -requires-python = ">=3.8" +requires-python = ">=3.9" keywords = [] classifiers = [ "Development Status :: 4 - Beta", @@ -23,7 +23,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", From 285de0a605bc1d10f66a34719153a9912a600e62 Mon Sep 17 00:00:00 2001 From: Jared Smith Date: Mon, 1 Jun 2026 15:32:15 -0700 Subject: [PATCH 7/7] updating loader to remove task type and add config tag --- agml/data/hf_loader.py | 61 ++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/agml/data/hf_loader.py b/agml/data/hf_loader.py index 3c59c95c..7c7dc21b 100644 --- a/agml/data/hf_loader.py +++ b/agml/data/hf_loader.py @@ -23,20 +23,28 @@ class HuggingFaceDataLoader(AgMLSerializable): ---------- dataset_name : str The name of a dataset on the Hugging Face Hub. - task : str - The computer vision task. Must be one of 'classification', 'detection', or 'segmentation'. + config : str, optional + The dataset configuration (subset) name for datasets that expose multiple + configs (e.g. an augmented variant). cache_dir : str, optional The local directory to cache the dataset in. + + Examples + -------- + Single-config dataset:: + + loader = HuggingFaceDataLoader("org/dataset") + + Multi-config dataset with an augmented variant:: + + loader = HuggingFaceDataLoader("org/dataset", "augmented") """ - serializable = frozenset(("dataset_name", "task", "cache_dir")) + serializable = frozenset(("dataset_name", "config", "cache_dir")) - def __init__(self, dataset_name: str, task: str, cache_dir: str = None, **kwargs): - if task not in ['classification', 'detection', 'segmentation']: - raise ValueError("Task must be 'classification', 'detection', or 'segmentation'.") - + def __init__(self, dataset_name: str, config: str = None, cache_dir: str = None, **kwargs): self.dataset_name = dataset_name - self.task = task + self.config = config self.cache_dir = cache_dir self._hf_dataset = None @@ -45,30 +53,37 @@ def __init__(self, dataset_name: str, task: str, cache_dir: str = None, **kwargs def _setup_loader(self): """Loads the dataset from the Hugging Face Hub.""" try: - self._hf_dataset = load_dataset(self.dataset_name, cache_dir=self.cache_dir) - except Exception as e: - raise RuntimeError( - f"Failed to load Hugging Face dataset '{self.dataset_name}': {e}" + self._hf_dataset = load_dataset( + self.dataset_name, + self.config, + cache_dir=self.cache_dir, ) + except Exception as e: + name = self.dataset_name if self.config is None else f"{self.dataset_name}/{self.config}" + raise RuntimeError(f"Failed to load Hugging Face dataset '{name}': {e}") self._cast_features() def _cast_features(self): - """Casts Hugging Face Dataset features according to the specified CV task.""" - # Convert DatasetDict to a processable format to cast columns accurately + """Infers and casts image-like columns to the HF Image type for automated decoding.""" for split_name in self._hf_dataset.keys(): ds = self._hf_dataset[split_name] features = ds.features - - # Ensure "image" exists and is of HF Image type - if "image" in features and not isinstance(features["image"], Image): - self._hf_dataset[split_name] = ds.cast_column("image", Image()) - if self.task == "segmentation": - # Segmentation masks should also be casted to the Image type for automated decoding - label_col = "mask" if "mask" in features else "label" - if label_col in features and not isinstance(features[label_col], Image): - self._hf_dataset[split_name] = ds.cast_column(label_col, Image()) + if "image" in features and not isinstance(features["image"], Image): + ds = ds.cast_column("image", Image()) + + # A "mask" column is always a pixel map — cast unconditionally. + if "mask" in features and not isinstance(features["mask"], Image): + ds = ds.cast_column("mask", Image()) + # Cast "label" only when it looks like image data (string path or binary bytes), + # not when it is a ClassLabel or a numeric scalar. + elif "label" in features and not isinstance(features["label"], Image): + label_feat = features["label"] + if isinstance(label_feat, Value) and label_feat.dtype in ("string", "binary", "large_binary"): + ds = ds.cast_column("label", Image()) + + self._hf_dataset[split_name] = ds def split(self, val_size: float = None,