From 307feba99233456992297f5f733356ddbdf79790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Sun, 2 Nov 2025 09:32:04 -0500 Subject: [PATCH 1/2] Add RetinaMNIST dataset and tests Introduces the RetinaMNISTDataset class for handling the RetinaMNIST dataset, including data loading, transformation, and output shape methods. Adds comprehensive unit tests to verify dataset initialization, item retrieval, tensor conversion, length, and output shape. --- .../datasets/retinamnist_dataset.py | 56 +++++++++++++++++ .../test_datasets/test_retinamnist_dataset.py | 63 +++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 src/matchcake_opt/datasets/retinamnist_dataset.py create mode 100644 tests/test_datasets/test_retinamnist_dataset.py diff --git a/src/matchcake_opt/datasets/retinamnist_dataset.py b/src/matchcake_opt/datasets/retinamnist_dataset.py new file mode 100644 index 0000000..8d89195 --- /dev/null +++ b/src/matchcake_opt/datasets/retinamnist_dataset.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Sequence, Union + +import numpy as np +import torch +from medmnist import RetinaMNIST +from torch.utils.data import ConcatDataset +from torchvision import transforms +from torchvision.transforms import ToTensor +from torchvision.transforms.v2 import Compose + +from .base_dataset import BaseDataset + + +class RetinaMNISTDataset(BaseDataset): + DATASET_NAME = "RetinaMNIST" + + @staticmethod + def to_scalar_tensor(y: Sequence[int]): + return torch.tensor(y).item() + + @staticmethod + def to_long_tensor(y: Sequence[int]) -> torch.Tensor: + return torch.tensor(y, dtype=torch.long) + + def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs): + super().__init__(data_dir, train, **kwargs) + transform = Compose( + [ + ToTensor(), + transforms.Normalize(0.0, 1.0), + ] + ) + target_transform = Compose([self.to_scalar_tensor, self.to_long_tensor]) + self._data = RetinaMNIST( + root=self.data_dir, + split="train" if self.train else "test", + download=True, + transform=transform, + target_transform=target_transform, + ) + self._n_classes = np.unique(self._data.labels).size + if self.train: + val_dataset = RetinaMNIST( + root=self.data_dir, split="val", download=True, transform=transform, target_transform=target_transform + ) + self._data = ConcatDataset([self._data, val_dataset]) + + def __getitem__(self, item): + return self._data[item] + + def __len__(self): + return len(self._data) + + def get_output_shape(self) -> tuple: + return (self._n_classes,) diff --git a/tests/test_datasets/test_retinamnist_dataset.py b/tests/test_datasets/test_retinamnist_dataset.py new file mode 100644 index 0000000..29ac6a3 --- /dev/null +++ b/tests/test_datasets/test_retinamnist_dataset.py @@ -0,0 +1,63 @@ +import shutil +from pathlib import Path +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch + +from matchcake_opt.datasets.retinamnist_dataset import RetinaMNISTDataset + + +class TestRetinaMNISTDataset: + MOCK_LEN = 10 + + @pytest.fixture(scope="class") + def data_dir(self): + path = Path(".tmp") / "data_dir" / "pathmnist" + yield path + shutil.rmtree(path, ignore_errors=True) + + @pytest.fixture + def data_mock(self, monkeypatch): + cls_mock = MagicMock() + monkeypatch.setattr("matchcake_opt.datasets.pathmnist_dataset.PathMNIST", cls_mock) + mock = MagicMock() + cls_mock.return_value = mock + mock.__getitem__.return_value = (torch.zeros(28, 28), torch.zeros(1).long()) + mock.__len__.return_value = self.MOCK_LEN + mock.labels = np.arange(9) + monkeypatch.setattr("matchcake_opt.datasets.pathmnist_dataset.ConcatDataset", lambda *x: mock) + return mock + + @pytest.fixture + def dataset_instance(self, data_mock, data_dir): + return RetinaMNISTDataset(data_dir=data_dir, train=True) + + def test_init(self, data_mock, data_dir): + dataset = RetinaMNISTDataset(data_dir=data_dir, train=True) + assert dataset._data == data_mock + + def test_getitem(self, data_mock, dataset_instance): + datum = dataset_instance[0] + assert isinstance(datum, tuple) + assert isinstance(datum[0], torch.Tensor) + assert isinstance(datum[1], torch.Tensor) + assert datum[1].dtype == torch.long + data_mock.__getitem__.assert_called_once_with(0) + + def test_to_long_tensor(self, dataset_instance): + x = torch.zeros(0, dtype=torch.int) + y = dataset_instance.to_long_tensor(x) + assert y.dtype == torch.long + + def test_len(self, dataset_instance): + assert len(dataset_instance) == self.MOCK_LEN + + def test_output_shape(self, dataset_instance): + assert dataset_instance.get_output_shape() == (5,) + + def test_to_scalar_tensor(self, dataset_instance): + x = torch.tensor([1]) + y = dataset_instance.to_scalar_tensor(x) + assert isinstance(y, int) From a80737d3dd8a4fce85009da5e46438819744969f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Sun, 2 Nov 2025 09:44:07 -0500 Subject: [PATCH 2/2] Update RetinaMNISTDataset import and test mocks Added RetinaMNISTDataset to the datasets module import. Updated the RetinaMNIST dataset test to use the correct dataset name, mock class, and label shape, ensuring consistency with the actual dataset implementation. --- src/matchcake_opt/datasets/__init__.py | 1 + tests/test_datasets/test_retinamnist_dataset.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/matchcake_opt/datasets/__init__.py b/src/matchcake_opt/datasets/__init__.py index ca9904e..a7d5d46 100644 --- a/src/matchcake_opt/datasets/__init__.py +++ b/src/matchcake_opt/datasets/__init__.py @@ -7,6 +7,7 @@ from .digits2d import Digits2D from .mnist_dataset import MNISTDataset from .pathmnist_dataset import PathMNISTDataset +from .retinamnist_dataset import RetinaMNISTDataset dataset_name_to_type_map: Dict[str, Type[BaseDataset]] = { _cls.DATASET_NAME: _cls diff --git a/tests/test_datasets/test_retinamnist_dataset.py b/tests/test_datasets/test_retinamnist_dataset.py index 29ac6a3..9dbe159 100644 --- a/tests/test_datasets/test_retinamnist_dataset.py +++ b/tests/test_datasets/test_retinamnist_dataset.py @@ -14,20 +14,20 @@ class TestRetinaMNISTDataset: @pytest.fixture(scope="class") def data_dir(self): - path = Path(".tmp") / "data_dir" / "pathmnist" + path = Path(".tmp") / "data_dir" / "retinamnist" yield path shutil.rmtree(path, ignore_errors=True) @pytest.fixture def data_mock(self, monkeypatch): cls_mock = MagicMock() - monkeypatch.setattr("matchcake_opt.datasets.pathmnist_dataset.PathMNIST", cls_mock) + monkeypatch.setattr("matchcake_opt.datasets.retinamnist_dataset.RetinaMNIST", cls_mock) mock = MagicMock() cls_mock.return_value = mock - mock.__getitem__.return_value = (torch.zeros(28, 28), torch.zeros(1).long()) + mock.__getitem__.return_value = (torch.zeros(3, 28, 28), torch.zeros(1).long()) mock.__len__.return_value = self.MOCK_LEN - mock.labels = np.arange(9) - monkeypatch.setattr("matchcake_opt.datasets.pathmnist_dataset.ConcatDataset", lambda *x: mock) + mock.labels = np.arange(5) + monkeypatch.setattr("matchcake_opt.datasets.retinamnist_dataset.ConcatDataset", lambda *x: mock) return mock @pytest.fixture