Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions notebooks/datasets_normalisation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Normalization of Datasets",
"id": "8b10448261733a07"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## MEDMNIST: PathMNIST",
"id": "822c3957d720a63b"
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-11-03T14:40:00.436371Z",
"start_time": "2025-11-03T14:39:56.065478Z"
}
},
"source": [
"from medmnist import PathMNIST\n",
"from torchvision.transforms import v2\n",
"from pathlib import Path\n",
"import torch\n",
"import numpy as np\n",
"\n",
"\n",
"PRECISION = 5"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-03T14:40:16.084198Z",
"start_time": "2025-11-03T14:40:00.440379Z"
}
},
"cell_type": "code",
"source": [
"root = Path(\"./data/\") / \"PathMNIST\"\n",
"root.mkdir(parents=True, exist_ok=True)\n",
"data = PathMNIST(\n",
" root=root,\n",
" split=\"train\",\n",
" download=True,\n",
" transform=v2.Compose(\n",
" [\n",
" v2.ToImage(),\n",
" v2.ToDtype(torch.float32, scale=True),\n",
" ]\n",
" ),\n",
")\n",
"imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n",
"imgs = imgs.reshape(-1, data.imgs.shape[-1])\n",
"print(f\"PathMNIST Dataset Shape: {data.imgs.shape}\")\n",
"print(f\"PathMNIST Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n",
"print(f\"PathMNIST Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")"
],
"id": "56bb1e919827063c",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PathMNIST Dataset Shape: (89996, 28, 28, 3)\n",
"PathMNIST Dataset Means: [0.23778 0.23778 0.23778]\n",
"PathMNIST Dataset Stds: [0.35807 0.3089 0.35218]\n"
]
}
],
"execution_count": 2
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## MEDMNIST: RetinaMNIST",
"id": "b51bc4fe4aca42e4"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-03T14:40:16.493835Z",
"start_time": "2025-11-03T14:40:16.145319Z"
}
},
"cell_type": "code",
"source": [
"from medmnist import RetinaMNIST\n",
"\n",
"\n",
"root = Path(\"./data/\") / \"RetinaMNIST\"\n",
"root.mkdir(parents=True, exist_ok=True)\n",
"data = RetinaMNIST(\n",
" root=root,\n",
" split=\"train\",\n",
" download=True,\n",
" transform=v2.Compose(\n",
" [\n",
" v2.ToImage(),\n",
" v2.ToDtype(torch.float32, scale=True),\n",
" ]\n",
" ),\n",
")\n",
"imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n",
"imgs = imgs.reshape(-1, data.imgs.shape[-1])\n",
"print(f\"RetinaMNIST Dataset Shape: {data.imgs.shape}\")\n",
"print(f\"RetinaMNIST Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n",
"print(f\"RetinaMNIST Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")"
],
"id": "159a7610d26e86bb",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RetinaMNIST Dataset Shape: (1080, 28, 28, 3)\n",
"RetinaMNIST Dataset Means: [0.39862 0.24519 0.15615]\n",
"RetinaMNIST Dataset Stds: [0.29827 0.20057 0.15053]\n"
]
}
],
"execution_count": 3
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## CIFAR10",
"id": "2fbc0f268d1be290"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-03T14:40:26.853063Z",
"start_time": "2025-11-03T14:40:16.504986Z"
}
},
"cell_type": "code",
"source": [
"from torchvision.datasets import CIFAR10\n",
"\n",
"root = Path(\"./data/\") / \"CIFAR10\"\n",
"root.mkdir(parents=True, exist_ok=True)\n",
"data = CIFAR10(\n",
" root=root,\n",
" train=True,\n",
" download=True,\n",
" transform=v2.Compose(\n",
" [\n",
" v2.ToImage(),\n",
" v2.ToDtype(torch.float32, scale=True),\n",
" ]\n",
" ),\n",
")\n",
"imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n",
"print(f\"CIFAR10 Dataset Shape: {imgs.shape}\")\n",
"imgs = imgs.reshape(-1, imgs.shape[-1])\n",
"print(f\"CIFAR10 Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n",
"print(f\"CIFAR10 Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")"
],
"id": "b8edde9c8da4faeb",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CIFAR10 Dataset Shape: (50000, 32, 32, 3)\n",
"CIFAR10 Dataset Means: [0.32768 0.32768 0.32768]\n",
"CIFAR10 Dataset Stds: [0.27755 0.2693 0.26812]\n"
]
}
],
"execution_count": 4
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## MNIST",
"id": "7139efe270121e2b"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-03T14:40:34.124994Z",
"start_time": "2025-11-03T14:40:26.868597Z"
}
},
"cell_type": "code",
"source": [
"from torchvision.datasets import MNIST\n",
"\n",
"root = Path(\"./data/\") / \"MNIST\"\n",
"root.mkdir(parents=True, exist_ok=True)\n",
"data = MNIST(\n",
" root=root,\n",
" train=True,\n",
" download=True,\n",
" transform=v2.Compose(\n",
" [\n",
" v2.ToImage(),\n",
" v2.ToDtype(torch.float32, scale=True),\n",
" ]\n",
" ),\n",
")\n",
"imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n",
"print(f\"MNIST Dataset Shape: {imgs.shape}\")\n",
"imgs = imgs.reshape(-1, imgs.shape[-1])\n",
"print(f\"MNIST Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n",
"print(f\"MNIST Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")"
],
"id": "2313354e9a3c61f0",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MNIST Dataset Shape: (60000, 28, 28, 1)\n",
"MNIST Dataset Means: [0.13066]\n",
"MNIST Dataset Stds: [0.30811]\n"
]
}
],
"execution_count": 5
},
{
"metadata": {},
"cell_type": "markdown",
"source": "----------------------",
"id": "53856f3e16fff049"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dev = [
"pytest-xdist>=3.7.0,<4",
"isort>=6.0.1,<7",
"types-networkx>=3.5.0.20251001",
"pip>=25.3",
]
docs = [
"sphinx>=6.2.1,<6.3.0",
Expand Down
18 changes: 7 additions & 11 deletions src/matchcake_opt/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,34 @@
from typing import Sequence, Union

import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.transforms.v2 import Compose
from torchvision.transforms import v2

from .base_dataset import BaseDataset


class Cifar10Dataset(BaseDataset):
DATASET_NAME = "Cifar10"

@staticmethod
def to_long_tensor(y: Sequence[int]) -> torch.Tensor:
return torch.tensor(y, dtype=torch.long)

def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs):
super().__init__(data_dir, train, **kwargs)
self._data = CIFAR10(
self.data_dir,
train=self.train,
download=True,
transform=Compose(
transform=v2.Compose(
[
ToTensor(),
transforms.Normalize(0.0, 1.0),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize((0.32768, 0.32768, 0.32768), (0.27755, 0.2693, 0.26812)),
# TODO: Add a param to add other transforms?
# transforms.RandomCrop(32, pad_if_needed=True), # Randomly crop a 32x32 patch
# transforms.RandomHorizontalFlip(), # Randomly flip horizontally
# transforms.RandomRotation(10), # Randomly rotate up to 10 degrees
# transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
]
),
target_transform=Compose([self.to_long_tensor]),
target_transform=v2.Compose([v2.ToDtype(torch.long)]),
)

def __getitem__(self, item):
Expand Down
18 changes: 7 additions & 11 deletions src/matchcake_opt/datasets/mnist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,33 @@
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.transforms.v2 import Compose
from torchvision.transforms import v2

from .base_dataset import BaseDataset


class MNISTDataset(BaseDataset):
DATASET_NAME = "MNIST"

@staticmethod
def to_long_tensor(y: Sequence[int]) -> torch.Tensor:
"""Convert a Sequence of integers to a long tensor."""
return torch.tensor(y, dtype=torch.long)

def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs):
super().__init__(data_dir, train, **kwargs)
self._data = MNIST(
self.data_dir,
train=self.train,
download=True,
transform=Compose(
transform=v2.Compose(
[
ToTensor(),
transforms.Normalize(0.0, 1.0),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize((0.13066,), (0.30811,)),
# TODO: Add a param to add other transforms?
# transforms.RandomCrop(32, pad_if_needed=True), # Randomly crop a 32x32 patch
# transforms.RandomHorizontalFlip(), # Randomly flip horizontally
# transforms.RandomRotation(10), # Randomly rotate up to 10 degrees
# transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
]
),
target_transform=Compose([self.to_long_tensor]),
target_transform=v2.Compose([v2.ToDtype(torch.long)]),
)

def __getitem__(self, item):
Expand Down
17 changes: 6 additions & 11 deletions src/matchcake_opt/datasets/pathmnist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import torch
from medmnist import PathMNIST
from torch.utils.data import ConcatDataset
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms.v2 import Compose
from torchvision.transforms import v2

from .base_dataset import BaseDataset

Expand All @@ -19,19 +17,16 @@ class PathMNISTDataset(BaseDataset):
def to_scalar_tensor(y: Sequence[int]):
return torch.tensor(y).item()

@staticmethod
def to_long_tensor(y: Sequence[int]) -> torch.Tensor:
return torch.tensor(y, dtype=torch.long)

def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs):
super().__init__(data_dir, train, **kwargs)
transform = Compose(
transform = v2.Compose(
[
ToTensor(),
transforms.Normalize(0.0, 1.0),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize((0.23778, 0.23778, 0.23778), (0.35807, 0.3089, 0.35218)),
]
)
target_transform = Compose([self.to_scalar_tensor, self.to_long_tensor])
target_transform = v2.Compose([self.to_scalar_tensor, v2.ToDtype(torch.long)])
self._data = PathMNIST(
root=self.data_dir,
split="train" if self.train else "test",
Expand Down
Loading