Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0d59a44
Add Maxcut dataset and model with torch-geometric support
JeremieGince Oct 17, 2025
dde02ae
Fix import paths for DataModule references
JeremieGince Oct 17, 2025
077ddcd
Update dataset import in notebooks
JeremieGince Oct 17, 2025
0a3e550
Raise MisconfigurationException in val_dataloader
JeremieGince Oct 17, 2025
dadf22e
Implement training, validation, and test steps in MaxcutModel
JeremieGince Oct 17, 2025
75066dd
Refactor predict method and validation metrics handling
JeremieGince Oct 17, 2025
24396dd
Refactor test_step to return metrics components
JeremieGince Oct 17, 2025
fb46bd5
Add bitstrings_to_arr utility to MaxcutModel
JeremieGince Oct 17, 2025
b77c0b7
Refactor MaxcutModel metrics and remove unused methods
JeremieGince Oct 17, 2025
c70cd45
Refactor dataset preparation and data module workflow
JeremieGince Oct 18, 2025
2a5bfa7
Add support for circular graphs in MaxcutDataset
JeremieGince Oct 18, 2025
48fc255
Update maxcut_dataset.py
JeremieGince Oct 18, 2025
bcd5a56
Add ckpt_path parameter to run_test method
JeremieGince Oct 18, 2025
79bab81
Implement proper val_dataloader and type annotations
JeremieGince Oct 19, 2025
e1daf86
Fallback to 'last' checkpoint if 'best' is unavailable
JeremieGince Oct 20, 2025
b1400e1
Handle SearchSpaceExhausted in AutoML pipeline
JeremieGince Oct 20, 2025
5f27a8e
Add checkpoint folder cleanup on overwrite fit
JeremieGince Oct 20, 2025
7686d85
Add CUDA 13.0 (cu130) support to dependencies
JeremieGince Oct 29, 2025
42cd977
Refactor datamodule imports and add MaxcutDataset tests
JeremieGince Nov 2, 2025
bd055d8
Refactor and expand Maxcut dataset tests
JeremieGince Nov 2, 2025
7608a49
Add tests for MaxcutDataModule and MaxcutModel
JeremieGince Nov 2, 2025
99dc64a
Reduce max_time for training in tutorial notebooks
JeremieGince Nov 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,4 @@ coverage.json
/notebooks/Digits2D
/notebooks/Cifar10
/notebooks/data
/.tmp
4 changes: 2 additions & 2 deletions notebooks/automl_pipeline_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
")\n",
"from torchvision.transforms import Resize\n",
"\n",
"from matchcake_opt.datasets import *\n",
"from matchcake_opt.datamodules.datamodule import DataModule\n",
"from matchcake_opt.modules.classification_model import ClassificationModel\n",
"from matchcake_opt.tr_pipeline.automl_pipeline import AutoMLPipeline"
],
Expand Down Expand Up @@ -259,7 +259,7 @@
"checkpoint_folder = Path(job_output_folder) / \"checkpoints\"\n",
"pipeline_args = dict(\n",
" max_epochs=100, # increase at least to 256\n",
" max_time=\"00:00:02:00\", # DD:HH:MM:SS, increase at least to \"00:01:00:00\"\n",
" max_time=\"00:00:01:00\", # DD:HH:MM:SS, increase at least to \"00:01:00:00\"\n",
")"
],
"id": "d8db16a0825411",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/ligthning_pipeline_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
")\n",
"from torchvision.transforms import Resize\n",
"\n",
"from matchcake_opt.datasets import *\n",
"from matchcake_opt.datamodules.datamodule import DataModule\n",
"from matchcake_opt.modules.classification_model import ClassificationModel\n",
"from matchcake_opt.tr_pipeline.lightning_pipeline import LightningPipeline"
],
Expand Down Expand Up @@ -275,7 +275,7 @@
" datamodule=datamodule,\n",
" checkpoint_folder=checkpoint_folder,\n",
" max_epochs=10,\n",
" max_time=\"00:00:03:00\", # DD:HH:MM:SS\n",
" max_time=\"00:00:01:00\", # DD:HH:MM:SS\n",
" overwrite_fit=True,\n",
" verbose=True,\n",
" **model_args,\n",
Expand Down
6 changes: 3 additions & 3 deletions notebooks/nif_deep_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"from matchcake import NonInteractingFermionicDevice\n",
"from matchcake.operations import SptmAngleEmbedding, SptmfRxRx, SptmFHH\n",
"\n",
"from matchcake_opt.datasets import *\n",
"from matchcake_opt.datamodules.datamodule import DataModule\n",
"from matchcake_opt.modules.classification_model import ClassificationModel\n",
"from matchcake_opt.tr_pipeline.automl_pipeline import AutoMLPipeline\n",
"from matchcake_opt.tr_pipeline.lightning_pipeline import LightningPipeline"
Expand Down Expand Up @@ -157,7 +157,7 @@
"checkpoint_folder = Path(job_output_folder) / \"checkpoints\"\n",
"pipeline_args = dict(\n",
" max_epochs=128, # increase at least to 256\n",
" max_time=\"00:00:02:00\", # DD:HH:MM:SS, increase at least to \"00:01:00:00\"\n",
" max_time=\"00:00:01:00\", # DD:HH:MM:SS, increase at least to \"00:01:00:00\"\n",
")"
],
"id": "412328c44c55e453",
Expand Down Expand Up @@ -211,7 +211,7 @@
" datamodule=datamodule,\n",
" checkpoint_folder=checkpoint_folder,\n",
" max_epochs=10,\n",
" max_time=\"00:00:03:00\", # DD:HH:MM:SS\n",
" max_time=\"00:00:01:00\", # DD:HH:MM:SS\n",
" overwrite_fit=True,\n",
" verbose=True,\n",
" **model_args,\n",
Expand Down
20 changes: 19 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"psutil>=5.9.6",
"importlib-metadata (>=8.7.0,<9.0.0)",
"torch (>=2.6.0,<3.0.0)",
"torchvision (>=0.21.0,<0.23.0)",
"torchvision (>=0.21.0)",
"torchaudio (>=2.6.0,<3.0.0)",
"lightning (>=2.5.2,<3.0.0)",
"tensorboardx (>=2.6.4,<3.0.0)",
Expand All @@ -32,6 +32,7 @@ dependencies = [
"matchcake (>=0.0.4,<0.0.5)",
"autoray (<=0.7.2)",
"medmnist (>=3.0.2,<4.0.0)",
"torch-geometric>=2.7.0",
]
dynamic = ["readme"]

Expand All @@ -49,6 +50,7 @@ dev = [
"twine>=6.1.0,<7",
"pytest-xdist>=3.7.0,<4",
"isort>=6.0.1,<7",
"types-networkx>=3.5.0.20251001",
]
docs = [
"sphinx>=6.2.1,<6.3.0",
Expand Down Expand Up @@ -81,6 +83,7 @@ conflicts = [
[
{ extra = "cpu" },
{ extra = "cu128" },
{ extra = "cu130" },
],
]

Expand All @@ -93,15 +96,21 @@ cu128 = [
"torch>=2.7.0",
"torchvision>=0.22.0",
]
cu130 = [
"torch>=2.7.0",
"torchvision>=0.23.0",
]

[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu130", extra = "cu130" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu130", extra = "cu130" },
]

[[tool.uv.index]]
Expand All @@ -114,6 +123,11 @@ name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu130"
url = "https://download.pytorch.org/whl/cu130"
explicit = true

[tool.setuptools.dynamic]
readme = {file = "README.md", content-type = "text/markdown"}

Expand Down Expand Up @@ -165,6 +179,10 @@ module = [
"pandas",
"psutil",
"matchcake.utils",
"torch_geometric.data",
"torch_geometric.utils",
"matchcake.utils.torch_utils",
"torch_geometric.loader",
]
ignore_missing_imports = true

Expand Down
3 changes: 3 additions & 0 deletions src/matchcake_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@

warnings.filterwarnings("ignore", category=Warning, module="docutils")
warnings.filterwarnings("ignore", category=Warning, module="sphinx")

from .datamodules import DataModule
from .datasets import get_dataset_cls_by_name
1 change: 1 addition & 0 deletions src/matchcake_opt/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .datamodule import DataModule
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
from typing import Optional
from typing import Any, Optional, Tuple

import lightning
import psutil
import torch
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset, random_split

from .base_dataset import BaseDataset
from ..datasets.base_dataset import BaseDataset


class DataModule(lightning.LightningDataModule):
Expand All @@ -24,7 +24,7 @@ def from_dataset_name(
random_state: int = DEFAULT_RANDOM_STATE,
num_workers: int = DEFAULT_NUM_WORKERS,
) -> "DataModule":
from . import get_dataset_cls_by_name
from ..datasets import get_dataset_cls_by_name

return cls(
train_dataset=get_dataset_cls_by_name(dataset_name)(train=True),
Expand Down Expand Up @@ -61,11 +61,19 @@ def __init__(
self._random_state = random_state
assert 0 <= fold_id < self.N_FOLDS, f"Fold id {fold_id} is out of range [0, {self.N_FOLDS})"
self._fold_id = fold_id
self._train_dataset, self._val_dataset = self._split_train_val_dataset(train_dataset)
self._given_train_dataset = train_dataset
self._test_dataset = test_dataset
self._num_workers = num_workers
self._train_dataset: Optional[ConcatDataset] = None
self._val_dataset: Optional[Subset] = None

def _split_train_val_dataset(self, dataset: Dataset):
def prepare_data(self) -> None:
self._given_train_dataset.prepare_data()
self._test_dataset.prepare_data()
self._train_dataset, self._val_dataset = self._split_train_val_dataset(self._given_train_dataset)
return

def _split_train_val_dataset(self, dataset: Dataset) -> Tuple[Any, Any]:
fold_ratio = 1 / self.N_FOLDS
subsets = random_split(
dataset,
Expand Down Expand Up @@ -116,11 +124,11 @@ def output_shape(self):
return self.test_dataset.get_output_shape()

@property
def train_dataset(self) -> ConcatDataset:
def train_dataset(self) -> Optional[ConcatDataset]:
return self._train_dataset

@property
def val_dataset(self) -> Subset:
def val_dataset(self) -> Optional[Subset]:
return self._val_dataset

@property
Expand Down
97 changes: 97 additions & 0 deletions src/matchcake_opt/datamodules/maxcut_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
from copy import deepcopy
from typing import Optional

from torch_geometric.loader import DataLoader

from ..datasets.maxcut_dataset import MaxcutDataset
from .datamodule import DataModule


class MaxcutDataModule(DataModule):
@classmethod
def add_specific_args(cls, parent_parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
if parent_parser is None:
parent_parser = argparse.ArgumentParser()
parser = parent_parser.add_argument_group(f"{cls.__name__} Arguments")
return parent_parser

@classmethod
def from_dataset_name(
cls,
dataset_name: str,
fold_id: int,
batch_size: int = 0,
random_state: int = 0,
num_workers: int = 0,
) -> "DataModule":
raise NotImplementedError("MaxcutDataModule does not support from_dataset_name method.") # pragma: no cover

def __init__(
self,
train_dataset: MaxcutDataset,
test_dataset: Optional[MaxcutDataset] = None,
):
if test_dataset is None:
test_dataset = deepcopy(train_dataset)
train_dataset.train = False
super().__init__(
train_dataset=train_dataset,
test_dataset=test_dataset,
fold_id=0,
batch_size=1,
random_state=0,
num_workers=0,
)

def _split_train_val_dataset(self, dataset: MaxcutDataset): # type: ignore
return dataset, None

def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self._batch_size,
shuffle=False,
num_workers=self._num_workers,
persistent_workers=self._num_workers > 0,
pin_memory=True,
)

def val_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self._batch_size,
shuffle=False,
num_workers=self._num_workers,
persistent_workers=self._num_workers > 0,
pin_memory=True,
)

def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self._batch_size,
num_workers=self._num_workers,
persistent_workers=self._num_workers > 0,
pin_memory=True,
)

@property
def input_shape(self):
return self.test_dataset.get_input_shape()

@property
def output_shape(self):
return self.test_dataset.get_output_shape()

@property
def train_dataset(self) -> MaxcutDataset: # type: ignore
return self._train_dataset # type: ignore

@property
def val_dataset(self):
return self._val_dataset

@property
def test_dataset(self) -> MaxcutDataset: # type: ignore
return self._test_dataset # type: ignore
1 change: 0 additions & 1 deletion src/matchcake_opt/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..utils import get_all_subclasses
from .base_dataset import BaseDataset
from .cifar10_dataset import Cifar10Dataset
from .datamodule import DataModule
from .digits2d import Digits2D
from .mnist_dataset import MNISTDataset
from .pathmnist_dataset import PathMNISTDataset
Expand Down
4 changes: 4 additions & 0 deletions src/matchcake_opt/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME,
self._data_dir = Path(data_dir)
self._data_dir.mkdir(parents=True, exist_ok=True)
self._train = train
self._kwargs = kwargs

def __getitem__(self, item):
raise NotImplementedError()

def __len__(self):
raise NotImplementedError()

def prepare_data(self):
return

def get_input_shape(self) -> tuple:
return tuple(self[0][0].shape) # pragma: no cover

Expand Down
Loading