diff --git a/.vulture_whitelist.py b/.vulture_whitelist.py index af316bf..4940fba 100644 --- a/.vulture_whitelist.py +++ b/.vulture_whitelist.py @@ -12,8 +12,6 @@ """ # SessionIOHandler methods - public API used in tests -save_model_checkpoint # noqa - Used in test_session_io_handler.py, test_model_io_integration.py -load_model_checkpoint # noqa - Used in test_model_io_integration.py list_sessions # noqa - Used in test_session_io_handler.py save_run # noqa - Used in test_run_label_migration.py save_labels_to_output_dir # noqa - Used in test_run_label_migration.py diff --git a/anomaly_match/data_io/SessionIOHandler.py b/anomaly_match/data_io/SessionIOHandler.py index 86fc0c3..724cf97 100644 --- a/anomaly_match/data_io/SessionIOHandler.py +++ b/anomaly_match/data_io/SessionIOHandler.py @@ -7,14 +7,13 @@ import json import os -import pickle from pathlib import Path from typing import Any, Dict, List, Optional import pandas as pd -import torch from loguru import logger +from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint from anomaly_match.data_io.save_config import save_config_toml from anomaly_match.pipeline.SessionTracker import IterationInfo, SessionTracker @@ -185,43 +184,6 @@ def save_iteration_scores( except Exception as e: logger.warning(f"Failed to save test scores: {e}") - def save_model_checkpoint( - self, - model_state: Dict[str, Any], - session_tracker: SessionTracker, - checkpoint_name: str = None, - ) -> str: - """ - Save a model checkpoint within the session directory. - - Args: - model_state: Model state dictionary to save. - session_tracker: Associated session tracker. - checkpoint_name: Optional custom checkpoint name. - - Returns: - Path to saved checkpoint. - """ - save_path = self.get_session_save_path(session_tracker) - save_path.mkdir(parents=True, exist_ok=True) - - checkpoints_dir = save_path / "checkpoints" - checkpoints_dir.mkdir(exist_ok=True) - - if checkpoint_name is None: - checkpoint_name = f"model_iter_{session_tracker.total_model_iterations}.pkl" - - checkpoint_path = checkpoints_dir / checkpoint_name - - with open(checkpoint_path, "wb") as f: - pickle.dump(model_state, f) - - # Update the session tracker with the checkpoint path - session_tracker.update_model_state_path(str(checkpoint_path)) - - logger.debug(f"Saved model checkpoint to: {checkpoint_path}") - return str(checkpoint_path) - def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str: """ Save the model to the session directory if session_tracker is available, @@ -246,7 +208,7 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str: if session_tracker.session_iterations else 0 ) - model_filename = f"model_iteration_{iteration_num}.pth" + model_filename = f"model_iteration_{iteration_num}.safetensors" model_path = save_path / model_filename else: if cfg.model_path is None: @@ -287,8 +249,9 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str: "fitsbolt_cfg": fitsbolt_cfg, } - # Save model - torch.save(save_state, model_path) + # Save model (save_checkpoint forces .safetensors extension) + save_checkpoint(save_state, model_path) + model_path = Path(model_path).with_suffix(".safetensors") if session_tracker is not None: # Ensure there's an active session iteration @@ -331,7 +294,7 @@ def load_model(self, model, cfg, model_path: str = None) -> bool: try: # Load checkpoint - checkpoint = torch.load(load_path, weights_only=False) + checkpoint = load_checkpoint(load_path) # Handle distributed training case train_model = ( @@ -426,37 +389,6 @@ def load_model(self, model, cfg, model_path: str = None) -> bool: logger.error(f"Failed to load model from {load_path}: {e}") return False - def load_model_checkpoint(self, checkpoint_path: str) -> Optional[Dict[str, Any]]: - """ - Load a model checkpoint from the specified path. - - Args: - checkpoint_path: Path to the checkpoint file - - Returns: - Dictionary containing the checkpoint data, or None if loading failed - """ - try: - if not os.path.exists(checkpoint_path): - logger.error(f"Checkpoint path does not exist: {checkpoint_path}") - return None - - # Try loading as pickle first (new format), then as torch (legacy) - try: - with open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - logger.debug(f"Loaded checkpoint from pickle format: {checkpoint_path}") - except (pickle.UnpicklingError, EOFError): - # Fall back to torch format - checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu") - logger.debug(f"Loaded checkpoint from torch format: {checkpoint_path}") - - return checkpoint - - except Exception as e: - logger.error(f"Failed to load checkpoint from {checkpoint_path}: {e}") - return None - def load_session(self, session_path: Path) -> SessionTracker: """ Load a session from disk. @@ -611,7 +543,9 @@ def save_run( "fitsbolt_cfg": fitsbolt_cfg, } - torch.save(save_state, save_filename) + save_checkpoint(save_state, save_filename) + # save_checkpoint forces .safetensors extension; update save_filename to match + save_filename = str(Path(save_filename).with_suffix(".safetensors")) # Update session tracker if provided if session_tracker is not None: @@ -706,7 +640,7 @@ def update_config_paths_for_session(self, cfg, session_tracker: SessionTracker) # Update model path to session directory only if not already set by user if cfg.model_path is None: - cfg.model_path = str(session_path / "model.pth") + cfg.model_path = str(session_path / "model.safetensors") # Update output directory to session directory cfg.output_dir = str(session_path) @@ -805,7 +739,7 @@ def print_session(filepath: str) -> None: checkpoints_dir = session_path / "checkpoints" if checkpoints_dir.exists(): - checkpoints = list(checkpoints_dir.glob("*.pkl")) + checkpoints = list(checkpoints_dir.glob("*.safetensors")) print(f"✓ {len(checkpoints)} model checkpoint(s)") print("=" * 60) diff --git a/anomaly_match/data_io/checkpoint_io.py b/anomaly_match/data_io/checkpoint_io.py new file mode 100644 index 0000000..191b63c --- /dev/null +++ b/anomaly_match/data_io/checkpoint_io.py @@ -0,0 +1,326 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. + +"""Checkpoint I/O using safetensors for secure model serialization. + +Replaces pickle-based ``torch.save`` / ``torch.load`` with safetensors to +prevent arbitrary code execution when loading untrusted model files. + +Checkpoint layout inside a single ``.safetensors`` file: + +* **Binary section** — all ``torch.Tensor`` values (model weights, optimizer + momentum buffers, …) stored under namespaced keys + (``train_model.``, ``optimizer.state..``, …). +* **Metadata header** — every non-tensor value is JSON-encoded into the + ``Dict[str, str]`` metadata that safetensors carries in its header. +""" + +from __future__ import annotations + +import json +from enum import Enum +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from loguru import logger + +# --------------------------------------------------------------------------- +# JSON helpers for types that appear in checkpoint metadata +# --------------------------------------------------------------------------- + + +def _nullify_empty_dicts(obj: Any) -> Any: + """Recursively replace empty dicts with ``None``. + + DotMap auto-creates empty child maps when accessing missing keys. After + ``toDict()`` these become ``{}``, which breaks fitsbolt's + ``validate_config`` on reload (e.g. ``channel_combination`` is expected to + be ``None`` or ``np.ndarray``, not ``{}``). + """ + if isinstance(obj, dict): + if len(obj) == 0: + return None + return {k: _nullify_empty_dicts(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_nullify_empty_dicts(v) for v in obj] + return obj + + +def _prepare_for_json(obj: Any) -> Any: + """Recursively convert non-JSON-native types to tagged representations. + + This is needed because ``IntEnum`` (which ``NormalisationMethod`` inherits + from) is a subclass of ``int`` — the standard JSON encoder serializes it + as a plain integer and never calls ``default()``. By walking the + structure up-front we ensure *all* special types are tagged. + + """ + # Enum check MUST come before int/float because IntEnum is also an int + if isinstance(obj, Enum): + return {"__enum__": type(obj).__name__, "name": obj.name} + if isinstance(obj, np.dtype): + return {"__numpy_dtype__": str(obj)} + if isinstance(obj, type) and issubclass(obj, np.generic): + return {"__numpy_dtype_type__": np.dtype(obj).str} + if isinstance(obj, np.ndarray): + return {"__numpy_array__": obj.tolist(), "dtype": str(obj.dtype)} + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, dict): + return {k: _prepare_for_json(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_prepare_for_json(v) for v in obj] + return obj + + +class _CheckpointEncoder(json.JSONEncoder): + """JSON encoder that handles checkpoint-specific types. + + Note: ``IntEnum`` values bypass ``default()`` because they *are* ints. + Use :func:`_prepare_for_json` on the data **before** calling + ``json.dumps`` to ensure those types are correctly tagged. + """ + + def default(self, obj: Any) -> Any: + if isinstance(obj, Enum): + return {"__enum__": type(obj).__name__, "name": obj.name} + if isinstance(obj, np.dtype): + return {"__numpy_dtype__": str(obj)} + if isinstance(obj, type) and issubclass(obj, np.generic): + return {"__numpy_dtype_type__": np.dtype(obj).str} + if isinstance(obj, np.ndarray): + return {"__numpy_array__": obj.tolist(), "dtype": str(obj.dtype)} + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.bool_): + return bool(obj) + return super().default(obj) + + +def _checkpoint_object_hook(obj: dict) -> Any: + """JSON object-hook that restores checkpoint-specific types.""" + if "__enum__" in obj: + enum_name = obj["__enum__"] + if enum_name == "NormalisationMethod": + from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + + return NormalisationMethod[obj["name"]] + return f"{enum_name}.{obj['name']}" + if "__numpy_dtype__" in obj: + return np.dtype(obj["__numpy_dtype__"]) + if "__numpy_dtype_type__" in obj: + return np.dtype(obj["__numpy_dtype_type__"]).type + if "__numpy_array__" in obj: + return np.array(obj["__numpy_array__"], dtype=obj["dtype"]) + return obj + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def save_checkpoint(save_state: dict[str, Any], path: str | Path) -> Path: + """Save a model checkpoint in safetensors format. + + Tensors are stored in the safetensors binary section; everything else is + JSON-encoded into the safetensors metadata header. + + Args: + save_state: Checkpoint dictionary (same keys as previously passed to + ``torch.save``). + path: Destination file path. The extension is forced to + ``.safetensors``. + + Returns: + The actual path written (with ``.safetensors`` extension). + """ + from safetensors.torch import save_file + + path = Path(path).with_suffix(".safetensors") + + tensors: dict[str, torch.Tensor] = {} + metadata: dict[str, str] = {} + + # ---- model state-dicts ------------------------------------------------ + for model_key in ("train_model", "eval_model"): + state_dict = save_state.get(model_key) + if state_dict is None: + continue + for param_name, tensor in state_dict.items(): + tensors[f"{model_key}.{param_name}"] = tensor.detach().clone().contiguous() + + # ---- optimizer state -------------------------------------------------- + opt_state = save_state.get("optimizer") + if opt_state is not None: + opt_skeleton: dict[str, Any] = { + "state": {}, + "param_groups": opt_state.get("param_groups", []), + } + for param_idx, state in opt_state.get("state", {}).items(): + idx_str = str(param_idx) + opt_skeleton["state"][idx_str] = {} + for key, val in state.items(): + if isinstance(val, torch.Tensor): + tensors[f"optimizer.state.{param_idx}.{key}"] = ( + val.detach().clone().contiguous() + ) + opt_skeleton["state"][idx_str][key] = "__tensor__" + else: + opt_skeleton["state"][idx_str][key] = val + metadata["optimizer"] = json.dumps(_prepare_for_json(opt_skeleton), cls=_CheckpointEncoder) + else: + metadata["optimizer"] = "null" + + # ---- scheduler state -------------------------------------------------- + sched_state = save_state.get("scheduler") + metadata["scheduler"] = ( + json.dumps(_prepare_for_json(sched_state), cls=_CheckpointEncoder) + if sched_state is not None + else "null" + ) + + # ---- scalar / enum metadata ------------------------------------------- + for key in ( + "it", + "total_it", + "best_eval_acc", + "best_it", + "num_channels", + "net", + "normalisation_method", + "last_normalisation_method", + ): + metadata[key] = json.dumps(_prepare_for_json(save_state.get(key)), cls=_CheckpointEncoder) + + # ---- fitsbolt config (DotMap → dict → JSON) --------------------------- + fb_cfg = save_state.get("fitsbolt_cfg") + if fb_cfg is not None: + cfg_dict = fb_cfg.toDict() if hasattr(fb_cfg, "toDict") else fb_cfg + # DotMap auto-creates empty child maps on missing-key access (e.g. + # channel_combination). After toDict() these become empty dicts {}, + # which break fitsbolt's validate_config on reload. Normalize + # leaf-level empty dicts to None. + cfg_dict = _nullify_empty_dicts(cfg_dict) + metadata["fitsbolt_cfg"] = json.dumps(_prepare_for_json(cfg_dict), cls=_CheckpointEncoder) + else: + metadata["fitsbolt_cfg"] = "null" + + # ---- labeled-data CSV ------------------------------------------------- + csv_str = save_state.get("labeled_data_csv") + if csv_str is not None: + metadata["labeled_data_csv"] = csv_str + + # safetensors requires at least one tensor + if not tensors: + tensors["__placeholder__"] = torch.zeros(1) + + save_file(tensors, str(path), metadata=metadata) + logger.debug(f"Saved checkpoint in safetensors format: {path}") + return path + + +def load_checkpoint(path: str | Path, device: str = "cpu") -> dict[str, Any]: + """Load a model checkpoint from a ``.safetensors`` file. + + Args: + path: Path to the ``.safetensors`` checkpoint file. + device: Device to map tensors to (default ``"cpu"``). + + Returns: + Checkpoint dictionary with the same structure as originally saved. + + Raises: + FileNotFoundError: If *path* does not exist. + """ + from safetensors import safe_open + from safetensors.torch import load_file + + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {path}") + + all_tensors = load_file(str(path), device=device) + + with safe_open(str(path), framework="pt", device=device) as f: + raw_metadata = f.metadata() or {} + + checkpoint: dict[str, Any] = {} + + # ---- model state-dicts ------------------------------------------------ + for model_key in ("train_model", "eval_model"): + prefix = f"{model_key}." + state_dict = {k[len(prefix) :]: v for k, v in all_tensors.items() if k.startswith(prefix)} + if state_dict: + checkpoint[model_key] = state_dict + + # ---- optimizer state -------------------------------------------------- + opt_skeleton = json.loads( + raw_metadata.get("optimizer", "null"), object_hook=_checkpoint_object_hook + ) + if opt_skeleton is not None: + new_state: dict[int, dict] = {} + for idx_str, state in opt_skeleton.get("state", {}).items(): + restored: dict[str, Any] = {} + for key, val in state.items(): + if val == "__tensor__": + restored[key] = all_tensors[f"optimizer.state.{idx_str}.{key}"] + else: + restored[key] = val + new_state[int(idx_str)] = restored + opt_skeleton["state"] = new_state + checkpoint["optimizer"] = opt_skeleton + else: + checkpoint["optimizer"] = None + + # ---- scheduler state -------------------------------------------------- + checkpoint["scheduler"] = json.loads( + raw_metadata.get("scheduler", "null"), object_hook=_checkpoint_object_hook + ) + + # ---- scalar / enum metadata ------------------------------------------- + for key in ( + "it", + "total_it", + "best_eval_acc", + "best_it", + "num_channels", + "net", + "normalisation_method", + "last_normalisation_method", + ): + checkpoint[key] = json.loads( + raw_metadata.get(key, "null"), object_hook=_checkpoint_object_hook + ) + + # ---- fitsbolt config -------------------------------------------------- + fb_data = json.loads( + raw_metadata.get("fitsbolt_cfg", "null"), object_hook=_checkpoint_object_hook + ) + if fb_data is not None: + from dotmap import DotMap + + # _dynamic=False prevents DotMap from auto-creating empty child maps + # on missing-key access, which would break fitsbolt's validate_config + # (e.g. channel_combination should stay absent, not become DotMap()). + checkpoint["fitsbolt_cfg"] = DotMap(fb_data, _dynamic=False) + else: + checkpoint["fitsbolt_cfg"] = None + + # ---- labeled-data CSV ------------------------------------------------- + if "labeled_data_csv" in raw_metadata: + checkpoint["labeled_data_csv"] = raw_metadata["labeled_data_csv"] + + return checkpoint diff --git a/anomaly_match/utils/get_default_cfg.py b/anomaly_match/utils/get_default_cfg.py index 5ac0bc9..969efbc 100644 --- a/anomaly_match/utils/get_default_cfg.py +++ b/anomaly_match/utils/get_default_cfg.py @@ -32,7 +32,7 @@ def get_default_cfg(): cfg.metadata_file = None # Path to the metadata CSV file cfg.prediction_search_dir = None cfg.save_path = os.path.join(cfg.save_dir) - cfg.save_file = create_model_string(cfg) + ".pth" + cfg.save_file = create_model_string(cfg) + ".safetensors" cfg.model_path = None # Will be set by SessionIOHandler when session is active cfg.N_batch_prediction = None # User specified batch size for evaluating a directory, if None: determined automatically cfg.subprocess_buffer_size = ( diff --git a/environment.yml b/environment.yml index 94b3ce4..209c452 100644 --- a/environment.yml +++ b/environment.yml @@ -38,4 +38,5 @@ dependencies: - cutana>=0.2.1 - fitsbolt>=0.2 - opencv-python-headless + - safetensors - timm diff --git a/environment_CI.yml b/environment_CI.yml index d815bc1..6c52521 100644 --- a/environment_CI.yml +++ b/environment_CI.yml @@ -35,6 +35,7 @@ dependencies: - pip: - opencv-python-headless - albumentations + - safetensors - timm - fitsbolt>=0.2 - cutana>=0.2.1 diff --git a/prediction_utils.py b/prediction_utils.py index 6ce6348..e47e810 100644 --- a/prediction_utils.py +++ b/prediction_utils.py @@ -24,6 +24,7 @@ from loguru import logger from turbojpeg import TurboJPEG +from anomaly_match.data_io.checkpoint_io import load_checkpoint from anomaly_match.data_io.load_images import get_fitsbolt_config, process_single_wrapper from anomaly_match.utils.get_default_cfg import get_default_cfg @@ -189,10 +190,8 @@ def load_model(cfg): else: logger.info("Using CPU for inference") - if torch.cuda.is_available(): - checkpoint = torch.load(model_path, weights_only=False) - else: - checkpoint = torch.load(model_path, weights_only=False, map_location=torch.device("cpu")) + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = load_checkpoint(model_path, device=device) if "eval_model" not in checkpoint: raise KeyError( diff --git a/pyproject.toml b/pyproject.toml index 3e25184..71703fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "psutil", "pyarrow", "pyturbojpeg", + "safetensors", "scikit-image", "scikit-learn", "scipy", diff --git a/tests/e2e/test_prediction_process.py b/tests/e2e/test_prediction_process.py index 799d366..e8bfad1 100644 --- a/tests/e2e/test_prediction_process.py +++ b/tests/e2e/test_prediction_process.py @@ -40,7 +40,7 @@ def test_config(): cfg.net = "efficientnet-lite0" cfg.pretrained = True cfg.num_channels = 3 - cfg.model_path = "tests/test_data/test_model.pth" + cfg.model_path = "tests/test_data/test_model.safetensors" cfg.gpu = 0 cfg.output_dir = tempfile.mkdtemp() cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY diff --git a/tests/integration/test_fitsbolt_config_persistence.py b/tests/integration/test_fitsbolt_config_persistence.py index ee58584..c8d80c9 100644 --- a/tests/integration/test_fitsbolt_config_persistence.py +++ b/tests/integration/test_fitsbolt_config_persistence.py @@ -7,8 +7,8 @@ """Tests for fitsbolt configuration persistence in model checkpoints. -The fitsbolt DotMap configuration can be pickled directly via torch.save/load -without explicit serialization. +The fitsbolt DotMap configuration is serialized via safetensors metadata +(JSON-encoded) through save_checkpoint/load_checkpoint. """ import shutil @@ -22,14 +22,36 @@ from fitsbolt.cfg.create_config import validate_config from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint from anomaly_match.data_io.load_images import get_fitsbolt_config -class TestFitsboltConfigPickling: - """Test cases for fitsbolt config pickling via torch.save/load.""" - - def test_pickle_roundtrip_basic(self): - """Test basic pickle roundtrip via torch checkpoint.""" +def _make_checkpoint(fitsbolt_cfg=None, **extra): + """Create a minimal checkpoint dict suitable for save_checkpoint.""" + checkpoint = { + "train_model": {"dummy.weight": torch.randn(2, 2)}, + "eval_model": {"dummy.weight": torch.randn(2, 2)}, + "optimizer": None, + "scheduler": None, + "it": 0, + "total_it": 0, + "best_eval_acc": None, + "best_it": None, + "num_channels": 3, + "net": "efficientnet-lite0", + "normalisation_method": None, + "last_normalisation_method": None, + "fitsbolt_cfg": fitsbolt_cfg, + } + checkpoint.update(extra) + return checkpoint + + +class TestFitsboltConfigSafetensors: + """Test cases for fitsbolt config persistence via safetensors.""" + + def test_roundtrip_basic(self): + """Test basic roundtrip via safetensors checkpoint.""" original_cfg = fb_create_cfg( output_dtype=np.uint8, size=[64, 64], @@ -38,44 +60,34 @@ def test_pickle_roundtrip_basic(self): num_workers=4, ) - # Save via torch - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] assert loaded_cfg.size == original_cfg.size assert loaded_cfg.n_output_channels == original_cfg.n_output_channels - assert loaded_cfg.num_workers == original_cfg.num_workers assert loaded_cfg.normalisation_method == original_cfg.normalisation_method - finally: - Path(checkpoint_path).unlink(missing_ok=True) - def test_pickle_numpy_dtype(self): - """Test pickling of numpy dtypes.""" + def test_numpy_dtype(self): + """Test persistence of numpy dtypes.""" original_cfg = fb_create_cfg( output_dtype=np.float32, size=[128, 128], n_output_channels=3, ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] assert loaded_cfg.output_dtype == np.float32 - finally: - Path(checkpoint_path).unlink(missing_ok=True) - def test_pickle_all_normalisation_methods(self): - """Test pickling with all normalisation methods.""" + def test_all_normalisation_methods(self): + """Test persistence with all normalisation methods.""" for method in NormalisationMethod: original_cfg = fb_create_cfg( output_dtype=np.uint8, @@ -84,20 +96,16 @@ def test_pickle_all_normalisation_methods(self): normalisation_method=method, ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] assert loaded_cfg.normalisation_method == method - finally: - Path(checkpoint_path).unlink(missing_ok=True) - def test_pickle_channel_combination(self): - """Test pickling of numpy array channel_combination.""" + def test_channel_combination(self): + """Test persistence of numpy array channel_combination.""" original_cfg = fb_create_cfg( output_dtype=np.uint8, size=[64, 64], @@ -106,22 +114,18 @@ def test_pickle_channel_combination(self): channel_combination=np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] np.testing.assert_array_equal( loaded_cfg.channel_combination, original_cfg.channel_combination ) - finally: - Path(checkpoint_path).unlink(missing_ok=True) - def test_pickle_asinh_settings(self): - """Test pickling of asinh normalisation settings.""" + def test_asinh_settings(self): + """Test persistence of asinh normalisation settings.""" original_cfg = fb_create_cfg( output_dtype=np.uint8, size=[64, 64], @@ -131,25 +135,21 @@ def test_pickle_asinh_settings(self): norm_asinh_clip=[99.0, 99.5, 99.8], ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] assert loaded_cfg.normalisation.asinh_scale == original_cfg.normalisation.asinh_scale assert loaded_cfg.normalisation.asinh_clip == original_cfg.normalisation.asinh_clip - finally: - Path(checkpoint_path).unlink(missing_ok=True) class TestFitsboltConfigValidation: - """Test cases for fitsbolt config validation after pickling.""" + """Test cases for fitsbolt config validation after safetensors roundtrip.""" - def test_validate_pickled_config(self): - """Test that pickled config passes fitsbolt validation.""" + def test_validate_roundtripped_config(self): + """Test that roundtripped config passes fitsbolt validation.""" original_cfg = fb_create_cfg( output_dtype=np.uint8, size=[64, 64], @@ -158,25 +158,21 @@ def test_validate_pickled_config(self): num_workers=4, ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=original_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] # Validate using fitsbolt's validate_config validate_config(loaded_cfg) - finally: - Path(checkpoint_path).unlink(missing_ok=True) class TestFitsboltConfigCompatibility: """Test compatibility with fitsbolt's create_config function.""" def test_compatibility_with_fits_extension_settings(self): - """Test pickling with various fits_extension configurations.""" + """Test persistence with various fits_extension configurations.""" # Single integer extension cfg1 = fb_create_cfg( output_dtype=np.uint8, @@ -185,15 +181,11 @@ def test_compatibility_with_fits_extension_settings(self): fits_extension=0, ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": cfg1}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=cfg1), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) validate_config(loaded["fitsbolt_cfg"]) - finally: - Path(checkpoint_path).unlink(missing_ok=True) # List of extensions cfg2 = fb_create_cfg( @@ -203,22 +195,18 @@ def test_compatibility_with_fits_extension_settings(self): fits_extension=[0, 1, 2], ) - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": cfg2}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=cfg2), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) validate_config(loaded["fitsbolt_cfg"]) - finally: - Path(checkpoint_path).unlink(missing_ok=True) class TestGetFitsboltConfigIntegration: - """Test get_fitsbolt_config integration with pickling.""" + """Test get_fitsbolt_config integration with safetensors persistence.""" - def test_get_fitsbolt_config_pickling(self): - """Test that config from get_fitsbolt_config can be pickled.""" + def test_get_fitsbolt_config_roundtrip(self): + """Test that config from get_fitsbolt_config survives safetensors roundtrip.""" # Create an AnomalyMatch-style config cfg = DotMap() cfg.normalisation = DotMap() @@ -240,13 +228,10 @@ def test_get_fitsbolt_config_pickling(self): # Get fitsbolt config cfg = get_fitsbolt_config(cfg) - # Save and load via torch - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: - checkpoint_path = f.name - - try: - torch.save({"fitsbolt_cfg": cfg.fitsbolt_cfg}, checkpoint_path) - loaded = torch.load(checkpoint_path, weights_only=False) + with tempfile.TemporaryDirectory() as tmp: + checkpoint_path = Path(tmp) / "model.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=cfg.fitsbolt_cfg), checkpoint_path) + loaded = load_checkpoint(checkpoint_path) loaded_cfg = loaded["fitsbolt_cfg"] # Validate @@ -256,8 +241,6 @@ def test_get_fitsbolt_config_pickling(self): assert loaded_cfg.size == [64, 64] assert loaded_cfg.n_output_channels == 3 assert loaded_cfg.normalisation_method == NormalisationMethod.CONVERSION_ONLY - finally: - Path(checkpoint_path).unlink(missing_ok=True) class TestFitsboltConfigE2EWithCheckpoint: @@ -272,7 +255,7 @@ def teardown_method(self): shutil.rmtree(self.temp_dir) def test_fitsbolt_config_in_checkpoint_dict(self): - """Test that fitsbolt config can be saved and loaded in a checkpoint-like dict.""" + """Test that fitsbolt config can be saved and loaded in a checkpoint dict.""" # Create a fitsbolt config fitsbolt_cfg = fb_create_cfg( output_dtype=np.uint8, @@ -283,19 +266,12 @@ def test_fitsbolt_config_in_checkpoint_dict(self): norm_asinh_clip=[99.0, 99.5, 99.8], ) - # Create a mock checkpoint - checkpoint = { - "model_state": {"dummy": "data"}, - "optimizer_state": None, - "fitsbolt_cfg": fitsbolt_cfg, - } - # Save checkpoint - checkpoint_path = Path(self.temp_dir) / "test_checkpoint.pth" - torch.save(checkpoint, checkpoint_path) + checkpoint_path = Path(self.temp_dir) / "test_checkpoint.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=fitsbolt_cfg), checkpoint_path) # Load checkpoint - loaded_checkpoint = torch.load(checkpoint_path, weights_only=False) + loaded_checkpoint = load_checkpoint(checkpoint_path) loaded_fitsbolt_cfg = loaded_checkpoint["fitsbolt_cfg"] # Verify @@ -310,22 +286,12 @@ def test_fitsbolt_config_in_checkpoint_dict(self): def test_backward_compatibility_checkpoint_without_fitsbolt(self): """Test loading checkpoints that don't have fitsbolt_cfg.""" - # Create a mock checkpoint without fitsbolt_cfg (legacy format) - checkpoint = { - "model_state": {"dummy": "data"}, - "optimizer_state": None, - } - - # Save checkpoint - checkpoint_path = Path(self.temp_dir) / "legacy_checkpoint.pth" - torch.save(checkpoint, checkpoint_path) + # Save checkpoint without fitsbolt_cfg + checkpoint_path = Path(self.temp_dir) / "legacy_checkpoint.safetensors" + save_checkpoint(_make_checkpoint(fitsbolt_cfg=None), checkpoint_path) # Load checkpoint - loaded_checkpoint = torch.load(checkpoint_path, weights_only=False) - - # Check that fitsbolt_cfg is not present - assert "fitsbolt_cfg" not in loaded_checkpoint + loaded_checkpoint = load_checkpoint(checkpoint_path) - # Accessing non-existent key should return None via .get() - result = loaded_checkpoint.get("fitsbolt_cfg") - assert result is None + # fitsbolt_cfg should be None (not missing) + assert loaded_checkpoint["fitsbolt_cfg"] is None diff --git a/tests/integration/test_model_io_integration.py b/tests/integration/test_model_io_integration.py index eb8e0e3..9fbb7a2 100644 --- a/tests/integration/test_model_io_integration.py +++ b/tests/integration/test_model_io_integration.py @@ -15,6 +15,7 @@ from dotmap import DotMap from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from anomaly_match.data_io.checkpoint_io import load_checkpoint from anomaly_match.data_io.SessionIOHandler import SessionIOHandler from anomaly_match.pipeline.SessionTracker import SessionTracker from anomaly_match.utils.get_net_builder import get_net_builder @@ -59,7 +60,7 @@ def setup_method(self): from anomaly_match.utils.get_default_cfg import get_default_cfg self.cfg = get_default_cfg() - self.cfg.model_path = str(self.temp_dir / "test_model.pth") + self.cfg.model_path = str(self.temp_dir / "test_model.safetensors") def teardown_method(self): """Clean up test fixtures.""" @@ -135,47 +136,19 @@ def test_load_model_with_normalisation_update(self): def test_load_model_nonexistent_file(self): """Test loading from nonexistent file.""" - self.cfg.model_path = str(self.temp_dir / "nonexistent.pth") + self.cfg.model_path = str(self.temp_dir / "nonexistent.safetensors") success = self.session_io.load_model(self.mock_model, self.cfg) assert not success - def test_load_model_checkpoint(self): - """Test loading model checkpoint.""" - # Create and save a checkpoint - model_state = { - "train_model_state_dict": self.mock_model.train_model.state_dict(), - "eval_model_state_dict": self.mock_model.eval_model.state_dict(), - "total_it": self.mock_model.total_it, - } - - checkpoint_path = self.session_io.save_model_checkpoint( - model_state, self.session_tracker, "test_checkpoint.pkl" - ) - - # Load checkpoint - loaded_checkpoint = self.session_io.load_model_checkpoint(checkpoint_path) - - # Verify checkpoint was loaded - assert loaded_checkpoint is not None - assert "train_model_state_dict" in loaded_checkpoint - assert "total_it" in loaded_checkpoint - assert loaded_checkpoint["total_it"] == self.mock_model.total_it - - def test_load_model_checkpoint_nonexistent(self): - """Test loading nonexistent checkpoint.""" - checkpoint = self.session_io.load_model_checkpoint(str(self.temp_dir / "nonexistent.pkl")) - - assert checkpoint is None - -TEST_MODEL_PATH = Path(__file__).parent.parent / "test_data" / "test_model.pth" +TEST_MODEL_PATH = Path(__file__).parent.parent / "test_data" / "test_model.safetensors" -@pytest.mark.skipif(not TEST_MODEL_PATH.exists(), reason="test_model.pth not available") +@pytest.mark.skipif(not TEST_MODEL_PATH.exists(), reason="test_model.safetensors not available") class TestStoredModelLoading: - """Regression tests for loading the stored test_model.pth checkpoint. + """Regression tests for loading the stored test_model.safetensors checkpoint. These tests verify that the checked-in test model remains compatible with the current model architecture (timm-based EfficientNet). @@ -183,7 +156,7 @@ class TestStoredModelLoading: def test_stored_model_has_expected_keys(self): """Verify the stored checkpoint contains expected top-level keys.""" - checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu") + checkpoint = load_checkpoint(TEST_MODEL_PATH) assert "eval_model" in checkpoint, ( f"Checkpoint missing 'eval_model' key. Found: {list(checkpoint.keys())}" @@ -192,7 +165,7 @@ def test_stored_model_has_expected_keys(self): def test_stored_model_loads_into_efficientnet_lite0(self): """Verify stored model state_dict is compatible with the current architecture.""" - checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu") + checkpoint = load_checkpoint(TEST_MODEL_PATH) net_builder = get_net_builder("efficientnet-lite0", pretrained=False, in_channels=3) model = net_builder(num_classes=2, in_channels=3) diff --git a/tests/integration/test_run_label_migration.py b/tests/integration/test_run_label_migration.py index a484cf7..de913f6 100644 --- a/tests/integration/test_run_label_migration.py +++ b/tests/integration/test_run_label_migration.py @@ -13,6 +13,7 @@ import pytest import torch +from anomaly_match.data_io.checkpoint_io import load_checkpoint from anomaly_match.data_io.SessionIOHandler import SessionIOHandler from anomaly_match.pipeline.SessionTracker import SessionTracker @@ -67,25 +68,25 @@ def mock_config(self): """Create a mock configuration.""" config = Mock() config.normalisation_method = "min_max" - config.model_path = "test_model.pth" + config.model_path = "test_model.safetensors" # Explicitly set fitsbolt_cfg to None to avoid pickling issues with Mock config.fitsbolt_cfg = None return config def test_save_run_basic(self, session_io, mock_model, temp_dir): """Test basic save_run functionality.""" - save_name = "test_model.pth" + save_name = "test_model.safetensors" save_path = temp_dir result_path = session_io.save_run(mock_model, save_name, save_path) - # Check that the model was saved + # Check that the model was saved (save_checkpoint forces .safetensors extension) expected_path = os.path.join(save_path, save_name) assert result_path == expected_path assert os.path.exists(expected_path) # Verify the saved model can be loaded - checkpoint = torch.load(expected_path, weights_only=False) + checkpoint = load_checkpoint(expected_path) assert "train_model" in checkpoint assert "eval_model" in checkpoint assert "optimizer" in checkpoint @@ -95,7 +96,7 @@ def test_save_run_basic(self, session_io, mock_model, temp_dir): def test_save_run_with_session_tracker(self, session_io, mock_model, session_tracker, temp_dir): """Test save_run with session tracker integration.""" - save_name = "test_model.pth" + save_name = "test_model.safetensors" save_path = temp_dir # Start a session iteration @@ -111,7 +112,7 @@ def test_save_run_with_session_tracker(self, session_io, mock_model, session_tra def test_save_run_with_config(self, session_io, mock_model, mock_config, temp_dir): """Test save_run with configuration saving.""" - save_name = "test_model.pth" + save_name = "test_model.safetensors" save_path = temp_dir # Mock the config saving function @@ -182,7 +183,7 @@ def test_integration_training_run_flow( self, session_io, session_tracker, mock_model, mock_config, temp_dir ): """Test the complete integration flow for training run saving.""" - save_name = "final_model.pth" + save_name = "final_model.safetensors" save_path = temp_dir # Simulate a training session @@ -199,7 +200,7 @@ def test_integration_training_run_flow( assert session_tracker.session_iterations[0].model_state_path == model_path # Verify model checkpoint structure - checkpoint = torch.load(model_path, weights_only=False) + checkpoint = load_checkpoint(model_path) assert all(key in checkpoint for key in ["train_model", "eval_model", "optimizer", "it"]) def test_integration_label_saving_flow(self, session_io, session_tracker, temp_dir): diff --git a/tests/test_data/test_model.pth b/tests/test_data/test_model.safetensors similarity index 50% rename from tests/test_data/test_model.pth rename to tests/test_data/test_model.safetensors index 1ffb3f4..53e7f45 100644 Binary files a/tests/test_data/test_model.pth and b/tests/test_data/test_model.safetensors differ diff --git a/tests/unit/test_checkpoint_io.py b/tests/unit/test_checkpoint_io.py new file mode 100644 index 0000000..d87bfc8 --- /dev/null +++ b/tests/unit/test_checkpoint_io.py @@ -0,0 +1,242 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. + +"""Unit tests for checkpoint_io: safetensors-based model checkpoint serialization.""" + +import numpy as np +import pytest +import torch +from dotmap import DotMap +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + +from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint + + +def _make_state_dict(seed=0): + """Create a small deterministic state_dict for testing.""" + torch.manual_seed(seed) + return { + "layer.weight": torch.randn(4, 3), + "layer.bias": torch.randn(4), + "bn.running_mean": torch.zeros(4), + "bn.running_var": torch.ones(4), + "bn.num_batches_tracked": torch.tensor(0, dtype=torch.long), + } + + +def _make_full_checkpoint(**overrides): + """Create a complete checkpoint dict with sensible defaults.""" + checkpoint = { + "train_model": _make_state_dict(seed=0), + "eval_model": _make_state_dict(seed=1), + "optimizer": None, + "scheduler": None, + "it": 42, + "total_it": 100, + "best_eval_acc": 0.95, + "best_it": 80, + "num_channels": 3, + "net": "efficientnet-lite0", + "normalisation_method": NormalisationMethod.CONVERSION_ONLY, + "last_normalisation_method": NormalisationMethod.LOG, + "fitsbolt_cfg": None, + } + checkpoint.update(overrides) + return checkpoint + + +class TestSaveLoadRoundTrip: + """Test that save_checkpoint → load_checkpoint round-trips all data correctly.""" + + def test_model_weights_roundtrip(self, tmp_path): + """Verify train_model and eval_model state_dicts survive round-trip.""" + original = _make_full_checkpoint() + path = save_checkpoint(original, tmp_path / "model") + + loaded = load_checkpoint(path) + + for key in ("train_model", "eval_model"): + for param_name in original[key]: + assert torch.equal(original[key][param_name], loaded[key][param_name]), ( + f"{key}.{param_name} mismatch after round-trip" + ) + + def test_scalar_metadata_roundtrip(self, tmp_path): + """Verify scalar metadata (it, total_it, etc.) survives round-trip.""" + original = _make_full_checkpoint() + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + assert loaded["it"] == 42 + assert loaded["total_it"] == 100 + assert loaded["best_eval_acc"] == 0.95 + assert loaded["best_it"] == 80 + assert loaded["num_channels"] == 3 + assert loaded["net"] == "efficientnet-lite0" + + def test_normalisation_enum_roundtrip(self, tmp_path): + """Verify NormalisationMethod enum values survive round-trip.""" + original = _make_full_checkpoint() + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + assert loaded["normalisation_method"] == NormalisationMethod.CONVERSION_ONLY + assert loaded["last_normalisation_method"] == NormalisationMethod.LOG + assert isinstance(loaded["normalisation_method"], NormalisationMethod) + + def test_optimizer_state_roundtrip(self, tmp_path): + """Verify optimizer state (including momentum tensors) survives round-trip.""" + # Build a real optimizer state + model = torch.nn.Linear(3, 2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + # Step once to create momentum buffers + loss = model(torch.randn(1, 3)).sum() + loss.backward() + optimizer.step() + + opt_state = optimizer.state_dict() + original = _make_full_checkpoint(optimizer=opt_state) + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + # Check param_groups + assert loaded["optimizer"]["param_groups"][0]["lr"] == 0.01 + assert loaded["optimizer"]["param_groups"][0]["momentum"] == 0.9 + + # Check state tensors + for param_idx in opt_state["state"]: + for key in opt_state["state"][param_idx]: + orig_val = opt_state["state"][param_idx][key] + loaded_val = loaded["optimizer"]["state"][param_idx][key] + if isinstance(orig_val, torch.Tensor): + assert torch.equal(orig_val, loaded_val) + + def test_scheduler_state_roundtrip(self, tmp_path): + """Verify scheduler state survives round-trip.""" + sched_state = { + "T_max": 200, + "eta_min": 0, + "last_epoch": 50, + "_step_count": 51, + "base_lrs": [0.01], + "_last_lr": [0.005], + } + original = _make_full_checkpoint(scheduler=sched_state) + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + assert loaded["scheduler"]["T_max"] == 200 + assert loaded["scheduler"]["last_epoch"] == 50 + + def test_fitsbolt_cfg_roundtrip(self, tmp_path): + """Verify fitsbolt DotMap config survives round-trip.""" + fb_cfg = DotMap( + { + "output_dtype": np.uint8, + "size": [64, 64], + "normalisation_method": NormalisationMethod.CONVERSION_ONLY, + "n_output_channels": 3, + "channel_combination": np.array([[1, 0], [0, 1], [0.5, 0.5]]), + } + ) + original = _make_full_checkpoint(fitsbolt_cfg=fb_cfg) + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + loaded_fb = loaded["fitsbolt_cfg"] + assert isinstance(loaded_fb, DotMap) + assert loaded_fb.normalisation_method == NormalisationMethod.CONVERSION_ONLY + assert loaded_fb.output_dtype == np.uint8 + assert np.array_equal(loaded_fb.channel_combination, fb_cfg.channel_combination) + + def test_labeled_data_csv_roundtrip(self, tmp_path): + """Verify labeled_data_csv string survives round-trip.""" + csv = "filename,label\nimg1.jpg,anomaly\nimg2.jpg,normal\n" + original = _make_full_checkpoint(labeled_data_csv=csv) + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + assert loaded["labeled_data_csv"] == csv + + def test_none_values_roundtrip(self, tmp_path): + """Verify None values survive round-trip correctly.""" + original = _make_full_checkpoint( + optimizer=None, + scheduler=None, + fitsbolt_cfg=None, + best_eval_acc=None, + normalisation_method=None, + ) + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + + assert loaded["optimizer"] is None + assert loaded["scheduler"] is None + assert loaded["fitsbolt_cfg"] is None + assert loaded["best_eval_acc"] is None + assert loaded["normalisation_method"] is None + + +class TestFileFormat: + """Test file format details.""" + + def test_extension_forced_to_safetensors(self, tmp_path): + """save_checkpoint forces .safetensors extension.""" + path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model.pth") + assert path.suffix == ".safetensors" + assert path.exists() + + def test_safetensors_extension_preserved(self, tmp_path): + """If .safetensors extension is already correct, it's preserved.""" + path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model.safetensors") + assert path.suffix == ".safetensors" + + def test_load_nonexistent_raises(self, tmp_path): + """Loading a nonexistent file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + load_checkpoint(tmp_path / "nonexistent.safetensors") + + def test_shared_memory_tensors(self, tmp_path): + """Tensors that share memory (e.g. EMA copy) are saved without error.""" + shared = _make_state_dict(seed=0) + original = _make_full_checkpoint( + train_model=shared, + eval_model=shared, # same object, shares memory + ) + # Should not raise RuntimeError about shared tensors + path = save_checkpoint(original, tmp_path / "model") + loaded = load_checkpoint(path) + assert "train_model" in loaded + assert "eval_model" in loaded + + +class TestSecurity: + """Verify the format is safe against code execution attacks.""" + + def test_no_pickle_in_file(self, tmp_path): + """The saved file must not contain pickle opcodes.""" + path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model") + data = path.read_bytes() + # Pickle protocol markers (0x80 = protocol 2+, 'cos\n' = protocol 0) + # safetensors files start with a little-endian u64 header size + assert not data[8:].startswith(b"\x80\x02") # not pickle protocol 2 + assert not data[8:].startswith(b"cos\n") # not pickle protocol 0 + + def test_metadata_is_plain_json(self, tmp_path): + """All metadata in the safetensors header is valid JSON strings.""" + import json + + from safetensors import safe_open + + path = save_checkpoint(_make_full_checkpoint(), tmp_path / "model") + with safe_open(str(path), framework="pt") as f: + metadata = f.metadata() + + for key, value in metadata.items(): + # Every metadata value must be a valid JSON string + parsed = json.loads(value) + assert parsed is not None or value == "null" diff --git a/tests/unit/test_session_io_handler.py b/tests/unit/test_session_io_handler.py index f0f4e1e..b9433fd 100644 --- a/tests/unit/test_session_io_handler.py +++ b/tests/unit/test_session_io_handler.py @@ -6,7 +6,6 @@ # the terms contained in the file 'LICENCE.txt'. import json -import pickle import shutil import tempfile from pathlib import Path @@ -101,39 +100,6 @@ def test_save_session_custom_path(self): assert save_path.exists() assert (save_path / "session_metadata.json").exists() - def test_save_model_checkpoint(self): - """Test saving model checkpoint.""" - model_state = {"weights": [1, 2, 3], "epoch": 10} - - checkpoint_path = self.io_handler.save_model_checkpoint(model_state, self.session_tracker) - - # Check checkpoint was saved - assert Path(checkpoint_path).exists() - assert "checkpoints" in checkpoint_path - assert checkpoint_path.endswith(".pkl") - - # Verify checkpoint content - with open(checkpoint_path, "rb") as f: - loaded_state = pickle.load(f) - assert loaded_state == model_state - - # Verify that session tracker was updated - check the last iteration - assert len(self.session_tracker.session_iterations) > 0 - last_iter = self.session_tracker.session_iterations[-1] - assert last_iter.model_state_path == checkpoint_path - - def test_save_model_checkpoint_custom_name(self): - """Test saving model checkpoint with custom name.""" - model_state = {"test": "data"} - custom_name = "custom_checkpoint.pkl" - - checkpoint_path = self.io_handler.save_model_checkpoint( - model_state, self.session_tracker, checkpoint_name=custom_name - ) - - assert checkpoint_path.endswith(custom_name) - assert Path(checkpoint_path).exists() - def test_load_session_complete_cycle(self): """Test complete save/load cycle.""" # First save a session @@ -221,7 +187,7 @@ def setup_method(self): session_tracker.add_labeled_sample("img1.jpg", "anomaly") session_tracker.add_labeled_sample("img2.jpg", "normal") session_tracker.update_test_performance({"AUROC": 0.92, "AUPRC": 0.88}) - session_tracker.update_model_state_path("models/final_model.pth") + session_tracker.update_model_state_path("models/final_model.safetensors") # Start second iteration session_tracker.start_new_session_iteration() @@ -347,15 +313,11 @@ def test_full_workflow_integration(self): tracker.update_model_iteration(0.5) tracker.add_labeled_sample("img4.jpg", "anomaly") tracker.update_test_performance({"AUROC": 0.93, "AUPRC": 0.89}) - tracker.update_model_state_path("models/best_model.pth") + tracker.update_model_state_path("models/best_model.safetensors") # Save session saved_path = self.io_handler.save_session(tracker) - # Save model checkpoint - model_state = {"epoch": 50, "weights": [1, 2, 3, 4]} - checkpoint_path = self.io_handler.save_model_checkpoint(model_state, tracker) - # Load session back loaded_tracker = self.io_handler.load_session(saved_path) @@ -365,12 +327,6 @@ def test_full_workflow_integration(self): assert len(loaded_tracker.get_labeled_data_df()) == 4 assert len(loaded_tracker.session_iterations) == 2 - # Check model checkpoint exists - assert Path(checkpoint_path).exists() - with open(checkpoint_path, "rb") as f: - loaded_model = pickle.load(f) - assert loaded_model == model_state - def test_multiple_sessions_management(self): """Test managing multiple sessions.""" # Create multiple sessions