Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/matchcake_opt/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .digits2d import Digits2D
from .mnist_dataset import MNISTDataset
from .pathmnist_dataset import PathMNISTDataset
from .retinamnist_dataset import RetinaMNISTDataset

dataset_name_to_type_map: Dict[str, Type[BaseDataset]] = {
_cls.DATASET_NAME: _cls
Expand Down
56 changes: 56 additions & 0 deletions src/matchcake_opt/datasets/retinamnist_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pathlib import Path
from typing import Sequence, Union

import numpy as np
import torch
from medmnist import RetinaMNIST
from torch.utils.data import ConcatDataset
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms.v2 import Compose

from .base_dataset import BaseDataset


class RetinaMNISTDataset(BaseDataset):
DATASET_NAME = "RetinaMNIST"

@staticmethod
def to_scalar_tensor(y: Sequence[int]):
return torch.tensor(y).item()

@staticmethod
def to_long_tensor(y: Sequence[int]) -> torch.Tensor:
return torch.tensor(y, dtype=torch.long)

def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs):
super().__init__(data_dir, train, **kwargs)
transform = Compose(
[
ToTensor(),
transforms.Normalize(0.0, 1.0),
]
)
target_transform = Compose([self.to_scalar_tensor, self.to_long_tensor])
self._data = RetinaMNIST(
root=self.data_dir,
split="train" if self.train else "test",
download=True,
transform=transform,
target_transform=target_transform,
)
self._n_classes = np.unique(self._data.labels).size
if self.train:
val_dataset = RetinaMNIST(
root=self.data_dir, split="val", download=True, transform=transform, target_transform=target_transform
)
self._data = ConcatDataset([self._data, val_dataset])

def __getitem__(self, item):
return self._data[item]

def __len__(self):
return len(self._data)

def get_output_shape(self) -> tuple:
return (self._n_classes,)
63 changes: 63 additions & 0 deletions tests/test_datasets/test_retinamnist_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import shutil
from pathlib import Path
from unittest.mock import MagicMock

import numpy as np
import pytest
import torch

from matchcake_opt.datasets.retinamnist_dataset import RetinaMNISTDataset


class TestRetinaMNISTDataset:
MOCK_LEN = 10

@pytest.fixture(scope="class")
def data_dir(self):
path = Path(".tmp") / "data_dir" / "retinamnist"
yield path
shutil.rmtree(path, ignore_errors=True)

@pytest.fixture
def data_mock(self, monkeypatch):
cls_mock = MagicMock()
monkeypatch.setattr("matchcake_opt.datasets.retinamnist_dataset.RetinaMNIST", cls_mock)
mock = MagicMock()
cls_mock.return_value = mock
mock.__getitem__.return_value = (torch.zeros(3, 28, 28), torch.zeros(1).long())
mock.__len__.return_value = self.MOCK_LEN
mock.labels = np.arange(5)
monkeypatch.setattr("matchcake_opt.datasets.retinamnist_dataset.ConcatDataset", lambda *x: mock)
return mock

@pytest.fixture
def dataset_instance(self, data_mock, data_dir):
return RetinaMNISTDataset(data_dir=data_dir, train=True)

def test_init(self, data_mock, data_dir):
dataset = RetinaMNISTDataset(data_dir=data_dir, train=True)
assert dataset._data == data_mock

def test_getitem(self, data_mock, dataset_instance):
datum = dataset_instance[0]
assert isinstance(datum, tuple)
assert isinstance(datum[0], torch.Tensor)
assert isinstance(datum[1], torch.Tensor)
assert datum[1].dtype == torch.long
data_mock.__getitem__.assert_called_once_with(0)

def test_to_long_tensor(self, dataset_instance):
x = torch.zeros(0, dtype=torch.int)
y = dataset_instance.to_long_tensor(x)
assert y.dtype == torch.long

def test_len(self, dataset_instance):
assert len(dataset_instance) == self.MOCK_LEN

def test_output_shape(self, dataset_instance):
assert dataset_instance.get_output_shape() == (5,)

def test_to_scalar_tensor(self, dataset_instance):
x = torch.tensor([1])
y = dataset_instance.to_scalar_tensor(x)
assert isinstance(y, int)