Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ demo/workdir/.claude/
dist
.coverage
.evolve
.omc
.omx
.secrets
event.json
Expand Down
95 changes: 0 additions & 95 deletions altk_evolve/frontend/client/evolve_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import logging
from typing import Any

from altk_evolve.backend.base import BaseEntityBackend
from altk_evolve.config.evolve import EvolveConfig
from altk_evolve.llm.fact_extraction.fact_extraction import ExtractedFact, extract_facts_from_messages
from altk_evolve.schema.conflict_resolution import EntityUpdate
from altk_evolve.schema.core import Entity, Namespace, RecordedEntity
from altk_evolve.schema.exceptions import NamespaceAlreadyExistsException, NamespaceNotFoundException
Expand Down Expand Up @@ -245,96 +243,3 @@ def ensure_namespace(self, namespace_id: str) -> Namespace:
return self.create_namespace(namespace_id)
except NamespaceAlreadyExistsException:
return self.get_namespace_details(namespace_id)

def store_user_facts(
self,
namespace_id: str,
message: str,
user_id: str,
metadata: dict[str, Any] | None = None,
enable_conflict_resolution: bool = False,
) -> list[EntityUpdate]:
"""Extract facts from a user utterance and persist them as `fact` entities."""
message = (message or "").strip()
if not message:
return []

self.ensure_namespace(namespace_id)

base_metadata: dict[str, Any] = dict(metadata or {})
base_metadata["user_id"] = user_id

extracted = extract_facts_from_messages([{"role": "user", "content": message}])
entities: list[Entity] = []
for one in extracted:
if isinstance(one, ExtractedFact):
fact_metadata = dict(base_metadata)
fact_metadata["category"] = one.category
fact_metadata["key"] = one.key
fact_metadata["value"] = one.value
entities.append(Entity(type="fact", content=one.content, metadata=fact_metadata))
else:
entities.append(Entity(type="fact", content=str(one), metadata=dict(base_metadata)))

if not entities:
return []

return self.update_entities(
namespace_id=namespace_id,
entities=entities,
enable_conflict_resolution=enable_conflict_resolution,
)

def retrieve_user_facts(
self,
namespace_id: str,
user_id: str,
query: str | None = None,
limit: int = 5,
) -> dict[str, list[dict[str, Any]]]:
"""Retrieve categorized user facts for prompt/context usage."""
if limit <= 0 or not self.namespace_exists(namespace_id):
return {}

facts = self.search_entities(
namespace_id=namespace_id,
query=query,
filters={"type": "fact", "metadata.user_id": user_id},
limit=limit,
)
if query and not facts:
facts = self.search_entities(
namespace_id=namespace_id,
query=None,
filters={"type": "fact", "metadata.user_id": user_id},
limit=limit,
)
if not facts and user_id != "default":
facts = self.search_entities(
namespace_id=namespace_id,
query=query,
filters={"type": "fact", "metadata.user_id": "default"},
limit=limit,
)
if query and not facts:
facts = self.search_entities(
namespace_id=namespace_id,
query=None,
filters={"type": "fact", "metadata.user_id": "default"},
limit=limit,
)

categorized_preferences: dict[str, list[dict[str, Any]]] = {}
for fact in facts:
metadata = fact.metadata or {}
category = str(metadata.get("category") or "misc")
categorized_preferences.setdefault(category, []).append(
{
"id": fact.id,
"content": str(fact.content),
"key": metadata.get("key"),
"value": metadata.get("value"),
}
)

return categorized_preferences
117 changes: 104 additions & 13 deletions altk_evolve/frontend/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from altk_evolve.config.evolve import evolve_config
from altk_evolve.frontend.client.evolve_client import EvolveClient
from altk_evolve.frontend.api.routes import router as api_router
from altk_evolve.llm.fact_extraction.fact_extraction import (
ExtractedFact,
categorize_facts,
extract_facts_from_messages,
)
from altk_evolve.llm.guidelines.guidelines import generate_guidelines
from altk_evolve.schema.core import Entity, RecordedEntity
from altk_evolve.schema.exceptions import EvolveException, NamespaceNotFoundException
Expand Down Expand Up @@ -278,6 +283,10 @@ def get_guidelines(
return get_entities_logic(task, "guideline", user_id=user_id, namespace_id=namespace_id, session_id=session_id)


def _empty_store_user_facts_response(user_id: str) -> str:
return json.dumps({"user_id": user_id, "stored_count": 0, "updates": []})


@mcp.tool()
def store_user_facts(
user_id: str,
Expand All @@ -297,13 +306,44 @@ def store_user_facts(
}
)

updates = get_client().store_user_facts(
namespace_id=evolve_config.namespace_id,
message=message,
user_id=user_id,
metadata=metadata_dict,
enable_conflict_resolution=enable_conflict_resolution,
)
trimmed_message = (message or "").strip()
if not trimmed_message:
return _empty_store_user_facts_response(user_id)

resolved_ns = _resolve_namespace(None)

base_metadata: dict[str, Any] = dict(metadata_dict)
base_metadata["user_id"] = user_id

extracted = extract_facts_from_messages([{"role": "user", "content": trimmed_message}])
entities: list[Entity] = []
for one in extracted:
if isinstance(one, ExtractedFact):
fact_metadata = dict(base_metadata)
fact_metadata["category"] = one.category
fact_metadata["key"] = one.key
fact_metadata["value"] = one.value
entities.append(Entity(type="fact", content=one.content, metadata=fact_metadata))
else:
entities.append(Entity(type="fact", content=str(one), metadata=dict(base_metadata)))

if not entities:
return _empty_store_user_facts_response(user_id)

try:
updates = get_client().update_entities(
namespace_id=resolved_ns,
entities=entities,
enable_conflict_resolution=enable_conflict_resolution,
)
except NamespaceNotFoundException:
_evict_namespace(resolved_ns)
resolved_ns = _resolve_namespace(None)
updates = get_client().update_entities(
namespace_id=resolved_ns,
entities=entities,
enable_conflict_resolution=enable_conflict_resolution,
)

serialized_updates = [
{
Expand All @@ -325,15 +365,66 @@ def store_user_facts(
)


@mcp.tool()
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
"""Retrieve categorized user facts/preferences for a durable user identity."""
categories = get_client().retrieve_user_facts(
namespace_id=evolve_config.namespace_id,
user_id=user_id,
def _search_facts_with_fallback(
namespace_id: str,
user_id: str,
query: str | None,
limit: int,
) -> list[RecordedEntity]:
"""Fetch fact entities for a user with the legacy fallback chain.

Order: (1) user filter + query, (2) user filter without query, (3) default
user with query, (4) default user without query. The default-user fallback
is skipped when the caller is already ``"default"``.
"""
client = get_client()
facts = client.search_entities(
namespace_id=namespace_id,
query=query,
filters={"type": "fact", "metadata.user_id": user_id},
limit=limit,
)
if query and not facts:
facts = client.search_entities(
namespace_id=namespace_id,
query=None,
filters={"type": "fact", "metadata.user_id": user_id},
limit=limit,
)
if not facts and user_id != "default":
facts = client.search_entities(
namespace_id=namespace_id,
query=query,
filters={"type": "fact", "metadata.user_id": "default"},
limit=limit,
)
if query and not facts:
facts = client.search_entities(
namespace_id=namespace_id,
query=None,
filters={"type": "fact", "metadata.user_id": "default"},
limit=limit,
)
return facts


@mcp.tool()
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
"""Retrieve categorized user facts/preferences for a durable user identity."""
namespace_id = evolve_config.namespace_id

if limit <= 0 or not get_client().namespace_exists(namespace_id):
return json.dumps(
{
"user_id": user_id,
"query": query,
"matched_count": 0,
"categories": {},
}
)

facts = _search_facts_with_fallback(namespace_id, user_id, query, limit)
categories = categorize_facts(facts)
matched_count = sum(len(items) for items in categories.values())

return json.dumps(
Expand Down
24 changes: 24 additions & 0 deletions altk_evolve/llm/fact_extraction/fact_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from altk_evolve.config.llm import llm_settings
from altk_evolve.llm.fact_extraction.categorization import CategoryManager
from altk_evolve.schema.core import RecordedEntity
from altk_evolve.utils.utils import clean_llm_response


Expand Down Expand Up @@ -77,3 +78,26 @@ def extract_facts_from_messages(messages: list[dict], use_categorization: bool |
last_error = exc
continue
raise ValueError(f"Failed to parse extracted facts response: {last_error}")


def categorize_facts(facts: list[RecordedEntity]) -> dict[str, list[dict[str, Any]]]:
"""Group fact entities by their metadata category.

Pure helper with no client or CRUD coupling: takes a list of already-fetched
RecordedEntity facts and returns a dict keyed by ``metadata.category`` (or
``"misc"`` when absent), with each entry exposing the id, content, key, and
value of the fact.
"""
categorized: dict[str, list[dict[str, Any]]] = {}
for fact in facts:
metadata = fact.metadata or {}
category = str(metadata.get("category") or "misc")
categorized.setdefault(category, []).append(
{
"id": fact.id,
"content": str(fact.content),
"key": metadata.get("key"),
"value": metadata.get("value"),
}
)
return categorized
49 changes: 0 additions & 49 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,52 +196,3 @@ def delete_entity_by_id(self, namespace_id, entity_id):

monkeypatch.setattr(evolve_client.backend, "delete_entity_by_id", delete_entity_by_id.__get__(evolve_client.backend, BaseEntityBackend))
evolve_client.delete_entity_by_id(namespace_id="foobar", entity_id="1")


@pytest.mark.unit
@pytest.mark.parametrize("message", [None, "", " \t\n"])
def test_store_user_facts_skips_none_empty_or_whitespace(evolve_client: EvolveClient, monkeypatch, message):
def fail_ensure_namespace(namespace_id: str):
raise AssertionError("ensure_namespace should not be called for blank messages")

def fail_update_entities(namespace_id, entities, enable_conflict_resolution=True):
raise AssertionError("update_entities should not be called for blank messages")

def fail_extract(messages):
raise AssertionError("extract_facts_from_messages should not be called for blank messages")

monkeypatch.setattr(evolve_client, "ensure_namespace", fail_ensure_namespace)
monkeypatch.setattr(evolve_client, "update_entities", fail_update_entities)
monkeypatch.setattr("altk_evolve.frontend.client.evolve_client.extract_facts_from_messages", fail_extract)

result = evolve_client.store_user_facts(namespace_id="foobar", message=message, user_id="u1")

assert result == []


@pytest.mark.unit
def test_store_user_facts_uses_trimmed_message(evolve_client: EvolveClient, monkeypatch):
captured: dict = {"ensure_namespace_called": False}

def ensure_namespace(namespace_id: str):
captured["ensure_namespace_called"] = True
return Namespace(id=namespace_id, created_at=datetime.datetime.now(datetime.UTC))

def extract(messages):
captured["message_content"] = messages[0]["content"]
return ["trimmed fact"]

def update_entities(namespace_id, entities, enable_conflict_resolution=True):
captured["entity_content"] = entities[0].content if entities else None
return [EntityUpdate(id="1", type="fact", content="trimmed fact", event="ADD")]

monkeypatch.setattr(evolve_client, "ensure_namespace", ensure_namespace)
monkeypatch.setattr(evolve_client, "update_entities", update_entities)
monkeypatch.setattr("altk_evolve.frontend.client.evolve_client.extract_facts_from_messages", extract)

result = evolve_client.store_user_facts(namespace_id="foobar", message=" hello world \n", user_id="u1")

assert captured["ensure_namespace_called"] is True
assert captured["message_content"] == "hello world"
assert captured["entity_content"] == "trimmed fact"
assert result[0].event == "ADD"
Loading
Loading