Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions spatialtissuepy/mcp/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from __future__ import annotations

import json
import os
import pickle
import threading
import uuid
import shutil
import uuid
from dataclasses import dataclass, field
Expand Down Expand Up @@ -190,8 +193,7 @@ def store_data(
Data to store.
"""
path = self.base_dir / session_id / "data" / f"{key}.pkl"
with open(path, "wb") as f:
pickle.dump(data, f)
self._atomic_pickle_dump(data, path)

meta = self._load_metadata(session_id)
if meta and key not in meta.data_keys:
Expand Down Expand Up @@ -395,8 +397,7 @@ def store_panel(
) -> None:
"""Store a StatisticsPanel object."""
path = self.base_dir / session_id / "panels" / f"{key}.pkl"
with open(path, "wb") as f:
pickle.dump(panel, f)
self._atomic_pickle_dump(panel, path)

meta = self._load_metadata(session_id)
if meta and key not in meta.panel_keys:
Expand Down Expand Up @@ -509,6 +510,43 @@ def _touch_session(self, session_id: str) -> None:
meta.last_accessed = datetime.now().isoformat()
self._save_metadata(session_id, meta)

@staticmethod
def _atomic_pickle_dump(obj: Any, path: Path) -> None:
"""Pickle ``obj`` to ``path`` atomically.

Writes to a sibling temp file first and then uses ``os.replace`` to
atomically swap it into place. This prevents truncated / partial
pickle files when two tool calls update the same data or panel
concurrently -- the symptom was ``pickle.UnpicklingError: Ran out
of input`` on a subsequent read, because one writer had truncated
the file after another had started writing.
"""
path.parent.mkdir(parents=True, exist_ok=True)
# Unique per-call temp name: pid + thread id + uuid. Needed because
# two threads in the same process share the same pid, so a
# pid-only suffix would collide under in-process concurrency.
tmp_suffix = (
f".tmp.{os.getpid()}.{threading.get_ident()}.{uuid.uuid4().hex[:8]}"
)
tmp_path = path.with_suffix(path.suffix + tmp_suffix)
try:
with open(tmp_path, "wb") as f:
pickle.dump(obj, f)
f.flush()
try:
os.fsync(f.fileno())
except OSError:
# fsync is best-effort -- some filesystems reject it
pass
os.replace(tmp_path, path)
except Exception:
# Clean up the temp file so we don't leak partial writes
try:
tmp_path.unlink()
except OSError:
pass
raise

def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get detailed information about a session."""
meta = self._load_metadata(session_id)
Expand Down
58 changes: 58 additions & 0 deletions spatialtissuepy/summary/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,44 @@ def __repr__(self) -> str:
custom_str = ", custom=True" if self.is_custom else ""
return f"MetricInfo(name={self.name!r}, category={self.category!r}{custom_str})"

def __reduce__(self):
"""Pickle by registry name, not by function reference.

The ``func`` attribute is set during registration to the unwrapped
metric function, but the decorator returns a ``functools.wraps``
wrapper as the module attribute. Pickle resolves functions by
``__qualname__``, so it finds the wrapper at that name and refuses
to serialise the stored unwrapped function ("not the same object as
...") -- which is the root cause of the long-standing panel
serialisation failure.

Instead, we pickle a MetricInfo as its name and re-resolve from the
global registry on unpickle. That's safe for any registered metric
(built-in or custom) as long as the registry is populated in the
unpickling process -- which it is for built-ins at import time.
Custom metrics must be re-registered with ``register_custom_metric``
before loading; otherwise unpickling raises a clear error
identifying which metric is missing.

Inline (per-panel) MetricInfo objects do not live in the registry
and cannot be pickled -- they raise a ``TypeError`` here, which is
the correct behavior: panels with inline functions have always been
flagged as non-JSON-serialisable, and pickling them would produce
objects that can't be loaded in a fresh process.
"""
from spatialtissuepy.summary.registry import _resolve_metric_for_pickle

registered = _registry._metrics.get(self.name)
if registered is None:
raise TypeError(
f"MetricInfo {self.name!r} is not in the global registry "
"and cannot be pickled. Inline metrics added via "
"StatisticsPanel.add_custom_function() are per-panel only; "
"register them globally with register_custom_metric() "
"before saving the containing panel."
)
return (_resolve_metric_for_pickle, (self.name,))


def _validate_metric_function(
func: Callable,
Expand Down Expand Up @@ -869,6 +907,26 @@ def get_metric(name: str) -> MetricInfo:
return _registry.get(name)


def _resolve_metric_for_pickle(name: str) -> 'MetricInfo':
"""Unpickle hook for MetricInfo: fetch the live registry entry by name.

Raises a descriptive error when the name is not registered in the
current process (e.g. a custom metric that was not re-registered
before loading a saved panel). Kept as a module-level function so
pickle can import it by qualname.
"""
try:
return _registry.get(name)
except KeyError as exc:
raise RuntimeError(
f"Cannot unpickle MetricInfo {name!r}: no metric with that "
"name is registered in the current process. Built-in metrics "
"register automatically at import time; custom metrics must "
"be re-registered with register_custom_metric() before "
"loading a saved panel that references them."
) from exc


def list_metrics(
category: Optional[str] = None,
include_custom: bool = True
Expand Down
168 changes: 168 additions & 0 deletions tests/test_summary_panel_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Regression tests for pickling StatisticsPanel / MetricInfo.

Background
----------
Panels used to fail to pickle with

PicklingError: Can't pickle <function cell_counts at 0x...>: it's not
the same object as spatialtissuepy.summary.population.cell_counts

because ``register()`` stores the raw ``func`` in ``MetricInfo`` while
returning a ``functools.wraps`` wrapper as the module attribute, so
pickle's qualname-based function lookup found a different object than the
one being serialised. Parallel writes to the same panel file also
produced truncated pickle blobs, surfacing on the next read as
``Ran out of input``.

The fix adds ``MetricInfo.__reduce__`` so panels pickle metric entries by
*name* and re-resolve from the live registry on unpickle, plus an atomic
``_atomic_pickle_dump`` helper in ``SessionManager`` that writes through
a temp file + ``os.replace`` so parallel writers cannot leave a partial
file on disk.
"""

from __future__ import annotations

import pickle
import threading
import time
from pathlib import Path
from typing import Dict

import pytest

from spatialtissuepy.summary import (
StatisticsPanel,
load_panel,
register_custom_metric,
unregister_custom_metric,
)
from spatialtissuepy.summary.registry import get_metric


class TestPanelPickle:
"""Panels should pickle cleanly regardless of how metrics were registered."""

def test_builtin_panel_roundtrips(self):
panel = load_panel("comprehensive")
assert panel.n_metrics > 0

blob = pickle.dumps(panel)
restored = pickle.loads(blob)

assert restored.n_metrics == panel.n_metrics
assert [m.name for m in restored.metrics] == [m.name for m in panel.metrics]

@pytest.mark.parametrize("preset", ["basic", "spatial", "neighborhood", "comprehensive"])
def test_all_preset_panels_pickle(self, preset: str):
panel = load_panel(preset)
blob = pickle.dumps(panel)
restored = pickle.loads(blob)
assert restored.n_metrics == panel.n_metrics

def test_registered_custom_metric_roundtrips(self):
name = "test_pickle_ratio_unique"

@register_custom_metric(name=name, description="Test custom metric")
def _custom(data) -> Dict[str, float]:
return {name: 0.5}

try:
panel = StatisticsPanel(name="test")
panel.add(name)
blob = pickle.dumps(panel)
restored = pickle.loads(blob)
assert restored.n_metrics == 1
assert restored.metrics[0].name == name
finally:
unregister_custom_metric(name)

def test_inline_metric_refuses_to_pickle(self):
"""Inline functions live only in the panel and can't survive a pickle round-trip."""
panel = StatisticsPanel(name="inline")

def _inline(data) -> Dict[str, float]:
return {"v": 1.0}

panel.add_custom_function("inline_unpicklable", _inline)

with pytest.raises(TypeError, match="not in the global registry"):
pickle.dumps(panel)

def test_missing_registry_entry_fails_loudly_on_unpickle(self):
"""Unpickling a panel whose custom metric isn't re-registered raises a clear error."""
name = "test_unpickle_missing_unique"

@register_custom_metric(name=name, description="Test")
def _m(data) -> Dict[str, float]:
return {name: 1.0}

panel = StatisticsPanel(name="test")
panel.add(name)
blob = pickle.dumps(panel)

unregister_custom_metric(name)
with pytest.raises(RuntimeError, match="no metric with that name is registered"):
pickle.loads(blob)


class TestAtomicPickleDump:
"""Concurrent writes must never leave a truncated pickle on disk."""

def test_atomic_write_survives_concurrent_writers(self, tmp_path: Path):
from spatialtissuepy.mcp.session import SessionManager

path = tmp_path / "payload.pkl"
payloads = [
{"writer": i, "data": list(range(1000))} for i in range(8)
]

def write(p):
SessionManager._atomic_pickle_dump(p, path)

threads = [threading.Thread(target=write, args=(p,)) for p in payloads]
for t in threads:
t.start()
for t in threads:
t.join()

# Whichever writer won, the result must be a valid, fully-written pickle
loaded = pickle.loads(path.read_bytes())
assert loaded["writer"] in {p["writer"] for p in payloads}
assert loaded["data"] == list(range(1000))

def test_failed_write_leaves_no_partial_file(self, tmp_path: Path):
from spatialtissuepy.mcp.session import SessionManager

path = tmp_path / "will_fail.pkl"

class Unpicklable:
def __reduce__(self):
raise RuntimeError("intentional failure")

with pytest.raises(RuntimeError, match="intentional"):
SessionManager._atomic_pickle_dump(Unpicklable(), path)

assert not path.exists(), "failed write should not have produced a file"
# Temp file should also be cleaned up
siblings = list(tmp_path.iterdir())
assert all("tmp" not in s.name for s in siblings), f"temp file leaked: {siblings}"


class TestPanelSessionRoundtrip:
"""End-to-end: SessionManager can store and retrieve a panel with metrics."""

def test_store_and_load_builtin_panel(self, tmp_path: Path):
from spatialtissuepy.mcp.session import SessionManager

mgr = SessionManager(base_dir=tmp_path)
sid = mgr.create_session()

panel = load_panel("basic")
mgr.store_panel(sid, "fingerprint", panel)

restored = mgr.load_panel(sid, "fingerprint")
assert restored is not None
assert restored.n_metrics == panel.n_metrics
assert [m.name for m in restored.metrics] == [m.name for m in panel.metrics]
Loading