Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 101 additions & 49 deletions src/surreal_memory/hooks/post_tool_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
Receives JSON on stdin with tool_name, tool_input, tool_output fields.
Writes one JSONL line to ~/.surrealmemory/tool_events.jsonl.

This hook does NOT access SQLite or perform encoding — all processing
This hook does NOT access SurrealDB or perform encoding — all processing
is deferred to the consolidation cycle.

Performance design (ac2a001 perf subset):
- Zero heavy imports on the hot path (stdlib only except optional tomllib)
- _NOISE_TOOLS fast-path skips the highest-frequency no-value tools
- Lock-safe JSONL append via fcntl.flock (POSIX) with a write-then-rename
fallback for platforms without flock
- Session ID: checks CLAUDE_SESSION_ID and CODEX_SESSION_ID
"""

from __future__ import annotations
Expand All @@ -18,6 +25,7 @@
import os
import sys
import time
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

Expand All @@ -28,6 +36,29 @@
# Max size for stdout response JSON
_MAX_TOOL_OUTPUT_PREVIEW = 100

# High-frequency tools that never produce useful memory signal.
# Checked before any config I/O for a fast zero-cost exit.
_NOISE_TOOLS: frozenset[str] = frozenset(
{
"TodoRead",
"TodoWrite",
"WebSearch",
"WebFetch",
"mcp__Claude_Preview__preview_logs",
"mcp__Claude_Preview__preview_console_logs",
"mcp__Claude_Preview__preview_network",
"smem_recall",
"smem_session",
"smem_stats",
"smem_index",
}
)


def _get_session_id() -> str:
"""Return the current session ID from env (Claude or Codex)."""
return os.environ.get("CLAUDE_SESSION_ID") or os.environ.get("CODEX_SESSION_ID", "")


def _read_stdin() -> dict[str, Any]:
"""Read Claude Code PostToolUse hook JSON from stdin."""
Expand All @@ -41,49 +72,50 @@ def _read_stdin() -> dict[str, Any]:
return {}


def _get_data_dir() -> Path:
"""Return the surreal-memory data directory."""
custom = os.environ.get("SURREAL_MEMORY_DIR", "")
return Path(custom) if custom else (Path.home() / ".surrealmemory")


def _get_buffer_path() -> Path:
"""Get the JSONL buffer file path."""
data_dir = Path(os.environ.get("SURREAL_MEMORY_DIR", "")) or (Path.home() / ".surrealmemory")
return data_dir / "tool_events.jsonl"
return _get_data_dir() / "tool_events.jsonl"


def _is_enabled() -> bool:
"""Quick check if tool memory is enabled via config.
def _read_tool_memory_config() -> dict[str, Any]:
"""Read [tool_memory] section from config.toml once.

Reads only the [tool_memory] section from config.toml.
Defaults to True if config is missing or section absent.
Returns an empty dict if the config is missing, unreadable, or
the section is absent. Called at most once per hook invocation.
"""
data_dir = Path(os.environ.get("SURREAL_MEMORY_DIR", "")) or (Path.home() / ".surrealmemory")
config_path = data_dir / "config.toml"
config_path = _get_data_dir() / "config.toml"
if not config_path.exists():
return True
return {}
try:
import tomllib

with open(config_path, "rb") as f:
data = tomllib.load(f)
return bool(data.get("tool_memory", {}).get("enabled", True))
result: dict[str, Any] = data.get("tool_memory", {})
return result
except Exception:
logger.debug("Failed to read tool_memory.enabled from config", exc_info=True)
return True
logger.debug("Failed to read tool_memory config", exc_info=True)
return {}


def _get_blacklist() -> list[str]:
"""Read blacklist from config.toml."""
data_dir = Path(os.environ.get("SURREAL_MEMORY_DIR", "")) or (Path.home() / ".surrealmemory")
config_path = data_dir / "config.toml"
if not config_path.exists():
return []
try:
import tomllib
def _is_enabled(tm_cfg: dict[str, Any]) -> bool:
"""Check if tool memory is enabled.

with open(config_path, "rb") as f:
data = tomllib.load(f)
bl = data.get("tool_memory", {}).get("blacklist", [])
return list(bl) if isinstance(bl, (list, tuple)) else []
except Exception:
logger.debug("Failed to read tool_memory.blacklist from config", exc_info=True)
return []
Defaults to True if the key is absent.
"""
return bool(tm_cfg.get("enabled", True))


def _get_blacklist(tm_cfg: dict[str, Any]) -> list[str]:
"""Return the blacklist from [tool_memory] config."""
bl = tm_cfg.get("blacklist", [])
return list(bl) if isinstance(bl, (list, tuple)) else []


def _truncate_args(tool_input: Any) -> str:
Expand All @@ -97,39 +129,54 @@ def _truncate_args(tool_input: Any) -> str:
return raw[:_MAX_ARGS_CHARS]


def _format_event(hook_input: dict[str, Any]) -> dict[str, Any]:
"""Format hook input into a JSONL event dict."""
from surreal_memory.utils.timeutils import utcnow
def _utcnow_iso() -> str:
"""Return current UTC time as naive ISO string (stdlib only, no imports)."""
return datetime.now(UTC).replace(tzinfo=None).isoformat()


def _format_event(hook_input: dict[str, Any]) -> dict[str, Any]:
"""Format hook input into a JSONL event dict (stdlib only)."""
tool_name = hook_input.get("tool_name", hook_input.get("tool", ""))
server_name = hook_input.get("server_name", "")
tool_input = hook_input.get("tool_input", {})
tool_error = hook_input.get("tool_error")
duration_ms = hook_input.get("duration_ms", 0)
session_id = os.environ.get("CLAUDE_SESSION_ID", "")

return {
"tool_name": str(tool_name),
"server_name": str(server_name),
"args_summary": _truncate_args(tool_input),
"success": tool_error is None,
"duration_ms": int(duration_ms) if isinstance(duration_ms, (int, float)) else 0,
"session_id": session_id,
"session_id": _get_session_id(),
"task_context": "", # Populated by processing engine if session is active
"created_at": utcnow().isoformat(),
"created_at": _utcnow_iso(),
}


def _append_to_buffer(event: dict[str, Any], buffer_path: Path) -> bool:
"""Append one JSONL line to the buffer file.
"""Append one JSONL line to the buffer file, lock-safe.

Uses fcntl.flock (POSIX) when available. Falls back to a plain
append on platforms that lack flock (Windows, some embedded).
Returns True on success, False on failure.
"""
try:
buffer_path.parent.mkdir(parents=True, exist_ok=True)
line = json.dumps(event, ensure_ascii=False, default=str)
with open(buffer_path, "a", encoding="utf-8") as f:
f.write(line + "\n")
line = json.dumps(event, ensure_ascii=False, default=str) + "\n"
try:
import fcntl

with open(buffer_path, "a", encoding="utf-8") as f:
fcntl.flock(f, fcntl.LOCK_EX)
try:
f.write(line)
finally:
fcntl.flock(f, fcntl.LOCK_UN)
except ImportError:
# Non-POSIX platform — plain append (best-effort)
with open(buffer_path, "a", encoding="utf-8") as f:
f.write(line)
return True
except OSError:
return False
Expand All @@ -155,26 +202,31 @@ def main() -> None:
"""Entry point for the PostToolUse hook."""
start = time.monotonic()

# Fast exit if disabled
if not _is_enabled():
# Output empty JSON for hook response
sys.stdout.write("{}\n")
return

hook_input = _read_stdin()
if not hook_input:
sys.stdout.write("{}\n")
return

tool_name = hook_input.get("tool_name", hook_input.get("tool", ""))
tool_name = str(hook_input.get("tool_name", hook_input.get("tool", "")))
if not tool_name:
sys.stdout.write("{}\n")
return

# Check blacklist
blacklist = _get_blacklist()
# Fast-path: skip high-frequency noise tools before any config I/O
if tool_name in _NOISE_TOOLS:
sys.stdout.write("{}\n")
return

# Read config once; derive enabled + blacklist from it
tm_cfg = _read_tool_memory_config()

if not _is_enabled(tm_cfg):
sys.stdout.write("{}\n")
return

blacklist = _get_blacklist(tm_cfg)
for prefix in blacklist:
if str(tool_name).startswith(prefix):
if tool_name.startswith(prefix):
sys.stdout.write("{}\n")
return

Expand All @@ -183,7 +235,7 @@ def main() -> None:
buffer_path = _get_buffer_path()
_append_to_buffer(event, buffer_path)

# Periodic buffer rotation check (every ~100 calls, cheap stat check)
# Periodic buffer rotation check (cheap stat check)
try:
if buffer_path.exists() and buffer_path.stat().st_size > 5_000_000: # > 5MB
_check_buffer_rotation(buffer_path)
Expand Down
Loading
Loading