diff --git a/src/matchcake_opt/datasets/__init__.py b/src/matchcake_opt/datasets/__init__.py index adf6066..5f2803f 100644 --- a/src/matchcake_opt/datasets/__init__.py +++ b/src/matchcake_opt/datasets/__init__.py @@ -6,6 +6,7 @@ from .digits2d import Digits2D from .mnist_dataset import MNISTDataset from .pathmnist_dataset import PathMNISTDataset +from .retinamnist_dataset import RetinaMNISTDataset dataset_name_to_type_map: Dict[str, Type[BaseDataset]] = { _cls.DATASET_NAME: _cls diff --git a/src/matchcake_opt/datasets/retinamnist_dataset.py b/src/matchcake_opt/datasets/retinamnist_dataset.py new file mode 100644 index 0000000..8d89195 --- /dev/null +++ b/src/matchcake_opt/datasets/retinamnist_dataset.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Sequence, Union + +import numpy as np +import torch +from medmnist import RetinaMNIST +from torch.utils.data import ConcatDataset +from torchvision import transforms +from torchvision.transforms import ToTensor +from torchvision.transforms.v2 import Compose + +from .base_dataset import BaseDataset + + +class RetinaMNISTDataset(BaseDataset): + DATASET_NAME = "RetinaMNIST" + + @staticmethod + def to_scalar_tensor(y: Sequence[int]): + return torch.tensor(y).item() + + @staticmethod + def to_long_tensor(y: Sequence[int]) -> torch.Tensor: + return torch.tensor(y, dtype=torch.long) + + def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs): + super().__init__(data_dir, train, **kwargs) + transform = Compose( + [ + ToTensor(), + transforms.Normalize(0.0, 1.0), + ] + ) + target_transform = Compose([self.to_scalar_tensor, self.to_long_tensor]) + self._data = RetinaMNIST( + root=self.data_dir, + split="train" if self.train else "test", + download=True, + transform=transform, + target_transform=target_transform, + ) + self._n_classes = np.unique(self._data.labels).size + if self.train: + val_dataset = RetinaMNIST( + root=self.data_dir, split="val", download=True, transform=transform, target_transform=target_transform + ) + self._data = ConcatDataset([self._data, val_dataset]) + + def __getitem__(self, item): + return self._data[item] + + def __len__(self): + return len(self._data) + + def get_output_shape(self) -> tuple: + return (self._n_classes,) diff --git a/tests/test_datasets/test_retinamnist_dataset.py b/tests/test_datasets/test_retinamnist_dataset.py new file mode 100644 index 0000000..9dbe159 --- /dev/null +++ b/tests/test_datasets/test_retinamnist_dataset.py @@ -0,0 +1,63 @@ +import shutil +from pathlib import Path +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch + +from matchcake_opt.datasets.retinamnist_dataset import RetinaMNISTDataset + + +class TestRetinaMNISTDataset: + MOCK_LEN = 10 + + @pytest.fixture(scope="class") + def data_dir(self): + path = Path(".tmp") / "data_dir" / "retinamnist" + yield path + shutil.rmtree(path, ignore_errors=True) + + @pytest.fixture + def data_mock(self, monkeypatch): + cls_mock = MagicMock() + monkeypatch.setattr("matchcake_opt.datasets.retinamnist_dataset.RetinaMNIST", cls_mock) + mock = MagicMock() + cls_mock.return_value = mock + mock.__getitem__.return_value = (torch.zeros(3, 28, 28), torch.zeros(1).long()) + mock.__len__.return_value = self.MOCK_LEN + mock.labels = np.arange(5) + monkeypatch.setattr("matchcake_opt.datasets.retinamnist_dataset.ConcatDataset", lambda *x: mock) + return mock + + @pytest.fixture + def dataset_instance(self, data_mock, data_dir): + return RetinaMNISTDataset(data_dir=data_dir, train=True) + + def test_init(self, data_mock, data_dir): + dataset = RetinaMNISTDataset(data_dir=data_dir, train=True) + assert dataset._data == data_mock + + def test_getitem(self, data_mock, dataset_instance): + datum = dataset_instance[0] + assert isinstance(datum, tuple) + assert isinstance(datum[0], torch.Tensor) + assert isinstance(datum[1], torch.Tensor) + assert datum[1].dtype == torch.long + data_mock.__getitem__.assert_called_once_with(0) + + def test_to_long_tensor(self, dataset_instance): + x = torch.zeros(0, dtype=torch.int) + y = dataset_instance.to_long_tensor(x) + assert y.dtype == torch.long + + def test_len(self, dataset_instance): + assert len(dataset_instance) == self.MOCK_LEN + + def test_output_shape(self, dataset_instance): + assert dataset_instance.get_output_shape() == (5,) + + def test_to_scalar_tensor(self, dataset_instance): + x = torch.tensor([1]) + y = dataset_instance.to_scalar_tensor(x) + assert isinstance(y, int)