Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .vulture_whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
"""

# SessionIOHandler methods - public API used in tests
save_model_checkpoint # noqa - Used in test_session_io_handler.py, test_model_io_integration.py
load_model_checkpoint # noqa - Used in test_model_io_integration.py
list_sessions # noqa - Used in test_session_io_handler.py
save_run # noqa - Used in test_run_label_migration.py
save_labels_to_output_dir # noqa - Used in test_run_label_migration.py
Expand Down
88 changes: 11 additions & 77 deletions anomaly_match/data_io/SessionIOHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

import json
import os
import pickle
from pathlib import Path
from typing import Any, Dict, List, Optional

import pandas as pd
import torch
from loguru import logger

from anomaly_match.data_io.checkpoint_io import load_checkpoint, save_checkpoint
from anomaly_match.data_io.save_config import save_config_toml
from anomaly_match.pipeline.SessionTracker import IterationInfo, SessionTracker

Expand Down Expand Up @@ -185,43 +184,6 @@ def save_iteration_scores(
except Exception as e:
logger.warning(f"Failed to save test scores: {e}")

def save_model_checkpoint(
self,
model_state: Dict[str, Any],
session_tracker: SessionTracker,
checkpoint_name: str = None,
) -> str:
"""
Save a model checkpoint within the session directory.

Args:
model_state: Model state dictionary to save.
session_tracker: Associated session tracker.
checkpoint_name: Optional custom checkpoint name.

Returns:
Path to saved checkpoint.
"""
save_path = self.get_session_save_path(session_tracker)
save_path.mkdir(parents=True, exist_ok=True)

checkpoints_dir = save_path / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True)

if checkpoint_name is None:
checkpoint_name = f"model_iter_{session_tracker.total_model_iterations}.pkl"

checkpoint_path = checkpoints_dir / checkpoint_name

with open(checkpoint_path, "wb") as f:
pickle.dump(model_state, f)

# Update the session tracker with the checkpoint path
session_tracker.update_model_state_path(str(checkpoint_path))

logger.debug(f"Saved model checkpoint to: {checkpoint_path}")
return str(checkpoint_path)

def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str:
"""
Save the model to the session directory if session_tracker is available,
Expand All @@ -246,7 +208,7 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str:
if session_tracker.session_iterations
else 0
)
model_filename = f"model_iteration_{iteration_num}.pth"
model_filename = f"model_iteration_{iteration_num}.safetensors"
model_path = save_path / model_filename
else:
if cfg.model_path is None:
Expand Down Expand Up @@ -287,8 +249,9 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str:
"fitsbolt_cfg": fitsbolt_cfg,
}

# Save model
torch.save(save_state, model_path)
# Save model (save_checkpoint forces .safetensors extension)
save_checkpoint(save_state, model_path)
model_path = Path(model_path).with_suffix(".safetensors")

if session_tracker is not None:
# Ensure there's an active session iteration
Expand Down Expand Up @@ -331,7 +294,7 @@ def load_model(self, model, cfg, model_path: str = None) -> bool:

try:
# Load checkpoint
checkpoint = torch.load(load_path, weights_only=False)
checkpoint = load_checkpoint(load_path)

# Handle distributed training case
train_model = (
Expand Down Expand Up @@ -426,37 +389,6 @@ def load_model(self, model, cfg, model_path: str = None) -> bool:
logger.error(f"Failed to load model from {load_path}: {e}")
return False

def load_model_checkpoint(self, checkpoint_path: str) -> Optional[Dict[str, Any]]:
"""
Load a model checkpoint from the specified path.

Args:
checkpoint_path: Path to the checkpoint file

Returns:
Dictionary containing the checkpoint data, or None if loading failed
"""
try:
if not os.path.exists(checkpoint_path):
logger.error(f"Checkpoint path does not exist: {checkpoint_path}")
return None

# Try loading as pickle first (new format), then as torch (legacy)
try:
with open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
logger.debug(f"Loaded checkpoint from pickle format: {checkpoint_path}")
except (pickle.UnpicklingError, EOFError):
# Fall back to torch format
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu")
logger.debug(f"Loaded checkpoint from torch format: {checkpoint_path}")

return checkpoint

except Exception as e:
logger.error(f"Failed to load checkpoint from {checkpoint_path}: {e}")
return None

def load_session(self, session_path: Path) -> SessionTracker:
"""
Load a session from disk.
Expand Down Expand Up @@ -611,7 +543,9 @@ def save_run(
"fitsbolt_cfg": fitsbolt_cfg,
}

torch.save(save_state, save_filename)
save_checkpoint(save_state, save_filename)
# save_checkpoint forces .safetensors extension; update save_filename to match
save_filename = str(Path(save_filename).with_suffix(".safetensors"))

# Update session tracker if provided
if session_tracker is not None:
Expand Down Expand Up @@ -706,7 +640,7 @@ def update_config_paths_for_session(self, cfg, session_tracker: SessionTracker)

# Update model path to session directory only if not already set by user
if cfg.model_path is None:
cfg.model_path = str(session_path / "model.pth")
cfg.model_path = str(session_path / "model.safetensors")

# Update output directory to session directory
cfg.output_dir = str(session_path)
Expand Down Expand Up @@ -805,7 +739,7 @@ def print_session(filepath: str) -> None:

checkpoints_dir = session_path / "checkpoints"
if checkpoints_dir.exists():
checkpoints = list(checkpoints_dir.glob("*.pkl"))
checkpoints = list(checkpoints_dir.glob("*.safetensors"))
print(f"✓ {len(checkpoints)} model checkpoint(s)")

print("=" * 60)
Expand Down
Loading
Loading