diff --git a/spatialtissuepy/mcp/session.py b/spatialtissuepy/mcp/session.py index b596948..0eba50e 100644 --- a/spatialtissuepy/mcp/session.py +++ b/spatialtissuepy/mcp/session.py @@ -20,7 +20,10 @@ from __future__ import annotations import json +import os import pickle +import threading +import uuid import shutil import uuid from dataclasses import dataclass, field @@ -190,8 +193,7 @@ def store_data( Data to store. """ path = self.base_dir / session_id / "data" / f"{key}.pkl" - with open(path, "wb") as f: - pickle.dump(data, f) + self._atomic_pickle_dump(data, path) meta = self._load_metadata(session_id) if meta and key not in meta.data_keys: @@ -395,8 +397,7 @@ def store_panel( ) -> None: """Store a StatisticsPanel object.""" path = self.base_dir / session_id / "panels" / f"{key}.pkl" - with open(path, "wb") as f: - pickle.dump(panel, f) + self._atomic_pickle_dump(panel, path) meta = self._load_metadata(session_id) if meta and key not in meta.panel_keys: @@ -509,6 +510,43 @@ def _touch_session(self, session_id: str) -> None: meta.last_accessed = datetime.now().isoformat() self._save_metadata(session_id, meta) + @staticmethod + def _atomic_pickle_dump(obj: Any, path: Path) -> None: + """Pickle ``obj`` to ``path`` atomically. + + Writes to a sibling temp file first and then uses ``os.replace`` to + atomically swap it into place. This prevents truncated / partial + pickle files when two tool calls update the same data or panel + concurrently -- the symptom was ``pickle.UnpicklingError: Ran out + of input`` on a subsequent read, because one writer had truncated + the file after another had started writing. + """ + path.parent.mkdir(parents=True, exist_ok=True) + # Unique per-call temp name: pid + thread id + uuid. Needed because + # two threads in the same process share the same pid, so a + # pid-only suffix would collide under in-process concurrency. + tmp_suffix = ( + f".tmp.{os.getpid()}.{threading.get_ident()}.{uuid.uuid4().hex[:8]}" + ) + tmp_path = path.with_suffix(path.suffix + tmp_suffix) + try: + with open(tmp_path, "wb") as f: + pickle.dump(obj, f) + f.flush() + try: + os.fsync(f.fileno()) + except OSError: + # fsync is best-effort -- some filesystems reject it + pass + os.replace(tmp_path, path) + except Exception: + # Clean up the temp file so we don't leak partial writes + try: + tmp_path.unlink() + except OSError: + pass + raise + def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: """Get detailed information about a session.""" meta = self._load_metadata(session_id) diff --git a/spatialtissuepy/summary/registry.py b/spatialtissuepy/summary/registry.py index 6cc6ae0..e8626ad 100644 --- a/spatialtissuepy/summary/registry.py +++ b/spatialtissuepy/summary/registry.py @@ -103,6 +103,44 @@ def __repr__(self) -> str: custom_str = ", custom=True" if self.is_custom else "" return f"MetricInfo(name={self.name!r}, category={self.category!r}{custom_str})" + def __reduce__(self): + """Pickle by registry name, not by function reference. + + The ``func`` attribute is set during registration to the unwrapped + metric function, but the decorator returns a ``functools.wraps`` + wrapper as the module attribute. Pickle resolves functions by + ``__qualname__``, so it finds the wrapper at that name and refuses + to serialise the stored unwrapped function ("not the same object as + ...") -- which is the root cause of the long-standing panel + serialisation failure. + + Instead, we pickle a MetricInfo as its name and re-resolve from the + global registry on unpickle. That's safe for any registered metric + (built-in or custom) as long as the registry is populated in the + unpickling process -- which it is for built-ins at import time. + Custom metrics must be re-registered with ``register_custom_metric`` + before loading; otherwise unpickling raises a clear error + identifying which metric is missing. + + Inline (per-panel) MetricInfo objects do not live in the registry + and cannot be pickled -- they raise a ``TypeError`` here, which is + the correct behavior: panels with inline functions have always been + flagged as non-JSON-serialisable, and pickling them would produce + objects that can't be loaded in a fresh process. + """ + from spatialtissuepy.summary.registry import _resolve_metric_for_pickle + + registered = _registry._metrics.get(self.name) + if registered is None: + raise TypeError( + f"MetricInfo {self.name!r} is not in the global registry " + "and cannot be pickled. Inline metrics added via " + "StatisticsPanel.add_custom_function() are per-panel only; " + "register them globally with register_custom_metric() " + "before saving the containing panel." + ) + return (_resolve_metric_for_pickle, (self.name,)) + def _validate_metric_function( func: Callable, @@ -869,6 +907,26 @@ def get_metric(name: str) -> MetricInfo: return _registry.get(name) +def _resolve_metric_for_pickle(name: str) -> 'MetricInfo': + """Unpickle hook for MetricInfo: fetch the live registry entry by name. + + Raises a descriptive error when the name is not registered in the + current process (e.g. a custom metric that was not re-registered + before loading a saved panel). Kept as a module-level function so + pickle can import it by qualname. + """ + try: + return _registry.get(name) + except KeyError as exc: + raise RuntimeError( + f"Cannot unpickle MetricInfo {name!r}: no metric with that " + "name is registered in the current process. Built-in metrics " + "register automatically at import time; custom metrics must " + "be re-registered with register_custom_metric() before " + "loading a saved panel that references them." + ) from exc + + def list_metrics( category: Optional[str] = None, include_custom: bool = True diff --git a/tests/test_summary_panel_pickle.py b/tests/test_summary_panel_pickle.py new file mode 100644 index 0000000..5ee3ff0 --- /dev/null +++ b/tests/test_summary_panel_pickle.py @@ -0,0 +1,168 @@ +""" +Regression tests for pickling StatisticsPanel / MetricInfo. + +Background +---------- +Panels used to fail to pickle with + + PicklingError: Can't pickle : it's not + the same object as spatialtissuepy.summary.population.cell_counts + +because ``register()`` stores the raw ``func`` in ``MetricInfo`` while +returning a ``functools.wraps`` wrapper as the module attribute, so +pickle's qualname-based function lookup found a different object than the +one being serialised. Parallel writes to the same panel file also +produced truncated pickle blobs, surfacing on the next read as +``Ran out of input``. + +The fix adds ``MetricInfo.__reduce__`` so panels pickle metric entries by +*name* and re-resolve from the live registry on unpickle, plus an atomic +``_atomic_pickle_dump`` helper in ``SessionManager`` that writes through +a temp file + ``os.replace`` so parallel writers cannot leave a partial +file on disk. +""" + +from __future__ import annotations + +import pickle +import threading +import time +from pathlib import Path +from typing import Dict + +import pytest + +from spatialtissuepy.summary import ( + StatisticsPanel, + load_panel, + register_custom_metric, + unregister_custom_metric, +) +from spatialtissuepy.summary.registry import get_metric + + +class TestPanelPickle: + """Panels should pickle cleanly regardless of how metrics were registered.""" + + def test_builtin_panel_roundtrips(self): + panel = load_panel("comprehensive") + assert panel.n_metrics > 0 + + blob = pickle.dumps(panel) + restored = pickle.loads(blob) + + assert restored.n_metrics == panel.n_metrics + assert [m.name for m in restored.metrics] == [m.name for m in panel.metrics] + + @pytest.mark.parametrize("preset", ["basic", "spatial", "neighborhood", "comprehensive"]) + def test_all_preset_panels_pickle(self, preset: str): + panel = load_panel(preset) + blob = pickle.dumps(panel) + restored = pickle.loads(blob) + assert restored.n_metrics == panel.n_metrics + + def test_registered_custom_metric_roundtrips(self): + name = "test_pickle_ratio_unique" + + @register_custom_metric(name=name, description="Test custom metric") + def _custom(data) -> Dict[str, float]: + return {name: 0.5} + + try: + panel = StatisticsPanel(name="test") + panel.add(name) + blob = pickle.dumps(panel) + restored = pickle.loads(blob) + assert restored.n_metrics == 1 + assert restored.metrics[0].name == name + finally: + unregister_custom_metric(name) + + def test_inline_metric_refuses_to_pickle(self): + """Inline functions live only in the panel and can't survive a pickle round-trip.""" + panel = StatisticsPanel(name="inline") + + def _inline(data) -> Dict[str, float]: + return {"v": 1.0} + + panel.add_custom_function("inline_unpicklable", _inline) + + with pytest.raises(TypeError, match="not in the global registry"): + pickle.dumps(panel) + + def test_missing_registry_entry_fails_loudly_on_unpickle(self): + """Unpickling a panel whose custom metric isn't re-registered raises a clear error.""" + name = "test_unpickle_missing_unique" + + @register_custom_metric(name=name, description="Test") + def _m(data) -> Dict[str, float]: + return {name: 1.0} + + panel = StatisticsPanel(name="test") + panel.add(name) + blob = pickle.dumps(panel) + + unregister_custom_metric(name) + with pytest.raises(RuntimeError, match="no metric with that name is registered"): + pickle.loads(blob) + + +class TestAtomicPickleDump: + """Concurrent writes must never leave a truncated pickle on disk.""" + + def test_atomic_write_survives_concurrent_writers(self, tmp_path: Path): + from spatialtissuepy.mcp.session import SessionManager + + path = tmp_path / "payload.pkl" + payloads = [ + {"writer": i, "data": list(range(1000))} for i in range(8) + ] + + def write(p): + SessionManager._atomic_pickle_dump(p, path) + + threads = [threading.Thread(target=write, args=(p,)) for p in payloads] + for t in threads: + t.start() + for t in threads: + t.join() + + # Whichever writer won, the result must be a valid, fully-written pickle + loaded = pickle.loads(path.read_bytes()) + assert loaded["writer"] in {p["writer"] for p in payloads} + assert loaded["data"] == list(range(1000)) + + def test_failed_write_leaves_no_partial_file(self, tmp_path: Path): + from spatialtissuepy.mcp.session import SessionManager + + path = tmp_path / "will_fail.pkl" + + class Unpicklable: + def __reduce__(self): + raise RuntimeError("intentional failure") + + with pytest.raises(RuntimeError, match="intentional"): + SessionManager._atomic_pickle_dump(Unpicklable(), path) + + assert not path.exists(), "failed write should not have produced a file" + # Temp file should also be cleaned up + siblings = list(tmp_path.iterdir()) + assert all("tmp" not in s.name for s in siblings), f"temp file leaked: {siblings}" + + +class TestPanelSessionRoundtrip: + """End-to-end: SessionManager can store and retrieve a panel with metrics.""" + + def test_store_and_load_builtin_panel(self, tmp_path: Path): + from spatialtissuepy.mcp.session import SessionManager + + mgr = SessionManager(base_dir=tmp_path) + sid = mgr.create_session() + + panel = load_panel("basic") + mgr.store_panel(sid, "fingerprint", panel) + + restored = mgr.load_panel(sid, "fingerprint") + assert restored is not None + assert restored.n_metrics == panel.n_metrics + assert [m.name for m in restored.metrics] == [m.name for m in panel.metrics]