From a0af07c0ba9bf351c04838c55d6b3cd34b51f6c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Mon, 3 Nov 2025 09:29:34 -0500 Subject: [PATCH 1/3] Refactor dataset normalization and update transforms Updated dataset classes for CIFAR10, MNIST, PathMNIST, and RetinaMNIST to use torchvision.transforms.v2 and dataset-specific normalization values. Removed redundant to_long_tensor methods and related tests. Added a notebook for dataset normalization statistics. Updated dev dependencies to include pip>=25.3. --- notebooks/datasets_normalisation.ipynb | 206 ++++++++++++++++++ pyproject.toml | 1 + src/matchcake_opt/datasets/cifar10_dataset.py | 18 +- src/matchcake_opt/datasets/mnist_dataset.py | 18 +- .../datasets/pathmnist_dataset.py | 17 +- .../datasets/retinamnist_dataset.py | 17 +- tests/test_datasets/test_cifar10_dataset.py | 5 - tests/test_datasets/test_mnist_dataset.py | 5 - uv.lock | 11 + 9 files changed, 244 insertions(+), 54 deletions(-) create mode 100644 notebooks/datasets_normalisation.ipynb diff --git a/notebooks/datasets_normalisation.ipynb b/notebooks/datasets_normalisation.ipynb new file mode 100644 index 0000000..48b74b8 --- /dev/null +++ b/notebooks/datasets_normalisation.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Normalization of Datasets", + "id": "8b10448261733a07" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## MEDMNIST: PathMNIST", + "id": "822c3957d720a63b" + }, + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-11-03T14:28:38.684409Z", + "start_time": "2025-11-03T14:28:38.681215Z" + } + }, + "source": [ + "from medmnist import PathMNIST\n", + "from torchvision.transforms import v2\n", + "from pathlib import Path\n", + "import torch\n", + "import numpy as np" + ], + "outputs": [], + "execution_count": 28 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-03T14:28:54.224276Z", + "start_time": "2025-11-03T14:28:38.699938Z" + } + }, + "cell_type": "code", + "source": [ + "root = Path(\"./data/\") / \"PathMNIST\"\n", + "root.mkdir(parents=True, exist_ok=True)\n", + "data = PathMNIST(\n", + " root=root,\n", + " split=\"train\",\n", + " download=True,\n", + " transform=v2.Compose(\n", + " [\n", + " v2.ToImage(),\n", + " v2.ToDtype(torch.float32, scale=True),\n", + " ]\n", + " ),\n", + ")\n", + "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", + "imgs = imgs.reshape(-1, data.imgs.shape[-1])\n", + "print(f\"PathMNIST Dataset Shape: {data.imgs.shape}\")\n", + "print(f\"PathMNIST Dataset Means: {np.round(imgs.mean(0), 3)}\")\n", + "print(f\"PathMNIST Dataset Stds: {np.round(imgs.std(0), 3)}\")" + ], + "id": "56bb1e919827063c", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PathMNIST Dataset Shape: (89996, 28, 28, 3)\n", + "PathMNIST Dataset Means: [0.238 0.238 0.238]\n", + "PathMNIST Dataset Stds: [0.358 0.309 0.352]\n" + ] + } + ], + "execution_count": 29 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## MEDMNIST: PathMNIST", + "id": "b51bc4fe4aca42e4" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-03T14:28:54.599935Z", + "start_time": "2025-11-03T14:28:54.249797Z" + } + }, + "cell_type": "code", + "source": [ + "from medmnist import RetinaMNIST\n", + "\n", + "\n", + "root = Path(\"./data/\") / \"RetinaMNIST\"\n", + "root.mkdir(parents=True, exist_ok=True)\n", + "data = RetinaMNIST(\n", + " root=root,\n", + " split=\"train\",\n", + " download=True,\n", + " transform=v2.Compose(\n", + " [\n", + " v2.ToImage(),\n", + " v2.ToDtype(torch.float32, scale=True),\n", + " ]\n", + " ),\n", + ")\n", + "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", + "imgs = imgs.reshape(-1, data.imgs.shape[-1])\n", + "print(f\"RetinaMNIST Dataset Shape: {data.imgs.shape}\")\n", + "print(f\"RetinaMNIST Dataset Means: {np.round(imgs.mean(0), 3)}\")\n", + "print(f\"RetinaMNIST Dataset Stds: {np.round(imgs.std(0), 3)}\")" + ], + "id": "159a7610d26e86bb", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RetinaMNIST Dataset Shape: (1080, 28, 28, 3)\n", + "RetinaMNIST Dataset Means: [0.399 0.245 0.156]\n", + "RetinaMNIST Dataset Stds: [0.298 0.201 0.151]\n" + ] + } + ], + "execution_count": 30 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## CIFAR10", + "id": "2fbc0f268d1be290" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-03T14:29:06.319542Z", + "start_time": "2025-11-03T14:28:54.605304Z" + } + }, + "cell_type": "code", + "source": [ + "from torchvision.datasets import CIFAR10\n", + "\n", + "root = Path(\"./data/\") / \"CIFAR10\"\n", + "root.mkdir(parents=True, exist_ok=True)\n", + "data = CIFAR10(\n", + " root=root,\n", + " train=True,\n", + " download=True,\n", + " transform=v2.Compose(\n", + " [\n", + " v2.ToImage(),\n", + " v2.ToDtype(torch.float32, scale=True),\n", + " ]\n", + " ),\n", + ")\n", + "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", + "print(f\"CIFAR10 Dataset Shape: {imgs.shape}\")\n", + "imgs = imgs.reshape(-1, imgs.shape[-1])\n", + "print(f\"CIFAR10 Dataset Means: {np.round(imgs.mean(0), 3)}\")\n", + "print(f\"CIFAR10 Dataset Stds: {np.round(imgs.std(0), 3)}\")" + ], + "id": "b8edde9c8da4faeb", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CIFAR10 Dataset Shape: (50000, 32, 32, 3)\n", + "CIFAR10 Dataset Means: [0.328 0.328 0.328]\n", + "CIFAR10 Dataset Stds: [0.278 0.269 0.268]\n" + ] + } + ], + "execution_count": 31 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "----------------------", + "id": "53856f3e16fff049" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index eb59dce..88dca9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dev = [ "pytest-xdist>=3.7.0,<4", "isort>=6.0.1,<7", "types-networkx>=3.5.0.20251001", + "pip>=25.3", ] docs = [ "sphinx>=6.2.1,<6.3.0", diff --git a/src/matchcake_opt/datasets/cifar10_dataset.py b/src/matchcake_opt/datasets/cifar10_dataset.py index 2991141..ced0730 100644 --- a/src/matchcake_opt/datasets/cifar10_dataset.py +++ b/src/matchcake_opt/datasets/cifar10_dataset.py @@ -2,10 +2,8 @@ from typing import Sequence, Union import torch -from torchvision import transforms from torchvision.datasets import CIFAR10 -from torchvision.transforms import ToTensor -from torchvision.transforms.v2 import Compose +from torchvision.transforms import v2 from .base_dataset import BaseDataset @@ -13,27 +11,25 @@ class Cifar10Dataset(BaseDataset): DATASET_NAME = "Cifar10" - @staticmethod - def to_long_tensor(y: Sequence[int]) -> torch.Tensor: - return torch.tensor(y, dtype=torch.long) - def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs): super().__init__(data_dir, train, **kwargs) self._data = CIFAR10( self.data_dir, train=self.train, download=True, - transform=Compose( + transform=v2.Compose( [ - ToTensor(), - transforms.Normalize(0.0, 1.0), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize((0.328, 0.328, 0.328), (0.278, 0.269, 0.268)), + # TODO: Add a param to add other transforms? # transforms.RandomCrop(32, pad_if_needed=True), # Randomly crop a 32x32 patch # transforms.RandomHorizontalFlip(), # Randomly flip horizontally # transforms.RandomRotation(10), # Randomly rotate up to 10 degrees # transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)), ] ), - target_transform=Compose([self.to_long_tensor]), + target_transform=v2.Compose([v2.ToDtype(torch.long)]), ) def __getitem__(self, item): diff --git a/src/matchcake_opt/datasets/mnist_dataset.py b/src/matchcake_opt/datasets/mnist_dataset.py index 778ebc0..7414b53 100644 --- a/src/matchcake_opt/datasets/mnist_dataset.py +++ b/src/matchcake_opt/datasets/mnist_dataset.py @@ -4,8 +4,7 @@ import torch from torchvision import transforms from torchvision.datasets import MNIST -from torchvision.transforms import ToTensor -from torchvision.transforms.v2 import Compose +from torchvision.transforms import v2 from .base_dataset import BaseDataset @@ -13,28 +12,25 @@ class MNISTDataset(BaseDataset): DATASET_NAME = "MNIST" - @staticmethod - def to_long_tensor(y: Sequence[int]) -> torch.Tensor: - """Convert a Sequence of integers to a long tensor.""" - return torch.tensor(y, dtype=torch.long) - def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs): super().__init__(data_dir, train, **kwargs) self._data = MNIST( self.data_dir, train=self.train, download=True, - transform=Compose( + transform=v2.Compose( [ - ToTensor(), - transforms.Normalize(0.0, 1.0), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize((0.1307,), (0.3081,)), + # TODO: Add a param to add other transforms? # transforms.RandomCrop(32, pad_if_needed=True), # Randomly crop a 32x32 patch # transforms.RandomHorizontalFlip(), # Randomly flip horizontally # transforms.RandomRotation(10), # Randomly rotate up to 10 degrees # transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)), ] ), - target_transform=Compose([self.to_long_tensor]), + target_transform=v2.Compose([v2.ToDtype(torch.long)]), ) def __getitem__(self, item): diff --git a/src/matchcake_opt/datasets/pathmnist_dataset.py b/src/matchcake_opt/datasets/pathmnist_dataset.py index feb19f3..5172bd2 100644 --- a/src/matchcake_opt/datasets/pathmnist_dataset.py +++ b/src/matchcake_opt/datasets/pathmnist_dataset.py @@ -5,9 +5,7 @@ import torch from medmnist import PathMNIST from torch.utils.data import ConcatDataset -from torchvision import transforms -from torchvision.transforms import ToTensor -from torchvision.transforms.v2 import Compose +from torchvision.transforms import v2 from .base_dataset import BaseDataset @@ -19,19 +17,16 @@ class PathMNISTDataset(BaseDataset): def to_scalar_tensor(y: Sequence[int]): return torch.tensor(y).item() - @staticmethod - def to_long_tensor(y: Sequence[int]) -> torch.Tensor: - return torch.tensor(y, dtype=torch.long) - def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs): super().__init__(data_dir, train, **kwargs) - transform = Compose( + transform = v2.Compose( [ - ToTensor(), - transforms.Normalize(0.0, 1.0), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize((0.238, 0.238, 0.238), (0.358, 0.309, 0.352)), ] ) - target_transform = Compose([self.to_scalar_tensor, self.to_long_tensor]) + target_transform = v2.Compose([self.to_scalar_tensor, v2.ToDtype(torch.long)]) self._data = PathMNIST( root=self.data_dir, split="train" if self.train else "test", diff --git a/src/matchcake_opt/datasets/retinamnist_dataset.py b/src/matchcake_opt/datasets/retinamnist_dataset.py index 8d89195..cc5c4a4 100644 --- a/src/matchcake_opt/datasets/retinamnist_dataset.py +++ b/src/matchcake_opt/datasets/retinamnist_dataset.py @@ -5,9 +5,7 @@ import torch from medmnist import RetinaMNIST from torch.utils.data import ConcatDataset -from torchvision import transforms -from torchvision.transforms import ToTensor -from torchvision.transforms.v2 import Compose +from torchvision.transforms import v2 from .base_dataset import BaseDataset @@ -19,19 +17,16 @@ class RetinaMNISTDataset(BaseDataset): def to_scalar_tensor(y: Sequence[int]): return torch.tensor(y).item() - @staticmethod - def to_long_tensor(y: Sequence[int]) -> torch.Tensor: - return torch.tensor(y, dtype=torch.long) - def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, train: bool = True, **kwargs): super().__init__(data_dir, train, **kwargs) - transform = Compose( + transform = v2.Compose( [ - ToTensor(), - transforms.Normalize(0.0, 1.0), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize((0.399, 0.245, 0.156), (0.298, 0.201, 0.151)), ] ) - target_transform = Compose([self.to_scalar_tensor, self.to_long_tensor]) + target_transform = v2.Compose([self.to_scalar_tensor, v2.ToDtype(torch.long)]) self._data = RetinaMNIST( root=self.data_dir, split="train" if self.train else "test", diff --git a/tests/test_datasets/test_cifar10_dataset.py b/tests/test_datasets/test_cifar10_dataset.py index 2163a76..b3de415 100644 --- a/tests/test_datasets/test_cifar10_dataset.py +++ b/tests/test_datasets/test_cifar10_dataset.py @@ -40,11 +40,6 @@ def test_getitem(self, cifar10_mock, data_dir): assert datum[1].dtype == torch.long cifar10_mock.__getitem__.assert_called_once_with(0) - def test_to_long_tensor(self): - x = torch.zeros(0, dtype=torch.int) - y = Cifar10Dataset.to_long_tensor(x) - assert y.dtype == torch.long - def test_len(self, cifar10_mock, data_dir): dataset = Cifar10Dataset(data_dir=data_dir, train=True) assert len(dataset) == self.MOCK_LEN diff --git a/tests/test_datasets/test_mnist_dataset.py b/tests/test_datasets/test_mnist_dataset.py index 73b84ea..35cae30 100644 --- a/tests/test_datasets/test_mnist_dataset.py +++ b/tests/test_datasets/test_mnist_dataset.py @@ -43,11 +43,6 @@ def test_getitem(self, data_mock, dataset_instance): assert datum[1].dtype == torch.long data_mock.__getitem__.assert_called_once_with(0) - def test_to_long_tensor(self, dataset_instance): - x = torch.zeros(0, dtype=torch.int) - y = dataset_instance.to_long_tensor(x) - assert y.dtype == torch.long - def test_len(self, dataset_instance): assert len(dataset_instance) == self.MOCK_LEN diff --git a/uv.lock b/uv.lock index 6acfe5f..63dc3ce 100644 --- a/uv.lock +++ b/uv.lock @@ -1842,6 +1842,7 @@ dev = [ { name = "isort" }, { name = "mypy" }, { name = "nbmake" }, + { name = "pip" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-json-report" }, @@ -1897,6 +1898,7 @@ dev = [ { name = "isort", specifier = ">=6.0.1,<7" }, { name = "mypy", specifier = ">=1.15.0,<2" }, { name = "nbmake", specifier = ">=1.5.5,<2" }, + { name = "pip", specifier = ">=25.3" }, { name = "pytest", specifier = ">=8.3.5,<9" }, { name = "pytest-cov", specifier = ">=6.1.1,<7" }, { name = "pytest-json-report", specifier = ">=1.5.0,<2" }, @@ -3045,6 +3047,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/e7/ae39f538fd6844e982063c3a5e4598b8ced43b9633baa3a85ef33af8c05c/pillow-11.3.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c84d689db21a1c397d001aa08241044aa2069e7587b398c8cc63020390b1c1b8", size = 6984598, upload-time = "2025-07-01T09:16:27.732Z" }, ] +[[package]] +name = "pip" +version = "25.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/6e/74a3f0179a4a73a53d66ce57fdb4de0080a8baa1de0063de206d6167acc2/pip-25.3.tar.gz", hash = "sha256:8d0538dbbd7babbd207f261ed969c65de439f6bc9e5dbd3b3b9a77f25d95f343", size = 1803014, upload-time = "2025-10-25T00:55:41.394Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/3c/d717024885424591d5376220b5e836c2d5293ce2011523c9de23ff7bf068/pip-25.3-py3-none-any.whl", hash = "sha256:9655943313a94722b7774661c21049070f6bbb0a1516bf02f7c8d5d9201514cd", size = 1778622, upload-time = "2025-10-25T00:55:39.247Z" }, +] + [[package]] name = "platformdirs" version = "4.5.0" From aa594889880e13707e53830d8e796bac614a24f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Mon, 3 Nov 2025 09:34:24 -0500 Subject: [PATCH 2/3] Remove redundant to_long_tensor tests from dataset tests Deleted the test_to_long_tensor test cases from both PathMNIST and RetinaMNIST dataset test files as they are no longer needed or relevant. --- tests/test_datasets/test_pathmnist_dataset.py | 5 ----- tests/test_datasets/test_retinamnist_dataset.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/tests/test_datasets/test_pathmnist_dataset.py b/tests/test_datasets/test_pathmnist_dataset.py index e701518..9ebf4b3 100644 --- a/tests/test_datasets/test_pathmnist_dataset.py +++ b/tests/test_datasets/test_pathmnist_dataset.py @@ -46,11 +46,6 @@ def test_getitem(self, data_mock, dataset_instance): assert datum[1].dtype == torch.long data_mock.__getitem__.assert_called_once_with(0) - def test_to_long_tensor(self, dataset_instance): - x = torch.zeros(0, dtype=torch.int) - y = dataset_instance.to_long_tensor(x) - assert y.dtype == torch.long - def test_len(self, dataset_instance): assert len(dataset_instance) == self.MOCK_LEN diff --git a/tests/test_datasets/test_retinamnist_dataset.py b/tests/test_datasets/test_retinamnist_dataset.py index 9dbe159..06b3861 100644 --- a/tests/test_datasets/test_retinamnist_dataset.py +++ b/tests/test_datasets/test_retinamnist_dataset.py @@ -46,11 +46,6 @@ def test_getitem(self, data_mock, dataset_instance): assert datum[1].dtype == torch.long data_mock.__getitem__.assert_called_once_with(0) - def test_to_long_tensor(self, dataset_instance): - x = torch.zeros(0, dtype=torch.int) - y = dataset_instance.to_long_tensor(x) - assert y.dtype == torch.long - def test_len(self, dataset_instance): assert len(dataset_instance) == self.MOCK_LEN From 635104f304d761591e00e25e005f2304e2ba5c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Mon, 3 Nov 2025 09:43:14 -0500 Subject: [PATCH 3/3] Update dataset normalization statistics for accuracy Refined the mean and std values used in v2.Normalize for CIFAR10, MNIST, PathMNIST, and RetinaMNIST datasets to higher precision, based on updated calculations. Also updated the datasets_normalisation.ipynb notebook to reflect these new statistics and added MNIST normalization analysis. --- notebooks/datasets_normalisation.ipynb | 105 +++++++++++++----- src/matchcake_opt/datasets/cifar10_dataset.py | 2 +- src/matchcake_opt/datasets/mnist_dataset.py | 2 +- .../datasets/pathmnist_dataset.py | 2 +- .../datasets/retinamnist_dataset.py | 2 +- 5 files changed, 83 insertions(+), 30 deletions(-) diff --git a/notebooks/datasets_normalisation.ipynb b/notebooks/datasets_normalisation.ipynb index 48b74b8..bfee450 100644 --- a/notebooks/datasets_normalisation.ipynb +++ b/notebooks/datasets_normalisation.ipynb @@ -18,8 +18,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2025-11-03T14:28:38.684409Z", - "start_time": "2025-11-03T14:28:38.681215Z" + "end_time": "2025-11-03T14:40:00.436371Z", + "start_time": "2025-11-03T14:39:56.065478Z" } }, "source": [ @@ -27,16 +27,19 @@ "from torchvision.transforms import v2\n", "from pathlib import Path\n", "import torch\n", - "import numpy as np" + "import numpy as np\n", + "\n", + "\n", + "PRECISION = 5" ], "outputs": [], - "execution_count": 28 + "execution_count": 1 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-03T14:28:54.224276Z", - "start_time": "2025-11-03T14:28:38.699938Z" + "end_time": "2025-11-03T14:40:16.084198Z", + "start_time": "2025-11-03T14:40:00.440379Z" } }, "cell_type": "code", @@ -57,8 +60,8 @@ "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", "imgs = imgs.reshape(-1, data.imgs.shape[-1])\n", "print(f\"PathMNIST Dataset Shape: {data.imgs.shape}\")\n", - "print(f\"PathMNIST Dataset Means: {np.round(imgs.mean(0), 3)}\")\n", - "print(f\"PathMNIST Dataset Stds: {np.round(imgs.std(0), 3)}\")" + "print(f\"PathMNIST Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n", + "print(f\"PathMNIST Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")" ], "id": "56bb1e919827063c", "outputs": [ @@ -67,24 +70,24 @@ "output_type": "stream", "text": [ "PathMNIST Dataset Shape: (89996, 28, 28, 3)\n", - "PathMNIST Dataset Means: [0.238 0.238 0.238]\n", - "PathMNIST Dataset Stds: [0.358 0.309 0.352]\n" + "PathMNIST Dataset Means: [0.23778 0.23778 0.23778]\n", + "PathMNIST Dataset Stds: [0.35807 0.3089 0.35218]\n" ] } ], - "execution_count": 29 + "execution_count": 2 }, { "metadata": {}, "cell_type": "markdown", - "source": "## MEDMNIST: PathMNIST", + "source": "## MEDMNIST: RetinaMNIST", "id": "b51bc4fe4aca42e4" }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-03T14:28:54.599935Z", - "start_time": "2025-11-03T14:28:54.249797Z" + "end_time": "2025-11-03T14:40:16.493835Z", + "start_time": "2025-11-03T14:40:16.145319Z" } }, "cell_type": "code", @@ -108,8 +111,8 @@ "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", "imgs = imgs.reshape(-1, data.imgs.shape[-1])\n", "print(f\"RetinaMNIST Dataset Shape: {data.imgs.shape}\")\n", - "print(f\"RetinaMNIST Dataset Means: {np.round(imgs.mean(0), 3)}\")\n", - "print(f\"RetinaMNIST Dataset Stds: {np.round(imgs.std(0), 3)}\")" + "print(f\"RetinaMNIST Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n", + "print(f\"RetinaMNIST Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")" ], "id": "159a7610d26e86bb", "outputs": [ @@ -118,12 +121,12 @@ "output_type": "stream", "text": [ "RetinaMNIST Dataset Shape: (1080, 28, 28, 3)\n", - "RetinaMNIST Dataset Means: [0.399 0.245 0.156]\n", - "RetinaMNIST Dataset Stds: [0.298 0.201 0.151]\n" + "RetinaMNIST Dataset Means: [0.39862 0.24519 0.15615]\n", + "RetinaMNIST Dataset Stds: [0.29827 0.20057 0.15053]\n" ] } ], - "execution_count": 30 + "execution_count": 3 }, { "metadata": {}, @@ -134,8 +137,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-11-03T14:29:06.319542Z", - "start_time": "2025-11-03T14:28:54.605304Z" + "end_time": "2025-11-03T14:40:26.853063Z", + "start_time": "2025-11-03T14:40:16.504986Z" } }, "cell_type": "code", @@ -158,8 +161,8 @@ "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", "print(f\"CIFAR10 Dataset Shape: {imgs.shape}\")\n", "imgs = imgs.reshape(-1, imgs.shape[-1])\n", - "print(f\"CIFAR10 Dataset Means: {np.round(imgs.mean(0), 3)}\")\n", - "print(f\"CIFAR10 Dataset Stds: {np.round(imgs.std(0), 3)}\")" + "print(f\"CIFAR10 Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n", + "print(f\"CIFAR10 Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")" ], "id": "b8edde9c8da4faeb", "outputs": [ @@ -168,12 +171,62 @@ "output_type": "stream", "text": [ "CIFAR10 Dataset Shape: (50000, 32, 32, 3)\n", - "CIFAR10 Dataset Means: [0.328 0.328 0.328]\n", - "CIFAR10 Dataset Stds: [0.278 0.269 0.268]\n" + "CIFAR10 Dataset Means: [0.32768 0.32768 0.32768]\n", + "CIFAR10 Dataset Stds: [0.27755 0.2693 0.26812]\n" + ] + } + ], + "execution_count": 4 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## MNIST", + "id": "7139efe270121e2b" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-03T14:40:34.124994Z", + "start_time": "2025-11-03T14:40:26.868597Z" + } + }, + "cell_type": "code", + "source": [ + "from torchvision.datasets import MNIST\n", + "\n", + "root = Path(\"./data/\") / \"MNIST\"\n", + "root.mkdir(parents=True, exist_ok=True)\n", + "data = MNIST(\n", + " root=root,\n", + " train=True,\n", + " download=True,\n", + " transform=v2.Compose(\n", + " [\n", + " v2.ToImage(),\n", + " v2.ToDtype(torch.float32, scale=True),\n", + " ]\n", + " ),\n", + ")\n", + "imgs = torch.stack([d[0] for d in data], dim=-1).permute(3, 1, 2, 0).cpu().numpy()\n", + "print(f\"MNIST Dataset Shape: {imgs.shape}\")\n", + "imgs = imgs.reshape(-1, imgs.shape[-1])\n", + "print(f\"MNIST Dataset Means: {np.round(imgs.mean(0), PRECISION)}\")\n", + "print(f\"MNIST Dataset Stds: {np.round(imgs.std(0), PRECISION)}\")" + ], + "id": "2313354e9a3c61f0", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MNIST Dataset Shape: (60000, 28, 28, 1)\n", + "MNIST Dataset Means: [0.13066]\n", + "MNIST Dataset Stds: [0.30811]\n" ] } ], - "execution_count": 31 + "execution_count": 5 }, { "metadata": {}, diff --git a/src/matchcake_opt/datasets/cifar10_dataset.py b/src/matchcake_opt/datasets/cifar10_dataset.py index ced0730..f8d4fce 100644 --- a/src/matchcake_opt/datasets/cifar10_dataset.py +++ b/src/matchcake_opt/datasets/cifar10_dataset.py @@ -21,7 +21,7 @@ def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, [ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), - v2.Normalize((0.328, 0.328, 0.328), (0.278, 0.269, 0.268)), + v2.Normalize((0.32768, 0.32768, 0.32768), (0.27755, 0.2693, 0.26812)), # TODO: Add a param to add other transforms? # transforms.RandomCrop(32, pad_if_needed=True), # Randomly crop a 32x32 patch # transforms.RandomHorizontalFlip(), # Randomly flip horizontally diff --git a/src/matchcake_opt/datasets/mnist_dataset.py b/src/matchcake_opt/datasets/mnist_dataset.py index 7414b53..8e36393 100644 --- a/src/matchcake_opt/datasets/mnist_dataset.py +++ b/src/matchcake_opt/datasets/mnist_dataset.py @@ -22,7 +22,7 @@ def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, [ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), - v2.Normalize((0.1307,), (0.3081,)), + v2.Normalize((0.13066,), (0.30811,)), # TODO: Add a param to add other transforms? # transforms.RandomCrop(32, pad_if_needed=True), # Randomly crop a 32x32 patch # transforms.RandomHorizontalFlip(), # Randomly flip horizontally diff --git a/src/matchcake_opt/datasets/pathmnist_dataset.py b/src/matchcake_opt/datasets/pathmnist_dataset.py index 5172bd2..70a231f 100644 --- a/src/matchcake_opt/datasets/pathmnist_dataset.py +++ b/src/matchcake_opt/datasets/pathmnist_dataset.py @@ -23,7 +23,7 @@ def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, [ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), - v2.Normalize((0.238, 0.238, 0.238), (0.358, 0.309, 0.352)), + v2.Normalize((0.23778, 0.23778, 0.23778), (0.35807, 0.3089, 0.35218)), ] ) target_transform = v2.Compose([self.to_scalar_tensor, v2.ToDtype(torch.long)]) diff --git a/src/matchcake_opt/datasets/retinamnist_dataset.py b/src/matchcake_opt/datasets/retinamnist_dataset.py index cc5c4a4..4747488 100644 --- a/src/matchcake_opt/datasets/retinamnist_dataset.py +++ b/src/matchcake_opt/datasets/retinamnist_dataset.py @@ -23,7 +23,7 @@ def __init__(self, data_dir: Union[str, Path] = Path("./data/") / DATASET_NAME, [ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), - v2.Normalize((0.399, 0.245, 0.156), (0.298, 0.201, 0.151)), + v2.Normalize((0.39862, 0.24519, 0.15615), (0.29827, 0.20057, 0.15053)), ] ) target_transform = v2.Compose([self.to_scalar_tensor, v2.ToDtype(torch.long)])