diff --git a/notebooks/automl_pipeline_tutorial.ipynb b/notebooks/automl_pipeline_tutorial.ipynb index 2802bab..6eb74b1 100644 --- a/notebooks/automl_pipeline_tutorial.ipynb +++ b/notebooks/automl_pipeline_tutorial.ipynb @@ -246,7 +246,7 @@ "source": [ "# Dataset\n", "dataset_name = \"Digits2D\"\n", - "fold_id = 0\n", + "split_id = 0\n", "batch_size = 32\n", "random_state = 0\n", "num_workers = 0\n", @@ -284,7 +284,7 @@ "source": [ "datamodule = DataModule.from_dataset_name(\n", " dataset_name,\n", - " fold_id=fold_id,\n", + " split_id=split_id,\n", " batch_size=batch_size,\n", " random_state=random_state,\n", " num_workers=num_workers,\n", diff --git a/notebooks/ligthning_pipeline_tutorial.ipynb b/notebooks/ligthning_pipeline_tutorial.ipynb index dacad60..90c25a9 100644 --- a/notebooks/ligthning_pipeline_tutorial.ipynb +++ b/notebooks/ligthning_pipeline_tutorial.ipynb @@ -232,7 +232,7 @@ "source": [ "# Dataset\n", "dataset_name = \"Digits2D\"\n", - "fold_id = 0\n", + "split_id = 0\n", "batch_size = 32\n", "random_state = 0\n", "num_workers = 0\n", @@ -265,7 +265,7 @@ "source": [ "datamodule = DataModule.from_dataset_name(\n", " dataset_name,\n", - " fold_id=fold_id,\n", + " split_id=split_id,\n", " batch_size=batch_size,\n", " random_state=random_state,\n", " num_workers=num_workers,\n", diff --git a/src/matchcake_opt/datamodules/datamodule.py b/src/matchcake_opt/datamodules/datamodule.py index f1d2032..1d97125 100644 --- a/src/matchcake_opt/datamodules/datamodule.py +++ b/src/matchcake_opt/datamodules/datamodule.py @@ -10,8 +10,17 @@ class DataModule(lightning.LightningDataModule): + """ + Handles the loading, splitting, and management of datasets used for training, + validation, and testing in a PyTorch Lightning workflow. + + The `DataModule` class provides a standardized interface for working with + datasets, including handling train/validation splits, data loaders, and other + dataset-specific configurations. + """ + DEFAULT_RANDOM_STATE = 0 - N_FOLDS = 5 + DEFAULT_TRAIN_VAL_SPLIT = 0.85 DEFAULT_BATCH_SIZE = 32 DEFAULT_NUM_WORKERS = min(2, psutil.cpu_count(logical=True) - 1) @@ -19,7 +28,9 @@ class DataModule(lightning.LightningDataModule): def from_dataset_name( cls, dataset_name: str, - fold_id: int, + split_id: int, + *, + train_val_split: float = DEFAULT_TRAIN_VAL_SPLIT, batch_size: int = DEFAULT_BATCH_SIZE, random_state: int = DEFAULT_RANDOM_STATE, num_workers: int = DEFAULT_NUM_WORKERS, @@ -29,7 +40,8 @@ def from_dataset_name( return cls( train_dataset=get_dataset_cls_by_name(dataset_name)(train=True), test_dataset=get_dataset_cls_by_name(dataset_name)(train=False), - fold_id=fold_id, + split_id=split_id, + train_val_split=train_val_split, batch_size=batch_size, random_state=random_state, num_workers=num_workers, @@ -50,21 +62,48 @@ def __init__( self, train_dataset: BaseDataset, test_dataset: BaseDataset, - fold_id: int, + split_id: int, + *, + train_val_split: float = DEFAULT_TRAIN_VAL_SPLIT, batch_size: int = DEFAULT_BATCH_SIZE, random_state: int = DEFAULT_RANDOM_STATE, num_workers: int = DEFAULT_NUM_WORKERS, ): + """ + Initializes the class with the provided training and testing datasets, split + information, and relevant parameters. + + :param train_dataset: The dataset to be used for training the model. + :type train_dataset: BaseDataset + :param test_dataset: The dataset to be used for testing the model. + :type test_dataset: BaseDataset + :param split_id: An identifier for tracking or differentiating dataset splits. + :type split_id: int + :param train_val_split: Proportion of the training dataset to be used + for validation. Defaults to DEFAULT_TRAIN_VAL_SPLIT. Must be between 0 and 1. + :type train_val_split: float, optional + :param batch_size: The size of each batch used during data loading. + Defaults to DEFAULT_BATCH_SIZE. Must be a positive integer. + :type batch_size: int, optional + :param random_state: Determines the randomness for reproducibility during + dataset splitting. Defaults to DEFAULT_RANDOM_STATE. + :type random_state: int, optional + :param num_workers: Number of workers to use for data loading. + Defaults to DEFAULT_NUM_WORKERS. + :type num_workers: int, optional + """ super().__init__() assert batch_size > 0, f"Batch size must be positive, got {batch_size}" + assert train_val_split > 0, f"Train split must be positive, got {train_val_split}" + assert train_val_split <= 1, f"Train split must be at most 1, got {train_val_split}" + self._train_val_split = train_val_split self._batch_size = batch_size self._random_state = random_state - assert 0 <= fold_id < self.N_FOLDS, f"Fold id {fold_id} is out of range [0, {self.N_FOLDS})" - self._fold_id = fold_id + self._split_id = split_id self._given_train_dataset = train_dataset self._test_dataset = test_dataset self._num_workers = num_workers - self._train_dataset: Optional[ConcatDataset] = None + self._train_dataset: Optional[Subset] = None self._val_dataset: Optional[Subset] = None def prepare_data(self) -> None: @@ -73,17 +112,11 @@ def prepare_data(self) -> None: self._train_dataset, self._val_dataset = self._split_train_val_dataset(self._given_train_dataset) return - def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Any, Any]: - fold_ratio = 1 / self.N_FOLDS - subsets = random_split( + def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Subset, Subset]: + train_subset, val_subset = random_split( dataset, - lengths=[fold_ratio for _ in range(self.N_FOLDS)], - generator=torch.Generator().manual_seed(self._random_state), - ) - val_subset = subsets[self._fold_id] - train_subset_indexes = [i for i in range(self.N_FOLDS) if i != self._fold_id] - train_subset: torch.utils.data.Dataset = torch.utils.data.ConcatDataset( - [subsets[i] for i in train_subset_indexes] + lengths=[self._train_val_split, 1 - self._train_val_split], + generator=torch.Generator().manual_seed(self._split_id), ) return train_subset, val_subset @@ -124,7 +157,7 @@ def output_shape(self): return self.test_dataset.get_output_shape() @property - def train_dataset(self) -> Optional[ConcatDataset]: + def train_dataset(self) -> Optional[Subset]: return self._train_dataset @property diff --git a/src/matchcake_opt/datamodules/maxcut_datamodule.py b/src/matchcake_opt/datamodules/maxcut_datamodule.py index 220047e..4becea8 100644 --- a/src/matchcake_opt/datamodules/maxcut_datamodule.py +++ b/src/matchcake_opt/datamodules/maxcut_datamodule.py @@ -20,10 +20,12 @@ def add_specific_args(cls, parent_parser: Optional[argparse.ArgumentParser] = No def from_dataset_name( cls, dataset_name: str, - fold_id: int, - batch_size: int = 0, - random_state: int = 0, - num_workers: int = 0, + split_id: int, + *, + train_val_split: float = DataModule.DEFAULT_TRAIN_VAL_SPLIT, + batch_size: int = DataModule.DEFAULT_BATCH_SIZE, + random_state: int = DataModule.DEFAULT_RANDOM_STATE, + num_workers: int = DataModule.DEFAULT_NUM_WORKERS, ) -> "DataModule": raise NotImplementedError("MaxcutDataModule does not support from_dataset_name method.") # pragma: no cover @@ -38,7 +40,7 @@ def __init__( super().__init__( train_dataset=train_dataset, test_dataset=test_dataset, - fold_id=0, + split_id=0, batch_size=1, random_state=0, num_workers=0,