Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions notebooks/automl_pipeline_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"source": [
"# Dataset\n",
"dataset_name = \"Digits2D\"\n",
"fold_id = 0\n",
"split_id = 0\n",
"batch_size = 32\n",
"random_state = 0\n",
"num_workers = 0\n",
Expand Down Expand Up @@ -284,7 +284,7 @@
"source": [
"datamodule = DataModule.from_dataset_name(\n",
" dataset_name,\n",
" fold_id=fold_id,\n",
" split_id=split_id,\n",
" batch_size=batch_size,\n",
" random_state=random_state,\n",
" num_workers=num_workers,\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/ligthning_pipeline_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@
"source": [
"# Dataset\n",
"dataset_name = \"Digits2D\"\n",
"fold_id = 0\n",
"split_id = 0\n",
"batch_size = 32\n",
"random_state = 0\n",
"num_workers = 0\n",
Expand Down Expand Up @@ -265,7 +265,7 @@
"source": [
"datamodule = DataModule.from_dataset_name(\n",
" dataset_name,\n",
" fold_id=fold_id,\n",
" split_id=split_id,\n",
" batch_size=batch_size,\n",
" random_state=random_state,\n",
" num_workers=num_workers,\n",
Expand Down
69 changes: 51 additions & 18 deletions src/matchcake_opt/datamodules/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,27 @@


class DataModule(lightning.LightningDataModule):
"""
Handles the loading, splitting, and management of datasets used for training,
validation, and testing in a PyTorch Lightning workflow.

The `DataModule` class provides a standardized interface for working with
datasets, including handling train/validation splits, data loaders, and other
dataset-specific configurations.
"""

DEFAULT_RANDOM_STATE = 0
N_FOLDS = 5
DEFAULT_TRAIN_VAL_SPLIT = 0.85
DEFAULT_BATCH_SIZE = 32
DEFAULT_NUM_WORKERS = min(2, psutil.cpu_count(logical=True) - 1)

@classmethod
def from_dataset_name(
cls,
dataset_name: str,
fold_id: int,
split_id: int,
*,
train_val_split: float = DEFAULT_TRAIN_VAL_SPLIT,
batch_size: int = DEFAULT_BATCH_SIZE,
random_state: int = DEFAULT_RANDOM_STATE,
num_workers: int = DEFAULT_NUM_WORKERS,
Expand All @@ -29,7 +40,8 @@ def from_dataset_name(
return cls(
train_dataset=get_dataset_cls_by_name(dataset_name)(train=True),
test_dataset=get_dataset_cls_by_name(dataset_name)(train=False),
fold_id=fold_id,
split_id=split_id,
train_val_split=train_val_split,
batch_size=batch_size,
random_state=random_state,
num_workers=num_workers,
Expand All @@ -50,21 +62,48 @@ def __init__(
self,
train_dataset: BaseDataset,
test_dataset: BaseDataset,
fold_id: int,
split_id: int,
*,
train_val_split: float = DEFAULT_TRAIN_VAL_SPLIT,
batch_size: int = DEFAULT_BATCH_SIZE,
random_state: int = DEFAULT_RANDOM_STATE,
num_workers: int = DEFAULT_NUM_WORKERS,
):
"""
Initializes the class with the provided training and testing datasets, split
information, and relevant parameters.

:param train_dataset: The dataset to be used for training the model.
:type train_dataset: BaseDataset
:param test_dataset: The dataset to be used for testing the model.
:type test_dataset: BaseDataset
:param split_id: An identifier for tracking or differentiating dataset splits.
:type split_id: int
:param train_val_split: Proportion of the training dataset to be used
for validation. Defaults to DEFAULT_TRAIN_VAL_SPLIT. Must be between 0 and 1.
:type train_val_split: float, optional
:param batch_size: The size of each batch used during data loading.
Defaults to DEFAULT_BATCH_SIZE. Must be a positive integer.
:type batch_size: int, optional
:param random_state: Determines the randomness for reproducibility during
dataset splitting. Defaults to DEFAULT_RANDOM_STATE.
:type random_state: int, optional
:param num_workers: Number of workers to use for data loading.
Defaults to DEFAULT_NUM_WORKERS.
:type num_workers: int, optional
"""
super().__init__()
assert batch_size > 0, f"Batch size must be positive, got {batch_size}"
assert train_val_split > 0, f"Train split must be positive, got {train_val_split}"
assert train_val_split <= 1, f"Train split must be at most 1, got {train_val_split}"
self._train_val_split = train_val_split
self._batch_size = batch_size
self._random_state = random_state
assert 0 <= fold_id < self.N_FOLDS, f"Fold id {fold_id} is out of range [0, {self.N_FOLDS})"
self._fold_id = fold_id
self._split_id = split_id
self._given_train_dataset = train_dataset
self._test_dataset = test_dataset
self._num_workers = num_workers
self._train_dataset: Optional[ConcatDataset] = None
self._train_dataset: Optional[Subset] = None
self._val_dataset: Optional[Subset] = None

def prepare_data(self) -> None:
Expand All @@ -73,17 +112,11 @@ def prepare_data(self) -> None:
self._train_dataset, self._val_dataset = self._split_train_val_dataset(self._given_train_dataset)
return

def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Any, Any]:
fold_ratio = 1 / self.N_FOLDS
subsets = random_split(
def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Subset, Subset]:
train_subset, val_subset = random_split(
dataset,
lengths=[fold_ratio for _ in range(self.N_FOLDS)],
generator=torch.Generator().manual_seed(self._random_state),
)
val_subset = subsets[self._fold_id]
train_subset_indexes = [i for i in range(self.N_FOLDS) if i != self._fold_id]
train_subset: torch.utils.data.Dataset = torch.utils.data.ConcatDataset(
[subsets[i] for i in train_subset_indexes]
lengths=[self._train_val_split, 1 - self._train_val_split],
generator=torch.Generator().manual_seed(self._split_id),
)
return train_subset, val_subset

Expand Down Expand Up @@ -124,7 +157,7 @@ def output_shape(self):
return self.test_dataset.get_output_shape()

@property
def train_dataset(self) -> Optional[ConcatDataset]:
def train_dataset(self) -> Optional[Subset]:
return self._train_dataset

@property
Expand Down
12 changes: 7 additions & 5 deletions src/matchcake_opt/datamodules/maxcut_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def add_specific_args(cls, parent_parser: Optional[argparse.ArgumentParser] = No
def from_dataset_name(
cls,
dataset_name: str,
fold_id: int,
batch_size: int = 0,
random_state: int = 0,
num_workers: int = 0,
split_id: int,
*,
train_val_split: float = DataModule.DEFAULT_TRAIN_VAL_SPLIT,
batch_size: int = DataModule.DEFAULT_BATCH_SIZE,
random_state: int = DataModule.DEFAULT_RANDOM_STATE,
num_workers: int = DataModule.DEFAULT_NUM_WORKERS,
) -> "DataModule":
raise NotImplementedError("MaxcutDataModule does not support from_dataset_name method.") # pragma: no cover

Expand All @@ -38,7 +40,7 @@ def __init__(
super().__init__(
train_dataset=train_dataset,
test_dataset=test_dataset,
fold_id=0,
split_id=0,
batch_size=1,
random_state=0,
num_workers=0,
Expand Down