From f4cc85d5b0188340bf80233f32957d8cacbb1c69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:34:32 -0500 Subject: [PATCH 1/3] Use train/val split instead of k-folds Replace N_FOLDS/fold_id-based k-fold splitting with a fixed train/validation split. Add DEFAULT_TRAIN_VAL_SPLIT and new split_id parameter (used as RNG seed) and expose train_val_split in from_dataset_name and the DataModule constructor with validation assertions. _split_train_val_dataset now uses random_split into [train, val] lengths instead of concatenating k-fold subsets. Update type hint for train_dataset and adjust MaxcutDataModule to pass split_id. Remove N_FOLDS constant. --- src/matchcake_opt/datamodules/datamodule.py | 33 ++++++++++--------- .../datamodules/maxcut_datamodule.py | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/matchcake_opt/datamodules/datamodule.py b/src/matchcake_opt/datamodules/datamodule.py index f1d2032..f91a3b5 100644 --- a/src/matchcake_opt/datamodules/datamodule.py +++ b/src/matchcake_opt/datamodules/datamodule.py @@ -11,7 +11,7 @@ class DataModule(lightning.LightningDataModule): DEFAULT_RANDOM_STATE = 0 - N_FOLDS = 5 + DEFAULT_TRAIN_VAL_SPLIT = 0.85 DEFAULT_BATCH_SIZE = 32 DEFAULT_NUM_WORKERS = min(2, psutil.cpu_count(logical=True) - 1) @@ -19,7 +19,9 @@ class DataModule(lightning.LightningDataModule): def from_dataset_name( cls, dataset_name: str, - fold_id: int, + split_id: int, + *, + train_val_split: float = DEFAULT_TRAIN_VAL_SPLIT, batch_size: int = DEFAULT_BATCH_SIZE, random_state: int = DEFAULT_RANDOM_STATE, num_workers: int = DEFAULT_NUM_WORKERS, @@ -29,7 +31,8 @@ def from_dataset_name( return cls( train_dataset=get_dataset_cls_by_name(dataset_name)(train=True), test_dataset=get_dataset_cls_by_name(dataset_name)(train=False), - fold_id=fold_id, + split_id=split_id, + train_val_split=train_val_split, batch_size=batch_size, random_state=random_state, num_workers=num_workers, @@ -50,17 +53,21 @@ def __init__( self, train_dataset: BaseDataset, test_dataset: BaseDataset, - fold_id: int, + split_id: int, + *, + train_val_split: float = DEFAULT_TRAIN_VAL_SPLIT, batch_size: int = DEFAULT_BATCH_SIZE, random_state: int = DEFAULT_RANDOM_STATE, num_workers: int = DEFAULT_NUM_WORKERS, ): super().__init__() assert batch_size > 0, f"Batch size must be positive, got {batch_size}" + assert train_val_split > 0, f"Train split must be positive, got {train_val_split}" + assert train_val_split <= 1, f"Train split must be at most 1, got {train_val_split}" + self._train_val_split = train_val_split self._batch_size = batch_size self._random_state = random_state - assert 0 <= fold_id < self.N_FOLDS, f"Fold id {fold_id} is out of range [0, {self.N_FOLDS})" - self._fold_id = fold_id + self._split_id = split_id self._given_train_dataset = train_dataset self._test_dataset = test_dataset self._num_workers = num_workers @@ -74,16 +81,10 @@ def prepare_data(self) -> None: return def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Any, Any]: - fold_ratio = 1 / self.N_FOLDS - subsets = random_split( + train_subset, val_subset = random_split( dataset, - lengths=[fold_ratio for _ in range(self.N_FOLDS)], - generator=torch.Generator().manual_seed(self._random_state), - ) - val_subset = subsets[self._fold_id] - train_subset_indexes = [i for i in range(self.N_FOLDS) if i != self._fold_id] - train_subset: torch.utils.data.Dataset = torch.utils.data.ConcatDataset( - [subsets[i] for i in train_subset_indexes] + lengths=[self._train_val_split, 1 - self._train_val_split], + generator=torch.Generator().manual_seed(self._split_id), ) return train_subset, val_subset @@ -124,7 +125,7 @@ def output_shape(self): return self.test_dataset.get_output_shape() @property - def train_dataset(self) -> Optional[ConcatDataset]: + def train_dataset(self) -> Optional[Subset]: return self._train_dataset @property diff --git a/src/matchcake_opt/datamodules/maxcut_datamodule.py b/src/matchcake_opt/datamodules/maxcut_datamodule.py index 220047e..34d5582 100644 --- a/src/matchcake_opt/datamodules/maxcut_datamodule.py +++ b/src/matchcake_opt/datamodules/maxcut_datamodule.py @@ -38,7 +38,7 @@ def __init__( super().__init__( train_dataset=train_dataset, test_dataset=test_dataset, - fold_id=0, + split_id=0, batch_size=1, random_state=0, num_workers=0, From 6a1a62c92195a0488122fd0f28e1952428b65b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:40:29 -0500 Subject: [PATCH 2/3] Document DataModule and refine typing/API Add class and __init__ docstrings to DataModule to clarify responsibilities and parameters. Tighten type hints by changing _train_dataset to Optional[Subset] and making _split_train_val_dataset return Tuple[Subset, Subset] instead of generic Any. Update MaxcutDataModule.from_dataset_name signature: rename fold_id to split_id, add a train_val_split kw-only parameter, and default batch_size, random_state, and num_workers to DataModule's DEFAULT_* constants for consistent defaults and clearer API. --- src/matchcake_opt/datamodules/datamodule.py | 36 +++++++++++++++++-- .../datamodules/maxcut_datamodule.py | 10 +++--- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/matchcake_opt/datamodules/datamodule.py b/src/matchcake_opt/datamodules/datamodule.py index f91a3b5..1d97125 100644 --- a/src/matchcake_opt/datamodules/datamodule.py +++ b/src/matchcake_opt/datamodules/datamodule.py @@ -10,6 +10,15 @@ class DataModule(lightning.LightningDataModule): + """ + Handles the loading, splitting, and management of datasets used for training, + validation, and testing in a PyTorch Lightning workflow. + + The `DataModule` class provides a standardized interface for working with + datasets, including handling train/validation splits, data loaders, and other + dataset-specific configurations. + """ + DEFAULT_RANDOM_STATE = 0 DEFAULT_TRAIN_VAL_SPLIT = 0.85 DEFAULT_BATCH_SIZE = 32 @@ -60,6 +69,29 @@ def __init__( random_state: int = DEFAULT_RANDOM_STATE, num_workers: int = DEFAULT_NUM_WORKERS, ): + """ + Initializes the class with the provided training and testing datasets, split + information, and relevant parameters. + + :param train_dataset: The dataset to be used for training the model. + :type train_dataset: BaseDataset + :param test_dataset: The dataset to be used for testing the model. + :type test_dataset: BaseDataset + :param split_id: An identifier for tracking or differentiating dataset splits. + :type split_id: int + :param train_val_split: Proportion of the training dataset to be used + for validation. Defaults to DEFAULT_TRAIN_VAL_SPLIT. Must be between 0 and 1. + :type train_val_split: float, optional + :param batch_size: The size of each batch used during data loading. + Defaults to DEFAULT_BATCH_SIZE. Must be a positive integer. + :type batch_size: int, optional + :param random_state: Determines the randomness for reproducibility during + dataset splitting. Defaults to DEFAULT_RANDOM_STATE. + :type random_state: int, optional + :param num_workers: Number of workers to use for data loading. + Defaults to DEFAULT_NUM_WORKERS. + :type num_workers: int, optional + """ super().__init__() assert batch_size > 0, f"Batch size must be positive, got {batch_size}" assert train_val_split > 0, f"Train split must be positive, got {train_val_split}" @@ -71,7 +103,7 @@ def __init__( self._given_train_dataset = train_dataset self._test_dataset = test_dataset self._num_workers = num_workers - self._train_dataset: Optional[ConcatDataset] = None + self._train_dataset: Optional[Subset] = None self._val_dataset: Optional[Subset] = None def prepare_data(self) -> None: @@ -80,7 +112,7 @@ def prepare_data(self) -> None: self._train_dataset, self._val_dataset = self._split_train_val_dataset(self._given_train_dataset) return - def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Any, Any]: + def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Subset, Subset]: train_subset, val_subset = random_split( dataset, lengths=[self._train_val_split, 1 - self._train_val_split], diff --git a/src/matchcake_opt/datamodules/maxcut_datamodule.py b/src/matchcake_opt/datamodules/maxcut_datamodule.py index 34d5582..4becea8 100644 --- a/src/matchcake_opt/datamodules/maxcut_datamodule.py +++ b/src/matchcake_opt/datamodules/maxcut_datamodule.py @@ -20,10 +20,12 @@ def add_specific_args(cls, parent_parser: Optional[argparse.ArgumentParser] = No def from_dataset_name( cls, dataset_name: str, - fold_id: int, - batch_size: int = 0, - random_state: int = 0, - num_workers: int = 0, + split_id: int, + *, + train_val_split: float = DataModule.DEFAULT_TRAIN_VAL_SPLIT, + batch_size: int = DataModule.DEFAULT_BATCH_SIZE, + random_state: int = DataModule.DEFAULT_RANDOM_STATE, + num_workers: int = DataModule.DEFAULT_NUM_WORKERS, ) -> "DataModule": raise NotImplementedError("MaxcutDataModule does not support from_dataset_name method.") # pragma: no cover From 1b61385f99d879cdd3e00bb3782c6eddc9bea5ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:45:02 -0500 Subject: [PATCH 3/3] Use split_id instead of fold_id in notebooks Replace fold_id with split_id in automl_pipeline_tutorial.ipynb and ligthning_pipeline_tutorial.ipynb. Updated the variable declaration and the argument passed to DataModule.from_dataset_name to match the newer API that expects split_id. --- notebooks/automl_pipeline_tutorial.ipynb | 4 ++-- notebooks/ligthning_pipeline_tutorial.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/automl_pipeline_tutorial.ipynb b/notebooks/automl_pipeline_tutorial.ipynb index 2802bab..6eb74b1 100644 --- a/notebooks/automl_pipeline_tutorial.ipynb +++ b/notebooks/automl_pipeline_tutorial.ipynb @@ -246,7 +246,7 @@ "source": [ "# Dataset\n", "dataset_name = \"Digits2D\"\n", - "fold_id = 0\n", + "split_id = 0\n", "batch_size = 32\n", "random_state = 0\n", "num_workers = 0\n", @@ -284,7 +284,7 @@ "source": [ "datamodule = DataModule.from_dataset_name(\n", " dataset_name,\n", - " fold_id=fold_id,\n", + " split_id=split_id,\n", " batch_size=batch_size,\n", " random_state=random_state,\n", " num_workers=num_workers,\n", diff --git a/notebooks/ligthning_pipeline_tutorial.ipynb b/notebooks/ligthning_pipeline_tutorial.ipynb index dacad60..90c25a9 100644 --- a/notebooks/ligthning_pipeline_tutorial.ipynb +++ b/notebooks/ligthning_pipeline_tutorial.ipynb @@ -232,7 +232,7 @@ "source": [ "# Dataset\n", "dataset_name = \"Digits2D\"\n", - "fold_id = 0\n", + "split_id = 0\n", "batch_size = 32\n", "random_state = 0\n", "num_workers = 0\n", @@ -265,7 +265,7 @@ "source": [ "datamodule = DataModule.from_dataset_name(\n", " dataset_name,\n", - " fold_id=fold_id,\n", + " split_id=split_id,\n", " batch_size=batch_size,\n", " random_state=random_state,\n", " num_workers=num_workers,\n",