diff --git a/.coverage b/.coverage index 4d6d4017a..dd581c215 100644 Binary files a/.coverage and b/.coverage differ diff --git a/docs/architecture/GRAPHITI_ALIGNMENT_SPECS.md b/docs/architecture/GRAPHITI_ALIGNMENT_SPECS.md new file mode 100644 index 000000000..d1660e748 --- /dev/null +++ b/docs/architecture/GRAPHITI_ALIGNMENT_SPECS.md @@ -0,0 +1,762 @@ +# Graphiti Alignment Specs — 5 Sanity-Check Gaps + +> Generated from the failing items in the Graphiti Memory Brief (§4 sanity checklist). +> Each spec targets one `[ ]` item and brings it to `[x]`. + +--- + +## SPEC-1: Community / Summary Layer + +**Checklist item**: §4.1 — *Community/summary layer — Graphiti supports community detection for high-level summaries. Not implemented.* + +### Problem + +Graphiti's Tier 3 (Community Subgraph) clusters strongly-connected entities and produces high-level summaries via label propagation. SynapseFlow never consumes these summaries, so for patients with long conversation histories there is no aggregated view — every retrieval must assemble context from individual episodes and entities. + +### Goal + +Expose community summaries from the episodic graph and make them available as an optional context source in `get_conversation_context()`. + +### Design + +``` +EpisodicMemoryService +├── get_conversation_context() ← existing +│ └── community_summaries: [...] ← NEW field in returned dict +│ +├── get_community_summaries() ← NEW method +│ ├── calls graphiti_core community API +│ ├── filters by group_id (patient isolation) +│ └── returns List[CommunitySummary] +│ +└── CommunitySummary (dataclass) ← NEW + ├── community_id: str + ├── summary: str + ├── entity_count: int + ├── key_entities: List[str] + └── updated_at: datetime +``` + +### Files to Change + +| File | Change | +|------|--------| +| `src/application/services/episodic_memory_service.py` | Add `CommunitySummary` dataclass, `get_community_summaries()` method, extend `get_conversation_context()` return dict | +| `tests/test_episodic_memory.py` | Add tests for `get_community_summaries()` and extended context | + +### Implementation Details + +**`CommunitySummary` dataclass** — add after `ConversationEpisode`: + +```python +@dataclass +class CommunitySummary: + """A community-level summary from Graphiti's community subgraph.""" + community_id: str + summary: str + entity_count: int + key_entities: List[str] + updated_at: Optional[datetime] = None +``` + +**`get_community_summaries()`** — add to `EpisodicMemoryService`: + +```python +async def get_community_summaries( + self, + patient_id: str, + limit: int = 5, +) -> List[CommunitySummary]: +``` + +Implementation strategy: +1. Call `search()` with the `COMBINED_HYBRID_SEARCH_CROSS_ENCODER` recipe, which already returns `results.communities` (a list of `CommunityNode` objects per Graphiti's `SearchResults` schema). +2. Filter by `group_id` matching `patient_id`. +3. Convert to `CommunitySummary` dataclasses. +4. If `results.communities` is empty or not present (version dependent), fall back to returning `[]` with a debug log — this keeps the feature gracefully degradable. + +**Extend `get_conversation_context()`**: + +Add a `community_summaries` key to the returned dict, populated by calling `get_community_summaries()`. This is appended *after* the existing recent/related/entities retrieval so it does not block core context assembly on failure. + +```python +# In get_conversation_context(), after entities retrieval: +community_summaries = [] +try: + community_summaries = await self.get_community_summaries( + patient_id=patient_id, + limit=3, + ) +except Exception as e: + logger.debug(f"Community summaries unavailable: {e}") + +return { + "recent_episodes": ..., + "related_episodes": ..., + "entities": entities, + "community_summaries": [self._summary_to_dict(s) for s in community_summaries], + "total_context_items": len(recent) + len(related) + len(entities) + len(community_summaries), +} +``` + +### Test Plan + +| Test | Asserts | +|------|---------| +| `test_get_community_summaries_returns_results` | Mock `search()` returning community nodes → service returns `CommunitySummary` list | +| `test_get_community_summaries_empty_graceful` | Mock `search()` returning no communities → returns `[]` | +| `test_get_community_summaries_error_graceful` | Mock `search()` raising → returns `[]`, no exception propagated | +| `test_context_includes_community_summaries` | `get_conversation_context()` return dict has `community_summaries` key | + +### Acceptance Criteria + +- `get_conversation_context()` response includes `community_summaries` field. +- If Graphiti version doesn't expose communities, the field is `[]` and no error is raised. +- Community summaries are filtered by `patient_id` group — no cross-patient leakage. + +--- + +## SPEC-2: Temporal Conflict Resolution + +**Checklist item**: §4.2 — *Temporal conflict resolution — Graphiti invalidates outdated edges via bi-temporal model, but this is not propagated to Neo4j.* + +### Problem + +Graphiti edges carry four timestamps (`created_at`, `expired_at`, `valid_at`, `invalid_at`). When new information contradicts a prior fact, Graphiti invalidates the old edge by setting `expired_at` — it never deletes it. The crystallization pipeline ignores these timestamps entirely: it reads entity names and confidence scores but discards the temporal validity markers. This means: + +1. Outdated facts transferred to Neo4j remain active indefinitely. +2. Two contradictory facts (e.g., "Patient takes Metformin" and later "Patient stopped Metformin") both live in the DIKW graph as current truths. +3. Point-in-time queries against the DIKW graph are impossible. + +### Goal + +Propagate Graphiti's bi-temporal metadata through the crystallization pipeline into Neo4j DIKW nodes and edges, and add a mechanism to invalidate superseded facts. + +### Design + +``` +Graphiti Edge +├── created_at (DB time) +├── expired_at (DB invalidation time) +├── valid_at (real-world start) +└── invalid_at (real-world end) + │ + ▼ +CrystallizationService.crystallize_entities() + │ reads temporal fields from Graphiti edges + │ propagates to Neo4j node properties + ▼ +Neo4j DIKW Node +├── valid_from: datetime ← NEW (mapped from valid_at) +├── valid_until: datetime ← NEW (mapped from invalid_at) +├── invalidated_at: datetime← NEW (mapped from expired_at) +├── is_current: bool ← NEW (computed: valid_until is null) +└── (existing fields) + +CrystallizationService._resolve_temporal_conflicts() ← NEW + │ when new fact contradicts existing: + │ mark old entity's valid_until = new fact's valid_from + │ set old entity's is_current = false + ▼ +``` + +### Files to Change + +| File | Change | +|------|--------| +| `src/application/services/crystallization_service.py` | Add `_resolve_temporal_conflicts()`, update `crystallize_entities()` to extract and propagate temporal fields, update `_create_perception_entity()` to include temporal properties | +| `src/application/services/entity_resolver.py` | Update `merge_for_crystallization()` to handle temporal conflict during merge | +| `tests/test_crystallization_pipeline.py` | Add temporal conflict resolution tests | + +### Implementation Details + +**New temporal fields on PERCEPTION entities** — update `_create_perception_entity()`: + +```python +properties = { + # ... existing fields ... + "valid_from": source_data.get("valid_at", datetime.utcnow().isoformat()), + "valid_until": source_data.get("invalid_at"), # None = still current + "invalidated_at": source_data.get("expired_at"), # None = not invalidated + "is_current": source_data.get("invalid_at") is None, +} +``` + +**New method `_resolve_temporal_conflicts()`** in `CrystallizationService`: + +```python +async def _resolve_temporal_conflicts( + self, + entity_name: str, + entity_type: str, + new_valid_from: Optional[str], +) -> int: + """ + Invalidate existing DIKW entities that are superseded by a newer fact. + + When the same entity (by name+type) appears with a newer valid_from, + all previous is_current=true versions get marked: + valid_until = new_valid_from + is_current = false + invalidated_at = now() + + Returns: + Number of entities invalidated. + """ +``` + +Query pattern: + +```cypher +MATCH (n:Entity) +WHERE toLower(n.name) = $name + AND n.entity_type = $type + AND n.is_current = true + AND n.valid_from < $new_valid_from +SET n.valid_until = $new_valid_from, + n.is_current = false, + n.invalidated_at = datetime() +RETURN count(n) as invalidated +``` + +**Call site** — in `crystallize_entities()`, after the entity is created or merged, if `source_data` includes a `valid_at` field: + +```python +if entity_data.get("valid_at"): + invalidated = await self._resolve_temporal_conflicts( + entity_name=name, + entity_type=entity_type, + new_valid_from=entity_data["valid_at"], + ) + if invalidated > 0: + logger.info(f"Temporal conflict: invalidated {invalidated} prior version(s) of '{name}'") +``` + +**Update `merge_for_crystallization()` in EntityResolver** — when merging, if the incoming data has `valid_at`/`invalid_at`, store them and update `is_current`: + +```python +# In merge_for_crystallization(), within the updates dict: +if "valid_at" in new_data: + updates["valid_from"] = new_data["valid_at"] +if "invalid_at" in new_data: + updates["valid_until"] = new_data["invalid_at"] + updates["is_current"] = False +``` + +**Extracting temporal data from Graphiti** — in `crystallize_from_graphiti()`, after the `search()` call, extract temporal metadata from edges: + +```python +for node in results.nodes[:limit]: + # Gather edge temporal data for this node + node_edges = [e for e in (results.edges or []) if e.source_node_uuid == node.uuid or e.target_node_uuid == node.uuid] + latest_edge = max(node_edges, key=lambda e: e.created_at, default=None) if node_edges else None + + entities.append({ + # ... existing fields ... + "valid_at": latest_edge.valid_at.isoformat() if latest_edge and hasattr(latest_edge, 'valid_at') and latest_edge.valid_at else None, + "invalid_at": latest_edge.invalid_at.isoformat() if latest_edge and hasattr(latest_edge, 'invalid_at') and latest_edge.invalid_at else None, + "expired_at": latest_edge.expired_at.isoformat() if latest_edge and hasattr(latest_edge, 'expired_at') and latest_edge.expired_at else None, + }) +``` + +### Test Plan + +| Test | Asserts | +|------|---------| +| `test_new_entity_gets_temporal_fields` | PERCEPTION entity has `valid_from`, `valid_until=None`, `is_current=True` | +| `test_contradicting_fact_invalidates_old` | Creating "Patient takes Metformin" then "Patient stopped Metformin" → first entity gets `is_current=False`, `valid_until` set | +| `test_merge_preserves_temporal_on_update` | Merging with `valid_at`/`invalid_at` stores them on the Neo4j node | +| `test_no_temporal_data_defaults_safe` | Entity without temporal fields gets `valid_from=now`, `is_current=True` | +| `test_invalidation_scoped_to_name_and_type` | Invalidating "Metformin/Medication" does NOT affect "Metformin/Allergy" | + +### Acceptance Criteria + +- All crystallized PERCEPTION entities have `valid_from`, `valid_until`, `is_current` properties. +- When a newer version of the same entity (name+type) is crystallized, older versions are marked `is_current=false`. +- Old entities are never deleted — only invalidated (preserving history per Graphiti best practice). +- Entities without temporal data from Graphiti default to `valid_from=now, is_current=true`. + +--- + +## SPEC-3: Bound Search Results + +**Checklist item**: §4.3 — *Bound search results — Some search paths don't limit results explicitly (e.g., `crystallize_from_graphiti` uses a default limit of 100 but the search itself doesn't pass `num_results`).* + +### Problem + +Several search call sites in the codebase do not pass explicit `num_results` to the Graphiti search API. While the caller may slice results afterward (`results.nodes[:limit]`), the underlying search may return unbounded intermediate results, wasting tokens and compute: + +1. `crystallize_from_graphiti()` — line 536-542: `search()` called without `num_results`, results sliced post-hoc at `[:limit]`. +2. `search_episodes()` — line 430-436: `search()` called without `num_results`, results sliced at `[:limit]`. +3. `get_related_entities()` — line 473-478: `search()` called without `num_results`, results sliced at `[:limit]`. +4. `_get_existing_entities()` in EntityResolver — line 217-221: `LIMIT 100` is hardcoded in the Cypher string rather than parameterized. + +### Goal + +Ensure every search path passes an explicit bound to the underlying search API, and parameterize all Cypher `LIMIT` clauses. + +### Files to Change + +| File | Change | +|------|--------| +| `src/application/services/episodic_memory_service.py` | Pass `num_results` to `search()` in `search_episodes()` and `get_related_entities()` | +| `src/application/services/crystallization_service.py` | Pass `num_results` to `search()` in `crystallize_from_graphiti()` | +| `src/application/services/entity_resolver.py` | Parameterize `LIMIT` in `_get_existing_entities()` | +| `tests/test_episodic_memory.py` | Verify `num_results` is passed in mock assertions | +| `tests/test_crystallization_pipeline.py` | Verify `num_results` is passed in mock assertions | + +### Implementation Details + +**`search_episodes()`** — add `num_results` to the search config: + +The `COMBINED_HYBRID_SEARCH_CROSS_ENCODER` recipe is a `SearchConfig` object. Graphiti's `search()` function accepts a `config` parameter which has a `limit` field. We should create a modified config per call: + +```python +from graphiti_core.search.search_config import SearchConfig +from copy import deepcopy + +async def search_episodes(self, patient_id, query, limit=10, session_id=None): + # ... + config = deepcopy(COMBINED_HYBRID_SEARCH_CROSS_ENCODER) + config.limit = limit + + results: SearchResults = await search( + clients=self.graphiti.clients, + query=query, + group_ids=group_ids, + search_filter=SearchFilters(), + config=config, + ) + # No longer need post-hoc slicing on episodes + return [self._convert_episode(ep, patient_id) for ep in results.episodes] +``` + +If `SearchConfig` doesn't have a `limit` field (version-dependent), use `num_results` kwarg: + +```python +results = await search( + clients=self.graphiti.clients, + query=query, + group_ids=group_ids, + search_filter=SearchFilters(), + config=COMBINED_HYBRID_SEARCH_CROSS_ENCODER, + num_results=limit, +) +``` + +Check Graphiti's `search()` signature at implementation time and use whichever parameter is available. If neither exists, keep post-hoc slicing but add a `# BOUNDED:` comment for auditability. + +**`get_related_entities()`** — same approach, pass `limit` to the search call. + +**`crystallize_from_graphiti()`** — same approach, pass `limit` to the search call. + +**`_get_existing_entities()`** in EntityResolver — parameterize the Cypher: + +```python +query = f""" +MATCH (e:{entity_type}) +RETURN e.id AS id, e.name AS name, e AS properties +LIMIT $limit +""" +results = await self.backend.query(query, {"limit": limit}) +``` + +Add a `limit` parameter to the method signature (default 100): + +```python +async def _get_existing_entities( + self, + entity_type: str, + context: Dict[str, Any], + limit: int = 100, +) -> List[Dict[str, Any]]: +``` + +### Test Plan + +| Test | Asserts | +|------|---------| +| `test_search_episodes_passes_limit` | Mock `search()` → assert `num_results` or `config.limit` matches caller's `limit` | +| `test_get_related_entities_passes_limit` | Same pattern | +| `test_crystallize_from_graphiti_passes_limit` | Same pattern | +| `test_entity_resolver_parameterized_limit` | Mock `backend.query()` → assert query params include `limit` | + +### Acceptance Criteria + +- No search call site relies solely on post-hoc slicing for bounding. +- Every `search()` call passes an explicit result count. +- Every Cypher `LIMIT` clause uses a parameter, not a hardcoded string. + +--- + +## SPEC-4: Memory Invalidation / Expiration + +**Checklist item**: §4.4 — *Memory invalidation/expiration — No mechanism to mark episodic or DIKW entities as expired/outdated.* + +### Problem + +There is currently no way to expire or invalidate entities in either the episodic graph (FalkorDB/Graphiti) or the DIKW graph (Neo4j). Entities accumulate indefinitely. For a medical domain this is particularly problematic — discontinued medications, resolved conditions, and outdated vitals remain "active" forever. + +This is related to but distinct from SPEC-2 (temporal conflict resolution). SPEC-2 handles automatic invalidation when a contradicting fact arrives. This spec covers explicit invalidation via API and TTL-based expiration for stale entities. + +### Goal + +Add an invalidation mechanism for DIKW entities (explicit API + TTL-based staleness detection), and an expiration sweep for episodic memory. + +### Design + +``` + ┌────────────────────────────┐ + │ Invalidation Sources │ + ├────────────────────────────┤ + │ 1. Explicit API call │ + │ 2. Temporal conflict (SPEC-2)│ + │ 3. TTL staleness sweep │ + └─────────┬──────────────────┘ + │ + ▼ + ┌───────────────────────────────┐ + │ MemoryInvalidationService │ ← NEW + ├───────────────────────────────┤ + │ invalidate_entity() │ + │ invalidate_by_query() │ + │ sweep_stale_entities() │ + │ get_invalidation_stats() │ + └───────────────────────────────┘ + │ + ┌───────────────┼───────────────┐ + ▼ ▼ ▼ + Neo4j DIKW FalkorDB Episodic Stats/Audit + SET is_current (future: TTL on log invalidated + = false episode nodes) entities +``` + +### Files to Change + +| File | Change | +|------|--------| +| `src/application/services/memory_invalidation_service.py` | **NEW** — `MemoryInvalidationService` class | +| `src/application/services/crystallization_service.py` | Import and call `MemoryInvalidationService.sweep_stale_entities()` in the periodic batch | +| `src/composition_root.py` | Wire `MemoryInvalidationService` into the bootstrap | +| `tests/test_memory_invalidation.py` | **NEW** — unit tests | + +### Implementation Details + +**`MemoryInvalidationService`** — new service: + +```python +@dataclass +class InvalidationConfig: + """Configuration for memory invalidation.""" + stale_threshold_days: int = 90 # PERCEPTION entities not seen in N days + stale_check_enabled: bool = True + episodic_ttl_days: Optional[int] = None # None = no auto-expiry for episodes + +@dataclass +class InvalidationResult: + """Result of an invalidation operation.""" + entities_invalidated: int + entity_ids: List[str] + reason: str + timestamp: datetime + + +class MemoryInvalidationService: + + def __init__( + self, + neo4j_backend: KnowledgeGraphBackend, + config: Optional[InvalidationConfig] = None, + ): + self.neo4j_backend = neo4j_backend + self.config = config or InvalidationConfig() + + async def invalidate_entity( + self, + entity_id: str, + reason: str = "manual", + ) -> InvalidationResult: + """ + Explicitly invalidate a single DIKW entity. + + Sets is_current=false, invalidated_at=now(), invalidation_reason=reason. + Does NOT delete the entity. + """ +``` + +Cypher for explicit invalidation: + +```cypher +MATCH (n:Entity {id: $entity_id}) +WHERE n.is_current = true OR n.is_current IS NULL +SET n.is_current = false, + n.invalidated_at = datetime(), + n.invalidation_reason = $reason, + n.valid_until = COALESCE(n.valid_until, datetime()) +RETURN n.id as id +``` + +**`invalidate_by_query()`** — invalidate multiple entities matching criteria: + +```python +async def invalidate_by_query( + self, + patient_id: Optional[str] = None, + entity_type: Optional[str] = None, + entity_name: Optional[str] = None, + reason: str = "bulk_invalidation", +) -> InvalidationResult: +``` + +Builds a dynamic Cypher WHERE clause from provided filters. + +**`sweep_stale_entities()`** — TTL-based staleness: + +```python +async def sweep_stale_entities(self) -> InvalidationResult: + """ + Find and invalidate PERCEPTION entities not observed recently. + + Uses config.stale_threshold_days to determine staleness. + Only targets PERCEPTION layer (higher layers are considered validated). + """ +``` + +Cypher: + +```cypher +MATCH (n:Entity) +WHERE n.dikw_layer = 'PERCEPTION' + AND (n.is_current = true OR n.is_current IS NULL) + AND n.last_observed < $cutoff_date +SET n.is_current = false, + n.invalidated_at = datetime(), + n.invalidation_reason = 'stale_ttl' +RETURN n.id as id, n.name as name +``` + +**Integration with periodic crystallization** — in `CrystallizationService._periodic_crystallization()`, after the existing batch logic, optionally run the sweep: + +```python +if self.invalidation_service and self.invalidation_service.config.stale_check_enabled: + sweep_result = await self.invalidation_service.sweep_stale_entities() + if sweep_result.entities_invalidated > 0: + logger.info(f"Stale sweep: invalidated {sweep_result.entities_invalidated} entities") +``` + +**Composition root** — instantiate `MemoryInvalidationService` alongside the crystallization pipeline and inject it: + +```python +invalidation_service = MemoryInvalidationService( + neo4j_backend=neo4j_backend, + config=InvalidationConfig( + stale_threshold_days=int(os.getenv("MEMORY_STALE_THRESHOLD_DAYS", "90")), + stale_check_enabled=os.getenv("ENABLE_STALE_SWEEP", "true").lower() in ("true", "1"), + ), +) +``` + +### Test Plan + +| Test | Asserts | +|------|---------| +| `test_invalidate_entity_sets_flags` | After invalidation, entity has `is_current=false`, `invalidated_at` set, `invalidation_reason` stored | +| `test_invalidate_entity_idempotent` | Invalidating already-invalidated entity returns count 0, no error | +| `test_invalidate_by_query_filters` | With `patient_id` filter, only matching entities invalidated | +| `test_sweep_targets_perception_only` | SEMANTIC/REASONING entities are untouched even if stale | +| `test_sweep_respects_threshold` | Entity observed 30 days ago with 90-day threshold is NOT invalidated | +| `test_sweep_disabled_config` | `stale_check_enabled=False` → sweep returns 0 immediately | + +### Acceptance Criteria + +- Entities can be explicitly invalidated via `invalidate_entity()` without deletion. +- Stale PERCEPTION entities are automatically invalidated by the periodic sweep. +- Higher-layer entities (SEMANTIC, REASONING, APPLICATION) are never auto-invalidated. +- All invalidations are auditable: `invalidated_at`, `invalidation_reason` are stored on the node. +- New env vars: `MEMORY_STALE_THRESHOLD_DAYS` (default 90), `ENABLE_STALE_SWEEP` (default true). + +--- + +## SPEC-5: LLM Rate Limit Management + +**Checklist item**: §4.5 — *LLM rate limit management — `SEMAPHORE_LIMIT` not configured; risk of 429 errors during high-throughput ingestion.* + +### Problem + +Graphiti makes concurrent LLM calls during episode ingestion (entity extraction, edge inference, deduplication). The library uses an internal semaphore (`SEMAPHORE_LIMIT` env var, default 10) to cap concurrency. SynapseFlow does not expose or configure this value, meaning: + +1. During burst ingestion (e.g., uploading a full conversation history), 10 concurrent LLM calls may exceed provider rate limits. +2. 429 errors from the LLM provider cause ingestion failures with no retry. +3. There is no visibility into whether rate limiting is occurring. + +### Goal + +Expose `SEMAPHORE_LIMIT` as a configurable environment variable, add a health check for LLM rate status, and document tuning guidance. + +### Design + +``` + ┌────────────────────────────────┐ + │ Environment │ + │ GRAPHITI_SEMAPHORE_LIMIT=5 │ ← NEW env var + │ GRAPHITI_LLM_RETRY_ENABLED=true│ ← NEW env var + │ GRAPHITI_LLM_MAX_RETRIES=3 │ ← NEW env var + └───────────┬────────────────────┘ + │ + ┌───────────▼────────────────────┐ + │ composition_root.py │ + │ bootstrap_episodic_memory() │ + │ os.environ["SEMAPHORE_LIMIT"] │ ← set before Graphiti import + └───────────┬────────────────────┘ + │ + ┌───────────▼────────────────────┐ + │ EpisodicMemoryService │ + │ ├── _rate_limit_stats │ ← NEW tracking dict + │ ├── store_turn_episode() │ + │ │ └── try/except 429 → │ + │ │ log + retry w/ backoff │ + │ └── get_health() │ ← NEW method + └────────────────────────────────┘ +``` + +### Files to Change + +| File | Change | +|------|--------| +| `src/composition_root.py` | Set `SEMAPHORE_LIMIT` env var before Graphiti imports in `bootstrap_episodic_memory()` | +| `src/application/services/episodic_memory_service.py` | Add rate-limit-aware retry wrapper, tracking stats, health method | +| `tests/test_episodic_memory.py` | Add tests for retry and health | + +### Implementation Details + +**Set `SEMAPHORE_LIMIT` at bootstrap** — in `bootstrap_episodic_memory()`, before importing Graphiti: + +```python +async def bootstrap_episodic_memory(event_bus=None): + import os + + # Configure Graphiti's LLM concurrency BEFORE importing graphiti_core + semaphore_limit = os.getenv("GRAPHITI_SEMAPHORE_LIMIT", "5") + os.environ["SEMAPHORE_LIMIT"] = semaphore_limit + logger.info(f"Graphiti SEMAPHORE_LIMIT set to {semaphore_limit}") + + # ... rest of existing bootstrap ... +``` + +Default is lowered from Graphiti's 10 to 5, as SynapseFlow's medical domain use case favors reliability over throughput. + +**Rate-limit-aware retry in `store_turn_episode()`** — wrap the `add_episode()` call: + +```python +async def _add_episode_with_retry(self, **kwargs) -> Any: + """Call graphiti.add_episode with retry on rate limit errors.""" + max_retries = int(os.getenv("GRAPHITI_LLM_MAX_RETRIES", "3")) + retry_enabled = os.getenv("GRAPHITI_LLM_RETRY_ENABLED", "true").lower() in ("true", "1") + + for attempt in range(max_retries + 1): + try: + return await self.graphiti.add_episode(**kwargs) + except Exception as e: + is_rate_limit = "429" in str(e) or "rate" in str(e).lower() + if is_rate_limit and retry_enabled and attempt < max_retries: + wait = 2 ** (attempt + 1) # 2, 4, 8 seconds + self._rate_limit_stats["retries"] += 1 + logger.warning( + f"LLM rate limit hit (attempt {attempt + 1}/{max_retries}), " + f"retrying in {wait}s" + ) + await asyncio.sleep(wait) + else: + if is_rate_limit: + self._rate_limit_stats["failures"] += 1 + raise +``` + +**Tracking stats** — add to `__init__`: + +```python +self._rate_limit_stats = { + "retries": 0, + "failures": 0, + "last_rate_limit": None, +} +``` + +**Health method** — add to `EpisodicMemoryService`: + +```python +def get_health(self) -> Dict[str, Any]: + """Return health/status information including rate limit stats.""" + return { + "initialized": self._initialized, + "rate_limit_retries": self._rate_limit_stats["retries"], + "rate_limit_failures": self._rate_limit_stats["failures"], + "last_rate_limit": self._rate_limit_stats["last_rate_limit"], + "semaphore_limit": int(os.getenv("SEMAPHORE_LIMIT", "10")), + } +``` + +**Use the retry wrapper** — in `store_turn_episode()` and `store_session_episode()`, replace direct `self.graphiti.add_episode()` calls with `self._add_episode_with_retry()`. + +### Test Plan + +| Test | Asserts | +|------|---------| +| `test_semaphore_limit_set_at_bootstrap` | After `bootstrap_episodic_memory()`, `os.environ["SEMAPHORE_LIMIT"]` equals configured value | +| `test_retry_on_rate_limit` | Mock `add_episode` raising 429 twice then succeeding → result returned, `retries=2` | +| `test_retry_exhausted_raises` | Mock `add_episode` raising 429 beyond max_retries → exception propagated, `failures=1` | +| `test_non_rate_limit_error_no_retry` | Mock `add_episode` raising `ValueError` → immediate propagation, `retries=0` | +| `test_retry_disabled_no_retry` | Set `GRAPHITI_LLM_RETRY_ENABLED=false` → 429 error propagated immediately | +| `test_get_health_includes_rate_stats` | `get_health()` returns dict with all expected keys | + +### Acceptance Criteria + +- `GRAPHITI_SEMAPHORE_LIMIT` env var controls Graphiti's internal concurrency (default: 5). +- Rate-limited `add_episode` calls are retried with exponential backoff (2s, 4s, 8s). +- After max retries exhausted, the error propagates normally. +- `get_health()` exposes rate limit stats for monitoring. +- New env vars: `GRAPHITI_SEMAPHORE_LIMIT` (default 5), `GRAPHITI_LLM_RETRY_ENABLED` (default true), `GRAPHITI_LLM_MAX_RETRIES` (default 3). + +--- + +## Cross-Cutting Concerns + +### New Environment Variables Summary + +| Variable | Default | Spec | Purpose | +|----------|---------|------|---------| +| `GRAPHITI_SEMAPHORE_LIMIT` | `5` | SPEC-5 | Graphiti LLM concurrency cap | +| `GRAPHITI_LLM_RETRY_ENABLED` | `true` | SPEC-5 | Enable retry on 429 errors | +| `GRAPHITI_LLM_MAX_RETRIES` | `3` | SPEC-5 | Max retries before propagating error | +| `MEMORY_STALE_THRESHOLD_DAYS` | `90` | SPEC-4 | Days before PERCEPTION entity considered stale | +| `ENABLE_STALE_SWEEP` | `true` | SPEC-4 | Enable periodic staleness sweep | + +### New Files Summary + +| File | Spec | Description | +|------|------|-------------| +| `src/application/services/memory_invalidation_service.py` | SPEC-4 | Entity invalidation + TTL sweep | +| `tests/test_memory_invalidation.py` | SPEC-4 | Unit tests for invalidation service | + +### Modified Files Summary + +| File | Specs | +|------|-------| +| `src/application/services/episodic_memory_service.py` | SPEC-1, SPEC-3, SPEC-5 | +| `src/application/services/crystallization_service.py` | SPEC-2, SPEC-3, SPEC-4 | +| `src/application/services/entity_resolver.py` | SPEC-2, SPEC-3 | +| `src/composition_root.py` | SPEC-4, SPEC-5 | +| `tests/test_episodic_memory.py` | SPEC-1, SPEC-3, SPEC-5 | +| `tests/test_crystallization_pipeline.py` | SPEC-2, SPEC-3 | + +### Dependency on SPEC-2 from SPEC-4 + +SPEC-4 (memory invalidation) depends on the `is_current`, `valid_until`, `invalidated_at` fields introduced by SPEC-2 (temporal conflict resolution). Implementation order: **SPEC-2 → SPEC-4**. + +### Recommended Implementation Order + +1. **SPEC-5** (LLM rate limits) — standalone, zero coupling, immediate operational benefit +2. **SPEC-3** (bound search results) — standalone, small diff, prevents resource waste +3. **SPEC-2** (temporal conflict resolution) — adds schema fields needed by SPEC-4 +4. **SPEC-4** (memory invalidation) — depends on SPEC-2's schema +5. **SPEC-1** (community summaries) — feature addition, version-dependent, lowest urgency diff --git a/docs/architecture/GRAPHITI_MEMORY_BRIEF.md b/docs/architecture/GRAPHITI_MEMORY_BRIEF.md new file mode 100644 index 000000000..1a0838b76 --- /dev/null +++ b/docs/architecture/GRAPHITI_MEMORY_BRIEF.md @@ -0,0 +1,187 @@ +# Graphiti Memory Handling Brief + +## 1. What Is Graphiti? + +Graphiti is an open-source Python framework by Zep for building temporally-aware knowledge graphs designed for AI agent memory. Unlike traditional RAG, which treats memories as isolated, static documents, Graphiti continuously integrates user interactions and structured/unstructured data into a coherent, queryable graph with: + +- **Bi-temporal model** -- tracks both when an event occurred and when it was ingested, with explicit validity intervals on every edge. +- **Real-time incremental updates** -- new episodes are integrated without batch recomputation. +- **Hybrid search** -- combines semantic embeddings, keyword (BM25), and graph traversal at sub-300ms P95 latency with no LLM calls at retrieval time. +- **Automatic ontology building** -- LLM-driven entity extraction and deduplication. + +Graphiti organizes memory into three hierarchical subgraph tiers (per the Zep paper arXiv:2501.13956): + +| Tier | Subgraph | Contents | +|------|----------|----------| +| 1 | **Episode Subgraph** | Raw events/messages with timestamps -- the ground-truth corpus | +| 2 | **Semantic Entity Subgraph** | Entities and factual edges extracted from episodes, embedded in high-dimensional space | +| 3 | **Community Subgraph** | Clusters of strongly connected entities with high-level summaries | + +--- + +## 2. How SynapseFlow Uses Graphiti + +### 2.1 Role in the Architecture + +Graphiti is **not** the primary knowledge graph backend. SynapseFlow uses a dual-graph architecture: + +| System | Backend | Purpose | Lifetime | +|--------|---------|---------|----------| +| DIKW Knowledge Graph | **Neo4j** | Persistent knowledge with 4 layers (PERCEPTION / SEMANTIC / REASONING / APPLICATION) | Permanent | +| Episodic Memory | **Graphiti + FalkorDB** | Conversation memory with automatic entity extraction | Session-bound | + +The **CrystallizationService** bridges the two: it transfers entities discovered in episodic memory into the DIKW graph as PERCEPTION-layer nodes, where they can be promoted through confidence-based thresholds. + +### 2.2 Core Implementation Files + +| File | Responsibility | +|------|---------------| +| `src/infrastructure/graphiti.py` | Graphiti client initialization (Neo4j driver) | +| `src/infrastructure/graphiti_backend.py` | `KnowledgeGraphBackend` adapter for direct KG operations via Graphiti | +| `src/application/services/episodic_memory_service.py` | Primary Graphiti consumer -- episode storage, retrieval, hybrid search | +| `src/application/services/crystallization_service.py` | Episodic-to-DIKW transfer pipeline | +| `src/composition_root.py` (lines 472-510) | Backend selection (`KG_BACKEND` env var) and bootstrap | + +### 2.3 EpisodicMemoryService + +This is the central integration point. Key design choices: + +**Group ID Strategy (multi-tenant isolation):** +- Session-level episodes: `group_id = patient_id` +- Turn-level episodes: `group_id = "{patient_id}:{session_id}"` + +**Episode Types:** +- `EpisodeType.message` for conversation turns (user/assistant pairs) +- `EpisodeType.json` for session summaries + +**Search:** +- Uses `COMBINED_HYBRID_SEARCH_CROSS_ENCODER` recipe (semantic + keyword + cross-encoder reranking) +- `get_conversation_context()` assembles multi-source context: recent episodes + semantically related episodes + extracted entities + +**Event Integration:** +- Emits `episode_added` events on the `EventBus` after each stored episode +- These events carry `episode_id`, `patient_id`, `session_id`, and `entities_extracted` +- The `CrystallizationService` subscribes to these events + +### 2.4 Crystallization Pipeline + +``` +EpisodicMemoryService (Graphiti + FalkorDB) + | emits "episode_added" + v +CrystallizationService + |-- Queries FalkorDB for extracted entities + |-- Resolves/deduplicates via EntityResolver (exact, fuzzy, embedding) + |-- Creates PERCEPTION-layer nodes in Neo4j + |-- Evaluates promotion candidates (confidence >= 0.85, observations >= 2) + v +Neo4j DIKW Knowledge Graph +``` + +Three processing modes: +- **EVENT_DRIVEN**: immediate crystallization per episode +- **BATCH**: periodic processing (default 5 min interval) +- **HYBRID** (default): queue events, trigger on threshold (10 entities) or interval + +### 2.5 GraphitiBackend (Alternative KG Backend) + +`GraphitiBackend` implements `KnowledgeGraphBackend` for using Graphiti as the *primary* KG backend (selected via `KG_BACKEND=graphiti`). This is separate from episodic memory usage and maps the generic interface to Graphiti's `EntityNode`/`EntityEdge` structures and `add_triplet()` API. + +--- + +## 3. Implementation vs. Graphiti Best Practices -- Conformance Check + +### 3.1 What Aligns Well + +| Best Practice | SynapseFlow Implementation | Status | +|---------------|---------------------------|--------| +| **Use `group_ids` for multi-tenant isolation** | Patient-level and session-level group IDs | Aligned | +| **Store episodes with `reference_time`** | Timestamps passed to `add_episode()` | Aligned | +| **Use hybrid search, not pure vector** | `COMBINED_HYBRID_SEARCH_CROSS_ENCODER` recipe | Aligned | +| **Separate episodic from semantic memory** | FalkorDB for episodes, Neo4j for persistent KG | Aligned | +| **Incremental, not batch RAG** | Episodes added in real-time per conversation turn | Aligned | +| **Handle entity deduplication** | `EntityResolver` with exact/fuzzy/embedding strategies | Aligned | +| **Avoid LLM calls at retrieval time** | Hybrid search config avoids LLM during search | Aligned | +| **Event-driven processing** | `EventBus` pub/sub for episode-to-crystallization flow | Aligned | + +### 3.2 Gaps and Concerns + +| Issue | Description | Severity | +|-------|-------------|----------| +| **Graphiti v0.27.1 RediSearch bug** | `build_fulltext_query()` generates invalid syntax for tag fields. SynapseFlow applies a monkey-patch (lines 44-93 of `episodic_memory_service.py`). This is a fragile workaround that will break on Graphiti upgrades. | Medium | +| **Missing community subgraph usage** | Graphiti's 3-tier model includes community detection (Tier 3), but SynapseFlow does not use community summaries. This means high-level, aggregated views of patient knowledge are unavailable. | Low | +| **No temporal conflict resolution** | Graphiti's bi-temporal model supports invalidating outdated facts via `t_valid`/`t_invalid` on edges. The crystallization pipeline does not propagate or leverage these temporal markers when transferring to Neo4j. | Medium | +| **`retrieve_recent_episodes()` scope limitation** | When no `session_id` is provided, only `group_id = patient_id` is queried. Turn-level episodes (which use `patient_id:session_id`) require knowing all session IDs -- the code acknowledges this with a comment but does not solve it. | Medium | +| **Hardcoded search query in `crystallize_from_graphiti()`** | Line 540: `query="medical entity"` is a broad, hardcoded search term for batch crystallization. This limits entity discovery to medical contexts and may miss non-medical entities. | Low | +| **No `SEMAPHORE_LIMIT` configuration** | Graphiti recommends tuning `SEMAPHORE_LIMIT` for LLM provider rate limits. SynapseFlow does not expose or configure this. | Low | +| **Dual `add_entity` Cypher calls in GraphitiBackend** | `graphiti_backend.py` lines 51-90 execute two separate MERGE queries for the same entity (one for attributes-as-JSON, one for attributes-as-properties). This is redundant and could cause race conditions. | Low | +| **`graphiti.py` initialization uses Neo4j driver only** | The `get_graphiti()` helper connects via Neo4j URI/user/password but does not support FalkorDB initialization, even though `EpisodicMemoryService` uses FalkorDB. These are two separate initialization paths that could diverge. | Low | +| **Version pinning** | Pinned to `graphiti-core[falkordb]>=0.27.1,<0.28`. The latest stable release is v0.28.1 (Feb 2026), which may contain fixes for the RediSearch bug. Upgrading should be evaluated. | Medium | + +### 3.3 Test Coverage Assessment + +| Area | Coverage | Notes | +|------|----------|-------| +| EpisodicMemoryService | Good | 14 tests covering init, storage, retrieval, search, error handling, helpers | +| CrystallizationService | Good | Tests for new/existing entity crystallization, batch processing, stats | +| EntityResolver | Good | Name normalization, type mapping, exact match, merge operations | +| PromotionGate | Good | Risk levels, approval/rejection criteria, stats | +| Integration (end-to-end) | Partial | Event bus wiring tested, but no live FalkorDB/Graphiti integration test | +| GraphitiBackend | Missing | No dedicated unit tests for the `KnowledgeGraphBackend` adapter | + +--- + +## 4. Memory Management Sanity Check -- Best Practices for Agent Systems + +Based on Graphiti documentation, the Zep research paper, and industry patterns, here is a sanity checklist for agent memory management: + +### 4.1 Memory Architecture + +- [x] **Separate episodic from semantic memory** -- SynapseFlow uses FalkorDB for episodes, Neo4j for structured knowledge. +- [x] **Multi-layer memory with different lifetimes** -- Redis (short-term, 24h TTL), Mem0 (mid-term), Neo4j (long-term). +- [x] **Entity deduplication across memory layers** -- EntityResolver handles cross-layer dedup with multiple strategies. +- [ ] **Community/summary layer** -- Graphiti supports community detection for high-level summaries. Not implemented. +- [x] **Memory isolation per user/tenant** -- Group IDs partition data at the storage layer. + +### 4.2 Data Ingestion + +- [x] **Incremental updates, not batch recomputation** -- Episodes are added per conversation turn in real-time. +- [x] **Structured episode format** -- Conversation turns use Graphiti's message format; session summaries use JSON. +- [ ] **Temporal conflict resolution** -- Graphiti invalidates outdated edges via bi-temporal model, but this is not propagated to Neo4j. +- [x] **Source traceability** -- Episodes include `source_description` with patient/session context. + +### 4.3 Retrieval + +- [x] **Hybrid search (semantic + keyword + graph)** -- Uses `COMBINED_HYBRID_SEARCH_CROSS_ENCODER`. +- [x] **No LLM calls at retrieval time** -- Search avoids LLM inference for low latency. +- [x] **Context assembly from multiple sources** -- `get_conversation_context()` merges recent, related, and entity data. +- [x] **Result deduplication** -- Recent and related episodes are deduplicated by ID. +- [ ] **Bound search results** -- Some search paths don't limit results explicitly (e.g., `crystallize_from_graphiti` uses a default limit of 100 but the search itself doesn't pass `num_results`). + +### 4.4 Knowledge Lifecycle + +- [x] **Confidence-based promotion** -- PERCEPTION -> SEMANTIC requires confidence >= 0.85 and observations >= 2. +- [x] **Observation counting** -- Merge operations increment observation counts for promotion eligibility. +- [x] **Audit trail** -- `first_observed`, `last_observed`, `source` fields track entity provenance. +- [ ] **Memory invalidation/expiration** -- No mechanism to mark episodic or DIKW entities as expired/outdated. +- [x] **Event-driven pipeline** -- `episode_added` events trigger crystallization without polling. + +### 4.5 Operational + +- [x] **Graceful error handling** -- Retrieval methods return empty results on error; storage methods propagate exceptions. +- [x] **Idempotent initialization** -- `EpisodicMemoryService.initialize()` guards against double initialization. +- [x] **Configurable processing modes** -- Crystallization supports event-driven, batch, and hybrid modes. +- [ ] **LLM rate limit management** -- `SEMAPHORE_LIMIT` not configured; risk of 429 errors during high-throughput ingestion. +- [x] **Flush/quiescence support** -- `flush_now()` and `is_quiescent()` enable evaluation framework integration. + +--- + +## 5. Recommendations + +1. **Upgrade to graphiti-core v0.28.x** -- Evaluate whether the RediSearch syntax bug is fixed upstream and remove the monkey-patch if so. +2. **Add temporal metadata to crystallized entities** -- Propagate Graphiti's `t_valid`/`t_invalid` to Neo4j DIKW nodes to enable temporal queries and fact invalidation. +3. **Implement cross-session episode retrieval** -- Solve the `retrieve_recent_episodes` gap by querying for all sessions belonging to a patient, or use Graphiti's search which handles this implicitly. +4. **Replace hardcoded crystallization query** -- Use entity type or timestamp-based queries instead of `"medical entity"` for broader discovery. +5. **Add GraphitiBackend tests** -- The adapter has zero test coverage; add unit tests covering `add_entity`, `add_relationship`, and `query`. +6. **Configure `SEMAPHORE_LIMIT`** -- Expose this as an environment variable to prevent LLM rate limit issues during episode ingestion bursts. +7. **Consider community subgraph** -- For patients with long histories, community summaries could provide useful high-level context without retrieving individual episodes. diff --git a/src/application/__pycache__/__init__.cpython-311.pyc b/src/application/__pycache__/__init__.cpython-311.pyc index 98b295326..0d72f9ec8 100644 Binary files a/src/application/__pycache__/__init__.cpython-311.pyc and b/src/application/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/application/services/crystallization_service.py b/src/application/services/crystallization_service.py index f9552d468..bd76a4b9a 100644 --- a/src/application/services/crystallization_service.py +++ b/src/application/services/crystallization_service.py @@ -122,6 +122,7 @@ def __init__( event_bus: Any, graphiti_client: Optional[Any] = None, config: Optional[CrystallizationConfig] = None, + invalidation_service: Optional[Any] = None, ): """ Initialize crystallization service. @@ -132,12 +133,14 @@ def __init__( event_bus: EventBus for event-driven processing graphiti_client: Graphiti client for querying FalkorDB config: Crystallization configuration + invalidation_service: Optional MemoryInvalidationService for stale sweeps """ self.neo4j_backend = neo4j_backend self.entity_resolver = entity_resolver self.event_bus = event_bus self.graphiti = graphiti_client self.config = config or CrystallizationConfig() + self.invalidation_service = invalidation_service # State tracking self._last_crystallization: Optional[datetime] = None @@ -240,6 +243,15 @@ async def _periodic_crystallization(self) -> None: # Even without pending entities, check for new Graphiti entities await self.crystallize_from_graphiti() + # SPEC-4: Run stale entity sweep after crystallization + if self.invalidation_service: + try: + sweep_result = await self.invalidation_service.sweep_stale_entities() + if sweep_result.entities_invalidated > 0: + logger.info(f"Stale sweep: invalidated {sweep_result.entities_invalidated} entities") + except Exception as sweep_err: + logger.warning(f"Stale sweep failed: {sweep_err}") + except asyncio.CancelledError: break except Exception as e: @@ -316,14 +328,20 @@ async def crystallize_entities( ) if match.found: - # Merge with existing entity + # Merge with existing entity (SPEC-2: include temporal fields) + merge_data = { + "confidence": confidence, + "graphiti_entity_id": graphiti_id, + "last_seen_in_episodic": datetime.utcnow().isoformat(), + } + if entity_data.get("valid_at"): + merge_data["valid_at"] = entity_data["valid_at"] + if entity_data.get("invalid_at"): + merge_data["invalid_at"] = entity_data["invalid_at"] + merge_result = await self.entity_resolver.merge_for_crystallization( existing_id=match.entity_id, - new_data={ - "confidence": confidence, - "graphiti_entity_id": graphiti_id, - "last_seen_in_episodic": datetime.utcnow().isoformat(), - } + new_data=merge_data, ) if merge_result.success: @@ -381,6 +399,14 @@ async def crystallize_entities( else: errors.append(f"Failed to create entity: {name}") + # SPEC-2: Resolve temporal conflicts for this entity + if entity_data.get("valid_at"): + await self._resolve_temporal_conflicts( + entity_name=name, + entity_type=entity_type, + new_valid_from=entity_data["valid_at"], + ) + except Exception as e: logger.error(f"Error crystallizing entity {entity_data}: {e}") errors.append(f"Error processing {entity_data}: {str(e)}") @@ -455,6 +481,11 @@ async def _create_perception_entity( "first_observed": datetime.utcnow().isoformat(), "last_observed": datetime.utcnow().isoformat(), "source": "graphiti_episodic", + # SPEC-2: Temporal conflict resolution fields + "valid_from": source_data.get("valid_at", datetime.utcnow().isoformat()), + "valid_until": source_data.get("invalid_at"), + "invalidated_at": source_data.get("expired_at"), + "is_current": source_data.get("invalid_at") is None, } if graphiti_id: @@ -480,6 +511,75 @@ async def _create_perception_entity( logger.error(f"Failed to create PERCEPTION entity {name}: {e}") return None + # ======================================== + # Temporal Conflict Resolution (SPEC-2) + # ======================================== + + async def _resolve_temporal_conflicts( + self, + entity_name: str, + entity_type: str, + new_valid_from: Optional[str], + ) -> int: + """ + Invalidate existing DIKW entities superseded by a newer fact. + + When the same entity (by name+type) appears with a newer valid_from, + all previous is_current=true versions get marked as invalidated. + Old entities are never deleted -- only invalidated (preserving history). + + Args: + entity_name: Name of the entity + entity_type: Type of the entity + new_valid_from: valid_from timestamp of the new fact + + Returns: + Number of entities invalidated. + """ + if not new_valid_from: + return 0 + + normalized_name = self.entity_resolver.normalize_entity_name(entity_name) + normalized_type = self.entity_resolver.normalize_entity_type(entity_type) + + query = """ + MATCH (n:Entity) + WHERE toLower(n.name) = $name + AND n.entity_type = $type + AND (n.is_current = true OR n.is_current IS NULL) + AND n.valid_from IS NOT NULL + AND n.valid_from < $new_valid_from + SET n.valid_until = $new_valid_from, + n.is_current = false, + n.invalidated_at = datetime() + RETURN count(n) as invalidated + """ + + try: + result = await self.neo4j_backend.query( + query, + { + "name": normalized_name, + "type": normalized_type, + "new_valid_from": new_valid_from, + }, + ) + + rows = result.get("rows", []) if isinstance(result, dict) else [] + invalidated = rows[0].get("invalidated", 0) if rows else 0 + + if invalidated > 0: + logger.info( + f"Temporal conflict: invalidated {invalidated} prior version(s) " + f"of '{entity_name}' ({entity_type})" + ) + + return invalidated + + except Exception as e: + logger.error(f"Error resolving temporal conflicts for '{entity_name}': {e}") + return 0 + async def crystallize_from_graphiti( self, since: Optional[datetime] = None, @@ -528,29 +628,65 @@ async def crystallize_from_graphiti( # If we have search capability, use it if hasattr(self.graphiti, 'clients'): + from copy import deepcopy from graphiti_core.search.search import search from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_CROSS_ENCODER from graphiti_core.search.search_filters import SearchFilters - # Search for all entities (broad query) + # SPEC-3: Pass explicit result bound + use timestamp-based query + # instead of hardcoded "medical entity" + search_config = deepcopy(COMBINED_HYBRID_SEARCH_CROSS_ENCODER) + if hasattr(search_config, 'limit'): + search_config.limit = limit + if hasattr(search_config, 'num_results'): + search_config.num_results = limit + + # Use entity-type-agnostic query based on recency + query = f"entities since {since.strftime('%Y-%m-%d')}" + if patient_id: + query = f"patient {patient_id} {query}" + results = await search( clients=self.graphiti.clients, - query="medical entity", # Broad query + query=query, group_ids=[patient_id] if patient_id else None, search_filter=SearchFilters(), - config=COMBINED_HYBRID_SEARCH_CROSS_ENCODER, + config=search_config, + num_results=limit, ) for node in results.nodes[:limit]: if node.created_at and node.created_at > since: - entities.append({ + # SPEC-2: Extract temporal metadata from edges for this node + node_edges = [ + e for e in (results.edges or []) + if e.source_node_uuid == node.uuid + or e.target_node_uuid == node.uuid + ] if hasattr(results, 'edges') and results.edges else [] + latest_edge = ( + max(node_edges, key=lambda e: e.created_at) + if node_edges else None + ) + + entity_entry = { "name": node.name, "entity_type": node.labels[0] if node.labels else "Entity", "confidence": 0.8, # Default confidence from Graphiti "graphiti_id": node.uuid, "summary": node.summary, "patient_id": patient_id, - }) + } + + # Propagate bi-temporal fields if available + if latest_edge: + if hasattr(latest_edge, 'valid_at') and latest_edge.valid_at: + entity_entry["valid_at"] = latest_edge.valid_at.isoformat() + if hasattr(latest_edge, 'invalid_at') and latest_edge.invalid_at: + entity_entry["invalid_at"] = latest_edge.invalid_at.isoformat() + if hasattr(latest_edge, 'expired_at') and latest_edge.expired_at: + entity_entry["expired_at"] = latest_edge.expired_at.isoformat() + + entities.append(entity_entry) if entities: return await self.crystallize_entities(entities, source="graphiti_query") diff --git a/src/application/services/entity_resolver.py b/src/application/services/entity_resolver.py index 1c92ec7d7..df6503162 100644 --- a/src/application/services/entity_resolver.py +++ b/src/application/services/entity_resolver.py @@ -198,7 +198,8 @@ async def resolve_entity( async def _get_existing_entities( self, entity_type: str, - context: Dict[str, Any] + context: Dict[str, Any], + limit: int = 100, ) -> List[Dict[str, Any]]: """ Retrieve existing entities from the knowledge graph. @@ -206,6 +207,7 @@ async def _get_existing_entities( Args: entity_type: Type of entity to retrieve context: Context for filtering (e.g., domain) + limit: Maximum number of entities to return Returns: List of existing entity dictionaries @@ -213,14 +215,14 @@ async def _get_existing_entities( # Query the graph for entities of this type # This is backend-specific; we'll use a generic interface try: - # Assuming backend has a query method + # SPEC-3: Parameterized LIMIT instead of hardcoded value query = f""" MATCH (e:{entity_type}) RETURN e.id AS id, e.name AS name, e AS properties - LIMIT 100 + LIMIT $limit """ - results = await self.backend.query(query) + results = await self.backend.query(query, {"limit": limit}) entities = [] for record in results: @@ -862,6 +864,16 @@ async def merge_for_crystallization( properties_updated.append(key) updates[key] = value + # SPEC-2: Propagate temporal fields during merge + if "valid_at" in new_data and new_data["valid_at"]: + updates["valid_from"] = new_data["valid_at"] + properties_updated.append("valid_from") + if "invalid_at" in new_data and new_data["invalid_at"]: + updates["valid_until"] = new_data["invalid_at"] + updates["is_current"] = False + properties_updated.append("valid_until") + properties_updated.append("is_current") + # Always update observation tracking current_count = existing_props.get("observation_count", 1) updates["observation_count"] = current_count + 1 diff --git a/src/application/services/episodic_memory_service.py b/src/application/services/episodic_memory_service.py index b83a23a74..ee6ae2f78 100644 --- a/src/application/services/episodic_memory_service.py +++ b/src/application/services/episodic_memory_service.py @@ -20,8 +20,10 @@ - Events include extracted entities for DIKW pipeline processing """ +import asyncio import logging import os +from copy import deepcopy from datetime import datetime from typing import List, Dict, Any, Optional, TYPE_CHECKING from dataclasses import dataclass, field @@ -117,6 +119,16 @@ class ConversationEpisode: entities: List[Dict[str, Any]] = field(default_factory=list) +@dataclass +class CommunitySummary: + """A community-level summary from Graphiti's community subgraph (SPEC-1).""" + community_id: str + summary: str + entity_count: int + key_entities: List[str] + updated_at: Optional[datetime] = None + + class EpisodicMemoryService: """ Episodic memory service using Graphiti with FalkorDB. @@ -153,6 +165,13 @@ def __init__( self.event_bus = event_bus self._initialized = False + # SPEC-5: Rate limit tracking + self._rate_limit_stats: Dict[str, Any] = { + "retries": 0, + "failures": 0, + "last_rate_limit": None, + } + logger.info("EpisodicMemoryService initialized") async def initialize(self) -> None: @@ -172,6 +191,45 @@ async def close(self) -> None: """Close the Graphiti connection.""" await self.graphiti.close() + # ======================================== + # Rate Limit Management (SPEC-5) + # ======================================== + + async def _add_episode_with_retry(self, **kwargs) -> Any: + """Call graphiti.add_episode with retry on LLM rate limit errors.""" + max_retries = int(os.getenv("GRAPHITI_LLM_MAX_RETRIES", "3")) + retry_enabled = os.getenv("GRAPHITI_LLM_RETRY_ENABLED", "true").lower() in ("true", "1") + + for attempt in range(max_retries + 1): + try: + return await self.graphiti.add_episode(**kwargs) + except Exception as e: + is_rate_limit = "429" in str(e) or "rate" in str(e).lower() + if is_rate_limit and retry_enabled and attempt < max_retries: + wait = 2 ** (attempt + 1) # 2, 4, 8 seconds + self._rate_limit_stats["retries"] += 1 + self._rate_limit_stats["last_rate_limit"] = datetime.now().isoformat() + logger.warning( + f"LLM rate limit hit (attempt {attempt + 1}/{max_retries}), " + f"retrying in {wait}s" + ) + await asyncio.sleep(wait) + else: + if is_rate_limit: + self._rate_limit_stats["failures"] += 1 + self._rate_limit_stats["last_rate_limit"] = datetime.now().isoformat() + raise + + def get_health(self) -> Dict[str, Any]: + """Return health/status information including rate limit stats.""" + return { + "initialized": self._initialized, + "rate_limit_retries": self._rate_limit_stats["retries"], + "rate_limit_failures": self._rate_limit_stats["failures"], + "last_rate_limit": self._rate_limit_stats["last_rate_limit"], + "semaphore_limit": int(os.getenv("SEMAPHORE_LIMIT", "10")), + } + # ======================================== # Episode Storage # ======================================== @@ -227,7 +285,7 @@ async def store_turn_episode( start_time = datetime.now() try: - result = await self.graphiti.add_episode( + result = await self._add_episode_with_retry( name=episode_name, episode_body=episode_body, source_description=source_description, @@ -320,7 +378,7 @@ async def store_session_episode( start_time = datetime.now() try: - result = await self.graphiti.add_episode( + result = await self._add_episode_with_retry( name=episode_name, episode_body=episode_body, source_description=source_description, @@ -427,15 +485,23 @@ async def search_episodes( group_ids.append(f"{patient_id}:{session_id}") try: + # SPEC-3: Pass explicit result bound to search + search_config = deepcopy(COMBINED_HYBRID_SEARCH_CROSS_ENCODER) + if hasattr(search_config, 'limit'): + search_config.limit = limit + if hasattr(search_config, 'num_results'): + search_config.num_results = limit + results: SearchResults = await search( clients=self.graphiti.clients, query=query, group_ids=group_ids, search_filter=SearchFilters(), - config=COMBINED_HYBRID_SEARCH_CROSS_ENCODER, + config=search_config, + num_results=limit, ) - # Convert search results to episodes + # Convert search results to episodes (bounded by search config above) episodes = [] for ep in results.episodes[:limit]: episodes.append(self._convert_episode(ep, patient_id)) @@ -470,12 +536,20 @@ async def get_related_entities( await self.initialize() try: + # SPEC-3: Pass explicit result bound to search + search_config = deepcopy(COMBINED_HYBRID_SEARCH_CROSS_ENCODER) + if hasattr(search_config, 'limit'): + search_config.limit = limit + if hasattr(search_config, 'num_results'): + search_config.num_results = limit + results: SearchResults = await search( clients=self.graphiti.clients, query=query, group_ids=[patient_id], search_filter=SearchFilters(), - config=COMBINED_HYBRID_SEARCH_CROSS_ENCODER, + config=search_config, + num_results=limit, ) entities = [] @@ -548,11 +622,104 @@ async def get_conversation_context( recent_ids = {ep.episode_id for ep in recent} related = [ep for ep in related if ep.episode_id not in recent_ids] + # SPEC-1: Get community summaries (gracefully degradable) + community_summaries: List[CommunitySummary] = [] + try: + community_summaries = await self.get_community_summaries( + patient_id=patient_id, + limit=3, + ) + except Exception as e: + logger.debug(f"Community summaries unavailable: {e}") + return { "recent_episodes": [self._episode_to_dict(ep) for ep in recent], "related_episodes": [self._episode_to_dict(ep) for ep in related[:max_episodes - len(recent)]], "entities": entities, - "total_context_items": len(recent) + len(related) + len(entities), + "community_summaries": [self._summary_to_dict(s) for s in community_summaries], + "total_context_items": len(recent) + len(related) + len(entities) + len(community_summaries), + } + + # ======================================== + # Community Summaries (SPEC-1) + # ======================================== + + async def get_community_summaries( + self, + patient_id: str, + limit: int = 5, + ) -> List[CommunitySummary]: + """Get community-level summaries from Graphiti's community subgraph. + + Graphiti clusters strongly-connected entities into communities and + generates high-level summaries via label propagation. This method + queries those community nodes filtered by patient group_id. + + If the Graphiti version does not expose community nodes, returns []. + + Args: + patient_id: Patient identifier for group_id filtering. + limit: Maximum number of community summaries to return. + + Returns: + List of CommunitySummary objects. + """ + await self.initialize() + + try: + driver = self.graphiti.driver if hasattr(self.graphiti, 'driver') else None + if driver is None: + logger.debug("No graph driver available for community queries") + return [] + + # Query community nodes from the graph + # Graphiti stores communities with the Community_ label prefix + query = """ + MATCH (c:Community) + WHERE c.group_id = $group_id OR c.group_id IS NULL + OPTIONAL MATCH (c)-[:HAS_MEMBER]->(e:Entity) + WITH c, collect(DISTINCT e.name) AS member_names, count(DISTINCT e) AS member_count + RETURN c.uuid AS community_id, + c.summary AS summary, + member_count AS entity_count, + member_names AS key_entities, + c.updated_at AS updated_at + ORDER BY member_count DESC + LIMIT $limit + """ + + result = await driver.execute_query( + query, {"group_id": patient_id, "limit": limit} + ) + + summaries = [] + rows = result if isinstance(result, list) else [] + for row in rows: + if not row.get("summary"): + continue + summaries.append(CommunitySummary( + community_id=row.get("community_id", ""), + summary=row.get("summary", ""), + entity_count=row.get("entity_count", 0), + key_entities=row.get("key_entities", [])[:10], + updated_at=row.get("updated_at"), + )) + + logger.debug(f"Retrieved {len(summaries)} community summaries for patient {patient_id}") + return summaries + + except Exception as e: + logger.debug(f"Community summaries unavailable: {e}") + return [] + + def _summary_to_dict(self, summary: CommunitySummary) -> Dict[str, Any]: + """Convert CommunitySummary to dictionary.""" + return { + "community_id": summary.community_id, + "summary": summary.summary, + "entity_count": summary.entity_count, + "key_entities": summary.key_entities, + "updated_at": summary.updated_at.isoformat() if summary.updated_at else None, } # ======================================== diff --git a/src/application/services/memory_invalidation_service.py b/src/application/services/memory_invalidation_service.py new file mode 100644 index 000000000..792f532f9 --- /dev/null +++ b/src/application/services/memory_invalidation_service.py @@ -0,0 +1,260 @@ +"""Memory Invalidation Service (SPEC-4). + +Provides explicit and TTL-based invalidation of DIKW entities. +Entities are never deleted -- only marked as invalidated with an audit trail. + +Invalidation sources: +1. Explicit API call (manual invalidation) +2. Temporal conflict resolution (SPEC-2, automatic) +3. TTL-based staleness sweep (periodic, PERCEPTION layer only) +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from domain.kg_backends import KnowledgeGraphBackend + +logger = logging.getLogger(__name__) + + +@dataclass +class InvalidationConfig: + """Configuration for memory invalidation.""" + + stale_threshold_days: int = 90 + stale_check_enabled: bool = True + episodic_ttl_days: Optional[int] = None # None = no auto-expiry for episodes + + +@dataclass +class InvalidationResult: + """Result of an invalidation operation.""" + + entities_invalidated: int + entity_ids: List[str] + reason: str + timestamp: datetime = field(default_factory=datetime.utcnow) + + +class MemoryInvalidationService: + """Service for invalidating DIKW entities without deletion. + + Supports: + - Explicit single-entity invalidation + - Bulk invalidation by query filters + - Periodic staleness sweep for PERCEPTION entities + """ + + def __init__( + self, + neo4j_backend: "KnowledgeGraphBackend", + config: Optional[InvalidationConfig] = None, + ): + self.neo4j_backend = neo4j_backend + self.config = config or InvalidationConfig() + + self._stats = { + "total_invalidated": 0, + "sweep_runs": 0, + "last_sweep": None, + } + + async def invalidate_entity( + self, + entity_id: str, + reason: str = "manual", + ) -> InvalidationResult: + """Explicitly invalidate a single DIKW entity. + + Sets is_current=false, invalidated_at=now(), invalidation_reason=reason. + Does NOT delete the entity. + + Args: + entity_id: ID of the entity to invalidate. + reason: Reason for invalidation (stored on the node). + + Returns: + InvalidationResult with details. + """ + query = """ + MATCH (n:Entity {id: $entity_id}) + WHERE n.is_current = true OR n.is_current IS NULL + SET n.is_current = false, + n.invalidated_at = datetime(), + n.invalidation_reason = $reason, + n.valid_until = COALESCE(n.valid_until, datetime()) + RETURN n.id as id + """ + + try: + result = await self.neo4j_backend.query( + query, {"entity_id": entity_id, "reason": reason} + ) + + rows = result.get("rows", []) if isinstance(result, dict) else [] + ids = [r.get("id") for r in rows if r.get("id")] + + if ids: + self._stats["total_invalidated"] += len(ids) + logger.info(f"Invalidated entity {entity_id}: reason={reason}") + + return InvalidationResult( + entities_invalidated=len(ids), + entity_ids=ids, + reason=reason, + ) + + except Exception as e: + logger.error(f"Error invalidating entity {entity_id}: {e}") + return InvalidationResult( + entities_invalidated=0, + entity_ids=[], + reason=reason, + ) + + async def invalidate_by_query( + self, + patient_id: Optional[str] = None, + entity_type: Optional[str] = None, + entity_name: Optional[str] = None, + reason: str = "bulk_invalidation", + ) -> InvalidationResult: + """Invalidate multiple entities matching filter criteria. + + Args: + patient_id: Optional patient filter. + entity_type: Optional entity type filter. + entity_name: Optional entity name filter (case-insensitive). + reason: Reason for invalidation. + + Returns: + InvalidationResult with details. + """ + conditions = ["(n.is_current = true OR n.is_current IS NULL)"] + params: Dict[str, Any] = {"reason": reason} + + if patient_id: + conditions.append("n.patient_id = $patient_id") + params["patient_id"] = patient_id + if entity_type: + conditions.append("n.entity_type = $entity_type") + params["entity_type"] = entity_type + if entity_name: + conditions.append("toLower(n.name) = $entity_name") + params["entity_name"] = entity_name.lower() + + where_clause = " AND ".join(conditions) + query = f""" + MATCH (n:Entity) + WHERE {where_clause} + SET n.is_current = false, + n.invalidated_at = datetime(), + n.invalidation_reason = $reason, + n.valid_until = COALESCE(n.valid_until, datetime()) + RETURN n.id as id + """ + + try: + result = await self.neo4j_backend.query(query, params) + + rows = result.get("rows", []) if isinstance(result, dict) else [] + ids = [r.get("id") for r in rows if r.get("id")] + + if ids: + self._stats["total_invalidated"] += len(ids) + logger.info( + f"Bulk invalidated {len(ids)} entities: reason={reason}, " + f"patient_id={patient_id}, entity_type={entity_type}" + ) + + return InvalidationResult( + entities_invalidated=len(ids), + entity_ids=ids, + reason=reason, + ) + + except Exception as e: + logger.error(f"Error in bulk invalidation: {e}") + return InvalidationResult( + entities_invalidated=0, + entity_ids=[], + reason=reason, + ) + + async def sweep_stale_entities(self) -> InvalidationResult: + """Find and invalidate PERCEPTION entities not observed recently. + + Only targets PERCEPTION layer -- higher layers (SEMANTIC, REASONING, + APPLICATION) are considered validated and never auto-invalidated. + + Uses config.stale_threshold_days to determine staleness. + + Returns: + InvalidationResult with details of invalidated entities. + """ + if not self.config.stale_check_enabled: + return InvalidationResult( + entities_invalidated=0, + entity_ids=[], + reason="stale_sweep_disabled", + ) + + cutoff = datetime.utcnow() - timedelta(days=self.config.stale_threshold_days) + cutoff_str = cutoff.isoformat() + + query = """ + MATCH (n:Entity) + WHERE n.dikw_layer = 'PERCEPTION' + AND (n.is_current = true OR n.is_current IS NULL) + AND n.last_observed IS NOT NULL + AND n.last_observed < $cutoff_date + SET n.is_current = false, + n.invalidated_at = datetime(), + n.invalidation_reason = 'stale_ttl' + RETURN n.id as id, n.name as name + """ + + try: + result = await self.neo4j_backend.query( + query, {"cutoff_date": cutoff_str} + ) + + rows = result.get("rows", []) if isinstance(result, dict) else [] + ids = [r.get("id") for r in rows if r.get("id")] + + self._stats["sweep_runs"] += 1 + self._stats["last_sweep"] = datetime.utcnow().isoformat() + + if ids: + self._stats["total_invalidated"] += len(ids) + logger.info( + f"Stale sweep: invalidated {len(ids)} PERCEPTION entities " + f"(threshold: {self.config.stale_threshold_days} days)" + ) + + return InvalidationResult( + entities_invalidated=len(ids), + entity_ids=ids, + reason="stale_ttl", + ) + + except Exception as e: + logger.error(f"Error in stale entity sweep: {e}") + return InvalidationResult( + entities_invalidated=0, + entity_ids=[], + reason="stale_ttl", + ) + + def get_invalidation_stats(self) -> Dict[str, Any]: + """Return invalidation service statistics.""" + return { + "total_invalidated": self._stats["total_invalidated"], + "sweep_runs": self._stats["sweep_runs"], + "last_sweep": self._stats["last_sweep"], + "stale_threshold_days": self.config.stale_threshold_days, + "stale_check_enabled": self.config.stale_check_enabled, + } diff --git a/src/composition_root.py b/src/composition_root.py index 3f150fd9c..285831e5c 100644 --- a/src/composition_root.py +++ b/src/composition_root.py @@ -334,7 +334,12 @@ async def bootstrap_episodic_memory(event_bus: Optional[EventBus] = None): print("ℹ️ Episodic memory not enabled (set ENABLE_EPISODIC_MEMORY=true to enable)") return None + # SPEC-5: Configure Graphiti's LLM concurrency BEFORE importing graphiti_core + semaphore_limit = os.getenv("GRAPHITI_SEMAPHORE_LIMIT", "5") + os.environ["SEMAPHORE_LIMIT"] = semaphore_limit + print("🔄 Initializing Episodic Memory Service (Graphiti + FalkorDB)...") + print(f" Graphiti SEMAPHORE_LIMIT set to {semaphore_limit}") try: from application.services.episodic_memory_service import create_episodic_memory_service @@ -442,12 +447,29 @@ async def bootstrap_crystallization_pipeline( ).lower() in ("true", "1", "yes"), ) + # SPEC-4: Create MemoryInvalidationService for stale entity sweeps + from application.services.memory_invalidation_service import ( + MemoryInvalidationService, + InvalidationConfig, + ) + + invalidation_service = MemoryInvalidationService( + neo4j_backend=neo4j_backend, + config=InvalidationConfig( + stale_threshold_days=int(os.getenv("MEMORY_STALE_THRESHOLD_DAYS", "90")), + stale_check_enabled=os.getenv("ENABLE_STALE_SWEEP", "true").lower() + in ("true", "1"), + ), + ) + print(" ✅ MemoryInvalidationService initialized") + crystallization_service = CrystallizationService( neo4j_backend=neo4j_backend, entity_resolver=entity_resolver, event_bus=event_bus, graphiti_client=graphiti_client, config=crystallization_config, + invalidation_service=invalidation_service, ) # Start the crystallization service diff --git a/src/domain/__pycache__/__init__.cpython-311.pyc b/src/domain/__pycache__/__init__.cpython-311.pyc index 0a16ce54f..c47005a7f 100644 Binary files a/src/domain/__pycache__/__init__.cpython-311.pyc and b/src/domain/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/multi_agent_system.egg-info/SOURCES.txt b/src/multi_agent_system.egg-info/SOURCES.txt index 20edeeced..10caf94e2 100644 --- a/src/multi_agent_system.egg-info/SOURCES.txt +++ b/src/multi_agent_system.egg-info/SOURCES.txt @@ -88,6 +88,8 @@ docs/architecture/ARCHITECTURE_DIAGRAM.md docs/architecture/DECISION_TREE.md docs/architecture/DIKW_ARCHITECTURE_PLAN.md docs/architecture/ENABLE_RAG_AND_REASONING.md +docs/architecture/GRAPHITI_ALIGNMENT_SPECS.md +docs/architecture/GRAPHITI_MEMORY_BRIEF.md docs/architecture/HYPERGRAPH_BRIDGE_ARCHITECTURE.md docs/architecture/KNOWLEDGE_GRAPH_LAYERS_ARCHITECTURE.md docs/architecture/LANGGRAPH_ARCHITECTURE.md @@ -17922,6 +17924,7 @@ tests/test_episodic_memory.py tests/test_evaluation_framework.py tests/test_event_bus.py tests/test_generate_data_map.py +tests/test_graphiti_alignment_specs.py tests/test_improved_reasoning.py tests/test_in_memory_backend.py tests/test_intelligent_chat_integration.py diff --git a/tests/test_crystallization_pipeline.py b/tests/test_crystallization_pipeline.py index 97ea56f14..35e43c0a6 100644 --- a/tests/test_crystallization_pipeline.py +++ b/tests/test_crystallization_pipeline.py @@ -349,6 +349,249 @@ async def test_promotion_stats(self, gate): assert stats.total_rejected == 0 +# ======================================== +# SPEC-2: Temporal Conflict Resolution Tests +# ======================================== + +class TestTemporalConflictResolution: + """Tests for SPEC-2: Temporal conflict resolution in crystallization.""" + + @pytest.fixture + def mock_backend(self): + backend = AsyncMock() + backend.add_entity = AsyncMock() + backend.query = AsyncMock(return_value={"rows": []}) + return backend + + @pytest.fixture + def mock_resolver(self): + resolver = AsyncMock() + resolver.find_existing_for_crystallization = AsyncMock( + return_value=CrystallizationMatch(found=False) + ) + resolver.merge_for_crystallization = AsyncMock( + return_value=MergeResult(success=True, entity_id="merged_123", observation_count=2) + ) + resolver.normalize_entity_type = MagicMock(side_effect=lambda x: x.title()) + resolver.normalize_entity_name = MagicMock(side_effect=lambda x: x.strip().lower()) + return resolver + + @pytest.fixture + def service(self, mock_backend, mock_resolver): + event_bus = EventBus() + config = CrystallizationConfig(mode=CrystallizationMode.BATCH) + return CrystallizationService( + neo4j_backend=mock_backend, + entity_resolver=mock_resolver, + event_bus=event_bus, + config=config, + ) + + @pytest.mark.asyncio + async def test_new_entity_has_temporal_fields(self, service, mock_backend): + """New PERCEPTION entity includes valid_from, valid_until, is_current.""" + result = await service.crystallize_entities( + entities=[{"name": "Metformin", "entity_type": "Medication"}], + source="test", + ) + + assert result.entities_created == 1 + call_args = mock_backend.add_entity.call_args + props = call_args[1].get("properties", call_args[0][1] if len(call_args[0]) > 1 else {}) + assert "valid_from" in props + assert "is_current" in props + assert props["is_current"] is True + assert props["valid_until"] is None + + @pytest.mark.asyncio + async def test_entity_with_invalid_at_is_not_current(self, service, mock_backend): + """Entity with invalid_at should have is_current=False.""" + result = await service.crystallize_entities( + entities=[{ + "name": "Metformin", + "entity_type": "Medication", + "invalid_at": "2024-06-01T00:00:00", + }], + source="test", + ) + + assert result.entities_created == 1 + call_args = mock_backend.add_entity.call_args + props = call_args[1].get("properties", call_args[0][1] if len(call_args[0]) > 1 else {}) + assert props["is_current"] is False + assert props["valid_until"] == "2024-06-01T00:00:00" + + @pytest.mark.asyncio + async def test_resolve_temporal_conflicts_invalidates_old(self, service, mock_backend): + """_resolve_temporal_conflicts marks older entities as not current.""" + mock_backend.query.return_value = {"rows": [{"invalidated": 3}]} + + count = await service._resolve_temporal_conflicts( + entity_name="Metformin", + entity_type="Medication", + new_valid_from="2024-06-01T00:00:00", + ) + + assert count == 3 + mock_backend.query.assert_called_once() + query_str = mock_backend.query.call_args[0][0] + assert "is_current = false" in query_str + + @pytest.mark.asyncio + async def test_resolve_temporal_conflicts_skips_when_no_valid_from(self, service, mock_backend): + """No invalidation happens when new_valid_from is None.""" + count = await service._resolve_temporal_conflicts( + entity_name="Metformin", + entity_type="Medication", + new_valid_from=None, + ) + + assert count == 0 + mock_backend.query.assert_not_called() + + @pytest.mark.asyncio + async def test_crystallize_with_valid_at_triggers_conflict_resolution(self, service, mock_backend): + """Entity with valid_at triggers _resolve_temporal_conflicts call.""" + mock_backend.query.return_value = {"rows": [{"invalidated": 0}]} + + await service.crystallize_entities( + entities=[{ + "name": "Metformin", + "entity_type": "Medication", + "valid_at": "2024-06-01T00:00:00", + }], + source="test", + ) + + # The query should have been called for conflict resolution + assert mock_backend.query.call_count >= 1 + + @pytest.mark.asyncio + async def test_merge_passes_temporal_fields(self, service, mock_backend, mock_resolver): + """When merging, temporal fields from incoming data are passed through.""" + mock_resolver.find_existing_for_crystallization.return_value = CrystallizationMatch( + found=True, + entity_id="existing_123", + entity_data={"name": "Metformin", "layer": "PERCEPTION"}, + match_type="exact", + similarity_score=1.0, + ) + + await service.crystallize_entities( + entities=[{ + "name": "Metformin", + "entity_type": "Medication", + "valid_at": "2024-06-01T00:00:00", + "invalid_at": "2024-12-01T00:00:00", + }], + source="test", + ) + + merge_call = mock_resolver.merge_for_crystallization.call_args + new_data = merge_call[1].get("new_data", merge_call[0][1] if len(merge_call[0]) > 1 else {}) + assert new_data.get("valid_at") == "2024-06-01T00:00:00" + assert new_data.get("invalid_at") == "2024-12-01T00:00:00" + + +# ======================================== +# SPEC-2: EntityResolver Temporal Merge +# ======================================== + +class TestEntityResolverTemporalMerge: + """Tests for temporal field handling in EntityResolver.""" + + @pytest.fixture + def mock_backend(self): + backend = AsyncMock() + backend.query = AsyncMock(return_value={ + "rows": [{ + "properties": { + "name": "Metformin", + "confidence": 0.8, + "observation_count": 2, + } + }], + }) + return backend + + @pytest.fixture + def resolver(self, mock_backend): + return EntityResolver(backend=mock_backend) + + @pytest.mark.asyncio + async def test_merge_stores_valid_from(self, resolver, mock_backend): + """merge_for_crystallization stores valid_from when valid_at provided.""" + result = await resolver.merge_for_crystallization( + existing_id="entity_123", + new_data={"confidence": 0.85, "valid_at": "2024-06-01T00:00:00"}, + ) + + assert result.success + update_call = mock_backend.query.call_args_list[-1] + params = update_call[0][1] if len(update_call[0]) > 1 else {} + if "updates" in params: + assert "valid_from" in params["updates"] + + @pytest.mark.asyncio + async def test_merge_marks_not_current_when_invalid_at(self, resolver, mock_backend): + """merge_for_crystallization sets is_current=False with invalid_at.""" + result = await resolver.merge_for_crystallization( + existing_id="entity_123", + new_data={"confidence": 0.85, "invalid_at": "2024-12-01T00:00:00"}, + ) + + assert result.success + update_call = mock_backend.query.call_args_list[-1] + params = update_call[0][1] if len(update_call[0]) > 1 else {} + if "updates" in params: + assert params["updates"].get("is_current") is False + assert "valid_until" in params["updates"] + + +# ======================================== +# SPEC-4: Invalidation Service Integration +# ======================================== + +class TestCrystallizationWithInvalidation: + """Tests for SPEC-4: MemoryInvalidationService integration.""" + + @pytest.mark.asyncio + async def test_service_accepts_invalidation_service(self): + """CrystallizationService accepts invalidation_service parameter.""" + mock_backend = AsyncMock() + mock_resolver = AsyncMock() + mock_resolver.normalize_entity_type = MagicMock(side_effect=lambda x: x.title()) + event_bus = EventBus() + mock_invalidation = AsyncMock() + + service = CrystallizationService( + neo4j_backend=mock_backend, + entity_resolver=mock_resolver, + event_bus=event_bus, + config=CrystallizationConfig(mode=CrystallizationMode.BATCH), + invalidation_service=mock_invalidation, + ) + + assert service.invalidation_service is mock_invalidation + + @pytest.mark.asyncio + async def test_service_works_without_invalidation_service(self): + """CrystallizationService works with invalidation_service=None.""" + mock_backend = AsyncMock() + mock_resolver = AsyncMock() + mock_resolver.normalize_entity_type = MagicMock(side_effect=lambda x: x.title()) + event_bus = EventBus() + + service = CrystallizationService( + neo4j_backend=mock_backend, + entity_resolver=mock_resolver, + event_bus=event_bus, + config=CrystallizationConfig(mode=CrystallizationMode.BATCH), + ) + + assert service.invalidation_service is None + + # ======================================== # Integration Tests # ======================================== diff --git a/tests/test_episodic_memory.py b/tests/test_episodic_memory.py index 59fc1ef95..e2b9c3d22 100644 --- a/tests/test_episodic_memory.py +++ b/tests/test_episodic_memory.py @@ -13,6 +13,7 @@ EpisodicMemoryService, EpisodeResult, ConversationEpisode, + CommunitySummary, ) @@ -51,6 +52,10 @@ def mock_graphiti(): # Mock close mock.close = AsyncMock() + # Mock driver for community queries (SPEC-1) + mock.driver = AsyncMock() + mock.driver.execute_query = AsyncMock(return_value=[]) + return mock @@ -373,3 +378,163 @@ def test_episode_to_dict(): assert result["turn_number"] == 1 assert result["mode"] == "casual_chat" assert result["topics"] == ["greeting"] + + +# ============================================================ +# SPEC-1: Community Summary Tests +# ============================================================ + +@pytest.mark.asyncio +async def test_get_community_summaries(episodic_memory_service, mock_graphiti): + """Test getting community summaries from episodic memory.""" + mock_graphiti.driver.execute_query.return_value = [ + { + "community_id": "comm-1", + "summary": "Medications and dosages", + "entity_count": 4, + "key_entities": ["Metformin", "Insulin"], + "updated_at": datetime(2024, 6, 1), + }, + ] + + summaries = await episodic_memory_service.get_community_summaries( + patient_id="patient-123", + limit=5, + ) + + assert len(summaries) == 1 + assert isinstance(summaries[0], CommunitySummary) + assert summaries[0].summary == "Medications and dosages" + + +@pytest.mark.asyncio +async def test_get_community_summaries_graceful_on_error(episodic_memory_service, mock_graphiti): + """Test that community summary errors don't propagate.""" + mock_graphiti.driver.execute_query.side_effect = Exception("Not available") + + summaries = await episodic_memory_service.get_community_summaries( + patient_id="patient-123", + ) + + assert summaries == [] + + +@pytest.mark.asyncio +async def test_conversation_context_includes_communities(episodic_memory_service, mock_graphiti): + """Test that conversation context includes community_summaries.""" + mock_graphiti.retrieve_episodes.return_value = [] + mock_graphiti.driver.execute_query.return_value = [] + + with patch("application.services.episodic_memory_service.search") as mock_search: + mock_results = MagicMock() + mock_results.episodes = [] + mock_results.nodes = [] + mock_search.return_value = mock_results + + context = await episodic_memory_service.get_conversation_context( + patient_id="patient-123", + current_query="test", + ) + + assert "community_summaries" in context + + +# ============================================================ +# SPEC-3: Bound Search Results Tests +# ============================================================ + +@pytest.mark.asyncio +async def test_search_episodes_passes_num_results(episodic_memory_service): + """Test that search_episodes passes num_results to search.""" + with patch("application.services.episodic_memory_service.search") as mock_search: + mock_results = MagicMock() + mock_results.episodes = [] + mock_search.return_value = mock_results + + await episodic_memory_service.search_episodes( + patient_id="patient-123", + query="test", + limit=7, + ) + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["num_results"] == 7 + + +@pytest.mark.asyncio +async def test_get_related_entities_passes_num_results(episodic_memory_service): + """Test that get_related_entities passes num_results to search.""" + with patch("application.services.episodic_memory_service.search") as mock_search: + mock_results = MagicMock() + mock_results.nodes = [] + mock_search.return_value = mock_results + + await episodic_memory_service.get_related_entities( + patient_id="patient-123", + query="test", + limit=15, + ) + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["num_results"] == 15 + + +# ============================================================ +# SPEC-5: LLM Rate Limit Tests +# ============================================================ + +@pytest.mark.asyncio +async def test_store_episode_uses_retry_wrapper(episodic_memory_service, mock_graphiti): + """Test that store_turn_episode routes through _add_episode_with_retry.""" + # Verify by calling store and confirming the mock was called (via retry wrapper) + result = await episodic_memory_service.store_turn_episode( + patient_id="patient-123", + session_id="session-abc", + user_message="Hello", + assistant_message="Hi!", + turn_number=1, + ) + + assert result.episode_id == "ep-123" + mock_graphiti.add_episode.assert_called_once() + + +@pytest.mark.asyncio +async def test_rate_limit_retry_increments_stats(episodic_memory_service, mock_graphiti): + """Test that rate limit retries are tracked in stats.""" + call_count = 0 + + async def rate_limit_then_succeed(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("429 Rate limit exceeded") + mock_result = MagicMock() + mock_result.episode = MagicMock(uuid="ep-retry") + mock_result.nodes = [] + mock_result.edges = [] + return mock_result + + mock_graphiti.add_episode.side_effect = rate_limit_then_succeed + + with patch("application.services.episodic_memory_service.asyncio.sleep", new_callable=AsyncMock): + result = await episodic_memory_service.store_turn_episode( + patient_id="patient-123", + session_id="session-abc", + user_message="Hello", + assistant_message="Hi!", + turn_number=1, + ) + + assert result.episode_id == "ep-retry" + assert episodic_memory_service._rate_limit_stats["retries"] == 1 + + +def test_get_health_returns_expected_keys(episodic_memory_service): + """Test that get_health returns all expected fields.""" + health = episodic_memory_service.get_health() + + assert "initialized" in health + assert "rate_limit_retries" in health + assert "rate_limit_failures" in health + assert "semaphore_limit" in health diff --git a/tests/test_graphiti_alignment_specs.py b/tests/test_graphiti_alignment_specs.py new file mode 100644 index 000000000..65307f2a9 --- /dev/null +++ b/tests/test_graphiti_alignment_specs.py @@ -0,0 +1,645 @@ +"""Tests for the 5 Graphiti Alignment Specs. + +SPEC-1: Community/summary layer +SPEC-2: Temporal conflict resolution +SPEC-3: Bound search results +SPEC-4: Memory invalidation/expiration +SPEC-5: LLM rate limit management +""" + +import os +import pytest +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +from application.services.episodic_memory_service import ( + EpisodicMemoryService, + CommunitySummary, +) +from application.services.crystallization_service import ( + CrystallizationService, + CrystallizationConfig, + CrystallizationMode, +) +from application.services.entity_resolver import ( + EntityResolver, + CrystallizationMatch, + MergeResult, +) +from application.services.memory_invalidation_service import ( + MemoryInvalidationService, + InvalidationConfig, + InvalidationResult, +) + + +# ============================================================ +# FIXTURES +# ============================================================ + +@pytest.fixture +def mock_graphiti(): + """Create a mock Graphiti instance.""" + mock = AsyncMock() + mock.build_indices_and_constraints = AsyncMock() + + mock_episode = MagicMock() + mock_episode.uuid = "ep-123" + mock_result = MagicMock() + mock_result.episode = mock_episode + mock_result.nodes = [MagicMock(name="Entity1")] + mock_result.edges = [MagicMock()] + mock.add_episode = AsyncMock(return_value=mock_result) + mock.retrieve_episodes = AsyncMock(return_value=[]) + mock.clients = MagicMock() + mock.close = AsyncMock() + + # Mock driver for community queries + mock.driver = AsyncMock() + mock.driver.execute_query = AsyncMock(return_value=[]) + + return mock + + +@pytest.fixture +def episodic_service(mock_graphiti): + """Create EpisodicMemoryService with mocked Graphiti.""" + return EpisodicMemoryService(graphiti=mock_graphiti) + + +@pytest.fixture +def mock_neo4j_backend(): + """Create a mock Neo4j backend.""" + backend = AsyncMock() + backend.add_entity = AsyncMock() + backend.query = AsyncMock(return_value={"rows": []}) + return backend + + +@pytest.fixture +def mock_resolver(mock_neo4j_backend): + """Create a mock EntityResolver.""" + resolver = AsyncMock() + resolver.find_existing_for_crystallization = AsyncMock( + return_value=CrystallizationMatch(found=False) + ) + resolver.merge_for_crystallization = AsyncMock( + return_value=MergeResult(success=True, entity_id="merged_123", observation_count=2) + ) + resolver.normalize_entity_type = MagicMock(side_effect=lambda x: x.title()) + resolver.normalize_entity_name = MagicMock(side_effect=lambda x: x.strip().lower()) + return resolver + + +@pytest.fixture +def mock_event_bus(): + """Create a mock EventBus.""" + bus = AsyncMock() + bus.subscribe = MagicMock() + bus.publish = AsyncMock() + return bus + + +# ============================================================ +# SPEC-1: Community / Summary Layer +# ============================================================ + +class TestSpec1CommunitySummaries: + """Tests for SPEC-1: Community/summary layer.""" + + @pytest.mark.asyncio + async def test_get_community_summaries_returns_results(self, episodic_service, mock_graphiti): + """Community nodes from graph are returned as CommunitySummary list.""" + mock_graphiti.driver.execute_query.return_value = [ + { + "community_id": "comm-1", + "summary": "Patient medications and treatment history", + "entity_count": 5, + "key_entities": ["Metformin", "Diabetes", "Insulin"], + "updated_at": datetime(2024, 6, 1), + }, + ] + + summaries = await episodic_service.get_community_summaries( + patient_id="patient-123", + limit=5, + ) + + assert len(summaries) == 1 + assert isinstance(summaries[0], CommunitySummary) + assert summaries[0].community_id == "comm-1" + assert summaries[0].entity_count == 5 + assert "Metformin" in summaries[0].key_entities + + @pytest.mark.asyncio + async def test_get_community_summaries_empty_graceful(self, episodic_service, mock_graphiti): + """Returns [] when no community nodes exist.""" + mock_graphiti.driver.execute_query.return_value = [] + + summaries = await episodic_service.get_community_summaries( + patient_id="patient-123", + ) + + assert summaries == [] + + @pytest.mark.asyncio + async def test_get_community_summaries_error_graceful(self, episodic_service, mock_graphiti): + """Returns [] on error, no exception propagated.""" + mock_graphiti.driver.execute_query.side_effect = Exception("DB error") + + summaries = await episodic_service.get_community_summaries( + patient_id="patient-123", + ) + + assert summaries == [] + + @pytest.mark.asyncio + async def test_get_community_summaries_no_driver(self, mock_graphiti): + """Returns [] when Graphiti has no driver attribute.""" + del mock_graphiti.driver + service = EpisodicMemoryService(graphiti=mock_graphiti) + + summaries = await service.get_community_summaries(patient_id="patient-123") + + assert summaries == [] + + @pytest.mark.asyncio + async def test_context_includes_community_summaries(self, episodic_service, mock_graphiti): + """get_conversation_context() response includes community_summaries field.""" + mock_graphiti.retrieve_episodes.return_value = [] + + with patch("application.services.episodic_memory_service.search") as mock_search: + mock_results = MagicMock() + mock_results.episodes = [] + mock_results.nodes = [] + mock_search.return_value = mock_results + + context = await episodic_service.get_conversation_context( + patient_id="patient-123", + current_query="test query", + ) + + assert "community_summaries" in context + assert isinstance(context["community_summaries"], list) + + @pytest.mark.asyncio + async def test_summary_to_dict(self, episodic_service): + """CommunitySummary converts to dict correctly.""" + summary = CommunitySummary( + community_id="comm-1", + summary="Test summary", + entity_count=3, + key_entities=["A", "B"], + updated_at=datetime(2024, 1, 1), + ) + + result = episodic_service._summary_to_dict(summary) + + assert result["community_id"] == "comm-1" + assert result["summary"] == "Test summary" + assert result["entity_count"] == 3 + assert result["updated_at"] == "2024-01-01T00:00:00" + + +# ============================================================ +# SPEC-2: Temporal Conflict Resolution +# ============================================================ + +class TestSpec2TemporalConflictResolution: + """Tests for SPEC-2: Temporal conflict resolution.""" + + @pytest.fixture + def crystallization_service(self, mock_neo4j_backend, mock_resolver, mock_event_bus): + config = CrystallizationConfig(mode=CrystallizationMode.BATCH) + return CrystallizationService( + neo4j_backend=mock_neo4j_backend, + entity_resolver=mock_resolver, + event_bus=mock_event_bus, + config=config, + ) + + @pytest.mark.asyncio + async def test_new_entity_gets_temporal_fields(self, crystallization_service, mock_neo4j_backend): + """PERCEPTION entity is created with valid_from, valid_until, is_current.""" + result = await crystallization_service.crystallize_entities( + entities=[{"name": "Metformin", "entity_type": "Medication"}], + source="test", + ) + + assert result.entities_created == 1 + # Verify add_entity was called with temporal properties + call_args = mock_neo4j_backend.add_entity.call_args + props = call_args[1]["properties"] if "properties" in call_args[1] else call_args[0][1] + assert "valid_from" in props + assert "is_current" in props + assert props["is_current"] is True + assert props["valid_until"] is None + + @pytest.mark.asyncio + async def test_entity_with_invalid_at_not_current(self, crystallization_service, mock_neo4j_backend): + """Entity with invalid_at is marked is_current=False.""" + result = await crystallization_service.crystallize_entities( + entities=[{ + "name": "Metformin", + "entity_type": "Medication", + "invalid_at": "2024-06-01T00:00:00", + }], + source="test", + ) + + assert result.entities_created == 1 + call_args = mock_neo4j_backend.add_entity.call_args + props = call_args[1]["properties"] if "properties" in call_args[1] else call_args[0][1] + assert props["is_current"] is False + assert props["valid_until"] == "2024-06-01T00:00:00" + + @pytest.mark.asyncio + async def test_resolve_temporal_conflicts_invalidates_old( + self, crystallization_service, mock_neo4j_backend + ): + """Calling _resolve_temporal_conflicts marks older entities as not current.""" + mock_neo4j_backend.query.return_value = {"rows": [{"invalidated": 2}]} + + invalidated = await crystallization_service._resolve_temporal_conflicts( + entity_name="Metformin", + entity_type="Medication", + new_valid_from="2024-06-01T00:00:00", + ) + + assert invalidated == 2 + mock_neo4j_backend.query.assert_called() + + @pytest.mark.asyncio + async def test_resolve_temporal_conflicts_no_valid_from(self, crystallization_service): + """No invalidation when new_valid_from is None.""" + invalidated = await crystallization_service._resolve_temporal_conflicts( + entity_name="Metformin", + entity_type="Medication", + new_valid_from=None, + ) + + assert invalidated == 0 + + @pytest.mark.asyncio + async def test_merge_propagates_temporal_fields(self, mock_neo4j_backend): + """merge_for_crystallization stores temporal fields from incoming data.""" + resolver = EntityResolver(backend=mock_neo4j_backend) + + mock_neo4j_backend.query.return_value = { + "rows": [{ + "properties": { + "name": "Metformin", + "confidence": 0.8, + "observation_count": 2, + } + }], + } + + result = await resolver.merge_for_crystallization( + existing_id="entity_123", + new_data={ + "confidence": 0.85, + "valid_at": "2024-06-01T00:00:00", + "invalid_at": "2024-12-01T00:00:00", + }, + ) + + assert result.success + # Check that the update query was called with temporal fields + update_call = mock_neo4j_backend.query.call_args_list[-1] + update_params = update_call[0][1] if len(update_call[0]) > 1 else update_call[1].get("params", {}) + # The updates dict should contain valid_from and valid_until + if isinstance(update_params, dict) and "updates" in update_params: + updates = update_params["updates"] + assert "valid_from" in updates + assert "valid_until" in updates + assert updates["is_current"] is False + + @pytest.mark.asyncio + async def test_crystallize_with_valid_at_triggers_conflict_resolution( + self, crystallization_service, mock_neo4j_backend + ): + """Entity with valid_at triggers _resolve_temporal_conflicts.""" + mock_neo4j_backend.query.return_value = {"rows": [{"invalidated": 0}]} + + with patch.object( + crystallization_service, "_resolve_temporal_conflicts", new_callable=AsyncMock + ) as mock_resolve: + mock_resolve.return_value = 0 + + await crystallization_service.crystallize_entities( + entities=[{ + "name": "Metformin", + "entity_type": "Medication", + "valid_at": "2024-06-01T00:00:00", + }], + source="test", + ) + + mock_resolve.assert_called_once_with( + entity_name="Metformin", + entity_type="Medication", + new_valid_from="2024-06-01T00:00:00", + ) + + +# ============================================================ +# SPEC-3: Bound Search Results +# ============================================================ + +class TestSpec3BoundSearchResults: + """Tests for SPEC-3: Bound search results.""" + + @pytest.mark.asyncio + async def test_search_episodes_passes_num_results(self, episodic_service): + """search_episodes passes num_results to the search function.""" + with patch("application.services.episodic_memory_service.search") as mock_search: + mock_results = MagicMock() + mock_results.episodes = [] + mock_search.return_value = mock_results + + await episodic_service.search_episodes( + patient_id="patient-123", + query="test", + limit=7, + ) + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["num_results"] == 7 + + @pytest.mark.asyncio + async def test_get_related_entities_passes_num_results(self, episodic_service): + """get_related_entities passes num_results to the search function.""" + with patch("application.services.episodic_memory_service.search") as mock_search: + mock_results = MagicMock() + mock_results.nodes = [] + mock_search.return_value = mock_results + + await episodic_service.get_related_entities( + patient_id="patient-123", + query="test", + limit=15, + ) + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["num_results"] == 15 + + @pytest.mark.asyncio + async def test_entity_resolver_parameterized_limit(self, mock_neo4j_backend): + """EntityResolver._get_existing_entities uses parameterized LIMIT.""" + resolver = EntityResolver(backend=mock_neo4j_backend) + mock_neo4j_backend.query.return_value = [] + + await resolver._get_existing_entities( + entity_type="Medication", + context={}, + limit=50, + ) + + call_args = mock_neo4j_backend.query.call_args + query_str = call_args[0][0] + params = call_args[0][1] if len(call_args[0]) > 1 else {} + assert "$limit" in query_str + assert params.get("limit") == 50 + + +# ============================================================ +# SPEC-4: Memory Invalidation / Expiration +# ============================================================ + +class TestSpec4MemoryInvalidation: + """Tests for SPEC-4: Memory invalidation/expiration.""" + + @pytest.fixture + def invalidation_service(self, mock_neo4j_backend): + return MemoryInvalidationService( + neo4j_backend=mock_neo4j_backend, + config=InvalidationConfig(stale_threshold_days=90), + ) + + @pytest.mark.asyncio + async def test_invalidate_entity_sets_flags(self, invalidation_service, mock_neo4j_backend): + """After invalidation, entity has is_current=false and audit fields.""" + mock_neo4j_backend.query.return_value = { + "rows": [{"id": "entity_123"}] + } + + result = await invalidation_service.invalidate_entity( + entity_id="entity_123", + reason="discontinued_medication", + ) + + assert result.entities_invalidated == 1 + assert "entity_123" in result.entity_ids + assert result.reason == "discontinued_medication" + + # Verify query was called + call_args = mock_neo4j_backend.query.call_args + query_str = call_args[0][0] + assert "is_current = false" in query_str + assert "invalidated_at" in query_str + assert "invalidation_reason" in query_str + + @pytest.mark.asyncio + async def test_invalidate_entity_idempotent(self, invalidation_service, mock_neo4j_backend): + """Invalidating already-invalidated entity returns count 0, no error.""" + mock_neo4j_backend.query.return_value = {"rows": []} + + result = await invalidation_service.invalidate_entity( + entity_id="already_invalid", + reason="test", + ) + + assert result.entities_invalidated == 0 + + @pytest.mark.asyncio + async def test_invalidate_by_query_filters(self, invalidation_service, mock_neo4j_backend): + """Bulk invalidation with patient_id filter only targets matching entities.""" + mock_neo4j_backend.query.return_value = { + "rows": [{"id": "e1"}, {"id": "e2"}] + } + + result = await invalidation_service.invalidate_by_query( + patient_id="patient-123", + entity_type="Medication", + reason="bulk_test", + ) + + assert result.entities_invalidated == 2 + call_args = mock_neo4j_backend.query.call_args + query_str = call_args[0][0] + assert "patient_id" in query_str + assert "entity_type" in query_str + + @pytest.mark.asyncio + async def test_sweep_targets_perception_only(self, invalidation_service, mock_neo4j_backend): + """Stale sweep query only targets PERCEPTION layer entities.""" + mock_neo4j_backend.query.return_value = {"rows": [{"id": "stale_1", "name": "Old Entity"}]} + + result = await invalidation_service.sweep_stale_entities() + + assert result.entities_invalidated == 1 + call_args = mock_neo4j_backend.query.call_args + query_str = call_args[0][0] + assert "PERCEPTION" in query_str + + @pytest.mark.asyncio + async def test_sweep_respects_threshold(self, invalidation_service, mock_neo4j_backend): + """Sweep passes correct cutoff date based on threshold.""" + mock_neo4j_backend.query.return_value = {"rows": []} + + await invalidation_service.sweep_stale_entities() + + call_args = mock_neo4j_backend.query.call_args + params = call_args[0][1] if len(call_args[0]) > 1 else {} + cutoff_str = params.get("cutoff_date", "") + # Cutoff should be ~90 days ago + cutoff = datetime.fromisoformat(cutoff_str) + expected_cutoff = datetime.utcnow() - timedelta(days=90) + assert abs((cutoff - expected_cutoff).total_seconds()) < 60 # Within 1 minute + + @pytest.mark.asyncio + async def test_sweep_disabled_config(self, mock_neo4j_backend): + """stale_check_enabled=False -> sweep returns 0 immediately.""" + service = MemoryInvalidationService( + neo4j_backend=mock_neo4j_backend, + config=InvalidationConfig(stale_check_enabled=False), + ) + + result = await service.sweep_stale_entities() + + assert result.entities_invalidated == 0 + assert result.reason == "stale_sweep_disabled" + mock_neo4j_backend.query.assert_not_called() + + def test_get_invalidation_stats(self, invalidation_service): + """Stats dict has all expected keys.""" + stats = invalidation_service.get_invalidation_stats() + + assert "total_invalidated" in stats + assert "sweep_runs" in stats + assert "last_sweep" in stats + assert "stale_threshold_days" in stats + assert stats["stale_threshold_days"] == 90 + + @pytest.mark.asyncio + async def test_invalidate_entity_error_handling(self, invalidation_service, mock_neo4j_backend): + """Error during invalidation returns 0 count, no exception.""" + mock_neo4j_backend.query.side_effect = Exception("DB error") + + result = await invalidation_service.invalidate_entity( + entity_id="entity_123", + reason="test", + ) + + assert result.entities_invalidated == 0 + + +# ============================================================ +# SPEC-5: LLM Rate Limit Management +# ============================================================ + +class TestSpec5LLMRateLimitManagement: + """Tests for SPEC-5: LLM rate limit management.""" + + @pytest.mark.asyncio + async def test_retry_on_rate_limit(self, episodic_service, mock_graphiti): + """Rate limit errors trigger retry with backoff.""" + call_count = 0 + + async def side_effect(**kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise Exception("Error 429: Rate limit exceeded") + mock_result = MagicMock() + mock_result.episode = MagicMock(uuid="ep-retry") + mock_result.nodes = [] + mock_result.edges = [] + return mock_result + + mock_graphiti.add_episode.side_effect = side_effect + + with patch("application.services.episodic_memory_service.asyncio.sleep", new_callable=AsyncMock): + result = await episodic_service.store_turn_episode( + patient_id="patient-123", + session_id="session-abc", + user_message="Hello", + assistant_message="Hi!", + turn_number=1, + ) + + assert result.episode_id == "ep-retry" + assert episodic_service._rate_limit_stats["retries"] == 2 + + @pytest.mark.asyncio + async def test_retry_exhausted_raises(self, episodic_service, mock_graphiti): + """After max retries, error propagates.""" + mock_graphiti.add_episode.side_effect = Exception("Error 429: Rate limit") + + with patch("application.services.episodic_memory_service.asyncio.sleep", new_callable=AsyncMock): + with patch.dict(os.environ, {"GRAPHITI_LLM_MAX_RETRIES": "2"}): + with pytest.raises(Exception, match="429"): + await episodic_service.store_turn_episode( + patient_id="patient-123", + session_id="session-abc", + user_message="Hello", + assistant_message="Hi!", + turn_number=1, + ) + + assert episodic_service._rate_limit_stats["failures"] == 1 + + @pytest.mark.asyncio + async def test_non_rate_limit_error_no_retry(self, episodic_service, mock_graphiti): + """Non-rate-limit errors are not retried.""" + mock_graphiti.add_episode.side_effect = ValueError("Bad input") + + with pytest.raises(ValueError, match="Bad input"): + await episodic_service.store_turn_episode( + patient_id="patient-123", + session_id="session-abc", + user_message="Hello", + assistant_message="Hi!", + turn_number=1, + ) + + assert episodic_service._rate_limit_stats["retries"] == 0 + assert episodic_service._rate_limit_stats["failures"] == 0 + + @pytest.mark.asyncio + async def test_retry_disabled_no_retry(self, episodic_service, mock_graphiti): + """GRAPHITI_LLM_RETRY_ENABLED=false -> 429 error propagated immediately.""" + mock_graphiti.add_episode.side_effect = Exception("Error 429: Rate limit") + + with patch.dict(os.environ, {"GRAPHITI_LLM_RETRY_ENABLED": "false"}): + with pytest.raises(Exception, match="429"): + await episodic_service.store_turn_episode( + patient_id="patient-123", + session_id="session-abc", + user_message="Hello", + assistant_message="Hi!", + turn_number=1, + ) + + assert episodic_service._rate_limit_stats["retries"] == 0 + assert episodic_service._rate_limit_stats["failures"] == 1 + + def test_get_health_includes_rate_stats(self, episodic_service): + """get_health() returns dict with all expected keys.""" + health = episodic_service.get_health() + + assert "initialized" in health + assert "rate_limit_retries" in health + assert "rate_limit_failures" in health + assert "last_rate_limit" in health + assert "semaphore_limit" in health + assert health["rate_limit_retries"] == 0 + assert health["rate_limit_failures"] == 0 + + def test_semaphore_limit_env_var(self): + """GRAPHITI_SEMAPHORE_LIMIT env var is readable by get_health.""" + with patch.dict(os.environ, {"SEMAPHORE_LIMIT": "7"}): + service = EpisodicMemoryService(graphiti=MagicMock()) + health = service.get_health() + assert health["semaphore_limit"] == 7