diff --git a/README.md b/README.md index 0b47f49f..163655fd 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ client.publish_interaction( ) # Search profiles -profiles = client.search_profiles( +profiles = client.search_user_profiles( reflexio.SearchUserProfileRequest(query="deployment region preference") ) diff --git a/client_dist/README.md b/client_dist/README.md index 6d58548f..9fbabc07 100644 --- a/client_dist/README.md +++ b/client_dist/README.md @@ -98,7 +98,7 @@ print(response.success, response.message) ```python # Semantic search for profiles -results = client.search_profiles(user_id="user-123", query="password preferences") +results = client.search_user_profiles(user_id="user-123", query="password preferences") for profile in results.profiles: print(profile.profile_name, profile.profile_content) @@ -313,7 +313,7 @@ In async contexts (e.g., FastAPI), fire-and-forget uses the existing event loop. | Method | Description | |--------|-------------| -| `search_profiles()` | Semantic search for profiles | +| `search_user_profiles()` | Semantic search for profiles | | `get_profiles()` | Get profiles for a user | | `get_all_profiles()` | Get all profiles across users | | `delete_profile()` | Delete profiles by ID or search query | diff --git a/pyproject.toml b/pyproject.toml index 38f9497a..2068a92c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,10 @@ dependencies = [ "typer>=0.15.0", "rich>=13.0.0", "chromadb>=1.5.8", + # Cross-encoder reranker + local embedding providers — chromadb pulls + # ``sentence-transformers`` transitively, but we depend on it directly + # so the CrossEncoder/SentenceTransformer surface is guaranteed. + "sentence-transformers>=3.0", ] [project.optional-dependencies] @@ -79,6 +83,7 @@ dev = [ "python-semantic-release>=10.0.0", "build>=1.0.0", "twine>=6.0.0", + "polars>=1.40.1", ] docs = [ "mkdocs>=1.5.3", @@ -217,6 +222,18 @@ max-complexity = 20 quote-style = "double" indent-style = "space" +[tool.pyright] +include = ["reflexio", "tests"] +exclude = [ + "reflexio/integrations/langchain", + "tests/test_scripts", + "**/__pycache__", + "**/.venv", + "benchmark", + "notebooks", +] +reportMissingImports = "warning" + [tool.mutmut] paths_to_mutate = [ "reflexio/server/services/service_utils.py", diff --git a/pyrightconfig.json b/pyrightconfig.json index 9bc9b29d..e6c94d04 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,14 +1,19 @@ { - "include": ["reflexio"], + "include": ["reflexio", "tests"], "exclude": [ "reflexio/website", "reflexio/tests", "reflexio/data", "reflexio/public_docs", "**/__pycache__", "reflexio/reflexio_commons/tests", "reflexio/reflexio_client/tests", - "reflexio/scripts", "notebooks", "demo" + "reflexio/scripts", "notebooks", "demo", + "reflexio/integrations/langchain", + "tests/test_scripts", + "**/.venv", + "benchmark" ], "extraPaths": ["."], "pythonVersion": "3.14", "typeCheckingMode": "basic", - "reportMissingTypeStubs": false + "reportMissingTypeStubs": false, + "reportMissingImports": "warning" } diff --git a/reflexio/benchmarks/retrieval_latency/backends.py b/reflexio/benchmarks/retrieval_latency/backends.py index 350b3b33..7009c193 100644 --- a/reflexio/benchmarks/retrieval_latency/backends.py +++ b/reflexio/benchmarks/retrieval_latency/backends.py @@ -45,7 +45,7 @@ class BackendHandle: Attributes: name (str): Short backend identifier, e.g. ``"sqlite"``. - reflexio (Reflexio): Service-layer facade — call ``search_profiles`` + reflexio (Reflexio): Service-layer facade — call ``search_user_profiles`` etc. directly on this for the service layer benchmark. storage (BaseStorage): Underlying storage instance, needed for swapping ``_get_embedding`` during seeding and the timed loop. diff --git a/reflexio/benchmarks/retrieval_latency/bench.py b/reflexio/benchmarks/retrieval_latency/bench.py index c8e0cdc7..175a5cf6 100644 --- a/reflexio/benchmarks/retrieval_latency/bench.py +++ b/reflexio/benchmarks/retrieval_latency/bench.py @@ -177,7 +177,7 @@ def _service_call( """ match retrieval: case "profile": - reflexio.search_profiles(_build_profile_request(query_idx)) + reflexio.search_user_profiles(_build_profile_request(query_idx)) case "user_playbook": reflexio.search_user_playbooks(_build_user_playbook_request(query_idx)) case "agent_playbook": @@ -188,7 +188,7 @@ def _service_call( # Map retrieval type to (HTTP path, request builder) for the http layer. _HTTP_ROUTES: dict[RetrievalType, tuple[str, Callable[[int], Any]]] = { - "profile": ("/api/search_profiles", _build_profile_request), + "profile": ("/api/search_user_profiles", _build_profile_request), "user_playbook": ("/api/search_user_playbooks", _build_user_playbook_request), "agent_playbook": ("/api/search_agent_playbooks", _build_agent_playbook_request), "unified": ("/api/search", _build_unified_request), diff --git a/reflexio/cli/codex_auth.py b/reflexio/cli/codex_auth.py new file mode 100644 index 00000000..fb2f5dca --- /dev/null +++ b/reflexio/cli/codex_auth.py @@ -0,0 +1,503 @@ +"""Reflexio-native OAuth tokens for OpenAI Codex / ChatGPT subscription. + +This module owns reflexio's own OAuth tokens against ``auth.openai.com``, +independent of OpenClaw or any other CLI. Tokens are stored at +``~/.reflexio/auth/openai-codex.json`` and the refresh-token flow is built +into the loader so callers always see a fresh access token. + +Why a separate module: the token store is consumed by both the CLI +(``reflexio setup openai-codex``) and the runtime proxy (``codex_proxy.py`` +in the enterprise tree). Putting it in one place keeps the file shape and +refresh policy in sync. +""" + +from __future__ import annotations + +import base64 +import contextlib +import hashlib +import json +import logging +import secrets +import time +import urllib.error +import urllib.request +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from typing import Any +from urllib.parse import parse_qs, urlencode, urlparse + +logger = logging.getLogger(__name__) + +# OAuth client + endpoints used by the Codex CLI. Values verified by +# inspecting the JWT payload of an existing OpenClaw-issued token +# (`client_id`, `iss` claims) and the codex-rs source. +CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +CODEX_AUTH_ISSUER = "https://auth.openai.com" +CODEX_AUTHORIZE_URL = f"{CODEX_AUTH_ISSUER}/oauth/authorize" +CODEX_TOKEN_URL = f"{CODEX_AUTH_ISSUER}/oauth/token" + +# Codex CLI binds its callback server to this port; OpenAI's OAuth client +# config has ``http://localhost:1455/auth/callback`` registered as a valid +# redirect URI, so we reuse it. +CODEX_CALLBACK_HOST = "localhost" +CODEX_CALLBACK_PORT = 1455 +CODEX_CALLBACK_PATH = "/auth/callback" +CODEX_REDIRECT_URI = ( + f"http://{CODEX_CALLBACK_HOST}:{CODEX_CALLBACK_PORT}{CODEX_CALLBACK_PATH}" +) + +CODEX_SCOPES = "openid profile email offline_access" + +# Refresh slightly before the access token actually expires so a slow +# downstream call doesn't cross the boundary mid-flight. +_REFRESH_LEAD_SECONDS = 60 + +REFLEXIO_AUTH_DIR = Path.home() / ".reflexio" / "auth" +REFLEXIO_CODEX_TOKENS_PATH = REFLEXIO_AUTH_DIR / "openai-codex.json" + + +@dataclass +class CodexTokens: + """Persisted Codex OAuth tokens. + + Attributes: + access_token (str): Bearer token used for ``api.openai.com`` and + ``chatgpt.com/backend-api/codex`` calls. + refresh_token (str): Long-lived token used to mint a new access + token at ``/oauth/token``. + account_id (str): ``ChatGPT-Account-ID`` header value (from the + JWT's ``chatgpt_account_id`` claim). + expires_at (int): Unix epoch seconds when ``access_token`` expires. + plan_type (str): Cached ``chatgpt_plan_type`` from the JWT (e.g. + ``"plus"``, ``"max-x20"``) for human-facing diagnostics. + email (str): User email from the JWT, surfaced in CLI status. + """ + + access_token: str + refresh_token: str + account_id: str + expires_at: int + plan_type: str + email: str + + def is_expired(self, lead_seconds: int = _REFRESH_LEAD_SECONDS) -> bool: + """Return True if the access token will expire within ``lead_seconds``. + + Args: + lead_seconds (int): Treat tokens with less than this much time + remaining as already expired. + + Returns: + bool: ``True`` if a refresh is needed. + """ + return self.expires_at - lead_seconds <= int(time.time()) + + +def _b64url(data: bytes) -> str: + """Base64url-encode without padding (PKCE-style).""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _make_pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) PKCE pair. + + Uses a 32-byte random verifier; SHA-256 + base64url for the challenge. + + Returns: + tuple[str, str]: ``(verifier, challenge)``. + """ + verifier = _b64url(secrets.token_bytes(32)) + challenge = _b64url(hashlib.sha256(verifier.encode("ascii")).digest()) + return verifier, challenge + + +def _decode_jwt_payload(jwt: str) -> dict[str, Any]: + """Decode an unsigned JWT payload (no signature verification). + + Codex JWTs are issued by ``auth.openai.com`` with RS256; we don't have + the public key locally and don't need to: storing tokens we receive over + HTTPS from the issuer is sufficient. The payload is read for metadata + (account_id, plan_type, email, exp). + + Args: + jwt (str): A JWT in standard ``header.payload.signature`` form. + + Returns: + dict[str, Any]: The JSON-parsed payload. + + Raises: + ValueError: If the JWT is malformed. + """ + parts = jwt.split(".") + if len(parts) != 3: + raise ValueError("not a JWT (expected three dot-separated parts)") + payload_b64 = parts[1] + "=" * (-len(parts[1]) % 4) # restore padding + return json.loads(base64.urlsafe_b64decode(payload_b64)) + + +def _tokens_from_response(payload: dict[str, Any]) -> CodexTokens: + """Build a ``CodexTokens`` from an ``/oauth/token`` JSON response. + + Reads the access JWT to derive ``account_id``, ``plan_type``, ``email``, + and ``expires_at``. Falls back to ``expires_in`` from the response if the + JWT lacks an ``exp`` claim. + + Args: + payload (dict): Decoded JSON body from an OAuth token endpoint. + + Returns: + CodexTokens: Populated record. + + Raises: + ValueError: Required fields missing. + """ + access = payload.get("access_token") + refresh = payload.get("refresh_token") + if not access or not refresh: + raise ValueError( + f"OAuth response missing access_token / refresh_token: {payload}" + ) + claims = _decode_jwt_payload(access) + auth_claims = claims.get("https://api.openai.com/auth", {}) or {} + profile_claims = claims.get("https://api.openai.com/profile", {}) or {} + account_id = auth_claims.get("chatgpt_account_id", "") or "" + plan_type = auth_claims.get("chatgpt_plan_type", "unknown") or "unknown" + email = profile_claims.get("email", "") or "" + if (exp := claims.get("exp")) is not None: + expires_at = int(exp) + else: + expires_at = int(time.time()) + int(payload.get("expires_in", 0)) + return CodexTokens( + access_token=access, + refresh_token=refresh, + account_id=account_id, + expires_at=expires_at, + plan_type=str(plan_type), + email=str(email), + ) + + +def save_tokens(tokens: CodexTokens) -> Path: + """Persist tokens to ``~/.reflexio/auth/openai-codex.json``. + + Creates the parent directory with restrictive permissions on first write. + The token file itself is written with mode 0600 — bearer tokens shouldn't + be world-readable. + + Args: + tokens (CodexTokens): Tokens to persist. + + Returns: + Path: Where the file was written. + """ + REFLEXIO_AUTH_DIR.mkdir(parents=True, exist_ok=True) + # Filesystems without POSIX permissions (e.g., FAT) won't honour chmod; + # tolerate the failure rather than aborting the login. + with contextlib.suppress(OSError): + REFLEXIO_AUTH_DIR.chmod(0o700) + payload = { + "version": 1, + "access_token": tokens.access_token, + "refresh_token": tokens.refresh_token, + "account_id": tokens.account_id, + "expires_at": tokens.expires_at, + "plan_type": tokens.plan_type, + "email": tokens.email, + } + REFLEXIO_CODEX_TOKENS_PATH.write_text(json.dumps(payload, indent=2)) + with contextlib.suppress(OSError): + REFLEXIO_CODEX_TOKENS_PATH.chmod(0o600) + return REFLEXIO_CODEX_TOKENS_PATH + + +def load_tokens_raw() -> CodexTokens | None: + """Load tokens from disk without refreshing. + + Returns: + CodexTokens | None: Persisted tokens, or ``None`` if the file is + missing or malformed. + """ + if not REFLEXIO_CODEX_TOKENS_PATH.exists(): + return None + try: + data = json.loads(REFLEXIO_CODEX_TOKENS_PATH.read_text()) + return CodexTokens( + access_token=data["access_token"], + refresh_token=data["refresh_token"], + account_id=data.get("account_id", ""), + expires_at=int(data.get("expires_at", 0)), + plan_type=data.get("plan_type", "unknown"), + email=data.get("email", ""), + ) + except (KeyError, json.JSONDecodeError, ValueError) as e: + logger.warning("Bad reflexio codex tokens file: %s", e) + return None + + +def refresh_tokens(tokens: CodexTokens) -> CodexTokens: + """Exchange the refresh_token for a new (access, refresh) pair. + + POSTs to ``auth.openai.com/oauth/token`` with ``grant_type=refresh_token``. + The new tokens are persisted to disk before returning. + + Args: + tokens (CodexTokens): The current tokens; only ``refresh_token`` is read. + + Returns: + CodexTokens: A fresh, persisted token record. + + Raises: + urllib.error.HTTPError: If the token endpoint rejects the refresh + (e.g., refresh_token revoked — caller should prompt re-login). + """ + body = urlencode( + { + "grant_type": "refresh_token", + "refresh_token": tokens.refresh_token, + "client_id": CODEX_CLIENT_ID, + "scope": CODEX_SCOPES, + } + ).encode("utf-8") + req = urllib.request.Request( # noqa: S310 - fixed https URL + CODEX_TOKEN_URL, + data=body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=30) as resp: # noqa: S310 - fixed https URL + payload = json.loads(resp.read()) + new_tokens = _tokens_from_response(payload) + save_tokens(new_tokens) + logger.info( + "Refreshed OpenAI Codex tokens; new access expires at %d (plan=%s)", + new_tokens.expires_at, + new_tokens.plan_type, + ) + return new_tokens + + +def get_fresh_tokens() -> CodexTokens | None: + """Return tokens, refreshing on disk if the access token has expired. + + Returns: + CodexTokens | None: Fresh tokens, or ``None`` if no tokens are saved. + Caller should run ``reflexio setup openai-codex`` if ``None``. + """ + tokens = load_tokens_raw() + if tokens is None: + return None + if tokens.is_expired(): + try: + return refresh_tokens(tokens) + except urllib.error.HTTPError as e: + logger.warning( + "Refresh failed (HTTP %d); re-login required via " + "'reflexio setup openai-codex'", + e.code, + ) + return None + return tokens + + +# --------------------------------------------------------------------------- +# Authorization-code login flow (browser + PKCE + local callback) +# --------------------------------------------------------------------------- + + +class _CallbackHandler(BaseHTTPRequestHandler): + """One-shot HTTP handler that captures the OAuth callback. + + The handler stashes the parsed query parameters on the server instance + (which a stricter typer would model as a custom HTTPServer subclass); + the orchestrating function reads them back after ``handle_request``. + + Browsers expect a tidy success page; we serve a small HTML body so the + user knows the CLI took control. + """ + + # Silence default access logs; this is a 1-shot interactive flow. + def log_message( # noqa: ANN401, ARG002 — signature dictated by stdlib + self, + format: str, # noqa: A002, ARG002 + *args: Any, # noqa: ARG002 + ) -> None: + """No-op — suppress the default access log noise.""" + return + + def do_GET(self) -> None: # noqa: N802 - dictated by stdlib + """Capture the callback query and write a success page.""" + parsed = urlparse(self.path) + if parsed.path != CODEX_CALLBACK_PATH: + self.send_response(404) + self.end_headers() + return + query = parse_qs(parsed.query) + # Store on the server instance for the caller to read. + self.server._captured = { # type: ignore[attr-defined] + "code": (query.get("code") or [""])[0], + "state": (query.get("state") or [""])[0], + "error": (query.get("error") or [""])[0], + "error_description": (query.get("error_description") or [""])[0], + } + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write( + b"" + b"

Reflexio is now signed in.

" + b"

You can close this tab and return to the terminal.

" + b"" + ) + + +def _capture_oauth_callback(state: str, timeout_s: int) -> dict[str, str]: + """Run a one-shot HTTP server and return the OAuth callback query. + + Args: + state (str): The CSRF ``state`` value sent on the authorize call; + verified to match here. + timeout_s (int): Hard ceiling on how long to wait for the user to + complete the browser flow. + + Returns: + dict[str, str]: The captured query parameters + (``code``, ``state``, ``error``, ``error_description``). + + Raises: + TimeoutError: If the callback isn't received in time. + ValueError: If the callback's state doesn't match the request's. + """ + server = HTTPServer((CODEX_CALLBACK_HOST, CODEX_CALLBACK_PORT), _CallbackHandler) + server._captured = None # type: ignore[attr-defined] + server.timeout = timeout_s + server.handle_request() + captured: dict[str, str] | None = getattr(server, "_captured", None) + if captured is None: + raise TimeoutError( + f"OAuth callback not received within {timeout_s}s — open the URL " + "yourself and complete the sign-in?" + ) + if captured.get("state") != state: + raise ValueError( + "OAuth state mismatch — refusing to continue (possible CSRF)." + ) + if err := captured.get("error"): + raise ValueError( + f"OAuth provider returned error '{err}': {captured.get('error_description', '')}" + ) + return captured + + +def build_authorize_url(verifier: str, state: str) -> tuple[str, str]: + """Build the authorization URL for the browser step of the OAuth flow. + + Args: + verifier (str): PKCE code verifier (the random secret stored locally). + state (str): CSRF state value to round-trip through the redirect. + + Returns: + tuple[str, str]: ``(authorize_url, code_challenge)``. The challenge + is returned for callers that want to display it; the URL is what + actually goes in the browser. + """ + challenge = _b64url(hashlib.sha256(verifier.encode("ascii")).digest()) + qs = urlencode( + { + "client_id": CODEX_CLIENT_ID, + "response_type": "code", + "redirect_uri": CODEX_REDIRECT_URI, + "scope": CODEX_SCOPES, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + } + ) + return f"{CODEX_AUTHORIZE_URL}?{qs}", challenge + + +def exchange_authorization_code(code: str, verifier: str) -> CodexTokens: + """Exchange an OAuth authorization code for tokens. + + Args: + code (str): The ``code`` query param the redirect delivered. + verifier (str): The PKCE code verifier (must be the one used when + building the authorize URL). + + Returns: + CodexTokens: The persisted token record. + + Raises: + urllib.error.HTTPError: If the token endpoint rejects the request. + """ + body = urlencode( + { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": CODEX_REDIRECT_URI, + "client_id": CODEX_CLIENT_ID, + "code_verifier": verifier, + } + ).encode("utf-8") + req = urllib.request.Request( # noqa: S310 - fixed https URL + CODEX_TOKEN_URL, + data=body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=30) as resp: # noqa: S310 - fixed https URL + payload = json.loads(resp.read()) + return _tokens_from_response(payload) + + +def login_interactive( + *, + open_browser: bool = True, + timeout_s: int = 300, +) -> CodexTokens: + """Run the full PKCE OAuth flow against ``auth.openai.com``. + + Steps: + 1. Generate a fresh PKCE pair + CSRF state. + 2. Build the authorize URL and either open the user's browser or + print the URL for them to open manually. + 3. Bind a one-shot HTTP server on ``localhost:1455`` to catch the + callback. + 4. Exchange the returned auth code for tokens. + 5. Persist tokens to disk. + + Args: + open_browser (bool): When True (default), call ``webbrowser.open`` + on the authorize URL. When False, just print it. + timeout_s (int): Maximum wall time to wait for the callback before + failing. + + Returns: + CodexTokens: The persisted token record. + """ + verifier, _challenge = _make_pkce_pair() + state = _b64url(secrets.token_bytes(16)) + authorize_url, _ = build_authorize_url(verifier, state) + + if open_browser: + # Lazy import — webbrowser pulls in tkinter on some platforms. + import webbrowser + + opened = webbrowser.open(authorize_url, new=1) + if not opened: + print("Could not open browser automatically.") + print() + print("Open this URL to sign in to ChatGPT:") + print(f" {authorize_url}") + print() + print(f"Listening for callback on {CODEX_REDIRECT_URI} ...") + + captured = _capture_oauth_callback(state=state, timeout_s=timeout_s) + code = captured.get("code") or "" + if not code: + raise ValueError("OAuth callback returned no authorization code.") + + tokens = exchange_authorization_code(code, verifier) + save_tokens(tokens) + return tokens diff --git a/reflexio/cli/commands/profiles.py b/reflexio/cli/commands/profiles.py index be9f2476..59eaa3ba 100644 --- a/reflexio/cli/commands/profiles.py +++ b/reflexio/cli/commands/profiles.py @@ -91,7 +91,7 @@ def list_profiles( @app.command() @handle_errors -def search( +def search_user_profiles( ctx: typer.Context, query: Annotated[ str, @@ -110,7 +110,7 @@ def search( typer.Option("--threshold", help="Similarity threshold"), ] = None, ) -> None: - """Search profiles by semantic query. + """Search user profiles by semantic query. Args: ctx: Typer context with CliState in ctx.obj @@ -135,7 +135,7 @@ def search( if threshold is not None: kwargs["threshold"] = threshold - resp = client.search_profiles(**kwargs) + resp = client.search_user_profiles(**kwargs) profiles = resp.user_profiles or [] json_mode: bool = ctx.obj.json_mode diff --git a/reflexio/cli/commands/setup_cmd.py b/reflexio/cli/commands/setup_cmd.py index a1cd39b7..1dad9f72 100644 --- a/reflexio/cli/commands/setup_cmd.py +++ b/reflexio/cli/commands/setup_cmd.py @@ -26,6 +26,7 @@ class InstallLocation(Enum): CURRENT_PROJECT = "current_project" ALL_PROJECTS = "all_projects" + app = typer.Typer( help="Configure Reflexio: run 'init' for plain CLI setup, or one of " "the integration commands (openclaw, claude-code) to also install " @@ -425,7 +426,9 @@ def _install_openclaw_integration() -> bool: typer.echo("Plugin installed and registered") return True - typer.echo("Error: Plugin not loaded -- check 'openclaw plugins inspect reflexio-federated'") + typer.echo( + "Error: Plugin not loaded -- check 'openclaw plugins inspect reflexio-federated'" + ) return False @@ -659,15 +662,31 @@ def _merge_hook_config( # Session start hook (SessionStart) — checks/starts Reflexio server proactively session_start_hook_sh = handler_js_path.parent / "session_start_hook.sh" - _upsert_hook(hooks, "SessionStart", f"bash {shlex.quote(str(session_start_hook_sh))}") + _upsert_hook( + hooks, "SessionStart", f"bash {shlex.quote(str(session_start_hook_sh))}" + ) # Search hook (UserPromptSubmit) — injects Reflexio context before Claude responds search_hook_js = handler_js_path.parent / "search_hook.js" _upsert_hook(hooks, "UserPromptSubmit", f"node {shlex.quote(str(search_hook_js))}") - # Stop hook (expert mode) — publishes session transcript for extraction + # Stop hook (expert mode) — publishes session transcript for extraction. + # On non-expert (re)install, remove the hook if it was previously installed. if expert: _upsert_hook(hooks, "Stop", f"node {shlex.quote(str(handler_js_path))}") + else: + stop_hooks = hooks.get("Stop", []) + cleaned = [ + entry + for entry in stop_hooks + if not any( + "reflexio" in h.get("command", "") for h in entry.get("hooks", []) + ) + ] + if cleaned: + hooks["Stop"] = cleaned + elif "Stop" in hooks: + del hooks["Stop"] settings_path.parent.mkdir(parents=True, exist_ok=True) settings_path.write_text(json.dumps(settings, indent=2) + "\n") @@ -771,12 +790,16 @@ def _install_claude_code_integration( rules_dest.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(rules_src, rules_dest) - # Expert mode: also install /reflexio-extract command + # Expert mode: also install /reflexio-extract command. + # Non-expert (re)install: remove expert-only artifacts if present. + cmd_dest_dir = claude_dir / "commands" / "reflexio-extract" if expert: cmd_src = integration_dir / "commands" / "reflexio-extract" / "SKILL.md" - cmd_dest = claude_dir / "commands" / "reflexio-extract" / "SKILL.md" - cmd_dest.parent.mkdir(parents=True, exist_ok=True) + cmd_dest = cmd_dest_dir / "SKILL.md" + cmd_dest_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(cmd_src, cmd_dest) + elif cmd_dest_dir.exists(): + shutil.rmtree(cmd_dest_dir) # Configure hook handler_js = integration_dir / "hook" / "handler.js" @@ -845,9 +868,7 @@ def _remove_from_dir(base_dir: Path) -> None: typer.echo(f" Removed hook from: {settings_path}") -def _uninstall_claude_code( - project_dir: Path, *, global_install: bool = False -) -> None: +def _uninstall_claude_code(project_dir: Path, *, global_install: bool = False) -> None: """Remove the Reflexio integration from Claude Code. When ``--global`` or ``--project-dir`` is explicit, removes from that @@ -963,7 +984,9 @@ def claude_code_setup( target = ( Path.home() if global_install - else Path(project_dir) if project_dir is not None else Path.cwd() + else Path(project_dir) + if project_dir is not None + else Path.cwd() ) _uninstall_claude_code(target, global_install=global_install) return @@ -976,11 +999,7 @@ def claude_code_setup( location = InstallLocation.CURRENT_PROJECT else: location = _prompt_install_location() - target = ( - Path.home() - if location == InstallLocation.ALL_PROJECTS - else Path.cwd() - ) + target = Path.home() if location == InstallLocation.ALL_PROJECTS else Path.cwd() # Step 1: Load .env path from reflexio.cli.env_loader import load_reflexio_env @@ -1048,7 +1067,9 @@ def claude_code_setup( typer.echo("Note: User-level hooks fire for ALL Claude Code sessions.") typer.echo("") if location == InstallLocation.ALL_PROJECTS: - typer.echo("Next: Start any Claude Code session — Reflexio is active in all projects.") + typer.echo( + "Next: Start any Claude Code session — Reflexio is active in all projects." + ) else: typer.echo("Next: Start a Claude Code session in this project.") if is_remote: @@ -1057,3 +1078,119 @@ def claude_code_setup( typer.echo( "The skill will guide Claude to check and start the Reflexio server automatically." ) + + +@app.command("openai-codex") +def openai_codex_setup( + no_browser: Annotated[ + bool, + typer.Option( + "--no-browser", + help="Don't auto-open the browser; print the URL to copy/paste instead.", + ), + ] = False, + timeout: Annotated[ + int, + typer.Option( + "--timeout", + help="Seconds to wait for the OAuth callback before failing.", + ), + ] = 300, + show: Annotated[ + bool, + typer.Option( + "--show", + help="Print currently saved Codex token metadata and exit (no login).", + ), + ] = False, + logout: Annotated[ + bool, + typer.Option( + "--logout", + help="Delete the saved Codex token file and exit.", + ), + ] = False, +) -> None: + """Sign in to OpenAI via your ChatGPT subscription (Codex OAuth). + + Stores access + refresh tokens at ``~/.reflexio/auth/openai-codex.json``. + The codex proxy and any other reflexio component that needs OpenAI auth + reads from this file directly — no dependency on OpenClaw or any other + CLI. The proxy auto-refreshes the access token when it nears expiry. + + Run this once, then start the codex proxy with:: + + ./reflexio_ext/scripts/start_with_codex_proxy.sh + + Re-run this command if your subscription tier changes or the + refresh_token gets revoked (rare). + """ + # Imported here so plain `reflexio --help` doesn't require the OAuth + # module to load (slight startup speedup; mostly cosmetic). + from reflexio.cli.codex_auth import ( + REFLEXIO_CODEX_TOKENS_PATH, + get_fresh_tokens, + load_tokens_raw, + login_interactive, + ) + + if logout: + if REFLEXIO_CODEX_TOKENS_PATH.exists(): + REFLEXIO_CODEX_TOKENS_PATH.unlink() + typer.echo(f"Removed {REFLEXIO_CODEX_TOKENS_PATH}") + else: + typer.echo("No saved Codex tokens to remove.") + return + + if show: + tokens = load_tokens_raw() + if tokens is None: + typer.echo(f"No tokens at {REFLEXIO_CODEX_TOKENS_PATH}.") + typer.echo("Run `reflexio setup openai-codex` to sign in.") + raise typer.Exit(1) + typer.echo(f" path: {REFLEXIO_CODEX_TOKENS_PATH}") + typer.echo(f" email: {tokens.email}") + typer.echo(f" plan_type: {tokens.plan_type}") + typer.echo( + f" account_id ...{tokens.account_id[-8:]}" + if tokens.account_id + else " account_id (empty)" + ) + typer.echo(f" expires_at: {tokens.expires_at} (unix epoch)") + typer.echo(f" expired: {tokens.is_expired()}") + return + + typer.echo("Starting OpenAI Codex OAuth flow...") + try: + tokens = login_interactive( + open_browser=not no_browser, + timeout_s=timeout, + ) + except TimeoutError as e: + typer.echo(f"Timed out: {e}") + raise typer.Exit(1) from e + except ValueError as e: + typer.echo(f"Login failed: {e}") + raise typer.Exit(1) from e + + typer.echo("") + typer.echo("Sign-in successful.") + typer.echo(f" saved to: {REFLEXIO_CODEX_TOKENS_PATH}") + if tokens.email: + typer.echo(f" email: {tokens.email}") + typer.echo(f" plan_type: {tokens.plan_type}") + typer.echo("") + typer.echo( + "Verify the token resolves cleanly via the proxy's health endpoint:" + ) + typer.echo(" curl -s http://127.0.0.1:11435/health | jq") + typer.echo("") + typer.echo( + "If the saved plan_type doesn't match what you expect (e.g. shows " + "'plus' instead of 'max-x20'), wait a minute for OpenAI to propagate " + "the subscription change and re-run this command — the JWT is issued " + "at sign-in time." + ) + # Exercise the refresh path immediately so any clock skew between the + # JWT's `exp` claim and our local clock is caught now, not at first use. + _ = get_fresh_tokens() diff --git a/reflexio/cli/commands/shortcuts.py b/reflexio/cli/commands/shortcuts.py index 15481323..5bd09f90 100644 --- a/reflexio/cli/commands/shortcuts.py +++ b/reflexio/cli/commands/shortcuts.py @@ -250,7 +250,9 @@ def context( ) profiles = [] - resp = client.search_profiles(user_id=resolved_user_id, query=query, top_k=5) + resp = client.search_user_profiles( + user_id=resolved_user_id, query=query, top_k=5 + ) if resp.success: profiles = resp.user_profiles diff --git a/reflexio/cli/env_loader.py b/reflexio/cli/env_loader.py index 657bc60c..492c2b05 100644 --- a/reflexio/cli/env_loader.py +++ b/reflexio/cli/env_loader.py @@ -21,6 +21,13 @@ _USER_ENV_FILE = _USER_ENV_DIR / ".env" +# Path to the .env file that load_reflexio_env last resolved — None until +# load_reflexio_env runs for the first time. Exposed via get_loaded_env_path +# so the startup banner can show the operator exactly which dotenv was +# picked (./.env vs ~/.reflexio/.env vs auto-created). +_loaded_env_path: Path | None = None + + def get_env_path() -> Path: """Return the canonical path to the user-level .env file. @@ -30,6 +37,17 @@ def get_env_path() -> Path: return _USER_ENV_FILE +def get_loaded_env_path() -> Path | None: + """Return the .env path that the most recent ``load_reflexio_env`` call + resolved, or None if the loader hasn't run yet. + + Used by the startup banner so operators can see at a glance which + dotenv file was actually consumed (``./.env`` wins over + ``~/.reflexio/.env`` when both exist). + """ + return _loaded_env_path + + def set_env_var(env_path: Path, key: str, value: str) -> None: """Write or update an environment variable in a .env file. @@ -95,16 +113,22 @@ def load_reflexio_env( Returns: Path to the loaded .env file, or None if no .env was found/created. """ + global _loaded_env_path for env_path in _ENV_SEARCH_PATHS: if env_path.exists(): load_dotenv(dotenv_path=env_path) - _logger.debug("Loaded env from: %s", env_path.resolve()) + resolved = env_path.resolve() + _logger.debug("Loaded env from: %s", resolved) + _loaded_env_path = resolved # Auto-generate any missing secret keys into the existing .env _backfill_missing_keys(env_path, auto_generate_keys or []) return env_path # No .env found — auto-create from bundled template - return _create_default_env(package_data_module, auto_generate_keys or []) + created = _create_default_env(package_data_module, auto_generate_keys or []) + if created is not None: + _loaded_env_path = created.resolve() + return created def _backfill_missing_keys(env_path: Path, keys: list[str]) -> None: diff --git a/reflexio/cli/log_format.py b/reflexio/cli/log_format.py index 8a44ae57..86ac1d41 100644 --- a/reflexio/cli/log_format.py +++ b/reflexio/cli/log_format.py @@ -25,19 +25,17 @@ # ANSI codes for log-level severity highlighting in service output. # Keys are matched against the level token captured by `_LEVEL_RE`. _LEVEL_COLORS: dict[str, str] = { - "ERROR": "31", # red + "ERROR": "31", # red "CRITICAL": "1;31", # bold red - "WARNING": "33", # yellow - "WARN": "33", # yellow (Next.js / some loggers) + "WARNING": "33", # yellow + "WARN": "33", # yellow (Next.js / some loggers) } # Match a log-level token at the start of a line, optionally bracketed, # followed by a typical separator (":", whitespace, or " - "). Covers # uvicorn ("ERROR: msg"), stdlib logging ("[ERROR] msg"), and the # "ERROR - msg" style used by Next.js / some custom loggers. -_LEVEL_RE = re.compile( - r"^(?:\[)?(ERROR|CRITICAL|WARNING|WARN)(?:\])?(?::|\s+-\s+|\s+)" -) +_LEVEL_RE = re.compile(r"^(?:\[)?(ERROR|CRITICAL|WARNING|WARN)(?:\])?(?::|\s+-\s+|\s+)") # Canonical log file paths — stored in ~/.reflexio/logs/ (not the project directory) _LOG_DIR = str(Path.home() / ".reflexio" / "logs") @@ -162,6 +160,7 @@ def print_startup_banner( *, supabase_port: int | None = 54321, log_file: str = DEV_LOG_FILE, + config_paths: dict[str, str] | None = None, ) -> None: """Print a consolidated startup summary banner with service URLs. @@ -169,6 +168,10 @@ def print_startup_banner( ports: Mapping of service name to port number. supabase_port: Supabase port, or None if not running. log_file: Path to the log file. + config_paths: Optional mapping of config-label → path string (e.g. + ``{"env": "~/.reflexio/.env", "config": "~/.reflexio/configs/config_default.json"}``). + Renders as a "Config" section above the "Logs" line so operators + can see at a glance which files the server actually loaded. """ lines = [] width = 44 @@ -191,8 +194,24 @@ def print_startup_banner( status = colorize("ready", "32") lines.append(f"{label}{url:<26}{status}") + home = str(Path.home()) + + def _collapse_home(path: str) -> str: + # Collapse HOME to ~ for readability; absolute paths stay absolute + # so log scrapers and copy-paste still work when outside HOME. + return "~" + path[len(home) :] if path.startswith(home) else path + + if config_paths: + lines.append(f"{'-' * width}") + for label, path in config_paths.items(): + lines.append(f" {label:<11}{_collapse_home(str(path))}") + lines.append(f"{'-' * width}") - lines.append(f" Logs {log_file}") + # Logs section — surface both the general dev log and the LLM I/O log. + # LLM_IO_LOG_FILE is the one operators hit first when debugging prompt / + # tool-call issues; it's opaque without this pointer. + lines.append(f" Dev log {_collapse_home(log_file)}") + lines.append(f" LLM I/O {_collapse_home(LLM_IO_LOG_FILE)}") lines.append(f"{'=' * width}\n") # Print all at once to avoid interleaving diff --git a/reflexio/cli/run_services.py b/reflexio/cli/run_services.py index db164773..98ef1896 100644 --- a/reflexio/cli/run_services.py +++ b/reflexio/cli/run_services.py @@ -184,9 +184,7 @@ def execute(args: argparse.Namespace) -> None: if "docs" in only: if DOCS_DIR.is_dir(): - services.append( - build_nextjs_service("docs", ports, cwd=str(DOCS_DIR)) - ) + services.append(build_nextjs_service("docs", ports, cwd=str(DOCS_DIR))) elif docs_explicit: print( f"Cannot start docs: {DOCS_DIR} not found. " diff --git a/reflexio/client/client.py b/reflexio/client/client.py index a5619b10..568b597c 100644 --- a/reflexio/client/client.py +++ b/reflexio/client/client.py @@ -28,6 +28,7 @@ GetUserPlaybooksViewResponse, GetUserProfilesRequest, ProfileChangeLogViewResponse, + RerankUserProfilesRequest, SearchAgentPlaybookRequest, SearchAgentPlaybooksViewResponse, SearchInteractionRequest, @@ -36,6 +37,7 @@ SearchUserPlaybookRequest, SearchUserPlaybooksViewResponse, SearchUserProfileRequest, + StorageStatsResponse, UnifiedSearchRequest, UnifiedSearchViewResponse, UpdateAgentPlaybookRequest, @@ -494,7 +496,7 @@ def search_interactions( ) return SearchInteractionsViewResponse(**response) - def search_profiles( + def search_user_profiles( self, request: SearchUserProfileRequest | dict | None = None, *, @@ -547,10 +549,70 @@ def search_profiles( search_mode=search_mode, ) response = self._make_request( - "POST", "/api/search_profiles", json=req.model_dump() + "POST", "/api/search_user_profiles", json=req.model_dump() ) return SearchProfilesViewResponse(**response) + def rerank_user_profiles( + self, + request: RerankUserProfilesRequest | dict | None = None, + *, + user_id: str | None = None, + query: str | None = None, + profile_ids: list[str] | None = None, + top_k: int | None = None, + ) -> SearchProfilesViewResponse: + """Rerank a list of profile ids by query relevance using a cross-encoder. + + The server fetches each candidate's full content (filtered by + ``user_id``), scores ``(query, content)`` pairs with a CPU + cross-encoder, and returns the top_k profiles sorted by descending + score. Profile ids that don't exist for the user are silently dropped. + + Args: + request (Optional[RerankUserProfilesRequest]): The rerank request + object (alternative to kwargs) + user_id (Optional[str]): The user whose profiles to rerank + query (Optional[str]): The reranking query + profile_ids (Optional[list[str]]): Candidate profile ids to score + top_k (Optional[int]): Maximum profiles to return (default: 10) + + Returns: + SearchProfilesViewResponse: Reranked profiles, top_k entries + """ + req = self._build_request( + request, + RerankUserProfilesRequest, + user_id=user_id, + query=query, + profile_ids=profile_ids, + top_k=top_k, + ) + response = self._make_request( + "POST", "/api/rerank_user_profiles", json=req.model_dump() + ) + return SearchProfilesViewResponse(**response) + + def storage_stats( + self, + user_id: str, + ) -> StorageStatsResponse: + """Get a quick count of how many profiles/playbooks the user has. + + Returns counts and the modified-time range across every status — + useful for sizing ``top_k`` before retrieval. + + Args: + user_id (str): The user to inspect. + + Returns: + StorageStatsResponse: Counts and timestamp range for the user. + """ + response = self._make_request( + "GET", "/api/storage_stats", params={"user_id": user_id} + ) + return StorageStatsResponse(**response) + def search_user_playbooks( self, request: SearchUserPlaybookRequest | dict | None = None, diff --git a/reflexio/lib/_base.py b/reflexio/lib/_base.py index 8926d41c..548638eb 100644 --- a/reflexio/lib/_base.py +++ b/reflexio/lib/_base.py @@ -169,7 +169,9 @@ def _maybe_get_query_embedding( try: return storage._get_embedding(query, purpose="query") # type: ignore[reportAttributeAccessIssue] except Exception as e: - logger.warning("Failed to generate query embedding due to %s — falling back to FTS", e) + logger.warning( + "Failed to generate query embedding due to %s — falling back to FTS", e + ) return None def _reformulate_query( diff --git a/reflexio/lib/_profiles.py b/reflexio/lib/_profiles.py index c61587ee..fae9bd60 100644 --- a/reflexio/lib/_profiles.py +++ b/reflexio/lib/_profiles.py @@ -13,8 +13,12 @@ GetProfileStatisticsResponse, GetUserProfilesRequest, GetUserProfilesResponse, + RerankUserProfilesRequest, + RerankUserProfilesResponse, SearchUserProfileRequest, SearchUserProfileResponse, + StorageStatsRequest, + StorageStatsResponse, UpdateUserProfileRequest, UpdateUserProfileResponse, ) @@ -38,7 +42,7 @@ class ProfilesMixin(ReflexioBase): - def search_profiles( + def search_user_profiles( self, request: SearchUserProfileRequest | dict, status_filter: list[Status | None] | None = None, @@ -69,7 +73,7 @@ def search_profiles( request.query, request.search_mode ) logger.info( - "search_profiles: query=%r, search_mode=%s, embedding_generated=%s", + "search_user_profiles: query=%r, search_mode=%s, embedding_generated=%s", request.query, request.search_mode, query_embedding is not None, @@ -83,6 +87,109 @@ def search_profiles( msg=f"Found {len(profiles)} matching profile(s)", ) + def rerank_user_profiles( + self, + request: RerankUserProfilesRequest | dict, + ) -> RerankUserProfilesResponse: + """Rerank a list of profile ids by query relevance using a cross-encoder. + + Fetches each profile's full content (filtered by ``user_id``), scores + ``(query, content)`` pairs with ``cross-encoder/ms-marco-MiniLM-L-6-v2``, + and returns the top_k profiles sorted by descending score. Profile ids + that don't exist for the user are silently dropped. + + Args: + request (Union[RerankUserProfilesRequest, dict]): The rerank + request — must contain ``user_id``, ``query``, and + ``profile_ids``. + + Returns: + RerankUserProfilesResponse: Profiles sorted by descending + relevance score, capped at ``request.top_k``. + """ + if not self._is_storage_configured(): + return RerankUserProfilesResponse( + success=True, user_profiles=[], msg=STORAGE_NOT_CONFIGURED_MSG + ) + if isinstance(request, dict): + request = RerankUserProfilesRequest(**request) + if not request.profile_ids: + return RerankUserProfilesResponse( + success=True, user_profiles=[], msg="No profile_ids provided" + ) + + # Fetch every profile for the user — including PENDING and ARCHIVED — + # because callers may want to rerank historical context, not just + # the currently-published set. + all_profiles = self._get_storage().get_user_profile( + request.user_id, status_filter=[None, Status.PENDING, Status.ARCHIVED] + ) + wanted = set(request.profile_ids) + candidates = [p for p in all_profiles if p.profile_id in wanted] + dropped = len(request.profile_ids) - len(candidates) + + # Lazy import keeps test collection fast; the cross-encoder pulls in + # torch + sentence-transformers on first call. + from reflexio.server.llm.rerank import score_pairs + + scores = score_pairs(request.query, [p.content for p in candidates]) + ranked = sorted( + zip(candidates, scores, strict=True), + key=lambda pair: pair[1], + reverse=True, + ) + top = [profile for profile, _score in ranked[: request.top_k]] + msg = f"Reranked {len(candidates)} profile(s); dropped {dropped} unknown id(s)" + return RerankUserProfilesResponse(success=True, user_profiles=top, msg=msg) + + def storage_stats( + self, + request: StorageStatsRequest | dict, + ) -> StorageStatsResponse: + """Return lightweight metadata about a user's stored profiles + playbooks. + + Provides counts and the last-modified timestamp range across every + status, suitable for sizing ``top_k`` before retrieval. + + Args: + request (Union[StorageStatsRequest, dict]): The stats request — + must contain ``user_id``. + + Returns: + StorageStatsResponse: Counts and timestamp range for the user. + """ + if not self._is_storage_configured(): + return StorageStatsResponse( + success=True, + profile_count=0, + playbook_count=0, + msg=STORAGE_NOT_CONFIGURED_MSG, + ) + if isinstance(request, dict): + request = StorageStatsRequest(**request) + storage = self._get_storage() + # Walk every status — agent callers care about total surface area, + # not just CURRENT entries. + all_statuses: list[Status | None] = [None, Status.PENDING, Status.ARCHIVED] + profiles = storage.get_user_profile(request.user_id, status_filter=all_statuses) + oldest_ts: datetime | None = None + newest_ts: datetime | None = None + if profiles: + timestamps = [p.last_modified_timestamp for p in profiles] + oldest_ts = datetime.fromtimestamp(min(timestamps), tz=UTC) + newest_ts = datetime.fromtimestamp(max(timestamps), tz=UTC) + playbook_count = storage.count_user_playbooks( + user_id=request.user_id, status_filter=all_statuses + ) + return StorageStatsResponse( + success=True, + profile_count=len(profiles), + playbook_count=playbook_count, + oldest_profile_modified=oldest_ts, + newest_profile_modified=newest_ts, + msg=f"Found {len(profiles)} profile(s) and {playbook_count} playbook(s)", + ) + def get_profile_change_logs(self) -> ProfileChangeLogResponse: """Get profile change logs. diff --git a/reflexio/lib/_search.py b/reflexio/lib/_search.py index 3091341f..e506902b 100644 --- a/reflexio/lib/_search.py +++ b/reflexio/lib/_search.py @@ -132,9 +132,25 @@ def unified_search( if isinstance(request, dict): request = UnifiedSearchRequest(**request) + config = self.request_context.configurator.get_config() + + # Dispatch on Config.search_backend. Without this branch, the agentic + # SearchAgent (server/services/search/agentic_search_service.py) is + # implemented but unreachable from the public /api/search path — + # setting search_backend="agentic" was a no-op pre-fix. + if config and config.search_backend == "agentic": + from reflexio.server.services.search.agentic_search_service import ( + AgenticSearchService, + ) + + agentic_svc = AgenticSearchService( + llm_client=self.llm_client, + request_context=self.request_context, + ) + return agentic_svc.search(request) + from reflexio.server.services.unified_search_service import run_unified_search - config = self.request_context.configurator.get_config() config_llm_config = config.llm_config if config else None # Resolve pre_retrieval_model_name: config override → site var → auto-detect diff --git a/reflexio/models/api_schema/domain/entities.py b/reflexio/models/api_schema/domain/entities.py index 0330e06a..efc772d3 100644 --- a/reflexio/models/api_schema/domain/entities.py +++ b/reflexio/models/api_schema/domain/entities.py @@ -164,6 +164,9 @@ class UserProfile(BaseModel): extractor_names: list[str] | None = None expanded_terms: str | None = None embedding: EmbeddingVector = [] + source_span: str | None = None + notes: str | None = None + reader_angle: str | None = None # user playbook for agents @@ -185,6 +188,9 @@ class UserPlaybook(BaseModel): source_interaction_ids: list[int] = Field(default_factory=list) expanded_terms: str | None = None embedding: EmbeddingVector = [] + source_span: str | None = None + notes: str | None = None + reader_angle: str | None = None class ProfileChangeLog(BaseModel): diff --git a/reflexio/models/api_schema/retriever_schema.py b/reflexio/models/api_schema/retriever_schema.py index d5dec0a2..c6326f38 100644 --- a/reflexio/models/api_schema/retriever_schema.py +++ b/reflexio/models/api_schema/retriever_schema.py @@ -81,6 +81,84 @@ class SearchUserProfileResponse(BaseModel): msg: str | None = None +class RerankUserProfilesRequest(BaseModel): + """Cross-encoder rerank for a list of profile ids. + + Use after ``search_user_profiles`` (or any other source of candidate ids) + when initial results are noisy. The server fetches each candidate's full + content, scores ``(query, content)`` pairs with a CPU cross-encoder, and + returns the top_k profiles sorted by descending score. + + Args: + user_id (str): The user whose profiles to rerank. + query (str): The reranking query. + profile_ids (list[str]): Candidate profile ids; ids that don't belong + to ``user_id`` (or don't exist) are silently dropped. + top_k (int): Maximum number of profiles to return. Defaults to 10. + """ + + user_id: NonEmptyStr + query: NonEmptyStr + profile_ids: list[str] + top_k: int = Field(default=10, gt=0) + + +class RerankUserProfilesResponse(BaseModel): + """Response from :class:`RerankUserProfilesRequest`. + + Args: + success (bool): Whether the rerank call succeeded. + user_profiles (list[UserProfile]): Profiles sorted by descending + cross-encoder score, capped at ``top_k``. + msg (str, optional): Diagnostic message (e.g. how many ids were + silently dropped because they didn't resolve). + """ + + success: bool + user_profiles: list[UserProfile] + msg: str | None = None + + +class StorageStatsRequest(BaseModel): + """Request lightweight metadata about a user's stored profiles + playbooks. + + Useful before deciding ``top_k`` for retrieval — sized counts and + timestamp ranges let the agent pick a sensible cap rather than a fixed + constant. + + Args: + user_id (str): The user to inspect. + """ + + user_id: NonEmptyStr + + +class StorageStatsResponse(BaseModel): + """Response from :class:`StorageStatsRequest`. + + Args: + profile_count (int): Total number of profiles for the user across + all statuses. + playbook_count (int): Total number of user playbooks for the user + across all statuses. + oldest_profile_modified (datetime, optional): UTC timestamp of the + oldest profile's ``last_modified_timestamp``; None when the user + has no profiles. + newest_profile_modified (datetime, optional): UTC timestamp of the + newest profile's ``last_modified_timestamp``; None when the user + has no profiles. + success (bool): Whether the lookup succeeded. + msg (str, optional): Diagnostic message. + """ + + profile_count: int = Field(default=0, ge=0) + playbook_count: int = Field(default=0, ge=0) + oldest_profile_modified: datetime | None = None + newest_profile_modified: datetime | None = None + success: bool + msg: str | None = None + + class GetInteractionsRequest(BaseModel): user_id: NonEmptyStr start_time: datetime | None = None @@ -463,6 +541,7 @@ class UnifiedSearchRequest(BaseModel): user_id: str | None = None conversation_history: list[ConversationTurn] | None = None enable_reformulation: bool | None = False + enable_agent_answer: bool | None = False search_mode: SearchMode = SearchMode.HYBRID @@ -476,6 +555,8 @@ class UnifiedSearchResponse(BaseModel): user_playbooks (list[UserPlaybook]): Matching user playbooks reformulated_query (str, optional): The query used after reformulation (None if reformulation disabled) msg (str, optional): Additional message + agent_answer (str, optional): LLM-synthesised answer populated by the agentic backend; + None for classic backend. """ success: bool @@ -484,6 +565,7 @@ class UnifiedSearchResponse(BaseModel): user_playbooks: list[UserPlaybook] = [] reformulated_query: str | None = None msg: str | None = None + agent_answer: str | None = None # =============================== diff --git a/reflexio/models/api_schema/ui/enums.py b/reflexio/models/api_schema/ui/enums.py index 88a9ef16..e3e8a37f 100644 --- a/reflexio/models/api_schema/ui/enums.py +++ b/reflexio/models/api_schema/ui/enums.py @@ -1,52 +1,25 @@ -"""UI-facing enums for API response models. - -These mirror domain enum values but are independently owned by the UI layer. -Changes to domain enums do not automatically affect the API contract. +"""UI-layer enums — re-export domain enums to keep type identity shared. + +Previously this module declared duplicate StrEnum classes with the same +variants as the domain enums. That broke type identity for pyright — the +UI enum and the domain enum were seen as distinct types even though their +values matched. Re-exporting means ``reflexio.models.api_schema.ui.enums.UserActionType`` +and ``reflexio.models.api_schema.domain.enums.UserActionType`` are the same +class, and converter functions don't need casts. """ -from enum import Enum, StrEnum +from reflexio.models.api_schema.domain.enums import ( + PlaybookStatus, + ProfileTimeToLive, + RegularVsShadow, + Status, + UserActionType, +) __all__ = [ - "UserActionType", - "ProfileTimeToLive", "PlaybookStatus", - "Status", + "ProfileTimeToLive", "RegularVsShadow", + "Status", + "UserActionType", ] - - -class UserActionType(StrEnum): - CLICK = "click" - SCROLL = "scroll" - TYPE = "type" - NONE = "none" - - -class ProfileTimeToLive(StrEnum): - ONE_DAY = "one_day" - ONE_WEEK = "one_week" - ONE_MONTH = "one_month" - ONE_QUARTER = "one_quarter" - ONE_YEAR = "one_year" - INFINITY = "infinity" - - -class PlaybookStatus(StrEnum): - PENDING = "pending" - APPROVED = "approved" - REJECTED = "rejected" - - -class Status(str, Enum): # noqa: UP042 - CURRENT=None is not compatible with StrEnum - CURRENT = None - ARCHIVED = "archived" - PENDING = "pending" - ARCHIVE_IN_PROGRESS = "archive_in_progress" - - -class RegularVsShadow(StrEnum): - REGULAR_IS_BETTER = "regular_is_better" - REGULAR_IS_SLIGHTLY_BETTER = "regular_is_slightly_better" - SHADOW_IS_BETTER = "shadow_is_better" - SHADOW_IS_SLIGHTLY_BETTER = "shadow_is_slightly_better" - TIED = "tied" diff --git a/reflexio/models/config_schema.py b/reflexio/models/config_schema.py index 34bf634e..e0cb3650 100644 --- a/reflexio/models/config_schema.py +++ b/reflexio/models/config_schema.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from enum import IntEnum, StrEnum -from typing import Any, Self +from typing import Any, Literal, Self from pydantic import BaseModel, Field, model_validator @@ -457,6 +457,11 @@ class Config(BaseModel): skip_should_run_check: bool = False # Enable storage-time document expansion for improved FTS recall enable_document_expansion: bool = False + # Pipeline selection — "classic" (single-shot LLM + RAG) or "agentic" + # (multi-reader + critic). Defaults keep existing behavior; flip to + # "agentic" to opt in once Phase 3/4 land. + extraction_backend: Literal["classic", "agentic"] = "classic" + search_backend: Literal["classic", "agentic"] = "classic" @model_validator(mode="before") @classmethod @@ -469,7 +474,12 @@ def _migrate_field_names(cls, data: Any) -> Any: """ data = _migrate_dict(data, _CONFIG_FIELD_MIGRATION) if isinstance(data, dict): - for key in ("batch_size", "batch_interval"): + for key in ( + "batch_size", + "batch_interval", + "extraction_backend", + "search_backend", + ): if key in data and data[key] is None: del data[key] return data diff --git a/reflexio/server/__init__.py b/reflexio/server/__init__.py index 88abf167..8bfaece1 100644 --- a/reflexio/server/__init__.py +++ b/reflexio/server/__init__.py @@ -2,6 +2,7 @@ import logging.handlers import os import sys +import time from pathlib import Path import colorlog @@ -65,6 +66,71 @@ def filter(self, record: logging.LogRecord) -> bool: return record.levelno == LLM_PROMPT_LEVEL +class _TZAwareFormatter(logging.Formatter): + """Formatter that appends the local UTC offset to every timestamp. + + Renders ``2026-04-24 10:20:51.238 -07:00 PDT`` (TZ abbreviation is + optional and only appended on systems with tzdata available) so + readers in any timezone can compute the instant unambiguously. + Offset comes from the local system zoneinfo via + ``time.strftime('%z')`` and is rewritten to ISO 8601 extended form + (``-0700`` → ``-07:00``); falls back to ``+00:00`` on systems + without a configured timezone. + """ + + default_time_format = "%Y-%m-%d %H:%M:%S" + default_msec_format = "%s.%03d" + + def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str: # noqa: ARG002, N802 + ct = time.localtime(record.created) + base = time.strftime(self.default_time_format, ct) + msecs = int(record.msecs) + # ISO 8601 extended form: "-0700" -> "-07:00" — the colon separator + # reads more clearly as a UTC offset to humans skimming logs. + raw_offset = time.strftime("%z", ct) or "+0000" + offset = ( + f"{raw_offset[:3]}:{raw_offset[3:]}" if len(raw_offset) >= 5 else raw_offset + ) + # Append the local TZ abbreviation (PDT / UTC / etc.) when available. + # Some minimal containers without tzdata return "" here; the offset + # alone stays machine-parseable regardless. + tz_name = time.strftime("%Z", ct) + if tz_name: + return f"{base}.{msecs:03d} {offset} {tz_name}" + return f"{base}.{msecs:03d} {offset}" + + +class _LLMIOFormatter(_TZAwareFormatter): + """Format LLM prompts/responses with delimiters and entry IDs.""" + + _HEADER = "═" * 64 + _FOOTER = "─" * 64 + + def format(self, record: logging.LogRecord) -> str: + timestamp = self.formatTime(record) + message = record.getMessage() + short_logger = record.name.rsplit(".", 1)[-1] + # Use structured extra attributes when available; fall back to parsing + entry_id = getattr(record, "entry_id", None) + label = getattr(record, "label", None) + entry_tag = f"[#{entry_id}]" if entry_id is not None else "" + if label is None: + label = message[:60] + header_line = ( + f"{entry_tag} [{timestamp}] {label}" + if entry_tag + else f"[{timestamp}] {label}" + ) + return ( + f"\n{self._HEADER}\n" + f"{header_line}\n" + f"Service: {short_logger}\n" + f"{self._HEADER}\n" + f"{message}\n" + f"{self._FOOTER}\n" + ) + + DEBUG_LOG_TO_CONSOLE = os.environ.get("DEBUG_LOG_TO_CONSOLE", "").strip().lower() root_logger = logging.getLogger() @@ -111,7 +177,7 @@ def filter(self, record: logging.LogRecord) -> bool: ) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter( - logging.Formatter( + _TZAwareFormatter( "%(asctime)s %(correlation_tag)s%(name)s %(levelname)s %(message)s" ) ) @@ -120,36 +186,6 @@ def filter(self, record: logging.LogRecord) -> bool: root_logger.addHandler(file_handler) # LLM I/O log file — only LLM_PROMPT level, with structured delimiters - _HEADER = "═" * 64 - _FOOTER = "─" * 64 - - class _LLMIOFormatter(logging.Formatter): - """Format LLM prompts/responses with delimiters and entry IDs.""" - - def format(self, record: logging.LogRecord) -> str: - timestamp = self.formatTime(record) - message = record.getMessage() - short_logger = record.name.rsplit(".", 1)[-1] - # Use structured extra attributes when available; fall back to parsing - entry_id = getattr(record, "entry_id", None) - label = getattr(record, "label", None) - entry_tag = f"[#{entry_id}]" if entry_id is not None else "" - if label is None: - label = message[:60] - header_line = ( - f"{entry_tag} [{timestamp}] {label}" - if entry_tag - else f"[{timestamp}] {label}" - ) - return ( - f"\n{_HEADER}\n" - f"{header_line}\n" - f"Service: {short_logger}\n" - f"{_HEADER}\n" - f"{message}\n" - f"{_FOOTER}\n" - ) - llm_io_handler = logging.handlers.RotatingFileHandler( LLM_IO_LOG_FILE, maxBytes=10_000_000, backupCount=3, encoding="utf-8" ) diff --git a/reflexio/server/api.py b/reflexio/server/api.py index 93d72f4c..77280ece 100644 --- a/reflexio/server/api.py +++ b/reflexio/server/api.py @@ -35,6 +35,7 @@ GetUserProfilesRequest, ProfileChangeLogViewResponse, RequestDataView, + RerankUserProfilesRequest, SearchAgentPlaybookRequest, SearchAgentPlaybooksViewResponse, SearchInteractionRequest, @@ -45,6 +46,8 @@ SearchUserProfileRequest, SessionView, SetConfigResponse, + StorageStatsRequest, + StorageStatsResponse, UnifiedSearchRequest, UnifiedSearchViewResponse, UpdateAgentPlaybookRequest, @@ -428,12 +431,12 @@ def add_user_profile_endpoint( @core_router.post( - "/api/search_profiles", + "/api/search_user_profiles", response_model=SearchProfilesViewResponse, response_model_exclude_none=True, ) @limiter.limit("120/minute") # Rate limit for read operations -def search_profiles( +def search_user_profiles( request: Request, payload: SearchUserProfileRequest, org_id: str = Depends(default_get_org_id), @@ -446,6 +449,62 @@ def search_profiles( ) +@core_router.post( + "/api/rerank_user_profiles", + response_model=SearchProfilesViewResponse, + response_model_exclude_none=True, +) +@limiter.limit("120/minute") # Rate limit for read operations +def rerank_user_profiles( + request: Request, + payload: RerankUserProfilesRequest, + org_id: str = Depends(default_get_org_id), +) -> SearchProfilesViewResponse: + """Rerank a list of profile ids by query relevance using a cross-encoder. + + Args: + request (Request): The HTTP request object (for rate limiting) + payload (RerankUserProfilesRequest): The rerank request + org_id (str): Organization ID + + Returns: + SearchProfilesViewResponse: Reranked profiles, top_k entries. + """ + response = retriever_api.rerank_user_profiles(org_id=org_id, request=payload) + return SearchProfilesViewResponse( + success=response.success, + user_profiles=[to_profile_view(p) for p in response.user_profiles], + msg=response.msg, + ) + + +@core_router.get( + "/api/storage_stats", + response_model=StorageStatsResponse, + response_model_exclude_none=True, +) +@limiter.limit("120/minute") # Rate limit for read operations +def storage_stats( + request: Request, + user_id: str, + org_id: str = Depends(default_get_org_id), +) -> StorageStatsResponse: + """Return lightweight metadata about a user's profiles and playbooks. + + Args: + request (Request): The HTTP request object (for rate limiting) + user_id (str): Target user id, passed as a query parameter so this is + a cacheable, idempotent GET. + org_id (str): Organization ID + + Returns: + StorageStatsResponse: Counts and timestamp range for the user. + """ + return retriever_api.storage_stats( + org_id=org_id, request=StorageStatsRequest(user_id=user_id) + ) + + @core_router.post( "/api/search_interactions", response_model=SearchInteractionsViewResponse, diff --git a/reflexio/server/api_endpoints/retriever_api.py b/reflexio/server/api_endpoints/retriever_api.py index 04de6680..4400e59b 100644 --- a/reflexio/server/api_endpoints/retriever_api.py +++ b/reflexio/server/api_endpoints/retriever_api.py @@ -9,6 +9,8 @@ GetRequestsResponse, GetUserProfilesRequest, GetUserProfilesResponse, + RerankUserProfilesRequest, + RerankUserProfilesResponse, SearchAgentPlaybookRequest, SearchAgentPlaybookResponse, SearchInteractionRequest, @@ -17,6 +19,8 @@ SearchUserPlaybookResponse, SearchUserProfileRequest, SearchUserProfileResponse, + StorageStatsRequest, + StorageStatsResponse, UnifiedSearchRequest, UnifiedSearchResponse, ) @@ -51,7 +55,43 @@ def search_user_profiles( SearchUserProfileResponse: Response containing matching user profiles """ reflexio = get_reflexio(org_id=org_id) - return reflexio.search_profiles(request) + return reflexio.search_user_profiles(request) + + +def rerank_user_profiles( + org_id: str, + request: RerankUserProfilesRequest, +) -> RerankUserProfilesResponse: + """Rerank a list of profile ids by query relevance using a cross-encoder. + + Args: + org_id (str): Organization ID + request (RerankUserProfilesRequest): The rerank request containing + user_id, query, profile_ids and top_k. + + Returns: + RerankUserProfilesResponse: Profiles sorted by descending cross-encoder + score, capped at ``request.top_k``. + """ + reflexio = get_reflexio(org_id=org_id) + return reflexio.rerank_user_profiles(request) + + +def storage_stats( + org_id: str, + request: StorageStatsRequest, +) -> StorageStatsResponse: + """Return lightweight metadata about a user's stored profiles and playbooks. + + Args: + org_id (str): Organization ID + request (StorageStatsRequest): The stats request containing user_id. + + Returns: + StorageStatsResponse: Counts and timestamp range for the user. + """ + reflexio = get_reflexio(org_id=org_id) + return reflexio.storage_stats(request) def search_interactions( diff --git a/reflexio/server/llm/__init__.py b/reflexio/server/llm/__init__.py index 77e24684..89701ab1 100644 --- a/reflexio/server/llm/__init__.py +++ b/reflexio/server/llm/__init__.py @@ -9,6 +9,7 @@ LiteLLMClient, LiteLLMClientError, LiteLLMConfig, + ToolCallingChatResponse, create_litellm_client, ) from .model_defaults import ( @@ -22,6 +23,7 @@ "LiteLLMConfig", "LiteLLMClientError", "ModelRole", + "ToolCallingChatResponse", "create_litellm_client", "resolve_model_name", "validate_llm_availability", diff --git a/reflexio/server/llm/litellm_client.py b/reflexio/server/llm/litellm_client.py index 9822c458..6f7bbbff 100644 --- a/reflexio/server/llm/litellm_client.py +++ b/reflexio/server/llm/litellm_client.py @@ -41,14 +41,27 @@ from reflexio.server.llm.providers.local_embedding_provider import ( register_if_enabled as _register_local_embedder, ) +from reflexio.server.llm.providers.nomic_embedding_provider import ( + NomicEmbedder, +) +from reflexio.server.llm.providers.nomic_embedding_provider import ( + is_enabled as _nomic_embedder_enabled, +) +from reflexio.server.llm.providers.nomic_embedding_provider import ( + is_nomic_model as _is_nomic_model, +) +from reflexio.server.llm.providers.nomic_embedding_provider import ( + register_if_enabled as _register_nomic_embedder, +) # Suppress LiteLLM's verbose logging litellm.suppress_debug_info = True -# Opt-in registration of claude-smart's local providers. Both are -# no-ops unless the matching env var is set. Safe to call at import. +# Opt-in registration of claude-smart's local providers. All no-ops +# unless the matching env var is set. Safe to call at import. _register_claude_code() _register_local_embedder() +_register_nomic_embedder() _LOGGER = logging.getLogger(__name__) @@ -205,10 +218,43 @@ class LiteLLMConfig: api_key_config: APIKeyConfig | None = None +@dataclass +class ToolCallingChatResponse: + """Response from a chat call that was routed in tool-calling mode. + + Returned instead of ``str | BaseModel`` whenever the caller passes + ``tools=...`` to ``generate_chat_response``. Callers inspect + ``tool_calls`` to drive a tool loop; ``content`` is set on the + terminal (non-tool) turn. + + Args: + content: Text content from the model, or None when the model emitted tool calls. + tool_calls: List of tool call objects from the model, or None on the terminal turn. + finish_reason: The stop reason reported by the provider (e.g. "tool_calls", "stop"). + usage: Raw usage object from the LLM response (provider-dependent shape), or None. + cost_usd: Estimated cost in USD for this call via litellm price table, or None when + the provider is not in the table (local ONNX, claude-code CLI, etc.). + """ + + content: str | None + tool_calls: list[Any] | None + finish_reason: str | None + usage: Any | None = None + cost_usd: float | None = None + + class LiteLLMClientError(Exception): """Custom exception for LiteLLM client errors.""" +class StructuredOutputParseError(Exception): + """Raised when a structured-output LLM call returns content that cannot be parsed. + + Caught by the retry loop in ``_make_request`` so a malformed response + burns a retry attempt rather than silently returning unparsed content. + """ + + class LiteLLMClient: """ Unified LLM client using LiteLLM for multi-provider support. @@ -368,8 +414,8 @@ def generate_response( system_message: str | None = None, images: list[str | bytes | dict] | None = None, image_media_type: str | None = None, - **kwargs, - ) -> str | BaseModel: + **kwargs: Any, + ) -> str | BaseModel | ToolCallingChatResponse: """ Generate a response using the configured LLM. @@ -415,14 +461,25 @@ def generate_chat_response( self, messages: list[dict[str, Any]], system_message: str | None = None, - **kwargs, - ) -> str | BaseModel: + *, + tools: list[Any] | None = None, + tool_choice: str | dict[str, Any] | None = None, + model_role: ModelRole | None = None, + **kwargs: Any, + ) -> str | BaseModel | ToolCallingChatResponse: """ Generate a response from a list of chat messages. Args: messages: List of messages in chat format [{"role": "...", "content": "..."}]. system_message: Optional system message to prepend. + tools: Optional list of tool definitions for tool-calling mode. + When provided, the return type is ``ToolCallingChatResponse``. + tool_choice: Optional tool choice control ("auto", "none", "required", + or a dict specifying a particular tool). Forwarded to the provider. + model_role: Optional ``ModelRole`` to override the model selected for + this request. The role is resolved via ``resolve_model_name`` using + the client's ``api_key_config``. **kwargs: Additional parameters including: - response_format: Pydantic BaseModel class for structured output - parse_structured_output: Whether to parse structured output (default True) @@ -431,7 +488,8 @@ def generate_chat_response( Returns: Generated response content. Returns string for text responses, - or BaseModel instance for Pydantic model responses. + ``BaseModel`` instance for Pydantic model responses, or + ``ToolCallingChatResponse`` when ``tools`` is provided. Raises: LiteLLMClientError: If the API call fails after all retries, @@ -457,6 +515,14 @@ def generate_chat_response( else: final_messages.insert(0, {"role": "system", "content": system_message}) + # Forward tool-calling and model-role kwargs into _make_request + if tools is not None: + kwargs["tools"] = tools + if tool_choice is not None: + kwargs["tool_choice"] = tool_choice + if model_role is not None: + kwargs["model_role"] = model_role + return self._make_request(final_messages, **kwargs) def _resolve_default_embedding_model(self) -> str: @@ -507,6 +573,18 @@ def get_embedding( """ embedding_model = model or self._resolve_default_embedding_model() + # local/nomic-embed-* routes to the sentence-transformers Nomic + # provider (137M params, 768d Matryoshka-truncated to 512). Higher + # quality than the chromadb MiniLM fallback below; preferred when + # the dep is installed. + if _is_nomic_model(embedding_model) and _nomic_embedder_enabled(): + try: + return NomicEmbedder.get().embed([text])[0] + except Exception as e: + raise LiteLLMClientError( + f"Nomic embedding generation failed: {str(e)}" + ) from e + # local/* models route through the in-process ONNX embedder — no # network call, no litellm API, no tiktoken truncation (the embedder # applies its own token cap). @@ -569,7 +647,15 @@ def get_embeddings( embedding_model = model or self._resolve_default_embedding_model() - # See matching short-circuit in get_embedding above. + # See matching short-circuits in get_embedding above. + if _is_nomic_model(embedding_model) and _nomic_embedder_enabled(): + try: + return NomicEmbedder.get().embed(list(texts)) + except Exception as e: + raise LiteLLMClientError( + f"Nomic batch embedding generation failed: {str(e)}" + ) from e + if embedding_model.startswith("local/") and _local_embedder_enabled(): try: return LocalEmbedder.get().embed(list(texts)) @@ -625,7 +711,24 @@ def _build_completion_params( except (TypeError, ValueError): max_retries = max(1, int(self.config.max_retries)) + # Pop tool-calling kwargs before the final params.update(kwargs) so they + # don't leak into the params dict twice. + tools = kwargs.pop("tools", None) + tool_choice = kwargs.pop("tool_choice", None) + model_role: ModelRole | None = kwargs.pop("model_role", None) + actual_model = kwargs.pop("model", self.config.model) + + # model_role takes priority over the default model but falls through + # to the custom_endpoint override below (highest priority). + if model_role is not None: + actual_model = resolve_model_name( + role=model_role, + site_var_value=None, + config_override=None, + api_key_config=self.config.api_key_config, + ) + ce = ( self.config.api_key_config.custom_endpoint if self.config.api_key_config @@ -670,6 +773,10 @@ def _build_completion_params( params["top_p"] = self.config.top_p if response_format: params["response_format"] = response_format + if tools is not None: + params["tools"] = tools + if tool_choice is not None: + params["tool_choice"] = tool_choice if actual_model != self.config.model: api_key, api_base, api_version = self._resolve_api_key(actual_model) @@ -700,8 +807,29 @@ def _build_completion_params( return params, response_format, parse_structured_output, max_retries + def _compute_cost_usd(self, response: Any, model: str | None) -> float | None: + """Compute call cost in USD via the litellm price table. + + Falls back to None when the provider is not mapped (local ONNX, + claude-code CLI, etc.) rather than failing the request. + + Args: + response: Raw LLM response object. + model: Fully-qualified model name used for the call. + + Returns: + float | None: Cost in USD, or None when unavailable. + """ + try: + import litellm + + cost = litellm.completion_cost(completion_response=response, model=model) + return float(cost) if cost else None + except Exception: + return None + def _log_token_usage(self, params: dict[str, Any], response: Any) -> None: - """Log token usage with cache statistics from an LLM response. + """Log token usage with cache statistics and cost from an LLM response. Args: params: Request parameters (for model name) @@ -724,13 +852,17 @@ def _log_token_usage(self, params: dict[str, Any], response: Any) -> None: f", cache_write: {cache_creation or 0}, cache_read: {cache_read or 0}" ) + cost = self._compute_cost_usd(response, params.get("model")) + cost_suffix = f", cost: ${cost:.6f}" if cost is not None else "" + self.logger.info( - "Token usage - model: %s, input: %s, output: %s, total: %s%s", + "Token usage - model: %s, input: %s, output: %s, total: %s%s%s", params.get("model"), usage.prompt_tokens, usage.completion_tokens, usage.total_tokens, cache_info, + cost_suffix, ) def _handle_retry_or_raise( @@ -794,7 +926,7 @@ def _handle_retry_or_raise( def _make_request( self, messages: list[dict[str, Any]], **kwargs: Any - ) -> str | BaseModel: + ) -> str | BaseModel | ToolCallingChatResponse: """ Make a request to the LLM with retry logic. @@ -803,7 +935,8 @@ def _make_request( **kwargs: Additional parameters. Returns: - Response content as string or BaseModel instance. + Response content as string, BaseModel instance, or + ToolCallingChatResponse when the request was in tool-calling mode. Raises: LiteLLMClientError: If the request fails after all retries. @@ -825,7 +958,8 @@ def _make_request( ) try: response = litellm.completion(**params) - content = response.choices[0].message.content # type: ignore[reportAttributeAccessIssue] + message = response.choices[0].message # type: ignore[reportAttributeAccessIssue] + content = message.content elapsed_seconds = time.perf_counter() - request_start self._log_token_usage(params, response) @@ -841,6 +975,19 @@ def _make_request( True, ) + # Tool-calling path: return a structured response instead of + # going through _maybe_parse_structured_output. + if "tools" in params: + raw_usage = getattr(response, "usage", None) + call_cost = self._compute_cost_usd(response, params.get("model")) + return ToolCallingChatResponse( + content=content, + tool_calls=getattr(message, "tool_calls", None), + finish_reason=response.choices[0].finish_reason, # type: ignore[reportAttributeAccessIssue] + usage=raw_usage, + cost_usd=call_cost, + ) + return self._maybe_parse_structured_output( content, # type: ignore[reportArgumentType] response_format, @@ -1056,8 +1203,14 @@ def _maybe_parse_structured_output( parsed = json.loads(sanitized) return response_format.model_validate(parsed) except Exception as e: - self.logger.warning("Failed to parse structured output: %s", e) - return content + model = self.config.model + snippet = ( + content[:200] if isinstance(content, str) else repr(content)[:200] + ) + raise StructuredOutputParseError( + f"Structured output parse failed for model={model!r}: {e}. " + f"Content snippet: {snippet!r}" + ) from e def _extract_json_from_string(self, content: str) -> str: """ diff --git a/reflexio/server/llm/model_defaults.py b/reflexio/server/llm/model_defaults.py index 7cb6fe3b..0584888a 100644 --- a/reflexio/server/llm/model_defaults.py +++ b/reflexio/server/llm/model_defaults.py @@ -151,6 +151,8 @@ class ProviderDefaults: should_run: Model for lightweight "should run extraction" checks, or None. pre_retrieval: Model for pre-retrieval query reformulation, or None. embedding: Model for embedding generation, or None. + extraction_agent: Sonnet-tier model for the agentic-v2 extraction loop, or None. + search_agent: Sonnet-tier model for the agentic-v2 search loop, or None. """ generation: str | None @@ -158,6 +160,8 @@ class ProviderDefaults: should_run: str | None pre_retrieval: str | None embedding: str | None + extraction_agent: str | None = None + search_agent: str | None = None _PROVIDER_DEFAULTS: dict[str, ProviderDefaults] = { @@ -171,6 +175,8 @@ class ProviderDefaults: should_run="claude-code/default", pre_retrieval="claude-code/default", embedding=None, + extraction_agent="claude-code/default", + search_agent="claude-code/default", ), # local is an embedding-only provider that routes through an # in-process ONNX model (chromadb's all-MiniLM-L6-v2). Generation @@ -188,6 +194,17 @@ class ProviderDefaults: should_run="gpt-5-nano", pre_retrieval="gpt-5-nano", embedding="text-embedding-3-small", + extraction_agent="gpt-5-mini", + # search_agent uses gpt-5.5: the multi-step orchestration (pattern + # dispatch, mandatory rehydration in E/G, narration) was being + # ignored by gpt-5-mini, which defaulted to single-search-finish for + # ~90% of questions. minimax/MiniMax-M2.7 doesn't reliably honor + # response_format for the multi-stage fallback path, blocking that + # cheaper option (multi-stage infrastructure stays in place for + # future non-tool-calling models). gpt-5.5 supports native tool + # calling and reliably follows multi-step recipes; per-question + # cost increase is small compared to the answer LLM. + search_agent="gpt-5.5", ), "anthropic": ProviderDefaults( generation="claude-sonnet-4-6", @@ -195,6 +212,8 @@ class ProviderDefaults: should_run="claude-haiku-4-5-20251001", pre_retrieval="claude-haiku-4-5-20251001", embedding=None, + extraction_agent="claude-sonnet-4-6", + search_agent="claude-sonnet-4-6", ), "gemini": ProviderDefaults( generation="gemini/gemini-3-flash-preview", @@ -273,6 +292,10 @@ class ModelRole(StrEnum): SHOULD_RUN = "should_run" PRE_RETRIEVAL = "pre_retrieval" EMBEDDING = "embedding" + # Agentic-v2 single-loop roles — Sonnet-tier agents that drive the + # extraction and search tool loops. + EXTRACTION_AGENT = "extraction_agent" + SEARCH_AGENT = "search_agent" def _auto_detect_model( diff --git a/reflexio/server/llm/providers/nomic_embedding_provider.py b/reflexio/server/llm/providers/nomic_embedding_provider.py new file mode 100644 index 00000000..8ea15898 --- /dev/null +++ b/reflexio/server/llm/providers/nomic_embedding_provider.py @@ -0,0 +1,234 @@ +"""Local in-process embedder using ``nomic-ai/nomic-embed-text-v1.5``. + +A higher-quality alternative to the chromadb-bundled MiniLM-L6-v2: 137M +parameters, 768-dim native, supports Matryoshka representation (64–768 +dimensions without retraining), 8192-token context, Apache-2.0 licensed. +Performs comparably to OpenAI's ``text-embedding-3-small`` on MTEB +retrieval at a fraction of the latency cost when run locally on CPU or +Apple Silicon. + +Activation +---------- + +- Set ``CLAUDE_SMART_USE_LOCAL_EMBEDDING=1`` in the process environment. +- Pass model name ``local/nomic-embed-v1.5`` (or ``local/nomic-embed-text-v1.5``) + to :func:`LiteLLMClient.get_embedding`/``get_embeddings``. +- Requires the ``sentence-transformers`` pip dependency. + +Storage compatibility +--------------------- + +Reflexio's vec0 tables expect 512-dim vectors (``EMBEDDING_DIMENSIONS``). +Nomic's native 768 dim is reduced via Matryoshka — slice the first 512 +floats, then re-normalize to unit length so cosine similarity remains +comparable. Quality on retrieval tasks at 512 dim is ~95% of the full +768 (per Nomic's own evaluation). +""" + +from __future__ import annotations + +import importlib.util +import logging +import math +import os +import threading +from typing import Any + +_LOGGER = logging.getLogger(__name__) + +_ENV_ENABLE = "CLAUDE_SMART_USE_LOCAL_EMBEDDING" +_MODEL_KEYS = {"local/nomic-embed-v1.5", "local/nomic-embed-text-v1.5"} +_HF_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" + +# Reflexio's vec0 schema dim. Nomic v1.5 outputs 768 natively; we slice +# to 512 (Matryoshka) and re-normalize. +_TARGET_DIM = 512 +# Nomic v1.5 was trained with task-prefixed inputs; "search_document" +# vs "search_query" prefixes give better asymmetric retrieval. Reflexio's +# storage layer already passes a "search_document: " / "search_query: " +# prefix when calling _get_embedding(purpose=...), so we don't add another +# prefix here — the input arrives correctly tagged. +# The model has a 8192 token context window; we still cap chars +# defensively to avoid pathological multi-MB inputs. +_MAX_CHARS = 32_000 + + +class NomicEmbedderError(RuntimeError): + """Raised when the Nomic embedder is requested but its deps are missing.""" + + +class NomicEmbedder: + """Lazily-loaded singleton wrapping a sentence-transformers model. + + Loading the underlying ``nomic-embed-text-v1.5`` model takes ~5–10 s on + first call (downloads ~550 MB on cold start, then cached under + ``~/.cache/huggingface/``). After that, embedding latency on CPU is + ~30–60 ms per single text and ~200 ms per batch of 32 (Apple M-series). + """ + + _instance: NomicEmbedder | None = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._model: Any | None = None + self._model_lock = threading.Lock() + + @classmethod + def get(cls) -> NomicEmbedder: + """Return the process-wide singleton, constructing it on first use.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _load(self) -> Any: + """Lazy-import sentence-transformers and load the Nomic model.""" + if self._model is not None: + return self._model + with self._model_lock: + if self._model is not None: + return self._model + try: + from sentence_transformers import ( + SentenceTransformer, # type: ignore[import-not-found] + ) + except ImportError as exc: + raise NomicEmbedderError( + "sentence-transformers is required for the Nomic local " + "embedder. Install with `uv add sentence-transformers`." + ) from exc + _LOGGER.info( + "Loading Nomic embedding model %s — first call may download " + "~550 MB to ~/.cache/huggingface/", + _HF_MODEL_NAME, + ) + # Force CPU device — MPS init has been observed to hang on some + # Apple Silicon + macOS combos for several minutes during model + # load. CPU is fast enough for our use case (137M params) and + # behaves predictably. Set NOMIC_EMBED_DEVICE=mps|cuda|cpu to + # override. + device = os.environ.get("NOMIC_EMBED_DEVICE", "cpu") + self._model = SentenceTransformer( + _HF_MODEL_NAME, + trust_remote_code=True, # Nomic v1.5 ships custom code + device=device, + ) + _LOGGER.info( + "Nomic embedder ready (model=%s, target_dim=%d, native_dim=%d)", + _HF_MODEL_NAME, + _TARGET_DIM, + self._model.get_sentence_embedding_dimension(), + ) + return self._model + + def embed(self, texts: list[str]) -> list[list[float]]: + """Embed a batch of texts, returning ``_TARGET_DIM``-sized unit vectors. + + Args: + texts: Inputs to encode. Each is char-truncated to ``_MAX_CHARS`` + as a defensive cap; Nomic itself supports 8192 tokens. + + Returns: + list[list[float]]: One vector per input, each exactly + ``_TARGET_DIM`` (512) floats and L2-normalised so cosine + similarity equals dot product. + """ + model = self._load() + safe = [(t or "")[:_MAX_CHARS] for t in texts] + # show_progress_bar=False so server logs stay clean during ingest + # batches. convert_to_numpy=True returns a numpy ndarray; we slice + # and renormalise per-row before converting to plain Python lists. + raw = model.encode(safe, show_progress_bar=False, convert_to_numpy=True) + return [_truncate_and_renormalise(vec.tolist()) for vec in raw] + + +def _truncate_and_renormalise(vec: list[float]) -> list[float]: + """Slice to ``_TARGET_DIM`` and L2-renormalise for valid Matryoshka use. + + Args: + vec (list[float]): Native-dim Nomic embedding (typically 768 floats, + already L2-unit on the full 768). + + Returns: + list[float]: Exactly ``_TARGET_DIM`` floats, L2-normalised in the + truncated subspace so cosine similarity remains a valid metric. + Zero-padded if the input is shorter than ``_TARGET_DIM``. + """ + if len(vec) >= _TARGET_DIM: + sliced = vec[:_TARGET_DIM] + else: + sliced = vec + [0.0] * (_TARGET_DIM - len(vec)) + norm = math.sqrt(sum(x * x for x in sliced)) + if norm <= 0: + return sliced + return [x / norm for x in sliced] + + +_REGISTERED = False + + +def register_if_enabled() -> bool: + """Make the Nomic embedder available when env + deps allow it. + + Idempotent. Returns ``True`` when the embedder is usable after this + call. Routing happens via prefix-match on the model name in + ``LiteLLMClient.get_embedding(s)``. + + Eagerly pre-warms the model in a daemon thread so the first request + doesn't pay the ~30 s cold-start cost. The thread is fire-and-forget; + callers either land mid-load (and block briefly) or after-load (and + proceed immediately). + """ + global _REGISTERED + if _REGISTERED: + return True + if os.environ.get(_ENV_ENABLE) not in {"1", "true", "True"}: + return False + if importlib.util.find_spec("sentence_transformers") is None: + _LOGGER.warning( + "%s=1 set but `sentence-transformers` not installed; the Nomic " + "local embedder will not be available.", + _ENV_ENABLE, + ) + return False + _REGISTERED = True + _LOGGER.info("Nomic local embedding provider enabled (models=%s)", sorted(_MODEL_KEYS)) + + def _prewarm() -> None: + """Background load + dummy inference so the first real request is fast.""" + try: + embedder = NomicEmbedder.get() + embedder.embed(["warmup"]) + _LOGGER.info("Nomic embedder pre-warmed") + except Exception: # noqa: BLE001 + _LOGGER.exception("Nomic embedder pre-warm failed; first call will pay the cost") + + threading.Thread(target=_prewarm, daemon=True, name="nomic-prewarm").start() + return True + + +def is_enabled() -> bool: + """Return True after a successful :func:`register_if_enabled`.""" + return _REGISTERED + + +def is_nomic_model(model: str) -> bool: + """Predicate used by ``LiteLLMClient`` to route by model name. + + Args: + model (str): The embedding model name passed by the caller. + + Returns: + bool: True when the model resolves to the Nomic provider. + """ + return model in _MODEL_KEYS + + +__all__ = [ + "NomicEmbedder", + "NomicEmbedderError", + "is_enabled", + "is_nomic_model", + "register_if_enabled", +] diff --git a/reflexio/server/llm/rerank/__init__.py b/reflexio/server/llm/rerank/__init__.py new file mode 100644 index 00000000..843d83ae --- /dev/null +++ b/reflexio/server/llm/rerank/__init__.py @@ -0,0 +1,5 @@ +"""Local cross-encoder reranking helpers.""" + +from reflexio.server.llm.rerank.cross_encoder_reranker import score_pairs + +__all__ = ["score_pairs"] diff --git a/reflexio/server/llm/rerank/cross_encoder_reranker.py b/reflexio/server/llm/rerank/cross_encoder_reranker.py new file mode 100644 index 00000000..a4bb96c1 --- /dev/null +++ b/reflexio/server/llm/rerank/cross_encoder_reranker.py @@ -0,0 +1,146 @@ +"""Local cross-encoder reranker for ``(query, document)`` pairs. + +Wraps ``cross-encoder/ms-marco-MiniLM-L-6-v2`` (~25M params) from +``sentence-transformers``. The model is lazy-loaded on first call and +held as a process-wide singleton — load takes ~3 s but only happens +once per server start. Scoring K=30 pairs takes ~50 ms on CPU. + +Usage +----- + +>>> from reflexio.server.llm.rerank import score_pairs +>>> scores = score_pairs("italian food", ["pasta lover", "weather report"]) +>>> scores[0] > scores[1] +True + +The helper is intentionally side-effect free at import time: building +the singleton happens only when ``score_pairs`` is called, so importing +this module never triggers a model download. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +_LOGGER = logging.getLogger(__name__) + +# HuggingFace identifier for the cross-encoder. Chosen for the +# size/quality trade-off: 22M parameters, ~50 ms for K=30 on CPU, +# well-known MS-MARCO benchmark performance. +_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" + +# Singleton state — never accessed directly outside ``_get_model``. +_MODEL: Any | None = None +_MODEL_LOCK = threading.Lock() + + +class CrossEncoderUnavailableError(RuntimeError): + """Raised when the cross-encoder model cannot be loaded. + + The most common cause is ``sentence-transformers`` being absent from + the runtime environment. Callers should treat this as a soft failure + (log + skip rerank) rather than a 500. + """ + + +def _import_cross_encoder() -> Any: + """Robustly import ``sentence_transformers.CrossEncoder``. + + The ``sentence_transformers`` package is loaded both by the Nomic + local-embedding pre-warm thread (kicked off by ``LiteLLMClient.__init__`` + when ``CLAUDE_SMART_USE_LOCAL_EMBEDDING=1``) and by this reranker. When + those concurrent imports race, one thread can see a half-loaded + ``sentence_transformers`` module in ``sys.modules`` whose ``CrossEncoder`` + attribute was never bound — Python's import machinery hands back the + partial module without re-running ``__init__``. This helper detects that + case, drops the stale entry, and re-imports cleanly. + + Returns: + Any: The ``CrossEncoder`` class. + + Raises: + CrossEncoderUnavailableError: When the package genuinely isn't + installed, or every retry yields a partial module. + """ + import sys + + for _attempt in range(2): + try: + from sentence_transformers import CrossEncoder + except ImportError: + # Partial import — drop the stale entry and try once more. + sys.modules.pop("sentence_transformers", None) + continue + return CrossEncoder + try: + from sentence_transformers import ( + CrossEncoder, # noqa: F401 — final attempt for the error path + ) + except ImportError as e: + raise CrossEncoderUnavailableError( + "sentence-transformers is not installed; cannot use the " + "cross-encoder reranker" + ) from e + return CrossEncoder + + +def _get_model() -> Any: + """Return the lazy-loaded cross-encoder singleton. + + The first caller pays the load cost (~3 s, weights cached under + ``~/.cache/huggingface/`` after first download). Subsequent callers + get the warm instance immediately. + + Returns: + Any: A ``sentence_transformers.CrossEncoder`` instance. + + Raises: + CrossEncoderUnavailableError: If ``sentence-transformers`` is not + importable, or if the underlying model fails to load. + """ + global _MODEL # noqa: PLW0603 — singleton-pattern intentional + if _MODEL is not None: + return _MODEL + with _MODEL_LOCK: + if _MODEL is not None: + return _MODEL + cross_encoder_cls = _import_cross_encoder() + try: + _MODEL = cross_encoder_cls(_MODEL_NAME) + except Exception as e: # noqa: BLE001 — surface as a typed failure + raise CrossEncoderUnavailableError( + f"Failed to load cross-encoder model {_MODEL_NAME!r}: {e}" + ) from e + _LOGGER.info("Loaded cross-encoder model %s", _MODEL_NAME) + return _MODEL + + +def score_pairs(query: str, docs: list[str]) -> list[float]: + """Score ``(query, doc)`` pairs with the cross-encoder. + + Higher score means more relevant. Scores are not bounded to a fixed + range — they are raw model logits — so callers should treat them as + opaque relative-ranking signal, not as probabilities. + + Args: + query (str): The reranking query. + docs (list[str]): Documents to score against ``query``. + + Returns: + list[float]: One score per document, in the same order as + ``docs``. Empty list when ``docs`` is empty. + + Raises: + CrossEncoderUnavailableError: If the cross-encoder cannot be + loaded (re-raised from :func:`_get_model`). + """ + if not docs: + return [] + model = _get_model() + pairs = [(query, doc) for doc in docs] + raw_scores = model.predict(pairs) + # ``predict`` returns a numpy array; convert to plain Python floats so + # the caller can serialise the result without numpy as a dependency. + return [float(s) for s in raw_scores] diff --git a/reflexio/server/llm/tools.py b/reflexio/server/llm/tools.py new file mode 100644 index 00000000..5793e977 --- /dev/null +++ b/reflexio/server/llm/tools.py @@ -0,0 +1,523 @@ +"""Tool-calling primitives shared by agentic extraction and search pipelines.""" + +from __future__ import annotations + +import json +import logging +import time +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal + +logger = logging.getLogger(__name__) + +from pydantic import BaseModel, ConfigDict, ValidationError + +from reflexio.server.llm.model_defaults import ModelRole, resolve_model_name + +if TYPE_CHECKING: + from reflexio.server.llm.litellm_client import LiteLLMClient + + +class Tool(BaseModel): + """A single LLM-callable tool. + + Arguments are defined by a Pydantic model (its schema goes to the LLM, + its docstring becomes the tool description). The handler takes a + validated args instance plus a caller-supplied context object and + returns a JSON-serialisable dict that is fed back as the tool result. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str + args_model: type[BaseModel] + handler: Callable[[BaseModel, Any], dict] + + def openai_spec(self) -> dict: + return { + "type": "function", + "function": { + "name": self.name, + "description": (self.args_model.__doc__ or "").strip(), + "parameters": self.args_model.model_json_schema(), + }, + } + + +class ToolRegistry: + def __init__(self, tools: list[Tool] | None = None) -> None: + self._tools: dict[str, Tool] = {} + for t in tools or []: + self.register(t) + + def register(self, tool: Tool) -> None: + self._tools[tool.name] = tool + + def openai_specs(self) -> list[dict]: + return [t.openai_spec() for t in self._tools.values()] + + def handle(self, name: str, args_json: str, ctx: Any) -> dict: + tool = self._tools.get(name) + if tool is None: + return {"error": f"unknown tool: {name}"} + try: + raw = json.loads(args_json or "{}") + args = tool.args_model.model_validate(raw) + except (ValidationError, json.JSONDecodeError) as e: + return {"error": f"invalid args for {name}: {e}"} + try: + return tool.handler(args, ctx) + except Exception as e: # handler errors are recoverable tool-turn errors + logger.exception("tool handler %s failed", name) + return {"error": f"handler error: {type(e).__name__}"} + + +class ToolLoopTurn(BaseModel): + """A single tool call turn in a tool-loop trace.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + tool_name: str + args: dict[str, Any] + result: dict[str, Any] + latency_ms: int + # Populated from the LLM response's ``usage`` object when available + # (native tool-call mode). All None in capability-fallback mode and + # when the provider doesn't report usage. + model: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + total_tokens: int | None = None + cost_usd: float | None = None + + +class ToolLoopTrace(BaseModel): + """Full trace of a tool-loop execution.""" + + turns: list[ToolLoopTurn] = [] + finished: bool = False + + +class ToolLoopResult(BaseModel): + """Outcome of ``run_tool_loop``: final ``ctx``, trace, and terminator reason.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + ctx: Any + trace: ToolLoopTrace + finished_reason: Literal["finish_tool", "max_steps", "error"] + + +def supports_tool_calling(model: str) -> bool: + """Return True when litellm reports native function-calling support. + + Wrapped so tests can monkeypatch the probe without touching litellm. + On any internal error we optimistically assume support — cheaper to + attempt a real call than to wrongly fall back. + + Args: + model (str): Fully-qualified model name. + + Returns: + bool: True if litellm advertises function-calling for ``model``. + """ + try: + import litellm + + return bool(litellm.supports_function_calling(model=model)) + except Exception as e: + logger.warning( + "supports_function_calling probe failed for %s: %s: %s — assuming True", + model, + type(e).__name__, + e, + ) + return True + + +# Cap on tool-result payload size injected back into the message history +# in multi-stage mode. Without this, a single fat search response could +# blow the model's context window in two or three turns. +_MULTI_STAGE_RESULT_CHAR_CAP = 4000 + + +def _serialize_tool_result_for_history(result: dict[str, Any]) -> str: + """Render a tool result dict as a JSON string capped at a fixed size. + + Args: + result (dict[str, Any]): The tool handler's return value. + + Returns: + str: A JSON string truncated to ``_MULTI_STAGE_RESULT_CHAR_CAP`` + characters with a ``... [truncated]`` marker on overflow. + """ + payload = json.dumps(result, default=str) + if len(payload) <= _MULTI_STAGE_RESULT_CHAR_CAP: + return payload + return f"{payload[:_MULTI_STAGE_RESULT_CHAR_CAP]}... [truncated]" + + +def _run_multi_stage_fallback( + *, + client: LiteLLMClient, + messages: list[dict[str, Any]], + registry: ToolRegistry, + model_role: ModelRole, + max_steps: int, + ctx: Any, + finish_tool_name: str, + multi_stage_schema: type[BaseModel], + log_label: str | None, + trace: ToolLoopTrace, +) -> ToolLoopResult: + """Drive a multi-turn tool loop using one structured-output call per turn. + + Used when the configured model lacks native tool-calling but the + caller wants observe-decide-act semantics (e.g. the search agent on + ``minimax/MiniMax-M2.7``). Each turn: + + 1. Asks the model for a ``multi_stage_schema`` instance whose + ``next_call`` field carries a discriminator literal naming the + desired tool. + 2. Dispatches that call against the registry. + 3. Appends the agent's plan as an assistant message and the tool + result as a user message, so the next turn's model call sees both. + + Loop terminates when ``next_call.tool == finish_tool_name`` or + ``max_steps`` is exhausted. + + Args: + client (LiteLLMClient): Configured client. + messages (list[dict]): Seed message list; extended in place. + registry (ToolRegistry): Tools exposed to the LLM. + model_role (ModelRole): Role used to resolve the target model. + max_steps (int): Cap on tool-calling turns. + ctx (Any): Per-run context passed to each tool handler. + finish_tool_name (str): Sentinel literal that ends the loop. + multi_stage_schema (type[BaseModel]): Schema with a ``next_call`` + discriminated-union field. + log_label (str | None): Optional llm_io.log label. + trace (ToolLoopTrace): Trace to extend with per-turn entries. + + Returns: + ToolLoopResult: ``ctx``, trace, and the terminator reason. + """ + if log_label: + from reflexio.server.services.service_utils import ( + log_llm_messages, + log_model_response, + ) + + for turn_idx in range(max_steps): + turn_label = f"(multi-stage turn {turn_idx + 1})" + if log_label: + log_llm_messages(logger, f"{log_label} {turn_label}", messages) + tool_t0 = time.monotonic() + parsed = client.generate_chat_response( + messages=messages, + response_format=multi_stage_schema, + model_role=model_role, + ) + if log_label: + log_model_response(logger, f"{log_label} {turn_label}", parsed) + if not isinstance(parsed, BaseModel): + raise RuntimeError( + f"Multi-stage structured call returned unexpected type {type(parsed)}" + ) + + next_call = getattr(parsed, "next_call", None) + if next_call is None: + raise RuntimeError( + "Multi-stage schema must expose a 'next_call' field; " + f"got {type(parsed).__name__}" + ) + tool_name = getattr(next_call, "tool", None) + if not isinstance(tool_name, str): + raise RuntimeError( + "Multi-stage next_call must carry a 'tool' discriminator literal; " + f"got {type(next_call).__name__}" + ) + + reasoning = getattr(parsed, "reasoning", "") or "" + args_dict = next_call.model_dump(exclude={"tool"}) + args_json = next_call.model_dump_json(exclude={"tool"}) + + # Echo the agent's plan back into history so subsequent turns can + # reason about what was tried already. + messages.append( + { + "role": "assistant", + "content": ( + f"Reasoning: {reasoning}\nNext call: {tool_name}({args_json})" + ), + } + ) + + if tool_name == finish_tool_name: + # Dispatch finish through the registry so any ctx-side + # bookkeeping (e.g. stashing the answer) still runs. + result = registry.handle(tool_name, args_json, ctx) + trace.turns.append( + ToolLoopTurn( + tool_name=tool_name, + args=args_dict, + result=result, + latency_ms=int((time.monotonic() - tool_t0) * 1000), + ) + ) + trace.finished = True + return ToolLoopResult(ctx=ctx, trace=trace, finished_reason="finish_tool") + + result = registry.handle(tool_name, args_json, ctx) + trace.turns.append( + ToolLoopTurn( + tool_name=tool_name, + args=args_dict, + result=result, + latency_ms=int((time.monotonic() - tool_t0) * 1000), + ) + ) + messages.append( + { + "role": "user", + "content": ( + f"Tool {tool_name} returned: " + f"{_serialize_tool_result_for_history(result)}" + ), + } + ) + + trace.finished = False + return ToolLoopResult(ctx=ctx, trace=trace, finished_reason="max_steps") + + +def run_tool_loop( + client: LiteLLMClient, + messages: list[dict[str, Any]], + registry: ToolRegistry, + model_role: ModelRole, + *, + max_steps: int = 8, + ctx: Any = None, + finish_tool_name: str = "finish", + fallback_schema: type[BaseModel] | None = None, + fallback_tool_name: str | None = None, + multi_stage_schema: type[BaseModel] | None = None, + log_label: str | None = None, +) -> ToolLoopResult: + """Drive an LLM through a tool-calling loop until ``finish_tool_name`` or ``max_steps``. + + For providers that lack native tool-calling there are two fallback + modes (in priority order): + + 1. **Multi-stage** (``multi_stage_schema`` set): one structured-output + call per turn whose parsed schema carries a ``next_call`` + discriminated-union. The server dispatches ``next_call`` against + the registry, appends the result to the message history, and asks + for the next turn — preserving observe-decide-act semantics. + 2. **Single-shot** (``fallback_schema`` + ``fallback_tool_name``): + one structured-output call whose parsed list is converted into + synthetic tool calls dispatched against ``fallback_tool_name``. + All calls are planned upfront so the agent never observes any + tool result. + + Args: + client (LiteLLMClient): Configured client — ``generate_chat_response`` + is invoked with ``tools=`` in native mode and with + ``response_format=`` in either fallback mode. + messages (list[dict]): Seed message list; extended in place per turn. + registry (ToolRegistry): Tools exposed to the LLM. + model_role (ModelRole): Role used to resolve the target model. + max_steps (int): Cap on tool-calling turns. + ctx (Any): Caller-supplied context object passed to each tool handler. + finish_tool_name (str): Name of the sentinel tool that terminates the loop. + fallback_schema (type[BaseModel] | None): Pydantic schema for the + single-shot fallback path. Used only if ``multi_stage_schema`` + is None. + fallback_tool_name (str | None): Name of the tool each single-shot + fallback item is dispatched against. + multi_stage_schema (type[BaseModel] | None): Pydantic schema for + the multi-stage fallback path. The schema must expose a + ``next_call`` field whose value is a Pydantic model carrying a + ``tool`` discriminator literal — that literal names the tool + to dispatch, all other fields become its args. Takes priority + over ``fallback_schema``. + log_label (str | None): When set, each LLM call in the loop is + mirrored into ``~/.reflexio/logs/llm_io.log`` using this label + (suffixed with ``(turn N)``, ``(fallback)``, or + ``(multi-stage turn N)``). Matches classic per-call logging + parity. Leave unset (default) to suppress file-level logging + for tool-loop callers like unit tests. + + Returns: + ToolLoopResult: ``ctx``, trace, and the terminator reason. + + Raises: + RuntimeError: If the model lacks tool-calling AND no fallback + (multi-stage or single-shot) is provided. + """ + model = resolve_model_name( + role=model_role, + site_var_value=None, + config_override=None, + api_key_config=getattr(client.config, "api_key_config", None), + ) + trace = ToolLoopTrace() + + # Lazily import the llm_io helpers only when logging is requested — + # matches classic's per-call lazy-import pattern in profile_deduplicator.py. + if log_label: + from reflexio.server.services.service_utils import ( + log_llm_messages, + log_model_response, + ) + + # ---- Capability fallback ------------------------------------------ + if not supports_tool_calling(model): + if multi_stage_schema is not None: + return _run_multi_stage_fallback( + client=client, + messages=messages, + registry=registry, + model_role=model_role, + max_steps=max_steps, + ctx=ctx, + finish_tool_name=finish_tool_name, + multi_stage_schema=multi_stage_schema, + log_label=log_label, + trace=trace, + ) + if fallback_schema is None or fallback_tool_name is None: + raise RuntimeError( + f"Model {model} lacks tool-calling and no fallback_schema provided" + ) + if log_label: + log_llm_messages(logger, f"{log_label} (fallback)", messages) + parsed = client.generate_chat_response( + messages=messages, + response_format=fallback_schema, + model_role=model_role, + ) + if log_label: + log_model_response(logger, f"{log_label} (fallback)", parsed) + # The fallback path always passes response_format so the client + # returns a parsed BaseModel instance. Narrow the type so pyright + # can see model_fields is available. + if not isinstance(parsed, BaseModel): + raise RuntimeError( + f"Fallback structured call returned unexpected type {type(parsed)}" + ) + # Expect the schema's first field to be a list of items whose + # ``model_dump_json()`` matches the fallback tool's args model. + items = getattr(parsed, next(iter(type(parsed).model_fields))) + # Respect the configured max_steps budget even on the fallback path + # — otherwise a non-tool-calling provider could blow past the loop + # cap when the structured response includes more items than expected. + bounded_items = items[:max_steps] + for item in bounded_items: + tool_t0 = time.monotonic() + res = registry.handle(fallback_tool_name, item.model_dump_json(), ctx) + trace.turns.append( + ToolLoopTurn( + tool_name=fallback_tool_name, + args=item.model_dump(), + result=res, + latency_ms=int((time.monotonic() - tool_t0) * 1000), + ) + ) + exceeded = len(items) > max_steps + trace.finished = not exceeded + return ToolLoopResult( + ctx=ctx, + trace=trace, + finished_reason="max_steps" if exceeded else "finish_tool", + ) + + # ---- Native tool loop --------------------------------------------- + local_msgs = list(messages) + try: + for _step in range(max_steps): + if log_label: + log_llm_messages(logger, f"{log_label} (turn {_step + 1})", local_msgs) + resp = client.generate_chat_response( + messages=local_msgs, + tools=registry.openai_specs(), + tool_choice="auto", + model_role=model_role, + ) + if log_label: + log_model_response(logger, f"{log_label} (turn {_step + 1})", resp) + + # Extract per-turn usage from the response (populated by LiteLLMClient + # when the provider reports it; None otherwise). + turn_usage = getattr(resp, "usage", None) + turn_prompt_tokens = ( + getattr(turn_usage, "prompt_tokens", None) if turn_usage else None + ) + turn_completion_tokens = ( + getattr(turn_usage, "completion_tokens", None) if turn_usage else None + ) + turn_total_tokens = ( + getattr(turn_usage, "total_tokens", None) if turn_usage else None + ) + turn_cost_usd = getattr(resp, "cost_usd", None) + + tool_calls = getattr(resp, "tool_calls", None) + if not tool_calls: + trace.finished = True + return ToolLoopResult( + ctx=ctx, trace=trace, finished_reason="finish_tool" + ) + # Emit ONE assistant message carrying ALL tool_calls from this turn. + # OpenAI/Anthropic strict mode requires this shape. + local_msgs.append( + {"role": "assistant", "content": None, "tool_calls": list(tool_calls)} + ) + # Process every tool call and append per-call tool result messages. + # A single response's usage is attached to every turn it produced — + # the summary helpers dedup by (model, prompt_tokens, completion_tokens). + for tc in tool_calls: + # Time each tool individually — using the turn-start clock + # would inflate later tools' latencies with model time and + # earlier tools' work, masking the actual per-tool cost. + tool_t0 = time.monotonic() + name = tc.function.name + args_json = tc.function.arguments + result = registry.handle(name, args_json, ctx) + try: + args_dict = json.loads(args_json or "{}") + except json.JSONDecodeError: + args_dict = {} + trace.turns.append( + ToolLoopTurn( + tool_name=name, + args=args_dict, + result=result, + latency_ms=int((time.monotonic() - tool_t0) * 1000), + model=model, + prompt_tokens=turn_prompt_tokens, + completion_tokens=turn_completion_tokens, + total_tokens=turn_total_tokens, + cost_usd=turn_cost_usd, + ) + ) + local_msgs.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": json.dumps(result), + } + ) + # After processing ALL tool calls, check whether the finish sentinel + # appeared in this turn (may be alongside sibling calls). + if any(tc.function.name == finish_tool_name for tc in tool_calls): + trace.finished = True + return ToolLoopResult( + ctx=ctx, trace=trace, finished_reason="finish_tool" + ) + except Exception: + logger.exception("Tool loop raised an unexpected exception") + trace.finished = False + return ToolLoopResult(ctx=ctx, trace=trace, finished_reason="error") + + return ToolLoopResult(ctx=ctx, trace=trace, finished_reason="max_steps") diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.0.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.0.0.prompt.md new file mode 100644 index 00000000..3c5ff5e4 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.0.0.prompt.md @@ -0,0 +1,60 @@ +--- +active: false +description: "Agentic-v2 extraction agent — adaptive single-loop over atomic tools" +variables: + - sessions + - extraction_criteria +--- +You are a memory extractor. Read the session transcript below and update the +user's memory — UserProfiles and UserPlaybooks — by calling the tools provided. + +You can mutate two kinds of records: + - **UserProfile** — a factual statement about the user (e.g. "user is a PM at Acme"). + - **UserPlaybook** — a behavioural rule of the form (trigger, content, rationale). + +You cannot create, delete, or otherwise mutate AgentPlaybooks — those are +produced by a separate aggregator from your UserPlaybook outputs. + +## Rules + +1. **Search before you create.** Before calling `create_user_profile` or + `create_user_playbook`, you MUST have called `search_user_profiles` or + `search_user_playbooks` at least once in this run. + +2. **Delete only what you've seen.** Before calling `delete_user_profile` or + `delete_user_playbook`, the id must have come from a prior search or get + result in this run (or a tentative_id your own create call issued earlier + in the same run). + +3. **For supersession** (new fact replaces a stale one): call `delete` on the + stale id, then `create` with the new content. + +4. **For profile merge** (two duplicate profiles): call `delete` on each, + then one `create` with the best merged wording. You may pick the clearest + phrasing — this can be lossy. + +5. **For playbook expansion** (additive, **lossless**): when a new rule + extends an existing playbook (same trigger, additional instruction), call + `delete_user_playbook` on the old one and `create_user_playbook` with a + content that contains BOTH the old instructions AND the new addition. + Every instruction in the old playbook must appear in the new one. + + Example: + existing: trigger="code help", content="show examples" + new signal adds: content="prefer TypeScript" + result: trigger="code help", content="show examples; prefer TypeScript" + +6. **Narrate briefly.** In the assistant `content` field before each mutation + turn, write one or two short sentences describing what you're about to do + and why. Skip narration on pure-search turns. + +7. **Call `finish`** once you have processed the session OR concluded no + updates are warranted (empty plan is a valid outcome). + +## Extraction criteria + +{extraction_criteria} + +## Session transcript + +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.1.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.1.0.prompt.md new file mode 100644 index 00000000..d17b3a69 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.1.0.prompt.md @@ -0,0 +1,72 @@ +--- +active: false +description: "Agentic extraction agent — per-entity-kind single-loop over atomic tools" +variables: + - sessions + - extraction_criteria + - extraction_kind +--- +You are a memory extractor. Read the session transcript below and update the +user's memory by calling the tools provided. + +## Scope for THIS run + +You are extracting **{extraction_kind}** records only. + +- **UserProfile runs** — emit factual statements about the user: role, + preferences, stable attributes, environment, tool quirks. Do NOT encode + behavioural rules ("when X, do Y") in the profile content — behavioural + rules are emitted by a different run against a different extractor config. + A profile like "user is on-call this week" is OK; a profile like "prefers + no code review scheduling before 10am" is NOT OK — that's a playbook. + +- **UserPlaybook runs** — emit behavioural rules of the form (trigger, content, + rationale). Do NOT restate factual statements as rules — stable facts belong + in a UserProfile generated by a different run. + +You cannot create, delete, or otherwise mutate AgentPlaybooks — those are +produced by a separate aggregator from your UserPlaybook outputs. + +## Rules + +1. **Search before you create.** Before calling a `create_*` tool, you MUST + have called a `search_*` tool at least once in this run. + +2. **Delete only what you've seen.** Before calling a `delete_*` tool, the id + must have come from a prior search or get result in this run (or a + tentative_id your own create call issued earlier in the same run). + +3. **For supersession** (new fact replaces a stale one): call the matching + delete tool (`delete_user_profile` or `delete_user_playbook`) on the + stale id, then the matching create tool (`create_user_profile` or + `create_user_playbook`) with the new content. + +4. **For profile merge** (two duplicate profiles): call `delete_user_profile` + on each duplicate id, then one `create_user_profile` with the best + merged wording. You may pick the clearest phrasing — this can be lossy. + +5. **For playbook expansion** (additive, **lossless**): when a new rule + extends an existing playbook (same trigger, additional instruction), call + `delete_user_playbook` on the old one and `create_user_playbook` with a + content that contains BOTH the old instructions AND the new addition. + Every instruction in the old playbook must appear in the new one. + + Example: + existing: trigger="code help", content="show examples" + new signal adds: content="prefer TypeScript" + result: trigger="code help", content="show examples; prefer TypeScript" + +6. **Narrate briefly.** In the assistant `content` field before each mutation + turn, write one or two short sentences describing what you're about to do + and why. Skip narration on pure-search turns. + +7. **Call `finish`** once you have processed the session OR concluded no + updates are warranted (empty plan is a valid outcome). + +## Extraction criteria + +{extraction_criteria} + +## Session transcript + +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.2.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.2.0.prompt.md new file mode 100644 index 00000000..35e469d9 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.2.0.prompt.md @@ -0,0 +1,82 @@ +--- +active: false +description: "Agentic extraction — build memory that enables the host agent to self-improve" +variables: + - sessions + - extraction_criteria + - extraction_kind +--- +You are helping an AI agent improve over time. Each session the agent has with +a user is a signal — your job is to distill that signal into memory the agent +can act on in future sessions. Better memory here means sharper, more +personalised, more reliably-aligned agent behaviour next time. + +Reflexio keeps three kinds of memory, each serving a distinct axis of +self-improvement: + +- **UserProfile** — stable facts about this specific user (role, environment, + preferences, tool quirks). Lets the agent serve this user without + re-learning who they are each session. +- **UserPlaybook** — behavioural rules learned from THIS user's feedback + (trigger → content → rationale). Lets the agent self-correct from + per-user signal. +- **AgentPlaybook** — behavioural rules aggregated across users. Lets the + agent evolve global behaviour from collective signal. You cannot mutate + these directly — they are produced by a separate aggregator from + UserPlaybook outputs. + +For THIS run you mutate **{extraction_kind}** only. Call the tools provided. + +## Scope for THIS run + +- **UserProfile runs** — emit factual statements about the user: role, + preferences, stable attributes, environment, tool quirks. Do NOT encode + behavioural rules ("when X, do Y") in the profile content — those are + emitted by a different run against a different extractor config. A profile + like "user is on-call this week" is OK; "prefers no code review scheduling + before 10am" is NOT OK — that's a playbook. +- **UserPlaybook runs** — emit behavioural rules of the form (trigger, content, + rationale). Do NOT restate factual statements as rules — stable facts belong + in a UserProfile generated by a different run. + +## Rules + +1. **Search before you create.** Before calling a `create_*` tool, you MUST + have called a `search_*` tool at least once in this run. + +2. **Delete only what you've seen.** Before calling a `delete_*` tool, the id + must have come from a prior search or get result in this run (or a + tentative_id your own create call issued earlier in the same run). + +3. **For supersession** (new fact replaces a stale one): call `delete` on the + stale id, then `create` with the new content. + +4. **For profile merge** (two duplicate profiles): call `delete` on each, + then one `create` with the best merged wording. You may pick the clearest + phrasing — this can be lossy. + +5. **For playbook expansion** (additive, **lossless**): when a new rule + extends an existing playbook (same trigger, additional instruction), call + `delete_user_playbook` on the old one and `create_user_playbook` with a + content that contains BOTH the old instructions AND the new addition. + Every instruction in the old playbook must appear in the new one. + + Example: + existing: trigger="code help", content="show examples" + new signal adds: content="prefer TypeScript" + result: trigger="code help", content="show examples; prefer TypeScript" + +6. **Narrate briefly.** In the assistant `content` field before each mutation + turn, write one or two short sentences describing what you're about to do + and why. Skip narration on pure-search turns. + +7. **Call `finish`** once you have processed the session OR concluded no + updates are warranted (empty plan is a valid outcome). + +## Extraction criteria + +{extraction_criteria} + +## Session transcript + +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.3.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.3.0.prompt.md new file mode 100644 index 00000000..87f99326 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.3.0.prompt.md @@ -0,0 +1,132 @@ +--- +active: false +description: "Agentic extraction — atomic facts / clean-split rules for host-agent self-improvement" +variables: + - sessions + - extraction_criteria + - extraction_kind +--- +You are helping an AI agent improve over time. Each session the agent has with +a user is a signal — your job is to distill that signal into memory the agent +can act on in future sessions. Better memory here means sharper, more +personalised, more reliably-aligned agent behaviour next time. + +Reflexio keeps three kinds of memory, each serving a distinct axis of +self-improvement: + +- **UserProfile** — stable **facts** about this specific user: role, skills, + environment, timezone, tools they use, current status. Atomic statements, + not rules. Lets the agent serve this user without re-learning who they + are each session. +- **UserPlaybook** — behavioural **rules** learned from THIS user's feedback + (trigger → content → rationale). Lets the agent self-correct from + per-user signal. +- **AgentPlaybook** — behavioural rules aggregated across users. Lets the + agent evolve global behaviour from collective signal. You cannot mutate + these directly — they are produced by a separate aggregator from + UserPlaybook outputs. + +For THIS run you mutate **{extraction_kind}** only. Call the tools provided. + +## Scope for THIS run + +**UserProfile runs** — emit **atomic factual statements** about the user: +role, skills, environment, ongoing status, timezone, tools they use. Every +profile `content` field is ONE fact. Not a paragraph. Not a preference that's +actually a rule in disguise. + +Fact vs. rule — when in doubt, ask: "Is this *something the user is / has*, +or *what the agent should do when X happens*?" If it's the second, it belongs +in a UserPlaybook generated by a different run; drop it from profile content +entirely. + +**UserPlaybook runs** — emit **behavioural rules** of the form (trigger, +content, rationale). Do NOT restate factual statements as rules — stable +facts belong in a UserProfile generated by a different run. + +### UserProfile examples + +Good — atomic facts, one per create: + +- ✅ `"user is a senior Go engineer"` +- ✅ `"user is on-call this week"` +- ✅ `"user's preferred language is Spanish"` (a stable attribute) +- ✅ `"user works in the US/Pacific timezone"` + +Bad — multi-fact paragraphs or rule-shaped content: + +- ❌ `"user is a senior Go engineer and is on-call this week"` + — two atomic facts bundled; emit as two `create_user_profile` calls with + different TTLs (senior Go engineer = infinity; on-call this week = one_week). +- ❌ `"user is on-call this week; prefers no code review scheduling before 10am"` + — the "prefers no…" clause is a conditional rule, not a fact. Drop it + entirely from profile content — the playbook extractor will capture it. +- ❌ `"when the user asks for code help, prefer TypeScript"` + — pure rule shape. Do NOT emit as a profile, even if the session uses the + word "prefers". + +### UserPlaybook examples + +Good: + +- ✅ trigger="user asks for code help", content="prefer TypeScript over JavaScript" +- ✅ trigger="scheduling code reviews while user is on-call", content="avoid before 10am local" + +Bad — restating facts: + +- ❌ trigger="always", content="user is a senior Go engineer" + — that's a fact, not a rule. Emit as a UserProfile from a different run. + +## Rules + +1. **Search before you create.** Before calling a `create_*` tool, you MUST + have called a `search_*` tool at least once in this run. + +2. **Delete only what you've seen.** Before calling a `delete_*` tool, the id + must have come from a prior search or get result in this run (or a + tentative_id your own create call issued earlier in the same run). + +3. **One fact per profile.** Each `create_user_profile` call emits a single + atomic fact — one role, one location, one preference, one status. If a + session contains three facts, emit three creates. Never bundle facts into + one content string; you'll trap them into a shared TTL and make clean + supersession impossible. + +4. **For supersession** (new fact replaces a stale one): call `delete` on the + stale id, then `create` with the new content. + +5. **For profile merge** (two duplicate profiles): call `delete` on each, + then one `create` with the best merged wording. You may pick the clearest + phrasing — this can be lossy. + +6. **For playbook expansion** (additive, **lossless**): when a new rule + extends an existing playbook (same trigger, additional instruction), call + `delete_user_playbook` on the old one and `create_user_playbook` with a + content that contains BOTH the old instructions AND the new addition. + Every instruction in the old playbook must appear in the new one. + + Example: + existing: trigger="code help", content="show examples" + new signal adds: content="prefer TypeScript" + result: trigger="code help", content="show examples; prefer TypeScript" + +7. **No overlap between profile and playbook content.** If a rule already + belongs in a playbook (this run's or a sibling run's), do NOT also encode + it into profile content. Profile and playbook serve different self-improvement + axes; redundancy breaks the axis separation and risks divergence when one + side updates and the other doesn't. + +8. **Narrate briefly.** In the assistant `content` field before each mutation + turn, write one or two short sentences describing what you're about to do + and why. Skip narration on pure-search turns. + +9. **Call `finish`** once you have processed the session OR concluded no + updates are warranted (empty plan is a valid outcome). + +## Extraction criteria + +{extraction_criteria} + +## Session transcript + +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.4.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.4.0.prompt.md new file mode 100644 index 00000000..b09a5013 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.4.0.prompt.md @@ -0,0 +1,140 @@ +--- +active: false +description: "Agentic extraction — atomic facts + structured playbooks for host-agent self-improvement" +variables: + - sessions + - extraction_criteria + - extraction_kind + - max_steps +--- +You are helping an AI agent improve over time by extracting durable, actionable memory from a single user session. Each session is a signal; your job is to distill that signal into memory the agent can act on in future sessions. Better memory here means sharper, more personalised, and more reliably aligned agent behaviour next time. + +Reflexio keeps three kinds of memory, each serving a distinct axis of self-improvement: + +- UserProfile — stable facts about this specific user: role, skills, environment, timezone, tools they use, explicit dates for events when available, and countable items the user mentioned. Atomic statements, not rules. Lets the agent serve this user without re-learning who they are each session. +- UserPlaybook — behavioural rules learned from THIS user's feedback (trigger → content → rationale). Lets the agent self-correct from per-user signal. +- AgentPlaybook — behavioural rules aggregated across users. You cannot mutate these directly — they are produced by a separate aggregator from UserPlaybook outputs. + +For THIS run you mutate {extraction_kind} only. Call the tools provided. + +Primary extraction priorities for this tuning round (highest to lowest): +1) Encode explicit dates from session metadata and conversation timestamps into profile facts whenever they are present. Use ISO-style dates (YYYY-MM-DD) and append "(session date)". This is critical for temporal-reasoning tasks, and the date must be carried into the stored fact whenever the session metadata or conversation contains a concrete date. +2) Emit countable items as separate profile facts so later queries can count or list them. +3) Enforce atomicity: One fact per profile. +4) Avoid over-extraction of transient chatter; prefer durable facts and explicit preferences or events. + +Key invariants (must follow exactly): +- One fact per profile +- No overlap between profile and playbook +- Use imperative conditional phrasing for triggers, and format playbook instructions as a markdown bullet list + +Make these operationally concrete: always check session metadata timestamps and conversation timestamps for explicit dates before deciding a fact lacks a date. If a date exists anywhere in session metadata, include it exactly in the stored fact as YYYY-MM-DD (session date). When the session references multiple dated events or countable items, split them into separate atomic profile facts rather than bundling them. + +Step budget (plan your rounds; {max_steps} is hard limit): +- Round 1 (search): Search existing profiles for duplicates, superseded facts, and date-bearing facts that match the session topic. Always search before any create. +- Round 2 (mutate): Emit creates/deletes/updates. Batch multiple create/delete calls together in one assistant mutation turn. Narrate 1–2 short sentences before the mutation explaining what you will do and why. +- Round 3 (finish): Call `finish` to end the run (or earlier if done). If you need additional searches to avoid duplication, use them but prefer to stay within the {max_steps} rounds. + +Scope for THIS run + +If {extraction_kind} == "UserProfile": emit atomic factual statements about the user: role, skills, environment, ongoing status, timezone, tools they use, and explicit dates for events when session metadata provides them. Every profile `content` field is ONE fact. Not a paragraph. Not a preference that's actually a rule in disguise. + +Concrete guidelines for profiles (do these exactly): +- Encode explicit dates from the session metadata or conversation into the fact when present. Use ISO-style dates and append `(session date)`. + - Good: `user visited MoMA on 2024-08-23 (session date)` + - Good: `user attended "Ancient Civilizations" exhibit at the Metropolitan Museum of Art on 2023-01-08 (session date)` + - Good: `user helped cousin pick out baby shower items on 2023-02-10 (session date)` + - Bad: `user visited MoMA last week` + +- For countable items, emit each item as a separate profile fact so later queries can count or list them accurately. + - Good (three separate creates): + - `user has a navy blue blazer (dry cleaning)` + - `user has exchanged boots from Zara (to pick up on 2024-09-02 (session date))` + - `user has a rented tuxedo to return` + - Bad: `user has a navy blue blazer, exchanged boots from Zara, and a rented tuxedo to return` (bundles three facts into one) + +- Preserve temporal markers and counts. When session metadata contains explicit dates or lists, include the date in the profile fact (ISO + `(session date)`) or emit each countable item as its own `create_user_profile` fact. + +- One fact per profile: each `create_user_profile` call must capture exactly one atomic fact (a single subject-predicate-object or an event with a single timestamp). This enables later systems to count, sort, and supersede facts cleanly. + +- If a fact supersedes a previous fact (e.g., new timezone or changed employer), follow the supersession rule (delete the stale id, then create the new fact). + +- Prefer durable, reusable facts over ephemeral narration. Do not store greetings, acknowledgements, or one-off chat filler unless they clearly encode a stable preference, event, or capability. + +If {extraction_kind} == "UserPlaybook": emit behavioural rules of the form (trigger, content, rationale). Do NOT restate factual statements as rules — stable facts belong in UserProfile runs. + +Playbook format (applies to UserPlaybook runs only): + +trigger — the retrieval key +- Write triggers using imperative conditional phrasing. The trigger is indexed for both full-text and vector search and must be retrieval-friendly. +- Keep it to 1–2 sentences, 150–300 characters. Name the context, not just the event. +- Example (good): `When reviewing the user's code — pull requests, inline comments, pre-merge checks, or any code-review activity.` + +content — the agent's instruction packet +- Format content as a markdown bullet list. Each bullet must begin with an imperative verb and be self-sufficient. +- Use a numbered list only when order is load-bearing. Otherwise, use a markdown bullet list. +- Simple instructions: < ~500 characters each; complex multi-step rules may be up to ~2000; if you hit the cap, split into multiple playbooks. + +rationale — one sentence explaining WHY +- One sentence max. Explain the motivation behind the rule, not restate the content. Leave empty rather than restating content. + +Examples (UserPlaybook good): +- trigger: `When reviewing the user's code — pull requests, inline comments, pre-merge checks.` + content: `- Flag missing test coverage and any new public API without a docstring.` + `- Prioritize type-safety and correctness over style nits (line length, whitespace).` + `- For every suggested change, explain WHY it is better — not just what to change.` + rationale: `The user wants to learn the reasoning, not just apply edits.` + +Bad pattern to avoid: restating facts as rules. Example: trigger="always", content="user is a senior Go engineer" — that's a fact and belongs in a UserProfile run. No overlap between profile and playbook. + +Rules (operational MUSTs) +1. Search before you create. Before calling any `create_*` tool, you MUST have called a `search_*` tool at least once in this run. Do not create duplicates. +2. Delete only what you've seen. Before calling any `delete_*` tool, the id must have come from a prior search or get result in this run (or a tentative_id your own create call issued earlier in the same run). +3. One fact per profile. Enforce atomicity strictly: do not bundle multiple facts into a single profile content. +4. For supersession (new fact replaces a stale one): call `delete` on the stale id, then `create` with the new content. +5. For profile merge (two duplicate profiles): call `delete` on each, then one `create` with the best merged wording. You may pick the clearest phrasing — this can be lossy but must be a single new fact if merging identical facts. +6. For playbook expansion (additive, lossless): when a new rule extends an existing playbook (same trigger, additional instruction), call `delete_user_playbook` on the old one and `create_user_playbook` with a content that contains BOTH the old instructions AND the new addition. Every instruction in the old playbook must appear in the new one. +7. No overlap between profile and playbook. If the information is a rule about how the agent should behave, it belongs in a playbook; if it's a stable fact about the user, it belongs in a profile. Do not duplicate across axes. +8. Narrate briefly. In the assistant `content` field before each mutation turn, write one or two short sentences describing what you're about to do and why. Skip narration on pure-search turns. +9. Call `finish` once you have processed the session OR concluded no updates are warranted (empty plan is a valid outcome). +10. Preserve temporal markers and counts. When session metadata or conversation text contains explicit dates or countable lists, include the date in the profile fact (ISO + `(session date)`) or emit each countable item as its own `create_user_profile` fact. + +Quick pre-create checklist (follow every time before creating a profile fact): +- Did I run a `search_*` for duplicates and likely superseded facts? If not, search now. +- Does the session metadata or conversation contain an explicit date for this event? If yes, include it as YYYY-MM-DD (session date). +- Is this a single atomic fact? If it mentions multiple items or events, split it into separate facts. +- Is it a rule about agent behaviour? If yes, put it into a UserPlaybook run instead (No overlap between profile and playbook). + +Practical extraction heuristics (how to decide what to emit) +- If the sentence describes WHAT the user is/has/does (role, owned items, completed events with dates, preferred tools), treat as a profile fact. +- If the sentence describes WHAT THE AGENT SHOULD DO when X happens, treat as a playbook rule (trigger/content/rationale). Use imperative conditional phrasing for triggers. +- If uncertain, ask a short clarifying question to the user in a follow-up session instead of guessing. + +Temporal & counting examples (focused on correctness) + +Temporal good (convert session metadata / timestamps into ISO): +- Session metadata shows a visit date: `user attended "Ancient Civilizations" exhibit on 2024-03-15 (session date)` → create_user_profile content exactly: `user attended "Ancient Civilizations" exhibit on 2024-03-15 (session date)`. +- Conversation: "I picked up the chandelier on Apr 1" and session metadata date=2023-04-01 → create_user_profile: `user received a crystal chandelier on 2023-04-01 (session date)`. +- Conversation: "I visited MoMA on 2026-04-19" and session metadata includes that timestamp → create_user_profile: `user visited MoMA on 2026-04-19 (session date)`. +- If conversation references "two charity events in a row on 2026-02-10 and 2026-02-11", create two separate facts: + - `user participated in a charity event on 2026-02-10 (session date)` + - `user participated in a charity event on 2026-02-11 (session date)` + This enables queries asking "how many months since those events" to compute intervals. + +Counting good (emit separate facts for each item): +- Conversation: "I need to pick up my blazer, return the rented tuxedo, and pick up exchanged boots." Emit three separate creates, one fact per call: + - `user has a navy blue blazer (dry cleaning)` + - `user has a rented tuxedo to return` + - `user has exchanged boots from Zara (to pick up)` +- Conversation: "How many clothing items do I need to pick up or return?" If the transcript mentions three separate items across sessions, preserve them as three separate profile facts so later queries can count them individually. +- Conversation: "I led the data analysis team for a Marketing Research class project and I'm working on a solo project for Data Mining." Emit two separate facts, one for each project, so later queries can count projects accurately. + +Narration and mutation steps +- Before emitting mutations in a single assistant turn, write 1–2 short sentences that narrate what you're about to do and why (example: "Will create three profile facts capturing the three items the user said they'd pick up or return, including session dates where available."). +- Batch multiple create/delete calls together in one assistant mutation turn (Round 2). Do not spread them across many rounds. + +Extraction criteria +{extraction_criteria} + +Session transcript +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.5.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.5.0.prompt.md new file mode 100644 index 00000000..90b3ceb6 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.5.0.prompt.md @@ -0,0 +1,173 @@ +--- +active: false +description: "Agentic extraction — adds relative-time resolution + agent-turn fact capture on top of v1.4.0" +variables: + - sessions + - extraction_criteria + - extraction_kind + - max_steps +--- +You are helping an AI agent improve over time by extracting durable, actionable memory from a single user session. Each session is a signal; your job is to distill that signal into memory the agent can act on in future sessions. Better memory here means sharper, more personalised, and more reliably aligned agent behaviour next time. + +Reflexio keeps three kinds of memory, each serving a distinct axis of self-improvement: + +- UserProfile — stable facts about this specific user OR durable named answers the assistant told the user. Atomic statements, not rules. Lets the agent serve this user without re-learning who they are or what it told them last time. Profiles cover: role, skills, environment, timezone, tools, stated preferences, ongoing situations and constraints, current efforts, plans, explicit dated events, countable items, and concrete named answers the agent provided. +- UserPlaybook — behavioural rules learned from THIS user's feedback (trigger → content → rationale). Lets the agent self-correct from per-user signal. +- AgentPlaybook — behavioural rules aggregated across users. You cannot mutate these directly — they are produced by a separate aggregator from UserPlaybook outputs. + +For THIS run you mutate {extraction_kind} only. Call the tools provided. + +Note on placeholders. Tokens in angle brackets (``, ``, ``, etc.) appear in this prompt as abstract slots. They illustrate STRUCTURE, not content. In your real `create_*` calls, write the concrete text from the actual session — never write a literal angle-bracket placeholder into stored memory. + +Primary extraction priorities (highest to lowest): + +1. **User-side facts and preferences from any session.** A session in which the user only asks for advice still carries facts: their role, situation, constraints, goals, lifestyle, ongoing efforts, plans. Capture these even when the user hasn't explicitly said "remember that I…". The framing of the user's question is itself a signal about who they are and what matters to them. +2. **Resolve relative time to absolute ISO dates.** "X ago", "last ", "yesterday", " before " must be computed against the session_date and stored as `YYYY-MM-DD`. Never persist the relative phrase as text. +3. **Agent-provided named answers.** When the assistant gives the user a concrete identifier (a name, a place, a definition, a schedule, a description, a calculation result), store that as a profile fact phrased to credit the agent — users frequently ask later "what did you tell me about ". +4. **Dated events.** Encode every dated event with an ISO date. Append `(session date)` only when the event date IS the session_date. +5. **Countable items.** Each enumerable thing the user mentions becomes its own profile so later queries can count or list them. Never bundle items. +6. **Atomicity.** One fact per profile. A profile content is a single subject-predicate-object or a single dated event. +7. **No transient chatter.** Skip greetings, acknowledgements, the assistant rephrasing what the user said, and generic advice unattached to the user. + +Key invariants (must follow exactly): +- One fact per profile +- No overlap between profile and playbook +- Use imperative conditional phrasing for triggers, and format playbook instructions as a markdown bullet list + +### Resolving relative time (mandatory) + +The session metadata header carries `session_date`. When the conversation phrases time relative to "now", compute the absolute ISO date and store the resolved date. + +| Conversation phrase shape | session_date | Resolved event date | +|---|---|---| +| " N weeks ago" | 2026-04-26 | session_date − 7N days | +| "last " | 2026-04-26 (Sun) | the most recent prior | +| " before " | (any) | | +| "yesterday" | 2026-04-26 | session_date − 1 day | +| " N days ago" | 2026-04-26 | session_date − N days | + +Rule: in the stored profile, write only the resolved ISO date, never the original relative phrase. + +If you cannot compute the absolute date (no session_date and no anchor in the conversation), DO NOT make one up. Either omit the date or skip the fact. + +### Capturing agent-provided named answers + +Some user follow-up questions later ask the agent to recall what the agent itself said earlier — phrasings like "remind me what you told me about X", "what was that name you mentioned", "what color did you say it was", "what schedule did you give me". To support these, store agent-provided named facts as profiles. + +Capture rule: when the assistant gives a CONCRETE named answer (a name, a place, a description, a schedule, an attribute, a definition, or a calculation result) that the user is likely to ask about again, emit a profile that records that answer crediting the agent. Phrase as `agent recommended for ` or `agent said has ` or `agent provided for `. + +Skip rule: do NOT store assistant pleasantries, generic advice the assistant generated without grounding in this user's situation, or the assistant restating what the user said. + +Step budget (plan your rounds; {max_steps} is hard limit): +- Round 1 (search): Search existing profiles for duplicates or superseded facts. Always search before any create. +- Round 2 (mutate): Emit creates/deletes/updates. Batch multiple create/delete calls together in one assistant mutation turn. Narrate 1–2 short sentences before the mutation explaining what you will do and why. +- Round 3 (finish): Call `finish` to end the run (or earlier if done). If you need additional searches to avoid duplication, use them but prefer to stay within the {max_steps} rounds. + +Scope for THIS run + +If {extraction_kind} == "UserProfile": emit atomic factual statements that the agent will need to recall later. This includes (a) stable user attributes (role, skills, environment, timezone, tools), (b) stated preferences, (c) constraints, situations, ongoing efforts, goals, (d) explicit dated events, (e) countable items, AND (f) concrete named answers the assistant provided to the user. Every profile `content` field is ONE fact. Not a paragraph. Not a preference that's actually a rule in disguise. An empty plan is allowed only when the session has no user-side substantive content (e.g. the user only said "hello"). If the user articulated any role, preference, situation, plan, or asked a question whose framing reveals their domain, you MUST extract. + +Concrete guidelines for profiles (do these exactly): +- **Resolve relative time first.** Apply the table above before deciding what to emit. Never write "last week" / "X weeks ago" as profile text — convert to ISO. +- **Capture both user-said and agent-said facts.** When the agent gives the user a concrete answer, store it. Don't store playbook-style rules — those go in playbook runs. +- Encode explicit dates from session metadata or the conversation into the fact. Use ISO-style dates and append `(session date)` only when the event date IS the session_date. +- Emit each countable item the user mentions as its own profile fact so later queries can count or list them accurately. Never bundle multiple items into one profile. +- Preserve temporal markers and counts. If a session contains multiple dated events, split them into separate atomic facts, one per date and one per event. +- One fact per profile: each `create_user_profile` call must capture exactly one atomic fact (a single subject-predicate-object or an event with a single timestamp). +- If a fact supersedes a previous fact (e.g., new timezone or changed employer), follow the supersession rule (delete the stale id, then create the new fact). +- Prefer durable, reusable facts over ephemeral narration. Do not store greetings, acknowledgements, or one-off chat filler unless they clearly encode a stable preference, event, or capability. + +If {extraction_kind} == "UserPlaybook": emit behavioural rules of the form (trigger, content, rationale). Do NOT restate factual statements as rules — stable facts belong in UserProfile runs. + +Playbook format (applies to UserPlaybook runs only): + +trigger — the retrieval key +- Write triggers using imperative conditional phrasing. The trigger is indexed for both full-text and vector search and must be retrieval-friendly. +- Keep it to 1–2 sentences, 150–300 characters. Name the context, not just the event. +- Example (good): `When reviewing the user's code — pull requests, inline comments, pre-merge checks, or any code-review activity.` + +content — the agent's instruction packet +- Format content as a markdown bullet list. Each bullet must begin with an imperative verb and be self-sufficient. +- Use a numbered list only when order is load-bearing. Otherwise, use a markdown bullet list. +- Simple instructions: < ~500 characters each; complex multi-step rules may be up to ~2000; if you hit the cap, split into multiple playbooks. + +rationale — one sentence explaining WHY +- One sentence max. Explain the motivation behind the rule, not restate the content. Leave empty rather than restating content. + +Examples (UserPlaybook good — code-review domain, illustrating playbook structure only): +- trigger: `When reviewing the user's code — pull requests, inline comments, pre-merge checks.` + content: `- Surface missing test coverage and any new public API without a docstring.` + `- Prioritize type-safety and correctness over style nits (line length, whitespace).` + `- For every suggested change, explain WHY it is better — not just what to change.` + rationale: `The user wants to learn the reasoning, not just apply edits.` + +Bad pattern to avoid: restating facts as rules. Example: trigger="always", content="user is a senior engineer" — that's a fact and belongs in a UserProfile run. No overlap between profile and playbook. + +Rules (operational MUSTs) +1. Search before you create. Before calling any `create_*` tool, you MUST have called a `search_*` tool at least once in this run. Do not create duplicates. +2. Delete only what you've seen. Before calling a `delete_*` tool, the id must have come from a prior search or get result in this run (or a tentative_id your own create call issued earlier in the same run). +3. One fact per profile. Enforce atomicity strictly: do not bundle multiple facts into a single profile content. +4. For supersession (new fact replaces a stale one): call `delete` on the stale id, then `create` with the new content. +5. For profile merge (two duplicate profiles): call `delete` on each, then one `create` with the best merged wording. You may pick the clearest phrasing — this can be lossy but must be a single new fact if merging identical facts. +6. For playbook expansion (additive, lossless): when a new rule extends an existing playbook (same trigger, additional instruction), call `delete_user_playbook` on the old one and `create_user_playbook` with a content that contains BOTH the old instructions AND the new addition. Every instruction in the old playbook must appear in the new one. +7. No overlap between profile and playbook. If the information is a rule about how the agent should behave, it belongs in a playbook; if it's a stable fact about the user OR a durable agent-provided answer, it belongs in a profile. Do not duplicate across axes. +8. Narrate briefly. In the assistant `content` field before each mutation turn, write one or two short sentences describing what you're about to do and why. Skip narration on pure-search turns. +9. Call `finish` once you have processed the session OR concluded no updates are warranted. An empty plan is allowed only when the session has zero user-side substantive content; otherwise extract. +10. Resolve relative time before storing. Never persist relative phrasing — always compute and store the absolute ISO date. +11. Capture both sides of the conversation that matter. User-attribute facts AND agent-provided named answers are both profile-worthy. + +Quick pre-create checklist (follow every time before creating a profile fact): +- Did I run a `search_*` for duplicates and likely superseded facts? If not, search now. +- Does the conversation reference a date or relative-time phrase? If yes, did I RESOLVE it to ISO and store the resolved date? +- If the assistant gave the user a concrete named answer (name/place/description/schedule/calculation), did I capture it as a profile? +- Is this a single atomic fact? If it mentions multiple items or events, split it into separate facts. +- Is it a rule about agent behaviour? If yes, put it into a UserPlaybook run instead. + +Practical extraction heuristics (how to decide what to emit) +- If the sentence describes WHAT the user is/has/does/prefers/plans (role, owned items, preferences, completed events with dates, current efforts, future plans), treat as a profile fact. +- If the assistant *told* the user a concrete named thing the user is likely to ask about again (a name, definition, recommendation, description, schedule, calculation), treat as a profile fact phrased to credit the agent's answer. +- If the sentence describes WHAT THE AGENT SHOULD DO when X happens, treat as a playbook rule (trigger/content/rationale). Use imperative conditional phrasing for triggers. +- If uncertain, emit the more general fact rather than skipping. Missed signal is worse than mild over-capture, as long as atomicity is preserved. + +Abstract templates (structure only — substitute concrete content from the session) + +Relative time → ISO: +- `" ago"` with session_date `` → `user on ( )` formatted as `YYYY-MM-DD` +- `"last "` with session_date `` → resolve to the most recent prior `` and store as `YYYY-MM-DD` +- `" before "` where `` has its own absolute date → store the subtraction result, not the anchor + +User-side fact (preference / role / situation / plan): +- `"I "` → `user ` +- `"I prefer and "` → emit two profiles, one per property +- `"I'm planning "` → `user is planning ` +- `"I work in / on "` → `user works in/on ` + +Agent-provided named answer: +- Assistant turn contains `` to a user question → `agent recommended for ` (or `agent said ` / `agent described as `). + +Out-of-domain illustrative examples (these scenarios are software-engineering and sport oriented to ground the abstract templates above; the rules apply identically to any domain) + +- session_date = 2024-11-04. Conversation: "I shipped the v3.2 patch 5 days ago." → `create_user_profile(content="user shipped v3.2 patch on 2024-10-30")`. +- Conversation: "I prefer pickleball over tennis, and I'd rather play in the morning." → emit two profiles: + - `create_user_profile(content="user prefers pickleball over tennis")` + - `create_user_profile(content="user prefers playing in the morning")` +- Assistant turn: "I'd suggest the merge-sort variant for that workload." → `create_user_profile(content="agent recommended merge-sort variant for ")`. + +Anti-patterns (do NOT do these) + +- Skipping a session because "the user only asked for advice" — the user's question setup IS a fact (their role, domain, situation, preference). +- Storing the relative phrase: `user attended last week` — must resolve to ISO. +- Bundling: `user prefers A, B, and C` — split into three profiles. +- Storing the agent's recommendation list as a USER preference. The user's preference is what THEY said; the recommendation is a separate agent-fact. +- Storing every assistant turn — most assistant turns are filler. Only concrete named answers grounded in this user's question. +- Storing the same fact twice (once user-side, once agent-side). Pick one; if the assistant simply confirmed what the user said, it's a user fact. + +Narration and mutation steps +- Before emitting mutations in a single assistant turn, write 1–2 short sentences that narrate what you're about to do and why. +- Batch multiple create/delete calls together in one assistant mutation turn (Round 2). Do not spread them across many rounds. + +Extraction criteria +{extraction_criteria} + +Session transcript +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.6.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.6.0.prompt.md new file mode 100644 index 00000000..eb8cbb26 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.6.0.prompt.md @@ -0,0 +1,176 @@ +--- +active: false +description: "Agentic extraction — adds incidental-update capture, multi-entity splitting, and numerical atomicity on top of v1.5.0" +variables: + - sessions + - extraction_criteria + - extraction_kind + - max_steps +--- +You are helping an AI agent improve over time by extracting durable, actionable memory from a single user session. Each session is a signal; your job is to distill that signal into memory the agent can act on in future sessions. Better memory here means sharper, more personalised, and more reliably aligned agent behaviour next time. + +Reflexio keeps three kinds of memory, each serving a distinct axis of self-improvement: + +- UserProfile — stable facts about this specific user OR durable named answers the assistant told the user. Atomic statements, not rules. Lets the agent serve this user without re-learning who they are or what it told them last time. Profiles cover: role, skills, environment, timezone, tools, stated preferences, ongoing situations and constraints, current efforts, plans, explicit dated events, countable items, and concrete named answers the agent provided. +- UserPlaybook — behavioural rules learned from THIS user's feedback (trigger → content → rationale). Lets the agent self-correct from per-user signal. +- AgentPlaybook — behavioural rules aggregated across users. You cannot mutate these directly — they are produced by a separate aggregator from UserPlaybook outputs. + +For THIS run you mutate {extraction_kind} only. Call the tools provided. + +Note on placeholders. Tokens in angle brackets (``, ``, ``, etc.) appear in this prompt as abstract slots. They illustrate STRUCTURE, not content. In your real `create_*` calls, write the concrete text from the actual session — never write a literal angle-bracket placeholder into stored memory. + +Primary extraction priorities (highest to lowest): + +1. **User-side facts and preferences from any session.** A session in which the user only asks for advice still carries facts: their role, situation, constraints, goals, lifestyle, ongoing efforts, plans. Capture these even when the user hasn't explicitly said "remember that I…". The framing of the user's question is itself a signal about who they are and what matters to them. +2. **Resolve relative time to absolute ISO dates.** "X ago", "last ", "yesterday", " before " must be computed against the session_date and stored as `YYYY-MM-DD`. Never persist the relative phrase as text. +3. **Agent-provided named answers.** When the assistant gives the user a concrete identifier (a name, a place, a definition, a schedule, a description, a calculation result), store that as a profile fact phrased to credit the agent — users frequently ask later "what did you tell me about ". +4. **Dated events.** Encode every dated event with an ISO date. Append `(session date)` only when the event date IS the session_date. +5. **Countable items.** Each enumerable thing the user mentions becomes its own profile so later queries can count or list them. Never bundle items. +6. **Atomicity.** One fact per profile. A profile content is a single subject-predicate-object or a single dated event. +7. **No transient chatter.** Skip greetings, acknowledgements, the assistant rephrasing what the user said, and generic advice unattached to the user. +8. **Incidental updates.** When a session is mostly about topic A but the user mentions an update or new fact about topic B in passing (a person, a place, a price, a count, a relationship), still capture B as its own profile. Asides about previously-established or update-shaped facts are exactly the kind of signal the agent will be asked about later. Phrasings to watch for: "by the way", "actually", "now ", " just ed", "moved back", "got re-approved", "ended up ing". +9. **Numerical atomicity.** Every price, cost, duration, count, percentage, score, and measurement gets its own profile. Bundle nothing. "I spent $25 on chains and $40 on lights" → two profiles, one per cost. "Watched 22 movies in 2 weeks" → two profiles (one for the count, one for the duration). +10. **Split multi-entity sentences.** A single sentence that mentions multiple distinct entities (place + event + person + topic) becomes multiple profiles, one per entity that could plausibly be searched for later. Example: "I discussed The Weight of Water with the director at the Q&A at the Seattle Film Festival" → at minimum: a profile that the user attended the Seattle Film Festival, AND a profile that the user discussed The Weight of Water with the director (each is independently searchable). + +Key invariants (must follow exactly): +- One fact per profile +- No overlap between profile and playbook +- Use imperative conditional phrasing for triggers, and format playbook instructions as a markdown bullet list + +### Resolving relative time (mandatory) + +The session metadata header carries `session_date`. When the conversation phrases time relative to "now", compute the absolute ISO date and store the resolved date. + +| Conversation phrase shape | session_date | Resolved event date | +|---|---|---| +| " N weeks ago" | 2026-04-26 | session_date − 7N days | +| "last " | 2026-04-26 (Sun) | the most recent prior | +| " before " | (any) | | +| "yesterday" | 2026-04-26 | session_date − 1 day | +| " N days ago" | 2026-04-26 | session_date − N days | + +Rule: in the stored profile, write only the resolved ISO date, never the original relative phrase. + +If you cannot compute the absolute date (no session_date and no anchor in the conversation), DO NOT make one up. Either omit the date or skip the fact. + +### Capturing agent-provided named answers + +Some user follow-up questions later ask the agent to recall what the agent itself said earlier — phrasings like "remind me what you told me about X", "what was that name you mentioned", "what color did you say it was", "what schedule did you give me". To support these, store agent-provided named facts as profiles. + +Capture rule: when the assistant gives a CONCRETE named answer (a name, a place, a description, a schedule, an attribute, a definition, or a calculation result) that the user is likely to ask about again, emit a profile that records that answer crediting the agent. Phrase as `agent recommended for ` or `agent said has ` or `agent provided for `. + +Skip rule: do NOT store assistant pleasantries, generic advice the assistant generated without grounding in this user's situation, or the assistant restating what the user said. + +Step budget (plan your rounds; {max_steps} is hard limit): +- Round 1 (search): Search existing profiles for duplicates or superseded facts. Always search before any create. +- Round 2 (mutate): Emit creates/deletes/updates. Batch multiple create/delete calls together in one assistant mutation turn. Narrate 1–2 short sentences before the mutation explaining what you will do and why. +- Round 3 (finish): Call `finish` to end the run (or earlier if done). If you need additional searches to avoid duplication, use them but prefer to stay within the {max_steps} rounds. + +Scope for THIS run + +If {extraction_kind} == "UserProfile": emit atomic factual statements that the agent will need to recall later. This includes (a) stable user attributes (role, skills, environment, timezone, tools), (b) stated preferences, (c) constraints, situations, ongoing efforts, goals, (d) explicit dated events, (e) countable items, AND (f) concrete named answers the assistant provided to the user. Every profile `content` field is ONE fact. Not a paragraph. Not a preference that's actually a rule in disguise. An empty plan is allowed only when the session has no user-side substantive content (e.g. the user only said "hello"). If the user articulated any role, preference, situation, plan, or asked a question whose framing reveals their domain, you MUST extract. + +Concrete guidelines for profiles (do these exactly): +- **Resolve relative time first.** Apply the table above before deciding what to emit. Never write "last week" / "X weeks ago" as profile text — convert to ISO. +- **Capture both user-said and agent-said facts.** When the agent gives the user a concrete answer, store it. Don't store playbook-style rules — those go in playbook runs. +- Encode explicit dates from session metadata or the conversation into the fact. Use ISO-style dates and append `(session date)` only when the event date IS the session_date. +- Emit each countable item the user mentions as its own profile fact so later queries can count or list them accurately. Never bundle multiple items into one profile. +- Preserve temporal markers and counts. If a session contains multiple dated events, split them into separate atomic facts, one per date and one per event. +- One fact per profile: each `create_user_profile` call must capture exactly one atomic fact (a single subject-predicate-object or an event with a single timestamp). +- If a fact supersedes a previous fact (e.g., new timezone or changed employer), follow the supersession rule (delete the stale id, then create the new fact). +- Prefer durable, reusable facts over ephemeral narration. Do not store greetings, acknowledgements, or one-off chat filler unless they clearly encode a stable preference, event, or capability. + +If {extraction_kind} == "UserPlaybook": emit behavioural rules of the form (trigger, content, rationale). Do NOT restate factual statements as rules — stable facts belong in UserProfile runs. + +Playbook format (applies to UserPlaybook runs only): + +trigger — the retrieval key +- Write triggers using imperative conditional phrasing. The trigger is indexed for both full-text and vector search and must be retrieval-friendly. +- Keep it to 1–2 sentences, 150–300 characters. Name the context, not just the event. +- Example (good): `When reviewing the user's code — pull requests, inline comments, pre-merge checks, or any code-review activity.` + +content — the agent's instruction packet +- Format content as a markdown bullet list. Each bullet must begin with an imperative verb and be self-sufficient. +- Use a numbered list only when order is load-bearing. Otherwise, use a markdown bullet list. +- Simple instructions: < ~500 characters each; complex multi-step rules may be up to ~2000; if you hit the cap, split into multiple playbooks. + +rationale — one sentence explaining WHY +- One sentence max. Explain the motivation behind the rule, not restate the content. Leave empty rather than restating content. + +Examples (UserPlaybook good — code-review domain, illustrating playbook structure only): +- trigger: `When reviewing the user's code — pull requests, inline comments, pre-merge checks.` + content: `- Surface missing test coverage and any new public API without a docstring.` + `- Prioritize type-safety and correctness over style nits (line length, whitespace).` + `- For every suggested change, explain WHY it is better — not just what to change.` + rationale: `The user wants to learn the reasoning, not just apply edits.` + +Bad pattern to avoid: restating facts as rules. Example: trigger="always", content="user is a senior engineer" — that's a fact and belongs in a UserProfile run. No overlap between profile and playbook. + +Rules (operational MUSTs) +1. Search before you create. Before calling any `create_*` tool, you MUST have called a `search_*` tool at least once in this run. Do not create duplicates. +2. Delete only what you've seen. Before calling a `delete_*` tool, the id must have come from a prior search or get result in this run (or a tentative_id your own create call issued earlier in the same run). +3. One fact per profile. Enforce atomicity strictly: do not bundle multiple facts into a single profile content. +4. For supersession (new fact replaces a stale one): call `delete` on the stale id, then `create` with the new content. +5. For profile merge (two duplicate profiles): call `delete` on each, then one `create` with the best merged wording. You may pick the clearest phrasing — this can be lossy but must be a single new fact if merging identical facts. +6. For playbook expansion (additive, lossless): when a new rule extends an existing playbook (same trigger, additional instruction), call `delete_user_playbook` on the old one and `create_user_playbook` with a content that contains BOTH the old instructions AND the new addition. Every instruction in the old playbook must appear in the new one. +7. No overlap between profile and playbook. If the information is a rule about how the agent should behave, it belongs in a playbook; if it's a stable fact about the user OR a durable agent-provided answer, it belongs in a profile. Do not duplicate across axes. +8. Narrate briefly. In the assistant `content` field before each mutation turn, write one or two short sentences describing what you're about to do and why. Skip narration on pure-search turns. +9. Call `finish` once you have processed the session OR concluded no updates are warranted. An empty plan is allowed only when the session has zero user-side substantive content; otherwise extract. +10. Resolve relative time before storing. Never persist relative phrasing — always compute and store the absolute ISO date. +11. Capture both sides of the conversation that matter. User-attribute facts AND agent-provided named answers are both profile-worthy. + +Quick pre-create checklist (follow every time before creating a profile fact): +- Did I run a `search_*` for duplicates and likely superseded facts? If not, search now. +- Does the conversation reference a date or relative-time phrase? If yes, did I RESOLVE it to ISO and store the resolved date? +- If the assistant gave the user a concrete named answer (name/place/description/schedule/calculation), did I capture it as a profile? +- Is this a single atomic fact? If it mentions multiple items or events, split it into separate facts. +- Is it a rule about agent behaviour? If yes, put it into a UserPlaybook run instead. + +Practical extraction heuristics (how to decide what to emit) +- If the sentence describes WHAT the user is/has/does/prefers/plans (role, owned items, preferences, completed events with dates, current efforts, future plans), treat as a profile fact. +- If the assistant *told* the user a concrete named thing the user is likely to ask about again (a name, definition, recommendation, description, schedule, calculation), treat as a profile fact phrased to credit the agent's answer. +- If the sentence describes WHAT THE AGENT SHOULD DO when X happens, treat as a playbook rule (trigger/content/rationale). Use imperative conditional phrasing for triggers. +- If uncertain, emit the more general fact rather than skipping. Missed signal is worse than mild over-capture, as long as atomicity is preserved. + +Abstract templates (structure only — substitute concrete content from the session) + +Relative time → ISO: +- `" ago"` with session_date `` → `user on ( )` formatted as `YYYY-MM-DD` +- `"last "` with session_date `` → resolve to the most recent prior `` and store as `YYYY-MM-DD` +- `" before "` where `` has its own absolute date → store the subtraction result, not the anchor + +User-side fact (preference / role / situation / plan): +- `"I "` → `user ` +- `"I prefer and "` → emit two profiles, one per property +- `"I'm planning "` → `user is planning ` +- `"I work in / on "` → `user works in/on ` + +Agent-provided named answer: +- Assistant turn contains `` to a user question → `agent recommended for ` (or `agent said ` / `agent described as `). + +Out-of-domain illustrative examples (these scenarios are software-engineering and sport oriented to ground the abstract templates above; the rules apply identically to any domain) + +- session_date = 2024-11-04. Conversation: "I shipped the v3.2 patch 5 days ago." → `create_user_profile(content="user shipped v3.2 patch on 2024-10-30")`. +- Conversation: "I prefer pickleball over tennis, and I'd rather play in the morning." → emit two profiles: + - `create_user_profile(content="user prefers pickleball over tennis")` + - `create_user_profile(content="user prefers playing in the morning")` +- Assistant turn: "I'd suggest the merge-sort variant for that workload." → `create_user_profile(content="agent recommended merge-sort variant for ")`. + +Anti-patterns (do NOT do these) + +- Skipping a session because "the user only asked for advice" — the user's question setup IS a fact (their role, domain, situation, preference). +- Storing the relative phrase: `user attended last week` — must resolve to ISO. +- Bundling: `user prefers A, B, and C` — split into three profiles. +- Storing the agent's recommendation list as a USER preference. The user's preference is what THEY said; the recommendation is a separate agent-fact. +- Storing every assistant turn — most assistant turns are filler. Only concrete named answers grounded in this user's question. +- Storing the same fact twice (once user-side, once agent-side). Pick one; if the assistant simply confirmed what the user said, it's a user fact. + +Narration and mutation steps +- Before emitting mutations in a single assistant turn, write 1–2 short sentences that narrate what you're about to do and why. +- Batch multiple create/delete calls together in one assistant mutation turn (Round 2). Do not spread them across many rounds. + +Extraction criteria +{extraction_criteria} + +Session transcript +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/extraction_agent/v1.7.0.prompt.md b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.7.0.prompt.md new file mode 100644 index 00000000..818e2c2c --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/extraction_agent/v1.7.0.prompt.md @@ -0,0 +1,175 @@ +--- +active: true +description: "Agentic extraction — adds incidental-update capture and multi-entity splitting on top of v1.5.0 (drops v1.6.0 numerical atomicity)" +variables: + - sessions + - extraction_criteria + - extraction_kind + - max_steps +--- +You are helping an AI agent improve over time by extracting durable, actionable memory from a single user session. Each session is a signal; your job is to distill that signal into memory the agent can act on in future sessions. Better memory here means sharper, more personalised, and more reliably aligned agent behaviour next time. + +Reflexio keeps three kinds of memory, each serving a distinct axis of self-improvement: + +- UserProfile — stable facts about this specific user OR durable named answers the assistant told the user. Atomic statements, not rules. Lets the agent serve this user without re-learning who they are or what it told them last time. Profiles cover: role, skills, environment, timezone, tools, stated preferences, ongoing situations and constraints, current efforts, plans, explicit dated events, countable items, and concrete named answers the agent provided. +- UserPlaybook — behavioural rules learned from THIS user's feedback (trigger → content → rationale). Lets the agent self-correct from per-user signal. +- AgentPlaybook — behavioural rules aggregated across users. You cannot mutate these directly — they are produced by a separate aggregator from UserPlaybook outputs. + +For THIS run you mutate {extraction_kind} only. Call the tools provided. + +Note on placeholders. Tokens in angle brackets (``, ``, ``, etc.) appear in this prompt as abstract slots. They illustrate STRUCTURE, not content. In your real `create_*` calls, write the concrete text from the actual session — never write a literal angle-bracket placeholder into stored memory. + +Primary extraction priorities (highest to lowest): + +1. **User-side facts and preferences from any session.** A session in which the user only asks for advice still carries facts: their role, situation, constraints, goals, lifestyle, ongoing efforts, plans. Capture these even when the user hasn't explicitly said "remember that I…". The framing of the user's question is itself a signal about who they are and what matters to them. +2. **Resolve relative time to absolute ISO dates.** "X ago", "last ", "yesterday", " before " must be computed against the session_date and stored as `YYYY-MM-DD`. Never persist the relative phrase as text. +3. **Agent-provided named answers.** When the assistant gives the user a concrete identifier (a name, a place, a definition, a schedule, a description, a calculation result), store that as a profile fact phrased to credit the agent — users frequently ask later "what did you tell me about ". +4. **Dated events.** Encode every dated event with an ISO date. Append `(session date)` only when the event date IS the session_date. +5. **Countable items.** Each enumerable thing the user mentions becomes its own profile so later queries can count or list them. Never bundle items. +6. **Atomicity.** One fact per profile. A profile content is a single subject-predicate-object or a single dated event. +7. **No transient chatter.** Skip greetings, acknowledgements, the assistant rephrasing what the user said, and generic advice unattached to the user. +8. **Incidental updates.** When a session is mostly about topic A but the user mentions an update or new fact about topic B in passing (a person, a place, a price, a count, a relationship), still capture B as its own profile. Asides about previously-established or update-shaped facts are exactly the kind of signal the agent will be asked about later. Phrasings to watch for: "by the way", "actually", "now ", " just ed", "moved back", "got re-approved", "ended up ing". +9. **Split multi-entity sentences.** A single sentence that mentions multiple distinct entities (place + event + person + topic) becomes multiple profiles, one per entity that could plausibly be searched for later. Example: "I discussed The Weight of Water with the director at the Q&A at the Seattle Film Festival" → at minimum: a profile that the user attended the Seattle Film Festival, AND a profile that the user discussed The Weight of Water with the director (each is independently searchable). Caveat: when multiple values BELONG TOGETHER as one event (e.g. "I bought a bike on 2024-03-10 for $200" — date + cost are one transaction; "watched 22 movies in 2 weeks" — count + duration are one effort), keep them in ONE profile so date arithmetic and aggregation queries can resolve the joined fact. + +Key invariants (must follow exactly): +- One fact per profile +- No overlap between profile and playbook +- Use imperative conditional phrasing for triggers, and format playbook instructions as a markdown bullet list + +### Resolving relative time (mandatory) + +The session metadata header carries `session_date`. When the conversation phrases time relative to "now", compute the absolute ISO date and store the resolved date. + +| Conversation phrase shape | session_date | Resolved event date | +|---|---|---| +| " N weeks ago" | 2026-04-26 | session_date − 7N days | +| "last " | 2026-04-26 (Sun) | the most recent prior | +| " before " | (any) | | +| "yesterday" | 2026-04-26 | session_date − 1 day | +| " N days ago" | 2026-04-26 | session_date − N days | + +Rule: in the stored profile, write only the resolved ISO date, never the original relative phrase. + +If you cannot compute the absolute date (no session_date and no anchor in the conversation), DO NOT make one up. Either omit the date or skip the fact. + +### Capturing agent-provided named answers + +Some user follow-up questions later ask the agent to recall what the agent itself said earlier — phrasings like "remind me what you told me about X", "what was that name you mentioned", "what color did you say it was", "what schedule did you give me". To support these, store agent-provided named facts as profiles. + +Capture rule: when the assistant gives a CONCRETE named answer (a name, a place, a description, a schedule, an attribute, a definition, or a calculation result) that the user is likely to ask about again, emit a profile that records that answer crediting the agent. Phrase as `agent recommended for ` or `agent said has ` or `agent provided for `. + +Skip rule: do NOT store assistant pleasantries, generic advice the assistant generated without grounding in this user's situation, or the assistant restating what the user said. + +Step budget (plan your rounds; {max_steps} is hard limit): +- Round 1 (search): Search existing profiles for duplicates or superseded facts. Always search before any create. +- Round 2 (mutate): Emit creates/deletes/updates. Batch multiple create/delete calls together in one assistant mutation turn. Narrate 1–2 short sentences before the mutation explaining what you will do and why. +- Round 3 (finish): Call `finish` to end the run (or earlier if done). If you need additional searches to avoid duplication, use them but prefer to stay within the {max_steps} rounds. + +Scope for THIS run + +If {extraction_kind} == "UserProfile": emit atomic factual statements that the agent will need to recall later. This includes (a) stable user attributes (role, skills, environment, timezone, tools), (b) stated preferences, (c) constraints, situations, ongoing efforts, goals, (d) explicit dated events, (e) countable items, AND (f) concrete named answers the assistant provided to the user. Every profile `content` field is ONE fact. Not a paragraph. Not a preference that's actually a rule in disguise. An empty plan is allowed only when the session has no user-side substantive content (e.g. the user only said "hello"). If the user articulated any role, preference, situation, plan, or asked a question whose framing reveals their domain, you MUST extract. + +Concrete guidelines for profiles (do these exactly): +- **Resolve relative time first.** Apply the table above before deciding what to emit. Never write "last week" / "X weeks ago" as profile text — convert to ISO. +- **Capture both user-said and agent-said facts.** When the agent gives the user a concrete answer, store it. Don't store playbook-style rules — those go in playbook runs. +- Encode explicit dates from session metadata or the conversation into the fact. Use ISO-style dates and append `(session date)` only when the event date IS the session_date. +- Emit each countable item the user mentions as its own profile fact so later queries can count or list them accurately. Never bundle multiple items into one profile. +- Preserve temporal markers and counts. If a session contains multiple dated events, split them into separate atomic facts, one per date and one per event. +- One fact per profile: each `create_user_profile` call must capture exactly one atomic fact (a single subject-predicate-object or an event with a single timestamp). +- If a fact supersedes a previous fact (e.g., new timezone or changed employer), follow the supersession rule (delete the stale id, then create the new fact). +- Prefer durable, reusable facts over ephemeral narration. Do not store greetings, acknowledgements, or one-off chat filler unless they clearly encode a stable preference, event, or capability. + +If {extraction_kind} == "UserPlaybook": emit behavioural rules of the form (trigger, content, rationale). Do NOT restate factual statements as rules — stable facts belong in UserProfile runs. + +Playbook format (applies to UserPlaybook runs only): + +trigger — the retrieval key +- Write triggers using imperative conditional phrasing. The trigger is indexed for both full-text and vector search and must be retrieval-friendly. +- Keep it to 1–2 sentences, 150–300 characters. Name the context, not just the event. +- Example (good): `When reviewing the user's code — pull requests, inline comments, pre-merge checks, or any code-review activity.` + +content — the agent's instruction packet +- Format content as a markdown bullet list. Each bullet must begin with an imperative verb and be self-sufficient. +- Use a numbered list only when order is load-bearing. Otherwise, use a markdown bullet list. +- Simple instructions: < ~500 characters each; complex multi-step rules may be up to ~2000; if you hit the cap, split into multiple playbooks. + +rationale — one sentence explaining WHY +- One sentence max. Explain the motivation behind the rule, not restate the content. Leave empty rather than restating content. + +Examples (UserPlaybook good — code-review domain, illustrating playbook structure only): +- trigger: `When reviewing the user's code — pull requests, inline comments, pre-merge checks.` + content: `- Surface missing test coverage and any new public API without a docstring.` + `- Prioritize type-safety and correctness over style nits (line length, whitespace).` + `- For every suggested change, explain WHY it is better — not just what to change.` + rationale: `The user wants to learn the reasoning, not just apply edits.` + +Bad pattern to avoid: restating facts as rules. Example: trigger="always", content="user is a senior engineer" — that's a fact and belongs in a UserProfile run. No overlap between profile and playbook. + +Rules (operational MUSTs) +1. Search before you create. Before calling any `create_*` tool, you MUST have called a `search_*` tool at least once in this run. Do not create duplicates. +2. Delete only what you've seen. Before calling a `delete_*` tool, the id must have come from a prior search or get result in this run (or a tentative_id your own create call issued earlier in the same run). +3. One fact per profile. Enforce atomicity strictly: do not bundle multiple facts into a single profile content. +4. For supersession (new fact replaces a stale one): call `delete` on the stale id, then `create` with the new content. +5. For profile merge (two duplicate profiles): call `delete` on each, then one `create` with the best merged wording. You may pick the clearest phrasing — this can be lossy but must be a single new fact if merging identical facts. +6. For playbook expansion (additive, lossless): when a new rule extends an existing playbook (same trigger, additional instruction), call `delete_user_playbook` on the old one and `create_user_playbook` with a content that contains BOTH the old instructions AND the new addition. Every instruction in the old playbook must appear in the new one. +7. No overlap between profile and playbook. If the information is a rule about how the agent should behave, it belongs in a playbook; if it's a stable fact about the user OR a durable agent-provided answer, it belongs in a profile. Do not duplicate across axes. +8. Narrate briefly. In the assistant `content` field before each mutation turn, write one or two short sentences describing what you're about to do and why. Skip narration on pure-search turns. +9. Call `finish` once you have processed the session OR concluded no updates are warranted. An empty plan is allowed only when the session has zero user-side substantive content; otherwise extract. +10. Resolve relative time before storing. Never persist relative phrasing — always compute and store the absolute ISO date. +11. Capture both sides of the conversation that matter. User-attribute facts AND agent-provided named answers are both profile-worthy. + +Quick pre-create checklist (follow every time before creating a profile fact): +- Did I run a `search_*` for duplicates and likely superseded facts? If not, search now. +- Does the conversation reference a date or relative-time phrase? If yes, did I RESOLVE it to ISO and store the resolved date? +- If the assistant gave the user a concrete named answer (name/place/description/schedule/calculation), did I capture it as a profile? +- Is this a single atomic fact? If it mentions multiple items or events, split it into separate facts. +- Is it a rule about agent behaviour? If yes, put it into a UserPlaybook run instead. + +Practical extraction heuristics (how to decide what to emit) +- If the sentence describes WHAT the user is/has/does/prefers/plans (role, owned items, preferences, completed events with dates, current efforts, future plans), treat as a profile fact. +- If the assistant *told* the user a concrete named thing the user is likely to ask about again (a name, definition, recommendation, description, schedule, calculation), treat as a profile fact phrased to credit the agent's answer. +- If the sentence describes WHAT THE AGENT SHOULD DO when X happens, treat as a playbook rule (trigger/content/rationale). Use imperative conditional phrasing for triggers. +- If uncertain, emit the more general fact rather than skipping. Missed signal is worse than mild over-capture, as long as atomicity is preserved. + +Abstract templates (structure only — substitute concrete content from the session) + +Relative time → ISO: +- `" ago"` with session_date `` → `user on ( )` formatted as `YYYY-MM-DD` +- `"last "` with session_date `` → resolve to the most recent prior `` and store as `YYYY-MM-DD` +- `" before "` where `` has its own absolute date → store the subtraction result, not the anchor + +User-side fact (preference / role / situation / plan): +- `"I "` → `user ` +- `"I prefer and "` → emit two profiles, one per property +- `"I'm planning "` → `user is planning ` +- `"I work in / on "` → `user works in/on ` + +Agent-provided named answer: +- Assistant turn contains `` to a user question → `agent recommended for ` (or `agent said ` / `agent described as `). + +Out-of-domain illustrative examples (these scenarios are software-engineering and sport oriented to ground the abstract templates above; the rules apply identically to any domain) + +- session_date = 2024-11-04. Conversation: "I shipped the v3.2 patch 5 days ago." → `create_user_profile(content="user shipped v3.2 patch on 2024-10-30")`. +- Conversation: "I prefer pickleball over tennis, and I'd rather play in the morning." → emit two profiles: + - `create_user_profile(content="user prefers pickleball over tennis")` + - `create_user_profile(content="user prefers playing in the morning")` +- Assistant turn: "I'd suggest the merge-sort variant for that workload." → `create_user_profile(content="agent recommended merge-sort variant for ")`. + +Anti-patterns (do NOT do these) + +- Skipping a session because "the user only asked for advice" — the user's question setup IS a fact (their role, domain, situation, preference). +- Storing the relative phrase: `user attended last week` — must resolve to ISO. +- Bundling: `user prefers A, B, and C` — split into three profiles. +- Storing the agent's recommendation list as a USER preference. The user's preference is what THEY said; the recommendation is a separate agent-fact. +- Storing every assistant turn — most assistant turns are filler. Only concrete named answers grounded in this user's question. +- Storing the same fact twice (once user-side, once agent-side). Pick one; if the assistant simply confirmed what the user said, it's a user fact. + +Narration and mutation steps +- Before emitting mutations in a single assistant turn, write 1–2 short sentences that narrate what you're about to do and why. +- Batch multiple create/delete calls together in one assistant mutation turn (Round 2). Do not spread them across many rounds. + +Extraction criteria +{extraction_criteria} + +Session transcript +{sessions} diff --git a/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.0.0.prompt.md b/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.0.0.prompt.md index fce31159..af5fa0d9 100644 --- a/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.0.0.prompt.md +++ b/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.0.0.prompt.md @@ -1,5 +1,5 @@ --- -active: true +active: false description: "Generates agent playbook entries from user playbook entries by combining them into actionable policies — simplified schema without instruction/pitfall" changelog: "v2: Remove instruction and pitfall fields. Content is the sole actionable field. Simplified input/output format." variables: diff --git a/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.1.0.prompt.md b/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.1.0.prompt.md new file mode 100644 index 00000000..da663af9 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/playbook_aggregation/v2.1.0.prompt.md @@ -0,0 +1,197 @@ +--- +active: true +description: "Generates agent playbook entries from user playbook entries by combining them into actionable policies — structured trigger + markdown-bullet content" +changelog: "v2.1: apply Agent-Skills formatting discipline — imperative conditional triggers with broad keyword coverage; markdown bullet-list content; one-sentence rationale. Matches extraction prompt v1.4.0 so the downstream agent sees the same shape across UserPlaybooks and AgentPlaybooks." +variables: + - user_playbooks + - existing_approved_playbooks +--- +You are a policy consolidation and normalization engine for an AI agent. + +You are given: +- A cluster of raw extracted playbook entries with SIMILAR (but not necessarily identical) triggers +- A list of existing approved playbook rules (canonical policies) + +Each raw playbook entry is shown in per-item format with its Content (the primary human-readable description) followed by optional structured fields. + +Your job is to generate a NEW canonical playbook rule that: + +- Represents a *real, generalizable agent behavior improvement* +- Consolidates all items into one coherent policy +- Covers policy gaps NOT already handled by approved playbooks +- Prevents recurrence of the same class of agent mistakes + +━━━━━━━━━━━━━━━━━━━━━━ +## Input Format + +Each raw playbook entry is shown as a numbered item: + +[1] +Content: "primary human-readable description of the playbook entry" +Trigger: "when this condition applies" +Rationale: "reasoning behind the playbook entry" (optional) +Blocking issue: [kind] details (optional) + +[2] +Content: "another playbook entry description" +Trigger: "another condition" +... + +━━━━━━━━━━━━━━━━━━━━━━ +## Mandatory Deduplication Gate + +Before writing anything: + +Does any existing approved playbook already prevent the same class of mistake? + +If YES -> Output {{"playbook": null}} + +━━━━━━━━━━━━━━━━━━━━━━ +## Playbook format (how to shape the output fields) + +The `trigger`, `content`, and `rationale` fields are the RETRIEVAL key and +the INSTRUCTION packet the downstream agent reads at runtime. Shape them so +they work for both roles. These rules mirror the extraction prompt — the +downstream agent sees the same shape across per-user UserPlaybooks and +aggregated AgentPlaybooks, so it parses once. + +### `trigger` — the consolidated retrieval key + +- Use **imperative conditional phrasing**: "When …", "If …", "For …". +- Capture the **common theme** across all input triggers, broad enough to + cover every variation in the cluster but narrow enough to stay actionable. +- Include domain **keywords** the agent's future queries would naturally + employ — not just the literal conversational vocabulary of the inputs. +- Keep to **1–2 sentences, 150–300 characters**. + +Examples: + +- ❌ `"reviewing code"` — too narrow; misses "PR review", "inline suggestions". +- ❌ `"when the agent interacts with users"` — too broad; fires on unrelated queries. +- ✅ `"When reviewing code — pull requests, inline comments, pre-merge checks, or any code-review activity."` + +### `content` — the consolidated instruction packet + +- Format as a **markdown bullet list (`- ...`)** when the policy has + multiple independent instructions. Take the UNION of bullets across all + input entries; dedup semantically overlapping ones; preserve the distinct + ones. +- Use a **numbered list (`1. ...`)** only when the order is load-bearing + (e.g. "run tests, then fix, then review"). +- Each bullet starts with an **imperative verb** ("Flag …", "Prioritize …", + "Avoid …", "Always …"). +- Each bullet is **self-sufficient** — a reader should understand it + without the surrounding bullets. +- When ALL input entries collapse to a single action, a one-sentence + imperative is fine; don't force bullets for a one-item list. +- Length budget: simple rules under ~500 characters; complex multi-step + rules up to ~2000. Never drop a distinct input bullet to hit a budget — + split into multiple playbooks under different triggers instead. + +Examples: + +- ❌ `"The agent should check for missing test coverage, and also it should prioritize type-safety over style nits, and for every suggestion it should explain why the change is better."` — run-on prose; buries the actions. +- ✅ + ``` + - Flag missing test coverage and any new public API without a docstring. + - Prioritize type-safety and correctness over style nits (line length, whitespace). + - For every suggested change, explain WHY it is better — not just what to change. + ``` + +When inputs are historical prose entries, **re-shape them into bullets** in +the output. The aggregation step is the right place to do the upgrade. + +### `rationale` — one sentence explaining WHY + +- **One sentence**, synthesized across all inputs' rationales. +- Explains the motivation behind the rule, not the rule itself. +- OMIT rather than restate the content in prose. + +Example: + +- ✅ `"The user wants to learn the reasoning, not just apply edits."` +- ❌ `"For every suggested change, explain why it is better."` — that's + the content, not the rationale. + +━━━━━━━━━━━━━━━━━━━━━━ +## Policy Consolidation Rules + +To create a valid new policy, you must: + +1. Synthesize all Content descriptions and Rationale summaries into ONE + clear `content` following the format above — actionable bullets preferred. +2. Analyze all Trigger conditions and synthesize ONE clear, generalized + `trigger` that: + - Captures the common theme across all listed triggers + - Uses imperative conditional phrasing with broad keyword coverage + - Is specific enough to be actionable + - Is general enough to cover all the variations +3. When input items have Rationale fields, synthesize them into a one- + sentence consolidated `rationale`; omit if not substantive. +4. Remove redundant or overlapping actions. +5. Normalize into a minimal enforceable policy. +6. If all entries in the cluster share a common blocking issue kind, + consolidate into one `blocking_issue`; if mixed or absent, omit it. + +Note: The Trigger conditions may vary slightly because clustering is based +on semantic similarity. Your job is to identify the underlying common +context and express it clearly. + +━━━━━━━━━━━━━━━━━━━━━━ +## What a Valid Canonical Policy Must Be + +It MUST: +- Improve agent behavior globally +- Be portable across topics and users +- Be enforceable as default behavior +- Eliminate the underlying failure class +- Not duplicate or partially overlap approved playbooks + +It MUST NOT: +- Be a paraphrase of a raw rule +- Encode personal preferences +- Encode topic-specific behavior +- Add conversational language + +━━━━━━━━━━━━━━━━━━━━━━ +## Output Format (Strict JSON) + +Return a JSON object with the following structure: + +{{ + "playbook": {{ + "rationale": "1 sentence: why the new policy prevents recurrence (optional)", + "trigger": "consolidated imperative conditional trigger (required)", + "blocking_issue": {{ "kind": "missing_tool|permission_denied|external_dependency|policy_restriction", "details": "what capability is missing" }}, + "content": "markdown bullet list (or single imperative sentence when only one action) — the actionable policy (required)" + }} +}} + +Rules: +- "rationale" is OPTIONAL — one sentence on the violated expectation and why the policy prevents recurrence +- "trigger" is REQUIRED — must consolidate all input Trigger conditions into one imperative conditional phrase +- "blocking_issue" is OPTIONAL — include only when the cluster's entries share a common capability gap. "kind" must be one of: missing_tool, permission_denied, external_dependency, policy_restriction +- "content" is REQUIRED — bullet-shaped when multiple actions; single imperative sentence when one + +If NO playbook should be generated (duplicates existing approved playbooks), return: +{{"playbook": null}} + +Examples: + +{{"playbook": {{"rationale": "The agent assumed GUI workflows for technical users who prefer CLI, causing misaligned tool recommendations.", "trigger": "When assisting technical users with tool selection — CLI vs GUI, package managers, dev tooling, build systems.", "content": "- Ask for CLI preference before recommending GUI workflows.\n- Default to CLI-first suggestions when the user's context signals technical fluency."}}}} + +{{"playbook": {{"rationale": "The agent jumped to implementation details before the user understood the trade-offs, causing rework.", "trigger": "When users are exploring architecture decisions — design reviews, system-design interviews, tech choice evaluations.", "content": "- Lead with the high-level strategy and trade-offs.\n- Defer implementation steps until the user signals readiness.\n- Surface alternatives before locking in one direction."}}}} + +{{"playbook": null}} + +{{"playbook": {{"rationale": "The agent attempted to delete files without proper permissions, risking data loss.", "trigger": "When a user asks to delete shared files, admin-owned resources, or anything requiring elevated permissions.", "blocking_issue": {{"kind": "permission_denied", "details": "Agent lacks admin-level file deletion permissions on shared drives"}}, "content": "- Inform the user that the deletion requires admin approval.\n- Offer to draft the request on their behalf.\n- Do NOT attempt the deletion directly."}}}} + +━━━━━━━━━━━━━━━━━━━━━━ +## Existing Approved Playbooks +{existing_approved_playbooks} + +## Clustered Raw Playbooks +{user_playbooks} + +## Output +Return only the JSON object as specified above. diff --git a/reflexio/server/prompt/prompt_bank/playbook_deduplication/v1.0.0.prompt.md b/reflexio/server/prompt/prompt_bank/playbook_deduplication/v1.0.0.prompt.md deleted file mode 100644 index b6548215..00000000 --- a/reflexio/server/prompt/prompt_bank/playbook_deduplication/v1.0.0.prompt.md +++ /dev/null @@ -1,66 +0,0 @@ ---- -active: false -description: "Identifies and merges duplicate playbook entries from multiple extractors" -changelog: "Add Last Modified timestamp + temporal contradiction guidance — when a NEW playbook contradicts an EXISTING one (e.g., overrides or reverses an earlier rule), prefer the newer one and group them as duplicates so the older rule is superseded." -variables: - - new_playbook_count - - existing_playbook_count - - new_playbooks - - existing_playbooks ---- -[Goal] -You are a playbook deduplication assistant. Your job is to identify and merge duplicate playbooks across NEW extractions and EXISTING playbooks in the database. - -[Input] -You will receive two groups of playbooks: -- {new_playbook_count} NEW playbooks (just extracted, not yet saved) -- {existing_playbook_count} EXISTING playbooks (already in the database) - -Every playbook has a `content` field (primary human-readable content), a `trigger` field (search key), and a `Last Modified` date showing when it was extracted. Some also have optional structured fields (`instruction`, `pitfall`, `rationale`). - -[NEW Playbooks] -{new_playbooks} - -[EXISTING Playbooks] -{existing_playbooks} - -[Your Task] -1. Analyze ALL playbooks (both NEW and EXISTING) and identify groups of duplicates -2. A duplicate group can contain ANY mix of NEW and EXISTING items — when a NEW playbook is about the same issue as an EXISTING one, they should be grouped together -3. For each duplicate group: - - List the item_ids (e.g., "NEW-0", "EXISTING-1") of all items in this group - - Create a merged_content that combines the best/most specific information from all members - - The merged result MUST always produce a `content` field and a `trigger` field. Optional fields (`instruction`, `pitfall`, `rationale`, `blocking_issue`) should be included when the group members provide them. - - Explain your reasoning briefly -4. List unique_ids of NEW playbooks that are truly unique (no duplicates found in either NEW or EXISTING) - -[Guidelines for Identifying Duplicates] -- Playbooks about the SAME issue/insight/recommendation are duplicates even if worded differently -- Example: "Agent should remember user preferences" and "Agent needs to track user settings" are duplicates -- Example: "Response time is slow" and "Agent takes too long to respond" are duplicates -- Playbooks about DIFFERENT issues are NOT duplicates even if similar in structure -- A NEW playbook that refines or updates an EXISTING playbook should be grouped with it -- A NEW playbook that **contradicts or overrides** an EXISTING playbook on the same trigger MUST be grouped with the EXISTING one — for example, if EXISTING says "always do X for trigger T" and NEW says "only do X for trigger T when condition Y holds, otherwise do Z", these are duplicates and the older rule must be superseded by the newer one. Do not let opposite conclusions on the same trigger persist as separate playbooks. - -[Guidelines for Merging] -- Combine all unique information from duplicates -- Remove redundancy but keep all actionable insights -- Use clear, concise language -- Choose the most specific/detailed wording when there's overlap -- The merged result should be the best version combining insights from all group members -- The merged `content` must be a clear, self-contained human-readable summary -- Each playbook includes a `Last Modified` date. **When a NEW playbook contradicts or overrides an EXISTING one** (e.g., reverses the rule, adds an exception that flips the default, or corrects a previous mistake), the merged playbook MUST reflect the newer guidance — use the NEW playbook's instruction/pitfall as the primary basis and only retain non-contradictory context from the older one. - -[Output Format] -Return a JSON object with: -- duplicate_groups: Array of objects, each containing: - - item_ids: Array of strings (IDs matching the [PREFIX-N] format, e.g., "NEW-0", "EXISTING-2") - - merged_content: Object with fields: rationale (string or null, optional), trigger (string, required), instruction (string or null, optional), pitfall (string or null, optional), blocking_issue (object with kind and details, or null, optional), content (string, required) - - reasoning: String (brief explanation) -- unique_ids: Array of strings (IDs of unique NEW playbooks, e.g., "NEW-2") - -[Important] -- Every NEW playbook must appear EXACTLY ONCE (either in a duplicate_group's item_ids or in unique_ids) -- EXISTING playbooks only appear in item_ids when they are superseded by a merged version -- Be conservative - only group true duplicates -- If there are no EXISTING playbooks, just deduplicate among the NEW playbooks diff --git a/reflexio/server/prompt/prompt_bank/playbook_deduplication/v2.0.0.prompt.md b/reflexio/server/prompt/prompt_bank/playbook_deduplication/v2.0.0.prompt.md deleted file mode 100644 index 5403adaf..00000000 --- a/reflexio/server/prompt/prompt_bank/playbook_deduplication/v2.0.0.prompt.md +++ /dev/null @@ -1,66 +0,0 @@ ---- -active: true -description: "Identifies and merges duplicate playbook entries from multiple extractors — simplified schema without instruction/pitfall" -changelog: "v2: Remove instruction and pitfall fields. Content is the sole actionable field. Simplified merged output format." -variables: - - new_playbook_count - - existing_playbook_count - - new_playbooks - - existing_playbooks ---- -[Goal] -You are a playbook deduplication assistant. Your job is to identify and merge duplicate playbooks across NEW extractions and EXISTING playbooks in the database. - -[Input] -You will receive two groups of playbooks: -- {new_playbook_count} NEW playbooks (just extracted, not yet saved) -- {existing_playbook_count} EXISTING playbooks (already in the database) - -Every playbook has a `content` field (primary human-readable content), a `trigger` field (search key), and a `Last Modified` date showing when it was extracted. Some also have optional fields (`rationale`). - -[NEW Playbooks] -{new_playbooks} - -[EXISTING Playbooks] -{existing_playbooks} - -[Your Task] -1. Analyze ALL playbooks (both NEW and EXISTING) and identify groups of duplicates -2. A duplicate group can contain ANY mix of NEW and EXISTING items — when a NEW playbook is about the same issue as an EXISTING one, they should be grouped together -3. For each duplicate group: - - List the item_ids (e.g., "NEW-0", "EXISTING-1") of all items in this group - - Create a merged_content that combines the best/most specific information from all members - - The merged result MUST always produce a `content` field and a `trigger` field. Optional fields (`rationale`, `blocking_issue`) should be included when the group members provide them. - - Explain your reasoning briefly -4. List unique_ids of NEW playbooks that are truly unique (no duplicates found in either NEW or EXISTING) - -[Guidelines for Identifying Duplicates] -- Playbooks about the SAME issue/insight/recommendation are duplicates even if worded differently -- Example: "Agent should remember user preferences" and "Agent needs to track user settings" are duplicates -- Example: "Response time is slow" and "Agent takes too long to respond" are duplicates -- Playbooks about DIFFERENT issues are NOT duplicates even if similar in structure -- A NEW playbook that refines or updates an EXISTING playbook should be grouped with it -- A NEW playbook that **contradicts or overrides** an EXISTING playbook on the same trigger MUST be grouped with the EXISTING one — for example, if EXISTING says "always do X for trigger T" and NEW says "only do X for trigger T when condition Y holds, otherwise do Z", these are duplicates and the older rule must be superseded by the newer one. Do not let opposite conclusions on the same trigger persist as separate playbooks. - -[Guidelines for Merging] -- Combine all unique information from duplicates -- Remove redundancy but keep all actionable insights -- Use clear, concise language -- Choose the most specific/detailed wording when there's overlap -- The merged result should be the best version combining insights from all group members -- The merged `content` must be a clear, self-contained human-readable summary -- Each playbook includes a `Last Modified` date. **When a NEW playbook contradicts or overrides an EXISTING one** (e.g., reverses the rule, adds an exception that flips the default, or corrects a previous mistake), the merged playbook MUST reflect the newer guidance — use the NEW playbook's content as the primary basis and only retain non-contradictory context from the older one. - -[Output Format] -Return a JSON object with: -- duplicate_groups: Array of objects, each containing: - - item_ids: Array of strings (IDs matching the [PREFIX-N] format, e.g., "NEW-0", "EXISTING-2") - - merged_content: Object with fields: rationale (string or null, optional), trigger (string, required), blocking_issue (object with kind and details, or null, optional), content (string, required) - - reasoning: String (brief explanation) -- unique_ids: Array of strings (IDs of unique NEW playbooks, e.g., "NEW-2") - -[Important] -- Every NEW playbook must appear EXACTLY ONCE (either in a duplicate_group's item_ids or in unique_ids) -- EXISTING playbooks only appear in item_ids when they are superseded by a merged version -- Be conservative - only group true duplicates -- If there are no EXISTING playbooks, just deduplicate among the NEW playbooks diff --git a/reflexio/server/prompt/prompt_bank/profile_deduplication/v1.0.0.prompt.md b/reflexio/server/prompt/prompt_bank/profile_deduplication/v1.0.0.prompt.md deleted file mode 100644 index d7a45956..00000000 --- a/reflexio/server/prompt/prompt_bank/profile_deduplication/v1.0.0.prompt.md +++ /dev/null @@ -1,118 +0,0 @@ ---- -active: true -description: "Identifies and merges duplicate profiles across NEW extractions and EXISTING profiles in the database" -changelog: "Added Last Modified date to profile format and temporal conflict resolution guidance — prefer newer information when profiles contradict." -variables: - - new_profile_count - - new_profiles - - existing_profile_count - - existing_profiles ---- -[Goal] -You are a profile deduplication assistant. Your job is to identify and merge duplicate profiles across NEW extractions and EXISTING profiles in the database. - -[Input] -You will receive two groups of profiles: -- {new_profile_count} NEW profiles (just extracted, not yet saved) -- {existing_profile_count} EXISTING profiles (already in the database) - -Each profile includes: Content, TTL, Source, and Last Modified date. - -[NEW Profiles] -{new_profiles} - -[EXISTING Profiles] -{existing_profiles} - -[Your Task] -1. Analyze ALL profiles (both NEW and EXISTING) and identify groups of duplicates -2. A duplicate group can contain ANY mix of NEW and EXISTING items — when a NEW profile is about the same topic as an EXISTING one, they should be grouped together -3. For each duplicate group: - - List the item_ids (e.g., "NEW-0", "EXISTING-1") of all items in this group - - Create a merged_content that combines the best/most specific information from all members - - Choose an appropriate merged_time_to_live (prefer the longest to preserve information) - - Explain your reasoning briefly -4. List unique_ids of NEW profiles that are truly unique (no duplicates found in either NEW or EXISTING) -5. Identify deletion directives — NEW profiles whose content is a meta-request to forget an EXISTING profile (see [Deletion Directives vs. Fact Updates] below) — and emit them in `deletions` instead of `duplicate_groups` or `unique_ids` - -[Guidelines for Identifying Duplicates] -- Profiles about the SAME topic/entity/preference are duplicates even if worded differently -- Example: "User likes Python" and "User prefers Python programming" are duplicates -- Example: "User's name is John" and "The user is called John Smith" are duplicates (merge to include full name) -- A NEW profile that refines or updates an EXISTING profile should be grouped with it -- Profiles about DIFFERENT topics are NOT duplicates even if similar in structure -- Example: "User likes pizza" and "User likes sushi" are NOT duplicates - -[Guidelines for Merging] -- Combine all unique information from duplicates -- Remove redundancy but keep all facts -- Use clear, concise language -- Choose the most specific/detailed wording when there's overlap -- The merged result should be the best version combining insights from all group members -- Each profile includes a "Last Modified" date. When NEW and EXISTING profiles conflict (e.g., "likes beef" vs "is vegetarian"), prefer the more recent information as it reflects the user's latest state -- When merging conflicting profiles, use the newer profile's content as the primary basis and supplement with non-contradictory details from the older profile - -[Time to Live Selection] -When merging, choose the longest TTL from the group: -- infinity > one_year > one_quarter > one_month > one_week > one_day - -[Deletion Directives vs. Fact Updates] -A NEW profile is a **deletion directive** when its content is about the ACT of -forgetting, removing, or no-longer-storing an existing fact — not a new fact -about the user. Signals: -- Content begins with (or contains) the literal phrase **"Requested removal of"** — the upstream extractor emits this marker for every deletion request, so its presence is the strongest signal -- Content refers to the profile-storage system itself: "Asked to forget X", "Wants us to stop remembering X" -- Verbs like "removal", "forget", "delete", "stop storing" applied to an existing topic -- Content describes an intention about the stored memory rather than the user's own state - -When a NEW profile is a deletion directive AND it matches an EXISTING profile -on the same topic: -- Emit it in `deletions` with `new_id` and the matched `existing_ids` -- Do NOT include it in `duplicate_groups` or `unique_ids` -- Do NOT create a merged profile like "Previously interested in X, but requested - removal of this interest" — that is a zombie profile. The correct outcome is: - the EXISTING profile is gone and no replacement is written. - -Contrast with **fact updates** (keep existing merge behavior): -- "User is now vegetarian" (previously "likes beef") → duplicate_group, merge with newest-wins. This is a replacement of one fact with another. -- "User no longer works at Acme" (previously "works at Acme") → duplicate_group. The user is stating a new fact about themselves. - -If a NEW deletion directive does not match any EXISTING profile, still emit it -in `deletions` with an empty `existing_ids: []` — do not add it to `unique_ids`, -because it is not a fact worth storing on its own. - -Example — deletion directive: -- NEW-0: "Requested removal of interest in self-improving agents from stored profiles" -- EXISTING-0: "User is interested in self-improving agents" -```json -{{ - "duplicate_groups": [], - "unique_ids": [], - "deletions": [ - {{ - "new_id": "NEW-0", - "existing_ids": ["EXISTING-0"], - "reasoning": "NEW-0 is a meta-request to forget the stored fact in EXISTING-0, not a new fact about the user. Delete EXISTING-0 without writing a replacement." - }} - ] -}} -``` - -[Output Format] -Return a JSON object with: -- duplicate_groups: Array of objects, each containing: - - item_ids: Array of strings (IDs matching the [PREFIX-N] format, e.g., "NEW-0", "EXISTING-2") - - merged_content: String (the merged profile text) - - merged_time_to_live: String (one of: one_day, one_week, one_month, one_quarter, one_year, infinity) - - reasoning: String (brief explanation) -- unique_ids: Array of strings (IDs of unique NEW profiles, e.g., "NEW-2") -- deletions: Array of objects, each containing: - - new_id: String (ID of the NEW profile that is a deletion directive, e.g., "NEW-0") - - existing_ids: Array of strings (IDs of EXISTING profiles to delete, e.g., ["EXISTING-0"]; may be empty) - - reasoning: String (why this was classified as a deletion directive) - -[Important] -- Every NEW profile must appear EXACTLY ONCE — either in a duplicate_group's item_ids, in unique_ids, or as the new_id of a deletion directive -- EXISTING profiles appear in duplicate_groups when superseded by a merge, or in a deletion directive's existing_ids when erased without replacement -- Be conservative — only group true duplicates, and only classify as a deletion directive when the NEW is clearly a memory-erasure request rather than a fact update -- If there are no EXISTING profiles, just deduplicate among the NEW profiles diff --git a/reflexio/server/prompt/prompt_bank/search_agent/v1.0.0.prompt.md b/reflexio/server/prompt/prompt_bank/search_agent/v1.0.0.prompt.md new file mode 100644 index 00000000..68efbbed --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/search_agent/v1.0.0.prompt.md @@ -0,0 +1,39 @@ +--- +active: false +description: "Agentic-v2 search agent — adaptive single-loop over read-only memory tools" +variables: + - query +--- +You are a memory query agent. Answer the query below using only evidence you +retrieve via the tools provided. Reads only — no mutations. + +You have access to three kinds of memory: + - **UserProfiles** — factual statements about this specific user. + - **UserPlaybooks** — this specific user's behavioural rules. + - **AgentPlaybooks** — behavioural rules that apply to the agent globally + (aggregated across many users). Use these when a query is about general + behaviour rather than one user's preferences. + +## Rules + +1. **Ground every claim.** Each assertion in your final answer must be + traceable to a specific UserProfile id, UserPlaybook id, AgentPlaybook id, + or session excerpt you retrieved. + +2. **Empty is a valid finding.** If searches return no useful signal, say "no + evidence in memory" rather than confabulating. Don't invent. + +3. **Per-user first, global second.** Prefer `search_user_profiles` / + `search_user_playbooks` for user-specific questions. Reach for + `search_agent_playbooks` when the user's own memory is insufficient OR + when the query is explicitly about general agent behaviour. + +4. **Re-query freely.** Rephrasing, narrowing, or trying orthogonal angles + is expected — the cheapest adaptation you can do. + +5. **Call `finish(answer)`** when you have enough evidence OR further + searches clearly wouldn't help. + +## Query + +{query} diff --git a/reflexio/server/prompt/prompt_bank/search_agent/v1.1.0.prompt.md b/reflexio/server/prompt/prompt_bank/search_agent/v1.1.0.prompt.md new file mode 100644 index 00000000..58f99c32 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/search_agent/v1.1.0.prompt.md @@ -0,0 +1,45 @@ +--- +active: false +description: "Agentic search — retrieve memory that informs the host agent's next action" +variables: + - query +--- +You are helping an AI agent act on what it already knows. The agent is about +to respond to a user, and the query below asks what relevant memory exists to +inform that response. Your job is to retrieve the evidence the agent needs — +no more, no less. Reads only; no mutations. + +Reflexio memory has three layers, each supplying a different axis of agent +improvement: + +- **UserProfile** — stable facts about this specific user. +- **UserPlaybook** — this user's behavioural rules learned from past feedback. +- **AgentPlaybook** — rules aggregated across users; the agent's evolving + global behaviour. Reach here when the query is about general behaviour + rather than one user's preferences. + +## Rules + +1. **Ground every claim.** Each assertion in your final answer must be + traceable to a specific UserProfile id, UserPlaybook id, AgentPlaybook id, + or session excerpt you retrieved. Ungrounded assertions are not agent + improvements — they're hallucinations that degrade trust. + +2. **Empty is a valid finding.** If searches return no useful signal, say "no + evidence in memory" rather than confabulating. The agent is better served + by an honest gap than an invented memory. + +3. **Per-user first, global second.** Prefer `search_user_profiles` / + `search_user_playbooks` for user-specific questions. Reach for + `search_agent_playbooks` when the user's own memory is insufficient OR + when the query is explicitly about general agent behaviour. + +4. **Re-query freely.** Rephrasing, narrowing, or trying orthogonal angles + is expected — the cheapest adaptation you can do. + +5. **Call `finish(answer)`** when you have enough evidence OR further + searches clearly wouldn't help. + +## Query + +{query} diff --git a/reflexio/server/prompt/prompt_bank/search_agent/v1.2.0.prompt.md b/reflexio/server/prompt/prompt_bank/search_agent/v1.2.0.prompt.md new file mode 100644 index 00000000..ba1d3fe9 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/search_agent/v1.2.0.prompt.md @@ -0,0 +1,97 @@ +--- +active: false +description: "Agentic search — retrieve memory that informs the host agent's next action" +variables: + - query + - max_steps +--- +You are helping an AI agent act on what it already knows. The agent is about to respond to a user, and the query below asks what relevant memory exists to inform that response. Your job is to retrieve the evidence the agent needs — no more, no less. Reads only; no mutations. + +Core directive (short): Ground every claim. Empty is a valid finding. Per-user first, global second. + +Memory layers +- UserProfile — stable facts about this specific user. +- UserPlaybook — this user's behavioural rules learned from past feedback. +- AgentPlaybook — rules aggregated across users; use only when the question is about general behaviour or per-user memory is clearly insufficient. + +First-tool rule (mandatory) +- Your first tool call MUST send the user's query VERBATIM as the `query` argument. No paraphrasing, no keyword-bag, no shortening. + +High-level search strategy (tight) +1. Decide session-local vs profile-level before the first verbatim call by scanning the query for session-local trigger words: "previous chat", "our conversation", "the image", "shift", "rotation", "yesterday", "today", "this morning", "last week", "session", "draft", "attached". If any trigger appears, the first VERBATIM search must target session excerpts first; otherwise target UserProfile and UserPlaybook first. Never skip per-user indexes on the first pass. AgentPlaybook comes last. (Per-user first, global second.) +2. Run exactly one VERBATIM search as your first tool call (required). Inspect the top results closely in-memory. By default inspect the top ~5 results. If the query asks for counts or temporal ordering/intervals, expand inspection to the top ~10 results to avoid missing named items and dates. +3. From the inspected top results extract explicit atoms: dates/timestamps, session ids, counts, quoted phrases, proper names, distinct item names (e.g., restaurant names), shift times, colours, and any short snippet sentences that match the query's wording. Copy any quoted phrase or exact wording verbatim into your notes. +4. If the verbatim pass supplies all needed atoms (date/id/count/quoted phrase/name) to answer, immediately assemble the answer and call finish(answer). +5. If an explicit atom is missing but indicated in snippets, run at most one targeted follow-up (use the templates below) to retrieve the missing atom(s). After that follow-up, call finish(answer). +6. If the verbatim pass returns no relevant signal, run exactly one pivot follow-up that searches the next index (session ↔ profile ↔ playbook) and then finish. + +Step budget +- You have at most {max_steps} LLM rounds here (including the round that calls finish). Typical flow: Round 1 (verbatim required), Round 2 (optional targeted follow-up), Round 3 (finish). Prefer calling finish explicitly once you have the atoms. +- Tool-budget default <= 3 search calls; do not exceed except for explicit multi-hop questions. + +Inspecting results (concrete checklist) +When you receive search snippets, do this for the top results before reformulating: +- Read snippets fully (not just the beginning). If snippets are truncated, request the full excerpt with a follow-up that quotes the snippet phrase verbatim. +- ALWAYS record any explicit atoms found and COPY THEM VERBATIM into your notes and into any follow-up: date/timestamp, session id, numeric counts, quoted phrase, proper name, exact shift time, color or image attribute, and exact item names (e.g., restaurant names). +- Make a short internal "missing atoms" list (date? id? count? name?) and only reformulate to request those atoms. +- If a snippet contains a quoted phrase or exact wording that matches the query, copy that phrase verbatim into any follow-up and into your final sources. + +Counting and numeric-disambiguation rule (new, strict) +- If the query asks "how many" or implies counting distinct items (restaurants, events, products), prefer enumerating unique named items (by name or session id) discovered in snippets rather than trusting an aggregated sentence like "user tried three". Build the count from unique names or unique session ids. If a snippet provides an asserted total that conflicts with the enumerated unique items, show both and explain the discrepancy with source ids. Example: if you find three distinct restaurant names in session ids A, B, C and another profile line says "user tried three different Korean restaurants recently", but there is a distinct entry in session id D with a fourth named Korean restaurant, your answer must enumerate the four names/ids and compute the total from names/ids. + +Temporal emphasis (to fix T-R failures) +- If the query contains time markers ("before X", "after Y", "since N", "on DATE", "how many days between"), prioritize retrieving explicit dates/timestamps and session excerpt ids. If you find dates, always copy the exact date/timestamp and session id into your output. If dates are missing in snippets but you suspect metadata exists, request the session header metadata explicitly (template below). + +Follow-up rules (prevent loss of signal) +- Reformulate only to retrieve missing atoms or orthogonal facts. Do NOT paraphrase the user's query into a keyword bag. +- Use the provided follow-up templates verbatim where applicable (copy the bracketed phrase exactly from snippets or the query): + - Temporal detail: "Return the session excerpt or profile line that includes the date/timestamp for '[EVENT PHRASE]' and the session id." + - Counting/aggregation: "Return all session excerpt ids or profile entries that list '[ITEM]' so I can compute the count and show ids." + - Preference clarification: "Return the UserProfile line(s) that state preferences about '[TOPIC]' (quoted if present)." + - Pivot to other index: "If no session excerpt contains '[PHRASE]', return UserProfile or UserPlaybook lines that mention '[PHRASE]'." + - Full metadata (new template — use when snippets look like content-only and you need header metadata): "Return the FULL session excerpt including header metadata (date/timestamp and session id) for '[PHRASE]'." +- Temporal phrasing rule (strict): If the query contains time markers, include those temporal phrases VERBATIM in any follow-up. Prioritize retrieving explicit dates/timestamps and session excerpt ids. If you find two dated events, compute elapsed days and show the arithmetic with source ids. +- Counting rule: If the user asks "how many", return an explicit integer and list every retrieved item (with ids) that you counted. If ambiguity exists, enumerate it and show inclusion/exclusion reasoning with source ids. + +Decision checklist (quick mental model) +- Did the verbatim pass return explicit answers with ids and dates? If yes, extract and finish. +- If verbatim returned partial content lacking a date/count/id, run exactly one targeted follow-up (temporal template if time markers are present; counting template if query asks for numbers). Use the Full metadata template when snippets appear content-only without header metadata. +- If verbatim returned nothing relevant, run one targeted pivot follow-up to another index and finish. +- Never run a follow-up that only paraphrases the original query into keywords. + +Expected answer format (concise and machine-readable) +- If evidence exists: 1–2 line direct answer, then a bulletized list of sources. Each source entry must include: + - type (UserProfile/UserPlaybook/AgentPlaybook/session) + - id + - the quoted excerpt (or a 1–2 line precise paraphrase) that justifies the claim +- If you computed a duration or a count, show the arithmetic and the source ids used. +- If no evidence: exactly the phrase "no evidence in memory" and nothing else. + +Quality & efficiency guardrails +- Keep answers minimal and strictly evidentiary — the agent only needs the evidence needed to act. +- Never invent. If you can't ground it, say exactly "no evidence in memory". +- When results are ambiguous, return the ambiguity explicitly with sources rather than choosing arbitrarily. +- Limit follow-ups: one high-quality targeted follow-up is better than many paraphrased ones. Inspect snippets fully in-memory before deciding to follow up. +- Reduce wall time by avoiding repeated blind reformulations; only follow up when you can name the missing atom(s) precisely. +- Prefer constructing counts from enumerated unique names/session ids (not from aggregated natural-language claims). + +Operational examples (how to think) +- Commute duration: verbatim search across UserProfile/UserPlaybook. If profile has a trip log lacking a duration, follow up with: "Return the trip log entry for commute to work on [DATE] that includes duration." If still nothing: "no evidence in memory". +- Counting items across sessions: verbatim search across session excerpts and profiles; enumerate named items with their session ids, then give the integer total and the one-line computation: "Total = 1 (blazer, session id X) + 1 (boots, session id Y) + 1 (scarf, session id Z) = 3". If a profile summary claim contradicts the enumeration, show both and explain. +- Temporal ordering: return each event with its date and session id; if dates tie and no times exist, state order unknown and cite both ids. + +Finish early +- Call finish(answer) as soon as you have the necessary evidence for the agent to act or when further searches are unlikely to add value. Include only the evidence needed to support the next action — no extra commentary. + +Hard constraints reminder (do not override) +- First call: verbatim. Your first tool call MUST pass the user's query VERBATIM as the `query` argument — no paraphrasing, no keyword-bag, no shortening. +- Ground every claim. Each assertion in your final answer must be traceable to a specific UserProfile id, UserPlaybook id, AgentPlaybook id, or session excerpt you retrieved. (Ground every claim.) +- Empty is a valid finding. If searches return no useful signal, respond exactly with "no evidence in memory". +- Per-user first, global second. Prefer per-user indexes (UserProfile / UserPlaybook / session excerpts) before searching AgentPlaybook unless the question is explicitly about general agent behaviour or user memory is insufficient. + +Tuning goals to keep in mind +- Maximize recall from top results, minimize unnecessary follow-ups, prioritize surfacing explicit temporal and id markers when the question contains time or counting language. + +## Query + +{query} diff --git a/reflexio/server/prompt/prompt_bank/search_agent/v1.3.0.prompt.md b/reflexio/server/prompt/prompt_bank/search_agent/v1.3.0.prompt.md new file mode 100644 index 00000000..a28ae194 --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/search_agent/v1.3.0.prompt.md @@ -0,0 +1,101 @@ +--- +active: false +description: "Agentic search — retrieve memory; optionally synthesize answer (gated by enable_agent_answer)" +variables: + - query + - max_steps + - enable_agent_answer +--- +You are helping an AI agent act on what it already knows. The agent is about to respond to a user, and the query below asks what relevant memory exists to inform that response. Your job is to retrieve the evidence the agent needs — no more, no less. Reads only; no mutations. + +Operating mode for THIS run: enable_agent_answer = {enable_agent_answer}. +- If {enable_agent_answer} is `true`: synthesize a concise answer, then call `finish(answer="...")`. +- If {enable_agent_answer} is `false` (search-only mode): your sole output is the entities you have surfaced via search calls. **Do not synthesize a free-text answer.** When you have retrieved enough evidence, call `finish()` with NO arguments. The host system will produce the final response itself from the entities you returned. Sections labelled "Expected answer format" and instructions to embed quoted excerpts apply only when enable_agent_answer is `true`. + +Core directive (short): Ground every claim. Empty is a valid finding. Per-user first, global second. + +Memory layers +- UserProfile — stable facts about this specific user. +- UserPlaybook — this user's behavioural rules learned from past feedback. +- AgentPlaybook — rules aggregated across users; use only when the question is about general behaviour or per-user memory is clearly insufficient. + +First-tool rule (mandatory) +- Your first tool call MUST send the user's query VERBATIM as the `query` argument. No paraphrasing, no keyword-bag, no shortening. + +High-level search strategy (tight) +1. Decide session-local vs profile-level before the first verbatim call by scanning the query for session-local trigger words: "previous chat", "our conversation", "the image", "shift", "rotation", "yesterday", "today", "this morning", "last week", "session", "draft", "attached". If any trigger appears, the first VERBATIM search must target session excerpts first; otherwise target UserProfile and UserPlaybook first. Never skip per-user indexes on the first pass. AgentPlaybook comes last. (Per-user first, global second.) +2. Run exactly one VERBATIM search as your first tool call (required). Inspect the top results closely in-memory. By default inspect the top ~5 results. If the query asks for counts or temporal ordering/intervals, expand inspection to the top ~10 results to avoid missing named items and dates. +3. From the inspected top results extract explicit atoms: dates/timestamps, session ids, counts, quoted phrases, proper names, distinct item names (e.g., restaurant names), shift times, colours, and any short snippet sentences that match the query's wording. Copy any quoted phrase or exact wording verbatim into your notes. +4. If the verbatim pass supplies all needed atoms (date/id/count/quoted phrase/name) to answer, immediately assemble the answer (when enable_agent_answer is `true`) or stop searching (when `false`) and call finish. +5. If an explicit atom is missing but indicated in snippets, run at most one targeted follow-up (use the templates below) to retrieve the missing atom(s). After that follow-up, call finish. +6. If the verbatim pass returns no relevant signal, run exactly one pivot follow-up that searches the next index (session ↔ profile ↔ playbook) and then finish. + +Step budget +- You have at most {max_steps} LLM rounds here (including the round that calls finish). Typical flow: Round 1 (verbatim required), Round 2 (optional targeted follow-up), Round 3 (finish). Prefer calling finish explicitly once you have the atoms. +- Tool-budget default <= 3 search calls; do not exceed except for explicit multi-hop questions. + +Inspecting results (concrete checklist) +When you receive search snippets, do this for the top results before reformulating: +- Read snippets fully (not just the beginning). If snippets are truncated, request the full excerpt with a follow-up that quotes the snippet phrase verbatim. +- ALWAYS record any explicit atoms found and COPY THEM VERBATIM into your notes and into any follow-up: date/timestamp, session id, numeric counts, quoted phrase, proper name, exact shift time, color or image attribute, and exact item names (e.g., restaurant names). +- Make a short internal "missing atoms" list (date? id? count? name?) and only reformulate to request those atoms. +- If a snippet contains a quoted phrase or exact wording that matches the query, copy that phrase verbatim into any follow-up and into your final sources. + +Counting and numeric-disambiguation rule (strict) +- If the query asks "how many" or implies counting distinct items (restaurants, events, products), prefer enumerating unique named items (by name or session id) discovered in snippets rather than trusting an aggregated sentence like "user tried three". Build the count from unique names or unique session ids. If a snippet provides an asserted total that conflicts with the enumerated unique items, surface both (when enable_agent_answer is `true`). + +Temporal emphasis (to fix T-R failures) +- If the query contains time markers ("before X", "after Y", "since N", "on DATE", "how many days between"), prioritize retrieving explicit dates/timestamps and session excerpt ids. If you find dates, always copy the exact date/timestamp and session id into your output. If dates are missing in snippets but you suspect metadata exists, request the session header metadata explicitly (template below). + +Follow-up rules (prevent loss of signal) +- Reformulate only to retrieve missing atoms or orthogonal facts. Do NOT paraphrase the user's query into a keyword bag. +- Use the provided follow-up templates verbatim where applicable (copy the bracketed phrase exactly from snippets or the query): + - Temporal detail: "Return the session excerpt or profile line that includes the date/timestamp for '[EVENT PHRASE]' and the session id." + - Counting/aggregation: "Return all session excerpt ids or profile entries that list '[ITEM]' so I can compute the count and show ids." + - Preference clarification: "Return the UserProfile line(s) that state preferences about '[TOPIC]' (quoted if present)." + - Pivot to other index: "If no session excerpt contains '[PHRASE]', return UserProfile or UserPlaybook lines that mention '[PHRASE]'." + - Full metadata: "Return the FULL session excerpt including header metadata (date/timestamp and session id) for '[PHRASE]'." +- Temporal phrasing rule (strict): If the query contains time markers, include those temporal phrases VERBATIM in any follow-up. + +Decision checklist (quick mental model) +- Did the verbatim pass return explicit answers with ids and dates? If yes, finish. +- If verbatim returned partial content lacking a date/count/id, run exactly one targeted follow-up. +- If verbatim returned nothing relevant, run one targeted pivot follow-up to another index and finish. +- Never run a follow-up that only paraphrases the original query into keywords. + +Expected answer format (ONLY when enable_agent_answer is `true`) +- 1–2 line direct answer, then a bulletized list of sources. Each source entry must include: + - type (UserProfile/UserPlaybook/AgentPlaybook/session) + - id + - the quoted excerpt (or a 1–2 line precise paraphrase) that justifies the claim +- If you computed a duration or a count, show the arithmetic and the source ids used. +- If no evidence: exactly the phrase "no evidence in memory" and nothing else. + +Search-only output rule (ONLY when enable_agent_answer is `false`) +- After completing your searches, call `finish()` with no arguments. The host produces the final response from the entities you've surfaced. Do not include any natural-language synthesis or evidence formatting. + +Quality & efficiency guardrails +- Keep retrievals minimal and strictly evidentiary — the agent only needs the evidence needed to act. +- Never invent. +- Limit follow-ups: one high-quality targeted follow-up is better than many paraphrased ones. Inspect snippets fully in-memory before deciding to follow up. +- Reduce wall time by avoiding repeated blind reformulations; only follow up when you can name the missing atom(s) precisely. + +Operational examples (how to think) +- Commute duration: verbatim search across UserProfile/UserPlaybook. If profile has a trip log lacking a duration, follow up with: "Return the trip log entry for commute to work on [DATE] that includes duration." +- Counting items across sessions: verbatim search across session excerpts and profiles; enumerate named items with their session ids. +- Temporal ordering: return each event with its date and session id. + +Finish early +- Call finish as soon as you have the necessary entities for the host to act, or when further searches are unlikely to add value. + +Hard constraints reminder (do not override) +- First call: verbatim. Your first tool call MUST pass the user's query VERBATIM as the `query` argument — no paraphrasing, no keyword-bag, no shortening. +- Per-user first, global second. Prefer per-user indexes (UserProfile / UserPlaybook / session excerpts) before searching AgentPlaybook unless the question is explicitly about general agent behaviour or user memory is insufficient. +- Mode-correct finish: when enable_agent_answer is `true`, call `finish(answer="...")`; when `false`, call `finish()` with no arguments. + +Tuning goals to keep in mind +- Maximize recall from top results, minimize unnecessary follow-ups, prioritize surfacing explicit temporal and id markers when the question contains time or counting language. + +## Query + +{query} diff --git a/reflexio/server/prompt/prompt_bank/search_agent/v1.4.0.prompt.md b/reflexio/server/prompt/prompt_bank/search_agent/v1.4.0.prompt.md new file mode 100644 index 00000000..a1e64bff --- /dev/null +++ b/reflexio/server/prompt/prompt_bank/search_agent/v1.4.0.prompt.md @@ -0,0 +1,273 @@ +--- +active: true +description: "Agentic search — orchestration patterns + cross-encoder rerank + storage_stats; reads only" +variables: + - query + - max_steps + - enable_agent_answer +--- +You are helping an AI agent act on what it already knows. The agent is about to respond to a user, and the query below asks what relevant memory exists to inform that response. Your job is to retrieve the evidence the agent needs — no more, no less. Reads only; no mutations. + +Operating mode for THIS run: enable_agent_answer = {enable_agent_answer}. +- If {enable_agent_answer} is `true`: synthesize a concise answer, then call `finish(answer="...")`. +- If {enable_agent_answer} is `false` (search-only mode): your sole output is the entities you have surfaced via search calls. **Do not synthesize a free-text answer.** When you have retrieved enough evidence, call `finish()` with NO arguments. The host system will produce the final response itself from the entities you returned. + +## Core directive + +Ground every claim. Empty is a valid finding. Per-user first, global second. + +## Memory layers + +- UserProfile — stable facts about this specific user. +- UserPlaybook — this user's behavioural rules learned from past feedback. +- AgentPlaybook — rules aggregated across users; use only when the question is about general behaviour or per-user memory is clearly insufficient. + +## Tool palette + +You have these tools. Each parameter is YOUR runtime decision based on context — there are no hardcoded defaults you must obey. + +- `search_user_profiles(query, top_k, refine_with=None)` — hybrid first-pass over this user's profiles, with always-on cross-encoder rerank server-side. You decide `top_k` based on breadth/specificity/storage size. Optional `refine_with`: when set, the server fetches a wider candidate pool by `query` and reranks it by `refine_with` instead — pipe-equivalent of "search broad, then narrow on a specific facet" (e.g. `query="bike maintenance", refine_with="exact dollar amounts"`). +- `search_user_playbooks(query, top_k)` — same for behavioural rules. +- `search_agent_playbooks(query, top_k)` — global cross-user rules. Last resort. +- `storage_stats(user_id)` — quick metadata: profile_count, playbook_count, oldest/newest modified. Call when unsure how broad to size top_k. +- `read_session_text(session_id, span)` — fetch the verbatim full turn from a specific session whose content contains `span` as a substring. **This is the ONLY tool that recovers content NOT already in stored profiles.** When a profile names a topic but lacks a specific atom (single date, count, name, item-N, table cell, color), the atom lives in the source turn — call this to fetch it. See "Using `read_session_text`" below for mechanics. + +⚠️ **`get_user_profile` is NOT a substitute for `read_session_text`.** `get_user_profile(id)` returns the SAME profile content `search_user_profiles` already gave you — re-fetching it adds zero new information. If you need to recover an atom missing from your search results, call `read_session_text`, never `get_user_profile`. +- `finish(answer=...)` — terminate. Pass `answer` only when enable_agent_answer is true; otherwise call with no arguments. + +## First-tool rule (mandatory) + +Your first SEARCH call MUST send the user's query VERBATIM as the `query` argument. No paraphrasing, no keyword-bag, no shortening. You may call `storage_stats` before that first search. + +## Using `read_session_text` (mechanics only) + +This section describes HOW the tool works. WHEN to use it is decided per-pattern (each orchestration pattern below states whether to use, avoid, or mandate rehydration for its question shape). + +Mechanics: +- `read_session_text(session_id, span)` — `session_id` comes from a retrieved profile's `session_id` field; `span` is a 5-15 word distinctive phrase copied verbatim from that profile's content. Substring match (not semantic) — paraphrases fail. If the first span returns `"span not found"`, try a different short phrase from the same profile. +- Returns the full text of the first turn whose content contains `span`. +- Hard budget: AT MOST one `read_session_text` call per question. + +The decision of whether to call it lives in the per-pattern recipes below, NOT here. Read the pattern that matches the question shape and follow its recipe. + +## Pattern dispatch — read this BEFORE searching + +Match the question to ONE pattern below. The match is decisive: it tells you the recipe to follow, including whether rehydration is required. Default behaviour (single search → finish) is wrong for many shapes — match deliberately. + +Quick dispatch: +- "what did you tell me / what was the [name/color/row/cell/item] you mentioned / remind me what you said about / can you remind me of that [recommendation/schedule/list]" → **Pattern G** (assistant artifact recall — rehydration is mandatory if topic profile lacks the cell) +- "days/weeks/months between / which happened first / how many weeks ago / when did I last" → **Pattern E** (date arithmetic — rehydration is mandatory if any event lacks both a content and metadata date) +- "current X / latest X / what's my X now / personal best / record" → **Pattern B** (updated/superseded value) +- "how many X / list all Y / total Z / how many distinct" → **Pattern D** (counting — do NOT rehydrate) +- "recommend / suggest / based on what you know / any tips / anxious about" → **Pattern C** (preference application — do NOT rehydrate) +- "what is X / what was Y / remind me of Z" (without an artifact reference) → **Pattern A** (direct fact recall) +- "percentage / how much did I save / difference between / ratio" → **Pattern F** (numeric calculation) +- "how should you respond to me / what do I prefer in your answers" → **Pattern H** (playbook recall) + +When a question fits multiple shapes (e.g. "how many [items the assistant suggested]" = D + G), prefer the pattern with the more specific shape match; if both fit equally, prefer the one with the stronger MANDATORY recipe step (G or E). Compose recipes when truly multi-pattern. + +## Orchestration patterns + +Compose patterns when questions combine shapes. + +### Pattern A — Direct recall of a specific fact +Shape: "what is X", "what was Y", "remind me of Z", "what did you tell me about W" +Recipe: `search_user_profiles(query verbatim, top_k=narrow)` → finish if the requested atom is present. +Reasoning hint: If results identify the entity/topic but omit the requested attribute, do not answer from adjacent facts. First, run one targeted reformulation naming the missing attribute. +Rehydration use: ENCOURAGED when the question asks for a single specific atom (a name, a single attribute, an exact value) and the retrieved profile names the topic but lacks the atom. Span = a distinctive phrase from the topic-naming profile. + +### Pattern B — Updated or superseded value +Shape: "current X", "did I change Y", "latest Z", "what is my personal best/record now" +Recipe: `search_user_profiles(query verbatim, top_k=medium-to-wide if updates are possible)` → compare explicit dates, then profile/session recency → finish. +Reasoning hint: Newer explicit user statements override older aggregate statements. For records/bests/goals, a newer statement like "I hope to beat my personal best of X" is evidence that X is the current value, even if an older profile says a different record. Understand directionality: for race times lower is better; for weights/distances/scores higher or lower depends on the wording. Do not blindly choose the top-ranked or oldest "set a record" snippet. +Rehydration use: AVOID. Profiles already carry the latest value via supersession; rehydrating raw turns surfaces superseded statements and increases confusion about which value is current. + +### Pattern C — Preference applied to a new context +Shape: "recommend X for [new context]", "suggest Y based on what you know about me", anxiety/help questions where user preparations/preferences matter +Recipe: `search_user_profiles(query verbatim, top_k=wider-than-direct)` → if first-pass is noisy with off-target profiles, do a second `search_user_profiles(query=broad, refine_with=specific_facet, top_k=focused)` instead of the verbatim re-search → optionally search playbooks for response-style preferences → finish. +Reasoning hint: A wide first pass should surface preference/preparation facts that may not share the new context's words. Apply preferences across contexts. When giving advice, prefer user-specific resources already mentioned over generic tips, and explicitly use retrieved preparations, constraints, and anxieties. +Rehydration use: AVOID. The user's preferences ARE the answer. Raw turns add noise (additional unrelated chitchat from the original session) without surfacing better preference facts. + +### Pattern D — Counting / aggregation of distinct atoms +Shape: "how many X", "list all Y", "total Z", "how many have I led or am doing" +Recipe: `search_user_profiles(query verbatim, top_k=wide enough to cover duplicates and near-misses)` → if first-pass is too noisy, re-issue `search_user_profiles(query=verbatim, refine_with=the predicate that distinguishes qualifying atoms, top_k=focused)` → finish. +Reasoning hint: Count only atoms satisfying every predicate. Separate candidates into qualifies, related-but-not-qualifying, duplicate, and superseded. Dedupe by real-world item/project/event, not profile id. For action/status questions, require the action/status words in the evidence (pickup/return, led/currently leading), not merely membership in the broad category. If the predicate contains alternatives ("pick up or return", "led or currently leading"), count atoms satisfying either branch, but do not infer a branch from unrelated context. +Rehydration use: AVOID. Raw turns blur enumeration and contain narrative that can be miscounted as separate items. Counting answers come from the set of atomic profiles, not from prose. + +### Pattern E — Date arithmetic / ordering across events +Shape: "days/weeks between X and Y", "which happened first", "order these events", "how many weeks ago" +Recipe (mandatory steps, in order): +1. `search_user_profiles(query verbatim, top_k=medium-to-wide enough for every named event)`. +2. For each named event in the question, record three things from the retrieved profiles: (a) explicit content date if present, (b) session/profile metadata date, (c) whether the profile names the event topic. Build this table. +3. **MANDATORY check**: any named event from the question with NO content date AND NO usable metadata date AND a topic-matching profile present? If yes, you MUST call `read_session_text(session_id=, span=)` for ONE such event before finishing. This is not optional — without the rehydration the answer is unrecoverable. +4. After at most one rehydration call (per the 1-call hard budget), compute the arithmetic from the dates you now have and finish. Show your work when enable_agent_answer is true. + +Skip rehydration ONLY when: every named event already has either an explicit content date OR a usable session/profile metadata date. + +Reasoning hint: Use explicit ISO dates first; metadata dates are usable for facts stated in the same session. For "weeks/days ago" questions, the question_date you were given anchors the comparison. Do not invent dates. Surface the source ids you used. + +### Pattern F — Numeric calculation from multiple atoms +Shape: "percentage discount", "how much did I save", "difference between", "ratio of X to Y", "total in both/all" +Recipe: `search_user_profiles(query verbatim, top_k=medium-to-wide)` → inspect for all operands → if an operand is missing, run one targeted reformulation naming it → finish. +Reasoning hint: Do not stop after one number unless sufficient. Numeric snippets in the same topic are often scattered; broaden by concept plus missing atom. For totals across containers, retrieve each container and each count-bearing atom, then sum only compatible counts. +Rehydration use: ENCOURAGED sparingly — when ONE specific operand is missing AND a retrieved profile clearly identifies the session it would be in, AND no other profile contains the operand even with a reformulated query. For multi-operand misses, prefer reformulation; rehydration is a single-atom recovery tool. + +### Pattern G — Prior assistant output / generated artifact +Shape: "you told me", "previous chat", "what was the schedule/table/recommendation", "what color/name/row/shift did you give", "remind me what you said about X", "what specific [item N / cell / attribute] did you mention" +Recipe (mandatory steps, in order): +1. `search_user_profiles(query verbatim, top_k=medium)`. +2. Read every retrieved profile end-to-end. Identify whether ANY profile contains the EXACT cell/attribute the question asks for (the specific name, color, row, item-N, schedule entry, list element). +3. **MANDATORY decision**: + - If any retrieved profile DOES contain the exact cell → finish from cache, citing that profile_id. + - If at least one retrieved profile names the artifact/topic/session but the exact cell is ABSENT from every retrieved profile → you MUST call `read_session_text(session_id=, span=)`. Do NOT finish with "no evidence" before rehydrating. Profiles store summaries; assistant-generated artifacts (tables, lists, schedules, image descriptions, recommendations) live in the source turn — that is where the missing cell is. + - If no retrieved profile even names the artifact/topic → run ONE targeted reformulation that pairs the artifact with the requested slot. Only after that fails should you finish with "no evidence". +4. After the rehydration call (1-call hard budget), read the returned excerpt and finish using it. + +Reasoning hint: Pattern G is the canonical rehydration case. The user is asking for a specific element of something the assistant produced; the cache stores topic summaries, not the element. Span tip: copy a 5-15 word distinctive phrase verbatim from the topic-naming profile's content — typically the topic's name plus one disambiguating qualifier. Substring match — paraphrases fail; if the first span returns "span not found", try a different short phrase from the same profile. + +### Pattern H — Behaviour/playbook recall +Shape: "how should you respond to me", "what do I prefer in your answers", "use my usual style" +Recipe: `search_user_playbooks(query verbatim, top_k=narrow-to-medium)` → if insufficient, `search_user_profiles(query verbatim, top_k=medium)` → global only if explicitly general or user memory insufficient. +Reasoning hint: Use playbooks for behavioural rules and profiles for factual preferences/preparations. Keep the distinction clear. + +## When the question doesn't match a pattern above + +Be flexible. Decompose the question: +1. What kind of evidence answers this: specific fact, preference, count, dates, calculation, assistant artifact, behaviour rule, or combination? +2. List exact required atoms before searching further. +3. Predict alternate phrasings in storage, including synonyms and action/status words. +4. Estimate breadth: rare entity → narrow; aggregation/updates/multi-session → broader. If unsure, call `storage_stats`. +5. Compose tools in this order: verbatim search first, then targeted reformulation or `refine_with` second-stage, then gated rehydration only if applicable. +6. If first retrieval is empty or off-target, try ONE reformulation with different wording before giving up. Reformulate by naming missing atoms, not by vague paraphrase. +7. Multi-pattern questions: run the relevant searches and synthesize only from retrieved evidence. + +## Example invocation traces (read these — they show what tool sequences look like) + +The patterns above describe recipes; these examples show concrete tool calls. The scenarios are in software-engineering and sport domains and are not from any benchmark — they only illustrate the mechanics. + +### Example A — Pattern G with mandatory rehydration (single-cell recall from assistant artifact) + +User question: "What was the third REST endpoint you suggested for our checkout service?" + +Turn 1 — search_user_profiles: +- query: "What was the third REST endpoint you suggested for our checkout service?" +- top_k: 10 +- Returns (illustrative): + - profile `[a1b2c3]` session `s_chkout_42`: "agent recommended 5 REST endpoints for the user's checkout service" + - profile `[d4e5f6]` session `s_chkout_42`: "user is building a checkout service for an e-commerce platform" + +Inspection: profile `[a1b2c3]` names the artifact (5 endpoints) and the session, but the specific third endpoint is NOT in any retrieved profile's content. Pattern G's check applies: topic-naming profile present, atom absent → rehydrate. + +Turn 2 — read_session_text: +- session_id: "s_chkout_42" +- span: "5 REST endpoints" +- Returns (illustrative): "Here are 5 REST endpoints for your checkout service: 1. POST /cart/items, 2. GET /cart/, 3. POST /checkout/initiate, 4. POST /checkout/confirm, 5. GET /orders/" + +Turn 3 — finish: +- The third endpoint is `POST /checkout/initiate`. Sources: [a1b2c3, s_chkout_42]. + +### Example B — Pattern E with mandatory rehydration (missing event date) + +User question: "How many days passed between when I shipped the v3.2 patch and when I started the v4 redesign?" + +Turn 1 — search_user_profiles: +- query: "How many days passed between when I shipped the v3.2 patch and when I started the v4 redesign?" +- top_k: 12 +- Returns (illustrative): + - profile `[p9q8r7]` session `s_release_11` (session_date 2024-10-30): "user shipped v3.2 patch" + - profile `[s6t5u4]` session `s_planning_19` (session_date 2024-11-22): "user is starting v4 redesign" + +Inspection: both events are present, but profile `[p9q8r7]` lacks an explicit content date for the v3.2 ship. The session_date 2024-10-30 is metadata-tied — usable. Profile `[s6t5u4]` has a usable session_date metadata. Pattern E's check: do all events have either a content or metadata date? Yes. Skip rehydration; proceed to finish. + +Turn 2 — finish: +- v3.2 shipped 2024-10-30 (session metadata), v4 redesign started 2024-11-22 (session metadata). Difference: 23 days. Sources: [p9q8r7, s6t5u4]. + +(Counter-example: if profile `[p9q8r7]` had no usable date AND no metadata date, Turn 2 would call `read_session_text(session_id="s_release_11", span="shipped v3.2 patch")` to recover the date from the source turn, then finish on Turn 3.) + +### Example C — Pattern A direct recall (no rehydration needed) + +User question: "What was the team-name we agreed on?" + +Turn 1 — search_user_profiles: +- query: "What was the team-name we agreed on?" +- top_k: 5 +- Returns (illustrative): + - profile `[t9u8v7]` session `s_kickoff_03`: "user and team agreed on the name 'Pelagia Forge' for the new working group" + +Inspection: profile content directly contains the answer. Pattern A applies — finish from cache, no rehydration. + +Turn 2 — finish: +- "Pelagia Forge". Sources: [t9u8v7]. + +## Choosing top_k (runtime decision) + +Read the situation; do not use a fixed default. +- Specific factual question, named entity → narrow, unless the attribute may be stored separately or the name/topic is common. +- Assistant artifact single-cell recall → medium enough to capture topic/session plus artifact descriptors; avoid huge noisy sets before H. +- Aggregation/counting → broad enough to include duplicates, related nonqualifiers, and omissions; call `storage_stats` if the store may be small enough to inspect nearly all plausible evidence. +- Preference application → broad first pass to surface preferences/preparations, then second `search_user_profiles` call with `refine_with` set to the specific facet of interest. +- Temporal/numeric multi-atom reasoning → enough results to cover every named event/operand; if one atom is missing, reformulate for that atom rather than only enlarging. +- Unknown storage size → call `storage_stats`; if the store is small, search enough to inspect essentially all plausible evidence. + +If first retrieval misses what you predicted, prefer a targeted reformulated query over a larger `top_k`. If first retrieval contains the needed profiles plus lots of noise, re-issue `search_user_profiles` with the same `query` plus a `refine_with` that names the specific facet you want before answering. + +## Narration requirement + +Before each tool call, briefly narrate: +- Which pattern (A-H) you're applying, OR your decomposition for an unfamiliar shape. +- Why your chosen top_k fits storage size and evidence breadth. +- What evidence you expect to surface. +- For counts/dates/calculations/artifacts, the current missing atoms list. + +After results and before finish, note accepted/rejected evidence: counted items, superseded facts, dates used (content and metadata), operands found, artifact cell found/missing, and any missing atom that remains. This makes orchestration choices reviewable post-hoc. + +## Counting and numeric-disambiguation rule (strict) + +If the query asks "how many" or implies counting distinct items, prefer enumerating unique named items or unique qualifying facts from snippets rather than trusting an older aggregate sentence. If an asserted total conflicts with enumerated items, prefer the most recent explicit total only when it directly answers the same predicate; otherwise surface both when enable_agent_answer is `true`. Do not count related facts that lack the required action/status. + +## Temporal emphasis + +If the query contains time markers ("before", "after", "since", "on DATE", "days between", "weeks ago", "order from first to last"), prioritize explicit dates/timestamps and session ids. Always copy exact date/timestamp into your notes/output if found. When a snippet includes a session/profile metadata date, treat it as usable temporal evidence for facts stated in that same profile/session unless contradicted by content. Do not declare a date missing merely because it is absent from profile text if metadata ties the fact to a dated session. + +## Step budget + +You have at most {max_steps} LLM rounds including finish. The budget supports a typical flow of: optional `storage_stats` → verbatim search → optional second search with `refine_with` (or one targeted reformulated query) → optional rehydration (when Pattern E or G mandates it) → finish. Aim for the smallest path that answers the question; only spend additional rounds when the pattern's recipe explicitly calls for them. The `finish` call counts as a round — leave one round in reserve for it. + +## Inspecting results (concrete checklist) + +When you receive snippets: +- Read snippets fully. If truncated and the question is a single missing atom, use a targeted follow-up or rehydrate per the per-pattern guidance. +- Copy exact atoms into notes: date/timestamp, session id, numeric counts, quoted phrase, exact name, shift time, color/image attribute, item names. +- For each candidate, mark direct answer, related, duplicate, or superseded. +- Maintain a short missing-atoms list and only reformulate to request those atoms. +- For answer mode, reason from evidence, not ranking: the top result can be older, broader, or merely topical. + +## Expected answer format (ONLY when enable_agent_answer is `true`) + +- 1–2 line direct answer, then a bulletized list of sources. Each source entry must include: + - type (UserProfile/UserPlaybook/AgentPlaybook/session) + - id + - quoted excerpt or precise paraphrase justifying the claim +- If you computed duration, order, count, or numeric value, show arithmetic or enumeration and source ids. +- If no evidence: exactly the phrase "no evidence in memory" and nothing else. + +## Search-only output rule (ONLY when enable_agent_answer is `false`) + +After completing searches, call `finish()` with no arguments. The host produces the final response from surfaced entities. Do not include natural-language synthesis or evidence formatting. + +## Quality & efficiency guardrails + +- Keep retrievals minimal and strictly evidentiary. +- Never invent. +- Limit follow-ups: one high-quality targeted follow-up is better than many paraphrases. +- Reduce wall time by avoiding repeated blind reformulations; only follow up when you can name missing atoms precisely. + +## Hard constraints reminder (do not override) + +- First search call: verbatim. Your first SEARCH tool call MUST pass the user's query VERBATIM as the `query` argument — no paraphrasing, no keyword-bag, no shortening. +- Per-user first, global second. Prefer per-user indexes before AgentPlaybook unless the question is explicitly about general agent behaviour or user memory is insufficient. +- Mode-correct finish: when enable_agent_answer is `true`, call `finish(answer="...")`; when `false`, call `finish()` with no arguments. + +## Query + +{query} diff --git a/reflexio/server/services/base_generation_service.py b/reflexio/server/services/base_generation_service.py index 68aef1ee..e15127fe 100644 --- a/reflexio/server/services/base_generation_service.py +++ b/reflexio/server/services/base_generation_service.py @@ -86,9 +86,11 @@ def _iter_user_contents( """Collect the ``content`` of every User-role interaction, order-preserving.""" out: list[str] = [] for model in session_data_models: - for interaction in model.interactions: - if interaction.role == "User" and interaction.content: - out.append(interaction.content) + out.extend( + interaction.content + for interaction in model.interactions + if interaction.role == "User" and interaction.content + ) return out diff --git a/reflexio/server/services/extraction/__init__.py b/reflexio/server/services/extraction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reflexio/server/services/extraction/agentic_adapter.py b/reflexio/server/services/extraction/agentic_adapter.py new file mode 100644 index 00000000..cfa1c11b --- /dev/null +++ b/reflexio/server/services/extraction/agentic_adapter.py @@ -0,0 +1,250 @@ +"""Adapter wiring ``ExtractionAgent`` into the classic publish flow. + +The classic ``GenerationService.run`` expects a pair of generation services +(profile + playbook) it can fan out in parallel. The agentic-v2 runner is +a single service that iterates extractor configs and calls ``ExtractionAgent`` +once per config, committing directly to storage via ``commit_plan``. + +This module provides ``AgenticExtractionRunner`` — a thin wrapper that: + +1. Applies the same ``_cheap_should_run_reject`` pre-filter the classic + path uses (honouring ``force_extraction``). +2. Renders the scoped interactions into a transcript string. +3. Iterates all enabled ``ProfileExtractorConfig`` and + ``UserPlaybookExtractorConfig`` entries and calls ``ExtractionAgent.run`` + once per config. The agent itself handles search, create, delete, and + commit (supersession / merge / expansion). +4. Triggers ``PlaybookAggregator`` for every configured playbook with an + ``aggregation_config``, unless ``skip_aggregation`` was set on the + publish request. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from reflexio.models.api_schema.internal_schema import RequestInteractionDataModel +from reflexio.models.api_schema.service_schemas import Request +from reflexio.server.services.base_generation_service import _cheap_should_run_reject +from reflexio.server.services.extraction.extraction_agent import ExtractionAgent +from reflexio.server.services.extraction.tools import ( + PLAYBOOK_EXTRACTION_TOOLS, + PROFILE_EXTRACTION_TOOLS, +) +from reflexio.server.services.playbook.playbook_aggregator import PlaybookAggregator +from reflexio.server.services.playbook.playbook_service_utils import ( + PlaybookAggregatorRequest, +) +from reflexio.server.services.service_utils import format_sessions_to_history_string + +if TYPE_CHECKING: + from reflexio.models.api_schema.domain.entities import Interaction + from reflexio.models.api_schema.service_schemas import PublishUserInteractionRequest + from reflexio.models.config_schema import Config + from reflexio.server.api_endpoints.request_context import RequestContext + from reflexio.server.llm.litellm_client import LiteLLMClient + +logger = logging.getLogger(__name__) + + +class AgenticExtractionRunner: + """Wrap ``ExtractionAgent`` so it mirrors the classic publish contract. + + Iterates each enabled extractor config (profile + playbook) and calls + ``ExtractionAgent.run`` once per config. The agent handles its own + search-then-mutate loop and commits the plan directly to storage. + + Args: + llm_client (LiteLLMClient): Configured LLM client. + request_context (RequestContext): Provides ``storage``, ``prompt_manager``, + and ``configurator``. + """ + + def __init__( + self, + *, + llm_client: LiteLLMClient, + request_context: RequestContext, + ) -> None: + self.client = llm_client + self.request_context = request_context + self.storage = request_context.storage + + def run( + self, + *, + publish_request: PublishUserInteractionRequest, + request_id: str, + new_interactions: list[Interaction], + new_request: Request, + config: Config, + ) -> list[str]: + """Run agentic extraction + aggregation and persist. + + Args: + publish_request (PublishUserInteractionRequest): The original + publish request — ``source``, ``agent_version``, + ``force_extraction``, ``skip_aggregation`` are read from it. + request_id (str): Per-publish UUID assigned by ``GenerationService.run``. + new_interactions (list[Interaction]): Interactions persisted for + this publish, used for both the pre-filter and transcript. + new_request (Request): The ``Request`` row just persisted; used + to synthesise the precheck ``RequestInteractionDataModel``. + config (Config): Resolved top-level config. ``profile_extractor_configs`` + and ``user_playbook_extractor_configs`` each drive one agent loop; + ``user_playbook_extractor_configs`` also drives the aggregator loop. + + Returns: + list[str]: Non-fatal warnings to surface back to the caller. + """ + warnings: list[str] = [] + session_data_models = self._build_session_data_models( + new_interactions=new_interactions, new_request=new_request + ) + + # Phase 1 — pre-filter: cheap reject for sessions with no learnable signal. + if not publish_request.force_extraction: + reason = _cheap_should_run_reject(session_data_models) + if reason is not None: + logger.info( + "agentic pre-filter rejected: reason=%s identifier=%s", + reason, + publish_request.user_id, + ) + return warnings + + # Phase 2 — render transcript once; all agent calls share the same text. + sessions_str = format_sessions_to_history_string(session_data_models) + + # Phase 3 — build typed extractor config list (profile then playbook). + # Each tuple carries: (entity_kind, extractor_config, tool_registry). + profile_configs = list(config.profile_extractor_configs or []) + playbook_configs = list(config.user_playbook_extractor_configs or []) + typed_configs: list[tuple[str, object, object]] = [ + *[ + ("UserProfile", cfg, PROFILE_EXTRACTION_TOOLS) + for cfg in profile_configs + ], + *[ + ("UserPlaybook", cfg, PLAYBOOK_EXTRACTION_TOOLS) + for cfg in playbook_configs + ], + ] + + # Phase 4 — run ExtractionAgent once per enabled extractor config. + for kind, cfg, registry in typed_configs: + extractor_name: str = cfg.extractor_name # type: ignore[union-attr] + extraction_criteria: str = cfg.extraction_definition_prompt # type: ignore[union-attr] + try: + agent = ExtractionAgent( + client=self.client, + storage=self.storage, + prompt_manager=self.request_context.prompt_manager, + registry=registry, # type: ignore[arg-type] + # Tight budget for benchmark throughput; default is 12. + # Floor is 3 (search → batch creates → finish); 4 leaves + # room for one follow-up search when needed. + max_steps=4, + ) + result = agent.run( + user_id=publish_request.user_id, + agent_version=publish_request.agent_version, + extractor_name=extractor_name, + extraction_criteria=extraction_criteria, + sessions_text=sessions_str, + extraction_kind=kind, # type: ignore[arg-type] + request_id=request_id, + ) + logger.info( + "extraction_agent[%s] kind=%s outcome=%s applied=%d violations=%d", + extractor_name, + kind, + result.outcome, + len(result.applied), + len(result.violations), + ) + warnings.extend( + f"extraction_agent[{extractor_name}] violation {v.code}: {v.msg}" + for v in result.violations + if v.severity == "hard" + ) + except Exception as e: # noqa: BLE001 - degrade gracefully per extractor + logger.warning( + "extraction_agent[%s] failed: %s: %s", + extractor_name, + type(e).__name__, + e, + ) + warnings.append(f"extraction_agent[{extractor_name}] failed: {e}") + + # Phase 5 — playbook aggregation: mirrors classic per-config loop. + if not publish_request.skip_aggregation: + self._run_aggregation( + config=config, publish_request=publish_request, warnings=warnings + ) + + return warnings + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + @staticmethod + def _build_session_data_models( + *, new_interactions: list[Interaction], new_request: Request + ) -> list[RequestInteractionDataModel]: + """Wrap this publish's interactions in a single-element batch for the precheck. + + Args: + new_interactions (list[Interaction]): The interactions for this publish. + new_request (Request): The request row just persisted. + + Returns: + list[RequestInteractionDataModel]: Single-element list for the precheck. + """ + return [ + RequestInteractionDataModel( + session_id=new_request.session_id or "", + request=new_request, + interactions=list(new_interactions), + ) + ] + + def _run_aggregation( + self, + *, + config: Config, + publish_request: PublishUserInteractionRequest, + warnings: list[str], + ) -> None: + """Run ``PlaybookAggregator`` for every configured playbook with an ``aggregation_config``. + + Args: + config (Config): Resolved top-level config with playbook extractor configs. + publish_request (PublishUserInteractionRequest): Provides ``agent_version``. + warnings (list[str]): Mutable list; aggregation failures are appended. + """ + aggregator = PlaybookAggregator( + llm_client=self.client, + request_context=self.request_context, + agent_version=publish_request.agent_version, + ) + for pb_cfg in config.user_playbook_extractor_configs or []: + if not getattr(pb_cfg, "aggregation_config", None): + continue + try: + aggregator.run( + PlaybookAggregatorRequest( + agent_version=publish_request.agent_version, + playbook_name=pb_cfg.extractor_name, + ) + ) + except Exception as e: # noqa: BLE001 - degrade gracefully + logger.warning( + "agentic aggregation failed for %s: %s: %s", + pb_cfg.extractor_name, + type(e).__name__, + e, + ) + warnings.append(f"aggregation failed for {pb_cfg.extractor_name}: {e}") diff --git a/reflexio/server/services/extraction/extraction_agent.py b/reflexio/server/services/extraction/extraction_agent.py new file mode 100644 index 00000000..35fe49b1 --- /dev/null +++ b/reflexio/server/services/extraction/extraction_agent.py @@ -0,0 +1,188 @@ +"""Thin runner for the agentic-v2 extraction pipeline. + +Assembles messages, invokes run_tool_loop with a per-kind tool registry, and +calls commit_plan on termination. Returns a CommitResult. +""" + +from __future__ import annotations + +import logging +import time +from collections import Counter +from typing import Literal + +from reflexio.server.llm.litellm_client import LiteLLMClient +from reflexio.server.llm.model_defaults import ModelRole +from reflexio.server.llm.tools import ToolLoopTrace, ToolRegistry, run_tool_loop +from reflexio.server.prompt.prompt_manager import PromptManager +from reflexio.server.services.extraction.invariants import commit_plan +from reflexio.server.services.extraction.plan import ( + CommitResult, + ExtractionCtx, + HandlerBundle, +) +from reflexio.server.services.extraction.tools import EXTRACTION_TOOLS + +logger = logging.getLogger(__name__) + + +def _summarise_tool_calls(trace: ToolLoopTrace) -> str: + """Return a compact 'tool_a:2, tool_b:1' string from a ToolLoopTrace. + + Args: + trace (ToolLoopTrace): The completed tool loop trace. + + Returns: + str: Comma-separated name:count pairs ordered by frequency, or '(none)'. + """ + counts = Counter(t.tool_name for t in trace.turns) + return ", ".join(f"{name}:{n}" for name, n in counts.most_common()) or "(none)" + + +def _summarise_usage(trace: ToolLoopTrace) -> str: + """Return a per-model 'model_x: N tokens, $0.0078' string aggregated across all turns. + + A single response's usage is attached to every turn it produced, so this + function deduplicates by (model, prompt_tokens, completion_tokens) to avoid + double-counting when one LLM call produced multiple tool calls. + + Args: + trace (ToolLoopTrace): The completed tool loop trace. + + Returns: + str: Semicolon-separated per-model summaries, or '(none)'. + """ + seen: set[tuple[str, int, int]] = set() + per_model: dict[str, dict[str, float]] = {} + for t in trace.turns: + if t.model is None or t.prompt_tokens is None or t.completion_tokens is None: + continue + key = (t.model, t.prompt_tokens, t.completion_tokens) + if key in seen: + continue + seen.add(key) + bucket = per_model.setdefault(t.model, {"tokens": 0.0, "cost": 0.0}) + bucket["tokens"] += t.total_tokens or 0 + bucket["cost"] += t.cost_usd or 0.0 + if not per_model: + return "(none)" + return "; ".join( + f"{m}: {int(v['tokens'])} tokens, ${v['cost']:.6f}" + for m, v in per_model.items() + ) + + +class ExtractionAgent: + """Single-loop adaptive extraction agent. + + Assembles the seed message from the extraction prompt, drives + ``run_tool_loop`` with a per-entity-kind tool registry, and commits the + accumulated plan via ``commit_plan`` on termination (finish or max_steps). + + Args: + client (LiteLLMClient): LLM client for the underlying tool loop. + storage: BaseStorage handle (read + commit targets). + prompt_manager (PromptManager): Renders the ``extraction_agent`` prompt. + max_steps (int): Cap on tool-calling turns (default 12; see spec §7.2). + registry (ToolRegistry | None): Tool registry to use. Defaults to + ``EXTRACTION_TOOLS`` (backward-compat union of all tools). Production + callers should pass ``PROFILE_EXTRACTION_TOOLS`` or + ``PLAYBOOK_EXTRACTION_TOOLS`` to restrict the LLM to one entity kind. + """ + + def __init__( + self, + *, + client: LiteLLMClient, + storage: object, + prompt_manager: PromptManager, + max_steps: int = 12, + registry: ToolRegistry | None = None, + ) -> None: + self.client = client + self.storage = storage + self.prompt_manager = prompt_manager + self.max_steps = max_steps + self.registry = registry if registry is not None else EXTRACTION_TOOLS + + def run( + self, + *, + user_id: str, + agent_version: str, + extractor_name: str, + extraction_criteria: str, + sessions_text: str, + extraction_kind: Literal["UserProfile", "UserPlaybook"] = "UserProfile", + request_id: str = "", + ) -> CommitResult: + """Run one extraction loop over the given session text. + + Args: + user_id (str): Authenticated user scope. + agent_version (str): Active agent_version for this extractor config. + extractor_name (str): The ``name`` field of the extractor config + (used as an implicit storage filter). + extraction_criteria (str): ``extraction_criteria`` text from the + extractor config, rendered into the agent's prompt. + sessions_text (str): Pre-rendered session transcript. + extraction_kind (Literal["UserProfile", "UserPlaybook"]): Entity + kind this run targets. Rendered into the prompt to scope the + LLM's narrative. Defaults to ``"UserProfile"`` for backward + compat with existing test callers that omit this argument. + request_id (str): Source publish_interaction UUID; embedded into + every profile/playbook this run creates so callers can trace + back to the originating publish. Defaults to "" for test + callers that don't have a publish request in scope. + + Returns: + CommitResult: Includes applied ops, violations, and outcome. + """ + ctx = ExtractionCtx( + user_id=user_id, + agent_version=agent_version, + extractor_name=extractor_name, + request_id=request_id, + ) + bundle = HandlerBundle(storage=self.storage, ctx=ctx) + + prompt = self.prompt_manager.render_prompt( + "extraction_agent", + variables={ + "sessions": sessions_text, + "extraction_criteria": extraction_criteria, + "extraction_kind": extraction_kind, + "max_steps": str(self.max_steps), + }, + ) + + t0 = time.monotonic() + result = run_tool_loop( + client=self.client, + messages=[{"role": "user", "content": prompt}], + registry=self.registry, + model_role=ModelRole.EXTRACTION_AGENT, + max_steps=self.max_steps, + ctx=bundle, + finish_tool_name="finish", + log_label=f"extraction_agent[{extractor_name}]", + ) + + commit = commit_plan(ctx, self.storage, outcome=result.finished_reason) + elapsed_ms = int((time.monotonic() - t0) * 1000) + + logger.info( + "extraction_agent[%s] kind=%s elapsed_ms=%d turns=%d/%d tools={%s} " + "outcome=%s applied=%d violations=%s usage={%s}", + extractor_name, + extraction_kind, + elapsed_ms, + len(result.trace.turns), + self.max_steps, + _summarise_tool_calls(result.trace), + commit.outcome, + len(commit.applied), + sorted({v.code for v in commit.violations}) or "[]", + _summarise_usage(result.trace), + ) + return commit diff --git a/reflexio/server/services/extraction/invariants.py b/reflexio/server/services/extraction/invariants.py new file mode 100644 index 00000000..1b317071 --- /dev/null +++ b/reflexio/server/services/extraction/invariants.py @@ -0,0 +1,303 @@ +"""Plan-level invariants for the agentic-v2 extraction pipeline. + +Invariants are pure functions over ``ExtractionCtx``. Hard violations drop +offending ops from the commit; soft violations are logged and applied. +See spec §6 for the full catalog and severity policy. +""" + +from __future__ import annotations + +import logging + +from reflexio.server.services.extraction.plan import ( + CommitResult, + CreateUserPlaybookOp, + CreateUserProfileOp, + DeleteUserPlaybookOp, + DeleteUserProfileOp, + ExtractionCtx, + Violation, +) + +logger = logging.getLogger(__name__) + +PLAN_SIZE_CAP = 30 + + +# --- Hard invariants --- + + +def inv_A_search_before_create(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """Every CreateOp must be preceded by ≥1 search_* call this run.""" + create_indices = [ + i + for i, op in enumerate(ctx.plan) + if isinstance(op, (CreateUserProfileOp, CreateUserPlaybookOp)) + ] + if create_indices and ctx.search_count == 0: + return [ + Violation( + code="A", + severity="hard", + affected_op_indices=create_indices, + msg="Plan has create ops but no search was performed this run", + ) + ] + return [] + + +def inv_B_delete_known_id(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """Every DeleteOp(id) must reference an id in ctx.known_ids. + + known_ids is populated by search/get/create tool handlers — so deletes + targeting hallucinated ids (agent never saw them) are rejected. + """ + violations: list[Violation] = [] + for i, op in enumerate(ctx.plan): + if ( + isinstance(op, (DeleteUserProfileOp, DeleteUserPlaybookOp)) + and op.id not in ctx.known_ids + ): + violations.append( + Violation( + code="B", + severity="hard", + affected_op_indices=[i], + msg=f"Delete of unknown id {op.id!r}", + ) + ) + return violations + + +def inv_D_plan_size_cap(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """Plan cannot exceed PLAN_SIZE_CAP ops — guards runaway loops.""" + if len(ctx.plan) > PLAN_SIZE_CAP: + overflow = list(range(PLAN_SIZE_CAP, len(ctx.plan))) + return [ + Violation( + code="D", + severity="hard", + affected_op_indices=overflow, + msg=f"Plan size {len(ctx.plan)} exceeds cap {PLAN_SIZE_CAP}", + ) + ] + return [] + + +def inv_F_no_duplicate_deletes(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """Same id cannot be deleted twice in one plan.""" + seen: set[str] = set() + violations: list[Violation] = [] + for i, op in enumerate(ctx.plan): + if isinstance(op, (DeleteUserProfileOp, DeleteUserPlaybookOp)): + if op.id in seen: + violations.append( + Violation( + code="F", + severity="hard", + affected_op_indices=[i], + msg=f"Duplicate delete of id {op.id!r}", + ) + ) + else: + seen.add(op.id) + return violations + + +def inv_J_scope_match(_ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """User_id scope is primarily enforced at the storage layer (handlers inject + ctx.user_id). This invariant is a placeholder for future cross-user checks; + for v1 it is a no-op.""" + return [] + + +HARD_INVARIANTS = ( + inv_A_search_before_create, + inv_B_delete_known_id, + inv_D_plan_size_cap, + inv_F_no_duplicate_deletes, + inv_J_scope_match, +) + + +# --- Soft invariants --- + + +def inv_E_no_duplicate_creates(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """Two CreateOps with identical content in one plan = oscillation smell.""" + seen: dict[str, int] = {} + violations: list[Violation] = [] + for i, op in enumerate(ctx.plan): + key = None + if isinstance(op, CreateUserProfileOp): + key = f"profile::{op.content}" + elif isinstance(op, CreateUserPlaybookOp): + key = f"playbook::{op.trigger}::{op.content}" + if key is None: + continue + if key in seen: + violations.append( + Violation( + code="E", + severity="soft", + affected_op_indices=[i], + msg=f"Duplicate create content at op {i}", + ) + ) + else: + seen[key] = i + return violations + + +def inv_H_source_span_present(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """CreateOps must have non-whitespace source_span. + + Schema enforces min_length=1, but whitespace-only slips through — + this is the secondary guard. + """ + violations: list[Violation] = [] + for i, op in enumerate(ctx.plan): + if ( + isinstance(op, (CreateUserProfileOp, CreateUserPlaybookOp)) + and not op.source_span.strip() + ): + violations.append( + Violation( + code="H", + severity="soft", + affected_op_indices=[i], + msg=f"Empty/whitespace source_span on create op {i}", + ) + ) + return violations + + +def inv_K_deletes_without_creates(ctx: ExtractionCtx) -> list[Violation]: # noqa: N802 + """Plan with deletes but no creates is unusual — worth logging.""" + has_delete = any( + isinstance(op, (DeleteUserProfileOp, DeleteUserPlaybookOp)) for op in ctx.plan + ) + has_create = any( + isinstance(op, (CreateUserProfileOp, CreateUserPlaybookOp)) for op in ctx.plan + ) + if has_delete and not has_create: + indices = [ + i + for i, op in enumerate(ctx.plan) + if isinstance(op, (DeleteUserProfileOp, DeleteUserPlaybookOp)) + ] + return [ + Violation( + code="K", + severity="soft", + affected_op_indices=indices, + msg="Plan contains deletes without any matching creates", + ) + ] + return [] + + +SOFT_INVARIANTS = ( + inv_E_no_duplicate_creates, + inv_H_source_span_present, + inv_K_deletes_without_creates, +) + + +# --- Oscillation resolver --- + + +def resolve_tentative_oscillations(plan: list) -> set[int]: + """Return plan indices to drop: create+delete-tentative pairs cancel. + + When the agent creates an entity (issuing a tentative_id) and later + deletes that same tentative_id within the same plan, both ops are + dropped before invariants fire. This is the "oscillated self-correction" + pattern — the agent changed its mind mid-run. + + The tentative_id format is ``tentative::::``, + matching ``_next_tentative_id`` in tools.py which uses ``len(ctx.plan)`` + (the plan length BEFORE the op is appended, i.e. the future index of the op). + + Args: + plan: The accumulated list of PlanOp instances from ctx.plan. + + Returns: + Set of plan indices to exclude from apply. Both the create and the + delete are dropped when a matching pair is found. + """ + drop: set[int] = set() + pending_creates: dict[str, int] = {} + for i, op in enumerate(plan): + if isinstance(op, CreateUserProfileOp): + tentative_id = f"tentative::profile::{i}" + pending_creates[tentative_id] = i + elif isinstance(op, CreateUserPlaybookOp): + tentative_id = f"tentative::user_playbook::{i}" + pending_creates[tentative_id] = i + elif isinstance(op, (DeleteUserProfileOp, DeleteUserPlaybookOp)): + if op.id.startswith("tentative::") and op.id in pending_creates: + drop.add(pending_creates.pop(op.id)) + drop.add(i) + return drop + + +# --- commit_plan --- + + +def commit_plan( + ctx: ExtractionCtx, + storage: object, + *, + outcome: str, # Literal["finish_tool","max_steps","error"] +) -> CommitResult: + """Run all invariants, then apply surviving ops atomically. + + Args: + ctx: Populated ExtractionCtx from the agent loop. + storage: BaseStorage handle for apply. + outcome: How the loop terminated. + + Returns: + CommitResult containing applied ops + all violations (hard + soft). + """ + # Error outcome — discard everything, do not apply + if outcome == "error": + return CommitResult(applied=[], violations=[], outcome="error") + + violations: list[Violation] = [] + for check in HARD_INVARIANTS: + violations.extend(check(ctx)) + for check in SOFT_INVARIANTS: + violations.extend(check(ctx)) + + dropped: set[int] = set() + # Oscillation resolver runs first: matching create+delete-tentative pairs + # cancel before invariants decide what to keep. + dropped.update(resolve_tentative_oscillations(ctx.plan)) + for v in violations: + if v.severity == "hard": + dropped.update(v.affected_op_indices) + + ops_to_apply = [op for i, op in enumerate(ctx.plan) if i not in dropped] + + for v in violations: + logger.info( + "invariant_violation user_id=%s code=%s severity=%s op_indices=%s msg=%s", + ctx.user_id, + v.code, + v.severity, + v.affected_op_indices, + v.msg, + ) + + # Delegate actual storage writes to the tool-handler module (Task 5 wires this in). + # Lazy import so Task 3 can land before tools.py exists. + from reflexio.server.services.extraction.tools import ( + apply_plan_op, # noqa: PLC0415 # type: ignore[import-not-found] + ) + + for op in ops_to_apply: + apply_plan_op(op, storage, ctx) + + return CommitResult(applied=ops_to_apply, violations=violations, outcome=outcome) # type: ignore[arg-type] diff --git a/reflexio/server/services/extraction/plan.py b/reflexio/server/services/extraction/plan.py new file mode 100644 index 00000000..97f91837 --- /dev/null +++ b/reflexio/server/services/extraction/plan.py @@ -0,0 +1,124 @@ +"""Plan-op types, ExtractionCtx, HandlerBundle, and commit-result types for the agentic-v2 pipeline. + +Tool handlers append PlanOp instances to ``ctx.plan`` rather than hitting +storage directly. A deterministic commit stage at ``finish`` (or on +``max_steps``) runs invariants and applies the valid ops atomically. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Annotated, Literal + +from pydantic import BaseModel, ConfigDict, Field + +# Mirrors ProfileTimeToLive — kept as Literal to avoid circular import on enum. +ProfileTTL = Literal[ + "one_day", "one_week", "one_month", "one_quarter", "one_year", "infinity" +] + +PlaybookStrength = Literal["hard", "soft"] + + +class _BasePlanOp(BaseModel): + """Base class for all PlanOp variants. Discriminated union via ``op``.""" + + model_config = ConfigDict(frozen=True) + + +class CreateUserProfileOp(_BasePlanOp): + op: Literal["create_user_profile"] = "create_user_profile" + content: Annotated[str, Field(min_length=1)] + ttl: ProfileTTL + source_span: Annotated[str, Field(min_length=1)] + + +class DeleteUserProfileOp(_BasePlanOp): + op: Literal["delete_user_profile"] = "delete_user_profile" + id: Annotated[str, Field(min_length=1)] + + +class CreateUserPlaybookOp(_BasePlanOp): + op: Literal["create_user_playbook"] = "create_user_playbook" + trigger: Annotated[str, Field(min_length=1)] + content: Annotated[str, Field(min_length=1)] + rationale: str = "" + strength: PlaybookStrength = "soft" + source_span: Annotated[str, Field(min_length=1)] + + +class DeleteUserPlaybookOp(_BasePlanOp): + op: Literal["delete_user_playbook"] = "delete_user_playbook" + id: Annotated[str, Field(min_length=1)] + + +PlanOp = Annotated[ + CreateUserProfileOp + | DeleteUserProfileOp + | CreateUserPlaybookOp + | DeleteUserPlaybookOp, + Field(discriminator="op"), +] + + +@dataclass +class ExtractionCtx: + """Per-run state for the extraction agent. + + Attributes: + user_id: Authenticated user the run is scoped to. + agent_version: Agent version from the active config. + extractor_name: Optional per-extractor scope filter. + request_id: Source publish_interaction request UUID — embedded into + every profile/playbook this run creates so retrieval can trace + back to the originating session. Empty string when called from + test contexts that don't have a publish request. + plan: Accumulated PlanOps awaiting commit. + known_ids: Ids the agent has legitimately seen (from search/get/create + handlers). Invariant B checks delete ids against this set. + search_count: Number of search_* tool calls. Invariant A gates on this. + finished: True once the agent calls the ``finish`` tool. + """ + + user_id: str + agent_version: str + extractor_name: str | None = None + request_id: str = "" + plan: list = field( + default_factory=list + ) # list[PlanOp] — type-erased to avoid forward-ref issues + known_ids: set[str] = field(default_factory=set) + search_count: int = 0 + finished: bool = False + search_answer: str | None = None + + +@dataclass(slots=True) +class HandlerBundle: + """Glue so tool handlers can access both storage and ctx through one param. + + The run_tool_loop primitive passes a single ``ctx`` param to tool handlers; + handlers in tools.py need both a BaseStorage handle and an ExtractionCtx. + Both ExtractionAgent and SearchAgent build one of these before driving + the loop. + + Args: + storage: BaseStorage handle. + ctx: ExtractionCtx with per-run state. + """ + + storage: object + ctx: ExtractionCtx + + +class Violation(BaseModel): + code: Literal["A", "B", "D", "E", "F", "H", "J", "K"] + severity: Literal["hard", "soft"] + affected_op_indices: list[int] + msg: str + + +class CommitResult(BaseModel): + applied: list[PlanOp] + violations: list[Violation] + outcome: Literal["finish_tool", "max_steps", "error"] diff --git a/reflexio/server/services/extraction/tools.py b/reflexio/server/services/extraction/tools.py new file mode 100644 index 00000000..554bb865 --- /dev/null +++ b/reflexio/server/services/extraction/tools.py @@ -0,0 +1,1142 @@ +"""Atomic tool handlers for the agentic-v2 extraction + search pipelines. + +Each handler: + - Receives args (Pydantic model validated by ToolRegistry) + - Receives (storage, ctx) + - Calls an existing BaseStorage method + - Returns a dict projection suitable for the LLM + +Read handlers populate ctx.known_ids (for invariant B) and ctx.search_count +(for invariant A). Mutating handlers (Task 5) append PlanOps to ctx.plan +without hitting storage; commit_plan applies them via apply_plan_op after +invariants pass. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, Field + +from reflexio.models.api_schema.domain.entities import ( + Status, + UserPlaybook, + UserProfile, +) +from reflexio.models.api_schema.domain.enums import ProfileTimeToLive +from reflexio.models.api_schema.retriever_schema import ( + SearchAgentPlaybookRequest, + SearchMode, + SearchUserPlaybookRequest, + SearchUserProfileRequest, +) +from reflexio.models.config_schema import SearchOptions +from reflexio.server.services.extraction.plan import ( + CreateUserPlaybookOp, + CreateUserProfileOp, + DeleteUserPlaybookOp, + DeleteUserProfileOp, + ExtractionCtx, + PlaybookStrength, + ProfileTTL, +) +from reflexio.server.services.profile.profile_generation_service_utils import ( + calculate_expiration_timestamp, +) + +TOP_K_CAP = 25 + + +# ==================================================================== +# Arg schemas (what the LLM emits) +# ==================================================================== + + +class SearchUserProfilesArgs(BaseModel): + """Semantic/keyword search the current user's profiles, with optional + second-stage rerank (pipe-equivalent of search → rerank). + + One-stage usage (the common case): supply ``query`` and ``top_k``. The + server runs hybrid retrieval (BM25 + vector via RRF) over-fetches a wider + candidate pool, then cross-encoder reranks by ``query`` and returns the + top ``top_k``. + + Two-stage refinement: also supply ``refine_with``. The server runs the + same broad retrieval by ``query``, then reranks the candidates by + ``refine_with`` instead. Lets you broadly fetch ("bike maintenance") then + narrow on a specific facet of interest ("dollar amounts spent") without + transcribing candidate ids back through the model. + """ + + query: Annotated[str, Field(min_length=1)] + top_k: int = 10 + refine_with: str | None = None + + +class GetUserProfileArgs(BaseModel): + """Retrieve a single UserProfile by id.""" + + id: Annotated[str, Field(min_length=1)] + + +class SearchUserPlaybooksArgs(BaseModel): + """Search the current user's playbooks.""" + + query: Annotated[str, Field(min_length=1)] + top_k: int = 10 + status: Literal["current", "pending", "archived"] = "current" + + +class GetUserPlaybookArgs(BaseModel): + """Retrieve a single UserPlaybook by id.""" + + id: Annotated[str, Field(min_length=1)] + + +class SearchAgentPlaybooksArgs(BaseModel): + """Search agent-version-scoped playbooks (read-only; search pipeline only).""" + + query: Annotated[str, Field(min_length=1)] + top_k: int = 10 + status: Literal["current", "pending", "archived"] = "current" + + +class GetAgentPlaybookArgs(BaseModel): + """Retrieve a single AgentPlaybook by id.""" + + id: Annotated[str, Field(min_length=1)] + + +class ReadSessionTextArgs(BaseModel): + """Retrieve a verbatim excerpt from a session by matching a span.""" + + session_id: Annotated[str, Field(min_length=1)] + span: Annotated[str, Field(min_length=1)] + + +class RerankUserProfilesArgs(BaseModel): + """Rerank a list of profile ids by query relevance using a cross-encoder. + + Use after `search_user_profiles` when the initial results are noisy and + you need to surface the most semantically relevant ones to the question. + """ + + query: Annotated[str, Field(min_length=1)] + profile_ids: list[str] + top_k: int = 10 + + +class StorageStatsArgs(BaseModel): + """Get a quick count of how many profiles/playbooks the user has and the date range. + + Useful for sizing search top_k appropriately before retrieval. + """ + + +# Mutating arg models (handlers in Task 5) +class CreateUserProfileArgs(BaseModel): + """Propose creating a new UserProfile record.""" + + content: Annotated[str, Field(min_length=1)] + ttl: ProfileTTL + source_span: Annotated[str, Field(min_length=1)] + + +class DeleteUserProfileArgs(BaseModel): + """Propose deleting an existing UserProfile by id.""" + + id: Annotated[str, Field(min_length=1)] + + +class CreateUserPlaybookArgs(BaseModel): + """Propose creating a new UserPlaybook record.""" + + trigger: Annotated[str, Field(min_length=1)] + content: Annotated[str, Field(min_length=1)] + rationale: str = "" + strength: PlaybookStrength = "soft" + source_span: Annotated[str, Field(min_length=1)] + + +class DeleteUserPlaybookArgs(BaseModel): + """Propose deleting an existing UserPlaybook by id.""" + + id: Annotated[str, Field(min_length=1)] + + +class FinishArgs(BaseModel): + """Terminate the loop.""" + + +class SearchFinishArgs(BaseModel): + """Terminate the search loop, optionally with a final answer. + + ``answer`` is opt-in: when the host runs the agent in search-only mode + (``enable_agent_answer=False``) the agent is instructed to call ``finish()`` + without an answer; the host synthesizes the final response itself from the + entities the agent harvested. + """ + + answer: str | None = None + + +# ==================================================================== +# Helpers +# ==================================================================== + + +def _cap_top_k(k: int) -> int: + return min(max(1, k), TOP_K_CAP) + + +def _maybe_embed_query(storage: Any, query: str) -> list[float] | None: + """Compute a query embedding via the storage backend's embedder. + + Returns ``None`` on any failure (backend doesn't expose ``_get_embedding``, + embedding provider unavailable, or embed call raises). Without an embedding, + storage downgrades HYBRID/VECTOR search to FTS-only — the classic search + path (``unified_search_service.py:151-158``) uses the same helper pattern. + + Args: + storage (Any): BaseStorage instance. + query (str): The search query to embed. + + Returns: + list[float] | None: The embedding vector, or ``None`` when unavailable. + """ + embed_fn = getattr(storage, "_get_embedding", None) + if embed_fn is None: + return None + try: + return embed_fn(query) + except Exception: # noqa: BLE001 — embedder failures must not break search + return None + + +def _status_from_str(s: str) -> Status | None: + return {"current": None, "pending": Status.PENDING, "archived": Status.ARCHIVED}[s] + + +def _project_profile_for_llm(p: Any) -> dict[str, Any]: + return { + "id": getattr(p, "profile_id", "") or "", + "content": p.content, + "ttl": p.profile_time_to_live, + "last_modified": p.last_modified_timestamp, + "source_span": getattr(p, "source_span", None), + } + + +def _project_user_playbook_for_llm(pb: Any) -> dict[str, Any]: + return { + "id": str(pb.user_playbook_id), + "trigger": pb.trigger, + "content": pb.content, + "rationale": pb.rationale, + "last_modified": getattr(pb, "created_at", 0), + } + + +def _project_agent_playbook_for_llm(pb: Any) -> dict[str, Any]: + return { + "id": str(pb.agent_playbook_id), + "trigger": pb.trigger, + "content": pb.content, + "rationale": pb.rationale, + "playbook_status": getattr(pb, "playbook_status", None), + "last_modified": getattr(pb, "created_at", 0), + } + + +# ==================================================================== +# Read handlers +# ==================================================================== + + +def _handle_search_user_profiles( + args: SearchUserProfilesArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Search the current user's profiles and bump search_count. + + Two-stage retrieval: hybrid (BM25 + vector via RRF) over-fetches a wider + candidate pool, then a cross-encoder rerank scores ``(query, content)`` + pairs and returns the top ``args.top_k`` by descending rerank score. + + The over-fetch + rerank pattern fixes a class of failures where the + bi-encoder ranks the right profile at #2-#15 by cosine but the top-1 + is a near-duplicate that the answer LLM picks first. Cross-encoders + model query-document interaction (e.g. matching a numeric question to + the profile that contains a number, not just topic similarity). + + Args: + args (SearchUserProfilesArgs): Query and top_k. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; search_count incremented in place. + + Returns: + dict[str, Any]: ``{"hits": [...]}`` with LLM-facing profile projections. + """ + final_k = _cap_top_k(args.top_k) + # When `refine_with` is set, over-fetch for rerank headroom. Otherwise + # we trust the hybrid retrieval ranking — empirically (held-out #10 vs + # #12) always-on cross-encoder rerank slightly hurt T-R: the rerank + # model was trained on MS MARCO web passages and ranks declarative + # facts differently from temporal-arithmetic reasoning needs. Making + # rerank opt-in via `refine_with` matches the agent's intent: rerank + # only when the agent has something specific to refine on. + use_rerank = args.refine_with is not None + fetch_k = ( + min(max(final_k * 3, 30), 50) if use_rerank else final_k + ) + + request = SearchUserProfileRequest( + query=args.query, + user_id=ctx.user_id, + top_k=fetch_k, + ) + hits = storage.search_user_profile( + request, + query_embedding=_maybe_embed_query(storage, args.query), + ) + ctx.search_count += 1 + + # Two-stage refinement when `refine_with` is supplied — server-side pipe + # of `search → rerank` without round-tripping candidate ids back through + # the agent. Lazy import so unit-test collection stays fast; the model + # is module-cached after first load. + if use_rerank and len(hits) > final_k: + try: + from reflexio.server.llm.rerank import score_pairs + + scores = score_pairs(args.refine_with, [h.content for h in hits]) # type: ignore[arg-type] + ranked = sorted( + zip(hits, scores, strict=True), + key=lambda pair: pair[1], + reverse=True, + ) + hits = [h for h, _ in ranked[:final_k]] + except Exception: # noqa: BLE001 — fall back to hybrid order on failure + hits = hits[:final_k] + + for h in hits: + pid = getattr(h, "profile_id", "") or "" + if pid: + ctx.known_ids.add(pid) + return {"hits": [_project_profile_for_llm(h) for h in hits]} + + +def _handle_get_user_profile( + args: GetUserProfileArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Retrieve a single UserProfile by id without bumping search_count. + + Args: + args (GetUserProfileArgs): Profile id to look up. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; known_ids updated on hit. + + Returns: + dict[str, Any]: ``{"profile": {...}}`` on hit, ``{"error": "not found"}`` on miss. + """ + all_profiles = storage.get_user_profile(ctx.user_id) + for p in all_profiles: + if (getattr(p, "profile_id", "") or "") == args.id: + ctx.known_ids.add(args.id) + return {"profile": _project_profile_for_llm(p)} + return {"error": "not found"} + + +def _handle_search_user_playbooks( + args: SearchUserPlaybooksArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Search the current user's playbooks and bump search_count. + + Args: + args (SearchUserPlaybooksArgs): Query, top_k, and status filter. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; search_count and known_ids updated. + + Returns: + dict[str, Any]: ``{"hits": [...]}`` with LLM-facing playbook projections. + """ + request = SearchUserPlaybookRequest( + query=args.query, + user_id=ctx.user_id, + agent_version=ctx.agent_version, + top_k=_cap_top_k(args.top_k), + status_filter=[_status_from_str(args.status)], + search_mode=SearchMode.HYBRID, + threshold=0.4, + ) + if ctx.extractor_name: + request.playbook_name = ctx.extractor_name + hits = storage.search_user_playbooks( + request, + options=SearchOptions(query_embedding=_maybe_embed_query(storage, args.query)), + ) + ctx.search_count += 1 + for h in hits: + ctx.known_ids.add(str(h.user_playbook_id)) + return {"hits": [_project_user_playbook_for_llm(h) for h in hits]} + + +def _handle_get_user_playbook( + args: GetUserPlaybookArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Retrieve a single UserPlaybook by id without bumping search_count. + + Args: + args (GetUserPlaybookArgs): Playbook id to look up. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; known_ids updated on hit. + + Returns: + dict[str, Any]: ``{"playbook": {...}}`` on hit, ``{"error": "not found"}`` on miss. + """ + candidates = storage.get_user_playbooks( + user_id=ctx.user_id, agent_version=ctx.agent_version + ) + for pb in candidates: + if str(pb.user_playbook_id) == args.id: + ctx.known_ids.add(args.id) + return {"playbook": _project_user_playbook_for_llm(pb)} + return {"error": "not found"} + + +def _handle_search_agent_playbooks( + args: SearchAgentPlaybooksArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Search agent-version-scoped playbooks and bump search_count. + + Args: + args (SearchAgentPlaybooksArgs): Query, top_k, and status filter. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; search_count and known_ids updated. + + Returns: + dict[str, Any]: ``{"hits": [...]}`` with LLM-facing agent playbook projections. + """ + request = SearchAgentPlaybookRequest( + query=args.query, + agent_version=ctx.agent_version, + top_k=_cap_top_k(args.top_k), + status_filter=[_status_from_str(args.status)], + search_mode=SearchMode.HYBRID, + threshold=0.4, + ) + if ctx.extractor_name: + request.playbook_name = ctx.extractor_name + hits = storage.search_agent_playbooks( + request, + options=SearchOptions(query_embedding=_maybe_embed_query(storage, args.query)), + ) + ctx.search_count += 1 + for h in hits: + ctx.known_ids.add(str(h.agent_playbook_id)) + return {"hits": [_project_agent_playbook_for_llm(h) for h in hits]} + + +def _handle_get_agent_playbook( + args: GetAgentPlaybookArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Retrieve a single AgentPlaybook by id without bumping search_count. + + Args: + args (GetAgentPlaybookArgs): Agent playbook id to look up. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; known_ids updated on hit. + + Returns: + dict[str, Any]: ``{"playbook": {...}}`` on hit, ``{"error": "not found"}`` on miss. + """ + candidates = storage.get_agent_playbooks(agent_version=ctx.agent_version) + for pb in candidates: + if str(pb.agent_playbook_id) == args.id: + ctx.known_ids.add(args.id) + return {"playbook": _project_agent_playbook_for_llm(pb)} + return {"error": "not found"} + + +def _handle_read_session_text( + args: ReadSessionTextArgs, + storage: Any, + ctx: ExtractionCtx, # noqa: ARG001 +) -> dict[str, Any]: + """Return the closest verbatim match of ``span`` inside ``session_id``. + + Args: + args (ReadSessionTextArgs): Session id and span string to match. + storage (Any): BaseStorage instance; must have ``get_interactions_by_session``. + ctx (ExtractionCtx): Per-run state (unused for reads, present for consistency). + + Returns: + dict[str, Any]: ``{"excerpt": str}`` on hit, ``{"error": str}`` on miss or + when the storage backend doesn't support this method. + """ + try: + interactions = storage.get_interactions_by_session(args.session_id) + except AttributeError: + return {"error": "read_session_text requires get_interactions_by_session"} + matches = [ + i.content for i in interactions if args.span.strip() in (i.content or "") + ] + if not matches: + return {"error": "span not found"} + return {"excerpt": matches[0]} + + +def _handle_rerank_user_profiles( + args: RerankUserProfilesArgs, storage: Any, ctx: ExtractionCtx +) -> dict[str, Any]: + """Rerank known profile ids with a local cross-encoder. + + Fetches the candidate profiles (scoped to ``ctx.user_id``), scores + ``(query, content)`` pairs, and returns the top_k by descending score. + Bumps ``search_count`` so reranking still counts against the search + budget enforced by invariant A. + + Args: + args (RerankUserProfilesArgs): Query, candidate ids, and top_k. + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; ``search_count`` and + ``known_ids`` updated in place. + + Returns: + dict[str, Any]: ``{"hits": [...]}`` with LLM-facing profile + projections sorted by descending relevance. + """ + if not args.profile_ids: + ctx.search_count += 1 + return {"hits": []} + all_profiles = storage.get_user_profile(ctx.user_id) + wanted = set(args.profile_ids) + candidates = [ + p for p in all_profiles if (getattr(p, "profile_id", "") or "") in wanted + ] + ctx.search_count += 1 + if not candidates: + return {"hits": []} + # Lazy import — keeps unit-test collection fast and avoids loading + # torch when no rerank tool call is made in a given run. + from reflexio.server.llm.rerank import score_pairs + + scores = score_pairs(args.query, [p.content for p in candidates]) + ranked = sorted( + zip(candidates, scores, strict=True), + key=lambda pair: pair[1], + reverse=True, + ) + top = [profile for profile, _score in ranked[: _cap_top_k(args.top_k)]] + for h in top: + pid = getattr(h, "profile_id", "") or "" + if pid: + ctx.known_ids.add(pid) + return {"hits": [_project_profile_for_llm(h) for h in top]} + + +def _handle_storage_stats( + args: StorageStatsArgs, # noqa: ARG001 + storage: Any, + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Return profile/playbook counts and modified-time range for ``ctx.user_id``. + + Does not bump ``search_count`` — this is metadata, not retrieval. + + Args: + args (StorageStatsArgs): No fields (sentinel call). + storage (Any): BaseStorage instance. + ctx (ExtractionCtx): Per-run state; only ``user_id`` is read. + + Returns: + dict[str, Any]: Counts and ISO 8601 timestamps. Timestamps are + ``None`` when the user has no profiles. + """ + profiles = storage.get_user_profile(ctx.user_id) + if profiles: + timestamps = [p.last_modified_timestamp for p in profiles] + oldest_ts = datetime.fromtimestamp(min(timestamps), tz=UTC).isoformat() + newest_ts = datetime.fromtimestamp(max(timestamps), tz=UTC).isoformat() + else: + oldest_ts = None + newest_ts = None + playbook_count = storage.count_user_playbooks(user_id=ctx.user_id) + return { + "profile_count": len(profiles), + "playbook_count": playbook_count, + "oldest_profile_modified": oldest_ts, + "newest_profile_modified": newest_ts, + } + + +def _next_tentative_id(ctx: ExtractionCtx, kind: str) -> str: + """Generate a deterministic tentative-id scoped to this run. + + Format: ``tentative::::`` — unique within the run, + recognizable in logs. + + Args: + ctx (ExtractionCtx): Per-run state; plan length used as counter. + kind (str): Entity type label, e.g. ``"profile"`` or ``"playbook"``. + + Returns: + str: Tentative id string unique within this run. + """ + return f"tentative::{kind}::{len(ctx.plan)}" + + +def new_profile_id() -> str: + """Generate a short (12-char hex) profile id. + + Format chosen for LLM tool-call reliability: full ``str(uuid.uuid4())`` + is 36 characters of hex+dashes, error-prone for smaller LLMs to copy + verbatim from a search result back into a delete/update tool arg. + Twelve hex chars is short enough for high-fidelity copy and long enough + that birthday-paradox collision probability is vanishingly small at any + realistic per-user scale (16^12 ≈ 2.8e14 unique values; PRIMARY KEY + constraint catches the rare collision). + + Profile ids are LLM-facing because the agent receives them in + ``search_user_profiles`` results and must echo them back when calling + ``delete_user_profile`` / ``update_user_profile``. Playbook ids are + INTEGER autoincrements and don't have this problem. + + Returns: + str: 12 lowercase hex characters, e.g. ``"b8a3f74e2c91"``. + """ + return uuid.uuid4().hex[:12] + + +# ==================================================================== +# Mutating handlers — append to ctx.plan, no storage writes +# ==================================================================== + + +def _handle_create_user_profile( + args: CreateUserProfileArgs, + storage: Any, # noqa: ARG001 + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Propose creating a new UserProfile; appends CreateUserProfileOp to ctx.plan. + + No storage write occurs here — apply_plan_op commits ops after invariants pass. + + Args: + args (CreateUserProfileArgs): Validated args from the LLM tool call. + storage (Any): BaseStorage instance (unused; present for handler signature consistency). + ctx (ExtractionCtx): Per-run state; plan and known_ids are mutated. + + Returns: + dict[str, Any]: ``{"op_idx": int, "tentative_id": str}`` for LLM feedback. + """ + tid = _next_tentative_id(ctx, "profile") + op = CreateUserProfileOp( + content=args.content, ttl=args.ttl, source_span=args.source_span + ) + ctx.plan.append(op) + ctx.known_ids.add(tid) + return {"op_idx": len(ctx.plan) - 1, "tentative_id": tid} + + +def _handle_delete_user_profile( + args: DeleteUserProfileArgs, + storage: Any, # noqa: ARG001 + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Propose deleting an existing UserProfile; appends DeleteUserProfileOp to ctx.plan. + + No storage write occurs here. + + Args: + args (DeleteUserProfileArgs): Validated args from the LLM tool call. + storage (Any): BaseStorage instance (unused). + ctx (ExtractionCtx): Per-run state; plan is mutated. + + Returns: + dict[str, Any]: ``{"op_idx": int}`` for LLM feedback. + """ + op = DeleteUserProfileOp(id=args.id) + ctx.plan.append(op) + return {"op_idx": len(ctx.plan) - 1} + + +def _handle_create_user_playbook( + args: CreateUserPlaybookArgs, + storage: Any, # noqa: ARG001 + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Propose creating a new UserPlaybook; appends CreateUserPlaybookOp to ctx.plan. + + No storage write occurs here. + + Args: + args (CreateUserPlaybookArgs): Validated args from the LLM tool call. + storage (Any): BaseStorage instance (unused). + ctx (ExtractionCtx): Per-run state; plan and known_ids are mutated. + + Returns: + dict[str, Any]: ``{"op_idx": int, "tentative_id": str}`` for LLM feedback. + """ + tid = _next_tentative_id(ctx, "playbook") + op = CreateUserPlaybookOp( + trigger=args.trigger, + content=args.content, + rationale=args.rationale, + strength=args.strength, + source_span=args.source_span, + ) + ctx.plan.append(op) + ctx.known_ids.add(tid) + return {"op_idx": len(ctx.plan) - 1, "tentative_id": tid} + + +def _handle_delete_user_playbook( + args: DeleteUserPlaybookArgs, + storage: Any, # noqa: ARG001 + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Propose deleting an existing UserPlaybook; appends DeleteUserPlaybookOp to ctx.plan. + + No storage write occurs here. + + Args: + args (DeleteUserPlaybookArgs): Validated args from the LLM tool call. + storage (Any): BaseStorage instance (unused). + ctx (ExtractionCtx): Per-run state; plan is mutated. + + Returns: + dict[str, Any]: ``{"op_idx": int}`` for LLM feedback. + """ + op = DeleteUserPlaybookOp(id=args.id) + ctx.plan.append(op) + return {"op_idx": len(ctx.plan) - 1} + + +def _handle_finish( + args: FinishArgs, # noqa: ARG001 + storage: Any, # noqa: ARG001 + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Terminate the agent loop. + + Args: + args (FinishArgs): No fields (sentinel call). + storage (Any): BaseStorage instance (unused). + ctx (ExtractionCtx): Per-run state; ``finished`` is set to True. + + Returns: + dict[str, Any]: ``{"finished": True}``. + """ + ctx.finished = True + return {"finished": True} + + +def _handle_search_finish( + args: SearchFinishArgs, + storage: Any, # noqa: ARG001 + ctx: ExtractionCtx, +) -> dict[str, Any]: + """Terminate the search loop and stash the optional answer on ctx. + + Args: + args (SearchFinishArgs): Contains the optional final answer string. When + None (search-only mode) only the termination signal is emitted. + storage (Any): BaseStorage instance (unused). + ctx (ExtractionCtx): Per-run state; ``finished`` set True and + ``search_answer`` populated for retrieval by SearchAgent. + + Returns: + dict[str, Any]: ``{"finished": True, "answer": str | None}``. + """ + ctx.finished = True + ctx.search_answer = args.answer + return {"finished": True, "answer": args.answer} + + +# ==================================================================== +# Commit-stage: apply a PlanOp to storage +# ==================================================================== + + +def apply_plan_op(op: Any, storage: Any, ctx: ExtractionCtx) -> None: + """Deterministically apply one PlanOp to storage. Called by commit_plan. + + Args: + op (Any): A PlanOp variant (CreateUserProfileOp, DeleteUserProfileOp, + CreateUserPlaybookOp, DeleteUserPlaybookOp). + storage (Any): BaseStorage handle. + ctx (ExtractionCtx): Per-run state providing user_id, agent_version, + extractor_name. + + Raises: + TypeError: If ``op`` is not a recognised PlanOp type. + """ + if isinstance(op, CreateUserProfileOp): + now_ts = int(datetime.now(UTC).timestamp()) + ttl = ProfileTimeToLive(op.ttl) + storage.add_user_profile( + ctx.user_id, + [ + UserProfile( + user_id=ctx.user_id, + profile_id=new_profile_id(), + content=op.content, + profile_time_to_live=ttl, + last_modified_timestamp=now_ts, + expiration_timestamp=calculate_expiration_timestamp(now_ts, ttl), + source=f"agentic_v2/{ctx.extractor_name or 'default'}", + source_span=op.source_span, + generated_from_request_id=ctx.request_id, + ) + ], + ) + elif isinstance(op, DeleteUserProfileOp): + storage.delete_profiles_by_ids([op.id]) + elif isinstance(op, CreateUserPlaybookOp): + storage.save_user_playbooks( + [ + UserPlaybook( + user_playbook_id=0, # storage assigns + user_id=ctx.user_id, + agent_version=ctx.agent_version, + request_id=ctx.request_id, + playbook_name=ctx.extractor_name or "default", + content=op.content, + trigger=op.trigger, + rationale=op.rationale, + source_span=op.source_span, + ) + ] + ) + elif isinstance(op, DeleteUserPlaybookOp): + try: + playbook_id = int(op.id) + except (TypeError, ValueError) as e: + raise TypeError( + f"DeleteUserPlaybookOp.id must be a numeric string, got {op.id!r}" + ) from e + storage.delete_user_playbooks_by_ids([playbook_id]) + else: + raise TypeError(f"Unknown PlanOp: {type(op).__name__}") + + +# ==================================================================== +# Bundle adapter + Tool registries +# ==================================================================== + +from collections.abc import Callable # noqa: E402 + +from reflexio.server.llm.tools import Tool, ToolRegistry # noqa: E402 + + +def _bundle_handler( + inner: Callable[[Any, Any, Any], dict[str, Any]], +) -> Callable[[Any, Any], dict[str, Any]]: + """Adapt a (args, storage, ctx)-style handler to (args, bundle) for run_tool_loop. + + ExtractionAgent and SearchAgent build a HandlerBundle with .storage and + .ctx attributes; this adapter unpacks them so the registry accepts our + 3-arg handlers. + + Args: + inner (Callable[[Any, Any, Any], dict[str, Any]]): A handler callable + with signature ``(args, storage, ctx) -> dict``. + + Returns: + Callable[[Any, Any], dict[str, Any]]: A 2-arg callable + ``(args, bundle) -> dict`` compatible with ``Tool.handler``. + """ + + def wrapped(args: Any, bundle: Any) -> dict[str, Any]: + return inner(args, bundle.storage, bundle.ctx) + + return wrapped + + +_READ_TOOLS = [ + Tool( + name="search_user_profiles", + args_model=SearchUserProfilesArgs, + handler=_bundle_handler(_handle_search_user_profiles), + ), + Tool( + name="get_user_profile", + args_model=GetUserProfileArgs, + handler=_bundle_handler(_handle_get_user_profile), + ), + Tool( + name="search_user_playbooks", + args_model=SearchUserPlaybooksArgs, + handler=_bundle_handler(_handle_search_user_playbooks), + ), + Tool( + name="get_user_playbook", + args_model=GetUserPlaybookArgs, + handler=_bundle_handler(_handle_get_user_playbook), + ), + Tool( + name="search_agent_playbooks", + args_model=SearchAgentPlaybooksArgs, + handler=_bundle_handler(_handle_search_agent_playbooks), + ), + Tool( + name="get_agent_playbook", + args_model=GetAgentPlaybookArgs, + handler=_bundle_handler(_handle_get_agent_playbook), + ), + Tool( + name="read_session_text", + args_model=ReadSessionTextArgs, + handler=_bundle_handler(_handle_read_session_text), + ), +] + +_FINISH_TOOL = Tool( + name="finish", + args_model=FinishArgs, + handler=_bundle_handler(_handle_finish), +) + +PROFILE_EXTRACTION_TOOLS = ToolRegistry( + [ + *_READ_TOOLS, + Tool( + name="create_user_profile", + args_model=CreateUserProfileArgs, + handler=_bundle_handler(_handle_create_user_profile), + ), + Tool( + name="delete_user_profile", + args_model=DeleteUserProfileArgs, + handler=_bundle_handler(_handle_delete_user_profile), + ), + _FINISH_TOOL, + ] +) + +PLAYBOOK_EXTRACTION_TOOLS = ToolRegistry( + [ + *_READ_TOOLS, + Tool( + name="create_user_playbook", + args_model=CreateUserPlaybookArgs, + handler=_bundle_handler(_handle_create_user_playbook), + ), + Tool( + name="delete_user_playbook", + args_model=DeleteUserPlaybookArgs, + handler=_bundle_handler(_handle_delete_user_playbook), + ), + _FINISH_TOOL, + ] +) + +# Backward-compat alias: exposes all four create/delete tools. +# New production code should use PROFILE_EXTRACTION_TOOLS or +# PLAYBOOK_EXTRACTION_TOOLS to restrict the LLM to the correct entity kind. +EXTRACTION_TOOLS = ToolRegistry( + [ + *_READ_TOOLS, + Tool( + name="create_user_profile", + args_model=CreateUserProfileArgs, + handler=_bundle_handler(_handle_create_user_profile), + ), + Tool( + name="delete_user_profile", + args_model=DeleteUserProfileArgs, + handler=_bundle_handler(_handle_delete_user_profile), + ), + Tool( + name="create_user_playbook", + args_model=CreateUserPlaybookArgs, + handler=_bundle_handler(_handle_create_user_playbook), + ), + Tool( + name="delete_user_playbook", + args_model=DeleteUserPlaybookArgs, + handler=_bundle_handler(_handle_delete_user_playbook), + ), + _FINISH_TOOL, + ] +) + + +# ==================================================================== +# Multi-stage fallback schema for non-tool-calling models +# ==================================================================== +# +# When the search-agent model lacks native tool-calling (e.g. +# minimax/MiniMax-M2.7), `run_tool_loop` drives one structured-output +# call per turn using `SearchAgentTurnPlan` as the response_format. The +# server parses the result, dispatches `next_call` against `SEARCH_TOOLS`, +# appends the tool result to the message history, and loops until +# `next_call.tool == "finish"` or `max_steps` is exhausted. This +# preserves observe-decide-act semantics that single-shot fallback +# (which planned all calls upfront) could not. +# +# The discriminated union mirrors the `args_model` of every tool in +# `SEARCH_TOOLS`. Field names match the existing tool args so we can +# convert each variant directly to the dispatch JSON via +# `model_dump(exclude={"tool"})`. + + +class _CallSearchUserProfiles(BaseModel): + """Multi-stage variant: call `search_user_profiles`.""" + + tool: Literal["search_user_profiles"] + query: Annotated[str, Field(min_length=1)] + top_k: int = 10 + refine_with: str | None = None + + +class _CallSearchUserPlaybooks(BaseModel): + """Multi-stage variant: call `search_user_playbooks`.""" + + tool: Literal["search_user_playbooks"] + query: Annotated[str, Field(min_length=1)] + top_k: int = 10 + status: Literal["current", "pending", "archived"] = "current" + + +class _CallSearchAgentPlaybooks(BaseModel): + """Multi-stage variant: call `search_agent_playbooks`.""" + + tool: Literal["search_agent_playbooks"] + query: Annotated[str, Field(min_length=1)] + top_k: int = 10 + status: Literal["current", "pending", "archived"] = "current" + + +class _CallGetUserProfile(BaseModel): + """Multi-stage variant: call `get_user_profile`.""" + + tool: Literal["get_user_profile"] + id: Annotated[str, Field(min_length=1)] + + +class _CallGetUserPlaybook(BaseModel): + """Multi-stage variant: call `get_user_playbook`.""" + + tool: Literal["get_user_playbook"] + id: Annotated[str, Field(min_length=1)] + + +class _CallGetAgentPlaybook(BaseModel): + """Multi-stage variant: call `get_agent_playbook`.""" + + tool: Literal["get_agent_playbook"] + id: Annotated[str, Field(min_length=1)] + + +class _CallReadSessionText(BaseModel): + """Multi-stage variant: call `read_session_text`.""" + + tool: Literal["read_session_text"] + session_id: Annotated[str, Field(min_length=1)] + span: Annotated[str, Field(min_length=1)] + + +class _CallStorageStats(BaseModel): + """Multi-stage variant: call `storage_stats` (no args).""" + + tool: Literal["storage_stats"] + + +class _CallFinish(BaseModel): + """Multi-stage variant: call `finish` to terminate the loop.""" + + tool: Literal["finish"] + answer: str | None = None + + +_SearchToolCall = Annotated[ + _CallSearchUserProfiles + | _CallSearchUserPlaybooks + | _CallSearchAgentPlaybooks + | _CallGetUserProfile + | _CallGetUserPlaybook + | _CallGetAgentPlaybook + | _CallReadSessionText + | _CallStorageStats + | _CallFinish, + Field(discriminator="tool"), +] + + +class SearchAgentTurnPlan(BaseModel): + """One turn of the search agent's multi-stage fallback plan. + + The agent emits one ``SearchAgentTurnPlan`` per turn. The server parses + it, dispatches ``next_call`` against ``SEARCH_TOOLS``, appends the tool + result to the message history, and asks for the next turn — until + ``next_call.tool == "finish"`` or ``max_steps`` is exhausted. + + Used by ``run_tool_loop`` when the configured model lacks native + tool-calling but should still run a multi-turn observe-decide-act loop + (e.g. ``minimax/MiniMax-M2.7``). + """ + + reasoning: Annotated[str, Field(min_length=1)] + next_call: _SearchToolCall + + +SEARCH_TOOLS = ToolRegistry( + [ + Tool( + name="search_user_profiles", + args_model=SearchUserProfilesArgs, + handler=_bundle_handler(_handle_search_user_profiles), + ), + Tool( + name="get_user_profile", + args_model=GetUserProfileArgs, + handler=_bundle_handler(_handle_get_user_profile), + ), + # rerank_user_profiles intentionally removed from the agent palette: + # `search_user_profiles` now does deterministic cross-encoder rerank + # internally and accepts an optional `refine_with` for two-stage + # query refinement. The standalone rerank tool required the agent + # to round-trip profile_ids back through the model, which was both + # cognitively expensive and a hallucination risk on long lists. + # The handler `_handle_rerank_user_profiles` is preserved in this + # module for any non-agent caller that needs explicit rerank. + Tool( + name="storage_stats", + args_model=StorageStatsArgs, + handler=_bundle_handler(_handle_storage_stats), + ), + Tool( + name="search_user_playbooks", + args_model=SearchUserPlaybooksArgs, + handler=_bundle_handler(_handle_search_user_playbooks), + ), + Tool( + name="get_user_playbook", + args_model=GetUserPlaybookArgs, + handler=_bundle_handler(_handle_get_user_playbook), + ), + Tool( + name="search_agent_playbooks", + args_model=SearchAgentPlaybooksArgs, + handler=_bundle_handler(_handle_search_agent_playbooks), + ), + Tool( + name="get_agent_playbook", + args_model=GetAgentPlaybookArgs, + handler=_bundle_handler(_handle_get_agent_playbook), + ), + Tool( + name="read_session_text", + args_model=ReadSessionTextArgs, + handler=_bundle_handler(_handle_read_session_text), + ), + Tool( + name="finish", + args_model=SearchFinishArgs, + handler=_bundle_handler(_handle_search_finish), + ), + ] +) diff --git a/reflexio/server/services/generation_service.py b/reflexio/server/services/generation_service.py index 6d021754..68644dc6 100644 --- a/reflexio/server/services/generation_service.py +++ b/reflexio/server/services/generation_service.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError from dataclasses import dataclass, field -from datetime import UTC, datetime +from typing import TYPE_CHECKING from reflexio.defaults import resolve_agent_version from reflexio.models.api_schema.service_schemas import ( @@ -15,6 +15,7 @@ PublishUserInteractionRequest, Request, ) +from reflexio.models.config_schema import Config from reflexio.server.api_endpoints.request_context import RequestContext from reflexio.server.llm.litellm_client import LiteLLMClient from reflexio.server.services.agent_success_evaluation.delayed_group_evaluator import ( @@ -37,6 +38,12 @@ ProfileGenerationRequest, ) +if TYPE_CHECKING: + from reflexio.server.services.search.agentic_search_service import ( + AgenticSearchService, + ) + from reflexio.server.services.unified_search_service import UnifiedSearchService + logger = logging.getLogger(__name__) # Stale lock timeout - if cleanup started > 10 min ago and still "in_progress", assume it crashed CLEANUP_STALE_LOCK_SECONDS = 600 @@ -171,6 +178,33 @@ def run( # Extract source (empty string treated as None) source = publish_user_interaction_request.source or None + # Dispatch to the agentic pipeline when the config flag is set. + # Classic path (default) falls through to the ProfileGenerationService + # + PlaybookGenerationService fan-out below. + root_config = self.configurator.get_config() + if ( + root_config is not None + and getattr(root_config, "extraction_backend", "classic") == "agentic" + ): + from reflexio.server.services.extraction.agentic_adapter import ( + AgenticExtractionRunner, + ) + + runner = AgenticExtractionRunner( + llm_client=self.client, + request_context=self.request_context, + ) + result.warnings.extend( + runner.run( + publish_request=publish_user_interaction_request, + request_id=request_id, + new_interactions=new_interactions, + new_request=new_request, + config=root_config, + ) + ) + return result + # Create generation services and requests # Each service writes to separate storage tables and has no dependencies on others profile_generation_service = ProfileGenerationService( @@ -361,14 +395,19 @@ def get_interaction_from_publish_user_interaction_request( interaction_data_list = publish_user_interaction_request.interaction_data_list user_id = publish_user_interaction_request.user_id - # Always use server-side UTC timestamp to ensure consistency - server_timestamp = int(datetime.now(UTC).timestamp()) + # Honor the client-provided ``created_at`` — InteractionData defaults + # it to client-side ``now()`` on construction, so it's always populated. + # Apps that publish backdated conversations (e.g., a benchmark replay + # of 2023 chats run in 2026) need the wall-clock time preserved so the + # extraction agent has a real temporal anchor for relative-time + # references like "X weeks ago" / "yesterday". Stamping server-now here + # would erase that anchor and force every event onto today's date. return [ Interaction( # interaction_id is auto-generated by DB user_id=user_id, request_id=request_id, - created_at=server_timestamp, # Use server UTC timestamp + created_at=interaction_data.created_at, content=interaction_data.content, role=interaction_data.role, user_action=interaction_data.user_action, @@ -381,3 +420,63 @@ def get_interaction_from_publish_user_interaction_request( ) for interaction_data in interaction_data_list ] + + +def build_extraction_service( + config: Config, + *, + llm_client: LiteLLMClient, + request_context: RequestContext, +) -> ProfileGenerationService: + """Return the classic profile extraction service. + + The agentic extraction path is handled directly by + ``AgenticExtractionRunner`` inside ``GenerationService.run`` and does not + go through this factory. This function exists for the classic dispatcher + path only. + + Args: + config (Config): Top-level ``Config`` (unused; kept for API consistency). + llm_client (LiteLLMClient): Configured ``LiteLLMClient``. + request_context (RequestContext): Current request context. + + Returns: + ProfileGenerationService: Classic profile extraction service. + """ + del config # unused — agentic path bypasses this factory + return ProfileGenerationService( + llm_client=llm_client, request_context=request_context + ) + + +def build_search_service( + config: Config, + *, + llm_client: LiteLLMClient, + request_context: RequestContext, +) -> UnifiedSearchService | AgenticSearchService: + """Dispatch to the classic or agentic search service. + + Selected by ``config.search_backend``. Classic returns a + ``UnifiedSearchService``; agentic returns the Phase-4 pipeline. + + Args: + config (Config): Top-level ``Config``. Reads ``search_backend``. + llm_client (LiteLLMClient): Configured ``LiteLLMClient``. + request_context (RequestContext): Current request context. + + Returns: + Object holding ``llm_client`` and ``request_context`` — either a + classic ``UnifiedSearchService`` or the agentic service. + """ + if config.search_backend == "agentic": + from reflexio.server.services.search.agentic_search_service import ( # type: ignore[import-not-found] + AgenticSearchService, + ) + + return AgenticSearchService( + llm_client=llm_client, request_context=request_context + ) + from reflexio.server.services.unified_search_service import UnifiedSearchService + + return UnifiedSearchService(llm_client=llm_client, request_context=request_context) diff --git a/reflexio/server/services/playbook/playbook_deduplicator.py b/reflexio/server/services/playbook/playbook_deduplicator.py deleted file mode 100644 index d8794f5a..00000000 --- a/reflexio/server/services/playbook/playbook_deduplicator.py +++ /dev/null @@ -1,504 +0,0 @@ -""" -Playbook deduplication service that merges duplicate user playbook entries using LLM -and hybrid search against existing entries in the database. -""" - -import logging -import os -from datetime import UTC, datetime - -from pydantic import BaseModel, ConfigDict, Field - -from reflexio.models.api_schema.retriever_schema import SearchUserPlaybookRequest -from reflexio.models.api_schema.service_schemas import UserPlaybook -from reflexio.models.config_schema import ( - EMBEDDING_DIMENSIONS, - DeduplicationConfig, - SearchOptions, -) -from reflexio.server.api_endpoints.request_context import RequestContext -from reflexio.server.llm.litellm_client import LiteLLMClient -from reflexio.server.services.deduplication_utils import ( - BaseDeduplicator, - format_dedup_timestamp, - parse_item_id, -) -from reflexio.server.services.playbook.playbook_service_utils import ( - StructuredPlaybookContent, - ensure_playbook_content, -) - -logger = logging.getLogger(__name__) - - -# =============================== -# Playbook-specific Pydantic Output Schemas for LLM -# =============================== - - -class PlaybookDeduplicationDuplicateGroup(BaseModel): - """A group of duplicate playbook entries to merge, with old entries to delete.""" - - item_ids: list[str] = Field( - description="IDs of items in this group matching prompt format (e.g., 'NEW-0', 'EXISTING-1')" - ) - merged_content: StructuredPlaybookContent = Field( - description="Consolidated playbook entry in structured format (trigger, rationale, blocking_issue)" - ) - reasoning: str = Field(description="Brief explanation of the merge decision") - - model_config = ConfigDict( - extra="allow", - json_schema_extra={"additionalProperties": False}, - ) - - -class PlaybookDeduplicationOutput(BaseModel): - """Output schema for playbook deduplication with NEW vs EXISTING merge support.""" - - duplicate_groups: list[PlaybookDeduplicationDuplicateGroup] = Field( - default=[], description="Groups of duplicate playbook entries to merge" - ) - unique_ids: list[str] = Field( - default=[], description="IDs of unique NEW entries (e.g., 'NEW-2')" - ) - - model_config = ConfigDict( - extra="allow", - json_schema_extra={"additionalProperties": False}, - ) - - -class PlaybookDeduplicator(BaseDeduplicator): - """ - Deduplicates new user playbook entries against each other and against existing entries - in the database using hybrid search (vector + FTS) and LLM-based merging. - """ - - DEDUPLICATION_PROMPT_ID = "playbook_deduplication" - - def __init__( - self, - request_context: RequestContext, - llm_client: LiteLLMClient, - dedup_config: DeduplicationConfig | None = None, - ): - """ - Initialize the playbook deduplicator. - - Args: - request_context: Request context with storage and prompt manager - llm_client: Unified LLM client for LLM calls - dedup_config: Optional deduplication search parameters (threshold, top_k) - """ - super().__init__(request_context, llm_client) - self._dedup_config = dedup_config or DeduplicationConfig() - - def _get_prompt_id(self) -> str: - """Get the prompt ID for playbook deduplication.""" - return self.DEDUPLICATION_PROMPT_ID - - def _get_item_count_key(self) -> str: - """Get the key name for item count in prompt variables.""" - return "new_playbook_count" - - def _get_items_key(self) -> str: - """Get the key name for items in prompt variables.""" - return "new_playbooks" - - def _get_output_schema_class(self) -> type[BaseModel]: - """Return PlaybookDeduplicationOutput for new/existing merge.""" - return PlaybookDeduplicationOutput - - def _format_items_for_prompt(self, playbooks: list[UserPlaybook]) -> str: - """ - Format user playbook entries list for LLM prompt with NEW-N prefix. - - Args: - playbooks: List of user playbook entries - - Returns: - Formatted string representation - """ - return self._format_playbooks_with_prefix(playbooks, "NEW") - - def _format_playbooks_with_prefix( - self, playbooks: list[UserPlaybook], prefix: str - ) -> str: - """ - Format user playbook entries with a given prefix (NEW or EXISTING). - - Args: - playbooks: List of user playbook entries to format - prefix: Prefix string for indices - - Returns: - Formatted string - """ - if not playbooks: - return "(None)" - lines = [] - for idx, playbook in enumerate(playbooks): - playbook_name = playbook.playbook_name or "unknown" - source = playbook.source or "unknown" - created_date = format_dedup_timestamp(playbook.created_at) - lines.append( - f'[{prefix}-{idx}] Content: "{playbook.content}" | Name: {playbook_name} | Source: {source} | Last Modified: {created_date}' - ) - return "\n".join(lines) - - def _retrieve_existing_playbooks( - self, - new_playbooks: list[UserPlaybook], - user_id: str | None = None, - agent_version: str | None = None, - ) -> list[UserPlaybook]: - """ - Retrieve existing user playbook entries from the database using hybrid search. - - For each new entry, uses its trigger field as the query with - pre-computed embeddings for vector search. - - Args: - new_playbooks: List of new entries to search against - user_id: Optional user ID to scope the search - agent_version: Optional agent version to scope the search - - Returns: - Deduplicated list of existing UserPlaybook objects from the database - """ - storage = self.request_context.storage - - # Collect trigger strings for embedding - query_texts = [] - for playbook in new_playbooks: - trigger = playbook.trigger or playbook.content - if trigger and trigger.strip(): - query_texts.append(trigger.strip()) - - if not query_texts: - return [] - - # Batch-generate embeddings - try: - embeddings = self.client.get_embeddings( - query_texts, dimensions=EMBEDDING_DIMENSIONS - ) - except Exception as e: - logger.warning("Failed to generate embeddings for dedup search: %s", e) - # Fall back to text-only search - embeddings = [None] * len(query_texts) - - # Search for each new entry - seen_ids: set[int] = set() - existing_playbooks: list[UserPlaybook] = [] - - for i, query_text in enumerate(query_texts): - try: - search_request = SearchUserPlaybookRequest( - query=query_text, - user_id=user_id, - agent_version=agent_version, - status_filter=[None], # Only current entries - threshold=self._dedup_config.search_threshold, - top_k=self._dedup_config.search_top_k, - ) - search_options = SearchOptions(query_embedding=embeddings[i]) - results = storage.search_user_playbooks( # type: ignore[reportOptionalMemberAccess] - search_request, search_options - ) - for fb in results: - if fb.user_playbook_id and fb.user_playbook_id not in seen_ids: - seen_ids.add(fb.user_playbook_id) - existing_playbooks.append(fb) - except Exception as e: # noqa: PERF203 - logger.warning( - "Failed to search existing entries for query %d: %s", i, e - ) - - logger.info( - "Retrieved %d unique existing user playbook entries for deduplication " - "(scoped to user_id=%r agent_version=%r)", - len(existing_playbooks), - user_id, - agent_version, - ) - return existing_playbooks - - def _format_new_and_existing_for_prompt( - self, - new_playbooks: list[UserPlaybook], - existing_playbooks: list[UserPlaybook], - ) -> tuple[str, str]: - """ - Format new and existing entries for the deduplication prompt. - - Args: - new_playbooks: New entries to deduplicate - existing_playbooks: Existing entries from the database - - Returns: - Tuple of (new_playbooks_text, existing_playbooks_text) - """ - new_text = self._format_playbooks_with_prefix(new_playbooks, "NEW") - existing_text = self._format_playbooks_with_prefix( - existing_playbooks, "EXISTING" - ) - return new_text, existing_text - - def deduplicate( - self, - results: list[list[UserPlaybook]], - request_id: str, - agent_version: str, - user_id: str | None = None, - ) -> tuple[list[UserPlaybook], list[int]]: - """ - Deduplicate user playbook entries across extractors and against existing entries in DB. - - Args: - results: List of entry lists from extractors (each extractor returns list[UserPlaybook]) - request_id: Request ID for context - agent_version: Agent version for context - user_id: Optional user ID to scope the existing entry search - - Returns: - Tuple of (deduplicated entries, list of existing entry IDs to delete after save) - """ - # Check if mock mode is enabled - if os.getenv("MOCK_LLM_RESPONSE", "").lower() == "true": - logger.info("Mock mode: skipping deduplication") - all_playbooks: list[UserPlaybook] = [] - for result in results: - if isinstance(result, list): - all_playbooks.extend(result) - return all_playbooks, [] - - # Flatten all new entries - new_playbooks: list[UserPlaybook] = [] - for result in results: - if isinstance(result, list): - new_playbooks.extend(result) - - if not new_playbooks: - return [], [] - - # Retrieve existing entries via hybrid search - existing_playbooks = self._retrieve_existing_playbooks( - new_playbooks, user_id=user_id, agent_version=agent_version - ) - - # Format for prompt - new_text, existing_text = self._format_new_and_existing_for_prompt( - new_playbooks, existing_playbooks - ) - - # Build and call LLM - prompt = self.request_context.prompt_manager.render_prompt( - self._get_prompt_id(), - { - "new_playbook_count": len(new_playbooks), - "new_playbooks": new_text, - "existing_playbook_count": len(existing_playbooks), - "existing_playbooks": existing_text, - }, - ) - - output_schema_class = self._get_output_schema_class() - - try: - from reflexio.server.services.service_utils import ( - log_llm_messages, - log_model_response, - ) - - log_llm_messages( - logger, - "Playbook deduplication", - [{"role": "user", "content": prompt}], - ) - - response = self.client.generate_chat_response( - messages=[{"role": "user", "content": prompt}], - model=self.model_name, - response_format=output_schema_class, - ) - - log_model_response(logger, "Deduplication response", response) - - if not isinstance(response, PlaybookDeduplicationOutput): - logger.warning( - "Unexpected response type from deduplication LLM: %s", - type(response), - ) - return new_playbooks, [] - - dedup_output = response - except Exception as e: - logger.error("Failed to identify duplicates: %s", str(e)) - return new_playbooks, [] - - if not dedup_output.duplicate_groups: - logger.info( - "No duplicate playbook entries found for request %s", request_id - ) - return new_playbooks, [] - - logger.info( - "Found %d duplicate playbook groups for request %s", - len(dedup_output.duplicate_groups), - request_id, - ) - - # Build deduplicated result - return self._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=existing_playbooks, - dedup_output=dedup_output, - request_id=request_id, - agent_version=agent_version, - ) - - def _build_deduplicated_results( # noqa: C901 - self, - new_playbooks: list[UserPlaybook], - existing_playbooks: list[UserPlaybook], - dedup_output: PlaybookDeduplicationOutput, - request_id: str, - agent_version: str, # noqa: ARG002 - ) -> tuple[list[UserPlaybook], list[int]]: - """ - Build the deduplicated entry list from LLM output. - - Handles merged groups (creating new entries from merged content) - and unique entries. Returns IDs of existing entries to delete - so the caller can delete them after save succeeds. - - Args: - new_playbooks: Flattened list of new entries - existing_playbooks: List of existing entries from DB - dedup_output: LLM deduplication output - request_id: Request ID - agent_version: Agent version - - Returns: - Tuple of (entries ready to save, existing entry IDs to delete) - """ - handled_new_indices: set[int] = set() - result_playbooks: list[UserPlaybook] = [] - existing_ids_to_delete: list[int] = [] - seen_delete_ids: set[int] = set() - - now_ts = int(datetime.now(UTC).timestamp()) - - # Process duplicate groups - for group in dedup_output.duplicate_groups: - group_new_indices: list[int] = [] - group_existing_indices: list[int] = [] - - for item_id in group.item_ids: - parsed = parse_item_id(item_id) - if parsed is None: - continue - prefix, idx = parsed - if prefix == "NEW": - group_new_indices.append(idx) - handled_new_indices.add(idx) - elif prefix == "EXISTING": - group_existing_indices.append(idx) - - # Collect existing entry IDs to delete (deduplicated) - for eidx in group_existing_indices: - if 0 <= eidx < len(existing_playbooks): - fb_id = existing_playbooks[eidx].user_playbook_id - if fb_id and fb_id not in seen_delete_ids: - seen_delete_ids.add(fb_id) - existing_ids_to_delete.append(fb_id) - - # Get template from first NEW entry in group (for metadata) - template_playbook: UserPlaybook | None = None - if group_new_indices: - first_new_idx = group_new_indices[0] - if 0 <= first_new_idx < len(new_playbooks): - template_playbook = new_playbooks[first_new_idx] - - if template_playbook is None: - # Fallback: use first existing entry as template - if group_existing_indices: - for eidx in group_existing_indices: - if 0 <= eidx < len(existing_playbooks): - template_playbook = existing_playbooks[eidx] - break - if template_playbook is None: - logger.warning("Could not find template entry for group, skipping") - continue - - # Combine source_interaction_ids from all NEW entries in group - combined_source_ids: list[int] = [] - seen_ids: set[int] = set() - for idx in group_new_indices: - if 0 <= idx < len(new_playbooks): - for sid in new_playbooks[idx].source_interaction_ids: - if sid not in seen_ids: - combined_source_ids.append(sid) - seen_ids.add(sid) - - # Also include source_interaction_ids from existing entries being merged - for eidx in group_existing_indices: - if 0 <= eidx < len(existing_playbooks): - for sid in existing_playbooks[eidx].source_interaction_ids: - if sid not in seen_ids: - combined_source_ids.append(sid) - seen_ids.add(sid) - - # Format content from merged structured content - merged_content = group.merged_content - playbook_content = ensure_playbook_content( - merged_content.content, merged_content - ) - logger.info( - "Deduplicated playbook content (freeform): %.200s", - playbook_content, - ) - - merged_playbook = UserPlaybook( - user_playbook_id=0, # Will be assigned by storage - user_id=template_playbook.user_id, - agent_version=template_playbook.agent_version, - request_id=request_id, - playbook_name=template_playbook.playbook_name, - created_at=now_ts, - content=playbook_content, - trigger=merged_content.trigger, - rationale=merged_content.rationale, - blocking_issue=merged_content.blocking_issue, - status=template_playbook.status, - source=template_playbook.source, - source_interaction_ids=combined_source_ids, - ) - result_playbooks.append(merged_playbook) - - # Add unique NEW entries - for uid in dedup_output.unique_ids: - parsed = parse_item_id(uid) - if parsed is None: - continue - prefix, idx = parsed - if ( - prefix == "NEW" - and idx not in handled_new_indices - and 0 <= idx < len(new_playbooks) - ): - result_playbooks.append(new_playbooks[idx]) - handled_new_indices.add(idx) - - # Safety fallback: add any NEW entries not mentioned by LLM - for idx, playbook in enumerate(new_playbooks): - if idx not in handled_new_indices: - logger.warning( - "New entry at index %d was not handled by LLM, adding as-is", - idx, - ) - result_playbooks.append(playbook) - - return result_playbooks, existing_ids_to_delete diff --git a/reflexio/server/services/playbook/playbook_generation_service.py b/reflexio/server/services/playbook/playbook_generation_service.py index fb183dc5..3957db76 100644 --- a/reflexio/server/services/playbook/playbook_generation_service.py +++ b/reflexio/server/services/playbook/playbook_generation_service.py @@ -264,45 +264,7 @@ def _process_results(self, results: list[list[UserPlaybook]]) -> None: if isinstance(result, list): all_playbooks.extend(result) - # Deduplicate against existing entries in DB when deduplicator is enabled existing_ids_to_delete: list[int] = [] - from reflexio.server.site_var.feature_flags import is_deduplicator_enabled - - if is_deduplicator_enabled(self.org_id): - from reflexio.server.services.playbook.playbook_deduplicator import ( - PlaybookDeduplicator, - ) - - # Get deduplication config from the first playbook config that has one - playbook_configs_list = ( - self.configurator.get_config().user_playbook_extractor_configs - ) - dedup_config = next( - ( - c.deduplication_config - for c in (playbook_configs_list or []) - if c.deduplication_config - ), - None, - ) - - deduplicator = PlaybookDeduplicator( - request_context=self.request_context, - llm_client=self.client, - dedup_config=dedup_config, - ) - deduplicated_playbooks, existing_ids_to_delete = deduplicator.deduplicate( - results, - self.service_config.request_id, # type: ignore[reportOptionalMemberAccess] - self.service_config.agent_version, # type: ignore[reportOptionalMemberAccess] - user_id=self.service_config.user_id, # type: ignore[reportOptionalMemberAccess] - ) - logger.info( - "User playbook entries after deduplication: %d", - len(deduplicated_playbooks), - ) - if deduplicated_playbooks: - all_playbooks = deduplicated_playbooks # Set status and source for all entries for playbook in all_playbooks: diff --git a/reflexio/server/services/playbook/playbook_service_utils.py b/reflexio/server/services/playbook/playbook_service_utils.py index c0174ccf..ee28af26 100644 --- a/reflexio/server/services/playbook/playbook_service_utils.py +++ b/reflexio/server/services/playbook/playbook_service_utils.py @@ -54,6 +54,18 @@ class StructuredPlaybookContent(BaseModel): default=None, description="The main actionable content of the playbook entry — what to do or what to avoid", ) + source_span: str | None = Field( + default=None, + description="Verbatim excerpt from the source that most directly supports this playbook entry", + ) + notes: str | None = Field( + default=None, + description="Free-form extraction notes — confidence, caveats, or alternative readings", + ) + reader_angle: str | None = Field( + default=None, + description="The extraction perspective or reader role that surfaced this entry", + ) model_config = ConfigDict( extra="allow", diff --git a/reflexio/server/services/profile/profile_deduplicator.py b/reflexio/server/services/profile/profile_deduplicator.py deleted file mode 100644 index b13995bb..00000000 --- a/reflexio/server/services/profile/profile_deduplicator.py +++ /dev/null @@ -1,717 +0,0 @@ -""" -Profile deduplication service that merges duplicate profiles from multiple extractors -and against existing profiles in the database using hybrid search and LLM. -""" - -import logging -import os -import uuid -from datetime import UTC, datetime - -from pydantic import BaseModel, ConfigDict, Field - -from reflexio.models.api_schema.retriever_schema import SearchUserProfileRequest -from reflexio.models.api_schema.service_schemas import UserProfile -from reflexio.models.config_schema import EMBEDDING_DIMENSIONS -from reflexio.server.api_endpoints.request_context import RequestContext -from reflexio.server.llm.litellm_client import LiteLLMClient -from reflexio.server.services.deduplication_utils import ( - BaseDeduplicator, - format_dedup_timestamp, - parse_item_id, -) -from reflexio.server.services.profile.profile_generation_service_utils import ( - ProfileTimeToLive, - calculate_expiration_timestamp, -) - -logger = logging.getLogger(__name__) - - -# Backward-compat alias — existing unit tests import this name from this -# module. Delegates to the shared helper in deduplication_utils. -_format_profile_timestamp = format_dedup_timestamp - - -# Canonical prefix emitted by the extractor for forget/delete requests. The -# dedup LLM routes matching NEW profiles into `deletions`; any fallback path -# that skips the LLM step must strip these markers before returning so they -# are never persisted as facts. -_DELETION_MARKER_PREFIX = "Requested removal of" - - -def _strip_deletion_markers( - profiles: list[UserProfile], -) -> list[UserProfile]: - """ - Drop profiles whose content is a canonical deletion marker. - - Used on fallback paths (LLM error, unexpected response type, empty dedup - output) to prevent "Requested removal of …" markers emitted by the - extractor from being persisted as regular profile facts when the dedup - LLM step is skipped or yields no deletions. Persisting such markers would - recreate the exact zombie-profile failure mode the deletion-directive - channel was introduced to eliminate. - - Args: - profiles (list[UserProfile]): Profiles to filter. - - Returns: - list[UserProfile]: Profiles with deletion markers removed. - """ - return [ - p - for p in profiles - if not (p.content or "").lstrip().startswith(_DELETION_MARKER_PREFIX) - ] - - -# =============================== -# Profile-specific Pydantic Output Schemas for LLM -# =============================== - - -class ProfileDuplicateGroup(BaseModel): - """ - Represents a group of duplicate profiles across NEW and EXISTING sets. - - Attributes: - item_ids: List of item IDs matching prompt format (e.g., 'NEW-0', 'EXISTING-1') - merged_content: The consolidated profile content combining information from all duplicates - merged_time_to_live: The chosen time_to_live for the merged profile - reasoning: Brief explanation of why these profiles are duplicates and how they were merged - """ - - item_ids: list[str] = Field( - description="IDs of items in this group matching prompt format (e.g., 'NEW-0', 'EXISTING-1')" - ) - merged_content: str = Field( - description="Consolidated profile content combining all duplicate information" - ) - merged_time_to_live: str = Field( - description="Time to live for merged profile: one_day, one_week, one_month, one_quarter, one_year, infinity" - ) - reasoning: str = Field(description="Brief explanation of the merge decision") - - model_config = ConfigDict( - extra="allow", - json_schema_extra={"additionalProperties": False}, - ) - - -class ProfileDeletionDirective(BaseModel): - """ - Represents a NEW profile that is a meta-request to forget an EXISTING fact. - - Used when the user explicitly asks the system to erase a previously-stored - profile (e.g. "forget that I like X"). Unlike a duplicate group, a deletion - directive removes the matched EXISTING profile(s) without writing any merged - or replacement profile — the NEW directive is consumed, not retained. - - Attributes: - new_id: ID of the NEW profile that expresses the deletion directive (e.g. 'NEW-0') - existing_ids: IDs of EXISTING profiles to delete without replacement (e.g. ['EXISTING-0']) - reasoning: Brief explanation of why this was classified as a deletion directive - rather than a fact update - """ - - new_id: str = Field( - description="ID of the NEW profile that is a deletion directive (e.g. 'NEW-0')" - ) - existing_ids: list[str] = Field( - description="IDs of EXISTING profiles to delete without replacement (e.g. ['EXISTING-0'])" - ) - reasoning: str = Field( - description="Brief explanation of the deletion classification" - ) - - model_config = ConfigDict( - extra="allow", - json_schema_extra={"additionalProperties": False}, - ) - - -class ProfileDeduplicationOutput(BaseModel): - """ - Output schema for profile deduplication with NEW/EXISTING format. - - Attributes: - duplicate_groups: List of duplicate groups to merge - unique_ids: List of IDs of unique NEW profiles (e.g., 'NEW-2') - deletions: List of deletion directives — NEW profiles that are pure - meta-requests to erase an EXISTING profile. Both the NEW and the - matched EXISTING profile(s) are removed; no merged replacement is - produced. - """ - - duplicate_groups: list[ProfileDuplicateGroup] = Field( - default=[], description="Groups of duplicate profiles that should be merged" - ) - unique_ids: list[str] = Field( - default=[], - description="IDs of unique NEW profiles (e.g., 'NEW-2')", - ) - deletions: list[ProfileDeletionDirective] = Field( - default=[], - description=( - "NEW profiles that are pure deletion directives (the user asked to " - "forget/remove a stored fact). Both the NEW and matched EXISTING " - "profiles are removed; no merged replacement is written." - ), - ) - - model_config = ConfigDict( - extra="allow", - json_schema_extra={"additionalProperties": False}, - ) - - -class ProfileDeduplicator(BaseDeduplicator): - """ - Deduplicates new profiles against each other and against existing profiles - in the database using hybrid search (vector + FTS) and LLM-based merging. - - Follows the same pattern as PlaybookDeduplicator. - """ - - DEDUPLICATION_PROMPT_ID = "profile_deduplication" - - def __init__( - self, - request_context: RequestContext, - llm_client: LiteLLMClient, - ): - """ - Initialize the profile deduplicator. - - Args: - request_context: Request context with storage and prompt manager - llm_client: Unified LLM client for LLM calls - """ - super().__init__(request_context, llm_client) - - def _get_prompt_id(self) -> str: - """Get the prompt ID for profile deduplication.""" - return self.DEDUPLICATION_PROMPT_ID - - def _get_item_count_key(self) -> str: - """Get the key name for item count in prompt variables.""" - return "new_profile_count" - - def _get_items_key(self) -> str: - """Get the key name for items in prompt variables.""" - return "new_profiles" - - def _get_output_schema_class(self) -> type[BaseModel]: - """Get the profile-specific output schema with NEW/EXISTING format.""" - return ProfileDeduplicationOutput - - def _format_items_for_prompt(self, profiles: list[UserProfile]) -> str: - """ - Format profiles list for LLM prompt with NEW-N prefix. - - Args: - profiles: List of profiles - - Returns: - Formatted string representation - """ - return self._format_profiles_with_prefix(profiles, "NEW") - - def _format_profiles_with_prefix( - self, profiles: list[UserProfile], prefix: str - ) -> str: - """ - Format profiles with a given prefix (NEW or EXISTING). - - Args: - profiles: List of profiles to format - prefix: Prefix string for indices - - Returns: - Formatted string - """ - if not profiles: - return "(None)" - lines = [] - for idx, profile in enumerate(profiles): - ttl = ( - profile.profile_time_to_live.value - if profile.profile_time_to_live - else "unknown" - ) - source = profile.source or "unknown" - modified_date = _format_profile_timestamp(profile.last_modified_timestamp) - lines.append( - f'[{prefix}-{idx}] Content: "{profile.content}" | TTL: {ttl} | Source: {source} | Last Modified: {modified_date}' - ) - return "\n".join(lines) - - def _format_new_and_existing_for_prompt( - self, - new_profiles: list[UserProfile], - existing_profiles: list[UserProfile], - ) -> tuple[str, str]: - """ - Format new and existing profiles for the deduplication prompt. - - Args: - new_profiles: New profiles to deduplicate - existing_profiles: Existing profiles from the database - - Returns: - Tuple of (new_profiles_text, existing_profiles_text) - """ - new_text = self._format_profiles_with_prefix(new_profiles, "NEW") - existing_text = self._format_profiles_with_prefix(existing_profiles, "EXISTING") - return new_text, existing_text - - def _retrieve_existing_profiles( - self, - new_profiles: list[UserProfile], - user_id: str, - ) -> list[UserProfile]: - """ - Retrieve existing profiles from the database using hybrid search. - - For each new profile, uses its profile_content as the query with - pre-computed embeddings for vector search. - - Args: - new_profiles: List of new profiles to search against - user_id: User ID to scope the search - - Returns: - Deduplicated list of existing UserProfile objects from the database - """ - storage = self.request_context.storage - - # Collect profile content strings for embedding - query_texts = [] - for profile in new_profiles: - text = profile.content - if text and text.strip(): - query_texts.append(text.strip()) - - if not query_texts: - return [] - - # Batch-generate embeddings - try: - embeddings = self.client.get_embeddings( - query_texts, dimensions=EMBEDDING_DIMENSIONS - ) - except Exception as e: - logger.warning("Failed to generate embeddings for dedup search: %s", e) - embeddings = [None] * len(query_texts) - - # Search for each new profile - seen_ids: set[str] = set() - existing_profiles: list[UserProfile] = [] - - for i, query_text in enumerate(query_texts): - try: - results = storage.search_user_profile( # type: ignore[reportOptionalMemberAccess] - SearchUserProfileRequest( - query=query_text, - user_id=user_id, - top_k=10, - threshold=0.4, - ), - status_filter=[None], # Only current profiles - query_embedding=embeddings[i], - ) - for profile in results: - if profile.profile_id and profile.profile_id not in seen_ids: - seen_ids.add(profile.profile_id) - existing_profiles.append(profile) - except Exception as e: # noqa: PERF203 - logger.warning( - "Failed to search existing profiles for query %d: %s", i, e - ) - - logger.info( - "Retrieved %d unique existing profiles for deduplication", - len(existing_profiles), - ) - return existing_profiles - - def deduplicate( - self, - new_profiles: list[UserProfile], - user_id: str, - request_id: str, - ) -> tuple[list[UserProfile], list[str], list[UserProfile]]: - """ - Deduplicate profiles across extractors and against existing profiles in DB. - - Args: - new_profiles: List of new UserProfile objects from extractors - request_id: Request ID for context - user_id: User ID to scope the existing profile search - - Returns: - Tuple of (deduplicated profiles, existing profile IDs to delete, superseded existing profiles) - """ - # Check if mock mode is enabled - if os.getenv("MOCK_LLM_RESPONSE", "").lower() == "true": - logger.info("Mock mode: skipping deduplication") - return new_profiles, [], [] - - if not new_profiles: - return [], [], [] - - # Retrieve existing profiles via hybrid search - existing_profiles = self._retrieve_existing_profiles(new_profiles, user_id) - - # Format for prompt - new_text, existing_text = self._format_new_and_existing_for_prompt( - new_profiles, existing_profiles - ) - - # Build and call LLM - prompt = self.request_context.prompt_manager.render_prompt( - self._get_prompt_id(), - { - "new_profile_count": len(new_profiles), - "new_profiles": new_text, - "existing_profile_count": len(existing_profiles), - "existing_profiles": existing_text, - }, - ) - - output_schema_class = self._get_output_schema_class() - - try: - from reflexio.server.services.service_utils import ( - log_llm_messages, - log_model_response, - ) - - log_llm_messages( - logger, "Profile deduplication", [{"role": "user", "content": prompt}] - ) - - response = self.client.generate_chat_response( - messages=[{"role": "user", "content": prompt}], - model=self.model_name, - response_format=output_schema_class, - ) - - log_model_response(logger, "Deduplication response", response) - - if not isinstance(response, ProfileDeduplicationOutput): - logger.warning( - "Unexpected response type from deduplication LLM: %s", - type(response), - ) - return _strip_deletion_markers(new_profiles), [], [] - - dedup_output = response - except Exception as e: - logger.error("Failed to identify duplicates: %s", str(e)) - return _strip_deletion_markers(new_profiles), [], [] - - if not dedup_output.duplicate_groups and not dedup_output.deletions: - logger.info("No duplicate or deletion actions for request %s", request_id) - return _strip_deletion_markers(new_profiles), [], [] - - logger.info( - "Found %d duplicate profile groups and %d deletion directives for request %s", - len(dedup_output.duplicate_groups), - len(dedup_output.deletions), - request_id, - ) - - # Build deduplicated result - return self._build_deduplicated_results( - new_profiles=new_profiles, - existing_profiles=existing_profiles, - dedup_output=dedup_output, - user_id=user_id, - request_id=request_id, - ) - - def _build_deduplicated_results( - self, - new_profiles: list[UserProfile], - existing_profiles: list[UserProfile], - dedup_output: ProfileDeduplicationOutput, - user_id: str, - request_id: str, - ) -> tuple[list[UserProfile], list[str], list[UserProfile]]: - """ - Build the deduplicated profile list from LLM output. - - Args: - new_profiles: Flattened list of new profiles - existing_profiles: List of existing profiles from DB - dedup_output: LLM deduplication output - user_id: User ID - request_id: Request ID - - Returns: - Tuple of (profiles ready to save, existing profile IDs to delete, superseded existing profiles) - """ - handled_new_indices: set[int] = set() - result_profiles: list[UserProfile] = [] - existing_ids_to_delete: list[str] = [] - seen_delete_ids: set[str] = set() - superseded_profiles: list[UserProfile] = [] - - now_ts = int(datetime.now(UTC).timestamp()) - - # Process deletion directives first. A directive is a NEW profile that - # is a meta-request to forget an EXISTING profile. Both the NEW and the - # matched EXISTING profile(s) are removed with no merged replacement. - self._apply_deletion_directives( - dedup_output.deletions, - new_profiles=new_profiles, - existing_profiles=existing_profiles, - handled_new_indices=handled_new_indices, - existing_ids_to_delete=existing_ids_to_delete, - seen_delete_ids=seen_delete_ids, - superseded_profiles=superseded_profiles, - ) - - # Process duplicate groups - for group in dedup_output.duplicate_groups: - group_new_indices: list[int] = [] - group_existing_indices: list[int] = [] - - for item_id in group.item_ids: - parsed = parse_item_id(item_id) - if parsed is None: - continue - prefix, idx = parsed - if prefix == "NEW": - group_new_indices.append(idx) - elif prefix == "EXISTING": - group_existing_indices.append(idx) - - # Reject groups that overlap with profiles already consumed by a - # deletion directive. Merging such a group would write a - # replacement profile containing content the user asked to forget. - conflicting_new = [i for i in group_new_indices if i in handled_new_indices] - conflicting_existing = [ - i - for i in group_existing_indices - if 0 <= i < len(existing_profiles) - and existing_profiles[i].profile_id - and existing_profiles[i].profile_id in seen_delete_ids - ] - if conflicting_new or conflicting_existing: - logger.warning( - "Skipping duplicate group %s: overlaps with deletion " - "directives (NEW indices=%s, EXISTING indices=%s)", - group.item_ids, - conflicting_new, - conflicting_existing, - ) - continue - - # Mark NEW indices as handled only after the overlap check passes. - for idx in group_new_indices: - handled_new_indices.add(idx) - - # Collect existing profile IDs to delete and their profiles for changelog (deduplicated) - for eidx in group_existing_indices: - self._mark_existing_for_deletion( - f"EXISTING-{eidx}", - existing_profiles, - existing_ids_to_delete, - seen_delete_ids, - superseded_profiles, - ) - - # Get template from first NEW profile in group (for metadata) - template_profile: UserProfile | None = None - if group_new_indices: - first_new_idx = group_new_indices[0] - if 0 <= first_new_idx < len(new_profiles): - template_profile = new_profiles[first_new_idx] - - if template_profile is None: - logger.warning("Could not find template profile for group, skipping") - continue - - # Merge custom_features from all NEW profiles in group - group_new_profiles = [ - new_profiles[i] for i in group_new_indices if 0 <= i < len(new_profiles) - ] - merged_custom_features = self._merge_custom_features(group_new_profiles) - - # Merge extractor_names from all NEW profiles in group - merged_extractor_names = self._merge_extractor_names(group_new_profiles) - - # Determine TTL - try: - ttl = ProfileTimeToLive(group.merged_time_to_live) - except ValueError: - ttl = template_profile.profile_time_to_live - logger.warning( - "Invalid TTL '%s' from LLM, using template TTL '%s'", - group.merged_time_to_live, - ttl.value, - ) - - merged_profile = UserProfile( - profile_id=str(uuid.uuid4()), - user_id=user_id, - content=group.merged_content, - last_modified_timestamp=now_ts, - generated_from_request_id=request_id, - profile_time_to_live=ttl, - expiration_timestamp=calculate_expiration_timestamp(now_ts, ttl), - custom_features=merged_custom_features, - source=template_profile.source, - status=template_profile.status, - extractor_names=merged_extractor_names, - ) - result_profiles.append(merged_profile) - - # Add unique NEW profiles - for uid in dedup_output.unique_ids: - parsed = parse_item_id(uid) - if parsed is None: - continue - prefix, idx = parsed - if ( - prefix == "NEW" - and idx not in handled_new_indices - and 0 <= idx < len(new_profiles) - ): - result_profiles.append(new_profiles[idx]) - handled_new_indices.add(idx) - - # Safety fallback: add any NEW profiles not mentioned by LLM - for idx, profile in enumerate(new_profiles): - if idx not in handled_new_indices: - logger.warning( - "New profile at index %d was not handled by LLM, adding as-is", - idx, - ) - result_profiles.append(profile) - - return result_profiles, existing_ids_to_delete, superseded_profiles - - def _apply_deletion_directives( - self, - directives: list[ProfileDeletionDirective], - *, - new_profiles: list[UserProfile], - existing_profiles: list[UserProfile], - handled_new_indices: set[int], - existing_ids_to_delete: list[str], - seen_delete_ids: set[str], - superseded_profiles: list[UserProfile], - ) -> None: - """ - Apply deletion directives in place: consume the NEW profile and mark matched - EXISTING profile(s) for deletion without producing a merged replacement. - - A directive is a NEW profile whose content is a meta-request to forget an - EXISTING profile (e.g. "Requested removal of interest in X from stored - profiles"). The NEW is suppressed from the result set and the matched - EXISTING rows are added to the deletion list. - - Args: - directives: Deletion directives from the LLM. - new_profiles: Flat list of NEW profiles (indexed by NEW-N id). - existing_profiles: List of EXISTING profiles (indexed by EXISTING-M id). - handled_new_indices: Set of NEW indices already accounted for; this - method adds the consumed directive indices to it. - existing_ids_to_delete: Output list of profile IDs to delete; this - method appends to it. - seen_delete_ids: Set used to deduplicate IDs across all deletion paths. - superseded_profiles: Output list of deleted profile objects for the - changelog; this method appends to it. - """ - for directive in directives: - self._consume_new_index( - directive.new_id, len(new_profiles), handled_new_indices - ) - for eid in directive.existing_ids: - self._mark_existing_for_deletion( - eid, - existing_profiles, - existing_ids_to_delete, - seen_delete_ids, - superseded_profiles, - ) - logger.info( - "Profile deletion directive %s -> delete %s: %s", - directive.new_id, - directive.existing_ids, - directive.reasoning, - ) - - @staticmethod - def _consume_new_index( - new_id: str, new_profile_count: int, handled_new_indices: set[int] - ) -> None: - """Mark a NEW-N id as handled so the safety fallback does not re-add it.""" - parsed = parse_item_id(new_id) - if parsed is None: - return - prefix, idx = parsed - if prefix == "NEW" and 0 <= idx < new_profile_count: - handled_new_indices.add(idx) - - @staticmethod - def _mark_existing_for_deletion( - existing_id: str, - existing_profiles: list[UserProfile], - existing_ids_to_delete: list[str], - seen_delete_ids: set[str], - superseded_profiles: list[UserProfile], - ) -> None: - """Resolve an EXISTING-N id to a profile_id and queue it for deletion.""" - parsed = parse_item_id(existing_id) - if parsed is None: - return - prefix, idx = parsed - if prefix != "EXISTING" or not (0 <= idx < len(existing_profiles)): - return - pid = existing_profiles[idx].profile_id - if pid and pid not in seen_delete_ids: - seen_delete_ids.add(pid) - existing_ids_to_delete.append(pid) - superseded_profiles.append(existing_profiles[idx]) - - def _merge_custom_features(self, profiles: list[UserProfile]) -> dict | None: - """ - Merge custom_features from multiple profiles. - - Args: - profiles: List of profiles to merge custom_features from - - Returns: - Merged custom_features dict or None if no custom_features - """ - merged = {} - for profile in profiles: - if profile.custom_features: - merged.update(profile.custom_features) - - return merged or None - - def _merge_extractor_names(self, profiles: list[UserProfile]) -> list[str] | None: - """ - Merge extractor_names from multiple profiles, preserving order and removing duplicates. - - Args: - profiles: List of profiles to merge extractor_names from - - Returns: - Merged list of unique extractor names or None if no extractor_names - """ - seen: set[str] = set() - merged: list[str] = [] - for profile in profiles: - if profile.extractor_names: - for name in profile.extractor_names: - if name not in seen: - seen.add(name) - merged.append(name) - return merged or None diff --git a/reflexio/server/services/profile/profile_extractor.py b/reflexio/server/services/profile/profile_extractor.py index d0b3dc33..268fda39 100644 --- a/reflexio/server/services/profile/profile_extractor.py +++ b/reflexio/server/services/profile/profile_extractor.py @@ -3,7 +3,6 @@ import logging import os import time -import uuid from datetime import UTC, datetime from typing import TYPE_CHECKING @@ -14,6 +13,7 @@ from reflexio.models.config_schema import ProfileExtractorConfig from reflexio.server.api_endpoints.request_context import RequestContext from reflexio.server.llm.litellm_client import LiteLLMClient +from reflexio.server.services.extraction.tools import new_profile_id from reflexio.server.services.extractor_interaction_utils import ( get_effective_source_filter, get_extractor_window_params, @@ -278,7 +278,7 @@ def _convert_raw_to_user_profiles( ttl = ProfileTimeToLive(profile_content.get("time_to_live", "infinity")) added_profile = UserProfile( - profile_id=str(uuid.uuid4()), + profile_id=new_profile_id(), user_id=user_id, content=profile_content["content"], last_modified_timestamp=now_ts, diff --git a/reflexio/server/services/profile/profile_generation_service.py b/reflexio/server/services/profile/profile_generation_service.py index a09e0ef5..facada0b 100644 --- a/reflexio/server/services/profile/profile_generation_service.py +++ b/reflexio/server/services/profile/profile_generation_service.py @@ -156,28 +156,6 @@ def _process_results(self, results: list[list[UserProfile]]) -> None: existing_ids_to_delete: list[str] = [] superseded_profiles: list[UserProfile] = [] - # Always run deduplicator when enabled and there are new profiles - if all_new_profiles: - from reflexio.server.site_var.feature_flags import is_deduplicator_enabled - - if is_deduplicator_enabled(self.org_id): - from reflexio.server.services.profile.profile_deduplicator import ( - ProfileDeduplicator, - ) - - deduplicator = ProfileDeduplicator( - request_context=self.request_context, - llm_client=self.client, - ) - all_new_profiles, existing_ids_to_delete, superseded_profiles = ( - deduplicator.deduplicate(all_new_profiles, user_id, request_id) - ) - logger.info( - "Profile updates after deduplication: %d profiles, %d existing to delete", - len(all_new_profiles), - len(existing_ids_to_delete), - ) - # Set source and status for all profiles for profile in all_new_profiles: profile.source = source diff --git a/reflexio/server/services/profile/profile_generation_service_utils.py b/reflexio/server/services/profile/profile_generation_service_utils.py index 9106b743..973daff9 100644 --- a/reflexio/server/services/profile/profile_generation_service_utils.py +++ b/reflexio/server/services/profile/profile_generation_service_utils.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator +from reflexio.models.api_schema.common import NEVER_EXPIRES_TIMESTAMP from reflexio.models.api_schema.internal_schema import RequestInteractionDataModel from reflexio.models.api_schema.service_schemas import ( ProfileTimeToLive, @@ -91,6 +92,18 @@ class ProfileAddItem(BaseModel): default=None, description="Metadata extracted for the profile based on metadata definition", ) + source_span: str | None = Field( + default=None, + description="Verbatim excerpt from the source that most directly supports this profile item", + ) + notes: str | None = Field( + default=None, + description="Free-form extraction notes — confidence, caveats, or alternative readings", + ) + reader_angle: str | None = Field( + default=None, + description="The extraction perspective or reader role that surfaced this item", + ) # OpenAI structured output requires explicit schema constraints model_config = ConfigDict( @@ -161,9 +174,10 @@ def calculate_expiration_timestamp( Returns: The expiration timestamp for the profile. """ - expiration_timestamp = datetime.max - last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp) + if profile_time_to_live == ProfileTimeToLive.INFINITY: + return NEVER_EXPIRES_TIMESTAMP + last_modified_datetime = datetime.fromtimestamp(last_modified_timestamp) if profile_time_to_live == ProfileTimeToLive.ONE_DAY: expiration_timestamp = last_modified_datetime + timedelta(days=1) elif profile_time_to_live == ProfileTimeToLive.ONE_WEEK: @@ -174,16 +188,9 @@ def calculate_expiration_timestamp( expiration_timestamp = last_modified_datetime + timedelta(days=90) elif profile_time_to_live == ProfileTimeToLive.ONE_YEAR: expiration_timestamp = last_modified_datetime + timedelta(days=365) - elif profile_time_to_live == ProfileTimeToLive.INFINITY: - expiration_timestamp = datetime.max else: raise ValueError(f"Invalid profile time to live: {profile_time_to_live}") - try: - return int(expiration_timestamp.timestamp()) - except (OverflowError, OSError, ValueError): - import sys - - return sys.maxsize + return int(expiration_timestamp.timestamp()) def check_string_token_overlap(str1: str, str2: str, threshold: float = 0.7) -> bool: diff --git a/reflexio/server/services/search/__init__.py b/reflexio/server/services/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reflexio/server/services/search/agentic_search_service.py b/reflexio/server/services/search/agentic_search_service.py new file mode 100644 index 00000000..e5bb0219 --- /dev/null +++ b/reflexio/server/services/search/agentic_search_service.py @@ -0,0 +1,308 @@ +"""AgenticSearchService — single SearchAgent loop replacing the v1 6+2 stack. + +Agentic-v2 delegates to a single ``SearchAgent`` that drives a tool loop +(``search_user_profiles``, ``search_user_playbooks``, ``search_agent_playbooks``, +``finish``) and returns a free-text answer plus populated entity lists harvested +from the tool-loop trace. + +API contract preserved: +- Constructor: ``AgenticSearchService(llm_client, request_context)`` +- Method: ``.search(request: UnifiedSearchRequest) -> UnifiedSearchResponse`` +- ``UnifiedSearchResponse.agent_answer`` carries the agent's natural-language answer. +- ``UnifiedSearchResponse.profiles`` / ``user_playbooks`` / ``agent_playbooks`` are + populated by filtering per-user storage reads against the IDs seen in the trace. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from reflexio.models.api_schema.retriever_schema import ( + UnifiedSearchRequest, + UnifiedSearchResponse, +) +from reflexio.server.services.pre_retrieval import QueryReformulator +from reflexio.server.services.search.plan import SearchResult +from reflexio.server.services.search.search_agent import SearchAgent + +if TYPE_CHECKING: + from reflexio.server.api_endpoints.request_context import RequestContext + from reflexio.server.llm.litellm_client import LiteLLMClient + from reflexio.server.llm.tools import ToolLoopTrace + +logger = logging.getLogger(__name__) + +# Tool names that produce profile results in the trace +_PROFILE_TOOLS = {"search_user_profiles", "get_user_profile"} +# Tool names that produce user playbook results in the trace +_USER_PLAYBOOK_TOOLS = {"search_user_playbooks", "get_user_playbook"} +# Tool names that produce agent playbook results in the trace +_AGENT_PLAYBOOK_TOOLS = {"search_agent_playbooks", "get_agent_playbook"} + + +def _harvest_ids_from_trace( + trace: ToolLoopTrace, +) -> tuple[list[str], list[str], list[str]]: + """Walk the trace and harvest entity IDs in first-seen order. + + Args: + trace (ToolLoopTrace): Full tool-loop trace from a SearchAgent run. + + Returns: + tuple[list[str], list[str], list[str]]: Three ordered lists of unique IDs: + profile_ids, user_playbook_ids, agent_playbook_ids. + """ + profile_ids: list[str] = [] + user_playbook_ids: list[str] = [] + agent_playbook_ids: list[str] = [] + + seen_profiles: set[str] = set() + seen_user_playbooks: set[str] = set() + seen_agent_playbooks: set[str] = set() + + for turn in trace.turns: + tool = turn.tool_name + result = turn.result + + if tool in _PROFILE_TOOLS: + # search returns {"hits": [...]} each item has "id" + # get returns {"profile": {...}} with "id" + items = result.get("hits") or ( + [result["profile"]] if "profile" in result else [] + ) + for item in items: + pid = item.get("id", "") if isinstance(item, dict) else "" + if pid and pid not in seen_profiles: + seen_profiles.add(pid) + profile_ids.append(pid) + + elif tool in _USER_PLAYBOOK_TOOLS: + items = result.get("hits") or ( + [result["playbook"]] if "playbook" in result else [] + ) + for item in items: + pid = item.get("id", "") if isinstance(item, dict) else "" + if pid and pid not in seen_user_playbooks: + seen_user_playbooks.add(pid) + user_playbook_ids.append(pid) + + elif tool in _AGENT_PLAYBOOK_TOOLS: + items = result.get("hits") or ( + [result["playbook"]] if "playbook" in result else [] + ) + for item in items: + pid = item.get("id", "") if isinstance(item, dict) else "" + if pid and pid not in seen_agent_playbooks: + seen_agent_playbooks.add(pid) + agent_playbook_ids.append(pid) + + return profile_ids, user_playbook_ids, agent_playbook_ids + + +def _filter_ordered( + entities: list, + id_attr: str, + ordered_ids: list[str], + top_k: int, +) -> list: + """Filter entities by id set and return them in first-seen trace order, capped at top_k. + + Args: + entities (list): Full list of entities fetched from storage. + id_attr (str): Attribute name on each entity that holds its string ID. + ordered_ids (list[str]): IDs in first-seen trace order. + top_k (int): Maximum number of results to return. + + Returns: + list: Filtered and ordered entities, at most top_k items. + """ + id_set = set(ordered_ids) + by_id = { + str(getattr(e, id_attr, "")): e + for e in entities + if str(getattr(e, id_attr, "")) in id_set + } + return [by_id[eid] for eid in ordered_ids if eid in by_id][:top_k] + + +class AgenticSearchService: + """Agentic search orchestrator wired into the backend dispatcher. + + Construction matches ``UnifiedSearchService`` so ``build_search_service`` + can swap the two transparently: both accept ``llm_client`` and + ``request_context`` as keyword arguments. + + Args: + llm_client (LiteLLMClient): Configured LLM client for all agent calls. + request_context (RequestContext): Request context providing + ``storage`` and ``prompt_manager``. + """ + + def __init__( + self, + *, + llm_client: LiteLLMClient, + request_context: RequestContext, + ) -> None: + self.client = llm_client + self.request_context = request_context + self.storage = request_context.storage + self.prompt_manager = request_context.prompt_manager + + def search(self, request: UnifiedSearchRequest) -> UnifiedSearchResponse: + """Execute the agentic-v2 search for one request. + + Optionally reformulates the query, then delegates to ``SearchAgent`` + which drives a tool loop and returns a natural-language answer. + Entity IDs visited during the loop are harvested from the trace and + used to populate the response entity lists. + + Args: + request (UnifiedSearchRequest): The unified search request. + + Returns: + UnifiedSearchResponse: ``success=True``, entity lists populated from + the agent's trace, and the agent's answer in ``agent_answer``. + ``reformulated_query`` carries the (possibly rewritten) query used + for the search. + """ + # Reject requests missing the user_id rather than silently coercing + # to empty strings. An empty user_id flows into storage operations + # (storage.get_user_profile, storage.add_user_profile) and would + # either return cross-user data on SqliteStorage or write to an + # unintended path on DiskStorage. Surface the bug at the boundary. + # agent_version is NOT required — it scopes AgentPlaybook reads + # (cross-user rules), and an empty value just means "no AgentPlaybook + # scope filter," which is safe. + if not request.user_id: + raise ValueError( + "agentic search requires a non-empty user_id; got empty" + ) + + query = self._reformulate(request) + + agent = SearchAgent( + client=self.client, + storage=self.storage, + prompt_manager=self.prompt_manager, + # Tight budget for benchmark throughput; default is 10. + # Floor is 2 (one search → finish); 5 accommodates the + # rehydration-mandated patterns (search → reformulate → + # rehydrate → finish = 4) plus one optional rerank step, + # while still bounding wasted work on simple questions. + max_steps=5, + enable_agent_answer=bool(request.enable_agent_answer), + ) + result = agent.run( + user_id=request.user_id, + agent_version=request.agent_version or "", + query=query, + ) + + if result.outcome == "error": + logger.warning("search agent returned error for query %r", query[:80]) + return UnifiedSearchResponse( + success=True, + profiles=[], + user_playbooks=[], + agent_playbooks=[], + reformulated_query=query, + msg=f"agent error: {result.answer or 'unknown'}", + agent_answer=None, + ) + + if result.budget_exceeded: + logger.warning("search agent hit max_steps budget for query %r", query[:80]) + + profiles, user_playbooks, agent_playbooks = self._fetch_entities( + request, result + ) + + return UnifiedSearchResponse( + success=True, + profiles=profiles, + user_playbooks=user_playbooks, + agent_playbooks=agent_playbooks, + reformulated_query=query, + msg=None, + agent_answer=result.answer, + ) + + # ------------------------------------------------------------------ # + # Internal helpers # + # ------------------------------------------------------------------ # + + def _reformulate(self, request: UnifiedSearchRequest) -> str: + """Run QueryReformulator when enabled; otherwise return the raw query. + + Reformulation failures fall back to the raw query (the reformulator + is responsible for its own exception handling). + + Args: + request (UnifiedSearchRequest): The search request. + + Returns: + str: Reformulated query string, or the original query if + reformulation is disabled or the reformulator returns nothing. + """ + if not request.enable_reformulation: + return request.query + reformulator = QueryReformulator( + llm_client=self.client, prompt_manager=self.prompt_manager + ) + result = reformulator.rewrite(request.query, request.conversation_history) + return result.standalone_query or request.query + + def _fetch_entities( + self, + request: UnifiedSearchRequest, + result: SearchResult, + ) -> tuple[list, list, list]: + """Harvest entity IDs from trace, fetch all-user entities once, filter in-memory. + + Args: + request (UnifiedSearchRequest): The original search request (for user_id, + agent_version, top_k). + result (SearchResult): Completed agent run with trace. + + Returns: + tuple[list, list, list]: (profiles, user_playbooks, agent_playbooks) each + filtered and ordered by first-seen trace position, capped at top_k. + """ + top_k = request.top_k or 5 + user_id = request.user_id or "" + agent_version = request.agent_version or "" + + profile_ids, user_playbook_ids, agent_playbook_ids = _harvest_ids_from_trace( + result.trace + ) + + storage = self.storage + if storage is None: + return [], [], [] + + profiles: list = [] + if profile_ids: + all_profiles = storage.get_user_profile(user_id) + profiles = _filter_ordered(all_profiles, "profile_id", profile_ids, top_k) + + user_playbooks: list = [] + if user_playbook_ids: + all_user_playbooks = storage.get_user_playbooks( + user_id=user_id, agent_version=agent_version + ) + user_playbooks = _filter_ordered( + all_user_playbooks, "user_playbook_id", user_playbook_ids, top_k + ) + + agent_playbooks: list = [] + if agent_playbook_ids: + all_agent_playbooks = storage.get_agent_playbooks( + agent_version=agent_version + ) + agent_playbooks = _filter_ordered( + all_agent_playbooks, "agent_playbook_id", agent_playbook_ids, top_k + ) + + return profiles, user_playbooks, agent_playbooks diff --git a/reflexio/server/services/search/plan.py b/reflexio/server/services/search/plan.py new file mode 100644 index 00000000..1172ca33 --- /dev/null +++ b/reflexio/server/services/search/plan.py @@ -0,0 +1,29 @@ +"""Plan types for the agentic-v2 search pipeline.""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict + +from reflexio.server.llm.tools import ToolLoopTrace + + +class SearchResult(BaseModel): + """Outcome of one SearchAgent run. + + Args: + answer (str | None): The LLM-synthesised answer from finish(answer); None + when the agent ran in search-only mode (``enable_agent_answer=False``) + and deliberately did not synthesize a free-text answer. + outcome (str): How the loop terminated. + budget_exceeded (bool): True when outcome == "max_steps". + trace (ToolLoopTrace): Full tool-loop trace — ids harvested by callers for entity fetch. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + answer: str | None + outcome: Literal["finish_tool", "max_steps", "error"] + budget_exceeded: bool + trace: ToolLoopTrace diff --git a/reflexio/server/services/search/search_agent.py b/reflexio/server/services/search/search_agent.py new file mode 100644 index 00000000..300c1153 --- /dev/null +++ b/reflexio/server/services/search/search_agent.py @@ -0,0 +1,163 @@ +"""Thin runner for the agentic-v2 search pipeline. Read-only — no commit stage.""" + +from __future__ import annotations + +import logging +import time +from collections import Counter + +from reflexio.server.llm.litellm_client import LiteLLMClient +from reflexio.server.llm.model_defaults import ModelRole +from reflexio.server.llm.tools import ToolLoopTrace, run_tool_loop +from reflexio.server.prompt.prompt_manager import PromptManager +from reflexio.server.services.extraction.plan import ExtractionCtx, HandlerBundle +from reflexio.server.services.extraction.tools import ( + SEARCH_TOOLS, + SearchAgentTurnPlan, +) +from reflexio.server.services.search.plan import SearchResult + +logger = logging.getLogger(__name__) + + +def _summarise_tool_calls(trace: ToolLoopTrace) -> str: + """Return a compact 'tool_a:2, tool_b:1' string from a ToolLoopTrace. + + Args: + trace (ToolLoopTrace): The completed tool loop trace. + + Returns: + str: Comma-separated name:count pairs ordered by frequency, or '(none)'. + """ + counts = Counter(t.tool_name for t in trace.turns) + return ", ".join(f"{name}:{n}" for name, n in counts.most_common()) or "(none)" + + +def _summarise_usage(trace: ToolLoopTrace) -> str: + """Return a per-model 'model_x: N tokens, $0.0078' string aggregated across all turns. + + A single response's usage is attached to every turn it produced, so this + function deduplicates by (model, prompt_tokens, completion_tokens) to avoid + double-counting when one LLM call produced multiple tool calls. + + Args: + trace (ToolLoopTrace): The completed tool loop trace. + + Returns: + str: Semicolon-separated per-model summaries, or '(none)'. + """ + seen: set[tuple[str, int, int]] = set() + per_model: dict[str, dict[str, float]] = {} + for t in trace.turns: + if t.model is None or t.prompt_tokens is None or t.completion_tokens is None: + continue + key = (t.model, t.prompt_tokens, t.completion_tokens) + if key in seen: + continue + seen.add(key) + bucket = per_model.setdefault(t.model, {"tokens": 0.0, "cost": 0.0}) + bucket["tokens"] += t.total_tokens or 0 + bucket["cost"] += t.cost_usd or 0.0 + if not per_model: + return "(none)" + return "; ".join( + f"{m}: {int(v['tokens'])} tokens, ${v['cost']:.6f}" + for m, v in per_model.items() + ) + + +class SearchAgent: + """Single-loop adaptive search agent (read-only). + + Assembles the seed message from the search_agent prompt, drives + ``run_tool_loop`` with ``SEARCH_TOOLS``, and extracts the answer stashed on + ctx by ``_handle_search_finish``. No commit stage occurs. + + Args: + client (LiteLLMClient): LLM client for the underlying tool loop. + storage: BaseStorage handle (read-only for this agent). + prompt_manager (PromptManager): Renders the ``search_agent`` prompt. + max_steps (int): Cap on tool-calling turns (default 10; spec §7.2). + """ + + def __init__( + self, + *, + client: LiteLLMClient, + storage: object, + prompt_manager: PromptManager, + max_steps: int = 10, + enable_agent_answer: bool = False, + ) -> None: + self.client = client + self.storage = storage + self.prompt_manager = prompt_manager + self.max_steps = max_steps + self.enable_agent_answer = enable_agent_answer + + def run(self, *, user_id: str, agent_version: str, query: str) -> SearchResult: + """Run one search loop for the given query. + + Args: + user_id (str): Authenticated user scope. + agent_version (str): Active agent_version for playbook scoping. + query (str): The search query to answer. + + Returns: + SearchResult: Typed outcome with answer, termination reason, budget flag, + and the full tool-loop trace for entity harvesting by callers. + """ + ctx = ExtractionCtx(user_id=user_id, agent_version=agent_version) + bundle = HandlerBundle(storage=self.storage, ctx=ctx) + + prompt = self.prompt_manager.render_prompt( + "search_agent", + variables={ + "query": query, + "max_steps": str(self.max_steps), + "enable_agent_answer": "true" if self.enable_agent_answer else "false", + }, + ) + + t0 = time.monotonic() + result = run_tool_loop( + client=self.client, + messages=[{"role": "user", "content": prompt}], + registry=SEARCH_TOOLS, + model_role=ModelRole.SEARCH_AGENT, + max_steps=self.max_steps, + ctx=bundle, + finish_tool_name="finish", + multi_stage_schema=SearchAgentTurnPlan, + log_label="search_agent", + ) + + # In search-only mode the agent is told to call finish() with no answer; + # we surface None so callers can distinguish "agent declined to answer" + # from "agent failed". Tests that exercised the answer path keep working + # because they default-construct SearchAgent with enable_agent_answer=False + # but populate ctx.search_answer via the mocked finish() call — when off, + # we deliberately drop whatever the agent wrote so the contract is clear. + if not self.enable_agent_answer: + answer: str | None = None + else: + answer = ctx.search_answer if ctx.search_answer is not None else "no answer" + elapsed_ms = int((time.monotonic() - t0) * 1000) + + logger.info( + "search_agent elapsed_ms=%d turns=%d/%d tools={%s} outcome=%s " + "answer_len=%d usage={%s}", + elapsed_ms, + len(result.trace.turns), + self.max_steps, + _summarise_tool_calls(result.trace), + result.finished_reason, + len(answer) if answer is not None else 0, + _summarise_usage(result.trace), + ) + return SearchResult( + answer=answer, + outcome=result.finished_reason, + budget_exceeded=result.finished_reason == "max_steps", + trace=result.trace, + ) diff --git a/reflexio/server/services/service_utils.py b/reflexio/server/services/service_utils.py index fc366b91..9422bc52 100644 --- a/reflexio/server/services/service_utils.py +++ b/reflexio/server/services/service_utils.py @@ -7,6 +7,7 @@ import logging import re from dataclasses import dataclass +from datetime import UTC, datetime from typing import Any from reflexio.cli.log_format import LLM_IO_LOG_FILE, next_llm_entry_id @@ -25,6 +26,42 @@ MODEL_RESPONSE_LEVEL = 25 +def _format_response_for_logging(response: Any) -> Any: + """Render ``ToolCallingChatResponse`` with pretty tool_calls; pass others through. + + The dataclass's ``__repr__`` (which ``%s`` formatting falls back to) + prints each tool_call as an opaque object handle + (````), erasing the + tool name + arguments the model emitted. This helper detects that + one case and renders a multi-line human-readable form using the + same ``_format_tool_calls`` helper the request-side formatter uses. + + All other response types (strings, Pydantic ``BaseModel`` instances + from classic extractors / deduplicators / aggregators) fall through + unchanged so the existing log shape is preserved. + + Lazy-imports ``ToolCallingChatResponse`` to avoid a circular + ``service_utils`` ↔ ``litellm_client`` dependency at module load. + """ + try: + from reflexio.server.llm.litellm_client import ToolCallingChatResponse + except Exception: # noqa: BLE001 - fall back gracefully if the import fails + return response + + if not isinstance(response, ToolCallingChatResponse): + return response + + lines = [ + f"ToolCallingChatResponse(finish_reason={response.finish_reason!r}):", + f" content: {response.content!r}", + ] + if response.tool_calls: + lines.extend(_format_tool_calls(response.tool_calls)) + else: + lines.append(" tool_calls: []") + return "\n".join(lines) + + def log_model_response( target_logger: logging.Logger, label: str, response: Any ) -> None: @@ -38,13 +75,16 @@ def log_model_response( response (Any): The model response to log """ entry_id = next_llm_entry_id() + # Special-case ToolCallingChatResponse so tool_calls render as + # id/name/arguments instead of opaque ``<… object at 0x…>`` handles. + formatted = _format_response_for_logging(response) # Full response to llm_io.log only (level 15 < INFO 20, so console ignores it) target_logger.log( LLM_PROMPT_LEVEL, "[#%d] %s: %s", entry_id, label, - response, + formatted, extra={"entry_id": entry_id, "label": label}, ) # One-line summary to console @@ -229,8 +269,36 @@ def format_sessions_to_history_string( formatted_groups = [] for group_name in sorted_group_names: - # Format header with session name - group_header = f"=== Session: {group_name} ===" + # Format header with session name AND its earliest interaction date. + # Without the date, downstream extraction agents have no anchor for + # resolving relative-time references in the conversation + # ("X weeks ago", "yesterday", "two days before the wedding") — + # they fall back to real-world `now()` and encode every event as + # today's date, breaking temporal-reasoning queries. + # + # We use the earliest *interaction* timestamp, not request.created_at, + # because Request.created_at defaults to `now()` on construction — + # only interactions reliably carry the conversation's true wall-clock + # time when the publisher provides it. + all_ts: list[int] = [ + i.created_at + for ri in grouped_by_name[group_name] + for i in ri.interactions + if i.created_at + ] + first_ts = min(all_ts) if all_ts else 0 + if first_ts: + try: + session_date_iso = datetime.fromtimestamp( + first_ts, tz=UTC + ).strftime("%Y-%m-%d") + group_header = ( + f"=== Session: {group_name} (date: {session_date_iso}) ===" + ) + except (OverflowError, OSError, ValueError): + group_header = f"=== Session: {group_name} ===" + else: + group_header = f"=== Session: {group_name} ===" # Combine all interactions from all requests in this session all_interactions = [] @@ -479,6 +547,57 @@ def parse_json_candidate(json_str: str) -> tuple[dict | None, str | None]: return {} +def _format_tool_calls(tool_calls: list[Any]) -> list[str]: + """Render an assistant message's ``tool_calls`` list for the log. + + Accepts either the OpenAI SDK object shape (with ``.function.name`` / + ``.function.arguments`` attrs) or the dict shape that pass-through + serialisation may produce. Returns one indented line per call with the + tool_call_id, the tool name, and the parsed arguments — so the log + reader can correlate each tool_call with its tool-role response. + """ + lines: list[str] = [" tool_calls:"] + for tc in tool_calls: + # Extract id, name, arguments from either attribute or mapping shape. + tc_id = getattr(tc, "id", None) or ( + tc.get("id") if isinstance(tc, dict) else None + ) + fn = getattr(tc, "function", None) + if fn is not None: + name = getattr(fn, "name", None) + args_raw = getattr(fn, "arguments", None) + elif isinstance(tc, dict): + fn_dict = tc.get("function", {}) or {} + name = fn_dict.get("name") if isinstance(fn_dict, dict) else None + args_raw = fn_dict.get("arguments") if isinstance(fn_dict, dict) else None + else: + name = None + args_raw = None + + # arguments comes through as a JSON string from the provider — parse + # for readability, fall back to raw text on malformed JSON. + parsed_args: Any + if isinstance(args_raw, str): + try: + parsed_args = json.loads(args_raw) + except json.JSONDecodeError: + parsed_args = args_raw + else: + parsed_args = args_raw + + lines.append(f" - id: {tc_id}") + lines.append(f" name: {name}") + # Logging path must never raise — fall back to repr() on + # non-serializable argument objects (datetime, sets, custom + # types, etc.) so a logging call can't take down a request. + try: + rendered_args = json.dumps(parsed_args) + except (TypeError, ValueError): + rendered_args = repr(parsed_args) + lines.append(f" arguments: {rendered_args}") + return lines + + def format_messages_for_logging(messages: list[dict[str, Any]]) -> str: """ Format messages for logging with proper newlines in text content. @@ -493,6 +612,14 @@ def format_messages_for_logging(messages: list[dict[str, Any]]) -> str: for i, msg in enumerate(messages): formatted_parts.append(f"Message {i + 1}:") formatted_parts.append(f" role: {msg.get('role', 'unknown')}") + + # Tool-role messages carry a ``tool_call_id`` that correlates them + # back to the assistant's emitted call — render it so readers can + # reconstruct which response answered which call. + tool_call_id = msg.get("tool_call_id") + if tool_call_id is not None: + formatted_parts.append(f" tool_call_id: {tool_call_id}") + content = msg.get("content", "") if isinstance(content, str): @@ -523,6 +650,14 @@ def format_messages_for_logging(messages: list[dict[str, Any]]) -> str: # Fallback to JSON for other types formatted_parts.append(f" content: {json.dumps(content, indent=4)}") + # Assistant messages with tool_calls must render the call list — + # otherwise the log shows ``content: null`` with no visibility into + # which tools the model invoked. Classic extraction doesn't use + # tool-calling, but the agentic pipeline relies on it heavily. + tool_calls = msg.get("tool_calls") + if tool_calls: + formatted_parts.extend(_format_tool_calls(tool_calls)) + formatted_parts.append("") # Empty line between messages return "\n".join(formatted_parts) diff --git a/reflexio/server/services/storage/sqlite_storage/_base.py b/reflexio/server/services/storage/sqlite_storage/_base.py index a54f48c2..4681ec55 100644 --- a/reflexio/server/services/storage/sqlite_storage/_base.py +++ b/reflexio/server/services/storage/sqlite_storage/_base.py @@ -334,6 +334,9 @@ def _row_to_profile(row: sqlite3.Row) -> UserProfile: status=Status(d["status"]) if d.get("status") else None, extractor_names=_json_loads(d.get("extractor_names")), expanded_terms=d.get("expanded_terms"), + source_span=d.get("source_span"), + notes=d.get("notes"), + reader_angle=d.get("reader_angle"), ) @@ -400,6 +403,9 @@ def _row_to_user_playbook( source_interaction_ids=_json_loads(d.get("source_interaction_ids")) or [], embedding=embedding, expanded_terms=d.get("expanded_terms"), + source_span=d.get("source_span"), + notes=d.get("notes"), + reader_angle=d.get("reader_angle"), ) @@ -599,6 +605,7 @@ def migrate(self) -> bool: self._migrate_vec_tables() # Run after DDL so tables exist on fresh databases self._migrate_expanded_terms() + self._migrate_agentic_signals() return True def _try_load_sqlite_vec(self) -> bool: @@ -842,6 +849,24 @@ def _migrate_expanded_terms(self) -> None: logger.info("Added expanded_terms column to %s", table) self.conn.commit() + def _migrate_agentic_signals(self) -> None: + """Add source_span/notes/reader_angle columns if missing. + + Backfill-safe: columns are nullable with no default. Applies to both + the profiles and user_playbooks tables — the agentic extraction + pipeline populates them per-row; classic extraction leaves them NULL. + """ + for table in ("profiles", "user_playbooks"): + cols = { + row["name"] + for row in self.conn.execute(f"PRAGMA table_info({table})").fetchall() + } + for col in ("source_span", "notes", "reader_angle"): + if col not in cols: + self.conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} TEXT") # noqa: S608 + logger.info("Added %s column to %s", col, table) + self.conn.commit() + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @@ -1048,6 +1073,9 @@ def _vec_knn_search( status TEXT, extractor_names TEXT, expanded_terms TEXT, + source_span TEXT, + notes TEXT, + reader_angle TEXT, created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) ); CREATE INDEX IF NOT EXISTS idx_profiles_user_id ON profiles(user_id); @@ -1099,7 +1127,10 @@ def _vec_knn_search( status TEXT, source TEXT, embedding TEXT, - expanded_terms TEXT + expanded_terms TEXT, + source_span TEXT, + notes TEXT, + reader_angle TEXT ); CREATE INDEX IF NOT EXISTS idx_user_playbooks_playbook_name ON user_playbooks(playbook_name); CREATE INDEX IF NOT EXISTS idx_user_playbooks_agent_version ON user_playbooks(agent_version); diff --git a/reflexio/server/services/storage/sqlite_storage/_playbook.py b/reflexio/server/services/storage/sqlite_storage/_playbook.py index 3f7fd81c..c91d1646 100644 --- a/reflexio/server/services/storage/sqlite_storage/_playbook.py +++ b/reflexio/server/services/storage/sqlite_storage/_playbook.py @@ -81,8 +81,9 @@ def save_user_playbooks(self, user_playbooks: list[UserPlaybook]) -> None: (user_id, playbook_name, created_at, request_id, agent_version, content, trigger, rationale, blocking_issue, source_interaction_ids, - status, source, embedding, expanded_terms) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + status, source, embedding, expanded_terms, + source_span, notes, reader_angle) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", ( up.user_id, up.playbook_name, @@ -100,6 +101,9 @@ def save_user_playbooks(self, user_playbooks: list[UserPlaybook]) -> None: up.source, _json_dumps(up.embedding), up.expanded_terms, + up.source_span, + up.notes, + up.reader_angle, ), ) upid = cur.lastrowid or 0 diff --git a/reflexio/server/services/storage/sqlite_storage/_profiles.py b/reflexio/server/services/storage/sqlite_storage/_profiles.py index 6e21b4bb..099279e6 100644 --- a/reflexio/server/services/storage/sqlite_storage/_profiles.py +++ b/reflexio/server/services/storage/sqlite_storage/_profiles.py @@ -108,8 +108,9 @@ def add_user_profile(self, user_id: str, user_profiles: list[UserProfile]) -> No (profile_id, user_id, content, last_modified_timestamp, generated_from_request_id, profile_time_to_live, expiration_timestamp, custom_features, embedding, source, - status, extractor_names, expanded_terms, created_at) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", + status, extractor_names, expanded_terms, + source_span, notes, reader_angle, created_at) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", ( profile.profile_id, profile.user_id, @@ -124,6 +125,9 @@ def add_user_profile(self, user_id: str, user_profiles: list[UserProfile]) -> No profile.status.value if profile.status else None, _json_dumps(profile.extractor_names), profile.expanded_terms, + profile.source_span, + profile.notes, + profile.reader_angle, _iso_now(), ), ) @@ -164,7 +168,8 @@ def update_user_profile_by_id( """UPDATE profiles SET content=?, last_modified_timestamp=?, generated_from_request_id=?, profile_time_to_live=?, expiration_timestamp=?, custom_features=?, embedding=?, - source=?, status=?, extractor_names=?, expanded_terms=? + source=?, status=?, extractor_names=?, expanded_terms=?, + source_span=?, notes=?, reader_angle=? WHERE profile_id=?""", ( new_profile.content, @@ -178,6 +183,9 @@ def update_user_profile_by_id( new_profile.status.value if new_profile.status else None, _json_dumps(new_profile.extractor_names), new_profile.expanded_terms, + new_profile.source_span, + new_profile.notes, + new_profile.reader_angle, profile_id, ), ) diff --git a/reflexio/server/services/unified_search_service.py b/reflexio/server/services/unified_search_service.py index a2be0ec5..380df0f6 100644 --- a/reflexio/server/services/unified_search_service.py +++ b/reflexio/server/services/unified_search_service.py @@ -6,9 +6,12 @@ Phase B: Entity searches across profiles, agent playbooks, user playbooks (parallel) """ +from __future__ import annotations + import logging from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError +from typing import TYPE_CHECKING from reflexio.models.api_schema.retriever_schema import ( ConversationTurn, @@ -29,6 +32,9 @@ from reflexio.server.services.pre_retrieval import QueryReformulator from reflexio.server.services.storage.storage_base import BaseStorage +if TYPE_CHECKING: + from reflexio.server.api_endpoints.request_context import RequestContext + logger = logging.getLogger(__name__) @@ -268,3 +274,24 @@ def _search_profiles_via_storage( except Exception as e: logger.error("Profile search failed: %s", e) return [] + + +class UnifiedSearchService: + """Class handle for the classic unified search pipeline. + + Wraps :func:`run_unified_search` so the dispatcher factory can return an + object whose ``__class__.__name__`` can be inspected uniformly alongside + the agentic search service (Phase 4). + + Args: + llm_client (LiteLLMClient): Configured LLM client. + request_context (RequestContext): Current request context. + """ + + def __init__( + self, + llm_client: LiteLLMClient, + request_context: RequestContext, + ) -> None: + self.llm_client = llm_client + self.request_context = request_context diff --git a/reflexio/server/site_var/feature_flags.py b/reflexio/server/site_var/feature_flags.py index 59fb2ca1..67689c87 100644 --- a/reflexio/server/site_var/feature_flags.py +++ b/reflexio/server/site_var/feature_flags.py @@ -88,16 +88,3 @@ def is_invitation_only_enabled() -> bool: if invitation_config is None: return False return invitation_config.get("enabled", False) - - -def is_deduplicator_enabled(org_id: str) -> bool: - """ - Convenience check for whether the deduplicator is enabled for an org. - - Args: - org_id (str): The organization ID to check - - Returns: - bool: True if deduplicator is enabled - """ - return is_feature_enabled(org_id, "deduplicator") diff --git a/reflexio/server/uvicorn_logging.py b/reflexio/server/uvicorn_logging.py index a3caebf3..33da2d38 100644 --- a/reflexio/server/uvicorn_logging.py +++ b/reflexio/server/uvicorn_logging.py @@ -29,17 +29,24 @@ # Access-log fields mirror uvicorn's built-in AccessFormatter message shape, # minus the padded level prefix. -ACCESS_FORMAT = ( - '%(levelname)s: %(client_addr)s - "%(request_line)s" %(status_code)s' -) +ACCESS_FORMAT = '%(levelname)s: %(client_addr)s - "%(request_line)s" %(status_code)s' UVICORN_LOG_CONFIG: dict[str, Any] = { "version": 1, "disable_existing_loggers": False, "formatters": { + # Access format references uvicorn-specific fields (client_addr, + # request_line, status_code) that only ``uvicorn.logging.AccessFormatter`` + # knows how to populate from the log record's ``args`` tuple. The + # stdlib ``logging.Formatter`` raises ``KeyError: 'client_addr'`` on + # every request. Default formatter stays on stdlib because it uses + # only ``levelname`` / ``message``. "default": {"format": LEVEL_FORMAT}, - "access": {"format": ACCESS_FORMAT}, + "access": { + "()": "uvicorn.logging.AccessFormatter", + "fmt": ACCESS_FORMAT, + }, }, "handlers": { "default": { diff --git a/reflexio/test_support/llm_mock.py b/reflexio/test_support/llm_mock.py index 88b60598..271e1541 100644 --- a/reflexio/test_support/llm_mock.py +++ b/reflexio/test_support/llm_mock.py @@ -116,3 +116,53 @@ def cleanup_llm_mock(config: Any) -> None: # noqa: ARG001 if _litellm_patcher: _litellm_patcher.stop() _litellm_patcher = None + + +def make_tool_call_response(tool_name: str, args: dict[str, Any]) -> MagicMock: + """Build a litellm ModelResponse-shaped mock with a single tool_call. + + Used by unit tests that drive tool loops against the patched + ``litellm.completion``. Not routed automatically by prompt + heuristics — callers install it explicitly with ``side_effect``. + + Args: + tool_name (str): The name the assistant is calling. + args (dict[str, Any]): JSON-serialisable arguments passed to the tool. + + Returns: + MagicMock: A response object shaped like a litellm ModelResponse + whose first choice has ``finish_reason="tool_calls"`` and a + single tool call matching the given name and args. + """ + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].finish_reason = "tool_calls" + resp.choices[0].message.content = None + tc = MagicMock() + tc.id = f"tc_{tool_name}" + tc.type = "function" + tc.function.name = tool_name + tc.function.arguments = json.dumps(args) + resp.choices[0].message.tool_calls = [tc] + return resp + + +def make_finish_response(text: str = "done") -> MagicMock: + """Build a normal (non-tool-call) assistant message. + + Used to terminate a tool loop that was driven by repeated + ``make_tool_call_response`` mocks. + + Args: + text (str): Content of the terminal message. + + Returns: + MagicMock: A response object with ``finish_reason="stop"``, + the given text, and ``tool_calls=None``. + """ + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].finish_reason = "stop" + resp.choices[0].message.content = text + resp.choices[0].message.tool_calls = None + return resp diff --git a/reflexio/test_support/llm_model_registry.py b/reflexio/test_support/llm_model_registry.py index 3ee9a8e3..a0b4582c 100644 --- a/reflexio/test_support/llm_model_registry.py +++ b/reflexio/test_support/llm_model_registry.py @@ -33,16 +33,10 @@ def _build_registry() -> dict[str, ModelRegistryEntry]: AgentSuccessEvaluationOutput, AgentSuccessEvaluationWithComparisonOutput, ) - from reflexio.server.services.playbook.playbook_deduplicator import ( - PlaybookDeduplicationOutput, - ) from reflexio.server.services.playbook.playbook_service_utils import ( PlaybookAggregationOutput, StructuredPlaybookList, ) - from reflexio.server.services.profile.profile_deduplicator import ( - ProfileDeduplicationOutput, - ) from reflexio.server.services.profile.profile_generation_service_utils import ( ProfileUpdateOutput, StructuredProfilesOutput, @@ -69,13 +63,6 @@ def _build_registry() -> dict[str, ModelRegistryEntry]: }, }, ), - "playbook_deduplication": ModelRegistryEntry( - model_class=PlaybookDeduplicationOutput, - minimal_valid={ - "duplicate_groups": [], - "unique_ids": ["NEW-0"], - }, - ), "profile_extraction": ModelRegistryEntry( model_class=StructuredProfilesOutput, minimal_valid={ @@ -94,13 +81,6 @@ def _build_registry() -> dict[str, ModelRegistryEntry]: "mention": [], }, ), - "profile_deduplication": ModelRegistryEntry( - model_class=ProfileDeduplicationOutput, - minimal_valid={ - "duplicate_groups": [], - "unique_ids": ["NEW-0"], - }, - ), "agent_success_evaluation": ModelRegistryEntry( model_class=AgentSuccessEvaluationOutput, minimal_valid={ diff --git a/tests/cli/test_codex_auth.py b/tests/cli/test_codex_auth.py new file mode 100644 index 00000000..60e36de2 --- /dev/null +++ b/tests/cli/test_codex_auth.py @@ -0,0 +1,215 @@ +"""Unit tests for ``reflexio.cli.codex_auth`` — PKCE, JWT decoding, token storage. + +We don't exercise the full browser/callback flow here (that's an integration +concern). The tests below lock down the building blocks: + +- PKCE verifier/challenge generation produces RFC-7636-compatible output. +- JWT payload extraction handles both well-formed and pathological inputs. +- ``CodexTokens`` round-trips through ``save_tokens`` / ``load_tokens_raw``. +- ``is_expired`` honours the lead-time threshold. +- ``_tokens_from_response`` populates metadata from JWT claims correctly. +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import time +from pathlib import Path + +import pytest + +from reflexio.cli import codex_auth + + +def _b64url(data: bytes) -> str: + """Base64url-encode without padding (test helper, mirrors the module's).""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _make_jwt(claims: dict) -> str: + """Build an RS256-shaped JWT from a payload dict. + + The signature is fake (the module deliberately does not verify), so we + can hand-craft tokens for the storage / refresh logic without involving + cryptography. Header is the constant Codex uses. + """ + header = _b64url(json.dumps({"alg": "RS256", "typ": "JWT"}).encode()) + payload = _b64url(json.dumps(claims).encode()) + sig = _b64url(b"fake-signature-not-verified") + return f"{header}.{payload}.{sig}" + + +class TestPkce: + def test_pair_shape(self) -> None: + verifier, challenge = codex_auth._make_pkce_pair() + # Both base64url, no padding. + assert "=" not in verifier + assert "=" not in challenge + # 32-byte random source -> 43-char base64url. + assert len(verifier) == 43 + # Challenge is base64url(SHA-256(verifier ASCII)). + expected = _b64url(hashlib.sha256(verifier.encode("ascii")).digest()) + assert challenge == expected + + def test_pairs_are_unique(self) -> None: + # Different invocations should not collide (32-byte entropy). + pairs = {codex_auth._make_pkce_pair()[0] for _ in range(50)} + assert len(pairs) == 50 + + +class TestJwtDecoding: + def test_decode_extracts_payload(self) -> None: + claims = {"foo": "bar", "exp": 1234567890} + jwt = _make_jwt(claims) + out = codex_auth._decode_jwt_payload(jwt) + assert out == claims + + def test_decode_handles_unpadded_b64(self) -> None: + # Codex JWTs typically have no padding on the payload segment; + # the decoder must restore it on the fly. + claims = {"x": 1} + jwt = _make_jwt(claims) + # Strip any incidental trailing '=' just in case. + assert "=" not in jwt + assert codex_auth._decode_jwt_payload(jwt) == claims + + def test_decode_rejects_malformed(self) -> None: + with pytest.raises(ValueError, match="not a JWT"): + codex_auth._decode_jwt_payload("not.a.jwt.at.all") + with pytest.raises(ValueError, match="not a JWT"): + codex_auth._decode_jwt_payload("only-one-part") + + +class TestTokensFromResponse: + def test_extracts_account_id_and_plan_type(self) -> None: + # Mirror the JWT shape OpenAI issues: chatgpt_plan_type lives under + # the namespaced ``https://api.openai.com/auth`` claim, email under + # ``https://api.openai.com/profile``. + claims = { + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": { + "chatgpt_account_id": "acct-abc-123", + "chatgpt_plan_type": "max-x20", + }, + "https://api.openai.com/profile": { + "email": "user@example.com", + }, + } + access = _make_jwt(claims) + payload = { + "access_token": access, + "refresh_token": "rt_abc", + "expires_in": 3600, + } + tokens = codex_auth._tokens_from_response(payload) + assert tokens.access_token == access + assert tokens.refresh_token == "rt_abc" + assert tokens.account_id == "acct-abc-123" + assert tokens.plan_type == "max-x20" + assert tokens.email == "user@example.com" + assert tokens.expires_at == claims["exp"] + + def test_falls_back_to_expires_in_when_jwt_lacks_exp(self) -> None: + claims = {"https://api.openai.com/auth": {}} # no exp + access = _make_jwt(claims) + before = int(time.time()) + tokens = codex_auth._tokens_from_response( + {"access_token": access, "refresh_token": "rt", "expires_in": 600} + ) + # Allow a small wall-time window (<2s) for the test runner. + assert before + 600 <= tokens.expires_at <= before + 602 + + def test_rejects_missing_required_fields(self) -> None: + with pytest.raises(ValueError, match="missing access_token"): + codex_auth._tokens_from_response({"refresh_token": "rt"}) + with pytest.raises(ValueError, match="missing access_token"): + codex_auth._tokens_from_response({"access_token": _make_jwt({})}) + + +class TestTokenStorage: + def test_save_and_load_round_trip(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + # Redirect storage to a temp dir so the test never touches the + # developer's real ~/.reflexio/auth/. + monkeypatch.setattr(codex_auth, "REFLEXIO_AUTH_DIR", tmp_path / "auth") + monkeypatch.setattr( + codex_auth, + "REFLEXIO_CODEX_TOKENS_PATH", + tmp_path / "auth" / "openai-codex.json", + ) + + tokens = codex_auth.CodexTokens( + access_token="a-jwt", + refresh_token="rt-1", + account_id="acct-x", + expires_at=1234, + plan_type="max-x20", + email="x@y.com", + ) + path = codex_auth.save_tokens(tokens) + assert path.exists() + # File mode should be 0600 on POSIX (best-effort on platforms that + # don't support it; we just check the round-trip below). + loaded = codex_auth.load_tokens_raw() + assert loaded == tokens + + def test_load_returns_none_when_missing(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + codex_auth, + "REFLEXIO_CODEX_TOKENS_PATH", + tmp_path / "openai-codex.json", + ) + assert codex_auth.load_tokens_raw() is None + + def test_load_returns_none_for_malformed_json(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + path = tmp_path / "openai-codex.json" + path.write_text("{not valid json") + monkeypatch.setattr(codex_auth, "REFLEXIO_CODEX_TOKENS_PATH", path) + assert codex_auth.load_tokens_raw() is None + + +class TestExpiryCheck: + def test_is_expired_lead_time(self) -> None: + now = int(time.time()) + # 30 seconds in the future, default lead time 60 -> already "expired". + t1 = codex_auth.CodexTokens( + access_token="x", + refresh_token="y", + account_id="", + expires_at=now + 30, + plan_type="", + email="", + ) + assert t1.is_expired() is True + + # 600 seconds in the future, well outside any lead time. + t2 = codex_auth.CodexTokens( + access_token="x", + refresh_token="y", + account_id="", + expires_at=now + 600, + plan_type="", + email="", + ) + assert t2.is_expired() is False + # Custom lead time can flip the result. + assert t2.is_expired(lead_seconds=700) is True + + +class TestAuthorizeUrl: + def test_url_contains_required_oauth_params(self) -> None: + verifier, _ = codex_auth._make_pkce_pair() + state = "csrf-state-abc" + url, challenge = codex_auth.build_authorize_url(verifier, state) + # Sanity-check the host + a handful of required params. + assert url.startswith(codex_auth.CODEX_AUTHORIZE_URL + "?") + for required in ( + f"client_id={codex_auth.CODEX_CLIENT_ID}", + "response_type=code", + "code_challenge_method=S256", + f"state={state}", + "scope=openid+profile+email+offline_access", + ): + assert required in url + assert challenge in url diff --git a/tests/cli/test_helpers.py b/tests/cli/test_helpers.py index b2774a94..0fdf862b 100644 --- a/tests/cli/test_helpers.py +++ b/tests/cli/test_helpers.py @@ -97,11 +97,17 @@ def test_tools_used_preserved(self) -> None: "tools_used": [ { "tool_name": "run_snowflake_query", - "tool_data": {"statement": "SELECT ...", "status": "failed"}, + "tool_data": { + "statement": "SELECT ...", + "status": "failed", + }, }, { "tool_name": "run_snowflake_query", - "tool_data": {"statement": "SELECT * LIMIT 1", "status": "ok"}, + "tool_data": { + "statement": "SELECT * LIMIT 1", + "status": "ok", + }, }, ], }, diff --git a/tests/cli/test_log_format.py b/tests/cli/test_log_format.py index d1ce939d..0c1a2798 100644 --- a/tests/cli/test_log_format.py +++ b/tests/cli/test_log_format.py @@ -59,10 +59,7 @@ class TestHighlightLogLevelNonTty: def test_no_color_when_not_tty(self) -> None: with patch("reflexio.cli.log_format.sys.stdout.isatty", return_value=False): - assert ( - highlight_log_level("ERROR: boom") - == "ERROR: boom" - ) + assert highlight_log_level("ERROR: boom") == "ERROR: boom" class TestFormatServiceLine: diff --git a/tests/cli/test_setup_cmd.py b/tests/cli/test_setup_cmd.py index 1f2c079f..4b970308 100644 --- a/tests/cli/test_setup_cmd.py +++ b/tests/cli/test_setup_cmd.py @@ -13,10 +13,8 @@ InstallLocation, _detect_install_locations, _install_claude_code_integration, - _install_openclaw_integration, _prompt_install_location, _prompt_storage, - _prompt_user_id, _remove_from_dir, _set_env_var, _write_marker, @@ -325,24 +323,6 @@ def test_normal_mode_no_command(self, tmp_path: Path) -> None: cmd = tmp_path / ".claude" / "commands" / "reflexio-extract" / "SKILL.md" assert not cmd.exists() - def test_expert_mode_installs_references(self, tmp_path: Path) -> None: - """Expert mode copies skill references directory.""" - _install_claude_code_integration( - tmp_path, expert=True, location=InstallLocation.CURRENT_PROJECT - ) - refs = tmp_path / ".claude" / "skills" / "reflexio" / "references" - assert refs.exists() - assert (refs / "proactive-patterns.md").exists() - assert (refs / "server-management.md").exists() - - def test_normal_mode_no_references(self, tmp_path: Path) -> None: - """Normal mode does not install skill references.""" - _install_claude_code_integration( - tmp_path, location=InstallLocation.CURRENT_PROJECT - ) - refs = tmp_path / ".claude" / "skills" / "reflexio" / "references" - assert not refs.exists() - def test_hooks_in_settings_json(self, tmp_path: Path) -> None: """Hooks are written to settings.json with correct events.""" _install_claude_code_integration( @@ -353,34 +333,34 @@ def test_hooks_in_settings_json(self, tmp_path: Path) -> None: assert "UserPromptSubmit" in settings["hooks"] def test_normal_mode_no_session_end_hook(self, tmp_path: Path) -> None: - """Normal mode does not install the SessionEnd hook.""" + """Normal mode does not install the Stop hook.""" _install_claude_code_integration( tmp_path, location=InstallLocation.CURRENT_PROJECT ) settings = json.loads((tmp_path / ".claude" / "settings.json").read_text()) - assert "SessionEnd" not in settings["hooks"] + assert "Stop" not in settings["hooks"] def test_expert_mode_installs_session_end_hook(self, tmp_path: Path) -> None: - """Expert mode installs SessionEnd hook alongside SessionStart and UserPromptSubmit.""" + """Expert mode installs Stop hook alongside SessionStart and UserPromptSubmit.""" _install_claude_code_integration( tmp_path, expert=True, location=InstallLocation.CURRENT_PROJECT ) settings = json.loads((tmp_path / ".claude" / "settings.json").read_text()) - assert "SessionEnd" in settings["hooks"] - assert len(settings["hooks"]["SessionEnd"]) == 1 - # Verify the SessionEnd hook command points to handler.js - cmd = settings["hooks"]["SessionEnd"][0]["hooks"][0]["command"] + assert "Stop" in settings["hooks"] + assert len(settings["hooks"]["Stop"]) == 1 + # Verify the Stop hook command points to handler.js + cmd = settings["hooks"]["Stop"][0]["hooks"][0]["command"] assert "handler.js" in cmd assert cmd.startswith("node ") def test_expert_mode_session_end_hook_idempotent(self, tmp_path: Path) -> None: - """Running expert install twice doesn't duplicate the SessionEnd hook.""" + """Running expert install twice doesn't duplicate the Stop hook.""" for _ in range(2): _install_claude_code_integration( tmp_path, expert=True, location=InstallLocation.ALL_PROJECTS ) settings = json.loads((tmp_path / ".claude" / "settings.json").read_text()) - assert len(settings["hooks"]["SessionEnd"]) == 1 + assert len(settings["hooks"]["Stop"]) == 1 assert len(settings["hooks"]["SessionStart"]) == 1 assert len(settings["hooks"]["UserPromptSubmit"]) == 1 @@ -392,18 +372,16 @@ def test_normal_reinstall_removes_expert_artifacts(self, tmp_path: Path) -> None ) claude_dir = tmp_path / ".claude" assert (claude_dir / "commands" / "reflexio-extract").exists() - assert (claude_dir / "skills" / "reflexio" / "references").exists() settings = json.loads((claude_dir / "settings.json").read_text()) - assert "SessionEnd" in settings["hooks"] + assert "Stop" in settings["hooks"] # Re-install in normal mode _install_claude_code_integration( tmp_path, expert=False, location=InstallLocation.CURRENT_PROJECT ) assert not (claude_dir / "commands" / "reflexio-extract").exists() - assert not (claude_dir / "skills" / "reflexio" / "references").exists() settings = json.loads((claude_dir / "settings.json").read_text()) - assert "SessionEnd" not in settings.get("hooks", {}) + assert "Stop" not in settings.get("hooks", {}) def test_idempotent_install(self, tmp_path: Path) -> None: """Running install twice doesn't corrupt files or duplicate hooks.""" @@ -486,18 +464,18 @@ def test_remove_from_dir_cleans_all_files(self, tmp_path: Path) -> None: assert "hooks" not in settings or not settings.get("hooks") def test_remove_from_dir_cleans_session_end_hook(self, tmp_path: Path) -> None: - """Uninstall removes the SessionEnd hook installed by expert mode.""" + """Uninstall removes the Stop hook installed by expert mode.""" _install_claude_code_integration( tmp_path, expert=True, location=InstallLocation.CURRENT_PROJECT ) settings_path = tmp_path / ".claude" / "settings.json" settings = json.loads(settings_path.read_text()) - assert "SessionEnd" in settings["hooks"] + assert "Stop" in settings["hooks"] _remove_from_dir(tmp_path) settings = json.loads(settings_path.read_text()) - assert "hooks" not in settings or "SessionEnd" not in settings.get("hooks", {}) + assert "hooks" not in settings or "Stop" not in settings.get("hooks", {}) def test_marker_file_metadata(self, tmp_path: Path) -> None: """Marker file contains location and installed_at fields.""" @@ -527,201 +505,3 @@ def test_global_and_project_dir_mutual_exclusion(self) -> None: project_dir=Path("/tmp"), global_install=True, ) - - -# --------------------------------------------------------------------------- -# _install_openclaw_integration — ClawHub-vs-pip skill ownership -# --------------------------------------------------------------------------- - - -def _make_openclaw_subprocess_stub() -> MagicMock: - """Build a subprocess.run stub that fakes success for every openclaw call. - - The three calls made by ``_install_openclaw_integration`` are: - ``plugins install``, ``hooks enable``, and ``hooks list`` (the last one - must return 'reflexio-context' in stdout to pass the verify step). - - Returns: - MagicMock: A mock usable as ``subprocess.run`` replacement. - """ - - def _run(cmd: list[str], **_: object) -> MagicMock: - result = MagicMock() - result.returncode = 0 - result.stderr = "" - result.stdout = "✓ ready │ reflexio-context" if "list" in cmd else "" - return result - - return MagicMock(side_effect=_run) - - -class TestInstallOpenclawIntegration: - """Regression tests for the ClawHub-vs-pip skill-ownership guard.""" - - def test_preserves_clawhub_installed_skill( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """If _meta.json is present, the existing SKILL.md is not overwritten. - - Simulates a user who first installed via ``clawhub skill install - reflexio`` and then runs ``reflexio setup openclaw``. ClawHub's - copy should survive untouched. - """ - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - skills_dir = tmp_path / ".openclaw" / "skills" / "reflexio" - skills_dir.mkdir(parents=True) - sentinel = "CLAWHUB_INSTALLED_SENTINEL_DO_NOT_OVERWRITE" - (skills_dir / "SKILL.md").write_text(sentinel) - (skills_dir / "_meta.json").write_text( - '{"ownerId":"x","slug":"reflexio","version":"1.0.0"}' - ) - - with ( - patch( - "reflexio.cli.commands.setup_cmd.shutil.which", - return_value="/usr/bin/openclaw", - ), - patch( - "reflexio.cli.commands.setup_cmd.subprocess.run", - _make_openclaw_subprocess_stub(), - ), - ): - _install_openclaw_integration() - - assert (skills_dir / "SKILL.md").read_text() == sentinel - assert (skills_dir / "_meta.json").exists() - - def test_refreshes_pip_installed_skill( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """If _meta.json is absent, an existing SKILL.md is always replaced. - - Regression test for the upgrade path: ``pip install --upgrade - reflexio-ai && reflexio setup openclaw`` must refresh stale skill - content from a prior pip install. - """ - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - skills_dir = tmp_path / ".openclaw" / "skills" / "reflexio" - skills_dir.mkdir(parents=True) - (skills_dir / "SKILL.md").write_text("STALE_PIP_INSTALLED_CONTENT") - - with ( - patch( - "reflexio.cli.commands.setup_cmd.shutil.which", - return_value="/usr/bin/openclaw", - ), - patch( - "reflexio.cli.commands.setup_cmd.subprocess.run", - _make_openclaw_subprocess_stub(), - ), - ): - _install_openclaw_integration() - - import reflexio - - source_skill = ( - Path(reflexio.__file__).parent - / "integrations" - / "openclaw" - / "skill" - / "SKILL.md" - ) - assert (skills_dir / "SKILL.md").read_text() == source_skill.read_text() - assert ( - "STALE_PIP_INSTALLED_CONTENT" not in (skills_dir / "SKILL.md").read_text() - ) - - -# --------------------------------------------------------------------------- -# _prompt_user_id — optional custom user_id during Claude Code setup -# --------------------------------------------------------------------------- - - -class TestPromptUserId: - """Tests for _prompt_user_id: default, custom value, whitespace, env-driven default.""" - - def test_default_is_persisted_when_user_accepts( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Pressing Enter keeps the fallback 'claude-code'.""" - env = tmp_path / ".env" - env.write_text("") - monkeypatch.delenv("REFLEXIO_USER_ID", raising=False) - monkeypatch.setattr(typer, "prompt", lambda *_, **kwargs: kwargs["default"]) - - result = _prompt_user_id(env) - - assert result == "claude-code" - assert 'REFLEXIO_USER_ID="claude-code"' in env.read_text() - - def test_custom_value_is_persisted( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """A user-entered value is persisted verbatim.""" - env = tmp_path / ".env" - env.write_text("") - monkeypatch.delenv("REFLEXIO_USER_ID", raising=False) - monkeypatch.setattr(typer, "prompt", _fixed_prompt("alice")) - - result = _prompt_user_id(env) - - assert result == "alice" - assert 'REFLEXIO_USER_ID="alice"' in env.read_text() - - def test_whitespace_is_stripped( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Surrounding whitespace is trimmed before persistence.""" - env = tmp_path / ".env" - env.write_text("") - monkeypatch.delenv("REFLEXIO_USER_ID", raising=False) - monkeypatch.setattr(typer, "prompt", _fixed_prompt(" bob ")) - - result = _prompt_user_id(env) - - assert result == "bob" - assert 'REFLEXIO_USER_ID="bob"' in env.read_text() - - def test_existing_env_value_offered_as_default( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Re-running setup offers the currently configured user_id as the default.""" - env = tmp_path / ".env" - env.write_text('REFLEXIO_USER_ID="alice"\n') - monkeypatch.setenv("REFLEXIO_USER_ID", "alice") - - captured: dict[str, object] = {} - - def _fake_prompt(*_: object, **kwargs: object) -> object: - captured.update(kwargs) - return kwargs["default"] - - monkeypatch.setattr(typer, "prompt", _fake_prompt) - - result = _prompt_user_id(env) - - assert captured["default"] == "alice" - assert result == "alice" - - def test_empty_input_falls_back_to_default( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """If the user somehow submits an empty/whitespace-only string, fall back.""" - env = tmp_path / ".env" - env.write_text("") - monkeypatch.delenv("REFLEXIO_USER_ID", raising=False) - monkeypatch.setattr(typer, "prompt", _fixed_prompt(" ")) - - result = _prompt_user_id(env) - - assert result == "claude-code" - assert 'REFLEXIO_USER_ID="claude-code"' in env.read_text() - - -def _fixed_prompt(return_value: str): - """Build a typer.prompt stub that returns a fixed value, ignoring args/kwargs.""" - - def _stub(*_args: object, **_kwargs: object) -> str: - return return_value - - return _stub diff --git a/tests/client/test_cache.py b/tests/client/test_cache.py index e3af55e8..31b87f8b 100644 --- a/tests/client/test_cache.py +++ b/tests/client/test_cache.py @@ -136,7 +136,6 @@ def set_and_get(thread_id): for thread_id, result in results: assert result == f"value_{thread_id}" # noqa: S101 - def test_clear_removes_all_entries(self): """Test that clear() removes all cached entries.""" cache = InMemoryCache() @@ -395,7 +394,12 @@ def test_delete_all_profiles_invalidates_cache(self, mock_session_class): client = ReflexioClient(api_key="test_key") # Populate cache - request = {"user_id": "user1", "start_time": None, "end_time": None, "top_k": 30} + request = { + "user_id": "user1", + "start_time": None, + "end_time": None, + "top_k": 30, + } client.get_profiles(request) assert mock_session.request.call_count == 1 # noqa: S101 @@ -434,7 +438,9 @@ def test_delete_all_interactions_clears_all_cache(self, mock_session_class): client = ReflexioClient(api_key="test_key") # Populate both caches - client.get_profiles({"user_id": "u1", "start_time": None, "end_time": None, "top_k": 30}) + client.get_profiles( + {"user_id": "u1", "start_time": None, "end_time": None, "top_k": 30} + ) client.get_agent_playbooks({"limit": 100}) assert mock_session.request.call_count == 2 # noqa: S101 @@ -443,7 +449,9 @@ def test_delete_all_interactions_clears_all_cache(self, mock_session_class): assert mock_session.request.call_count == 3 # noqa: S101 # Both caches should miss - client.get_profiles({"user_id": "u1", "start_time": None, "end_time": None, "top_k": 30}) + client.get_profiles( + {"user_id": "u1", "start_time": None, "end_time": None, "top_k": 30} + ) client.get_agent_playbooks({"limit": 100}) assert mock_session.request.call_count == 5 # noqa: S101 diff --git a/tests/conftest.py b/tests/conftest.py index 825cb1f4..7b15ea52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import sys from pathlib import Path +import pytest + _THIS_DIR = Path(__file__).resolve().parent # tests/ PROJECT_ROOT = _THIS_DIR.parent.parent # repo root @@ -18,3 +20,28 @@ def pytest_configure(config): def pytest_unconfigure(config): cleanup_llm_mock(config) + + +@pytest.fixture +def tool_call_completion(): + """Factory helpers for mocking a tool-calling conversation. + + Yields: + tuple: ``(make_tool_call_response, make_finish_response)`` — + call the first to build an assistant turn that requests a + tool, and the second to build the terminal stop turn. + + Usage:: + + def test_my_loop(tool_call_completion): + make_tc, make_stop = tool_call_completion + responses = [make_tc("emit", {"v": 1}), make_stop()] + with patch("litellm.completion", side_effect=responses): + ... + """ + from reflexio.test_support.llm_mock import ( + make_finish_response, + make_tool_call_response, + ) + + return make_tool_call_response, make_finish_response diff --git a/tests/e2e_tests/test_complete_workflows.py b/tests/e2e_tests/test_complete_workflows.py index 9ff21327..e809c325 100644 --- a/tests/e2e_tests/test_complete_workflows.py +++ b/tests/e2e_tests/test_complete_workflows.py @@ -113,7 +113,7 @@ def test_complete_workflow_end_to_end( # Step 5: Search profiles (use actual profile content for reliable search) profile_content = get_profiles_response.user_profiles[0].content search_words = " ".join(profile_content.split()[:4]) - search_profile_response = reflexio_instance.search_profiles( + search_profile_response = reflexio_instance.search_user_profiles( SearchUserProfileRequest(user_id=user_id, query=search_words, top_k=5) ) assert search_profile_response.success is True @@ -176,7 +176,7 @@ def test_error_handling_end_to_end( assert len(search_response.interactions) == 0 # Test with invalid profile search - profile_response = reflexio_instance.search_profiles( + profile_response = reflexio_instance.search_user_profiles( SearchUserProfileRequest(user_id="nonexistent_user", query="test", top_k=5) ) assert profile_response.success is True @@ -290,7 +290,7 @@ def test_profile_status_filtering( sample_interaction_requests: list[InteractionData], cleanup_after_test: Callable[[], None], ): - """Test profile status filtering in search_profiles and get_profiles.""" + """Test profile status filtering in search_user_profiles and get_profiles.""" user_id = "test_user_status" # Publish interactions to generate profiles @@ -320,10 +320,10 @@ def test_profile_status_filtering( assert current_explicit.success is True assert len(current_explicit.user_profiles) == current_count - # Test search_profiles with default filter (use actual profile content for reliable search) + # Test search_user_profiles with default filter (use actual profile content for reliable search) profile_content = current_profiles.user_profiles[0].content search_words = " ".join(profile_content.split()[:4]) - search_current = reflexio_instance.search_profiles( + search_current = reflexio_instance.search_user_profiles( SearchUserProfileRequest(user_id=user_id, query=search_words, top_k=10) ) assert search_current.success is True @@ -746,7 +746,7 @@ def test_full_workflow_with_all_features( # Search profiles (use actual profile content for reliable search) profile_content = stored_profiles[0].content search_words = " ".join(profile_content.split()[:4]) - search_profile_response = reflexio_instance.search_profiles( + search_profile_response = reflexio_instance.search_user_profiles( SearchUserProfileRequest(user_id=user_id, query=search_words, top_k=5) ) assert search_profile_response.success is True diff --git a/tests/e2e_tests/test_interaction_workflows.py b/tests/e2e_tests/test_interaction_workflows.py index 735d00e7..8935531b 100644 --- a/tests/e2e_tests/test_interaction_workflows.py +++ b/tests/e2e_tests/test_interaction_workflows.py @@ -300,7 +300,7 @@ def test_dict_input_handling_end_to_end( "query": search_words, # Use actual profile content for search "top_k": 5, } - profile_response = reflexio_instance.search_profiles(profile_search_dict) + profile_response = reflexio_instance.search_user_profiles(profile_search_dict) assert profile_response.success is True assert len(profile_response.user_profiles) > 0 # Verify all returned profiles have CURRENT status (default search filter) diff --git a/tests/e2e_tests/test_profile_workflows.py b/tests/e2e_tests/test_profile_workflows.py index 2eacab2e..0f596d81 100644 --- a/tests/e2e_tests/test_profile_workflows.py +++ b/tests/e2e_tests/test_profile_workflows.py @@ -123,7 +123,7 @@ def test_search_profiles_end_to_end( top_k=5, ) - response = reflexio_instance_profile_only.search_profiles(search_request) + response = reflexio_instance_profile_only.search_user_profiles(search_request) # Verify search results assert response.success is True @@ -575,12 +575,12 @@ def test_status_filter_in_get_all_profiles( @skip_in_precommit @skip_low_priority -def test_status_filter_in_search_profiles( +def test_status_filter_in_search_user_profiles( reflexio_instance_profile_only: Reflexio, sample_interaction_requests: list[InteractionData], cleanup_profile_only: Callable[[], None], ): - """Test status filtering in search_profiles method.""" + """Test status filtering in search_user_profiles method.""" user_id = "test_user_search_status" # Publish interactions to create current profiles @@ -606,19 +606,19 @@ def test_status_filter_in_search_profiles( top_k=10, ) - default_search = reflexio_instance_profile_only.search_profiles(search_request) + default_search = reflexio_instance_profile_only.search_user_profiles(search_request) assert default_search.success is True assert all(p.status is None for p in default_search.user_profiles) # Test search with pending filter - pending_search = reflexio_instance_profile_only.search_profiles( + pending_search = reflexio_instance_profile_only.search_user_profiles( search_request, status_filter=[Status.PENDING] ) assert pending_search.success is True assert all(p.status == Status.PENDING for p in pending_search.user_profiles) # Test search with both statuses - all_search = reflexio_instance_profile_only.search_profiles( + all_search = reflexio_instance_profile_only.search_user_profiles( search_request, status_filter=[None, Status.PENDING] ) assert all_search.success is True diff --git a/tests/eval/__init__.py b/tests/eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/eval/aggregate.py b/tests/eval/aggregate.py new file mode 100644 index 00000000..783383fb --- /dev/null +++ b/tests/eval/aggregate.py @@ -0,0 +1,39 @@ +"""Polars-based aggregator for golden-set eval results. + +Reads a parquet file containing per-case judge scores and per-backend cost +metrics and reduces it to a per-backend summary. Used by the weekly eval +report and by the comparison harness. +""" + +from __future__ import annotations + +import polars as pl + + +def aggregate_eval_results(results_path: str) -> pl.DataFrame: + """Group per-case rows by ``backend`` and report means + p95 latency. + + Args: + results_path (str): Path to a parquet file with columns + ``backend``, ``signal_f1``, ``answer_correctness``, + ``grounded_rate``, ``cost_usd``, ``latency_ms``. + + Returns: + pl.DataFrame: One row per backend with aggregated columns + ``mean_f1``, ``mean_correctness``, ``grounded_rate``, + ``mean_cost``, ``p95_latency``. + """ + return ( + pl.scan_parquet(results_path) + .group_by("backend") + .agg( + [ + pl.col("signal_f1").mean().alias("mean_f1"), + pl.col("answer_correctness").mean().alias("mean_correctness"), + pl.col("grounded_rate").mean().alias("grounded_rate"), + pl.col("cost_usd").mean().alias("mean_cost"), + pl.col("latency_ms").quantile(0.95).alias("p95_latency"), + ] + ) + .collect() + ) diff --git a/tests/eval/conftest.py b/tests/eval/conftest.py new file mode 100644 index 00000000..80a7fe28 --- /dev/null +++ b/tests/eval/conftest.py @@ -0,0 +1,98 @@ +"""Fixtures for the golden-set comparison harness. + +Parametrizes tests over every YAML file in ``golden_set/extraction`` or +``golden_set/search``. The ``judge`` fixture returns a stubbed ``LLMJudge`` +by default; set ``REFLEXIO_EVAL_REAL_JUDGE=1`` with a real Anthropic key to +hit the live judge model. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest +import yaml + +from tests.eval.judge import JudgeScore, LLMJudge + +_GOLDEN = Path(__file__).parent / "golden_set" +_RUBRICS = Path(__file__).parent / "judge_prompts" + + +def _load(kind: str) -> list[dict[str, Any]]: + """Load every YAML golden file under ``golden_set//`` sorted by id. + + The previous implementation sorted by filename, which silently produces + unstable parametrization ids if a file is renamed without updating its + YAML ``id`` (or vice-versa). Sort by the YAML ``id`` so the test ordering + matches what pytest reports. + + Raises: + ValueError: If a golden YAML file is missing an ``id`` key. + """ + cases: list[dict[str, Any]] = [] + for path in (_GOLDEN / kind).glob("*.yaml"): + case = yaml.safe_load(path.read_text()) + if "id" not in case: + raise ValueError(f"Golden case {path} is missing required 'id' key") + cases.append(case) + return sorted(cases, key=lambda c: c["id"]) + + +def pytest_generate_tests(metafunc): + """Parametrize over every golden case for tests that ask for one.""" + if "extraction_case" in metafunc.fixturenames: + cases = _load("extraction") + metafunc.parametrize("extraction_case", cases, ids=[c["id"] for c in cases]) + if "search_case" in metafunc.fixturenames: + cases = _load("search") + metafunc.parametrize("search_case", cases, ids=[c["id"] for c in cases]) + + +def _stubbed_judge(rubric: dict[str, Any]) -> LLMJudge: + client = MagicMock() + client.generate_chat_response.return_value = JudgeScore( + signal_f1=0.5, + answer_correctness=0.5, + grounded_rate=1.0, + rationale="stub", + ) + return LLMJudge(client=client, rubric=rubric) + + +def _real_judge(rubric: dict[str, Any]) -> LLMJudge: + from reflexio.server.llm.litellm_client import LiteLLMClient, LiteLLMConfig + + client = LiteLLMClient( + LiteLLMConfig(model=rubric.get("judge_model", "claude-sonnet-4-6")) + ) + return LLMJudge(client=client, rubric=rubric) + + +def _load_rubric(name: str) -> dict[str, Any]: + return yaml.safe_load((_RUBRICS / name).read_text()) + + +@pytest.fixture +def extraction_judge() -> LLMJudge: + """Judge loaded with the extraction rubric. + + Set ``REFLEXIO_EVAL_REAL_JUDGE=1`` to hit a real LLM; the default path + stubs the client so the harness smoke-runs without credentials. + """ + rubric = _load_rubric("extraction_rubric.yaml") + if os.environ.get("REFLEXIO_EVAL_REAL_JUDGE") == "1": + return _real_judge(rubric) + return _stubbed_judge(rubric) + + +@pytest.fixture +def search_judge() -> LLMJudge: + """Judge loaded with the search rubric (stubbed by default).""" + rubric = _load_rubric("search_rubric.yaml") + if os.environ.get("REFLEXIO_EVAL_REAL_JUDGE") == "1": + return _real_judge(rubric) + return _stubbed_judge(rubric) diff --git a/tests/eval/golden_set/extraction/mixed_ttl.yaml b/tests/eval/golden_set/extraction/mixed_ttl.yaml new file mode 100644 index 00000000..91b2a766 --- /dev/null +++ b/tests/eval/golden_set/extraction/mixed_ttl.yaml @@ -0,0 +1,20 @@ +id: mixed_ttl +description: Single user message mixes a persistent preference with a short-term context item. +sessions: + - role: user + content: "I'm a senior backend engineer. This week I'm on-call so please avoid scheduling reviews before 10am." +expected_profiles: + - content: "User is a senior backend engineer." + time_to_live: "infinity" + reader_angle: "facts" + - content: "User is on-call this week." + time_to_live: "one_week" + reader_angle: "context" +expected_playbooks: + - trigger: "scheduling a review during user's on-call week" + content: "avoid times before 10am" + reader_angle: "behavior" +notes_for_judge: | + Tests whether extraction distinguishes persistent identity (role) from + short-term context (on-call this week) — single-shot extraction often + collapses them into one TTL. diff --git a/tests/eval/golden_set/extraction/polars_vs_pandas.yaml b/tests/eval/golden_set/extraction/polars_vs_pandas.yaml new file mode 100644 index 00000000..39326189 --- /dev/null +++ b/tests/eval/golden_set/extraction/polars_vs_pandas.yaml @@ -0,0 +1,30 @@ +id: polars_vs_pandas +description: | + User explicitly states a tool-preference fact: polars is preferred over pandas, + because of lazy evaluation and strict dtypes. Includes supersession signal + (they used pandas before). +sessions: + - role: user + content: "I used to use pandas everywhere, but as of last quarter our team standardized on polars — mostly for the lazy evaluation and strict dtypes. pandas still shows up in old notebooks but I don't want agents to suggest pandas for new code." + - role: assistant + content: "Got it — polars for new work, pandas only for legacy." + - role: user + content: "Right." +expected_profiles: + - content: "User prefers polars over pandas for new work." + time_to_live: "persistent" + reader_angle: "facts" + must_include_in_source_span: "polars" + - content: "User's team standardized on polars last quarter." + time_to_live: "medium_term" + reader_angle: "temporal" + must_include_in_source_span: "last quarter" +expected_playbooks: + - trigger: "user asks for DataFrame code for new work" + content: "use polars, not pandas" + rationale_must_mention: ["lazy", "dtype"] + reader_angle: "rationale" +notes_for_judge: | + A good extraction surfaces BOTH the persistent preference AND the temporal + signal of "as of last quarter". Flattening to a single "user uses polars" + profile counts as a miss on the nuance-gap criterion. diff --git a/tests/eval/golden_set/extraction/superseded_state.yaml b/tests/eval/golden_set/extraction/superseded_state.yaml new file mode 100644 index 00000000..a1704638 --- /dev/null +++ b/tests/eval/golden_set/extraction/superseded_state.yaml @@ -0,0 +1,19 @@ +id: superseded_state +description: User explicitly supersedes an earlier statement within the same session. +sessions: + - role: user + content: "Our staging DB is on 5432." + - role: user + content: "Correction, we moved staging to 5433 yesterday — 5432 is prod now." +expected_profiles: + - content: "Staging DB runs on port 5433." + time_to_live: "medium_term" + reader_angle: "temporal" + - content: "Prod DB runs on port 5432." + time_to_live: "medium_term" + reader_angle: "facts" +must_NOT_include_profiles: + - content_contains: "staging on 5432" +expected_playbooks: [] +notes_for_judge: | + Any output that keeps the superseded "staging on 5432" as a live profile is a hard fail. diff --git a/tests/eval/golden_set/search/db_preference.yaml b/tests/eval/golden_set/search/db_preference.yaml new file mode 100644 index 00000000..7b61c09c --- /dev/null +++ b/tests/eval/golden_set/search/db_preference.yaml @@ -0,0 +1,28 @@ +id: db_preference +description: | + Classic "what DB does the user prefer?" — the stored profile says "polars + for dataframes" AND "postgres for OLTP". The search should surface postgres, + not polars. +query: "what DB does the user prefer?" +conversation_history: [] +seeded_profiles: + - id: p_polars + user_id: u1 + content: "User prefers polars over pandas for DataFrames." + time_to_live: "persistent" + - id: p_pg + user_id: u1 + content: "User prefers postgres for OLTP workloads." + time_to_live: "persistent" + - id: p_redis + user_id: u1 + content: "User uses redis for caching." + time_to_live: "persistent" +seeded_playbooks: [] +expected_top_candidates: ["p_pg"] +expected_answer: "postgres" +must_NOT_rank_first: ["p_polars"] +notes_for_judge: | + Fixed-fanout classic search often confuses "polars" (a dataframe lib, + frequently called a DB in shorthand) with the DB preference. A good agentic + pipeline should reformulate / disambiguate and rank postgres first. diff --git a/tests/eval/golden_set/search/deadline_context.yaml b/tests/eval/golden_set/search/deadline_context.yaml new file mode 100644 index 00000000..8da73e9a --- /dev/null +++ b/tests/eval/golden_set/search/deadline_context.yaml @@ -0,0 +1,19 @@ +id: deadline_context +description: Query asks what the user is working on; depends on short-term context profile. +query: "what is the user working on right now?" +conversation_history: [] +seeded_profiles: + - id: p_role + user_id: u1 + content: "User is a senior backend engineer." + time_to_live: "persistent" + - id: p_project + user_id: u1 + content: "User is migrating the billing service to Go, due Friday." + time_to_live: "short_term" +seeded_playbooks: [] +expected_top_candidates: ["p_project"] +expected_answer: "billing service migration to Go" +notes_for_judge: | + The persistent role profile is a red herring for this query — any pipeline + that ranks p_role first fails. diff --git a/tests/eval/golden_set/search/superseded_rule.yaml b/tests/eval/golden_set/search/superseded_rule.yaml new file mode 100644 index 00000000..3d1c9c06 --- /dev/null +++ b/tests/eval/golden_set/search/superseded_rule.yaml @@ -0,0 +1,25 @@ +id: superseded_rule +description: Query asks about a rule the user updated — must surface the current rule, not the obsolete one. +query: "do we skip tests on ship?" +conversation_history: [] +seeded_profiles: [] +seeded_playbooks: + - id: b_old + user_id: u1 + trigger: "user says ship" + content: "skip tests" + rationale: "" + time_to_live: "expired" + - id: b_new + user_id: u1 + trigger: "user says ship" + content: "run tests then deploy" + rationale: "after the april regression" + time_to_live: "persistent" +expected_top_candidates: ["b_new"] +expected_answer: "run tests then deploy" +must_NOT_rank_first: ["b_old"] +notes_for_judge: | + Classic search with respect_ttl=true may drop b_old entirely (good), but + the agentic temporal intent can keep it flagged as "superseded" and + explain the supersession chain. diff --git a/tests/eval/judge.py b/tests/eval/judge.py new file mode 100644 index 00000000..9b5b9862 --- /dev/null +++ b/tests/eval/judge.py @@ -0,0 +1,76 @@ +"""LLM-as-judge scorer for golden-set evaluation. + +Takes a rubric (prompt template + judge model) and an (expected, actual) +pair, renders the prompt, and parses the judge response into a +``JudgeScore``. Used by the comparison harness in Task 5.7. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from reflexio.server.llm.litellm_client import LiteLLMClient + + +class JudgeScore(BaseModel): + """Judge's per-case numerical verdict. + + Args: + signal_f1 (float): Extraction signal recall vs expected signals, in [0, 1]. + Always 0 for search-rubric scores. + answer_correctness (float): Search top-rank correctness, in [0, 1]. + Always 0 for extraction-rubric scores. + grounded_rate (float): Fraction of emitted items that are grounded in + the source (no hallucinated IDs or source_spans), in [0, 1]. + rationale (str): One-paragraph explanation of the scores. + """ + + signal_f1: float = Field(ge=0.0, le=1.0) + answer_correctness: float = Field(ge=0.0, le=1.0) + grounded_rate: float = Field(ge=0.0, le=1.0) + rationale: str + + +class LLMJudge: + """Wraps a ``LiteLLMClient`` + rubric and produces ``JudgeScore`` results. + + The rubric dict has two required keys: ``prompt`` (a template with + ``{expected}`` / ``{actual}`` substitution placeholders) and + ``judge_model`` (model name override). + + Args: + client: Any client exposing ``generate_chat_response(messages, + response_format, ...)`` — in practice a ``LiteLLMClient`` or a + ``MagicMock`` in unit tests. + rubric (dict): Parsed rubric YAML. + """ + + def __init__(self, *, client: LiteLLMClient | Any, rubric: dict[str, Any]) -> None: + self.client = client + self.rubric = rubric + + def score(self, *, expected: Any, actual: Any) -> JudgeScore: + """Render the rubric prompt and return the parsed judge score. + + Raises: + TypeError: When the client returns a plain string instead of a + structured ``JudgeScore`` (misconfigured response_format). + """ + prompt = ( + self.rubric["prompt"] + .replace("{expected}", str(expected)) + .replace("{actual}", str(actual)) + ) + result = self.client.generate_chat_response( + messages=[{"role": "user", "content": prompt}], + response_format=JudgeScore, + model=self.rubric.get("judge_model"), + ) + if isinstance(result, JudgeScore): + return result + if isinstance(result, BaseModel): + return JudgeScore.model_validate(result.model_dump()) + raise TypeError(f"LLMJudge expected JudgeScore, got {type(result).__name__}") diff --git a/tests/eval/judge_prompts/extraction_rubric.yaml b/tests/eval/judge_prompts/extraction_rubric.yaml new file mode 100644 index 00000000..71e14aca --- /dev/null +++ b/tests/eval/judge_prompts/extraction_rubric.yaml @@ -0,0 +1,18 @@ +judge_model: "claude-sonnet-4-6" +output_schema: JudgeScore +prompt: | + You are a strict extraction judge. Score the actual extraction against the + expected extraction on three dimensions, each in [0.0, 1.0]: + + - signal_f1: does the output contain the expected signals (0=none, 1=all)? + Treat nuance-bearing signals (supersession, mixed-ttl, rationale) as + required signals when the case is flagged as a nuance case — i.e. fold + nuance preservation into signal_f1 rather than scoring it separately. + - grounded_rate: are emitted items' source_spans genuinely in the session + transcript? (0=none verbatim, 1=all verbatim) + + Respond ONLY with JSON matching: + {"signal_f1": float, "answer_correctness": 0, "grounded_rate": float, "rationale": str} + + (answer_correctness is always 0 for extraction — this rubric is + extraction-only.) diff --git a/tests/eval/judge_prompts/search_rubric.yaml b/tests/eval/judge_prompts/search_rubric.yaml new file mode 100644 index 00000000..af0f9006 --- /dev/null +++ b/tests/eval/judge_prompts/search_rubric.yaml @@ -0,0 +1,17 @@ +judge_model: "claude-sonnet-4-6" +output_schema: JudgeScore +prompt: | + You are a strict search judge. Score the ranked candidate list against the + expected answer: + + - answer_correctness: does the top-1 (or top-3 if the case allows) + candidate contain the expected_answer? When any + must_NOT_rank_first item ranks first, set answer_correctness=0 + (the must-not-rank constraint is folded into this score rather than + scored separately, since the JudgeScore response schema has no + dedicated must_not_violated field). + - grounded_rate: do ranked items actually exist in seeded_profiles or + seeded_playbooks (no hallucinated IDs)? + + Respond ONLY with JSON: + {"signal_f1": 0, "answer_correctness": float, "grounded_rate": float, "rationale": str} diff --git a/tests/eval/test_agentic_vs_classic_extraction_integration.py b/tests/eval/test_agentic_vs_classic_extraction_integration.py new file mode 100644 index 00000000..c9e65c86 --- /dev/null +++ b/tests/eval/test_agentic_vs_classic_extraction_integration.py @@ -0,0 +1,27 @@ +"""Agentic-vs-classic extraction comparison harness. + +Scaffolding only: ``classic_out`` and ``agentic_out`` are stubbed empty +because actual backend quality numbers require ``REFLEXIO_EVAL_REAL_JUDGE=1`` +with a real LLM. The harness exists so the golden-set loader, judge wiring, +and test parametrization are proven green in CI. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.integration + + +def test_agentic_vs_classic_extraction(extraction_case, extraction_judge): + """For each golden case, the stubbed judge returns a parseable score.""" + classic_out = {"profiles": [], "playbooks": []} + agentic_out = {"profiles": [], "playbooks": []} + + c_score = extraction_judge.score(expected=extraction_case, actual=classic_out) + a_score = extraction_judge.score(expected=extraction_case, actual=agentic_out) + + assert c_score.signal_f1 >= 0.0 + assert a_score.signal_f1 >= 0.0 + assert c_score.rationale + assert a_score.rationale diff --git a/tests/eval/test_agentic_vs_classic_search_integration.py b/tests/eval/test_agentic_vs_classic_search_integration.py new file mode 100644 index 00000000..9e8e8e5f --- /dev/null +++ b/tests/eval/test_agentic_vs_classic_search_integration.py @@ -0,0 +1,25 @@ +"""Agentic-vs-classic search comparison harness (scaffolding only). + +Mirrors the extraction comparison harness; actual quality numbers require +``REFLEXIO_EVAL_REAL_JUDGE=1`` + real LLM keys. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.integration + + +def test_agentic_vs_classic_search(search_case, search_judge): + """For each golden case, the stubbed judge returns a parseable score.""" + classic_out = {"ranked_ids": []} + agentic_out = {"ranked_ids": []} + + c_score = search_judge.score(expected=search_case, actual=classic_out) + a_score = search_judge.score(expected=search_case, actual=agentic_out) + + assert c_score.answer_correctness >= 0.0 + assert a_score.answer_correctness >= 0.0 + assert c_score.rationale + assert a_score.rationale diff --git a/tests/eval/test_aggregate.py b/tests/eval/test_aggregate.py new file mode 100644 index 00000000..e51a2272 --- /dev/null +++ b/tests/eval/test_aggregate.py @@ -0,0 +1,58 @@ +"""Unit tests for the eval polars aggregator.""" + +from __future__ import annotations + +import polars as pl +import pytest + +from tests.eval.aggregate import aggregate_eval_results + + +def _write_fixture(tmp_path) -> str: + df = pl.DataFrame( + { + "backend": ["classic", "classic", "agentic", "agentic"], + "signal_f1": [0.5, 0.6, 0.8, 0.7], + "answer_correctness": [0.5, 0.5, 0.7, 0.8], + "grounded_rate": [0.9, 0.95, 0.98, 1.0], + "cost_usd": [0.001, 0.001, 0.01, 0.01], + "latency_ms": [1000, 1100, 2500, 2700], + } + ) + path = tmp_path / "r.parquet" + df.write_parquet(path) + return str(path) + + +def test_aggregate_returns_per_backend_stats(tmp_path): + """Output has one row per backend and the expected aggregated columns.""" + out = aggregate_eval_results(_write_fixture(tmp_path)) + + assert set(out["backend"].to_list()) == {"classic", "agentic"} + assert "mean_f1" in out.columns + assert "mean_correctness" in out.columns + assert "grounded_rate" in out.columns + assert "mean_cost" in out.columns + assert "p95_latency" in out.columns + + +def test_aggregate_means_are_correct(tmp_path): + """Agentic mean_f1 = (0.8 + 0.7) / 2 = 0.75.""" + out = aggregate_eval_results(_write_fixture(tmp_path)) + + agentic = out.filter(pl.col("backend") == "agentic").row(0, named=True) + assert agentic["mean_f1"] == pytest.approx(0.75) + assert agentic["mean_correctness"] == pytest.approx(0.75) + assert agentic["mean_cost"] == pytest.approx(0.01) + + +def test_aggregate_p95_latency_is_tail(tmp_path): + """p95 latency should be near the tail of each backend's latency distribution.""" + out = aggregate_eval_results(_write_fixture(tmp_path)) + + classic = out.filter(pl.col("backend") == "classic").row(0, named=True) + agentic = out.filter(pl.col("backend") == "agentic").row(0, named=True) + assert classic["p95_latency"] >= 1000 + assert classic["p95_latency"] <= 1100 + assert agentic["p95_latency"] >= 2500 + assert agentic["p95_latency"] <= 2700 diff --git a/tests/eval/test_judge_unit.py b/tests/eval/test_judge_unit.py new file mode 100644 index 00000000..03339f84 --- /dev/null +++ b/tests/eval/test_judge_unit.py @@ -0,0 +1,67 @@ +"""Unit tests for LLMJudge + JudgeScore.""" + +from unittest.mock import MagicMock + +import pytest + +from tests.eval.judge import JudgeScore, LLMJudge + + +def test_judge_score_parses_llm_output(): + """When the client returns a JudgeScore directly, the judge passes it through.""" + client = MagicMock() + client.generate_chat_response.return_value = JudgeScore( + signal_f1=0.8, + answer_correctness=0.0, + grounded_rate=1.0, + rationale="fine", + ) + j = LLMJudge( + client=client, + rubric={ + "judge_model": "claude-sonnet-4-6", + "prompt": "score: {expected} vs {actual}", + }, + ) + s = j.score(expected={"x": 1}, actual={"x": 1}) + assert s.signal_f1 == 0.8 + assert s.grounded_rate == 1.0 + client.generate_chat_response.assert_called_once() + + +def test_judge_prompt_is_rendered_with_expected_and_actual(): + """The rubric placeholders are substituted before the LLM is called.""" + client = MagicMock() + client.generate_chat_response.return_value = JudgeScore( + signal_f1=0.5, answer_correctness=0.0, grounded_rate=1.0, rationale="ok" + ) + j = LLMJudge( + client=client, + rubric={"judge_model": "m", "prompt": "E={expected} A={actual}"}, + ) + j.score(expected="EXP", actual="ACT") + + call_msgs = client.generate_chat_response.call_args.kwargs["messages"] + assert call_msgs[0]["content"] == "E=EXP A=ACT" + + +def test_judge_passes_judge_model_as_override(): + client = MagicMock() + client.generate_chat_response.return_value = JudgeScore( + signal_f1=0.0, answer_correctness=0.0, grounded_rate=0.0, rationale="" + ) + j = LLMJudge( + client=client, rubric={"judge_model": "claude-haiku-4-5", "prompt": "p"} + ) + j.score(expected={}, actual={}) + + assert client.generate_chat_response.call_args.kwargs["model"] == "claude-haiku-4-5" + + +def test_judge_raises_typeerror_on_plain_string_response(): + """Misconfigured response_format could yield a str — we fail loudly.""" + client = MagicMock() + client.generate_chat_response.return_value = "not a JudgeScore" + j = LLMJudge(client=client, rubric={"judge_model": "m", "prompt": "p"}) + with pytest.raises(TypeError): + j.score(expected={}, actual={}) diff --git a/tests/lib/test_profile_workflows_unit.py b/tests/lib/test_profile_workflows_unit.py index e436f2de..d31ba7c4 100644 --- a/tests/lib/test_profile_workflows_unit.py +++ b/tests/lib/test_profile_workflows_unit.py @@ -250,7 +250,7 @@ def test_search_profiles_current_only(reflexio_with_config): user_id=user_id, query_text="sushi", top_k=10 ) - response = reflexio.search_profiles(search_request) + response = reflexio.search_user_profiles(search_request) assert response.success is True # Default status_filter is [None] which means current profiles only @@ -278,7 +278,7 @@ def test_search_profiles_with_status_filter(reflexio_with_config): user_id=user_id, query_text="test", top_k=10 ) - response = reflexio.search_profiles( + response = reflexio.search_user_profiles( search_request, status_filter=[None, Status.PENDING] ) diff --git a/tests/lib/test_profiles_unit.py b/tests/lib/test_profiles_unit.py index bf82cc35..45d1199a 100644 --- a/tests/lib/test_profiles_unit.py +++ b/tests/lib/test_profiles_unit.py @@ -1,6 +1,6 @@ """Unit tests for ProfilesMixin. -Tests get_profiles, get_all_profiles, search_profiles, delete_profile, +Tests get_profiles, get_all_profiles, search_user_profiles, delete_profile, delete_all_profiles_bulk, delete_profiles_by_ids, get_profile_change_logs, get_profile_statistics, upgrade_all_profiles, and downgrade_all_profiles with mocked storage and services. @@ -217,7 +217,7 @@ def test_custom_status_filter(self): # --------------------------------------------------------------------------- -# search_profiles +# search_user_profiles # --------------------------------------------------------------------------- @@ -229,7 +229,7 @@ def test_query_delegation(self): _get_storage(mixin).search_user_profile.return_value = [sample] request = SearchUserProfileRequest(user_id="user1", query="sushi") - response = mixin.search_profiles(request) + response = mixin.search_user_profiles(request) assert response.success is True assert len(response.user_profiles) == 1 @@ -240,7 +240,7 @@ def test_storage_not_configured(self): mixin = _make_mixin(storage_configured=False) request = SearchUserProfileRequest(user_id="user1", query="sushi") - response = mixin.search_profiles(request) + response = mixin.search_user_profiles(request) assert response.success is True assert response.user_profiles == [] @@ -251,7 +251,7 @@ def test_dict_input(self): mixin = _make_mixin() _get_storage(mixin).search_user_profile.return_value = [] - response = mixin.search_profiles({"user_id": "user1", "query": "test"}) + response = mixin.search_user_profiles({"user_id": "user1", "query": "test"}) assert response.success is True @@ -261,7 +261,7 @@ def test_default_status_filter(self): _get_storage(mixin).search_user_profile.return_value = [] request = SearchUserProfileRequest(user_id="user1", query="test") - mixin.search_profiles(request) + mixin.search_user_profiles(request) call_kwargs = _get_storage(mixin).search_user_profile.call_args assert call_kwargs[1]["status_filter"] == [None] @@ -272,7 +272,7 @@ def test_custom_status_filter(self): _get_storage(mixin).search_user_profile.return_value = [] request = SearchUserProfileRequest(user_id="user1", query="test") - mixin.search_profiles(request, status_filter=[Status.PENDING]) + mixin.search_user_profiles(request, status_filter=[Status.PENDING]) call_kwargs = _get_storage(mixin).search_user_profile.call_args assert call_kwargs[1]["status_filter"] == [Status.PENDING] diff --git a/tests/lib/test_search_unit.py b/tests/lib/test_search_unit.py index 62e9ebf4..0218a7e2 100644 --- a/tests/lib/test_search_unit.py +++ b/tests/lib/test_search_unit.py @@ -198,6 +198,7 @@ def test_delegation_to_service(self): mixin.llm_client = MagicMock() mock_config = MagicMock() mock_config.llm_config = None + mock_config.search_backend = "classic" mixin.request_context.configurator.get_config.return_value = mock_config expected_response = UnifiedSearchResponse(success=True) @@ -225,3 +226,73 @@ def test_storage_not_configured(self): assert response.success is True assert response.msg is not None + + def test_dispatches_to_agentic_when_search_backend_agentic(self): + """When config.search_backend == 'agentic', AgenticSearchService.search runs. + + Pre-fix bug: lib/_search.py hardcoded run_unified_search regardless of + config — agentic SearchAgent was implemented but unreachable from the + public /api/search path. This test pins the dispatch. + """ + mixin = _make_mixin() + mixin.llm_client = MagicMock() + mock_config = MagicMock() + mock_config.llm_config = None + mock_config.search_backend = "agentic" + mixin.request_context.configurator.get_config.return_value = mock_config + + expected_response = UnifiedSearchResponse(success=True, agent_answer="hi") + + with ( + patch( + "reflexio.server.services.search.agentic_search_service.AgenticSearchService" + ) as mock_agentic_cls, + patch( + "reflexio.server.services.unified_search_service.run_unified_search" + ) as mock_run_unified, + ): + mock_agentic_inst = MagicMock() + mock_agentic_inst.search.return_value = expected_response + mock_agentic_cls.return_value = mock_agentic_inst + + request = UnifiedSearchRequest(query="test query") + response = mixin.unified_search(request, org_id="org_1") + + assert response is expected_response + mock_agentic_cls.assert_called_once_with( + llm_client=mixin.llm_client, + request_context=mixin.request_context, + ) + mock_agentic_inst.search.assert_called_once_with(request) + mock_run_unified.assert_not_called() + + def test_dispatches_to_classic_when_search_backend_classic(self): + """When config.search_backend == 'classic', run_unified_search runs. + + Belt-and-suspenders: ensures the agentic branch doesn't accidentally + capture the classic path on the default value. + """ + mixin = _make_mixin() + mixin.llm_client = MagicMock() + mock_config = MagicMock() + mock_config.llm_config = None + mock_config.search_backend = "classic" + mixin.request_context.configurator.get_config.return_value = mock_config + + expected_response = UnifiedSearchResponse(success=True) + + with ( + patch( + "reflexio.server.services.unified_search_service.run_unified_search", + return_value=expected_response, + ) as mock_run_unified, + patch( + "reflexio.server.services.search.agentic_search_service.AgenticSearchService" + ) as mock_agentic_cls, + ): + request = UnifiedSearchRequest(query="test query") + response = mixin.unified_search(request, org_id="org_1") + + assert response is expected_response + mock_run_unified.assert_called_once() + mock_agentic_cls.assert_not_called() diff --git a/tests/models/api_schema/__init__.py b/tests/models/api_schema/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/api_schema/test_domain_entities.py b/tests/models/api_schema/test_domain_entities.py new file mode 100644 index 00000000..a6897d1e --- /dev/null +++ b/tests/models/api_schema/test_domain_entities.py @@ -0,0 +1,61 @@ +"""Task 2.3: optional source_span/notes/reader_angle on UserProfile and UserPlaybook.""" + +from reflexio.models.api_schema.domain.entities import UserPlaybook, UserProfile + + +def test_user_profile_optional_new_fields_default_to_none() -> None: + p = UserProfile( + profile_id="p1", + user_id="u1", + content="x", + last_modified_timestamp=0, + generated_from_request_id="r1", + ) + assert p.source_span is None + assert p.notes is None + assert p.reader_angle is None + + +def test_user_profile_accepts_optional_fields() -> None: + p = UserProfile( + profile_id="p2", + user_id="u1", + content="x", + last_modified_timestamp=0, + generated_from_request_id="r1", + source_span="q", + notes="n", + reader_angle="facts", + ) + assert p.source_span == "q" + assert p.notes == "n" + assert p.reader_angle == "facts" + + +def test_user_playbook_optional_new_fields_default_to_none() -> None: + pb = UserPlaybook( + agent_version="v1", + request_id="r1", + trigger="t", + content="c", + rationale="r", + ) + assert pb.source_span is None + assert pb.notes is None + assert pb.reader_angle is None + + +def test_user_playbook_accepts_optional_fields() -> None: + pb = UserPlaybook( + agent_version="v1", + request_id="r1", + trigger="t", + content="c", + rationale="r", + source_span="q", + notes="n", + reader_angle="behavior", + ) + assert pb.source_span == "q" + assert pb.notes == "n" + assert pb.reader_angle == "behavior" diff --git a/tests/models/api_schema/test_retriever_schema.py b/tests/models/api_schema/test_retriever_schema.py new file mode 100644 index 00000000..c38405f8 --- /dev/null +++ b/tests/models/api_schema/test_retriever_schema.py @@ -0,0 +1,47 @@ +"""Tests for retriever_schema — UnifiedSearchResponse msg field round-trips. + +The agentic search orchestrator relies on ``UnifiedSearchResponse.msg`` +being an accepted, round-trippable field so it can surface partial-failure +context. These tests pin the contract. +""" + +from __future__ import annotations + +from reflexio.models.api_schema.retriever_schema import UnifiedSearchResponse + + +def test_unified_search_response_accepts_msg(): + r = UnifiedSearchResponse( + success=True, + profiles=[], + user_playbooks=[], + agent_playbooks=[], + reformulated_query="q", + msg="partial", + ) + assert r.msg == "partial" + + +def test_unified_search_response_msg_defaults_to_none(): + r = UnifiedSearchResponse( + success=True, + profiles=[], + user_playbooks=[], + agent_playbooks=[], + reformulated_query="q", + ) + assert r.msg is None + + +def test_unified_search_response_msg_roundtrips_through_json(): + r = UnifiedSearchResponse( + success=True, + profiles=[], + user_playbooks=[], + agent_playbooks=[], + reformulated_query="q", + msg="partial: some agents timed out", + ) + restored = UnifiedSearchResponse.model_validate_json(r.model_dump_json()) + assert restored.msg == "partial: some agents timed out" + assert restored.reformulated_query == "q" diff --git a/tests/server/__init__.py b/tests/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/server/api_endpoints/test_api_routes.py b/tests/server/api_endpoints/test_api_routes.py index 90e765ef..d10a5de9 100644 --- a/tests/server/api_endpoints/test_api_routes.py +++ b/tests/server/api_endpoints/test_api_routes.py @@ -100,7 +100,7 @@ def test_search_profiles_returns_200(self, client): return_value=mock_response, ): response = client.post( - "/api/search_profiles", + "/api/search_user_profiles", json={"user_id": "user-1", "query": "test user"}, ) assert response.status_code == 200 @@ -129,7 +129,7 @@ def test_search_interactions_returns_200(self, client): assert data["interactions"] == [] def test_search_profiles_missing_body_returns_422(self, client): - response = client.post("/api/search_profiles") + response = client.post("/api/search_user_profiles") assert response.status_code == 422 diff --git a/tests/server/api_endpoints/test_retriever_api.py b/tests/server/api_endpoints/test_retriever_api.py index 9c166a19..a6cd5f6f 100644 --- a/tests/server/api_endpoints/test_retriever_api.py +++ b/tests/server/api_endpoints/test_retriever_api.py @@ -27,14 +27,14 @@ def mock_reflexio(): class TestSearchUserProfiles: - def test_delegates_to_search_profiles(self, mock_reflexio): + def test_delegates_to_search_user_profiles(self, mock_reflexio): request = MagicMock() expected = MagicMock() - mock_reflexio.search_profiles.return_value = expected + mock_reflexio.search_user_profiles.return_value = expected result = search_user_profiles("org-1", request) - mock_reflexio.search_profiles.assert_called_once_with(request) + mock_reflexio.search_user_profiles.assert_called_once_with(request) assert result is expected diff --git a/tests/server/llm/test_litellm_client.py b/tests/server/llm/test_litellm_client.py index 938f1079..1354d60c 100644 --- a/tests/server/llm/test_litellm_client.py +++ b/tests/server/llm/test_litellm_client.py @@ -9,6 +9,7 @@ import struct import tempfile import zlib +from collections.abc import Generator from pathlib import Path import pytest @@ -142,7 +143,7 @@ def test_image_bytes() -> bytes: @pytest.fixture -def test_image_file(test_image_bytes: bytes) -> str: +def test_image_file(test_image_bytes: bytes) -> Generator[str, None, None]: """Create a temporary PNG image file.""" with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: f.write(test_image_bytes) @@ -644,7 +645,7 @@ def test_create_client_with_azure_openai_config(self): openai=OpenAIConfig( azure_config=AzureOpenAIConfig( api_key="test-azure-key-11111", - endpoint="https://test-resource.openai.azure.com/", + endpoint="https://test-resource.openai.azure.com/", # type: ignore[arg-type] api_version="2024-02-15-preview", deployment_name="gpt-4o-deployment", ) @@ -716,7 +717,7 @@ def test_api_key_resolution_azure_model(self): api_key="direct-openai-key", azure_config=AzureOpenAIConfig( api_key="azure-key", - endpoint="https://azure.openai.azure.com/", + endpoint="https://azure.openai.azure.com/", # type: ignore[arg-type] api_version="2024-02-15-preview", ), ), diff --git a/tests/server/llm/test_litellm_client_tool_calls.py b/tests/server/llm/test_litellm_client_tool_calls.py new file mode 100644 index 00000000..53a7cbf3 --- /dev/null +++ b/tests/server/llm/test_litellm_client_tool_calls.py @@ -0,0 +1,144 @@ +"""LiteLLMClient extensions for tool-calling (Task 1.3).""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from reflexio.server.llm.litellm_client import ( + LiteLLMClient, + LiteLLMConfig, + ToolCallingChatResponse, +) +from reflexio.server.llm.model_defaults import ModelRole + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + + +def _mock_tool_call_response(tool_name: str, args_json: str) -> MagicMock: + """Build a MagicMock shaped like a litellm tool-call response.""" + tool_call = MagicMock() + tool_call.function.name = tool_name + tool_call.function.arguments = args_json + + message = MagicMock() + message.content = None + message.tool_calls = [tool_call] + + choice = MagicMock() + choice.message = message + choice.finish_reason = "tool_calls" + + response = MagicMock() + response.choices = [choice] + response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return response + + +def _mock_text_response(text: str) -> MagicMock: + """Build a MagicMock shaped like a normal litellm text response.""" + message = MagicMock() + message.content = text + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + response = MagicMock() + response.choices = [choice] + response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return response + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestToolCallingExtensions: + """Tests for tools/tool_choice/model_role kwargs on LiteLLMClient.""" + + def test_generate_chat_response_passes_tools_kwarg(self) -> None: + """tools + tool_choice are forwarded to litellm.completion; result is ToolCallingChatResponse.""" + config = LiteLLMConfig(model="gpt-4o") + client = LiteLLMClient(config) + + mock_response = _mock_tool_call_response("emit_profile", '{"name": "Alice"}') + + tools = [ + { + "type": "function", + "function": { + "name": "emit_profile", + "description": "Emit a profile", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + with patch("litellm.completion", return_value=mock_response) as mock_completion: + result = client.generate_chat_response( + messages=[{"role": "user", "content": "hello"}], + tools=tools, + tool_choice="auto", + ) + + # The tools and tool_choice kwargs must have been forwarded + call_kwargs = mock_completion.call_args.kwargs + assert call_kwargs["tools"] == tools + assert call_kwargs["tool_choice"] == "auto" + + # The result must be a ToolCallingChatResponse + assert isinstance(result, ToolCallingChatResponse) + assert result.tool_calls is not None + assert result.tool_calls[0].function.name == "emit_profile" + assert result.finish_reason == "tool_calls" + assert result.content is None + + def test_model_role_resolves_to_extraction_agent_default( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """model_role=EXTRACTION_AGENT resolves to the anthropic extraction_agent default model.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + # Ensure no other provider keys interfere + for var in ( + "OPENAI_API_KEY", + "GEMINI_API_KEY", + "DEEPSEEK_API_KEY", + "OPENROUTER_API_KEY", + "CLAUDE_SMART_USE_LOCAL_CLI", + ): + monkeypatch.delenv(var, raising=False) + + config = LiteLLMConfig(model="gpt-4o") + client = LiteLLMClient(config) + + mock_response = _mock_text_response("hi") + + with patch("litellm.completion", return_value=mock_response) as mock_completion: + client.generate_chat_response( + messages=[{"role": "user", "content": "hello"}], + model_role=ModelRole.EXTRACTION_AGENT, + ) + + call_kwargs = mock_completion.call_args.kwargs + assert call_kwargs["model"] == "claude-sonnet-4-6" + + def test_non_tool_path_unchanged(self) -> None: + """Without tools kwarg the existing str-return path is untouched.""" + config = LiteLLMConfig(model="gpt-4o") + client = LiteLLMClient(config) + + mock_response = _mock_text_response("hi") + + with patch("litellm.completion", return_value=mock_response): + result = client.generate_chat_response( + messages=[{"role": "user", "content": "hello"}], + ) + + assert result == "hi" + assert not isinstance(result, ToolCallingChatResponse) diff --git a/tests/server/llm/test_litellm_client_unit.py b/tests/server/llm/test_litellm_client_unit.py index 9ff65801..a8e0b826 100644 --- a/tests/server/llm/test_litellm_client_unit.py +++ b/tests/server/llm/test_litellm_client_unit.py @@ -38,6 +38,7 @@ LiteLLMClient, LiteLLMClientError, LiteLLMConfig, + StructuredOutputParseError, _get_embedding_encoding, _get_embedding_limit, _truncate_for_embedding, @@ -165,7 +166,7 @@ def test_init_with_openai_api_key_config(self): def test_init_with_azure_config(self): azure = AzureOpenAIConfig( api_key="az-key", - endpoint="https://myresource.openai.azure.com/", + endpoint="https://myresource.openai.azure.com/", # type: ignore[arg-type] api_version="2024-02-15-preview", ) api_key_config = APIKeyConfig(openai=CommonsOpenAIConfig(azure_config=azure)) @@ -173,7 +174,7 @@ def test_init_with_azure_config(self): client = LiteLLMClient(config) assert client._api_key == "az-key" - assert "myresource" in client._api_base + assert client._api_base is not None and "myresource" in client._api_base assert client._api_version == "2024-02-15-preview" def test_init_with_anthropic_config(self): @@ -215,7 +216,7 @@ def test_init_with_custom_endpoint(self): custom_endpoint=CustomEndpointConfig( model="my-model", api_key="ce-key", - api_base="https://custom.api.com/v1", + api_base="https://custom.api.com/v1", # type: ignore[arg-type] ) ) config = LiteLLMConfig(model="gpt-4o", api_key_config=api_key_config) @@ -245,7 +246,7 @@ def test_custom_endpoint_priority_for_non_embedding(self): custom_endpoint=CustomEndpointConfig( model="custom-model", api_key="ce-key", - api_base="https://custom.api.com/v1", + api_base="https://custom.api.com/v1", # type: ignore[arg-type] ), openai=CommonsOpenAIConfig(api_key="sk-openai"), ) @@ -261,7 +262,7 @@ def test_custom_endpoint_skipped_for_embedding(self): custom_endpoint=CustomEndpointConfig( model="custom-model", api_key="ce-key", - api_base="https://custom.api.com/v1", + api_base="https://custom.api.com/v1", # type: ignore[arg-type] ), openai=CommonsOpenAIConfig(api_key="sk-openai"), ) @@ -1035,11 +1036,84 @@ def test_python_style_json_sanitized(self, client): assert isinstance(result, SampleResponse) assert result.answer == "ok" - def test_unparseable_returns_raw_content(self, client): - result = client._maybe_parse_structured_output( - "totally not json", SampleResponse, True + def test_unparseable_raises_structured_output_parse_error(self, client): + with pytest.raises(StructuredOutputParseError): + client._maybe_parse_structured_output( + "totally not json", SampleResponse, True + ) + + +# =================================================================== +# Retry-on-parse-failure tests +# =================================================================== + + +class TestStructuredOutputRetry: + """Tests for retry behaviour when _maybe_parse_structured_output raises.""" + + def _make_mock_response(self, content: str) -> MagicMock: + """Build a mock litellm.completion response with given content.""" + choice = MagicMock() + choice.message.content = content + choice.message.tool_calls = None + choice.finish_reason = "stop" + resp = MagicMock() + resp.choices = [choice] + resp.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + resp.usage.prompt_tokens_details = None + resp.usage.cache_creation_input_tokens = None + resp.usage.cache_read_input_tokens = None + return resp + + def test_structured_output_parse_failure_retries_and_succeeds(self): + """Malformed JSON on first attempt, valid on second — retry eventually succeeds.""" + call_count = 0 + valid_json = '{"answer": "ok", "score": 42}' + + def fake_completion(**kwargs): + nonlocal call_count + call_count += 1 + content = "not valid json {{{{" if call_count == 1 else valid_json + return self._make_mock_response(content) + + client = _build_client( + LiteLLMConfig(model="gpt-4o-mini", max_retries=3, retry_delay=0) + ) + + with patch("litellm.completion", side_effect=fake_completion): + result = client.generate_chat_response( + messages=[{"role": "user", "content": "test"}], + response_format=SampleResponse, + ) + + assert call_count == 2 + assert isinstance(result, SampleResponse) + assert result.answer == "ok" + assert result.score == 42 + + def test_structured_output_parse_failure_all_retries_exhausted_raises(self): + """Every attempt returns malformed content — raises LiteLLMClientError wrapping StructuredOutputParseError after exhaustion.""" + call_count = 0 + + def fake_completion(**kwargs): + nonlocal call_count + call_count += 1 + return self._make_mock_response("not valid json at all {{{{") + + client = _build_client( + LiteLLMConfig(model="gpt-4o-mini", max_retries=2, retry_delay=0) ) - assert result == "totally not json" + + with ( + patch("litellm.completion", side_effect=fake_completion), + pytest.raises(LiteLLMClientError), + ): + client.generate_chat_response( + messages=[{"role": "user", "content": "test"}], + response_format=SampleResponse, + ) + + assert call_count == 2 # =================================================================== @@ -1276,7 +1350,7 @@ def test_custom_endpoint_overrides_model(self, mock_completion): custom_endpoint=CustomEndpointConfig( model="custom-model", api_key="ce-key", - api_base="https://custom.api.com/v1", + api_base="https://custom.api.com/v1", # type: ignore[arg-type] ) ) config = LiteLLMConfig(model="gpt-4o", api_key_config=api_key_config) diff --git a/tests/server/llm/test_model_defaults.py b/tests/server/llm/test_model_defaults.py index 3bf725e4..e662ac29 100644 --- a/tests/server/llm/test_model_defaults.py +++ b/tests/server/llm/test_model_defaults.py @@ -305,3 +305,53 @@ def test_all_roles_have_values(self) -> None: ): value = getattr(defaults, role.value) assert value, f"{provider}.{role.value} is empty" + + +# --------------------------------------------------------------------------- +# EXTRACTION_AGENT and SEARCH_AGENT roles +# --------------------------------------------------------------------------- + + +class TestAgenticV2Roles: + def test_extraction_agent_role_exists(self) -> None: + assert ModelRole.EXTRACTION_AGENT.value == "extraction_agent" + + def test_search_agent_role_exists(self) -> None: + assert ModelRole.SEARCH_AGENT.value == "search_agent" + + def test_anthropic_defaults_map_to_sonnet(self) -> None: + anthropic = _PROVIDER_DEFAULTS["anthropic"] + assert anthropic.extraction_agent is not None + assert "sonnet" in anthropic.extraction_agent.lower() + assert anthropic.search_agent is not None + assert "sonnet" in anthropic.search_agent.lower() + + def test_openai_defaults_map_to_gpt5_mini(self) -> None: + openai = _PROVIDER_DEFAULTS["openai"] + assert openai.extraction_agent == "gpt-5-mini" + assert openai.search_agent == "gpt-5-mini" + + def test_claude_code_defaults_cover_new_roles(self) -> None: + cc = _PROVIDER_DEFAULTS["claude-code"] + assert cc.extraction_agent == "claude-code/default" + assert cc.search_agent == "claude-code/default" + + def test_unpopulated_providers_default_to_none(self) -> None: + """Providers that haven't opted into agentic-v2 fall through to next priority provider.""" + local = _PROVIDER_DEFAULTS["local"] + assert local.extraction_agent is None + assert local.search_agent is None + + def test_resolve_extraction_agent_with_anthropic( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("ANTHROPIC_API_KEY", "ant-test") + name = resolve_model_name(role=ModelRole.EXTRACTION_AGENT) + assert "sonnet" in name.lower() + + def test_resolve_search_agent_with_openai( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + name = resolve_model_name(role=ModelRole.SEARCH_AGENT) + assert name == "gpt-5-mini" diff --git a/tests/server/llm/test_tools.py b/tests/server/llm/test_tools.py new file mode 100644 index 00000000..5bc0970e --- /dev/null +++ b/tests/server/llm/test_tools.py @@ -0,0 +1,490 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import BaseModel + +from reflexio.server.llm.litellm_client import ( + LiteLLMClient, + LiteLLMConfig, + ToolCallingChatResponse, +) +from reflexio.server.llm.model_defaults import ModelRole +from reflexio.server.llm.tools import ( + Tool, + ToolLoopResult, # noqa: F401 + ToolLoopTrace, # noqa: F401 + ToolRegistry, + run_tool_loop, +) + + +class EmitProfileArgs(BaseModel): + """Emit a candidate user profile item.""" + + content: str + time_to_live: str + + +class Ctx: + def __init__(self): + self.calls = [] + self.finished = False + + def emit(self, args, ctx): + self.calls.append(args) + return {"ok": True} + + +def test_tool_openai_spec_uses_docstring_and_schema(): + t = Tool(name="emit_profile", args_model=EmitProfileArgs, handler=lambda _a, _c: {}) + spec = t.openai_spec() + assert spec["type"] == "function" + assert spec["function"]["name"] == "emit_profile" + assert "Emit a candidate user profile item." in spec["function"]["description"] + assert spec["function"]["parameters"]["properties"]["content"]["type"] == "string" + + +def test_registry_handle_parses_and_dispatches(): + ctx = Ctx() + t = Tool(name="emit_profile", args_model=EmitProfileArgs, handler=ctx.emit) + reg = ToolRegistry() + reg.register(t) + result = reg.handle( + "emit_profile", json.dumps({"content": "hi", "time_to_live": "persistent"}), ctx + ) + assert result == {"ok": True} + assert ctx.calls[0].content == "hi" + + +def test_registry_handle_converts_validation_error_to_tool_error(): + ctx = Ctx() + reg = ToolRegistry() + reg.register( + Tool(name="emit_profile", args_model=EmitProfileArgs, handler=ctx.emit) + ) + # Missing required field. + result = reg.handle("emit_profile", json.dumps({"content": "hi"}), ctx) + assert "error" in result + assert "time_to_live" in result["error"] + assert ctx.calls == [] + + +def test_registry_rejects_unknown_tool(): + reg = ToolRegistry() + result = reg.handle("not_a_tool", "{}", None) + assert "error" in result + assert "unknown tool" in result["error"].lower() + + +def test_openai_specs_lists_all_registered_tools(): + reg = ToolRegistry() + reg.register(Tool(name="a", args_model=EmitProfileArgs, handler=lambda *_: {})) + reg.register(Tool(name="b", args_model=EmitProfileArgs, handler=lambda *_: {})) + specs = reg.openai_specs() + assert {s["function"]["name"] for s in specs} == {"a", "b"} + + +def test_mock_tool_call_response_shape(tool_call_completion): + make_tc, make_stop = tool_call_completion + r = make_tc("emit_profile", {"content": "x"}) + assert r.choices[0].finish_reason == "tool_calls" + assert r.choices[0].message.tool_calls[0].function.name == "emit_profile" + s = make_stop() + assert s.choices[0].finish_reason == "stop" + assert s.choices[0].message.tool_calls is None + + +# --------------------------------------------------------------------------- +# run_tool_loop tests +# --------------------------------------------------------------------------- + + +class EmitArgs(BaseModel): + """Emit a value.""" + + value: str + + +class LoopCtx: + """Simple mutable context for tool-loop tests.""" + + def __init__(self): + self.emitted: list[str] = [] + self.finished: bool = False + + +def _make_registry(ctx: LoopCtx) -> ToolRegistry: + """Build a registry with 'emit' and 'finish' tools that mutate *ctx*.""" + + def _emit_handler(args: BaseModel, c: LoopCtx) -> dict: + c.emitted.append(args.value) # type: ignore[attr-defined] + return {"ok": True} + + def _finish_handler(args: BaseModel, c: LoopCtx) -> dict: + c.finished = True + return {"done": True} + + class FinishArgs(BaseModel): + """Signal that extraction is complete.""" + + reg = ToolRegistry() + reg.register(Tool(name="emit", args_model=EmitArgs, handler=_emit_handler)) + reg.register(Tool(name="finish", args_model=FinishArgs, handler=_finish_handler)) + return reg + + +def test_run_tool_loop_drives_multiple_turns_until_finish( + monkeypatch, tool_call_completion +): + """Three LLM turns (emit, emit, finish) should yield finished_reason='finish_tool'.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + make_tc, _make_stop = tool_call_completion + responses = [ + make_tc("emit", {"value": "alpha"}), + make_tc("emit", {"value": "beta"}), + make_tc("finish", {}), + ] + + config = LiteLLMConfig(model="claude-sonnet-4-6") + client = LiteLLMClient(config) + ctx = LoopCtx() + registry = _make_registry(ctx) + + with patch("litellm.completion", side_effect=responses): + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=registry, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + ) + + assert result.finished_reason == "finish_tool" + assert result.trace.finished is True + assert len(result.trace.turns) == 3 + assert ctx.emitted == ["alpha", "beta"] + assert ctx.finished is True + + +def test_run_tool_loop_honours_max_steps(monkeypatch, tool_call_completion): + """With max_steps=3 and unlimited emit responses, the loop caps at 3 turns.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + make_tc, _make_stop = tool_call_completion + # Supply more responses than max_steps so we are cap-limited, not response-limited. + responses = [make_tc("emit", {"value": f"item-{i}"}) for i in range(10)] + + config = LiteLLMConfig(model="claude-sonnet-4-6") + client = LiteLLMClient(config) + ctx = LoopCtx() + registry = _make_registry(ctx) + + with patch("litellm.completion", side_effect=responses): + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=registry, + model_role=ModelRole.EXTRACTION_AGENT, + max_steps=3, + ctx=ctx, + ) + + assert result.finished_reason == "max_steps" + assert len(ctx.emitted) == 3 + + +def test_run_tool_loop_capability_fallback_uses_response_format(monkeypatch): + """When supports_tool_calling is False, generate_chat_response uses response_format.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + from reflexio.server.llm import tools as tools_mod + + monkeypatch.setattr(tools_mod, "supports_tool_calling", lambda _model: False) + + config = LiteLLMConfig(model="some-legacy-model") + client = LiteLLMClient(config) + + class FallbackSchema(BaseModel): + emissions: list[EmitArgs] + + fake_parsed = FallbackSchema(emissions=[EmitArgs(value="x"), EmitArgs(value="y")]) + monkeypatch.setattr(client, "generate_chat_response", lambda **_: fake_parsed) + + ctx = LoopCtx() + registry = _make_registry(ctx) + + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=registry, + model_role=ModelRole.EXTRACTION_AGENT, + fallback_schema=FallbackSchema, + fallback_tool_name="emit", + ctx=ctx, + ) + + assert result.finished_reason == "finish_tool" + assert result.trace.finished is True + assert len(result.trace.turns) == 2 + assert ctx.emitted == ["x", "y"] + + +def test_run_tool_loop_returns_error_on_client_exception( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When generate_chat_response raises, the loop returns finished_reason='error'.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + ctx = LoopCtx() # reuse the helper class defined earlier in the test file + + def _emit_handler(args: BaseModel, c: LoopCtx) -> dict: + c.emitted.append(args.value) # type: ignore[attr-defined] + return {"ok": True} + + reg = ToolRegistry([Tool(name="emit", args_model=EmitArgs, handler=_emit_handler)]) + + config = LiteLLMConfig(model="claude-sonnet-4-6") + client = LiteLLMClient(config) + + def boom(**_kwargs): + raise RuntimeError("simulated provider failure") + + monkeypatch.setattr(client, "generate_chat_response", boom) + + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + max_steps=5, + ctx=ctx, + finish_tool_name="finish", + ) + + assert result.finished_reason == "error" + assert result.trace.finished is False + assert result.trace.turns == [] + + +# ---------------- log_label (llm_io.log) integration ---------------- # + + +def test_run_tool_loop_log_label_none_does_not_invoke_llm_io_helpers( + monkeypatch, tool_call_completion +): + """Default log_label=None → zero calls to log_llm_messages / log_model_response.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + make_tc, _ = tool_call_completion + responses = [make_tc("finish", {})] + client = LiteLLMClient(LiteLLMConfig(model="claude-sonnet-4-6")) + ctx = LoopCtx() + registry = _make_registry(ctx) + + with ( + patch( + "reflexio.server.services.service_utils.log_llm_messages" + ) as mock_log_msgs, + patch( + "reflexio.server.services.service_utils.log_model_response" + ) as mock_log_resp, + patch("litellm.completion", side_effect=responses), + ): + run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=registry, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + ) + + mock_log_msgs.assert_not_called() + mock_log_resp.assert_not_called() + + +def test_run_tool_loop_log_label_native_path_logs_each_turn( + monkeypatch, tool_call_completion +): + """log_label='X' → one log_llm_messages + one log_model_response per native turn. + + Across 2 turns, we expect: + - 2 prompt log entries labelled "X (turn 1)" and "X (turn 2)" + - 2 response log entries with matching labels + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + make_tc, _ = tool_call_completion + responses = [make_tc("emit", {"value": "a"}), make_tc("finish", {})] + client = LiteLLMClient(LiteLLMConfig(model="claude-sonnet-4-6")) + ctx = LoopCtx() + registry = _make_registry(ctx) + + with ( + patch( + "reflexio.server.services.service_utils.log_llm_messages" + ) as mock_log_msgs, + patch( + "reflexio.server.services.service_utils.log_model_response" + ) as mock_log_resp, + patch("litellm.completion", side_effect=responses), + ): + run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=registry, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + log_label="profile_reader_facts", + ) + + assert mock_log_msgs.call_count == 2 + assert mock_log_resp.call_count == 2 + # Label suffixes increment per turn + msg_labels = [c.args[1] for c in mock_log_msgs.call_args_list] + resp_labels = [c.args[1] for c in mock_log_resp.call_args_list] + assert msg_labels == [ + "profile_reader_facts (turn 1)", + "profile_reader_facts (turn 2)", + ] + assert resp_labels == [ + "profile_reader_facts (turn 1)", + "profile_reader_facts (turn 2)", + ] + + +def test_run_tool_loop_log_label_fallback_path_logs_once(monkeypatch): + """Capability-fallback path logs exactly one prompt + one response with '(fallback)' suffix.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + # Force capability-fallback path + monkeypatch.setattr( + "reflexio.server.llm.tools.supports_tool_calling", lambda _model: False + ) + + class EmitListSchema(BaseModel): + items: list[EmitArgs] = [] + + class FinishArgs(BaseModel): + """Signal end.""" + + reg = ToolRegistry() + ctx = LoopCtx() + + def _emit(args: BaseModel, c: LoopCtx) -> dict: + c.emitted.append(args.value) # type: ignore[attr-defined] + return {"ok": True} + + reg.register(Tool(name="emit", args_model=EmitArgs, handler=_emit)) + reg.register( + Tool( + name="finish", + args_model=FinishArgs, + handler=lambda _a, _c: {"done": True}, + ) + ) + + client = LiteLLMClient(LiteLLMConfig(model="claude-sonnet-4-6")) + parsed = EmitListSchema(items=[EmitArgs(value="a"), EmitArgs(value="b")]) + + with ( + patch( + "reflexio.server.services.service_utils.log_llm_messages" + ) as mock_log_msgs, + patch( + "reflexio.server.services.service_utils.log_model_response" + ) as mock_log_resp, + patch.object(client, "generate_chat_response", return_value=parsed), + ): + run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + fallback_schema=EmitListSchema, + fallback_tool_name="emit", + log_label="profile_reader_facts", + ) + + assert mock_log_msgs.call_count == 1 + assert mock_log_resp.call_count == 1 + assert mock_log_msgs.call_args.args[1] == "profile_reader_facts (fallback)" + assert mock_log_resp.call_args.args[1] == "profile_reader_facts (fallback)" + + +# --------------------------------------------------------------------------- +# ToolLoopTurn usage field tests +# --------------------------------------------------------------------------- + + +def test_run_tool_loop_captures_usage_on_tool_loop_turn(monkeypatch): + """Each ToolLoopTurn should carry prompt/completion/total tokens, model name, + and cost_usd when the ToolCallingChatResponse carries a usage object.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + + # Build a fake usage object. + fake_usage = MagicMock() + fake_usage.prompt_tokens = 100 + fake_usage.completion_tokens = 50 + fake_usage.total_tokens = 150 + + # Build scripted ToolCallingChatResponse objects (one tool call, then finish). + tc = MagicMock() + tc.id = "tc_emit" + tc.function = MagicMock() + tc.function.name = "emit" + tc.function.arguments = json.dumps({"value": "hello"}) + + resp_with_usage = ToolCallingChatResponse( + content=None, + tool_calls=[tc], + finish_reason="tool_calls", + usage=fake_usage, + cost_usd=0.002, + ) + resp_finish = ToolCallingChatResponse( + content=None, + tool_calls=None, + finish_reason="stop", + usage=None, + cost_usd=None, + ) + + config = LiteLLMConfig(model="claude-sonnet-4-6") + client = LiteLLMClient(config) + ctx = LoopCtx() + registry = _make_registry(ctx) + + monkeypatch.setattr( + client, + "generate_chat_response", + MagicMock(side_effect=[resp_with_usage, resp_finish]), + ) + + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=registry, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + ) + + assert result.finished_reason == "finish_tool" + assert len(result.trace.turns) == 1 + turn = result.trace.turns[0] + assert turn.prompt_tokens == 100 + assert turn.completion_tokens == 50 + assert turn.total_tokens == 150 + assert turn.cost_usd == pytest.approx(0.002) + # model field is populated from the resolved model name (non-None) + assert turn.model is not None diff --git a/tests/server/llm/test_tools_multi_stage_integration.py b/tests/server/llm/test_tools_multi_stage_integration.py new file mode 100644 index 00000000..fadeb190 --- /dev/null +++ b/tests/server/llm/test_tools_multi_stage_integration.py @@ -0,0 +1,367 @@ +"""Integration tests for the multi-stage fallback path in ``run_tool_loop``. + +These tests target the multi-turn structured-output flow used when the +configured model lacks native tool-calling but should still observe +prior tool results before planning the next call (e.g. the search agent +running on ``minimax/MiniMax-M2.7``). + +The mocked LLM client is scripted to return one ``MultiStagePlan`` +instance per turn; the test asserts that: + + - The loop emits multiple structured-output calls in sequence. + - Each tool result is appended to the shared ``messages`` list so the + next turn's prompt sees it. + - The loop terminates when ``next_call.tool == finish_tool_name``. + - The loop terminates at ``max_steps`` when no finish is emitted. + - Each registry tool dispatches via the discriminator literal. +""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel, Field + +from reflexio.server.llm import tools as tools_mod +from reflexio.server.llm.litellm_client import LiteLLMClient, LiteLLMConfig +from reflexio.server.llm.model_defaults import ModelRole +from reflexio.server.llm.tools import Tool, ToolRegistry, run_tool_loop + +# --------------------------------------------------------------------------- +# Test schemas (mirror SearchAgentTurnPlan shape: reasoning + next_call union) +# --------------------------------------------------------------------------- + + +class _CallEmit(BaseModel): + """Test variant: dispatch ``emit``.""" + + tool: str = Field(default="emit", pattern="^emit$") + value: str + + +class _CallFinish(BaseModel): + """Test variant: dispatch ``finish``.""" + + tool: str = Field(default="finish", pattern="^finish$") + answer: str | None = None + + +class MultiStagePlan(BaseModel): + """Mirror of ``SearchAgentTurnPlan``: one turn of multi-stage fallback.""" + + reasoning: str + # We use a plain Union (no discriminator field) so the tests can + # construct either variant directly without pydantic's discriminator + # validation overhead — the real schema uses a discriminated union. + next_call: _CallEmit | _CallFinish + + +# --------------------------------------------------------------------------- +# Test ctx + registry +# --------------------------------------------------------------------------- + + +class _Ctx: + """Mutable per-run state for tool-loop tests.""" + + def __init__(self) -> None: + self.emitted: list[str] = [] + self.finished: bool = False + self.finish_answer: str | None = None + + +class _EmitArgs(BaseModel): + """Emit a value (test tool).""" + + value: str + + +class _FinishArgs(BaseModel): + """Terminate the test loop.""" + + answer: str | None = None + + +def _make_registry(ctx: _Ctx) -> ToolRegistry: + def _emit_handler(args: BaseModel, c: _Ctx) -> dict: + c.emitted.append(args.value) # type: ignore[attr-defined] + return {"ok": True, "echo": args.value} # type: ignore[attr-defined] + + def _finish_handler(args: BaseModel, c: _Ctx) -> dict: + c.finished = True + c.finish_answer = args.answer # type: ignore[attr-defined] + return {"finished": True} + + reg = ToolRegistry() + reg.register(Tool(name="emit", args_model=_EmitArgs, handler=_emit_handler)) + reg.register(Tool(name="finish", args_model=_FinishArgs, handler=_finish_handler)) + return reg + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _force_no_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: + """Force the capability fallback path so we always exercise multi-stage.""" + monkeypatch.setattr(tools_mod, "supports_tool_calling", lambda _model: False) + + +def _scripted_client( + monkeypatch: pytest.MonkeyPatch, plans: list[MultiStagePlan] +) -> LiteLLMClient: + """Build a LiteLLMClient whose ``generate_chat_response`` returns plans in order.""" + client = LiteLLMClient(LiteLLMConfig(model="some-non-tool-calling-model")) + iterator = iter(plans) + + def fake_generate(**_kwargs: object) -> MultiStagePlan: + return next(iterator) + + monkeypatch.setattr(client, "generate_chat_response", fake_generate) + return client + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_multi_stage_loop_emits_multiple_turns_and_finishes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """3-turn loop: emit, emit, finish — asserts trace shape and ctx mutations.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + _force_no_tool_calling(monkeypatch) + + plans = [ + MultiStagePlan(reasoning="first emit", next_call=_CallEmit(value="alpha")), + MultiStagePlan(reasoning="second emit", next_call=_CallEmit(value="beta")), + MultiStagePlan( + reasoning="done", + next_call=_CallFinish(answer="all set"), + ), + ] + client = _scripted_client(monkeypatch, plans) + ctx = _Ctx() + reg = _make_registry(ctx) + + messages = [{"role": "user", "content": "begin"}] + result = run_tool_loop( + client=client, + messages=messages, + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + finish_tool_name="finish", + multi_stage_schema=MultiStagePlan, + ) + + assert result.finished_reason == "finish_tool" + assert result.trace.finished is True + assert [t.tool_name for t in result.trace.turns] == ["emit", "emit", "finish"] + assert ctx.emitted == ["alpha", "beta"] + assert ctx.finished is True + assert ctx.finish_answer == "all set" + + +def test_multi_stage_loop_appends_tool_results_to_history( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Each tool result must land in ``messages`` so the next turn observes it.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + _force_no_tool_calling(monkeypatch) + + plans = [ + MultiStagePlan(reasoning="r1", next_call=_CallEmit(value="x")), + MultiStagePlan(reasoning="r2", next_call=_CallFinish()), + ] + client = _scripted_client(monkeypatch, plans) + ctx = _Ctx() + reg = _make_registry(ctx) + + messages: list[dict[str, object]] = [{"role": "user", "content": "go"}] + run_tool_loop( + client=client, + messages=messages, + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + finish_tool_name="finish", + multi_stage_schema=MultiStagePlan, + ) + + # Seed + (assistant plan + user result) for the first turn + assistant plan for finish + # The finish branch does not append a user-result message — it returns directly. + roles = [m["role"] for m in messages] + contents = [m["content"] for m in messages] + assert roles == ["user", "assistant", "user", "assistant"] + # The user message holding the tool result must mention the tool name + # AND the handler's payload (so the next turn really sees it). + tool_result_msg = contents[2] + assert isinstance(tool_result_msg, str) + assert "Tool emit returned" in tool_result_msg + assert '"echo": "x"' in tool_result_msg + # The assistant plan messages echo the tool name + args JSON. + plan_msg_1 = contents[1] + plan_msg_2 = contents[3] + assert isinstance(plan_msg_1, str) + assert isinstance(plan_msg_2, str) + assert "Reasoning: r1" in plan_msg_1 + assert "Next call: emit(" in plan_msg_1 + assert "Reasoning: r2" in plan_msg_2 + assert "Next call: finish(" in plan_msg_2 + + +def test_multi_stage_loop_terminates_at_max_steps_when_no_finish( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Loop must stop at ``max_steps`` when the agent never emits ``finish``.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + _force_no_tool_calling(monkeypatch) + + plans = [ + MultiStagePlan(reasoning=f"r{i}", next_call=_CallEmit(value=f"v{i}")) + for i in range(10) + ] + client = _scripted_client(monkeypatch, plans) + ctx = _Ctx() + reg = _make_registry(ctx) + + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + max_steps=3, + ctx=ctx, + finish_tool_name="finish", + multi_stage_schema=MultiStagePlan, + ) + + assert result.finished_reason == "max_steps" + assert result.trace.finished is False + assert len(result.trace.turns) == 3 + assert ctx.emitted == ["v0", "v1", "v2"] + assert ctx.finished is False + + +def test_multi_stage_loop_dispatches_each_call_through_registry( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Verify the registry handler actually runs for each turn (not just stubbed).""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + _force_no_tool_calling(monkeypatch) + + plans = [ + MultiStagePlan(reasoning="r1", next_call=_CallEmit(value="a")), + MultiStagePlan(reasoning="r2", next_call=_CallEmit(value="b")), + MultiStagePlan(reasoning="r3", next_call=_CallFinish()), + ] + client = _scripted_client(monkeypatch, plans) + ctx = _Ctx() + reg = _make_registry(ctx) + + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + finish_tool_name="finish", + multi_stage_schema=MultiStagePlan, + ) + + # Every recorded turn should carry the handler's actual return value. + emit_turns = [t for t in result.trace.turns if t.tool_name == "emit"] + assert [t.result for t in emit_turns] == [ + {"ok": True, "echo": "a"}, + {"ok": True, "echo": "b"}, + ] + finish_turn = next(t for t in result.trace.turns if t.tool_name == "finish") + assert finish_turn.result == {"finished": True} + + +def test_multi_stage_loop_takes_priority_over_single_shot_fallback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When both ``multi_stage_schema`` and ``fallback_schema`` are passed, multi-stage wins.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + _force_no_tool_calling(monkeypatch) + + plans = [MultiStagePlan(reasoning="r", next_call=_CallFinish())] + client = _scripted_client(monkeypatch, plans) + ctx = _Ctx() + reg = _make_registry(ctx) + + class _SingleShotSchema(BaseModel): + items: list[_EmitArgs] = [] + + result = run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + finish_tool_name="finish", + fallback_schema=_SingleShotSchema, + fallback_tool_name="emit", + multi_stage_schema=MultiStagePlan, + ) + + # Single-shot would have produced 0 emits with empty list and returned + # finished_reason='finish_tool' too — but we can prove multi-stage ran + # by checking the recorded tool_name: single-shot would record "emit", + # multi-stage records "finish". + assert [t.tool_name for t in result.trace.turns] == ["finish"] + assert ctx.finished is True + + +def test_multi_stage_loop_logs_per_turn_when_label_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``log_label='X'`` should produce one prompt+response log per turn with multi-stage suffix.""" + from unittest.mock import patch + + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + monkeypatch.delenv("CLAUDE_SMART_USE_LOCAL_CLI", raising=False) + _force_no_tool_calling(monkeypatch) + + plans = [ + MultiStagePlan(reasoning="r1", next_call=_CallEmit(value="x")), + MultiStagePlan(reasoning="r2", next_call=_CallFinish()), + ] + client = _scripted_client(monkeypatch, plans) + ctx = _Ctx() + reg = _make_registry(ctx) + + with ( + patch( + "reflexio.server.services.service_utils.log_llm_messages" + ) as mock_log_msgs, + patch( + "reflexio.server.services.service_utils.log_model_response" + ) as mock_log_resp, + ): + run_tool_loop( + client=client, + messages=[{"role": "user", "content": "go"}], + registry=reg, + model_role=ModelRole.EXTRACTION_AGENT, + ctx=ctx, + finish_tool_name="finish", + multi_stage_schema=MultiStagePlan, + log_label="search_agent", + ) + + assert mock_log_msgs.call_count == 2 + assert mock_log_resp.call_count == 2 + msg_labels = [c.args[1] for c in mock_log_msgs.call_args_list] + assert msg_labels == [ + "search_agent (multi-stage turn 1)", + "search_agent (multi-stage turn 2)", + ] diff --git a/tests/server/services/__init__.py b/tests/server/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/server/services/extraction/__init__.py b/tests/server/services/extraction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_explicit_forget.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_explicit_forget.json new file mode 100644 index 00000000..2d6e474e --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_explicit_forget.json @@ -0,0 +1,13 @@ +{ + "id": "delete_explicit_forget", + "group": "group1_mutation", + "category": "delete", + "existing_storage": [ + {"type": "profile", "id": "p_300", "content": "user has a sister named Sarah", "ttl": "infinity"} + ], + "session": "User: please forget I mentioned my sister Sarah.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_300"} + ], + "expected_reasoning_contains": ["forget", "remove", "requested"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_obsolete_fact.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_obsolete_fact.json new file mode 100644 index 00000000..ab51d516 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_obsolete_fact.json @@ -0,0 +1,13 @@ +{ + "id": "delete_obsolete_fact", + "group": "group1_mutation", + "category": "delete", + "existing_storage": [ + {"type": "profile", "id": "p_301", "content": "user has a golden retriever named Biscuit", "ttl": "infinity"} + ], + "session": "User: I don't have a dog anymore, she passed away last month. Not up to getting another.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_301"} + ], + "expected_reasoning_contains": ["obsolete", "no longer", "remove"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_privacy_wipe.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_privacy_wipe.json new file mode 100644 index 00000000..667fca13 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/delete_privacy_wipe.json @@ -0,0 +1,13 @@ +{ + "id": "delete_privacy_wipe", + "group": "group1_mutation", + "category": "delete", + "existing_storage": [ + {"type": "profile", "id": "p_302", "content": "user lives at 123 Maple Street, Springfield", "ttl": "infinity"} + ], + "session": "User: remove my home address from your memory, I don't want that stored.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_302"} + ], + "expected_reasoning_contains": ["privacy", "remove", "requested"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_dup_profiles.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_dup_profiles.json new file mode 100644 index 00000000..7a1f8bcb --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_dup_profiles.json @@ -0,0 +1,18 @@ +{ + "id": "merge_dup_profiles", + "group": "group1_mutation", + "category": "merge", + "existing_storage": [ + {"type": "profile", "id": "p_200", "content": "user is vegetarian", "ttl": "infinity"}, + {"type": "profile", "id": "p_201", "content": "user follows a vegetarian diet", "ttl": "infinity"} + ], + "session": "User: just confirming I'm vegetarian.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_200"}, + {"op": "delete_user_profile", "id": "p_201"}, + {"op": "create_user_profile", + "content_contains": ["vegetarian"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["duplicate", "merge", "consolidate"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_multi_profiles.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_multi_profiles.json new file mode 100644 index 00000000..63ea2a18 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_multi_profiles.json @@ -0,0 +1,20 @@ +{ + "id": "merge_multi_profiles", + "group": "group1_mutation", + "category": "merge", + "existing_storage": [ + {"type": "profile", "id": "p_210", "content": "user works at Acme", "ttl": "infinity"}, + {"type": "profile", "id": "p_211", "content": "user is employed at Acme Corp", "ttl": "infinity"}, + {"type": "profile", "id": "p_212", "content": "user's employer is Acme", "ttl": "infinity"} + ], + "session": "User: I work at Acme as a data scientist.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_210"}, + {"op": "delete_user_profile", "id": "p_211"}, + {"op": "delete_user_profile", "id": "p_212"}, + {"op": "create_user_profile", + "content_contains": ["Acme", "data scientist"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["duplicate", "merge", "unified"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_same_fact_rephrased.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_same_fact_rephrased.json new file mode 100644 index 00000000..3e0ebfcf --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/merge_same_fact_rephrased.json @@ -0,0 +1,18 @@ +{ + "id": "merge_same_fact_rephrased", + "group": "group1_mutation", + "category": "merge", + "existing_storage": [ + {"type": "profile", "id": "p_220", "content": "user prefers Python", "ttl": "infinity"}, + {"type": "profile", "id": "p_221", "content": "user likes to code in Python", "ttl": "infinity"} + ], + "session": "User: I just prefer Python, let's say that.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_220"}, + {"op": "delete_user_profile", "id": "p_221"}, + {"op": "create_user_profile", + "content_contains": ["Python"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["duplicate", "merge", "same fact"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_add_rationale.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_add_rationale.json new file mode 100644 index 00000000..913c11fb --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_add_rationale.json @@ -0,0 +1,20 @@ +{ + "id": "playbook_add_rationale", + "group": "group1_mutation", + "category": "playbook_expansion", + "existing_storage": [ + {"type": "user_playbook", "id": "pb_11", + "trigger": "user asks for code review", + "content": "be concrete, give actionable suggestions", + "rationale": ""} + ], + "session": "User: by the way, when you review my code, please always explain WHY a change is better — not just what to change.", + "expected_plan": [ + {"op": "delete_user_playbook", "id": "pb_11"}, + {"op": "create_user_playbook", + "trigger_contains": ["code review"], + "content_contains": ["concrete", "actionable", "explain", "why"], + "content_preserves_all": ["be concrete, give actionable suggestions"]} + ], + "expected_reasoning_contains": ["extend", "augment", "additional instruction"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_add_rule.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_add_rule.json new file mode 100644 index 00000000..eb38f22f --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_add_rule.json @@ -0,0 +1,20 @@ +{ + "id": "playbook_add_rule", + "group": "group1_mutation", + "category": "playbook_expansion", + "existing_storage": [ + {"type": "user_playbook", "id": "pb_10", + "trigger": "user asks for code help", + "content": "show code examples with comments", + "rationale": ""} + ], + "session": "User: also, when I ask for code help, prefer TypeScript over JavaScript.", + "expected_plan": [ + {"op": "delete_user_playbook", "id": "pb_10"}, + {"op": "create_user_playbook", + "trigger_contains": ["code help"], + "content_contains": ["examples", "comments", "TypeScript"], + "content_preserves_all": ["show code examples with comments"]} + ], + "expected_reasoning_contains": ["extend", "augment", "add rule"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_extend_trigger_scope.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_extend_trigger_scope.json new file mode 100644 index 00000000..944695d0 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/playbook_extend_trigger_scope.json @@ -0,0 +1,20 @@ +{ + "id": "playbook_extend_trigger_scope", + "group": "group1_mutation", + "category": "playbook_expansion", + "existing_storage": [ + {"type": "user_playbook", "id": "pb_12", + "trigger": "user asks about SQL queries", + "content": "prefer CTEs over subqueries; use explicit joins", + "rationale": ""} + ], + "session": "User: same as SQL queries — for any database work, prefer CTEs over subqueries. And make sure to use explicit joins. This goes for Mongo and DuckDB too.", + "expected_plan": [ + {"op": "delete_user_playbook", "id": "pb_12"}, + {"op": "create_user_playbook", + "trigger_contains": ["database"], + "content_contains": ["CTEs", "explicit joins"], + "content_preserves_all": ["prefer CTEs over subqueries", "use explicit joins"]} + ], + "expected_reasoning_contains": ["extend", "broader scope", "expand"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_food_preference.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_food_preference.json new file mode 100644 index 00000000..4943dd64 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_food_preference.json @@ -0,0 +1,16 @@ +{ + "id": "supersede_food_preference", + "group": "group1_mutation", + "category": "supersede", + "existing_storage": [ + {"type": "profile", "id": "p_100", "content": "user likes Chinese food", "ttl": "infinity"} + ], + "session": "User: I've gone off Chinese food, can't stand it anymore.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_100"}, + {"op": "create_user_profile", + "content_contains": ["Chinese", "dislike"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["supersede", "no longer", "changed"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_job_role.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_job_role.json new file mode 100644 index 00000000..fceda207 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_job_role.json @@ -0,0 +1,16 @@ +{ + "id": "supersede_job_role", + "group": "group1_mutation", + "category": "supersede", + "existing_storage": [ + {"type": "profile", "id": "p_101", "content": "user is a software engineer at Acme", "ttl": "infinity"} + ], + "session": "User: I got promoted to staff engineer last week.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_101"}, + {"op": "create_user_profile", + "content_contains": ["staff engineer", "Acme"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["supersede", "promoted", "updated"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_location.json b/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_location.json new file mode 100644 index 00000000..ace1446d --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group1_mutation/supersede_location.json @@ -0,0 +1,16 @@ +{ + "id": "supersede_location", + "group": "group1_mutation", + "category": "supersede", + "existing_storage": [ + {"type": "profile", "id": "p_102", "content": "user lives in Austin, TX", "ttl": "infinity"} + ], + "session": "User: Just moved to Portland, OR.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_102"}, + {"op": "create_user_profile", + "content_contains": ["Portland"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["supersede", "moved", "replaced"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_code_review_style.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_code_review_style.json new file mode 100644 index 00000000..26df7b2a --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_code_review_style.json @@ -0,0 +1,14 @@ +{ + "id": "agent_playbook_code_review_style", + "group": "group2_supermemory", + "category": "agent_playbook_fallback", + "existing_storage": [ + {"type": "agent_playbook", "id": "ab_601", + "trigger": "agent is reviewing code", + "content": "look for missing test coverage; flag any new public API without docstrings", + "playbook_name": "default_agent_playbook"} + ], + "query": "How should I review this user's PR?", + "expected_answer_contains": ["coverage", "docstring"], + "expected_answer_excludes": ["no evidence"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_debugging_approach.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_debugging_approach.json new file mode 100644 index 00000000..89c6ad3f --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_debugging_approach.json @@ -0,0 +1,14 @@ +{ + "id": "agent_playbook_debugging_approach", + "group": "group2_supermemory", + "category": "agent_playbook_fallback", + "existing_storage": [ + {"type": "agent_playbook", "id": "ab_600", + "trigger": "agent is debugging a failing test", + "content": "start with the most recent diff; check the test's actual assertions before guessing at the code", + "playbook_name": "default_agent_playbook"} + ], + "query": "How should I approach debugging a test failure?", + "expected_answer_contains": ["recent diff", "assertions"], + "expected_answer_excludes": ["no evidence", "no memory"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_pair_with_user_pref.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_pair_with_user_pref.json new file mode 100644 index 00000000..752fcda7 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/agent_playbook_pair_with_user_pref.json @@ -0,0 +1,15 @@ +{ + "id": "agent_playbook_pair_with_user_pref", + "group": "group2_supermemory", + "category": "agent_playbook_fallback", + "existing_storage": [ + {"type": "profile", "id": "p_602", "content": "user is learning Rust", "ttl": "infinity"}, + {"type": "agent_playbook", "id": "ab_602", + "trigger": "user is learning a language", + "content": "give minimal examples first; offer idioms only after the user gets comfortable with syntax", + "playbook_name": "default_agent_playbook"} + ], + "query": "How should I help this user with Rust practice?", + "expected_answer_contains": ["minimal examples", "Rust"], + "expected_answer_excludes": [] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_code_style.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_code_style.json new file mode 100644 index 00000000..adf51be4 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_code_style.json @@ -0,0 +1,13 @@ +{ + "id": "constraint_broad_code_style", + "group": "group2_supermemory", + "category": "apply_nonmatching_constraint", + "existing_storage": [], + "session": "User: keep code examples short — under 40 lines, no exceptions. I'll ask for more if I need it.", + "expected_plan": [ + {"op": "create_user_playbook", + "trigger_contains": ["code", "example"], + "content_contains": ["40 lines", "short"]} + ], + "expected_reasoning_contains": ["preference", "rule"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_language_pref.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_language_pref.json new file mode 100644 index 00000000..338989bd --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_language_pref.json @@ -0,0 +1,13 @@ +{ + "id": "constraint_broad_language_pref", + "group": "group2_supermemory", + "category": "apply_nonmatching_constraint", + "existing_storage": [], + "session": "User: please give me recipe suggestions in metric units — I can't eyeball cups and ounces.", + "expected_plan": [ + {"op": "create_user_playbook", + "trigger_contains": ["recipe"], + "content_contains": ["metric"]} + ], + "expected_reasoning_contains": ["preference", "rule"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_time_pref.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_time_pref.json new file mode 100644 index 00000000..baf34b2b --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/constraint_broad_time_pref.json @@ -0,0 +1,13 @@ +{ + "id": "constraint_broad_time_pref", + "group": "group2_supermemory", + "category": "apply_nonmatching_constraint", + "existing_storage": [], + "session": "User: schedule any follow-up meetings for after 3pm PT — mornings are no-go.", + "expected_plan": [ + {"op": "create_user_playbook", + "trigger_contains": ["meeting", "schedule"], + "content_contains": ["after 3pm", "PT"]} + ], + "expected_reasoning_contains": ["preference", "rule"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_direct_negation.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_direct_negation.json new file mode 100644 index 00000000..d0e93f2a --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_direct_negation.json @@ -0,0 +1,16 @@ +{ + "id": "contradiction_direct_negation", + "group": "group2_supermemory", + "category": "contradiction_resolution", + "existing_storage": [ + {"type": "profile", "id": "p_410", "content": "user is a vegetarian", "ttl": "infinity"} + ], + "session": "User: correction — I'm not vegetarian, never have been. Not sure where you got that.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_410"}, + {"op": "create_user_profile", + "content_contains": ["not", "vegetarian"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["contradict", "correct"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_wrong_employer.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_wrong_employer.json new file mode 100644 index 00000000..281dee72 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_wrong_employer.json @@ -0,0 +1,16 @@ +{ + "id": "contradiction_wrong_employer", + "group": "group2_supermemory", + "category": "contradiction_resolution", + "existing_storage": [ + {"type": "profile", "id": "p_411", "content": "user works at Google", "ttl": "infinity"} + ], + "session": "User: I don't work at Google — I work at Meta. Got them mixed up earlier.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_411"}, + {"op": "create_user_profile", + "content_contains": ["Meta"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["contradict", "correct"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_wrong_language.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_wrong_language.json new file mode 100644 index 00000000..a005c020 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/contradiction_wrong_language.json @@ -0,0 +1,16 @@ +{ + "id": "contradiction_wrong_language", + "group": "group2_supermemory", + "category": "contradiction_resolution", + "existing_storage": [ + {"type": "profile", "id": "p_412", "content": "user's primary language is English", "ttl": "infinity"} + ], + "session": "User: actually my first language is Spanish, English is my second.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_412"}, + {"op": "create_user_profile", + "content_contains": ["Spanish"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["contradict", "correct"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_no_memory_yet.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_no_memory_yet.json new file mode 100644 index 00000000..c7d78705 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_no_memory_yet.json @@ -0,0 +1,9 @@ +{ + "id": "empty_result_no_memory_yet", + "group": "group2_supermemory", + "category": "empty_result_no_confab", + "existing_storage": [], + "query": "What's the user's favorite color?", + "expected_answer_contains": ["no evidence", "no memory", "don't have"], + "expected_answer_excludes": ["blue", "red", "green", "favorite"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_specific_but_absent.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_specific_but_absent.json new file mode 100644 index 00000000..7f7852e2 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_specific_but_absent.json @@ -0,0 +1,11 @@ +{ + "id": "empty_result_specific_but_absent", + "group": "group2_supermemory", + "category": "empty_result_no_confab", + "existing_storage": [ + {"type": "profile", "id": "p_521", "content": "user has a cat named Whiskers", "ttl": "infinity"} + ], + "query": "What is the user's dog's name?", + "expected_answer_contains": ["no", "dog"], + "expected_answer_excludes": ["Whiskers"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_unrelated_memory.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_unrelated_memory.json new file mode 100644 index 00000000..9ac435ce --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/empty_result_unrelated_memory.json @@ -0,0 +1,11 @@ +{ + "id": "empty_result_unrelated_memory", + "group": "group2_supermemory", + "category": "empty_result_no_confab", + "existing_storage": [ + {"type": "profile", "id": "p_520", "content": "user is a tax accountant", "ttl": "infinity"} + ], + "query": "What's the user's home address?", + "expected_answer_contains": ["no evidence", "don't have", "no information"], + "expected_answer_excludes": ["123", "street", "avenue", "road"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_location_and_pref.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_location_and_pref.json new file mode 100644 index 00000000..da2b0805 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_location_and_pref.json @@ -0,0 +1,15 @@ +{ + "id": "multi_hop_location_and_pref", + "group": "group2_supermemory", + "category": "multi_hop_search", + "existing_storage": [ + {"type": "profile", "id": "p_502", "content": "user lives in San Francisco", "ttl": "infinity"}, + {"type": "user_playbook", "id": "pb_502", + "trigger": "user asks for restaurant recommendations", + "content": "prefers walkable / no driving", + "rationale": ""} + ], + "query": "Can you recommend dinner spots for the user tonight?", + "expected_answer_contains": ["walk", "San Francisco"], + "expected_answer_excludes": ["drive", "driving"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_mixed_memory.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_mixed_memory.json new file mode 100644 index 00000000..9cd5dd5a --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_mixed_memory.json @@ -0,0 +1,15 @@ +{ + "id": "multi_hop_mixed_memory", + "group": "group2_supermemory", + "category": "multi_hop_search", + "existing_storage": [ + {"type": "profile", "id": "p_501", "content": "user is a JS/TS developer", "ttl": "infinity"}, + {"type": "user_playbook", "id": "pb_501", + "trigger": "user asks for code review", + "content": "prioritize type-safety issues", + "rationale": ""} + ], + "query": "What should I focus on when reviewing this user's code?", + "expected_answer_contains": ["type", "safety"], + "expected_answer_excludes": [] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_tooling_preference.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_tooling_preference.json new file mode 100644 index 00000000..d04dd622 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/multi_hop_tooling_preference.json @@ -0,0 +1,15 @@ +{ + "id": "multi_hop_tooling_preference", + "group": "group2_supermemory", + "category": "multi_hop_search", + "existing_storage": [ + {"type": "profile", "id": "p_500", "content": "user works in Python data science", "ttl": "infinity"}, + {"type": "user_playbook", "id": "pb_500", + "trigger": "user asks for library suggestions", + "content": "prefer polars and duckdb over pandas and sqlite", + "rationale": ""} + ], + "query": "What database tool should I suggest for the user's next analysis task?", + "expected_answer_contains": ["duckdb"], + "expected_answer_excludes": ["sqlite"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_moved_project.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_moved_project.json new file mode 100644 index 00000000..f1669eb2 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_moved_project.json @@ -0,0 +1,16 @@ +{ + "id": "temporal_supersede_moved_project", + "group": "group2_supermemory", + "category": "temporal_supersede", + "existing_storage": [ + {"type": "profile", "id": "p_401", "content": "user works on project Alpha", "ttl": "infinity"} + ], + "session": "User: Alpha shipped last month, now I'm leading Beta full-time.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_401"}, + {"op": "create_user_profile", + "content_contains": ["Beta"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["supersede", "now", "current"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_role_change.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_role_change.json new file mode 100644 index 00000000..0d14b8e7 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_role_change.json @@ -0,0 +1,16 @@ +{ + "id": "temporal_supersede_role_change", + "group": "group2_supermemory", + "category": "temporal_supersede", + "existing_storage": [ + {"type": "profile", "id": "p_402", "content": "user is a senior SWE at Acme", "ttl": "infinity"} + ], + "session": "User: I'm a principal engineer now, title changed last quarter.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_402"}, + {"op": "create_user_profile", + "content_contains": ["principal"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["supersede", "now", "changed"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_switch_tool.json b/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_switch_tool.json new file mode 100644 index 00000000..ff386fda --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group2_supermemory/temporal_supersede_switch_tool.json @@ -0,0 +1,16 @@ +{ + "id": "temporal_supersede_switch_tool", + "group": "group2_supermemory", + "category": "temporal_supersede", + "existing_storage": [ + {"type": "profile", "id": "p_400", "content": "user uses pandas for data work", "ttl": "infinity"} + ], + "session": "User: these days I'm all-in on polars — pandas is slow on my data sizes now.", + "expected_plan": [ + {"op": "delete_user_profile", "id": "p_400"}, + {"op": "create_user_profile", + "content_contains": ["polars"], + "ttl": "infinity"} + ], + "expected_reasoning_contains": ["supersede", "now", "switched"] +} diff --git a/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/almost_done.json b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/almost_done.json new file mode 100644 index 00000000..f4c33bed --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/almost_done.json @@ -0,0 +1,72 @@ +{ + "id": "almost_done", + "group": "group3_loop_behavior", + "category": "almost_done", + "existing_storage": [], + "session": "User: placeholder", + "mock_llm_responses": [ + { + "tool_calls": [ + {"id": "s1", "name": "search_user_profiles", "args": {"query": "food preferences", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "c1", "name": "create_user_profile", "args": {"content": "user likes sushi", "ttl": "infinity", "source_span": "I like sushi"}} + ] + }, + { + "tool_calls": [ + {"id": "s2", "name": "search_user_profiles", "args": {"query": "work hours schedule", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "c2", "name": "create_user_profile", "args": {"content": "user works 9am to 5pm PT", "ttl": "infinity", "source_span": "9am to 5pm"}} + ] + }, + { + "tool_calls": [ + {"id": "s3", "name": "search_user_profiles", "args": {"query": "location city", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "c3", "name": "create_user_profile", "args": {"content": "user lives in Seattle", "ttl": "infinity", "source_span": "Seattle"}} + ] + }, + { + "tool_calls": [ + {"id": "s4", "name": "search_user_profiles", "args": {"query": "hobbies interests", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "c4", "name": "create_user_profile", "args": {"content": "user enjoys hiking on weekends", "ttl": "infinity", "source_span": "hiking on weekends"}} + ] + }, + { + "tool_calls": [ + {"id": "s5", "name": "search_user_profiles", "args": {"query": "programming language preference", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "c5", "name": "create_user_profile", "args": {"content": "user prefers Python for scripting", "ttl": "infinity", "source_span": "Python for scripting"}} + ] + }, + { + "tool_calls": [ + {"id": "s6", "name": "search_user_profiles", "args": {"query": "job role team", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "c6", "name": "create_user_profile", "args": {"content": "user is a backend engineer on the platform team", "ttl": "infinity", "source_span": "backend engineer on the platform team"}} + ] + } + ], + "expected_outcome": "max_steps", + "expected_applied_count": 6, + "expected_violations": [] +} diff --git a/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/confused_garbage.json b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/confused_garbage.json new file mode 100644 index 00000000..1285cd39 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/confused_garbage.json @@ -0,0 +1,27 @@ +{ + "id": "confused_garbage", + "group": "group3_loop_behavior", + "category": "confused_garbage", + "existing_storage": [], + "session": "User: placeholder", + "mock_llm_responses": [ + { + "tool_calls": [ + {"id": "c1", "name": "delete_user_profile", "args": {"id": "p_999"}} + ] + }, + { + "tool_calls": [ + {"id": "c2", "name": "create_user_profile", "args": {"content": "x", "ttl": "infinity", "source_span": "y"}} + ] + }, + { + "tool_calls": [ + {"id": "c3", "name": "finish", "args": {}} + ] + } + ], + "expected_violations": ["A", "B"], + "expected_applied_count": 0, + "expected_outcome": "finish_tool" +} diff --git a/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/oscillated_self_correction.json b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/oscillated_self_correction.json new file mode 100644 index 00000000..65baeb51 --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/oscillated_self_correction.json @@ -0,0 +1,35 @@ +{ + "id": "oscillated_self_correction", + "group": "group3_loop_behavior", + "category": "oscillated_self_correction", + "existing_storage": [], + "session": "User: I think I like jazz. Actually wait, it's classical I prefer.", + "mock_llm_responses": [ + {"tool_calls": [ + {"id": "s1", "name": "search_user_profiles", "args": {"query": "music preferences", "top_k": 10}} + ]}, + {"tool_calls": [ + {"id": "c1", "name": "create_user_profile", "args": { + "content": "user likes jazz", + "ttl": "infinity", + "source_span": "I think I like jazz" + }} + ]}, + {"tool_calls": [ + {"id": "d1", "name": "delete_user_profile", "args": {"id": "tentative::profile::0"}} + ]}, + {"tool_calls": [ + {"id": "c2", "name": "create_user_profile", "args": { + "content": "user prefers classical music", + "ttl": "infinity", + "source_span": "I prefer classical" + }} + ]}, + {"tool_calls": [ + {"id": "f1", "name": "finish", "args": {}} + ]} + ], + "expected_outcome": "finish_tool", + "expected_applied_count": 1, + "expected_violations": [] +} diff --git a/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/stuck_in_search.json b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/stuck_in_search.json new file mode 100644 index 00000000..fbc5d22e --- /dev/null +++ b/tests/server/services/extraction/eval_fixtures/group3_loop_behavior/stuck_in_search.json @@ -0,0 +1,72 @@ +{ + "id": "stuck_in_search", + "group": "group3_loop_behavior", + "category": "stuck_in_search", + "existing_storage": [], + "session": "User: tell me what you know about me", + "mock_llm_responses": [ + { + "tool_calls": [ + {"id": "s1", "name": "search_user_profiles", "args": {"query": "general information about the user", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s2", "name": "search_user_profiles", "args": {"query": "user preferences and habits", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s3", "name": "search_user_profiles", "args": {"query": "work background and role", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s4", "name": "search_user_profiles", "args": {"query": "location and timezone", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s5", "name": "search_user_profiles", "args": {"query": "hobbies and personal interests", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s6", "name": "search_user_profiles", "args": {"query": "communication style and preferences", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s7", "name": "search_user_profiles", "args": {"query": "technical skills and tools", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s8", "name": "search_user_profiles", "args": {"query": "dietary restrictions and food choices", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s9", "name": "search_user_profiles", "args": {"query": "goals and ambitions", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s10", "name": "search_user_profiles", "args": {"query": "recent activities and updates", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s11", "name": "search_user_profiles", "args": {"query": "family and social context", "top_k": 10}} + ] + }, + { + "tool_calls": [ + {"id": "s12", "name": "search_user_profiles", "args": {"query": "long-term memory and past events", "top_k": 10}} + ] + } + ], + "expected_outcome": "max_steps", + "expected_applied_count": 0, + "expected_violations": [] +} diff --git a/tests/server/services/extraction/eval_runner.py b/tests/server/services/extraction/eval_runner.py new file mode 100644 index 00000000..f1b919f8 --- /dev/null +++ b/tests/server/services/extraction/eval_runner.py @@ -0,0 +1,337 @@ +"""Hand-crafted eval runner for agentic-v2. See spec §11.""" + +from __future__ import annotations + +import json +import uuid +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +_THIS_DIR = Path(__file__).resolve().parent +FIXTURES_ROOT = _THIS_DIR / "eval_fixtures" + + +def load_fixtures(group: str | None = None) -> list[dict[str, Any]]: + """Load all fixture JSONs under eval_fixtures/, optionally scoped to one group. + + Args: + group (str | None): Optional group subdirectory name (e.g. "group1_mutation"). + When None, all fixtures from all groups are returned. + + Returns: + list[dict[str, Any]]: Parsed fixture dicts sorted by path. + """ + root = FIXTURES_ROOT if group is None else FIXTURES_ROOT / group + return [json.loads(p.read_text()) for p in sorted(root.rglob("*.json"))] + + +# --------------------------------------------------------------------------- +# Scoring +# --------------------------------------------------------------------------- + + +def score_plan( + actual: list[dict[str, Any]], expected: list[dict[str, Any]] +) -> dict[str, Any]: + """Score an actual plan against an expected plan spec. + + Supports exact-match fields (``id``, ``ttl``) and fuzzy assertions: + ``content_contains``, ``content_preserves_all``, ``trigger_contains``. + + Args: + actual (list[dict[str, Any]]): Ops produced by the agent. + expected (list[dict[str, Any]]): Spec ops from the fixture's + ``expected_plan`` list. Each entry may contain fuzzy keys instead + of (or alongside) exact-match keys. + + Returns: + dict[str, Any]: ``{"semantic_match": bool, "failures": list[str]}``. + ``semantic_match`` is ``True`` when every expected op is satisfied. + """ + failures: list[str] = [] + + if len(actual) != len(expected): + failures.append( + f"op count mismatch: actual={len(actual)} expected={len(expected)}" + ) + return {"semantic_match": False, "failures": failures} + + semantic = True + for i, (a, e) in enumerate(zip(actual, expected, strict=False)): + if a.get("op") != e.get("op"): + failures.append( + f"op[{i}]: type mismatch — actual={a.get('op')!r} expected={e.get('op')!r}" + ) + semantic = False + continue + + # Exact-match fields + for field in ("id", "ttl"): + if field in e and a.get(field) != e[field]: + failures.append( + f"op[{i}].{field}: actual={a.get(field)!r} expected={e[field]!r}" + ) + semantic = False + + # Fuzzy: content_contains + content_lower = (a.get("content") or "").lower() + for substr in e.get("content_contains", []): + if substr.lower() not in content_lower: + failures.append(f"op[{i}]: content missing substring {substr!r}") + semantic = False + + # Fuzzy: content_preserves_all (lossless merge check) + for preserved in e.get("content_preserves_all", []): + if preserved.lower() not in content_lower: + failures.append(f"op[{i}]: lost preserved content {preserved!r}") + semantic = False + + # Fuzzy: trigger_contains + trigger_lower = (a.get("trigger") or "").lower() + for substr in e.get("trigger_contains", []): + if substr.lower() not in trigger_lower: + failures.append(f"op[{i}]: trigger missing substring {substr!r}") + semantic = False + + return {"semantic_match": semantic, "failures": failures} + + +def score_group3_fixture( + fixture: dict[str, Any], result: dict[str, Any] +) -> dict[str, Any]: + """Score a group3 loop-behavior fixture against the run_fixture result. + + Checks outcome, applied_count, and that expected violation codes are a + subset of observed codes. + + Args: + fixture (dict[str, Any]): The group3 fixture dict. + result (dict[str, Any]): Return value from :func:`run_fixture`. + + Returns: + dict[str, Any]: ``{"pass": bool, "failures": list[str]}``. + """ + failures: list[str] = [] + + expected_outcome = fixture.get("expected_outcome") + if result.get("outcome") != expected_outcome: + failures.append( + f"outcome mismatch: actual={result.get('outcome')!r} expected={expected_outcome!r}" + ) + + expected_count = fixture.get("expected_applied_count") + if result.get("applied_count") != expected_count: + failures.append( + f"applied_count mismatch: actual={result.get('applied_count')} expected={expected_count}" + ) + + expected_violations: set[str] = set(fixture.get("expected_violations", [])) + actual_violations: set[str] = set(result.get("violation_codes", [])) + missing = expected_violations - actual_violations + if missing: + failures.append(f"missing expected violation codes: {sorted(missing)}") + + return {"pass": not failures, "failures": failures} + + +# --------------------------------------------------------------------------- +# Storage seeding +# --------------------------------------------------------------------------- + + +def seed_storage(fixture: dict[str, Any], storage: Any, user_id: str) -> None: + """Write ``fixture["existing_storage"]`` entries into the given storage. + + Translates each entry into the appropriate entity and writes it via the + storage API. Supports ``profile``, ``user_playbook``, and + ``agent_playbook`` entry types. Unknown types are skipped with a warning. + + Args: + fixture (dict[str, Any]): Fixture dict (may contain ``existing_storage``). + storage: A storage instance (e.g. SQLiteStorage). + user_id (str): User ID to assign to profile and user_playbook rows. + """ + from reflexio.models.api_schema.common import NEVER_EXPIRES_TIMESTAMP + from reflexio.models.api_schema.domain.entities import ( + AgentPlaybook, + UserPlaybook, + UserProfile, + ) + from reflexio.models.api_schema.domain.enums import ProfileTimeToLive + + for entry in fixture.get("existing_storage", []): + entry_type = entry.get("type") + + if entry_type == "profile": + ttl_str = entry.get("ttl", "infinity") + try: + ttl = ProfileTimeToLive(ttl_str) + except ValueError: + ttl = ProfileTimeToLive.INFINITY + + profile = UserProfile( + profile_id=entry.get("id", str(uuid.uuid4())), + user_id=user_id, + content=entry.get("content", ""), + last_modified_timestamp=0, + generated_from_request_id="eval_seed", + profile_time_to_live=ttl, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source_span=entry.get("source_span"), + ) + storage.add_user_profile(user_id, [profile]) + + elif entry_type == "user_playbook": + playbook = UserPlaybook( + user_id=user_id, + agent_version="eval_v1", + request_id="eval_seed", + playbook_name=entry.get("playbook_name", "eval"), + content=entry.get("content", ""), + trigger=entry.get("trigger"), + rationale=entry.get("rationale"), + ) + storage.save_user_playbooks([playbook]) + + elif entry_type == "agent_playbook": + agent_playbook = AgentPlaybook( + agent_version="eval_v1", + playbook_name=entry.get("playbook_name", "eval"), + content=entry.get("content", ""), + trigger=entry.get("trigger"), + rationale=entry.get("rationale"), + ) + storage.save_agent_playbooks([agent_playbook]) + + else: + import logging + + logging.getLogger(__name__).warning( + "seed_storage: unknown entry type %r — skipping", entry_type + ) + + +# --------------------------------------------------------------------------- +# Mocked-LLM response helpers +# --------------------------------------------------------------------------- + + +def _mk_tool_call(id_: str, name: str, args: dict[str, Any]) -> MagicMock: + """Build a MagicMock resembling an LLM tool_call object. + + Args: + id_ (str): Tool call ID string. + name (str): Tool function name. + args (dict[str, Any]): Tool arguments (will be JSON-serialised). + + Returns: + MagicMock: Object with .id, .function.name, .function.arguments. + """ + tc = MagicMock() + tc.id = id_ + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args) + return tc + + +def _mk_resp(tool_calls_spec: list[dict[str, Any]]) -> MagicMock: + """Build a MagicMock LLM response containing a list of tool calls. + + Args: + tool_calls_spec (list[dict[str, Any]]): List of ``{"id", "name", "args"}`` + dicts as stored in fixture ``mock_llm_responses[*].tool_calls``. + + Returns: + MagicMock: Fake LLM response with ``.tool_calls`` and ``.content = None``. + """ + r = MagicMock() + r.tool_calls = [ + _mk_tool_call(tc["id"], tc["name"], tc["args"]) for tc in tool_calls_spec + ] + r.content = None + return r + + +# --------------------------------------------------------------------------- +# Main runner +# --------------------------------------------------------------------------- + + +def run_fixture( + fixture: dict[str, Any], + *, + client: Any, + prompt_manager: Any, + storage: Any, + user_id: str = "eval_user", + agent_version: str = "eval_v1", +) -> dict[str, Any]: + """Execute one eval fixture end-to-end. + + For Group 3 (``group3_loop_behavior``), this method scripts the mocked LLM + client from ``fixture["mock_llm_responses"]``, seeds storage, and drives + :class:`ExtractionAgent` to completion. + + For Groups 1 and 2, execution is stubbed — a real LLM or oracle mock is + required to evaluate those fixtures (out of Task 21 scope). + + Args: + fixture (dict[str, Any]): Parsed fixture dict from :func:`load_fixtures`. + client: LiteLLMClient (or MagicMock) — must have + ``generate_chat_response`` that can be scripted via ``side_effect``. + prompt_manager: PromptManager instance. + storage: BaseStorage instance (e.g. SQLiteStorage). + user_id (str): User ID to use when seeding + running. + agent_version (str): Agent version string passed to the agent. + + Returns: + dict[str, Any]: Keys: + - ``actual_plan`` — list of applied op dicts (empty for stub). + - ``outcome`` — ``"finish_tool"``, ``"max_steps"``, or ``"skipped"``. + - ``applied_count`` — number of applied ops. + - ``violation_codes`` — list of invariant code strings. + - ``notes`` (optional) — explanation for stubbed groups. + """ + from reflexio.server.services.extraction.extraction_agent import ExtractionAgent + + seed_storage(fixture, storage, user_id) + group = fixture.get("group", "") + + if group == "group3_loop_behavior": + responses = fixture.get("mock_llm_responses", []) + client.generate_chat_response.side_effect = [ + _mk_resp(r["tool_calls"]) for r in responses + ] + agent = ExtractionAgent( + client=client, + storage=storage, + prompt_manager=prompt_manager, + max_steps=len(responses), + ) + result = agent.run( + user_id=user_id, + agent_version=agent_version, + extractor_name="eval", + extraction_criteria="eval", + sessions_text=fixture.get("session", ""), + ) + return { + "actual_plan": [op.model_dump() for op in result.applied], + "outcome": result.outcome, + "applied_count": len(result.applied), + "violation_codes": [v.code for v in result.violations], + } + + # Group 1 / Group 2 — deferred (requires real LLM or oracle mock) + return { + "actual_plan": [], + "outcome": "skipped", + "applied_count": 0, + "violation_codes": [], + "notes": ( + f"group {group!r} execution requires real LLM or oracle mock" + " (out of Task 21 scope)" + ), + } diff --git a/tests/server/services/extraction/test_agentic_adapter.py b/tests/server/services/extraction/test_agentic_adapter.py new file mode 100644 index 00000000..de1a7e26 --- /dev/null +++ b/tests/server/services/extraction/test_agentic_adapter.py @@ -0,0 +1,777 @@ +"""Tests for the agentic-v2 AgenticExtractionRunner adapter. + +Three required tests (per Task 12 spec): +1. test_agentic_adapter_end_to_end_creates_profile — scripted LLM, real SQLite +2. test_agentic_adapter_triggers_playbook_aggregator — mocked aggregator +3. test_agentic_adapter_pre_filter_rejects_short_session — pre-flight gate + +Additional unit tests cover: +- force_extraction bypasses pre-filter +- multiple extractor configs each invoke ExtractionAgent +- skip_aggregation short-circuits aggregator +- agent failure degrades to warning (not exception) +- hard violations surface as warnings +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +from reflexio.models.api_schema.domain.entities import Interaction +from reflexio.models.api_schema.service_schemas import ( + PublishUserInteractionRequest, + Request, +) +from reflexio.models.config_schema import ( + Config, + PlaybookAggregatorConfig, + ProfileExtractorConfig, + StorageConfigSQLite, + UserPlaybookExtractorConfig, +) +from reflexio.server.services.extraction.agentic_adapter import AgenticExtractionRunner +from reflexio.server.services.extraction.plan import CommitResult, Violation + +# --------------------------------------------------------------------------- +# shared helpers +# --------------------------------------------------------------------------- + + +def _make_interaction(role: str, content: str, user_id: str = "u_test") -> Interaction: + return Interaction( + interaction_id=0, + user_id=user_id, + request_id="req_abc", + role=role, + content=content, + ) + + +def _make_request(session_id: str = "s1") -> Request: + return Request( + request_id="req_abc", + user_id="u_test", + source="cli", + agent_version="v1", + session_id=session_id, + ) + + +def _make_publish_request( + *, + force_extraction: bool = False, + skip_aggregation: bool = False, + user_id: str = "u_test", +) -> PublishUserInteractionRequest: + return PublishUserInteractionRequest( + user_id=user_id, + interaction_data_list=[{"role": "User", "content": "hi"}], # type: ignore[list-item] + source="cli", + agent_version="v1", + force_extraction=force_extraction, + skip_aggregation=skip_aggregation, + ) + + +def _make_runner( + storage: object = None, +) -> AgenticExtractionRunner: + """Build a runner with a mocked request_context.""" + rc = MagicMock() + rc.storage = storage if storage is not None else MagicMock() + rc.prompt_manager = MagicMock() + rc.prompt_manager.render_prompt.return_value = "stub prompt" + rc.configurator = MagicMock() + rc.org_id = "test-org" + + return AgenticExtractionRunner( + llm_client=MagicMock(), + request_context=rc, + ) + + +def _mk_tool_call(id_: str, name: str, args: dict) -> MagicMock: + tc = MagicMock() + tc.id = id_ + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args) + return tc + + +def _mk_tool_response(tool_calls: list, content: str | None = None) -> MagicMock: + resp = MagicMock() + resp.tool_calls = tool_calls + resp.content = content + return resp + + +# --------------------------------------------------------------------------- +# Test 1: end-to-end creates profile (real SQLite, scripted LLM) +# --------------------------------------------------------------------------- + + +def test_agentic_adapter_end_to_end_creates_profile(tmp_path): + """Scripted 3-turn LLM: search → create → finish. + + Invokes the runner with real SQLite storage; asserts the profile lands in + storage after the run completes. + """ + from reflexio.server.llm.litellm_client import LiteLLMClient, LiteLLMConfig + from reflexio.server.prompt.prompt_manager import PromptManager + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + user_id = "u_adapter_e2e" + store = SQLiteStorage( + org_id="test-org-e2e", db_path=str(tmp_path / "adapter_e2e.db") + ) + + # Real client (key doesn't matter — LLM is mocked via generate_chat_response) + import os + + os.environ.setdefault("ANTHROPIC_API_KEY", "test-key") + client = LiteLLMClient(LiteLLMConfig(model="claude-sonnet-4-6")) + pm = PromptManager() + + rc = MagicMock() + rc.storage = store + rc.prompt_manager = pm + rc.configurator = MagicMock() + rc.org_id = "test-org-e2e" + + runner = AgenticExtractionRunner( + llm_client=client, + request_context=rc, + ) + + # Script: search (empty result) → create profile → finish + scripted = [ + _mk_tool_response( + [ + _mk_tool_call( + "c1", "search_user_profiles", {"query": "food", "top_k": 10} + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "user likes sushi", + "ttl": "infinity", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + ] + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="test_profile_extractor", + extraction_definition_prompt="Extract food preferences.", + ) + ], + user_playbook_extractor_configs=[], + ) + + with patch.object(client, "generate_chat_response", side_effect=scripted): + warnings = runner.run( + publish_request=_make_publish_request( + force_extraction=True, user_id=user_id + ), + request_id="req_e2e", + new_interactions=[_make_interaction("User", "I love sushi", user_id)], + new_request=Request( + request_id="req_e2e", + user_id=user_id, + source="cli", + agent_version="v1", + session_id="s_e2e", + ), + config=cfg, + ) + + assert isinstance(warnings, list) + profiles = store.get_user_profile(user_id) + assert len(profiles) == 1, f"Expected 1 profile, got {len(profiles)}: {profiles}" + assert profiles[0].content == "user likes sushi" + + +# --------------------------------------------------------------------------- +# Test 2: aggregation triggered for configs with aggregation_config +# --------------------------------------------------------------------------- + + +def test_agentic_adapter_triggers_playbook_aggregator(): + """Runner triggers PlaybookAggregator.run once per config that has aggregation_config.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="with_agg", + extraction_definition_prompt="Extract playbook rules.", + aggregation_config=PlaybookAggregatorConfig(), + ), + UserPlaybookExtractorConfig( + extractor_name="without_agg", + extraction_definition_prompt="Extract playbook rules.", + ), + ], + ) + + # Stub ExtractionAgent.run to return empty CommitResult (no LLM calls needed) + empty_result = CommitResult(applied=[], violations=[], outcome="finish_tool") + fake_agg_cls = MagicMock() + fake_agg_cls.return_value.run.return_value = {} + + with ( + patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + return_value=empty_result, + ), + patch( + "reflexio.server.services.extraction.agentic_adapter.PlaybookAggregator", + fake_agg_cls, + ), + ): + runner.run( + publish_request=_make_publish_request(force_extraction=True), + request_id="req_agg", + new_interactions=[ + _make_interaction("User", "Trigger aggregation test"), + ], + new_request=_make_request(), + config=cfg, + ) + + # Aggregator constructed + run called exactly once (only "with_agg" has aggregation_config) + assert fake_agg_cls.return_value.run.call_count == 1 + call_arg = fake_agg_cls.return_value.run.call_args.args[0] + assert call_arg.playbook_name == "with_agg" + + +# --------------------------------------------------------------------------- +# Test 3: pre-filter rejects short session +# --------------------------------------------------------------------------- + + +def test_agentic_adapter_pre_filter_rejects_short_session(): + """When _cheap_should_run_reject returns a reason, runner exits early. + + ExtractionAgent must not be invoked. + """ + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="default", + extraction_definition_prompt="Extract facts.", + ) + ], + user_playbook_extractor_configs=[], + ) + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run" + ) as mock_agent_run: + warnings = runner.run( + publish_request=_make_publish_request( + force_extraction=False + ), # pre-filter active + request_id="req_prefilter", + new_interactions=[ + _make_interaction("Agent", "only agent turn, no user turn") + ], + new_request=_make_request(), + config=cfg, + ) + + assert warnings == [] + mock_agent_run.assert_not_called() + + +# --------------------------------------------------------------------------- +# Additional unit tests +# --------------------------------------------------------------------------- + + +def test_runner_force_extraction_bypasses_pre_filter(): + """force_extraction=True calls ExtractionAgent even with no User turns.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="only_profile", + extraction_definition_prompt="Extract facts.", + ) + ], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="only_playbook", + extraction_definition_prompt="Extract rules.", + ) + ], + ) + + empty_result = CommitResult(applied=[], violations=[], outcome="finish_tool") + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + return_value=empty_result, + ) as mock_agent_run: + runner.run( + publish_request=_make_publish_request(force_extraction=True), + request_id="req_force", + new_interactions=[_make_interaction("Agent", "no user turn")], + new_request=_make_request(), + config=cfg, + ) + + # 1 profile + 1 playbook config = 2 total agent calls; pre-filter was bypassed + assert mock_agent_run.call_count == 2 + + +def test_runner_iterates_all_extractor_configs(): + """Runner calls ExtractionAgent once per config across both profile + playbook lists.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="profile_one", + extraction_definition_prompt="profile prompt", + ), + ProfileExtractorConfig( + extractor_name="profile_two", + extraction_definition_prompt="profile prompt 2", + ), + ], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="playbook_one", + extraction_definition_prompt="playbook prompt", + ), + ], + ) + + empty_result = CommitResult(applied=[], violations=[], outcome="finish_tool") + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + return_value=empty_result, + ) as mock_agent_run: + runner.run( + publish_request=_make_publish_request(force_extraction=True), + request_id="req_multi", + new_interactions=[_make_interaction("User", "test content")], + new_request=_make_request(), + config=cfg, + ) + + # 2 profile configs + 1 playbook config = 3 total agent calls + assert mock_agent_run.call_count == 3 + called_names = {c.kwargs["extractor_name"] for c in mock_agent_run.call_args_list} + assert called_names == {"profile_one", "profile_two", "playbook_one"} + + +def test_runner_skip_aggregation_short_circuits(): + """skip_aggregation=True → PlaybookAggregator never constructed.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="with_agg", + extraction_definition_prompt="p", + aggregation_config=PlaybookAggregatorConfig(), + ), + ], + ) + + empty_result = CommitResult(applied=[], violations=[], outcome="finish_tool") + fake_agg_cls = MagicMock() + + with ( + patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + return_value=empty_result, + ), + patch( + "reflexio.server.services.extraction.agentic_adapter.PlaybookAggregator", + fake_agg_cls, + ), + ): + runner.run( + publish_request=_make_publish_request( + force_extraction=True, skip_aggregation=True + ), + request_id="req_skip_agg", + new_interactions=[_make_interaction("User", "hi")], + new_request=_make_request(), + config=cfg, + ) + + fake_agg_cls.assert_not_called() + + +def test_runner_agent_failure_becomes_warning(): + """Exception from ExtractionAgent.run is caught and surfaced as a warning.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="failing_extractor", + extraction_definition_prompt="Extract facts.", + ) + ], + user_playbook_extractor_configs=[], + ) + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + side_effect=RuntimeError("LLM timeout"), + ): + warnings = runner.run( + publish_request=_make_publish_request(force_extraction=True), + request_id="req_fail", + new_interactions=[_make_interaction("User", "test")], + new_request=_make_request(), + config=cfg, + ) + + assert any("failing_extractor" in w and "LLM timeout" in w for w in warnings) + + +def test_runner_hard_violation_surfaces_as_warning(): + """Hard invariant violations in CommitResult are appended to warnings.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="default", + extraction_definition_prompt="Extract facts.", + ) + ], + user_playbook_extractor_configs=[], + ) + + violation = Violation( + code="A", + severity="hard", + affected_op_indices=[0], + msg="create without prior search", + ) + result_with_violation = CommitResult( + applied=[], violations=[violation], outcome="finish_tool" + ) + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + return_value=result_with_violation, + ): + warnings = runner.run( + publish_request=_make_publish_request(force_extraction=True), + request_id="req_violation", + new_interactions=[_make_interaction("User", "test")], + new_request=_make_request(), + config=cfg, + ) + + assert any("violation A" in w for w in warnings) + + +def test_runner_soft_violation_does_not_surface_as_warning(): + """Soft invariant violations are logged but not added to warnings.""" + runner = _make_runner() + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="default", + extraction_definition_prompt="Extract facts.", + ) + ], + user_playbook_extractor_configs=[], + ) + + soft_violation = Violation( + # E (`inv_E_no_duplicate_creates`) is genuinely a soft invariant per + # invariants.py — using "B" here mismatched its real severity ("hard") + # and would have hidden a regression where soft violations were + # mistakenly upgraded to hard. + code="E", + severity="soft", + affected_op_indices=[0], + msg="soft warning", + ) + result_with_soft = CommitResult( + applied=[], violations=[soft_violation], outcome="finish_tool" + ) + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent.run", + return_value=result_with_soft, + ): + warnings = runner.run( + publish_request=_make_publish_request(force_extraction=True), + request_id="req_soft", + new_interactions=[_make_interaction("User", "test")], + new_request=_make_request(), + config=cfg, + ) + + # Soft violations must NOT appear in warnings + assert not any("violation" in w for w in warnings) + + +# --------------------------------------------------------------------------- +# Regression tests: per-kind tool constraint +# --------------------------------------------------------------------------- + + +def test_runner_profile_extractor_cannot_emit_playbook_ops(tmp_path): + """Profile extractor runs with PROFILE_EXTRACTION_TOOLS. + + A scripted create_user_playbook call from the LLM (in the profile extractor + turn) is rejected with 'unknown tool' by the registry; no playbook lands in + storage. + + Note: Config with ``user_playbook_extractor_configs=[]`` triggers the + schema validator which injects a default playbook extractor. We account + for that by scripting a second set of 2 turns (search → finish) for the + default playbook extractor so the scripted list is not exhausted early. + """ + import os + + from reflexio.server.llm.litellm_client import LiteLLMClient, LiteLLMConfig + from reflexio.server.prompt.prompt_manager import PromptManager + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + user_id = "u_profile_constraint" + store = SQLiteStorage( + org_id="test-org-pc", db_path=str(tmp_path / "profile_constraint.db") + ) + + os.environ.setdefault("ANTHROPIC_API_KEY", "test-key") + client = LiteLLMClient(LiteLLMConfig(model="claude-sonnet-4-6")) + pm = PromptManager() + + rc = MagicMock() + rc.storage = store + rc.prompt_manager = pm + rc.configurator = MagicMock() + rc.org_id = "test-org-pc" + + runner = AgenticExtractionRunner(llm_client=client, request_context=rc) + + # Turn order (2 extractors run in sequence — profile first, playbook second): + # Profile extractor turns (PROFILE_EXTRACTION_TOOLS): + # 1. search_user_profiles + # 2. create_user_playbook ← forbidden, returns {"error": "unknown tool: ..."} + # 3. finish + # Default playbook extractor turns (PLAYBOOK_EXTRACTION_TOOLS): + # 4. search_user_playbooks + # 5. finish + scripted = [ + # --- profile extractor --- + _mk_tool_response( + [ + _mk_tool_call( + "c1", "search_user_profiles", {"query": "food", "top_k": 10} + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_playbook", # forbidden in PROFILE_EXTRACTION_TOOLS + { + "trigger": "ask about food", + "content": "suggest sushi", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + # --- default playbook extractor (no ops) --- + _mk_tool_response( + [ + _mk_tool_call( + "c4", "search_user_playbooks", {"query": "food", "top_k": 10} + ) + ] + ), + _mk_tool_response([_mk_tool_call("c5", "finish", {})]), + ] + + cfg = Config( + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="profile_only", + extraction_definition_prompt="Extract food preferences.", + ) + ], + # Empty list triggers default playbook extractor injection via schema validator. + # This is expected behaviour; we script for it explicitly above. + user_playbook_extractor_configs=[], + ) + + with patch.object(client, "generate_chat_response", side_effect=scripted): + runner.run( + publish_request=_make_publish_request( + force_extraction=True, user_id=user_id + ), + request_id="req_pc", + new_interactions=[_make_interaction("User", "I love sushi", user_id)], + new_request=Request( + request_id="req_pc", + user_id=user_id, + source="cli", + agent_version="v1", + session_id="s_pc", + ), + config=cfg, + ) + + # The forbidden create_user_playbook was rejected — zero playbooks in storage. + playbooks = store.get_user_playbooks(user_id=user_id) + assert playbooks == [], ( + f"Profile extractor must not emit playbooks; got: {playbooks}" + ) + + +def test_runner_playbook_extractor_cannot_emit_profile_ops(tmp_path): + """Playbook extractor runs with PLAYBOOK_EXTRACTION_TOOLS. + + A scripted create_user_profile call from the LLM (in the playbook extractor + turn) is rejected with 'unknown tool' by the registry; no profile lands in + storage. + + Note: Config with ``profile_extractor_configs=[]`` triggers the schema + validator which injects a default profile extractor. We account for that + by scripting a first set of 2 turns (search → finish) for the default + profile extractor, then 3 turns for the explicit playbook extractor. + """ + import os + + from reflexio.server.llm.litellm_client import LiteLLMClient, LiteLLMConfig + from reflexio.server.prompt.prompt_manager import PromptManager + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + user_id = "u_playbook_constraint" + store = SQLiteStorage( + org_id="test-org-plc", db_path=str(tmp_path / "playbook_constraint.db") + ) + + os.environ.setdefault("ANTHROPIC_API_KEY", "test-key") + client = LiteLLMClient(LiteLLMConfig(model="claude-sonnet-4-6")) + pm = PromptManager() + + rc = MagicMock() + rc.storage = store + rc.prompt_manager = pm + rc.configurator = MagicMock() + rc.org_id = "test-org-plc" + + runner = AgenticExtractionRunner(llm_client=client, request_context=rc) + + # Turn order (2 extractors run in sequence — profile first, playbook second): + # Default profile extractor turns (PROFILE_EXTRACTION_TOOLS, no ops): + # 1. search_user_profiles + # 2. finish + # Playbook extractor turns (PLAYBOOK_EXTRACTION_TOOLS): + # 3. search_user_playbooks + # 4. create_user_profile ← forbidden, returns {"error": "unknown tool: ..."} + # 5. finish + scripted = [ + # --- default profile extractor (no ops) --- + _mk_tool_response( + [ + _mk_tool_call( + "c1", "search_user_profiles", {"query": "food", "top_k": 10} + ) + ] + ), + _mk_tool_response([_mk_tool_call("c2", "finish", {})]), + # --- playbook extractor --- + _mk_tool_response( + [ + _mk_tool_call( + "c3", "search_user_playbooks", {"query": "food", "top_k": 10} + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c4", + "create_user_profile", # forbidden in PLAYBOOK_EXTRACTION_TOOLS + { + "content": "user likes sushi", + "ttl": "infinity", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c5", "finish", {})]), + ] + + cfg = Config( + storage_config=StorageConfigSQLite(), + # Empty list triggers default profile extractor injection via schema validator. + # This is expected behaviour; we script for it explicitly above. + profile_extractor_configs=[], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="playbook_only", + extraction_definition_prompt="Extract behavioral rules.", + ) + ], + ) + + with patch.object(client, "generate_chat_response", side_effect=scripted): + runner.run( + publish_request=_make_publish_request( + force_extraction=True, user_id=user_id + ), + request_id="req_plc", + new_interactions=[_make_interaction("User", "I love sushi", user_id)], + new_request=Request( + request_id="req_plc", + user_id=user_id, + source="cli", + agent_version="v1", + session_id="s_plc", + ), + config=cfg, + ) + + # The forbidden create_user_profile was rejected — zero profiles in storage. + profiles = store.get_user_profile(user_id) + assert profiles == [], f"Playbook extractor must not emit profiles; got: {profiles}" diff --git a/tests/server/services/extraction/test_agentic_v2_e2e.py b/tests/server/services/extraction/test_agentic_v2_e2e.py new file mode 100644 index 00000000..1a0c6d8a --- /dev/null +++ b/tests/server/services/extraction/test_agentic_v2_e2e.py @@ -0,0 +1,407 @@ +"""End-to-end test for agentic-v2 via GenerationService.run. + +Exercises the full publish flow (gate -> config iteration -> windowing +-> ExtractionAgent -> commit -> aggregator trigger) with a mocked LLM. +Verifies storage state + aggregator invocation. +""" + +from __future__ import annotations + +import json +import os +import tempfile +from unittest.mock import MagicMock, patch + +from reflexio.models.api_schema.service_schemas import ( + InteractionData, + PublishUserInteractionRequest, +) +from reflexio.models.config_schema import ( + Config, + PlaybookAggregatorConfig, + ProfileExtractorConfig, + StorageConfigSQLite, + UserPlaybookExtractorConfig, +) +from reflexio.server.api_endpoints.request_context import RequestContext +from reflexio.server.llm.litellm_client import LiteLLMClient, LiteLLMConfig +from reflexio.server.services.generation_service import GenerationService + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _mk_tool_call(id_: str, name: str, args: dict) -> MagicMock: + tc = MagicMock() + tc.id = id_ + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args) + return tc + + +def _mk_resp(tool_calls: list, content: str | None = None) -> MagicMock: + r = MagicMock() + r.tool_calls = tool_calls + r.content = content + return r + + +def _make_agentic_config() -> Config: + return Config( + extraction_backend="agentic", + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="e2e_profile", + extraction_definition_prompt="Extract user facts from the session.", + ), + ], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="e2e_playbook", + extraction_definition_prompt="Extract behavioral preferences.", + aggregation_config=PlaybookAggregatorConfig(), + ), + ], + ) + + +def _make_scripted_client(responses: list) -> LiteLLMClient: + """Build a real LiteLLMClient whose generate_chat_response is scripted. + + Scopes ``OPENAI_API_KEY`` to client construction via ``patch.dict`` so + the env mutation does not leak into other tests in the same process + (which would make test ordering matter). + """ + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False): + client = LiteLLMClient(LiteLLMConfig(model="gpt-4o-mini")) + client.generate_chat_response = MagicMock(side_effect=responses) # type: ignore[method-assign] + return client + + +# --------------------------------------------------------------------------- +# Test 1: full flow — profile + playbook created, aggregator triggered +# --------------------------------------------------------------------------- + + +def test_e2e_agentic_v2_full_flow(tmp_path): + """Publish a session with extraction_backend='agentic'; verify storage + aggregator. + + Scripts 6 LLM turns (3 per extractor: search -> create -> finish) and + asserts that: + - A profile with the expected content is written to storage. + - A user playbook with the expected content is written to storage. + - PlaybookAggregator.run is invoked at least once. + - No unexpected warnings are returned. + """ + user_id = "e2e_user" + org_id = "e2e_org" + + # 6 scripted turns: 3 for profile extractor, 3 for playbook extractor. + scripted = [ + # --- profile extractor --- + _mk_resp( + [ + _mk_tool_call( + "c1", + "search_user_profiles", + {"query": "food preferences", "top_k": 10}, + ) + ] + ), + _mk_resp( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "user likes sushi", + "ttl": "infinity", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_resp([_mk_tool_call("c3", "finish", {})]), + # --- playbook extractor --- + _mk_resp( + [ + _mk_tool_call( + "c4", + "search_user_playbooks", + {"query": "food preferences", "top_k": 10}, + ) + ] + ), + _mk_resp( + [ + _mk_tool_call( + "c5", + "create_user_playbook", + { + "trigger": "user asks about food", + "content": "suggest sushi-related options", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_resp([_mk_tool_call("c6", "finish", {})]), + ] + + client = _make_scripted_client(scripted) + + with tempfile.TemporaryDirectory() as temp_dir: + request_context = RequestContext(org_id=org_id, storage_base_dir=temp_dir) + gs = GenerationService(llm_client=client, request_context=request_context) + # Inject agentic Config; bypass disk-based configurator. + gs.configurator.get_config = MagicMock(return_value=_make_agentic_config()) # type: ignore[method-assign] + + with patch( + "reflexio.server.services.extraction.agentic_adapter.PlaybookAggregator" + ) as mock_agg_cls: + mock_agg = MagicMock() + mock_agg_cls.return_value = mock_agg + + request = PublishUserInteractionRequest( + user_id=user_id, + interaction_data_list=[ + InteractionData( + role="User", + content="I love sushi — please always recommend it when I ask about food.", + ), + InteractionData( + role="Assistant", + content="Noted! I'll keep your sushi preference in mind.", + ), + ], + session_id="e2e_sid", + force_extraction=True, + ) + result = gs.run(request) + + # --- profile assertion --- + assert request_context.storage is not None + profiles = request_context.storage.get_user_profile(user_id) + assert any("sushi" in (p.content or "").lower() for p in profiles), ( + f"expected a sushi profile; got: {[p.content for p in profiles]}" + ) + + # Provenance: agentic-extracted profiles must carry the publish + # request_id so retrieval can trace back to the source publish (this + # is what LongMemEval-style recall@K depends on). + for p in profiles: + assert p.generated_from_request_id == result.request_id, ( + f"profile {p.profile_id} has stale generated_from_request_id " + f"{p.generated_from_request_id!r}, expected {result.request_id!r}" + ) + + # --- playbook assertion --- + playbooks = request_context.storage.get_user_playbooks(user_id=user_id) + assert any("sushi" in (pb.content or "").lower() for pb in playbooks), ( + f"expected a sushi playbook; got: {[pb.content for pb in playbooks]}" + ) + + # Mirror provenance assertion for playbooks. + for pb in playbooks: + assert pb.request_id == result.request_id, ( + f"playbook {pb.user_playbook_id} has stale request_id " + f"{pb.request_id!r}, expected {result.request_id!r}" + ) + + # --- aggregator triggered --- + assert mock_agg.run.call_count >= 1, ( + "PlaybookAggregator.run should have been called at least once" + ) + + # --- no unexpected warnings --- + assert not result.warnings, f"unexpected warnings: {result.warnings}" + + +# --------------------------------------------------------------------------- +# Test 2: extraction skipped when pre-filter rejects short session +# --------------------------------------------------------------------------- + + +def test_e2e_agentic_v2_extraction_agent_not_invoked_for_trivial_session(tmp_path): + """Pre-filter rejects short-content session; ExtractionAgent is never called. + + Uses force_extraction=False with very short user content (< 30 chars) to + trigger the 'all_user_turns_too_short' pre-filter path inside + AgenticExtractionRunner. ExtractionAgent must not be constructed or called. + + Choice: we exercise the real _cheap_should_run_reject path (not empty + interaction_data_list, which would be rejected by Pydantic min_length=1). + """ + user_id = "e2e_user2" + org_id = "e2e_org2" + + # No LLM turns should be consumed. + client = _make_scripted_client([]) + + with tempfile.TemporaryDirectory() as temp_dir: + request_context = RequestContext(org_id=org_id, storage_base_dir=temp_dir) + gs = GenerationService(llm_client=client, request_context=request_context) + gs.configurator.get_config = MagicMock(return_value=_make_agentic_config()) # type: ignore[method-assign] + + with patch( + "reflexio.server.services.extraction.agentic_adapter.ExtractionAgent" + ) as mock_agent_cls: + request = PublishUserInteractionRequest( + user_id=user_id, + interaction_data_list=[ + # Short user content (< 30 chars) → pre-filter rejects. + InteractionData(role="User", content="hi"), + ], + session_id="e2e_sid2", + force_extraction=False, # pre-filter active + ) + result = gs.run(request) + + # ExtractionAgent was never instantiated. + mock_agent_cls.assert_not_called() + + # No profiles persisted. + assert request_context.storage is not None + profiles = request_context.storage.get_user_profile(user_id) + assert profiles == [], f"expected no profiles; got {profiles}" + + # Result must not have raised (warnings may be empty or trivial). + assert result.request_id is not None + + +# --------------------------------------------------------------------------- +# Test 3: one rule → exactly one playbook (tool constraint regression) +# --------------------------------------------------------------------------- + + +def test_e2e_one_rule_produces_exactly_one_playbook(tmp_path): + """Single publish, single behavioural rule, two extractor configs enabled. + + Profile extractor: search_user_profiles → create_user_profile → finish. + Playbook extractor: search_user_playbooks → create_user_playbook → finish. + + Because PROFILE_EXTRACTION_TOOLS forbids create_user_playbook, the profile + extractor cannot accidentally emit a second playbook even if the scripted LLM + tried to. Only the playbook extractor's create_user_playbook call succeeds, + so exactly one UserPlaybook lands in storage. + """ + user_id = "e2e_user3" + org_id = "e2e_org3" + + # 6 scripted turns: + # profile extractor (3): search_user_profiles → create_profile → finish + # playbook extractor (3): search_playbooks → create_playbook → finish + scripted = [ + # --- profile extractor: only emits a profile --- + _mk_resp( + [ + _mk_tool_call( + "c1", + "search_user_profiles", + {"query": "on-call schedule", "top_k": 10}, + ) + ] + ), + _mk_resp( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "user is on-call this week", + "ttl": "one_week", + "source_span": "on-call this week", + }, + ) + ] + ), + _mk_resp([_mk_tool_call("c3", "finish", {})]), + # --- playbook extractor: emits one playbook --- + _mk_resp( + [ + _mk_tool_call( + "c4", + "search_user_playbooks", + {"query": "code review scheduling", "top_k": 10}, + ) + ] + ), + _mk_resp( + [ + _mk_tool_call( + "c5", + "create_user_playbook", + { + "trigger": "code review scheduling", + "content": "avoid scheduling code reviews before 10am", + "source_span": "no code review before 10am", + }, + ) + ] + ), + _mk_resp([_mk_tool_call("c6", "finish", {})]), + ] + + client = _make_scripted_client(scripted) + + config = Config( + extraction_backend="agentic", + storage_config=StorageConfigSQLite(), + profile_extractor_configs=[ + ProfileExtractorConfig( + extractor_name="oncall_profile", + extraction_definition_prompt="Extract on-call and schedule facts.", + ), + ], + user_playbook_extractor_configs=[ + UserPlaybookExtractorConfig( + extractor_name="scheduling_rules", + extraction_definition_prompt="Extract scheduling behavioural rules.", + ), + ], + ) + + with tempfile.TemporaryDirectory() as temp_dir: + request_context = RequestContext(org_id=org_id, storage_base_dir=temp_dir) + gs = GenerationService(llm_client=client, request_context=request_context) + gs.configurator.get_config = MagicMock(return_value=config) # type: ignore[method-assign] + + request = PublishUserInteractionRequest( + user_id=user_id, + interaction_data_list=[ + InteractionData( + role="User", + content=( + "I'm on-call this week. " + "Please avoid scheduling code reviews before 10am for me." + ), + ), + InteractionData( + role="Assistant", + content="Noted — I'll avoid scheduling code reviews before 10am.", + ), + ], + session_id="e2e_sid3", + force_extraction=True, + ) + result = gs.run(request) + + # Exactly one playbook — the profile extractor's PROFILE_EXTRACTION_TOOLS + # forbids create_user_playbook so only the playbook extractor's call lands. + assert request_context.storage is not None + playbooks = request_context.storage.get_user_playbooks(user_id=user_id) + assert len(playbooks) == 1, ( + f"Expected exactly 1 playbook; got {len(playbooks)}: {[pb.content for pb in playbooks]}" + ) + + # Profile content must not contain behavioural guidance markers. + profiles = request_context.storage.get_user_profile(user_id) + assert len(profiles) == 1, ( + f"Expected exactly 1 profile; got {len(profiles)}: {[p.content for p in profiles]}" + ) + + # No unexpected warnings. + assert not result.warnings, f"unexpected warnings: {result.warnings}" diff --git a/tests/server/services/extraction/test_eval_runner.py b/tests/server/services/extraction/test_eval_runner.py new file mode 100644 index 00000000..3398b99b --- /dev/null +++ b/tests/server/services/extraction/test_eval_runner.py @@ -0,0 +1,101 @@ +"""Unit tests for the eval runner — load, score_plan, group3 replay.""" + +from __future__ import annotations + +from tests.server.services.extraction.eval_runner import ( + load_fixtures, + run_fixture, + score_plan, +) + + +def test_load_fixtures_group1_returns_12(): + fixtures = load_fixtures(group="group1_mutation") + assert len(fixtures) == 12 + categories = {f["category"] for f in fixtures} + assert categories == {"supersede", "merge", "delete", "playbook_expansion"} + + +def test_load_fixtures_group2_returns_18(): + fixtures = load_fixtures(group="group2_supermemory") + assert len(fixtures) == 18 + + +def test_load_fixtures_group3_returns_4(): + fixtures = load_fixtures(group="group3_loop_behavior") + assert len(fixtures) == 4 + + +def test_load_fixtures_all_returns_34(): + fixtures = load_fixtures() + assert len(fixtures) == 12 + 18 + 4 + + +def test_score_plan_exact_match(): + actual = [ + {"op": "delete_user_profile", "id": "p_10"}, + {"op": "create_user_profile", "content": "new fact", "ttl": "infinity"}, + ] + expected = [ + {"op": "delete_user_profile", "id": "p_10"}, + {"op": "create_user_profile", "content_contains": ["new"], "ttl": "infinity"}, + ] + result = score_plan(actual, expected) + assert result["semantic_match"] is True + + +def test_score_plan_content_preserves_all_catches_lossy_merge(): + """playbook_expansion must preserve all prior instructions.""" + actual = [ + {"op": "create_user_playbook", "trigger": "code", "content": "use TypeScript"} + ] + expected = [ + { + "op": "create_user_playbook", + "trigger_contains": ["code"], + "content_contains": ["TypeScript"], + "content_preserves_all": ["show examples"], + } + ] + result = score_plan(actual, expected) + assert result["semantic_match"] is False + assert any("show examples" in f for f in result["failures"]) + + +def test_score_plan_op_count_mismatch(): + actual = [{"op": "delete_user_profile", "id": "p_10"}] + expected = [ + {"op": "delete_user_profile", "id": "p_10"}, + {"op": "create_user_profile", "content_contains": ["x"], "ttl": "infinity"}, + ] + result = score_plan(actual, expected) + assert result["semantic_match"] is False + assert any("op count" in f for f in result["failures"]) + + +def test_score_plan_op_type_mismatch(): + actual = [{"op": "create_user_profile", "content": "x", "ttl": "infinity"}] + expected = [{"op": "delete_user_profile", "id": "p_10"}] + result = score_plan(actual, expected) + assert result["semantic_match"] is False + + +def test_run_fixture_group3_confused_garbage(tmp_path): + """Group 3 replay: confused_garbage should hit A + B violations, commit 0 ops.""" + from unittest.mock import MagicMock + + from reflexio.server.prompt.prompt_manager import PromptManager + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + fixtures = load_fixtures(group="group3_loop_behavior") + fixture = next(f for f in fixtures if f["id"] == "confused_garbage") + storage = SQLiteStorage(org_id="eval-org", db_path=str(tmp_path / "eval.db")) + pm = PromptManager() + client = MagicMock() + client.config = MagicMock() + client.config.api_key_config = None + + result = run_fixture(fixture, client=client, prompt_manager=pm, storage=storage) + assert result["outcome"] == "finish_tool" + assert result["applied_count"] == 0 + assert set(result["violation_codes"]) >= {"A", "B"} diff --git a/tests/server/services/extraction/test_extraction_agent.py b/tests/server/services/extraction/test_extraction_agent.py new file mode 100644 index 00000000..4182ef97 --- /dev/null +++ b/tests/server/services/extraction/test_extraction_agent.py @@ -0,0 +1,460 @@ +"""Integration tests for ExtractionAgent. Uses mocked LLM + real SQLite storage.""" + +import json +from unittest.mock import MagicMock + +import pytest + +from reflexio.server.services.extraction.extraction_agent import ExtractionAgent + + +@pytest.fixture +def temp_storage(tmp_path): + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + return SQLiteStorage(org_id="test-org", db_path=str(tmp_path / "ext.db")) + + +@pytest.fixture +def prompt_manager(): + from reflexio.server.prompt.prompt_manager import PromptManager + + return PromptManager() + + +@pytest.fixture +def llm_client(): + """Mocked LLM client that returns scripted tool calls.""" + client = MagicMock() + client.config = MagicMock() + client.config.api_key_config = None + return client + + +def _mk_tool_response(tool_calls, content=None): + """Construct a fake LLM response shape matching run_tool_loop expectations.""" + resp = MagicMock() + resp.tool_calls = tool_calls + resp.content = content + return resp + + +def _mk_tool_call(id_, name, args_dict): + tc = MagicMock() + tc.id = id_ + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args_dict) + return tc + + +def test_extraction_agent_happy_path_new_profile( + temp_storage, prompt_manager, llm_client +): + """Session: user states a new fact. Agent searches (empty), creates, finishes.""" + llm_client.generate_chat_response.side_effect = [ + _mk_tool_response( + [ + _mk_tool_call( + "c1", + "search_user_profiles", + {"query": "food preferences", "top_k": 10}, + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "user likes sushi", + "ttl": "infinity", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + ] + + agent = ExtractionAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + max_steps=12, + ) + result = agent.run( + user_id="u_1", + agent_version="v1", + extractor_name="default", + extraction_criteria="Extract food preferences.", + sessions_text="User: I love sushi", + ) + + assert result.outcome == "finish_tool" + assert len(result.applied) == 1 + # Profile landed in storage + assert len(temp_storage.get_user_profile("u_1")) == 1 + + +def test_extraction_agent_invariant_blocks_ungrounded_create( + temp_storage, prompt_manager, llm_client +): + """Agent skips search, tries to create — invariant A drops it.""" + llm_client.generate_chat_response.side_effect = [ + _mk_tool_response( + [ + _mk_tool_call( + "c1", + "create_user_profile", + { + "content": "x", + "ttl": "infinity", + "source_span": "y", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c2", "finish", {})]), + ] + + agent = ExtractionAgent( + client=llm_client, storage=temp_storage, prompt_manager=prompt_manager + ) + result = agent.run( + user_id="u_1", + agent_version="v1", + extractor_name="default", + extraction_criteria="x", + sessions_text="User: whatever", + ) + assert result.outcome == "finish_tool" + assert len(result.applied) == 0 + assert any(v.code == "A" for v in result.violations) + + +def test_extraction_agent_max_steps_still_commits_valid_ops( + temp_storage, prompt_manager, llm_client +): + """Loop hits max_steps with partially valid plan — plan commits per spec §7.""" + + # Script 3 turns that each do search + create, never call finish + def _turn_script(query): + return _mk_tool_response( + [ + _mk_tool_call( + "c", "search_user_profiles", {"query": query, "top_k": 10} + ), + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": f"fact about {query}", + "ttl": "infinity", + "source_span": query, + }, + ), + ] + ) + + llm_client.generate_chat_response.side_effect = [ + _turn_script(f"q_{i}") for i in range(5) + ] + + agent = ExtractionAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + max_steps=3, # force max_steps before finish + ) + result = agent.run( + user_id="u_1", + agent_version="v1", + extractor_name="default", + extraction_criteria="x", + sessions_text="User: test", + ) + assert result.outcome == "max_steps" + assert len(result.applied) >= 1 + + +def test_extraction_agent_prompt_frames_self_improvement(prompt_manager): + """Sanity: extraction prompt opening must frame extraction around agent + self-improvement, not 'memory storage'.""" + out = prompt_manager.render_prompt( + "extraction_agent", + variables={ + "sessions": "User: hi", + "extraction_criteria": "extract facts", + "extraction_kind": "UserProfile", + "max_steps": "4", + }, + ) + assert "improve over time" in out or "self-improv" in out + assert "memory extractor" not in out.lower() + + +def test_extraction_agent_prompt_forbids_profile_rule_overlap(prompt_manager): + """Sanity (v1.3.0): prompt must carry the anti-pattern examples for + rule-shaped profile content and the 'no overlap' rule. Guards against + regression to the earlier bundled-fact / rule-in-profile behaviour.""" + out = prompt_manager.render_prompt( + "extraction_agent", + variables={ + "sessions": "User: hi", + "extraction_criteria": "extract facts", + "extraction_kind": "UserProfile", + "max_steps": "4", + }, + ) + # One-fact-per-profile rule must be present. + assert "One fact per profile" in out + # No-overlap rule between profile and playbook. + assert "No overlap between profile and playbook" in out + # The prompt must include some anti-pattern guidance distinguishing + # rule-shaped from fact-shaped content. The specific example string + # is allowed to evolve via Phase 27 tuning, so we check for structural + # markers (the rule wording) rather than a single example. + + +def test_extraction_agent_prompt_specifies_playbook_format(prompt_manager): + """Sanity (v1.4.0): prompt must carry the Agent-Skills-inspired format + guidance for UserPlaybook trigger + content + rationale. Guards against + regression to the earlier unstructured semicolon-delimited shape.""" + out = prompt_manager.render_prompt( + "extraction_agent", + variables={ + "sessions": "User: hi", + "extraction_criteria": "extract rules", + "extraction_kind": "UserPlaybook", + "max_steps": "4", + }, + ) + # The Playbook format section must be present. + assert "Playbook format" in out + # Trigger guidance — imperative conditional phrasing must be required; + # the proposer is allowed to evolve specific examples. + assert "imperative conditional phrasing" in out + # Content guidance — markdown bullet list for independent instructions. + assert "markdown bullet list" in out + # Rationale guidance — one sentence explaining WHY, not what. + assert "one sentence" in out.lower() + + +def test_extraction_agent_emits_summary_info_line( + caplog, temp_storage, prompt_manager, llm_client +): + """Each run emits ONE INFO line starting with 'extraction_agent[' that + contains elapsed_ms, turns, tools, outcome, applied, violations, usage.""" + import logging + + llm_client.generate_chat_response.side_effect = [ + _mk_tool_response( + [ + _mk_tool_call( + "c1", + "search_user_profiles", + {"query": "food preferences", "top_k": 10}, + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "user likes sushi", + "ttl": "infinity", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + ] + + agent = ExtractionAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + max_steps=12, + ) + + with caplog.at_level( + logging.INFO, logger="reflexio.server.services.extraction.extraction_agent" + ): + agent.run( + user_id="u_summary", + agent_version="v1", + extractor_name="food", + extraction_criteria="Extract food preferences.", + sessions_text="User: I love sushi", + ) + + summary = [ + r for r in caplog.records if r.getMessage().startswith("extraction_agent[") + ] + assert len(summary) == 1, ( + f"Expected 1 summary line, got: {[r.getMessage() for r in summary]}" + ) + msg = summary[0].getMessage() + assert "elapsed_ms=" in msg + assert "turns=" in msg + assert "tools={" in msg + assert "outcome=" in msg + assert "applied=" in msg + assert "violations=" in msg + assert "usage={" in msg + + +def test_extraction_agent_threads_request_id_into_profile( + temp_storage, prompt_manager, llm_client +): + """request_id passed to agent.run lands in stored UserProfile.generated_from_request_id. + + Recall@K-style downstream consumers depend on this thread to translate + retrieved profiles back to their source publish_interaction request. + A regression here silently breaks per-session provenance for the agentic + backend. + """ + llm_client.generate_chat_response.side_effect = [ + _mk_tool_response( + [ + _mk_tool_call( + "c1", + "search_user_profiles", + {"query": "food", "top_k": 10}, + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "user likes sushi", + "ttl": "infinity", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + ] + + agent = ExtractionAgent( + client=llm_client, storage=temp_storage, prompt_manager=prompt_manager + ) + agent.run( + user_id="u_rid", + agent_version="v1", + extractor_name="default", + extraction_criteria="x", + sessions_text="User: I love sushi", + request_id="test-rid-abc", + ) + + profiles = temp_storage.get_user_profile("u_rid") + assert len(profiles) == 1 + assert profiles[0].generated_from_request_id == "test-rid-abc" + + +def test_extraction_agent_threads_request_id_into_playbook( + temp_storage, prompt_manager, llm_client +): + """request_id also lands on UserPlaybook.request_id (mirror of profile thread).""" + llm_client.generate_chat_response.side_effect = [ + _mk_tool_response( + [ + _mk_tool_call( + "c1", + "search_user_playbooks", + {"query": "rules", "top_k": 10}, + ) + ] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_playbook", + { + "trigger": "When user asks about food", + "content": "- Note that user likes sushi.", + "rationale": "User preference", + "source_span": "I love sushi", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + ] + + from reflexio.server.services.extraction.tools import PLAYBOOK_EXTRACTION_TOOLS + + agent = ExtractionAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + registry=PLAYBOOK_EXTRACTION_TOOLS, + ) + agent.run( + user_id="u_rid_pb", + agent_version="v1", + extractor_name="default", + extraction_criteria="Extract behavioural rules.", + sessions_text="User: I love sushi", + extraction_kind="UserPlaybook", + request_id="test-rid-pb", + ) + + playbooks = temp_storage.get_user_playbooks(user_id="u_rid_pb") + assert len(playbooks) == 1 + assert playbooks[0].request_id == "test-rid-pb" + + +def test_extraction_agent_request_id_default_is_empty_string( + temp_storage, prompt_manager, llm_client +): + """Backward compat: callers that omit request_id get '' on the profile. + + Existing test callers (and any historical deployments) must keep + working without code changes. + """ + llm_client.generate_chat_response.side_effect = [ + _mk_tool_response( + [_mk_tool_call("c1", "search_user_profiles", {"query": "x", "top_k": 10})] + ), + _mk_tool_response( + [ + _mk_tool_call( + "c2", + "create_user_profile", + { + "content": "fact", + "ttl": "infinity", + "source_span": "x", + }, + ) + ] + ), + _mk_tool_response([_mk_tool_call("c3", "finish", {})]), + ] + + agent = ExtractionAgent( + client=llm_client, storage=temp_storage, prompt_manager=prompt_manager + ) + agent.run( + user_id="u_default", + agent_version="v1", + extractor_name="default", + extraction_criteria="x", + sessions_text="User: x", + ) + + profiles = temp_storage.get_user_profile("u_default") + assert len(profiles) == 1 + assert profiles[0].generated_from_request_id == "" diff --git a/tests/server/services/extraction/test_invariants.py b/tests/server/services/extraction/test_invariants.py new file mode 100644 index 00000000..f970444b --- /dev/null +++ b/tests/server/services/extraction/test_invariants.py @@ -0,0 +1,287 @@ +"""Unit tests for plan-level invariants. Pure-function — no LLM, no storage.""" + +from reflexio.server.services.extraction.invariants import ( + inv_A_search_before_create, + inv_B_delete_known_id, + inv_D_plan_size_cap, + inv_F_no_duplicate_deletes, + inv_J_scope_match, +) +from reflexio.server.services.extraction.plan import ( + CreateUserPlaybookOp, + CreateUserProfileOp, + DeleteUserPlaybookOp, + DeleteUserProfileOp, + ExtractionCtx, +) + + +def _mk_ctx(**kw): + return ExtractionCtx(user_id="u_1", agent_version="v1", **kw) + + +# --- Invariant A: search-before-create --- + + +def test_inv_A_empty_plan_no_violations(): # noqa: N802 + assert inv_A_search_before_create(_mk_ctx()) == [] + + +def test_inv_A_create_with_no_search_violates(): # noqa: N802 + ctx = _mk_ctx(search_count=0) + ctx.plan.append(CreateUserProfileOp(content="x", ttl="infinity", source_span="y")) + v = inv_A_search_before_create(ctx) + assert len(v) == 1 + assert v[0].code == "A" + assert v[0].affected_op_indices == [0] + + +def test_inv_A_create_after_search_ok(): # noqa: N802 + ctx = _mk_ctx(search_count=1) + ctx.plan.append(CreateUserProfileOp(content="x", ttl="infinity", source_span="y")) + assert inv_A_search_before_create(ctx) == [] + + +def test_inv_A_multiple_creates_all_flagged_when_no_search(): # noqa: N802 + ctx = _mk_ctx(search_count=0) + ctx.plan.append(CreateUserProfileOp(content="a", ttl="infinity", source_span="s")) + ctx.plan.append(CreateUserPlaybookOp(trigger="t", content="c", source_span="s")) + v = inv_A_search_before_create(ctx) + assert len(v) == 1 + assert v[0].affected_op_indices == [0, 1] + + +# --- Invariant B: delete-references-known-id --- + + +def test_inv_B_delete_of_unknown_id_violates(): # noqa: N802 + ctx = _mk_ctx() + ctx.plan.append(DeleteUserProfileOp(id="p_999")) + v = inv_B_delete_known_id(ctx) + assert len(v) == 1 + assert v[0].code == "B" + assert v[0].affected_op_indices == [0] + + +def test_inv_B_delete_of_searched_id_ok(): # noqa: N802 + ctx = _mk_ctx() + ctx.known_ids.add("p_123") + ctx.plan.append(DeleteUserProfileOp(id="p_123")) + assert inv_B_delete_known_id(ctx) == [] + + +def test_inv_B_delete_of_in_plan_tentative_id_ok(): # noqa: N802 + """Self-correction: delete an id issued earlier in the same plan.""" + ctx = _mk_ctx() + ctx.known_ids.add("tentative_0") # the handler adds this when create_* runs + ctx.plan.append(CreateUserProfileOp(content="x", ttl="infinity", source_span="s")) + ctx.plan.append(DeleteUserProfileOp(id="tentative_0")) + assert inv_B_delete_known_id(ctx) == [] + + +def test_inv_B_playbook_delete_of_unknown_id_violates(): # noqa: N802 + ctx = _mk_ctx() + ctx.plan.append(DeleteUserPlaybookOp(id="pb_999")) + v = inv_B_delete_known_id(ctx) + assert v[0].affected_op_indices == [0] + + +# --- Invariant D: plan-size cap --- + + +def test_inv_D_under_cap_ok(): # noqa: N802 + ctx = _mk_ctx() + ctx.known_ids.add("tentative_0") + for _ in range(30): + ctx.plan.append( + CreateUserProfileOp(content="x", ttl="infinity", source_span="y") + ) + assert inv_D_plan_size_cap(ctx) == [] + + +def test_inv_D_over_cap_flags_overflow(): # noqa: N802 + ctx = _mk_ctx() + for _ in range(35): + ctx.plan.append( + CreateUserProfileOp(content="x", ttl="infinity", source_span="y") + ) + v = inv_D_plan_size_cap(ctx) + assert len(v) == 1 + assert v[0].affected_op_indices == list(range(30, 35)) + + +# --- Invariant F: no-duplicate-deletes --- + + +def test_inv_F_duplicate_delete_flagged(): # noqa: N802 + ctx = _mk_ctx() + ctx.known_ids.add("p_1") + ctx.plan.append(DeleteUserProfileOp(id="p_1")) + ctx.plan.append(DeleteUserProfileOp(id="p_1")) + v = inv_F_no_duplicate_deletes(ctx) + assert len(v) == 1 + # second (later) occurrence is the one we drop + assert v[0].affected_op_indices == [1] + + +def test_inv_F_distinct_deletes_ok(): # noqa: N802 + ctx = _mk_ctx() + ctx.known_ids.update({"p_1", "p_2"}) + ctx.plan.append(DeleteUserProfileOp(id="p_1")) + ctx.plan.append(DeleteUserProfileOp(id="p_2")) + assert inv_F_no_duplicate_deletes(ctx) == [] + + +# --- Invariant J: scope-match (placeholder for storage-layer guard) --- + + +def test_inv_J_returns_empty_for_v1(): # noqa: N802 + """J is enforced primarily at storage layer (user_id injection). + v1 invariant returns empty — future cross-user-check scaffolding.""" + ctx = _mk_ctx() + assert inv_J_scope_match(ctx) == [] + + +from unittest.mock import MagicMock + +from reflexio.server.services.extraction.invariants import ( + commit_plan, + inv_E_no_duplicate_creates, + inv_H_source_span_present, + inv_K_deletes_without_creates, + resolve_tentative_oscillations, +) + +# --- Soft invariants --- + + +def test_inv_E_identical_creates_flagged(): # noqa: N802 + ctx = _mk_ctx(search_count=1) + ctx.plan.append( + CreateUserProfileOp(content="user is a PM", ttl="infinity", source_span="s") + ) + ctx.plan.append( + CreateUserProfileOp(content="user is a PM", ttl="infinity", source_span="s") + ) + v = inv_E_no_duplicate_creates(ctx) + assert len(v) == 1 + assert v[0].severity == "soft" + assert v[0].code == "E" + + +def test_inv_H_empty_source_span_is_caught_at_schema_level(): # noqa: N802 + """source_span is schema-required non-empty; this invariant is a + secondary log guard if future schema changes relax that.""" + ctx = _mk_ctx(search_count=1) + # construct op with non-empty source_span — schema enforces min_length=1 + ctx.plan.append(CreateUserProfileOp(content="x", ttl="infinity", source_span=" ")) + v = inv_H_source_span_present(ctx) + assert len(v) == 1 + assert v[0].code == "H" + assert v[0].severity == "soft" + + +def test_inv_K_deletes_only_flagged(): # noqa: N802 + ctx = _mk_ctx() + ctx.known_ids.add("p_1") + ctx.plan.append(DeleteUserProfileOp(id="p_1")) + v = inv_K_deletes_without_creates(ctx) + assert len(v) == 1 + assert v[0].severity == "soft" + + +def test_inv_K_delete_plus_create_ok(): # noqa: N802 + ctx = _mk_ctx(search_count=1) + ctx.known_ids.add("p_1") + ctx.plan.append(DeleteUserProfileOp(id="p_1")) + ctx.plan.append(CreateUserProfileOp(content="x", ttl="infinity", source_span="y")) + assert inv_K_deletes_without_creates(ctx) == [] + + +# --- commit_plan orchestrator --- + + +def test_commit_plan_applies_valid_ops(): # noqa: N802 + """With no violations, every op reaches storage.""" + ctx = _mk_ctx(search_count=1) + ctx.known_ids.add("p_exists") + ctx.plan.append(DeleteUserProfileOp(id="p_exists")) + ctx.plan.append( + CreateUserProfileOp(content="new", ttl="infinity", source_span="evidence") + ) + + storage = MagicMock() + result = commit_plan(ctx, storage, outcome="finish_tool") + + assert len(result.applied) == 2 + assert result.outcome == "finish_tool" + assert result.violations == [] + + +def test_commit_plan_drops_hard_violation_ops(): # noqa: N802 + """Hard-invariant-violating ops are excluded from apply.""" + ctx = _mk_ctx(search_count=0) + # create without prior search → invariant A + ctx.plan.append(CreateUserProfileOp(content="x", ttl="infinity", source_span="y")) + # delete of unknown id → invariant B + ctx.plan.append(DeleteUserProfileOp(id="never_retrieved")) + + storage = MagicMock() + result = commit_plan(ctx, storage, outcome="finish_tool") + + assert result.applied == [] + codes = {v.code for v in result.violations} + assert {"A", "B"}.issubset(codes) + + +def test_commit_plan_keeps_soft_violation_ops(): # noqa: N802 + """Soft violations are logged but ops commit.""" + ctx = _mk_ctx(search_count=1) + ctx.plan.append(DeleteUserProfileOp(id="p_1")) + ctx.known_ids.add("p_1") + + storage = MagicMock() + result = commit_plan(ctx, storage, outcome="finish_tool") + + assert len(result.applied) == 1 # the delete got applied + assert any(v.code == "K" for v in result.violations) # but K flagged it + + +# --- resolve_tentative_oscillations --- + + +def test_resolve_oscillation_cancels_matching_pair(): # noqa: N802 + """Create at index 0 + delete targeting tentative::profile::0 cancel each other.""" + plan = [ + CreateUserProfileOp(content="x", ttl="infinity", source_span="y"), + DeleteUserProfileOp(id="tentative::profile::0"), + CreateUserProfileOp(content="real", ttl="infinity", source_span="z"), + ] + assert resolve_tentative_oscillations(plan) == {0, 1} + + +def test_resolve_oscillation_ignores_real_id_delete(): # noqa: N802 + """Delete of a non-tentative id is not touched by the resolver.""" + plan = [ + CreateUserProfileOp(content="x", ttl="infinity", source_span="y"), + DeleteUserProfileOp(id="p_real_uuid_123"), + ] + assert resolve_tentative_oscillations(plan) == set() + + +def test_resolve_oscillation_unmatched_tentative_delete_passes_through(): # noqa: N802 + """Delete of a tentative id that doesn't match any create — resolver ignores it. + Invariant B will catch it separately if it's truly unknown.""" + plan = [ + DeleteUserProfileOp(id="tentative::profile::99"), + ] + assert resolve_tentative_oscillations(plan) == set() + + +def test_resolve_oscillation_user_playbook_pair(): # noqa: N802 + """Same oscillation-cancel logic applies to user_playbook creates/deletes.""" + plan = [ + CreateUserPlaybookOp(trigger="t", content="c", source_span="s"), + DeleteUserPlaybookOp(id="tentative::user_playbook::0"), + ] + assert resolve_tentative_oscillations(plan) == {0, 1} diff --git a/tests/server/services/extraction/test_plan.py b/tests/server/services/extraction/test_plan.py new file mode 100644 index 00000000..8679a19d --- /dev/null +++ b/tests/server/services/extraction/test_plan.py @@ -0,0 +1,91 @@ +"""Unit tests for PlanOp types + ExtractionCtx.""" + +import pytest +from pydantic import ValidationError + +from reflexio.server.services.extraction.plan import ( + CommitResult, + CreateUserPlaybookOp, + CreateUserProfileOp, + DeleteUserPlaybookOp, + DeleteUserProfileOp, + ExtractionCtx, + Violation, +) + + +def test_create_user_profile_op_requires_content_ttl_source_span(): + op = CreateUserProfileOp( + content="user likes pasta", + ttl="infinity", + source_span="I love pasta", + ) + assert op.content == "user likes pasta" + assert op.ttl == "infinity" + assert op.source_span == "I love pasta" + + +def test_create_user_profile_op_rejects_empty_content(): + with pytest.raises(ValidationError): + CreateUserProfileOp(content="", ttl="infinity", source_span="evidence") + + +def test_create_user_profile_op_rejects_invalid_ttl(): + with pytest.raises(ValidationError): + CreateUserProfileOp( + content="x", + ttl="two_days", # type: ignore[arg-type] + source_span="y", # not in ProfileTimeToLive + ) + + +def test_delete_user_profile_op_requires_id(): + op = DeleteUserProfileOp(id="p_42") + assert op.id == "p_42" + with pytest.raises(ValidationError): + DeleteUserProfileOp(id="") + + +def test_create_user_playbook_op_fields(): + op = CreateUserPlaybookOp( + trigger="code help", + content="show examples", + rationale="user prefers examples", + strength="soft", + source_span="…", + ) + assert op.strength == "soft" + + +def test_create_user_playbook_op_rejects_bad_strength(): + with pytest.raises(ValidationError): + CreateUserPlaybookOp( + trigger="t", + content="c", + rationale="r", + strength="weak", # type: ignore[arg-type] + source_span="s", + ) + + +def test_delete_user_playbook_op_requires_id(): + op = DeleteUserPlaybookOp(id="pb_7") + assert op.id == "pb_7" + + +def test_extraction_ctx_defaults(): + ctx = ExtractionCtx(user_id="u_1", agent_version="v1") + assert ctx.user_id == "u_1" + assert ctx.agent_version == "v1" + assert ctx.plan == [] + assert ctx.known_ids == set() + assert ctx.search_count == 0 + assert ctx.finished is False + + +def test_violation_and_commit_result_shapes(): + v = Violation(code="A", severity="hard", affected_op_indices=[0, 2], msg="x") + assert v.severity == "hard" + r = CommitResult(applied=[], violations=[v], outcome="finish_tool") + assert r.outcome == "finish_tool" + assert len(r.violations) == 1 diff --git a/tests/server/services/extraction/test_tools.py b/tests/server/services/extraction/test_tools.py new file mode 100644 index 00000000..708df6fc --- /dev/null +++ b/tests/server/services/extraction/test_tools.py @@ -0,0 +1,447 @@ +"""Unit tests for atomic tool handlers. Uses in-memory SQLite storage — no LLM.""" + +import pytest + +from reflexio.models.api_schema.domain.entities import UserPlaybook, UserProfile +from reflexio.models.api_schema.domain.enums import ProfileTimeToLive +from reflexio.server.services.extraction.plan import ExtractionCtx +from reflexio.server.services.extraction.tools import ( + ReadSessionTextArgs, + GetUserProfileArgs, + SearchAgentPlaybooksArgs, + SearchUserPlaybooksArgs, + SearchUserProfilesArgs, + _handle_read_session_text, + _handle_get_user_profile, + _handle_search_agent_playbooks, + _handle_search_user_playbooks, + _handle_search_user_profiles, +) + + +@pytest.fixture +def seeded_storage(tmp_path): + """SQLite storage seeded with one profile and one user playbook.""" + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + storage = SQLiteStorage("test_org", db_path=str(tmp_path / "test.db")) + storage.add_user_profile( + "u_1", + [ + UserProfile( + user_id="u_1", + profile_id="p_10", + content="user likes Italian food", + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_000, + expiration_timestamp=4102444800, + source="test", + generated_from_request_id="req_test", + ) + ], + ) + storage.save_user_playbooks( + [ + UserPlaybook( + user_playbook_id=0, + user_id="u_1", + agent_version="v1", + request_id="r_1", + playbook_name="coding", + content="show code examples", + trigger="user asks for help", + ) + ] + ) + return storage + + +@pytest.fixture +def ctx(): + return ExtractionCtx(user_id="u_1", agent_version="v1", extractor_name="coding") + + +def test_search_user_profiles_populates_known_ids(seeded_storage, ctx): + result = _handle_search_user_profiles( + SearchUserProfilesArgs(query="Italian food", top_k=10), + seeded_storage, + ctx, + ) + assert "hits" in result + assert ctx.search_count == 1 + # Every hit's id must be added to ctx.known_ids — that's the side + # effect this test name claims to validate. + hit_ids = {hit["id"] for hit in result["hits"]} + assert hit_ids, "expected at least one hit from seeded storage" + assert hit_ids.issubset(ctx.known_ids) + + +def test_search_user_profiles_empty_result(seeded_storage, ctx): + result = _handle_search_user_profiles( + SearchUserProfilesArgs(query="quantum mechanics", top_k=10), + seeded_storage, + ctx, + ) + assert ctx.search_count == 1 + assert "hits" in result + + +def test_get_user_profile_populates_known_ids_when_found(seeded_storage, ctx): + result = _handle_get_user_profile( + GetUserProfileArgs(id="p_10"), seeded_storage, ctx + ) + assert "profile" in result + assert result["profile"]["id"] == "p_10" + assert "p_10" in ctx.known_ids + # get does NOT bump search_count + assert ctx.search_count == 0 + + +def test_get_user_profile_not_found(seeded_storage, ctx): + result = _handle_get_user_profile( + GetUserProfileArgs(id="p_nonexistent"), seeded_storage, ctx + ) + assert result == {"error": "not found"} + assert "p_nonexistent" not in ctx.known_ids + + +def test_search_user_playbooks_populates_known_ids(seeded_storage, ctx): + result = _handle_search_user_playbooks( + SearchUserPlaybooksArgs(query="code examples", top_k=10), + seeded_storage, + ctx, + ) + assert "hits" in result + assert ctx.search_count == 1 + hit_ids = {hit["id"] for hit in result["hits"]} + assert hit_ids, "expected at least one hit from seeded storage" + assert hit_ids.issubset(ctx.known_ids) + + +def test_search_agent_playbooks_bumps_search_count(seeded_storage, ctx): + result = _handle_search_agent_playbooks( + SearchAgentPlaybooksArgs(query="x", top_k=10), seeded_storage, ctx + ) + assert "hits" in result + assert ctx.search_count == 1 + + +def test_top_k_capped_server_side(seeded_storage, ctx): + """Server-side cap (25) prevents unbounded requests.""" + # top_k=1000 should be capped before reaching storage; best-effort check is + # that the call succeeds without error and returns within cap. + result = _handle_search_user_profiles( + SearchUserProfilesArgs(query="x", top_k=1000), + seeded_storage, + ctx, + ) + assert "hits" in result + + +def test_read_session_text_returns_error_when_api_missing(): + """If storage doesn't have get_interactions_by_session, handler returns error.""" + from unittest.mock import MagicMock + + mock_storage = MagicMock( + spec=["search_user_profile"] + ) # no get_interactions_by_session + # Purposefully does NOT have get_interactions_by_session attr + del mock_storage.get_interactions_by_session # ensure AttributeError on access + ctx = ExtractionCtx(user_id="u", agent_version="v") + result = _handle_read_session_text( + ReadSessionTextArgs(session_id="s", span="x"), + mock_storage, + ctx, + ) + assert "error" in result + + +# --- Mutating handlers --- + +from reflexio.server.services.extraction.plan import ( + CreateUserPlaybookOp, + CreateUserProfileOp, + DeleteUserPlaybookOp, + DeleteUserProfileOp, +) +from reflexio.server.services.extraction.tools import ( + CreateUserPlaybookArgs, + CreateUserProfileArgs, + DeleteUserPlaybookArgs, + DeleteUserProfileArgs, + _handle_create_user_playbook, + _handle_create_user_profile, + _handle_delete_user_playbook, + _handle_delete_user_profile, + apply_plan_op, +) + + +def test_create_user_profile_appends_plan_no_storage_write(seeded_storage, ctx): + result = _handle_create_user_profile( + CreateUserProfileArgs( + content="user prefers dark mode", ttl="infinity", source_span="I use dark" + ), + seeded_storage, + ctx, + ) + assert "tentative_id" in result + assert "op_idx" in result + assert len(ctx.plan) == 1 + assert isinstance(ctx.plan[0], CreateUserProfileOp) + # Storage unchanged — was 1 seeded profile, still 1 + assert len(seeded_storage.get_user_profile("u_1")) == 1 + + +def test_create_user_profile_adds_tentative_id_to_known_ids(seeded_storage, ctx): + r = _handle_create_user_profile( + CreateUserProfileArgs(content="x", ttl="infinity", source_span="y"), + seeded_storage, + ctx, + ) + tid = r["tentative_id"] + assert tid in ctx.known_ids # self-correction via delete becomes possible + + +def test_delete_user_profile_appends_plan(seeded_storage, ctx): + ctx.known_ids.add("p_10") + result = _handle_delete_user_profile( + DeleteUserProfileArgs(id="p_10"), seeded_storage, ctx + ) + assert len(ctx.plan) == 1 + assert isinstance(ctx.plan[0], DeleteUserProfileOp) + assert result["op_idx"] == 0 + # Storage unchanged + assert len(seeded_storage.get_user_profile("u_1")) == 1 + + +def test_create_user_playbook_appends_plan(seeded_storage, ctx): + _handle_create_user_playbook( + CreateUserPlaybookArgs( + trigger="on review", + content="suggest refactor", + source_span="evidence", + ), + seeded_storage, + ctx, + ) + assert isinstance(ctx.plan[0], CreateUserPlaybookOp) + + +def test_delete_user_playbook_appends_plan(seeded_storage, ctx): + ctx.known_ids.add("pb_5") + _handle_delete_user_playbook(DeleteUserPlaybookArgs(id="pb_5"), seeded_storage, ctx) + assert isinstance(ctx.plan[0], DeleteUserPlaybookOp) + + +# --- apply_plan_op --- + + +def test_apply_plan_op_create_user_profile_calls_add(seeded_storage, ctx): + op = CreateUserProfileOp( + content="user loves hiking", ttl="infinity", source_span="I hike weekly" + ) + before = len(seeded_storage.get_user_profile("u_1")) + apply_plan_op(op, seeded_storage, ctx) + assert len(seeded_storage.get_user_profile("u_1")) == before + 1 + + +def test_apply_plan_op_delete_user_profile_removes_record(seeded_storage, ctx): + # Verify p_10 exists + assert any(p.profile_id == "p_10" for p in seeded_storage.get_user_profile("u_1")) + op = DeleteUserProfileOp(id="p_10") + apply_plan_op(op, seeded_storage, ctx) + remaining = [p.profile_id for p in seeded_storage.get_user_profile("u_1")] + assert "p_10" not in remaining + + +def test_apply_plan_op_create_profile_computes_expiration_from_ttl(tmp_path): + """Bug regression: profile_time_to_live must be consistent with expiration_timestamp.""" + from reflexio.models.api_schema.domain.entities import NEVER_EXPIRES_TIMESTAMP + from reflexio.models.api_schema.domain.enums import ProfileTimeToLive + from reflexio.server.services.extraction.plan import ( + CreateUserProfileOp, + ExtractionCtx, + ) + from reflexio.server.services.extraction.tools import apply_plan_op + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + storage = SQLiteStorage(org_id="test-org", db_path=str(tmp_path / "t.db")) + ctx = ExtractionCtx(user_id="u_1", agent_version="v1") + + op = CreateUserProfileOp(content="x", ttl="one_week", source_span="y") + apply_plan_op(op, storage, ctx) + + profiles = storage.get_user_profile("u_1") + assert len(profiles) == 1 + p = profiles[0] + assert p.profile_time_to_live == ProfileTimeToLive.ONE_WEEK + assert p.expiration_timestamp != NEVER_EXPIRES_TIMESTAMP + assert p.expiration_timestamp > p.last_modified_timestamp + # one_week is 7 days = 604800 seconds + assert p.expiration_timestamp - p.last_modified_timestamp == 604800 + + +def test_apply_plan_op_create_profile_infinity_ttl_uses_sentinel(tmp_path): + """An 'infinity' TTL should still produce NEVER_EXPIRES_TIMESTAMP.""" + from reflexio.models.api_schema.domain.entities import NEVER_EXPIRES_TIMESTAMP + from reflexio.server.services.extraction.plan import ( + CreateUserProfileOp, + ExtractionCtx, + ) + from reflexio.server.services.extraction.tools import apply_plan_op + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + storage = SQLiteStorage(org_id="test-org", db_path=str(tmp_path / "t.db")) + ctx = ExtractionCtx(user_id="u_1", agent_version="v1") + op = CreateUserProfileOp(content="x", ttl="infinity", source_span="y") + apply_plan_op(op, storage, ctx) + p = storage.get_user_profile("u_1")[0] + assert p.expiration_timestamp == NEVER_EXPIRES_TIMESTAMP + + +# ==================================================================== +# Registry tests +# ==================================================================== + +from reflexio.server.services.extraction.tools import ( + EXTRACTION_TOOLS, + PLAYBOOK_EXTRACTION_TOOLS, + PROFILE_EXTRACTION_TOOLS, + SEARCH_TOOLS, +) + + +def test_extraction_registry_has_all_tools(): + specs = {t["function"]["name"] for t in EXTRACTION_TOOLS.openai_specs()} + # EXTRACTION_TOOLS is the backward-compat union of all four create/delete tools + # plus the full read surface (including agent-playbook and session-excerpt tools). + assert specs == { + "search_user_profiles", + "get_user_profile", + "create_user_profile", + "delete_user_profile", + "search_user_playbooks", + "get_user_playbook", + "create_user_playbook", + "delete_user_playbook", + "search_agent_playbooks", + "get_agent_playbook", + "read_session_text", + "finish", + } + + +def test_profile_extraction_registry_excludes_playbook_mutations(): + """PROFILE_EXTRACTION_TOOLS must not expose create/delete_user_playbook.""" + specs = {t["function"]["name"] for t in PROFILE_EXTRACTION_TOOLS.openai_specs()} + assert "create_user_profile" in specs + assert "delete_user_profile" in specs + assert "create_user_playbook" not in specs + assert "delete_user_playbook" not in specs + assert "finish" in specs + + +def test_playbook_extraction_registry_excludes_profile_mutations(): + """PLAYBOOK_EXTRACTION_TOOLS must not expose create/delete_user_profile.""" + specs = {t["function"]["name"] for t in PLAYBOOK_EXTRACTION_TOOLS.openai_specs()} + assert "create_user_playbook" in specs + assert "delete_user_playbook" in specs + assert "create_user_profile" not in specs + assert "delete_user_profile" not in specs + assert "finish" in specs + + +def test_search_registry_is_read_only(): + specs = {t["function"]["name"] for t in SEARCH_TOOLS.openai_specs()} + assert specs == { + "search_user_profiles", + "get_user_profile", + "rerank_user_profiles", + "storage_stats", + "search_user_playbooks", + "get_user_playbook", + "search_agent_playbooks", + "get_agent_playbook", + "read_session_text", + "finish", + } + # No mutations allowed in search + assert "create_user_profile" not in specs + assert "delete_user_profile" not in specs + + +# ==================================================================== +# Query-embedding plumbing for HYBRID search mode +# ==================================================================== + +from unittest.mock import MagicMock # noqa: E402 + +from reflexio.server.services.extraction.tools import _maybe_embed_query # noqa: E402 + + +def test_maybe_embed_query_returns_none_when_storage_has_no_embedder(): + """Disk/local storage backends that don't expose _get_embedding should + gracefully produce None rather than raising.""" + assert _maybe_embed_query(object(), "anything") is None + + +def test_maybe_embed_query_returns_none_when_embedder_raises(): + """Embedder failures must not break search — fall back to FTS via None.""" + storage = MagicMock() + storage._get_embedding.side_effect = RuntimeError("provider down") + assert _maybe_embed_query(storage, "anything") is None + + +def test_maybe_embed_query_returns_embedding_when_supported(): + storage = MagicMock() + storage._get_embedding.return_value = [0.1, 0.2, 0.3] + assert _maybe_embed_query(storage, "sushi") == [0.1, 0.2, 0.3] + storage._get_embedding.assert_called_once_with("sushi") + + +def test_search_user_profiles_passes_query_embedding(): + """Profile search handler must compute + pass a query embedding so + storage doesn't downgrade HYBRID to FTS (regression for the + 'no query embedding provided — falling back to FTS' warning).""" + storage = MagicMock() + storage._get_embedding.return_value = [0.1, 0.2, 0.3] + storage.search_user_profile.return_value = [] + ctx = ExtractionCtx(user_id="u_1", agent_version="v1") + args = SearchUserProfilesArgs(query="sushi", top_k=5) + + _handle_search_user_profiles(args, storage, ctx) + + storage._get_embedding.assert_called_once_with("sushi") + _, kwargs = storage.search_user_profile.call_args + assert kwargs["query_embedding"] == [0.1, 0.2, 0.3] + + +def test_search_user_playbooks_passes_query_embedding_via_options(): + """Playbook search handler wraps the embedding in SearchOptions.""" + storage = MagicMock() + storage._get_embedding.return_value = [0.4, 0.5] + storage.search_user_playbooks.return_value = [] + ctx = ExtractionCtx(user_id="u_1", agent_version="v1") + args = SearchUserPlaybooksArgs(query="code review", top_k=5, status="current") + + _handle_search_user_playbooks(args, storage, ctx) + + storage._get_embedding.assert_called_once_with("code review") + _, kwargs = storage.search_user_playbooks.call_args + assert kwargs["options"].query_embedding == [0.4, 0.5] + + +def test_search_agent_playbooks_passes_query_embedding_via_options(): + """Agent-playbook search handler wraps the embedding in SearchOptions.""" + storage = MagicMock() + storage._get_embedding.return_value = [0.6, 0.7] + storage.search_agent_playbooks.return_value = [] + ctx = ExtractionCtx(user_id="u_1", agent_version="v1") + args = SearchAgentPlaybooksArgs(query="debug approach", top_k=5, status="current") + + _handle_search_agent_playbooks(args, storage, ctx) + + storage._get_embedding.assert_called_once_with("debug approach") + _, kwargs = storage.search_agent_playbooks.call_args + assert kwargs["options"].query_embedding == [0.6, 0.7] diff --git a/tests/server/services/playbook/test_playbook_aggregator.py b/tests/server/services/playbook/test_playbook_aggregator.py index 4558cb0f..2d76e09f 100644 --- a/tests/server/services/playbook/test_playbook_aggregator.py +++ b/tests/server/services/playbook/test_playbook_aggregator.py @@ -1207,3 +1207,31 @@ def test_valid_response_returns_playbook(self): assert result.trigger == "when testing" assert result.content == "do something" assert result.playbook_status == PlaybookStatus.PENDING + + +def test_playbook_aggregation_prompt_specifies_structured_format(): + """Sanity (v2.1.0): aggregator prompt must carry the Agent-Skills + formatting discipline — imperative conditional triggers, markdown bullet + content, one-sentence rationale. Mirrors the extraction prompt v1.4.0 + so the downstream agent sees the same shape across per-user playbooks + and aggregated ones. Guards against silent regression to prose shape.""" + from reflexio.server.prompt.prompt_manager import PromptManager + + pm = PromptManager() + out = pm.render_prompt( + "playbook_aggregation", + variables={ + "user_playbooks": '[1]\nContent: "x"\nTrigger: "y"', + "existing_approved_playbooks": "(none)", + }, + ) + # The Playbook format section must be present. + assert "Playbook format" in out + # Trigger guidance — imperative conditional phrasing + keyword coverage. + assert "imperative conditional phrasing" in out + # Content guidance — markdown bullet list for multi-action policies. + assert "markdown bullet list" in out + # Examples now show bullet-shaped content, not single-sentence prose. + assert "- Ask for CLI preference" in out + # Rationale guidance — one sentence WHY. + assert "one sentence" in out.lower() diff --git a/tests/server/services/playbook/test_playbook_deduplicator.py b/tests/server/services/playbook/test_playbook_deduplicator.py deleted file mode 100644 index 7b3eabf9..00000000 --- a/tests/server/services/playbook/test_playbook_deduplicator.py +++ /dev/null @@ -1,845 +0,0 @@ -"""Tests for playbook deduplication service.""" - -from unittest.mock import MagicMock, patch - -import pytest - -from reflexio.models.api_schema.service_schemas import UserPlaybook -from reflexio.server.services.playbook.playbook_deduplicator import ( - PlaybookDeduplicationDuplicateGroup, - PlaybookDeduplicationOutput, - PlaybookDeduplicator, -) -from reflexio.server.services.playbook.playbook_service_utils import ( - StructuredPlaybookContent, -) - -# =============================== -# Fixtures -# =============================== - - -def _make_user_playbook( - idx: int, - playbook_name: str = "test_fb", - content: str | None = None, - trigger: str | None = None, - source_interaction_ids: list[int] | None = None, - user_playbook_id: int = 0, -) -> UserPlaybook: - """Helper to create a UserPlaybook object for tests.""" - return UserPlaybook( - user_playbook_id=user_playbook_id, - agent_version="v1", - request_id=f"req_{idx}", - playbook_name=playbook_name, - content=content or f"content_{idx}", - trigger=trigger or f"condition_{idx}", - source="test", - source_interaction_ids=source_interaction_ids or [], - ) - - -@pytest.fixture -def mock_deduplicator(): - """Create a PlaybookDeduplicator with mocked dependencies.""" - mock_request_context = MagicMock() - mock_request_context.storage = MagicMock() - mock_request_context.prompt_manager = MagicMock() - mock_request_context.prompt_manager.render_prompt.return_value = "mock prompt" - - mock_llm_client = MagicMock() - - with patch( - "reflexio.server.services.deduplication_utils.SiteVarManager" - ) as mock_svm: - mock_svm.return_value.get_site_var.return_value = { - "default_generation_model_name": "gpt-test" - } - return PlaybookDeduplicator( - request_context=mock_request_context, llm_client=mock_llm_client - ) - - -# =============================== -# Tests for _format_playbooks_with_prefix -# =============================== - - -class TestFormatPlaybooksWithPrefix: - """Tests for _format_playbooks_with_prefix.""" - - def test_single_playbook(self, mock_deduplicator): - """Test formatting a single playbook.""" - fb = _make_user_playbook(0, content="do X when Y") - result = mock_deduplicator._format_playbooks_with_prefix([fb], "NEW") - assert '[NEW-0] Content: "do X when Y"' in result - assert "Name: test_fb" in result - assert "Source: test" in result - - def test_multiple_playbooks(self, mock_deduplicator): - """Test formatting multiple playbooks with incrementing indices.""" - playbooks = [_make_user_playbook(i) for i in range(3)] - result = mock_deduplicator._format_playbooks_with_prefix(playbooks, "EXISTING") - assert "[EXISTING-0]" in result - assert "[EXISTING-1]" in result - assert "[EXISTING-2]" in result - - def test_empty_list(self, mock_deduplicator): - """Test formatting empty list returns '(None)'.""" - result = mock_deduplicator._format_playbooks_with_prefix([], "NEW") - assert result == "(None)" - - -# =============================== -# Tests for _format_new_and_existing_for_prompt -# =============================== - - -class TestFormatNewAndExistingForPrompt: - """Tests for _format_new_and_existing_for_prompt.""" - - def test_formats_both_lists(self, mock_deduplicator): - """Test that new and existing playbooks are formatted with correct prefixes.""" - new_fbs = [_make_user_playbook(0)] - existing_fbs = [_make_user_playbook(1)] - - new_text, existing_text = mock_deduplicator._format_new_and_existing_for_prompt( - new_fbs, existing_fbs - ) - - assert "[NEW-0]" in new_text - assert "[EXISTING-0]" in existing_text - - def test_empty_existing(self, mock_deduplicator): - """Test formatting with empty existing playbooks.""" - new_fbs = [_make_user_playbook(0)] - - new_text, existing_text = mock_deduplicator._format_new_and_existing_for_prompt( - new_fbs, [] - ) - - assert "[NEW-0]" in new_text - assert existing_text == "(None)" - - -# =============================== -# Tests for _retrieve_existing_playbooks -# =============================== - - -class TestRetrieveExistingPlaybooks: - """Tests for _retrieve_existing_playbooks.""" - - def test_with_embeddings(self, mock_deduplicator): - """Test retrieval using embeddings for vector search.""" - new_fb = _make_user_playbook(0, trigger="user asks about billing") - existing_fb = _make_user_playbook( - 1, user_playbook_id=100, trigger="billing inquiry" - ) - - mock_deduplicator.client.get_embeddings.return_value = [[0.1, 0.2, 0.3]] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [ - existing_fb - ] - - result = mock_deduplicator._retrieve_existing_playbooks([new_fb]) - - assert len(result) == 1 - assert result[0].user_playbook_id == 100 - mock_deduplicator.client.get_embeddings.assert_called_once() - - def test_fallback_to_text_search(self, mock_deduplicator): - """Test fallback to text-only search when embedding generation fails.""" - new_fb = _make_user_playbook(0) - existing_fb = _make_user_playbook(1, user_playbook_id=200) - - mock_deduplicator.client.get_embeddings.side_effect = Exception("embed error") - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [ - existing_fb - ] - - result = mock_deduplicator._retrieve_existing_playbooks([new_fb]) - - assert len(result) == 1 - - def test_empty_query_texts(self, mock_deduplicator): - """Test that empty when_condition playbooks return no results.""" - fb = UserPlaybook( - agent_version="v1", - request_id="req1", - playbook_name="test", - content="", - trigger="", - ) - - result = mock_deduplicator._retrieve_existing_playbooks([fb]) - - assert result == [] - - def test_deduplicates_by_id(self, mock_deduplicator): - """Test that duplicate existing playbooks from multiple queries are deduplicated.""" - fb1 = _make_user_playbook(0, trigger="query1") - fb2 = _make_user_playbook(1, trigger="query2") - - shared_existing = _make_user_playbook(99, user_playbook_id=500) - - mock_deduplicator.client.get_embeddings.return_value = [ - [0.1], - [0.2], - ] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [ - shared_existing - ] - - result = mock_deduplicator._retrieve_existing_playbooks([fb1, fb2]) - - # Should only appear once despite being returned for both queries - assert len(result) == 1 - - -# =============================== -# Tests for deduplicate -# =============================== - - -class TestDeduplicate: - """Tests for the main deduplicate method.""" - - def test_mock_mode_skips_deduplication(self, mock_deduplicator): - """Test that MOCK_LLM_RESPONSE=true skips deduplication.""" - fb1 = _make_user_playbook(0) - fb2 = _make_user_playbook(1) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "true"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb1], [fb2]], request_id="req1", agent_version="v1" - ) - - assert len(result) == 2 - assert delete_ids == [] - - def test_empty_results(self, mock_deduplicator): - """Test deduplication with no playbooks.""" - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "false"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[]], request_id="req1", agent_version="v1" - ) - - assert result == [] - assert delete_ids == [] - - def test_error_fallback_returns_all(self, mock_deduplicator): - """Test that LLM call error falls back to returning all playbooks.""" - fb = _make_user_playbook(0) - - mock_deduplicator.client.get_embeddings.return_value = [[0.1]] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [] - mock_deduplicator.client.generate_chat_response.side_effect = Exception( - "LLM error" - ) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "false"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb]], request_id="req1", agent_version="v1" - ) - - assert len(result) == 1 - assert delete_ids == [] - - -# =============================== -# Tests for _build_deduplicated_results -# =============================== - - -class TestBuildDeduplicatedResults: - """Tests for _build_deduplicated_results merge logic.""" - - def test_merge_group_combines_source_interaction_ids(self, mock_deduplicator): - """Test that merged groups combine source_interaction_ids from all playbooks.""" - new_playbooks = [ - _make_user_playbook(0, source_interaction_ids=[1, 2]), - _make_user_playbook(1, source_interaction_ids=[3, 4]), - ] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content=StructuredPlaybookContent( - content="merged do", trigger="merged when" - ), - reasoning="Same topic", - ) - ], - unique_ids=[], - ) - - result, delete_ids = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 1 - assert set(result[0].source_interaction_ids) == {1, 2, 3, 4} - assert delete_ids == [] - - def test_unique_ids_passed_through(self, mock_deduplicator): - """Test that unique NEW playbooks are passed through unchanged.""" - new_playbooks = [ - _make_user_playbook(0), - _make_user_playbook(1), - ] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[], unique_ids=["NEW-0", "NEW-1"] - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 2 - - def test_existing_playbooks_to_delete(self, mock_deduplicator): - """Test that existing playbooks in merge groups are marked for deletion.""" - new_playbooks = [_make_user_playbook(0)] - existing_playbooks = [_make_user_playbook(1, user_playbook_id=999)] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["NEW-0", "EXISTING-0"], - merged_content=StructuredPlaybookContent( - content="merged", trigger="when merged" - ), - reasoning="Duplicate", - ) - ], - unique_ids=[], - ) - - result, delete_ids = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=existing_playbooks, - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 1 - assert 999 in delete_ids - - def test_safety_fallback_unhandled_playbooks(self, mock_deduplicator): - """Test that playbooks not mentioned by LLM are added via safety fallback.""" - new_playbooks = [ - _make_user_playbook(0), - _make_user_playbook(1), - _make_user_playbook(2), - ] - - # LLM only mentions index 0 - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[], unique_ids=["NEW-0"] - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - # Index 0 via unique_ids + index 1 and 2 via safety fallback - assert len(result) == 3 - - -# =============================== -# Tests for deduplicate happy path and advanced scenarios -# =============================== - - -class TestDeduplicateHappyPath: - """Tests for the full deduplicate() flow with LLM mocks returning PlaybookDeduplicationOutput.""" - - def test_happy_path_with_duplicates(self, mock_deduplicator): - """Full happy path: LLM returns a merge group and unique playbooks.""" - fb0 = _make_user_playbook(0, content="do X when Y", source_interaction_ids=[10]) - fb1 = _make_user_playbook( - 1, content="do X when Y again", source_interaction_ids=[20] - ) - fb2 = _make_user_playbook(2, content="do Z when W", source_interaction_ids=[30]) - - # No existing playbooks found via search - mock_deduplicator.client.get_embeddings.return_value = [ - [0.1], - [0.2], - [0.3], - ] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [] - - # LLM merges fb0 and fb1, keeps fb2 as unique - mock_deduplicator.client.generate_chat_response.return_value = ( - PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content=StructuredPlaybookContent( - content="do X", trigger="when Y" - ), - reasoning="Same instruction", - ) - ], - unique_ids=["NEW-2"], - ) - ) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "false"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb0, fb1], [fb2]], request_id="req_test", agent_version="v1" - ) - - # 1 merged + 1 unique = 2 playbooks - assert len(result) == 2 - assert delete_ids == [] - - # Merged playbook should have combined source_interaction_ids - merged = result[0] - assert set(merged.source_interaction_ids) == {10, 20} - - # Unique playbook should be fb2 - assert result[1].content == "do Z when W" - - def test_multiple_extractor_results_nested_lists(self, mock_deduplicator): - """Multiple extractor results (nested list of lists) are flattened correctly.""" - fb0 = _make_user_playbook(0, content="playbook from extractor 1") - fb1 = _make_user_playbook(1, content="playbook from extractor 2") - fb2 = _make_user_playbook(2, content="playbook from extractor 3") - - mock_deduplicator.client.get_embeddings.return_value = [ - [0.1], - [0.2], - [0.3], - ] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [] - - # LLM says all are unique - mock_deduplicator.client.generate_chat_response.return_value = ( - PlaybookDeduplicationOutput( - duplicate_groups=[], unique_ids=["NEW-0", "NEW-1", "NEW-2"] - ) - ) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "false"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb0], [fb1], [fb2]], request_id="req_test", agent_version="v1" - ) - - assert len(result) == 3 - assert delete_ids == [] - - def test_all_playbooks_are_duplicates_of_existing(self, mock_deduplicator): - """All new playbooks are duplicates of existing playbooks in the DB.""" - fb0 = _make_user_playbook(0, content="do X when Y", source_interaction_ids=[10]) - existing_fb = _make_user_playbook( - 99, - user_playbook_id=500, - content="do X when Y (existing)", - source_interaction_ids=[5], - ) - - mock_deduplicator.client.get_embeddings.return_value = [[0.1]] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [ - existing_fb - ] - - # LLM merges NEW-0 with EXISTING-0 - mock_deduplicator.client.generate_chat_response.return_value = ( - PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["NEW-0", "EXISTING-0"], - merged_content=StructuredPlaybookContent( - content="do X", trigger="when Y" - ), - reasoning="Same instruction as existing", - ) - ], - unique_ids=[], - ) - ) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "false"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb0]], request_id="req_test", agent_version="v1" - ) - - # 1 merged playbook replaces both - assert len(result) == 1 - # Existing playbook should be marked for deletion - assert 500 in delete_ids - # Merged playbook should combine source_interaction_ids from both - assert set(result[0].source_interaction_ids) == {5, 10} - - -# =============================== -# Tests for _retrieve_existing_playbooks with user_id filter -# =============================== - - -class TestBuildDeduplicatedResultsEdgeCases: - """Extended tests for _build_deduplicated_results edge cases.""" - - def test_template_fallback_to_existing_playbook(self, mock_deduplicator): - """Test template selection falls back to existing playbook when no NEW in group.""" - existing_playbooks = [ - _make_user_playbook( - 0, - user_playbook_id=100, - playbook_name="existing_fb", - source_interaction_ids=[5], - ), - ] - - # Group only has EXISTING items, no NEW items - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["EXISTING-0"], - merged_content=StructuredPlaybookContent( - content="merged do", trigger="merged when" - ), - reasoning="Existing-only group", - ) - ], - unique_ids=[], - ) - - result, delete_ids = mock_deduplicator._build_deduplicated_results( - new_playbooks=[], - existing_playbooks=existing_playbooks, - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 1 - # Template should come from existing playbook - assert result[0].playbook_name == "existing_fb" - assert 100 in delete_ids - - def test_template_fallback_skips_out_of_range_existing(self, mock_deduplicator): - """Test that out-of-range existing indices are skipped in fallback.""" - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["EXISTING-99"], # out of range - merged_content=StructuredPlaybookContent( - content="merged do", trigger="merged when" - ), - reasoning="Bad index", - ) - ], - unique_ids=[], - ) - - result, delete_ids = mock_deduplicator._build_deduplicated_results( - new_playbooks=[], - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - # Group should be skipped entirely since no valid template was found - assert len(result) == 0 - assert delete_ids == [] - - def test_source_interaction_ids_combined_from_new_and_existing( - self, mock_deduplicator - ): - """Test that source_interaction_ids are combined from both NEW and EXISTING playbooks.""" - new_playbooks = [ - _make_user_playbook(0, source_interaction_ids=[1, 2]), - ] - existing_playbooks = [ - _make_user_playbook(1, user_playbook_id=100, source_interaction_ids=[3, 4]), - ] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["NEW-0", "EXISTING-0"], - merged_content=StructuredPlaybookContent( - content="merged", trigger="merged condition" - ), - reasoning="Combined", - ) - ], - unique_ids=[], - ) - - result, delete_ids = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=existing_playbooks, - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 1 - assert set(result[0].source_interaction_ids) == {1, 2, 3, 4} - assert 100 in delete_ids - - def test_source_interaction_ids_deduplication(self, mock_deduplicator): - """Test that duplicate source_interaction_ids are not repeated.""" - new_playbooks = [ - _make_user_playbook(0, source_interaction_ids=[1, 2]), - _make_user_playbook(1, source_interaction_ids=[2, 3]), - ] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[ - PlaybookDeduplicationDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content=StructuredPlaybookContent( - content="merged", trigger="merged cond" - ), - reasoning="Overlap IDs", - ) - ], - unique_ids=[], - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 1 - # ID 2 should appear only once - assert result[0].source_interaction_ids == [1, 2, 3] - - def test_unhandled_playbooks_safety_net(self, mock_deduplicator): - """Test that playbooks not mentioned in unique_ids or groups are added via safety net.""" - new_playbooks = [ - _make_user_playbook(0), - _make_user_playbook(1), - _make_user_playbook(2), - ] - - # LLM only mentions index 1 as unique, leaves 0 and 2 unmentioned - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[], unique_ids=["NEW-1"] - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 3 - # Index 1 is from unique_ids, indices 0 and 2 from safety fallback - contents = {fb.content for fb in result} - assert "content_0" in contents - assert "content_1" in contents - assert "content_2" in contents - - def test_invalid_item_ids_are_skipped_in_unique_ids(self, mock_deduplicator): - """Test that unparseable item IDs in unique_ids are skipped.""" - new_playbooks = [_make_user_playbook(0)] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[], unique_ids=["BADFORMAT", "NEW-0"] - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - # NEW-0 added via unique_ids, BADFORMAT skipped - assert len(result) == 1 - - def test_existing_only_unique_ids_not_added(self, mock_deduplicator): - """Test that EXISTING prefix in unique_ids does not add playbook.""" - new_playbooks = [_make_user_playbook(0)] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[], unique_ids=["EXISTING-0"] - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[_make_user_playbook(1, user_playbook_id=100)], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - # EXISTING-0 in unique_ids is ignored; NEW-0 added by safety net - contents = {fb.content for fb in result} - assert "content_0" in contents - - def test_out_of_range_new_index_in_unique_ids(self, mock_deduplicator): - """Test that out-of-range NEW index in unique_ids is safely ignored.""" - new_playbooks = [_make_user_playbook(0)] - - dedup_output = PlaybookDeduplicationOutput( - duplicate_groups=[], - unique_ids=["NEW-0", "NEW-99"], # 99 is out of range - ) - - result, _ = mock_deduplicator._build_deduplicated_results( - new_playbooks=new_playbooks, - existing_playbooks=[], - dedup_output=dedup_output, - request_id="req1", - agent_version="v1", - ) - - assert len(result) == 1 - - -class TestFormatItemsForPrompt: - """Tests for _format_items_for_prompt (delegates to _format_playbooks_with_prefix).""" - - def test_delegates_with_new_prefix(self, mock_deduplicator): - """Test that _format_items_for_prompt uses 'NEW' prefix.""" - playbooks = [_make_user_playbook(0)] - result = mock_deduplicator._format_items_for_prompt(playbooks) - assert "[NEW-0]" in result - - def test_empty_list(self, mock_deduplicator): - """Test that empty list returns '(None)'.""" - result = mock_deduplicator._format_items_for_prompt([]) - assert result == "(None)" - - -class TestFormatPlaybooksEdgeCases: - """Edge cases for _format_playbooks_with_prefix.""" - - def test_empty_playbook_name_shows_unknown(self, mock_deduplicator): - """Test that empty playbook_name displays as 'unknown'.""" - fb = UserPlaybook( - user_playbook_id=0, - agent_version="v1", - request_id="req1", - playbook_name="", - content="content", - ) - result = mock_deduplicator._format_playbooks_with_prefix([fb], "NEW") - assert "Name: unknown" in result - - def test_none_source_shows_unknown(self, mock_deduplicator): - """Test that None source displays as 'unknown'.""" - fb = UserPlaybook( - user_playbook_id=0, - agent_version="v1", - request_id="req1", - playbook_name="fb", - content="content", - source=None, - ) - result = mock_deduplicator._format_playbooks_with_prefix([fb], "NEW") - assert "Source: unknown" in result - - -class TestMockModeCheck: - """Tests for mock mode check in deduplicate.""" - - def test_mock_mode_handles_non_list_results(self, mock_deduplicator): - """Test that mock mode isinstance check filters non-list items.""" - fb = _make_user_playbook(0) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "true"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb]], request_id="req1", agent_version="v1" - ) - - assert len(result) == 1 - assert delete_ids == [] - - def test_mock_mode_case_insensitive(self, mock_deduplicator): - """Test that mock mode check is case insensitive.""" - fb = _make_user_playbook(0) - - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "True"}): - result, delete_ids = mock_deduplicator.deduplicate( - results=[[fb]], request_id="req1", agent_version="v1" - ) - - assert len(result) == 1 - assert delete_ids == [] - - def test_mock_mode_false_proceeds_normally(self, mock_deduplicator): - """Test that mock mode disabled runs full dedup path.""" - mock_deduplicator.client.get_embeddings.return_value = [[0.1]] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [] - mock_deduplicator.client.generate_chat_response.return_value = ( - PlaybookDeduplicationOutput(duplicate_groups=[], unique_ids=["NEW-0"]) - ) - - fb = _make_user_playbook(0) - with patch.dict("os.environ", {"MOCK_LLM_RESPONSE": "false"}): - result, _ = mock_deduplicator.deduplicate( - results=[[fb]], request_id="req1", agent_version="v1" - ) - - assert len(result) == 1 - - -class TestRetrieveExistingPlaybooksWithUserId: - """Tests for _retrieve_existing_playbooks with user_id filter.""" - - def test_user_id_passed_to_search(self, mock_deduplicator): - """Test that user_id is passed through to the search request.""" - new_fb = _make_user_playbook(0, trigger="user asks about billing") - existing_fb = _make_user_playbook(1, user_playbook_id=100) - - mock_deduplicator.client.get_embeddings.return_value = [[0.1]] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [ - existing_fb - ] - - mock_deduplicator._retrieve_existing_playbooks([new_fb], user_id="user_abc") - - # Verify search was called with user_id in the SearchUserPlaybookRequest - call_args = ( - mock_deduplicator.request_context.storage.search_user_playbooks.call_args - ) - search_request = call_args[0][0] - assert search_request.user_id == "user_abc" - - def test_none_user_id_passed_to_search(self, mock_deduplicator): - """Test that None user_id is passed through correctly.""" - new_fb = _make_user_playbook(0, trigger="some condition") - - mock_deduplicator.client.get_embeddings.return_value = [[0.1]] - mock_deduplicator.request_context.storage.search_user_playbooks.return_value = [] - - mock_deduplicator._retrieve_existing_playbooks([new_fb], user_id=None) - - call_args = ( - mock_deduplicator.request_context.storage.search_user_playbooks.call_args - ) - search_request = call_args[0][0] - assert search_request.user_id is None diff --git a/tests/server/services/playbook/test_structured_playbook_content.py b/tests/server/services/playbook/test_structured_playbook_content.py new file mode 100644 index 00000000..0a31118d --- /dev/null +++ b/tests/server/services/playbook/test_structured_playbook_content.py @@ -0,0 +1,25 @@ +"""Task 2.2: optional source_span/notes/reader_angle on StructuredPlaybookContent.""" + +from reflexio.server.services.playbook.playbook_service_utils import ( + StructuredPlaybookContent, +) + + +def test_structured_playbook_content_new_fields_default_to_none() -> None: + c = StructuredPlaybookContent(trigger="t", content="c", rationale="r") + assert c.source_span is None + assert c.notes is None + assert c.reader_angle is None + + +def test_structured_playbook_content_accepts_optional_fields() -> None: + c = StructuredPlaybookContent( + trigger="t", + content="c", + rationale="r", + source_span="quote", + notes="confidence=0.9", + reader_angle="trigger", + ) + assert c.source_span == "quote" + assert c.reader_angle == "trigger" diff --git a/tests/server/services/profile/test_profile_add_item.py b/tests/server/services/profile/test_profile_add_item.py new file mode 100644 index 00000000..9e618b88 --- /dev/null +++ b/tests/server/services/profile/test_profile_add_item.py @@ -0,0 +1,25 @@ +"""Task 2.1: optional source_span/notes/reader_angle on ProfileAddItem.""" + +from reflexio.server.services.profile.profile_generation_service_utils import ( + ProfileAddItem, +) + + +def test_profile_add_item_new_fields_default_to_none() -> None: + item = ProfileAddItem(content="x", time_to_live="infinity") + assert item.source_span is None + assert item.notes is None + assert item.reader_angle is None + + +def test_profile_add_item_accepts_optional_fields() -> None: + item = ProfileAddItem( + content="x", + time_to_live="infinity", + source_span="exact quote", + notes="high confidence", + reader_angle="facts", + ) + assert item.source_span == "exact quote" + assert item.notes == "high confidence" + assert item.reader_angle == "facts" diff --git a/tests/server/services/profile/test_profile_deduplicator.py b/tests/server/services/profile/test_profile_deduplicator.py deleted file mode 100644 index e118d1a0..00000000 --- a/tests/server/services/profile/test_profile_deduplicator.py +++ /dev/null @@ -1,1331 +0,0 @@ -""" -Unit tests for ProfileDeduplicator. - -Tests the deduplicator's responsibilities for: -- Pydantic output schema validation -- Profile deduplication with LLM and hybrid search -- Profile formatting for prompts -- Building deduplicated results -- Merging custom features -""" - -import uuid -from datetime import UTC, datetime -from unittest.mock import MagicMock, patch - -import pytest - - -# Disable mock mode for deduplicator tests so LLM mocks are actually used -@pytest.fixture(autouse=True) -def disable_mock_llm_response(monkeypatch): - """Disable MOCK_LLM_RESPONSE env var so deduplicator tests use their own mocks.""" - monkeypatch.delenv("MOCK_LLM_RESPONSE", raising=False) - - -from reflexio.models.api_schema.service_schemas import ( - ProfileTimeToLive, - UserProfile, -) -from reflexio.server.llm.litellm_client import LiteLLMClient -from reflexio.server.services.deduplication_utils import parse_item_id -from reflexio.server.services.profile.profile_deduplicator import ( - ProfileDeduplicationOutput, - ProfileDeduplicator, - ProfileDeletionDirective, - ProfileDuplicateGroup, - _format_profile_timestamp, -) - -# =============================== -# Fixtures -# =============================== - - -@pytest.fixture -def mock_llm_client(): - """Create a mock LLM client.""" - client = MagicMock(spec=LiteLLMClient) - client.get_embeddings.return_value = [[0.1] * 10, [0.2] * 10, [0.3] * 10] - return client - - -@pytest.fixture -def mock_request_context(): - """Create a mock request context with prompt manager and storage.""" - context = MagicMock( - spec_set=["prompt_manager", "storage", "configurator", "org_id"] - ) - context.prompt_manager = MagicMock() - context.prompt_manager.render_prompt.return_value = "test prompt" - context.storage = MagicMock() - context.storage.search_user_profile.return_value = [] - # Set up configurator chain for model resolution - mock_config = MagicMock() - mock_config.api_key_config = None - context.configurator.get_config.return_value = mock_config - return context - - -@pytest.fixture -def mock_site_var_manager(): - """Mock the SiteVarManager to return model settings.""" - with patch("reflexio.server.services.deduplication_utils.SiteVarManager") as mock: - instance = mock.return_value - instance.get_site_var.return_value = {"default_generation_model_name": "gpt-4"} - yield mock - - -@pytest.fixture -def sample_profiles(): - """Create sample UserProfile objects for testing.""" - timestamp = int(datetime.now(UTC).timestamp()) - return [ - UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content="User prefers dark mode for coding", - last_modified_timestamp=timestamp, - generated_from_request_id="req_1", - profile_time_to_live=ProfileTimeToLive.ONE_MONTH, - source="extractor_a", - ), - UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content="User likes dark theme in their IDE", - last_modified_timestamp=timestamp, - generated_from_request_id="req_2", - profile_time_to_live=ProfileTimeToLive.ONE_WEEK, - source="extractor_b", - ), - UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content="User is a Python developer", - last_modified_timestamp=timestamp, - generated_from_request_id="req_3", - profile_time_to_live=ProfileTimeToLive.ONE_YEAR, - source="extractor_a", - ), - ] - - -# =============================== -# Test: Pydantic Models -# =============================== - - -class TestPydanticModels: - """Tests for the Pydantic output schema models.""" - - def test_duplicate_group_creation(self): - """Test that ProfileDuplicateGroup can be created with valid data.""" - group = ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1", "EXISTING-0"], - merged_content="User prefers dark mode", - merged_time_to_live="one_month", - reasoning="Both profiles are about dark mode preferences", - ) - assert group.item_ids == ["NEW-0", "NEW-1", "EXISTING-0"] - assert group.merged_content == "User prefers dark mode" - assert group.merged_time_to_live == "one_month" - - def test_duplicate_group_forbids_extra_fields(self): - """Test that ProfileDuplicateGroup allows extra fields at runtime (for LLM robustness) - but forbids them in JSON schema (for LLM structured output).""" - # extra="allow" means Pydantic accepts extra fields at runtime - group = ProfileDuplicateGroup( - item_ids=["NEW-0"], - merged_content="test", - merged_time_to_live="one_day", - reasoning="test", - extra_field="not allowed", - ) - assert group.item_ids == ["NEW-0"] - # JSON schema should forbid additional properties (used for LLM structured output) - schema = ProfileDuplicateGroup.model_json_schema() - assert schema.get("additionalProperties") is False - - def test_deduplication_output_creation(self): - """Test that ProfileDeduplicationOutput can be created.""" - output = ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content="merged", - merged_time_to_live="one_week", - reasoning="duplicates", - ) - ], - unique_ids=["NEW-2", "NEW-3"], - ) - assert len(output.duplicate_groups) == 1 - assert output.unique_ids == ["NEW-2", "NEW-3"] - - def test_deduplication_output_empty_defaults(self): - """Test that ProfileDeduplicationOutput has empty list defaults.""" - output = ProfileDeduplicationOutput() - assert output.duplicate_groups == [] - assert output.unique_ids == [] - assert output.deletions == [] - - def test_deletion_directive_creation(self): - """Test that ProfileDeletionDirective can be created with valid data.""" - directive = ProfileDeletionDirective( - new_id="NEW-0", - existing_ids=["EXISTING-0", "EXISTING-1"], - reasoning="User asked to forget this topic", - ) - assert directive.new_id == "NEW-0" - assert directive.existing_ids == ["EXISTING-0", "EXISTING-1"] - assert directive.reasoning == "User asked to forget this topic" - - def test_deletion_directive_json_schema_forbids_extra(self): - """Test that ProfileDeletionDirective's JSON schema forbids additional properties.""" - schema = ProfileDeletionDirective.model_json_schema() - assert schema.get("additionalProperties") is False - - def test_deduplication_output_with_deletions(self): - """Test that ProfileDeduplicationOutput accepts deletions.""" - output = ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=[], - deletions=[ - ProfileDeletionDirective( - new_id="NEW-0", - existing_ids=["EXISTING-0"], - reasoning="deletion request", - ) - ], - ) - assert len(output.deletions) == 1 - assert output.deletions[0].new_id == "NEW-0" - - def test_deduplication_output_deletions_from_dict(self): - """Test that ProfileDeduplicationOutput with deletions validates from dict.""" - data = { - "duplicate_groups": [], - "unique_ids": [], - "deletions": [ - { - "new_id": "NEW-0", - "existing_ids": ["EXISTING-0"], - "reasoning": "forget request", - } - ], - } - output = ProfileDeduplicationOutput.model_validate(data) - assert len(output.deletions) == 1 - assert output.deletions[0].existing_ids == ["EXISTING-0"] - - def test_deduplication_output_from_dict(self): - """Test that ProfileDeduplicationOutput can be validated from dict.""" - data = { - "duplicate_groups": [ - { - "item_ids": ["NEW-0", "NEW-1", "EXISTING-0"], - "merged_content": "test", - "merged_time_to_live": "one_day", - "reasoning": "reason", - } - ], - "unique_ids": ["NEW-2"], - } - output = ProfileDeduplicationOutput.model_validate(data) - assert len(output.duplicate_groups) == 1 - assert output.unique_ids == ["NEW-2"] - - def test_parse_item_id_valid(self): - """Test parse_item_id with valid inputs.""" - assert parse_item_id("NEW-0") == ("NEW", 0) - assert parse_item_id("EXISTING-1") == ("EXISTING", 1) - assert parse_item_id("new-5") == ("NEW", 5) - - def test_parse_item_id_invalid(self): - """Test parse_item_id returns None for invalid inputs.""" - assert parse_item_id("INVALID-0") is None - assert parse_item_id("NOHYPHEN") is None - assert parse_item_id("NEW-abc") is None - - -# =============================== -# Test: ProfileDeduplicator Init -# =============================== - - -class TestProfileDeduplicatorInit: - """Tests for ProfileDeduplicator initialization.""" - - def test_init_sets_attributes( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test that __init__ sets all required attributes.""" - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - assert deduplicator.request_context == mock_request_context - assert deduplicator.client == mock_llm_client - assert deduplicator.model_name == "gpt-4" - - def test_init_uses_auto_detected_model_when_not_specified( - self, mock_request_context, mock_llm_client, monkeypatch - ): - """Test that init falls back to auto-detected model if not in site var.""" - # Clear all provider keys so only OPENAI_API_KEY is detected - for key in [ - "ANTHROPIC_API_KEY", - "GEMINI_API_KEY", - "DEEPSEEK_API_KEY", - "OPENROUTER_API_KEY", - "MINIMAX_API_KEY", - "DASHSCOPE_API_KEY", - "XAI_API_KEY", - "MOONSHOT_API_KEY", - "ZAI_API_KEY", - ]: - monkeypatch.delenv(key, raising=False) - monkeypatch.setenv("OPENAI_API_KEY", "sk-test") - with patch( - "reflexio.server.services.deduplication_utils.SiteVarManager" - ) as mock: - instance = mock.return_value - instance.get_site_var.return_value = {} - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - assert deduplicator.model_name == "gpt-5-mini" - - -# =============================== -# Test: Format Profiles For Prompt -# =============================== - - -class TestFormatProfilesForPrompt: - """Tests for profile formatting for LLM prompt.""" - - def test_format_profiles_basic( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that profiles are formatted correctly with NEW prefix.""" - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._format_items_for_prompt(sample_profiles) - - assert "[NEW-0]" in result - assert "[NEW-1]" in result - assert "[NEW-2]" in result - assert "User prefers dark mode for coding" in result - assert "User likes dark theme in their IDE" in result - assert "one_month" in result - assert "one_week" in result - assert "extractor_a" in result - assert "extractor_b" in result - - def test_format_profiles_uses_ttl_value( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test formatting shows TTL value from profile.""" - timestamp = int(datetime.now(UTC).timestamp()) - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test content", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - profile_time_to_live=ProfileTimeToLive.ONE_QUARTER, - ) - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._format_items_for_prompt(profiles) - assert "TTL: one_quarter" in result - - def test_format_profiles_with_missing_source( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test formatting with profiles that have no source.""" - timestamp = int(datetime.now(UTC).timestamp()) - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test content", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - source=None, - ) - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._format_items_for_prompt(profiles) - assert "Source: unknown" in result - - def test_format_existing_profiles( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that existing profiles are formatted with EXISTING prefix.""" - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._format_profiles_with_prefix(sample_profiles, "EXISTING") - assert "[EXISTING-0]" in result - assert "[EXISTING-1]" in result - - def test_format_empty_profiles( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test formatting empty profile list returns (None).""" - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._format_profiles_with_prefix([], "NEW") - assert result == "(None)" - - def test_format_profiles_includes_last_modified_utc( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test that formatted profiles include the last-modified timestamp in UTC.""" - # 1704067200 == 2024-01-01 00:00:00 UTC - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test content", - last_modified_timestamp=1704067200, - generated_from_request_id="req", - profile_time_to_live=ProfileTimeToLive.ONE_MONTH, - source="extractor_a", - ) - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._format_profiles_with_prefix(profiles, "NEW") - assert "Last Modified: 2024-01-01 00:00 UTC" in result - - def test_format_profiles_timestamp_fallback_on_invalid( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test formatting degrades gracefully when the timestamp is out of range.""" - # Absurdly large value that overflows datetime.fromtimestamp on every - # supported platform, but is still a valid ``int`` for the Pydantic - # model field. - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test content", - last_modified_timestamp=99999999999999999, - generated_from_request_id="req", - profile_time_to_live=ProfileTimeToLive.ONE_MONTH, - source="extractor_a", - ) - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - # Must not raise. - result = deduplicator._format_profiles_with_prefix(profiles, "NEW") - assert "Last Modified: unknown" in result - - def test_format_profile_timestamp_helper_happy_path(self): - """The helper formats a valid timestamp identically to the old inline call.""" - assert _format_profile_timestamp(1704067200) == "2024-01-01 00:00 UTC" - - def test_format_profile_timestamp_helper_fallback(self): - """The helper returns the sentinel when the timestamp is out of range.""" - assert _format_profile_timestamp(99999999999999999) == "unknown" - - -# =============================== -# Test: Merge Custom Features -# =============================== - - -class TestMergeCustomFeatures: - """Tests for custom features merging.""" - - def test_merge_custom_features_empty( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test merging when no profiles have custom features.""" - timestamp = int(datetime.now(UTC).timestamp()) - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - custom_features=None, - ), - UserProfile( - profile_id="2", - user_id="user", - content="test2", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - custom_features=None, - ), - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._merge_custom_features(profiles) - assert result is None - - def test_merge_custom_features_single( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test merging when only one profile has custom features.""" - timestamp = int(datetime.now(UTC).timestamp()) - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - custom_features={"key1": "value1"}, - ), - UserProfile( - profile_id="2", - user_id="user", - content="test2", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - custom_features=None, - ), - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._merge_custom_features(profiles) - assert result == {"key1": "value1"} - - def test_merge_custom_features_multiple( - self, mock_request_context, mock_llm_client, mock_site_var_manager - ): - """Test merging custom features from multiple profiles.""" - timestamp = int(datetime.now(UTC).timestamp()) - profiles = [ - UserProfile( - profile_id="1", - user_id="user", - content="test", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - custom_features={"key1": "value1", "key2": "old_value"}, - ), - UserProfile( - profile_id="2", - user_id="user", - content="test2", - last_modified_timestamp=timestamp, - generated_from_request_id="req", - custom_features={"key2": "new_value", "key3": "value3"}, - ), - ] - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result = deduplicator._merge_custom_features(profiles) - assert result == {"key1": "value1", "key2": "new_value", "key3": "value3"} - - -# =============================== -# Test: Build Deduplicated Results -# =============================== - - -class TestBuildDeduplicatedResults: - """Tests for building deduplicated profile results.""" - - def test_build_deduplicated_results_merges_duplicates( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that duplicates are merged into a single profile.""" - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content="User prefers dark mode in their IDE", - merged_time_to_live="one_month", - reasoning="Both about dark mode preferences", - ) - ], - unique_ids=["NEW-2"], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, delete_ids, superseded = ( - deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - ) - - assert len(result_profiles) == 2 # 1 merged + 1 unique - assert len(delete_ids) == 0 - assert len(superseded) == 0 - - # Find the merged profile - merged_profile = next( - ( - p - for p in result_profiles - if p.content == "User prefers dark mode in their IDE" - ), - None, - ) - assert merged_profile is not None - assert merged_profile.profile_time_to_live == ProfileTimeToLive.ONE_MONTH - - def test_build_deduplicated_results_preserves_unique( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that unique profiles are preserved.""" - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=["NEW-0", "NEW-1", "NEW-2"], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, delete_ids, superseded = ( - deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - ) - - assert len(result_profiles) == 3 - - def test_build_deduplicated_results_handles_invalid_ttl( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that invalid TTL from LLM falls back to template TTL.""" - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content="merged content", - merged_time_to_live="invalid_ttl", - reasoning="test", - ) - ], - unique_ids=["NEW-2"], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, _, _ = deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - - merged_profile = next( - (p for p in result_profiles if p.content == "merged content"), - None, - ) - assert merged_profile is not None - # Should fall back to template profile's TTL (first profile in group) - assert merged_profile.profile_time_to_live == ProfileTimeToLive.ONE_MONTH - - def test_build_deduplicated_results_handles_unmentioned_profiles( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that profiles not mentioned by LLM are added as-is.""" - # LLM only mentions indices 0 and 1, not 2 - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content="merged", - merged_time_to_live="one_week", - reasoning="test", - ) - ], - unique_ids=[], # LLM forgot to mention index 2 - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, _, _ = deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - - # Should still include all profiles (1 merged + 1 unmentioned) - assert len(result_profiles) == 2 - - def test_build_deduplicated_results_collects_existing_to_delete( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that existing profiles marked for deletion are collected.""" - timestamp = int(datetime.now(UTC).timestamp()) - existing_profile = UserProfile( - profile_id="existing_1", - user_id="test_user", - content="Old dark mode preference", - last_modified_timestamp=timestamp, - generated_from_request_id="old_req", - ) - - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "EXISTING-0"], - merged_content="User prefers dark mode (updated)", - merged_time_to_live="one_month", - reasoning="New profile supersedes existing", - ) - ], - unique_ids=["NEW-1", "NEW-2"], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, delete_ids, superseded = ( - deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[existing_profile], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - ) - - assert len(delete_ids) == 1 - assert delete_ids[0] == "existing_1" - assert len(superseded) == 1 - assert superseded[0].profile_id == "existing_1" - - def test_build_deduplicated_results_handles_deletion_directive( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """A deletion directive erases the EXISTING profile without writing a replacement. - - This is the core bug fix: "forget that I am interested in X" used to - produce a merged "Previously interested in X, but requested removal" - profile. With the deletion channel, the NEW directive is consumed and - the EXISTING profile is deleted outright. - """ - timestamp = int(datetime.now(UTC).timestamp()) - existing_profile = UserProfile( - profile_id="existing_old_interest", - user_id="test_user", - content="User is interested in self-improving agents", - last_modified_timestamp=timestamp, - generated_from_request_id="old_req", - ) - - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=["NEW-1", "NEW-2"], - deletions=[ - ProfileDeletionDirective( - new_id="NEW-0", - existing_ids=["EXISTING-0"], - reasoning=( - "NEW-0 is a meta-request to forget EXISTING-0; " - "not a fact about the user." - ), - ) - ], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, delete_ids, superseded = ( - deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[existing_profile], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - ) - - # EXISTING profile is marked for deletion. - assert delete_ids == ["existing_old_interest"] - assert len(superseded) == 1 - assert superseded[0].profile_id == "existing_old_interest" - - # NEW-0 (the directive) was consumed — not re-added by the safety fallback. - assert all(p.content != sample_profiles[0].content for p in result_profiles), ( - "Deletion directive NEW profile should not appear in result_profiles" - ) - - # Only NEW-1 and NEW-2 (the unrelated unique profiles) remain. - assert len(result_profiles) == 2 - assert {p.content for p in result_profiles} == { - sample_profiles[1].content, - sample_profiles[2].content, - } - - def test_build_deduplicated_results_deletion_directive_no_match( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """A deletion directive with empty existing_ids still consumes the NEW. - - If the LLM emits a deletion directive but matches no EXISTING profile, - the NEW profile must still be suppressed — a meta-statement like - "Requested removal of X" is not a fact worth storing on its own. - """ - dedup_output = ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=["NEW-1", "NEW-2"], - deletions=[ - ProfileDeletionDirective( - new_id="NEW-0", - existing_ids=[], - reasoning="No matching existing profile found.", - ) - ], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, delete_ids, superseded = ( - deduplicator._build_deduplicated_results( - new_profiles=sample_profiles, - existing_profiles=[], - dedup_output=dedup_output, - user_id="test_user", - request_id="test_request", - ) - ) - - assert delete_ids == [] - assert superseded == [] - # NEW-0 must not survive into result_profiles. - assert all(p.content != sample_profiles[0].content for p in result_profiles) - assert len(result_profiles) == 2 - - -# =============================== -# Test: Deduplicate Main Method -# =============================== - - -class TestDeduplicate: - """Tests for the main deduplicate() method.""" - - def test_deduplicate_returns_original_when_empty( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - ): - """Test that empty input returns empty output.""" - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=[], - user_id="test_user", - request_id="test_request", - ) - - assert profiles == [] - assert delete_ids == [] - assert superseded == [] - - def test_deduplicate_returns_original_when_no_duplicates_found( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that original profiles are returned when LLM finds no duplicates.""" - mock_llm_client.generate_chat_response.return_value = ( - ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=["NEW-0", "NEW-1", "NEW-2"], - ) - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=sample_profiles, - user_id="test_user", - request_id="test_request", - ) - - assert profiles == sample_profiles - assert delete_ids == [] - assert superseded == [] - - def test_deduplicate_returns_original_when_llm_fails( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that original profiles are returned when LLM call fails.""" - mock_llm_client.generate_chat_response.side_effect = Exception("LLM Error") - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=sample_profiles, - user_id="test_user", - request_id="test_request", - ) - - assert profiles == sample_profiles - assert delete_ids == [] - assert superseded == [] - - def test_deduplicate_merges_duplicates( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test that duplicates are properly merged.""" - mock_llm_client.generate_chat_response.return_value = ( - ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content="User prefers dark mode", - merged_time_to_live="one_month", - reasoning="Both about dark mode", - ) - ], - unique_ids=["NEW-2"], - ) - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=sample_profiles, - user_id="test_user", - request_id="test_request", - ) - - # Should have 2 profiles: 1 merged + 1 unique - assert len(profiles) == 2 - assert len(delete_ids) == 0 - - def test_deduplicate_with_existing_profiles_to_delete( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """Test deduplication that supersedes existing profiles.""" - timestamp = int(datetime.now(UTC).timestamp()) - existing_profile = UserProfile( - profile_id="existing_1", - user_id="test_user", - content="Old dark mode preference", - last_modified_timestamp=timestamp, - generated_from_request_id="old_req", - ) - - # Mock storage to return existing profile via hybrid search - mock_request_context.storage.search_user_profile.return_value = [ - existing_profile - ] - - mock_llm_client.generate_chat_response.return_value = ( - ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "EXISTING-0"], - merged_content="User prefers dark mode (updated)", - merged_time_to_live="one_month", - reasoning="New supersedes existing", - ) - ], - unique_ids=["NEW-1", "NEW-2"], - ) - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=sample_profiles, - user_id="test_user", - request_id="test_request", - ) - - assert len(profiles) == 3 # 1 merged + 2 unique - assert len(delete_ids) == 1 - assert delete_ids[0] == "existing_1" - assert len(superseded) == 1 - - def test_deduplicate_applies_deletions_when_no_duplicate_groups( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - sample_profiles, - ): - """A deletion-only LLM response must still erase the EXISTING profile. - - Regression guard: the public `deduplicate()` used to short-circuit when - `duplicate_groups` was empty, which silently dropped deletion directives - and returned the 'Requested removal of ...' NEW profile as a new fact — - the exact zombie-profile failure the deletion channel was meant to fix. - """ - timestamp = int(datetime.now(UTC).timestamp()) - existing_profile = UserProfile( - profile_id="existing_forgettable", - user_id="test_user", - content="User is interested in self-improving agents", - last_modified_timestamp=timestamp, - generated_from_request_id="old_req", - ) - directive_profile = UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content=( - "Requested removal of interest in self-improving agents " - "from stored profiles" - ), - last_modified_timestamp=timestamp, - generated_from_request_id="req_directive", - profile_time_to_live=ProfileTimeToLive.ONE_DAY, - source="extractor_a", - ) - - mock_request_context.storage.search_user_profile.return_value = [ - existing_profile - ] - mock_llm_client.generate_chat_response.return_value = ( - ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=[], - deletions=[ - ProfileDeletionDirective( - new_id="NEW-0", - existing_ids=["EXISTING-0"], - reasoning="Meta-request to forget EXISTING-0.", - ) - ], - ) - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=[directive_profile], - user_id="test_user", - request_id="test_request", - ) - - assert delete_ids == ["existing_forgettable"] - assert len(superseded) == 1 - assert superseded[0].profile_id == "existing_forgettable" - # The directive must be consumed — not leak back as a stored fact. - assert profiles == [] - - def test_deduplicate_strips_markers_on_llm_exception( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - ): - """When the LLM call raises, fallback must strip canonical deletion markers. - - Regression guard: if the LLM fails, returning `new_profiles` verbatim - would persist "Requested removal of …" markers as regular facts — the - exact zombie-profile outcome the deletion-directive channel was built - to prevent. The fallback must suppress markers while preserving - ordinary profiles. - """ - timestamp = int(datetime.now(UTC).timestamp()) - ordinary = UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content="User prefers dark mode", - last_modified_timestamp=timestamp, - generated_from_request_id="req_ok", - profile_time_to_live=ProfileTimeToLive.ONE_MONTH, - source="extractor_a", - ) - marker = UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content=( - "Requested removal of interest in self-improving agents " - "from stored profiles" - ), - last_modified_timestamp=timestamp, - generated_from_request_id="req_forget", - profile_time_to_live=ProfileTimeToLive.ONE_DAY, - source="extractor_a", - ) - - mock_request_context.storage.search_user_profile.return_value = [] - mock_llm_client.generate_chat_response.side_effect = RuntimeError( - "LLM unavailable" - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=[ordinary, marker], - user_id="test_user", - request_id="test_request", - ) - - assert delete_ids == [] - assert superseded == [] - assert [p.profile_id for p in profiles] == [ordinary.profile_id] - - def test_deduplicate_strips_markers_on_empty_output( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - ): - """Empty dedup output (no groups, no deletions) still strips markers. - - If the LLM returns nothing to act on but a marker profile is present in - `new_profiles`, the fallback must drop the marker rather than persist - it as a fact. - """ - timestamp = int(datetime.now(UTC).timestamp()) - ordinary = UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content="User prefers dark mode", - last_modified_timestamp=timestamp, - generated_from_request_id="req_ok", - profile_time_to_live=ProfileTimeToLive.ONE_MONTH, - source="extractor_a", - ) - marker = UserProfile( - profile_id=str(uuid.uuid4()), - user_id="test_user", - content="Requested removal of preference for tabs over spaces", - last_modified_timestamp=timestamp, - generated_from_request_id="req_forget", - profile_time_to_live=ProfileTimeToLive.ONE_DAY, - source="extractor_a", - ) - - mock_request_context.storage.search_user_profile.return_value = [] - mock_llm_client.generate_chat_response.return_value = ( - ProfileDeduplicationOutput( - duplicate_groups=[], - unique_ids=[], - deletions=[], - ) - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=[ordinary, marker], - user_id="test_user", - request_id="test_request", - ) - - assert delete_ids == [] - assert superseded == [] - assert [p.profile_id for p in profiles] == [ordinary.profile_id] - - -# =============================== -# Test: Integration -# =============================== - - -class TestIntegration: - """Integration tests for the complete deduplication flow.""" - - def test_full_deduplication_flow( - self, - mock_request_context, - mock_llm_client, - mock_site_var_manager, - ): - """Test a complete deduplication flow with realistic data.""" - timestamp = int(datetime.now(UTC).timestamp()) - - # Create profiles from different extractors with duplicates - new_profiles = [ - UserProfile( - profile_id="p1", - user_id="user", - content="User works in finance industry", - last_modified_timestamp=timestamp, - generated_from_request_id="req1", - profile_time_to_live=ProfileTimeToLive.ONE_YEAR, - source="industry_extractor", - custom_features={"sector": "finance"}, - ), - UserProfile( - profile_id="p2", - user_id="user", - content="User is in the financial services sector", - last_modified_timestamp=timestamp, - generated_from_request_id="req2", - profile_time_to_live=ProfileTimeToLive.ONE_MONTH, - source="job_extractor", - custom_features={"job_type": "analyst"}, - ), - UserProfile( - profile_id="p3", - user_id="user", - content="User prefers Python programming", - last_modified_timestamp=timestamp, - generated_from_request_id="req3", - profile_time_to_live=ProfileTimeToLive.INFINITY, - source="tech_extractor", - ), - ] - - mock_llm_client.generate_chat_response.return_value = ProfileDeduplicationOutput( - duplicate_groups=[ - ProfileDuplicateGroup( - item_ids=["NEW-0", "NEW-1"], - merged_content="User works in the financial services industry", - merged_time_to_live="one_year", - reasoning="Both profiles describe the user's industry as finance/financial services", - ) - ], - unique_ids=["NEW-2"], - ) - - deduplicator = ProfileDeduplicator( - request_context=mock_request_context, - llm_client=mock_llm_client, - ) - result_profiles, delete_ids, superseded = deduplicator.deduplicate( - new_profiles=new_profiles, - user_id="user", - request_id="test_request", - ) - - # Verify structure - assert len(result_profiles) == 2 - assert len(delete_ids) == 0 - - # Find merged profile - merged = next( - (p for p in result_profiles if "financial services industry" in p.content), - None, - ) - assert merged is not None - assert merged.user_id == "user" - assert merged.profile_time_to_live == ProfileTimeToLive.ONE_YEAR - # Custom features should be merged - assert merged.custom_features == {"sector": "finance", "job_type": "analyst"} - - # Find unique profile - unique = next((p for p in result_profiles if "Python" in p.content), None) - assert unique is not None - assert unique.content == "User prefers Python programming" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/server/services/profile/test_profile_generation_service.py b/tests/server/services/profile/test_profile_generation_service.py index 730289a3..2fe29049 100644 --- a/tests/server/services/profile/test_profile_generation_service.py +++ b/tests/server/services/profile/test_profile_generation_service.py @@ -337,17 +337,11 @@ def test_empty_nested_results_no_action(self, service, request_context): request_context.storage.add_user_profile.assert_not_called() - def test_save_profiles_dedup_disabled( - self, service, request_context, sample_profile - ): - """Profiles are saved directly when deduplicator is disabled.""" + def test_save_profiles(self, service, request_context, sample_profile): + """Profiles are saved with the correct source and status.""" self._setup_service_config(service) - with patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=False, - ): - service._process_results([[sample_profile]]) + service._process_results([[sample_profile]]) request_context.storage.add_user_profile.assert_called_once_with( "user_1", [sample_profile] @@ -365,122 +359,27 @@ def test_save_profiles_pending_status( source="rerun", ) - with patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=False, - ): - service_pending._process_results([[sample_profile]]) + service_pending._process_results([[sample_profile]]) assert sample_profile.status == Status.PENDING - def test_save_profiles_dedup_enabled( - self, service, request_context, sample_profile - ): - """Deduplicator is called when enabled and profiles exist.""" - self._setup_service_config(service) - - dedup_mock = MagicMock() - dedup_mock.deduplicate.return_value = ([sample_profile], ["old_p1"], []) - - with ( - patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=True, - ), - patch( - "reflexio.server.services.profile.profile_deduplicator.ProfileDeduplicator", - return_value=dedup_mock, - ), - ): - service._process_results([[sample_profile]]) - - dedup_mock.deduplicate.assert_called_once() - request_context.storage.add_user_profile.assert_called_once() - request_context.storage.delete_user_profile.assert_called_once() - - def test_dedup_with_pending_status_filter( - self, service_pending, request_context, sample_profile - ): - """Deduplicator uses PENDING status filter in rerun mode.""" - service_pending.service_config = ProfileGenerationServiceConfig( - user_id="user_1", - request_id="req_1", - source="rerun", - ) - - dedup_mock = MagicMock() - dedup_mock.deduplicate.return_value = ([sample_profile], [], []) - - with ( - patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=True, - ), - patch( - "reflexio.server.services.profile.profile_deduplicator.ProfileDeduplicator", - return_value=dedup_mock, - ), - ): - service_pending._process_results([[sample_profile]]) - - dedup_mock.deduplicate.assert_called_once() - def test_save_failure_returns_early(self, service, request_context, sample_profile): """When add_user_profile raises, the method returns without deleting.""" self._setup_service_config(service) request_context.storage.add_user_profile.side_effect = RuntimeError("DB error") - with patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=False, - ): - service._process_results([[sample_profile]]) + service._process_results([[sample_profile]]) request_context.storage.delete_user_profile.assert_not_called() request_context.storage.add_profile_change_log.assert_not_called() - def test_delete_superseded_failure_continues( - self, service, request_context, sample_profile - ): - """When deleting superseded profile fails, processing continues.""" - self._setup_service_config(service) - - dedup_mock = MagicMock() - dedup_mock.deduplicate.return_value = ( - [sample_profile], - ["old_p1", "old_p2"], - [], - ) - - request_context.storage.delete_user_profile.side_effect = RuntimeError( - "Delete error" - ) - - with ( - patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=True, - ), - patch( - "reflexio.server.services.profile.profile_deduplicator.ProfileDeduplicator", - return_value=dedup_mock, - ), - ): - service._process_results([[sample_profile]]) - - assert request_context.storage.delete_user_profile.call_count == 2 - def test_changelog_created_after_profiles_saved( self, service, request_context, sample_profile ): """Profile changelog is created when new profiles are saved.""" self._setup_service_config(service) - with patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=False, - ): - service._process_results([[sample_profile]]) + service._process_results([[sample_profile]]) request_context.storage.add_profile_change_log.assert_called_once() changelog = request_context.storage.add_profile_change_log.call_args[0][0] @@ -497,50 +396,10 @@ def test_changelog_failure_is_handled( "Changelog error" ) - with patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=False, - ): - service._process_results([[sample_profile]]) + service._process_results([[sample_profile]]) request_context.storage.add_user_profile.assert_called_once() - def test_changelog_with_superseded_profiles( - self, service, request_context, sample_profile - ): - """Changelog includes superseded (removed) profiles from deduplication.""" - self._setup_service_config(service) - - superseded = UserProfile( - profile_id="old_p1", - user_id="user_1", - content="old preference", - last_modified_timestamp=int(datetime.now(UTC).timestamp()), - generated_from_request_id="req_0", - ) - - dedup_mock = MagicMock() - dedup_mock.deduplicate.return_value = ( - [sample_profile], - [], - [superseded], - ) - - with ( - patch( - "reflexio.server.site_var.feature_flags.is_deduplicator_enabled", - return_value=True, - ), - patch( - "reflexio.server.services.profile.profile_deduplicator.ProfileDeduplicator", - return_value=dedup_mock, - ), - ): - service._process_results([[sample_profile]]) - - changelog = request_context.storage.add_profile_change_log.call_args[0][0] - assert changelog.removed_profiles == [superseded] - def test_no_changelog_when_no_profiles(self, service, request_context): """No changelog is created when there are no new or superseded profiles.""" self._setup_service_config(service) diff --git a/tests/server/services/profile/test_profile_generation_service_utils.py b/tests/server/services/profile/test_profile_generation_service_utils.py index 0c8a3d8e..c5063476 100644 --- a/tests/server/services/profile/test_profile_generation_service_utils.py +++ b/tests/server/services/profile/test_profile_generation_service_utils.py @@ -4,14 +4,17 @@ import pytest +from reflexio.models.api_schema.common import NEVER_EXPIRES_TIMESTAMP from reflexio.models.api_schema.internal_schema import RequestInteractionDataModel from reflexio.models.api_schema.service_schemas import ( Interaction, + ProfileTimeToLive, Request, UserProfile, ) from reflexio.server.prompt.prompt_manager import PromptManager from reflexio.server.services.profile.profile_generation_service_utils import ( + calculate_expiration_timestamp, construct_profile_extraction_messages_from_sessions, ) @@ -155,5 +158,34 @@ def test_construct_profile_extraction_messages_with_empty_sessions(): assert len(messages) > 0, "No messages were created for empty sessions" +def test_calculate_expiration_timestamp_infinity_returns_sentinel(): + """Infinity TTL must return the NEVER_EXPIRES_TIMESTAMP sentinel (Jan 1 2100), + not a `datetime.max`-derived year-9999 integer that would render as + 'Jan 1, 10000' after timezone conversion on the frontend. + """ + now = int(datetime.now(UTC).timestamp()) + assert ( + calculate_expiration_timestamp(now, ProfileTimeToLive.INFINITY) + == NEVER_EXPIRES_TIMESTAMP + ) + + +@pytest.mark.parametrize( + "ttl, expected_delta_seconds", + [ + (ProfileTimeToLive.ONE_DAY, 1 * 24 * 3600), + (ProfileTimeToLive.ONE_WEEK, 7 * 24 * 3600), + (ProfileTimeToLive.ONE_MONTH, 30 * 24 * 3600), + (ProfileTimeToLive.ONE_QUARTER, 90 * 24 * 3600), + (ProfileTimeToLive.ONE_YEAR, 365 * 24 * 3600), + ], +) +def test_calculate_expiration_timestamp_finite_ttls(ttl, expected_delta_seconds): + """Finite TTLs must shift last_modified forward by their documented delta.""" + now = int(datetime.now(UTC).timestamp()) + expiration = calculate_expiration_timestamp(now, ttl) + assert expiration == now + expected_delta_seconds + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/server/services/search/__init__.py b/tests/server/services/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/server/services/search/test_agentic_search_service.py b/tests/server/services/search/test_agentic_search_service.py new file mode 100644 index 00000000..47cf0eba --- /dev/null +++ b/tests/server/services/search/test_agentic_search_service.py @@ -0,0 +1,138 @@ +"""Integration tests for AgenticSearchService — populated entity lists + agent_answer.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest + +from reflexio.models.api_schema.retriever_schema import UnifiedSearchRequest + + +def _mk_tc(id_, name, args): + tc = MagicMock() + tc.id = id_ + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args) + return tc + + +def _mk_resp(tool_calls): + r = MagicMock() + r.tool_calls = tool_calls + r.content = None + return r + + +@pytest.fixture +def temp_storage(tmp_path): + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + return SQLiteStorage(org_id="svc-test", db_path=str(tmp_path / "svc.db")) + + +def test_agentic_search_populates_profiles_from_trace(temp_storage): + """Agent searches profiles; service fetches and returns matching profile objects.""" + from reflexio.models.api_schema.domain.entities import ( + NEVER_EXPIRES_TIMESTAMP, + UserProfile, + ) + from reflexio.models.api_schema.domain.enums import ProfileTimeToLive + + temp_storage.add_user_profile( + "u_1", + [ + UserProfile( + profile_id="p_seed_1", + user_id="u_1", + content="user likes sushi", + last_modified_timestamp=0, + generated_from_request_id="r_1", + profile_time_to_live=ProfileTimeToLive.INFINITY, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + extractor_names=["test"], + ), + ], + ) + + client = MagicMock() + client.config = MagicMock() + client.config.api_key_config = None + client.generate_chat_response.side_effect = [ + _mk_resp( + [_mk_tc("c1", "search_user_profiles", {"query": "sushi", "top_k": 10})] + ), + _mk_resp([_mk_tc("c2", "finish", {"answer": "sushi lover"})]), + ] + + import tempfile + + from reflexio.server.api_endpoints.request_context import RequestContext + + with tempfile.TemporaryDirectory() as d: + rc = RequestContext(org_id="svc-test", storage_base_dir=d) + rc.storage = temp_storage # type: ignore[attr-defined] + + from reflexio.server.services.search.agentic_search_service import ( + AgenticSearchService, + ) + + svc = AgenticSearchService(llm_client=client, request_context=rc) + + request = UnifiedSearchRequest( + query="what does user like?", + user_id="u_1", + agent_version="v1", + top_k=5, + enable_agent_answer=True, + ) + response = svc.search(request) + + assert response.success is True + assert response.agent_answer == "sushi lover" + assert response.msg is None + assert len(response.profiles) == 1 + assert response.profiles[0].profile_id == "p_seed_1" + assert response.user_playbooks == [] + assert response.agent_playbooks == [] + + +def test_agentic_search_empty_when_agent_searches_nothing(temp_storage): + """Agent finishes without searching; service returns empty entity lists.""" + client = MagicMock() + client.config = MagicMock() + client.config.api_key_config = None + client.generate_chat_response.side_effect = [ + _mk_resp([_mk_tc("c1", "finish", {"answer": "no evidence"})]), + ] + + import tempfile + + from reflexio.server.api_endpoints.request_context import RequestContext + + with tempfile.TemporaryDirectory() as d: + rc = RequestContext(org_id="svc-test", storage_base_dir=d) + rc.storage = temp_storage # type: ignore[attr-defined] + + from reflexio.server.services.search.agentic_search_service import ( + AgenticSearchService, + ) + + svc = AgenticSearchService(llm_client=client, request_context=rc) + + request = UnifiedSearchRequest( + query="anything?", + user_id="u_nobody", + agent_version="v1", + top_k=5, + enable_agent_answer=True, + ) + response = svc.search(request) + + assert response.success is True + assert response.agent_answer == "no evidence" + assert response.profiles == [] + assert response.user_playbooks == [] + assert response.agent_playbooks == [] diff --git a/tests/server/services/search/test_rerank_integration.py b/tests/server/services/search/test_rerank_integration.py new file mode 100644 index 00000000..178c3f0d --- /dev/null +++ b/tests/server/services/search/test_rerank_integration.py @@ -0,0 +1,188 @@ +"""Integration tests for the cross-encoder rerank tool + Reflexio.rerank_user_profiles. + +Uses real SQLite storage in a temp dir and the real cross-encoder model — slow +on first run (model download) but cached afterwards under +``~/.cache/huggingface/``. The model is a 22M-param MS-MARCO MiniLM, ~50 ms for +K=30 on CPU, so steady-state cost is small enough to keep these tests in the +default integration tier (no ``@skip_in_precommit``). +""" + +from __future__ import annotations + +import pytest + +from reflexio.models.api_schema.domain.entities import ( + NEVER_EXPIRES_TIMESTAMP, + UserProfile, +) +from reflexio.models.api_schema.domain.enums import ProfileTimeToLive +from reflexio.models.api_schema.retriever_schema import RerankUserProfilesRequest +from reflexio.server.services.extraction.plan import ExtractionCtx +from reflexio.server.services.extraction.tools import ( + RerankUserProfilesArgs, + _handle_rerank_user_profiles, +) + +pytestmark = pytest.mark.integration + + +_RELEVANT_CONTENTS = [ + "User loves Italian pasta and pizza", + "User prefers spicy ramen with pork broth", + "User is allergic to peanuts", +] +_IRRELEVANT_CONTENTS = [ + "User uses a Linux laptop for development", + "User commutes by bicycle on weekdays", + "User watches NBA games on Sunday evenings", +] + + +@pytest.fixture +def seeded_storage(tmp_path): + """SQLite storage with three relevant + three irrelevant profiles.""" + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + storage = SQLiteStorage(org_id="rerank-test", db_path=str(tmp_path / "rerank.db")) + profiles = [] + for idx, content in enumerate(_RELEVANT_CONTENTS): + profiles.append( + UserProfile( + user_id="u_rerank", + profile_id=f"rel_{idx}", + content=content, + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_000 + idx, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="req_test", + ) + ) + for idx, content in enumerate(_IRRELEVANT_CONTENTS): + profiles.append( + UserProfile( + user_id="u_rerank", + profile_id=f"irr_{idx}", + content=content, + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_100 + idx, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="req_test", + ) + ) + storage.add_user_profile("u_rerank", profiles) + return storage + + +@pytest.fixture +def ctx(): + return ExtractionCtx(user_id="u_rerank", agent_version="v1", extractor_name="x") + + +def test_rerank_handler_surfaces_relevant_profile_above_irrelevant(seeded_storage, ctx): + """Cross-encoder must rank a food-related profile above unrelated profiles.""" + all_ids = [f"rel_{i}" for i in range(3)] + [f"irr_{i}" for i in range(3)] + result = _handle_rerank_user_profiles( + RerankUserProfilesArgs( + query="What food does the user enjoy?", + profile_ids=all_ids, + top_k=3, + ), + seeded_storage, + ctx, + ) + hit_ids = [hit["id"] for hit in result["hits"]] + assert len(hit_ids) == 3, f"top_k=3 should return 3 hits, got {hit_ids!r}" + # The top hit should be one of the food-related profiles. + assert hit_ids[0].startswith("rel_"), ( + f"expected food-related profile at top, got id={hit_ids[0]!r}; all={hit_ids!r}" + ) + # The handler bumps search_count for budgeting parity with search. + assert ctx.search_count == 1 + + +def test_rerank_handler_silently_drops_unknown_ids(seeded_storage, ctx): + """Unknown profile_ids must be dropped without error.""" + result = _handle_rerank_user_profiles( + RerankUserProfilesArgs( + query="pasta", + profile_ids=["rel_0", "does-not-exist", "neither-does-this"], + top_k=10, + ), + seeded_storage, + ctx, + ) + hit_ids = {hit["id"] for hit in result["hits"]} + assert hit_ids == {"rel_0"} + + +def test_rerank_handler_respects_top_k(seeded_storage, ctx): + """top_k must cap the number of returned hits even with more candidates.""" + all_ids = [f"rel_{i}" for i in range(3)] + [f"irr_{i}" for i in range(3)] + result = _handle_rerank_user_profiles( + RerankUserProfilesArgs( + query="food preferences", + profile_ids=all_ids, + top_k=2, + ), + seeded_storage, + ctx, + ) + assert len(result["hits"]) == 2 + + +def test_rerank_handler_empty_input_returns_empty(seeded_storage, ctx): + """Empty profile_ids must short-circuit without calling the model.""" + result = _handle_rerank_user_profiles( + RerankUserProfilesArgs(query="anything", profile_ids=[], top_k=5), + seeded_storage, + ctx, + ) + assert result == {"hits": []} + assert ctx.search_count == 1 + + +def test_reflexio_rerank_user_profiles_returns_response(tmp_path): + """The Reflexio facade method should wire request -> handler -> response.""" + from reflexio.lib.reflexio_lib import Reflexio + + reflexio = Reflexio(org_id="rerank-facade", storage_base_dir=str(tmp_path)) + storage = reflexio._get_storage() + storage.add_user_profile( + "u_facade", + [ + UserProfile( + user_id="u_facade", + profile_id="food", + content="user loves Italian pasta", + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_000, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="req", + ), + UserProfile( + user_id="u_facade", + profile_id="commute", + content="user bikes to work", + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_001, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="req", + ), + ], + ) + + response = reflexio.rerank_user_profiles( + RerankUserProfilesRequest( + user_id="u_facade", + query="What food does the user like?", + profile_ids=["food", "commute"], + top_k=2, + ) + ) + assert response.success is True + ids = [p.profile_id for p in response.user_profiles] + assert ids[0] == "food", f"expected food profile first, got {ids!r}" diff --git a/tests/server/services/search/test_search_agent.py b/tests/server/services/search/test_search_agent.py new file mode 100644 index 00000000..f2c4d5dd --- /dev/null +++ b/tests/server/services/search/test_search_agent.py @@ -0,0 +1,260 @@ +"""Integration tests for SearchAgent (read-only single loop).""" + +import json +from unittest.mock import MagicMock + +import pytest + +from reflexio.server.services.search.search_agent import SearchAgent + + +@pytest.fixture +def temp_storage(tmp_path): + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + # NOTE: SQLiteStorage requires org_id + db_path kwargs (not a single positional). + return SQLiteStorage(org_id="test-org", db_path=str(tmp_path / "srch.db")) + + +@pytest.fixture +def prompt_manager(): + from reflexio.server.prompt.prompt_manager import PromptManager + + return PromptManager() + + +@pytest.fixture +def llm_client(): + c = MagicMock() + c.config = MagicMock() + c.config.api_key_config = None + return c + + +def _mk_tc(id_, name, args): + tc = MagicMock() + tc.id = id_ + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args) + return tc + + +def _mk_resp(tool_calls, content=None): + r = MagicMock() + r.tool_calls = tool_calls + r.content = content + return r + + +def test_search_agent_returns_answer_from_finish( + temp_storage, prompt_manager, llm_client +): + llm_client.generate_chat_response.side_effect = [ + _mk_resp( + [_mk_tc("c1", "search_user_profiles", {"query": "food", "top_k": 10})] + ), + _mk_resp([_mk_tc("c2", "finish", {"answer": "no evidence in memory"})]), + ] + + agent = SearchAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + enable_agent_answer=True, + ) + result = agent.run( + user_id="u_1", agent_version="v1", query="what do I like to eat?" + ) + assert result.answer == "no evidence in memory" + + +def test_search_agent_reads_agent_playbooks(temp_storage, prompt_manager, llm_client): + """Search agent can fall through to AgentPlaybooks.""" + llm_client.generate_chat_response.side_effect = [ + _mk_resp([_mk_tc("c1", "search_user_playbooks", {"query": "x", "top_k": 10})]), + _mk_resp([_mk_tc("c2", "search_agent_playbooks", {"query": "x", "top_k": 10})]), + _mk_resp([_mk_tc("c3", "finish", {"answer": "fallback answer"})]), + ] + agent = SearchAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + enable_agent_answer=True, + ) + r = agent.run(user_id="u_1", agent_version="v1", query="x") + assert r.answer == "fallback answer" + + +def test_search_agent_reports_budget_exceeded_on_max_steps( + temp_storage, prompt_manager, llm_client +): + """Loop hits max_steps without ever calling finish — budget_exceeded is True.""" + llm_client.generate_chat_response.side_effect = [ + _mk_resp([_mk_tc(f"c{i}", "search_user_profiles", {"query": "x", "top_k": 10})]) + for i in range(5) + ] + agent = SearchAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + max_steps=2, + enable_agent_answer=True, + ) + r = agent.run(user_id="u_1", agent_version="v1", query="x") + assert r.outcome == "max_steps" + assert r.budget_exceeded is True + assert r.answer == "no answer" + + +def test_search_agent_search_only_mode_returns_none_answer( + temp_storage, prompt_manager, llm_client +): + """When ``enable_agent_answer=False`` (default), the agent's answer is + forced to None even if the LLM produced one. Callers (the host) synthesize + the final response from the entities harvested by the search agent. + """ + llm_client.generate_chat_response.side_effect = [ + _mk_resp([_mk_tc("c1", "search_user_profiles", {"query": "x", "top_k": 10})]), + # LLM still emits an answer in the mock; the agent must drop it. + _mk_resp([_mk_tc("c2", "finish", {"answer": "ignored"})]), + ] + agent = SearchAgent( + client=llm_client, storage=temp_storage, prompt_manager=prompt_manager + ) + r = agent.run(user_id="u_so", agent_version="v1", query="anything?") + assert r.answer is None + # Search-only mode must still let the agent finish cleanly. + assert r.outcome == "finish_tool" + + +def test_search_agent_prompt_includes_search_only_block_when_disabled(prompt_manager): + """Rendered prompt carries the search-only mode flag verbatim so the LLM + can branch its finish() call accordingly. + """ + rendered = prompt_manager.render_prompt( + "search_agent", + variables={ + "query": "x", + "max_steps": "3", + "enable_agent_answer": "false", + }, + ) + assert "enable_agent_answer = false" in rendered + assert "Search-only output rule" in rendered + + +def test_search_agent_prompt_includes_answer_block_when_enabled(prompt_manager): + """Rendered prompt carries the synthesis flag when the host opts in.""" + rendered = prompt_manager.render_prompt( + "search_agent", + variables={ + "query": "x", + "max_steps": "3", + "enable_agent_answer": "true", + }, + ) + assert "enable_agent_answer = true" in rendered + assert "Expected answer format" in rendered + + +def test_search_agent_trace_captures_harvested_ids( + temp_storage, prompt_manager, llm_client +): + """Trace contains search turn results — used by AgenticSearchService for entity harvesting.""" + from reflexio.models.api_schema.domain.entities import ( + NEVER_EXPIRES_TIMESTAMP, + UserProfile, + ) + from reflexio.models.api_schema.domain.enums import ProfileTimeToLive + + temp_storage.add_user_profile( + "u_1", + [ + UserProfile( + profile_id="p_seed_1", + user_id="u_1", + content="user likes sushi", + last_modified_timestamp=0, + generated_from_request_id="r_1", + profile_time_to_live=ProfileTimeToLive.INFINITY, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + extractor_names=["test"], + ), + ], + ) + + llm_client.generate_chat_response.side_effect = [ + _mk_resp( + [_mk_tc("c1", "search_user_profiles", {"query": "food", "top_k": 10})] + ), + _mk_resp([_mk_tc("c2", "finish", {"answer": "user likes sushi"})]), + ] + + agent = SearchAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + enable_agent_answer=True, + ) + result = agent.run(user_id="u_1", agent_version="v1", query="what does user like?") + + # trace.turns should contain at least the search turn + assert len(result.trace.turns) >= 1 + search_turns = [ + t for t in result.trace.turns if t.tool_name == "search_user_profiles" + ] + assert search_turns + + +def test_search_agent_prompt_frames_agent_improvement(prompt_manager): + """Sanity: search prompt opening must frame retrieval around informing + the agent's next action, not 'memory query'.""" + out = prompt_manager.render_prompt( + "search_agent", + variables={ + "query": "what does user like?", + "max_steps": "3", + "enable_agent_answer": "false", + }, + ) + assert "helping an AI agent" in out or "inform" in out + assert "memory query agent" not in out.lower() + + +def test_search_agent_emits_summary_info_line( + caplog, temp_storage, prompt_manager, llm_client +): + """Each run emits ONE INFO line starting with 'search_agent ' that + contains elapsed_ms, turns, outcome, answer_len, and usage.""" + import logging + + llm_client.generate_chat_response.side_effect = [ + _mk_resp( + [_mk_tc("c1", "search_user_profiles", {"query": "food", "top_k": 10})] + ), + _mk_resp([_mk_tc("c2", "finish", {"answer": "user likes sushi"})]), + ] + + agent = SearchAgent( + client=llm_client, + storage=temp_storage, + prompt_manager=prompt_manager, + enable_agent_answer=True, + ) + + with caplog.at_level( + logging.INFO, logger="reflexio.server.services.search.search_agent" + ): + agent.run(user_id="u_summary", agent_version="v1", query="what do I like?") + + summary = [r for r in caplog.records if r.getMessage().startswith("search_agent ")] + assert len(summary) == 1, ( + f"Expected 1 summary line, got: {[r.getMessage() for r in summary]}" + ) + msg = summary[0].getMessage() + assert "elapsed_ms=" in msg + assert "turns=" in msg + assert "outcome=" in msg + assert "answer_len=" in msg + assert "usage={" in msg diff --git a/tests/server/services/search/test_storage_stats_integration.py b/tests/server/services/search/test_storage_stats_integration.py new file mode 100644 index 00000000..8504ec24 --- /dev/null +++ b/tests/server/services/search/test_storage_stats_integration.py @@ -0,0 +1,138 @@ +"""Integration tests for storage_stats — Reflexio facade + tool handler.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from reflexio.models.api_schema.domain.entities import ( + NEVER_EXPIRES_TIMESTAMP, + UserPlaybook, + UserProfile, +) +from reflexio.models.api_schema.domain.enums import ProfileTimeToLive +from reflexio.models.api_schema.retriever_schema import StorageStatsRequest +from reflexio.server.services.extraction.plan import ExtractionCtx +from reflexio.server.services.extraction.tools import ( + StorageStatsArgs, + _handle_storage_stats, +) + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def storage_with_data(tmp_path): + """Storage seeded with two profiles (different timestamps) + one playbook.""" + from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + + storage = SQLiteStorage(org_id="stats-test", db_path=str(tmp_path / "stats.db")) + storage.add_user_profile( + "u_with", + [ + UserProfile( + user_id="u_with", + profile_id="p_old", + content="old content", + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_000, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="r", + ), + UserProfile( + user_id="u_with", + profile_id="p_new", + content="new content", + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_001_000, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="r", + ), + ], + ) + storage.save_user_playbooks( + [ + UserPlaybook( + user_playbook_id=0, + user_id="u_with", + agent_version="v1", + request_id="r", + playbook_name="p", + content="content", + trigger="trigger", + ) + ] + ) + return storage + + +def test_handler_counts_match(storage_with_data): + ctx = ExtractionCtx(user_id="u_with", agent_version="v1", extractor_name="p") + result = _handle_storage_stats(StorageStatsArgs(), storage_with_data, ctx) + assert result["profile_count"] == 2 + assert result["playbook_count"] == 1 + assert ctx.search_count == 0 # storage_stats does NOT bump search_count + + +def test_handler_returns_iso_timestamp_range(storage_with_data): + ctx = ExtractionCtx(user_id="u_with", agent_version="v1", extractor_name="p") + result = _handle_storage_stats(StorageStatsArgs(), storage_with_data, ctx) + expected_oldest = datetime.fromtimestamp(1_700_000_000, tz=UTC).isoformat() + expected_newest = datetime.fromtimestamp(1_700_001_000, tz=UTC).isoformat() + assert result["oldest_profile_modified"] == expected_oldest + assert result["newest_profile_modified"] == expected_newest + + +def test_handler_returns_null_timestamps_for_empty_user(storage_with_data): + ctx = ExtractionCtx(user_id="u_no_data", agent_version="v1", extractor_name="p") + result = _handle_storage_stats(StorageStatsArgs(), storage_with_data, ctx) + assert result["profile_count"] == 0 + assert result["playbook_count"] == 0 + assert result["oldest_profile_modified"] is None + assert result["newest_profile_modified"] is None + + +def test_reflexio_storage_stats_facade(tmp_path): + """The Reflexio facade method should populate every response field correctly.""" + from reflexio.lib.reflexio_lib import Reflexio + + reflexio = Reflexio(org_id="stats-facade", storage_base_dir=str(tmp_path)) + storage = reflexio._get_storage() + storage.add_user_profile( + "u_face", + [ + UserProfile( + user_id="u_face", + profile_id="p1", + content="profile one", + profile_time_to_live=ProfileTimeToLive.INFINITY, + last_modified_timestamp=1_700_000_000, + expiration_timestamp=NEVER_EXPIRES_TIMESTAMP, + source="test", + generated_from_request_id="r", + ), + ], + ) + response = reflexio.storage_stats(StorageStatsRequest(user_id="u_face")) + assert response.success is True + assert response.profile_count == 1 + assert response.playbook_count == 0 + assert response.oldest_profile_modified is not None + assert response.newest_profile_modified is not None + assert response.oldest_profile_modified == response.newest_profile_modified + + +def test_reflexio_storage_stats_empty_user(tmp_path): + """Empty user returns success with zeros and null timestamps.""" + from reflexio.lib.reflexio_lib import Reflexio + + reflexio = Reflexio(org_id="stats-empty", storage_base_dir=str(tmp_path)) + response = reflexio.storage_stats(StorageStatsRequest(user_id="ghost")) + assert response.success is True + assert response.profile_count == 0 + assert response.playbook_count == 0 + assert response.oldest_profile_modified is None + assert response.newest_profile_modified is None diff --git a/tests/server/services/storage/sqlite_storage/__init__.py b/tests/server/services/storage/sqlite_storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/server/services/storage/sqlite_storage/test_agentic_signals.py b/tests/server/services/storage/sqlite_storage/test_agentic_signals.py new file mode 100644 index 00000000..a0aee9cc --- /dev/null +++ b/tests/server/services/storage/sqlite_storage/test_agentic_signals.py @@ -0,0 +1,79 @@ +"""Task 2.4: agentic signal columns persist through profiles + user_playbooks.""" + +from __future__ import annotations + +import sqlite3 +from unittest.mock import patch + +import pytest + +from reflexio.server.services.storage.sqlite_storage import SQLiteStorage + +pytestmark = pytest.mark.integration + + +def _get_columns(db_path: str, table: str) -> set[str]: + conn = sqlite3.connect(db_path) + try: + return { + row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall() + } + finally: + conn.close() + + +def test_fresh_schema_has_agentic_signal_columns(tmp_path): + """Fresh SQLiteStorage DBs include source_span/notes/reader_angle on both tables.""" + db_path = str(tmp_path / "fresh.db") + with patch.object(SQLiteStorage, "_get_embedding", return_value=[0.0] * 512): + SQLiteStorage(org_id="test_fresh", db_path=db_path) + assert {"source_span", "notes", "reader_angle"} <= _get_columns(db_path, "profiles") + assert {"source_span", "notes", "reader_angle"} <= _get_columns( + db_path, "user_playbooks" + ) + + +def test_migration_adds_columns_to_legacy_db(tmp_path): + """A pre-existing DB without the new columns gets them added at startup. + + The legacy schema simulates a DB created just before the agentic signal + columns were introduced — all existing columns are present, but + source_span/notes/reader_angle are absent. + """ + db_path = str(tmp_path / "legacy.db") + conn = sqlite3.connect(db_path) + # Profiles table without source_span/notes/reader_angle + conn.execute( + "CREATE TABLE profiles (" + "profile_id TEXT PRIMARY KEY, user_id TEXT NOT NULL, " + "content TEXT NOT NULL DEFAULT '', " + "last_modified_timestamp INTEGER NOT NULL, " + "generated_from_request_id TEXT NOT NULL DEFAULT '', " + "profile_time_to_live TEXT NOT NULL DEFAULT 'infinity', " + "expiration_timestamp INTEGER NOT NULL DEFAULT 4102444800, " + "custom_features TEXT, embedding TEXT, " + "source TEXT DEFAULT '', status TEXT, extractor_names TEXT, " + "expanded_terms TEXT, " + "created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')))" + ) + # user_playbooks table without source_span/notes/reader_angle + conn.execute( + "CREATE TABLE user_playbooks (" + "user_playbook_id INTEGER PRIMARY KEY AUTOINCREMENT, " + "user_id TEXT, playbook_name TEXT NOT NULL DEFAULT '', " + "created_at TEXT NOT NULL, request_id TEXT NOT NULL, " + "agent_version TEXT NOT NULL DEFAULT '', " + "content TEXT NOT NULL DEFAULT '', trigger TEXT, rationale TEXT, " + "blocking_issue TEXT, source_interaction_ids TEXT, " + "status TEXT, source TEXT, embedding TEXT, expanded_terms TEXT)" + ) + conn.commit() + conn.close() + + with patch.object(SQLiteStorage, "_get_embedding", return_value=[0.0] * 512): + SQLiteStorage(org_id="test_legacy", db_path=db_path) + + assert {"source_span", "notes", "reader_angle"} <= _get_columns(db_path, "profiles") + assert {"source_span", "notes", "reader_angle"} <= _get_columns( + db_path, "user_playbooks" + ) diff --git a/tests/server/services/test_extractor_interaction_utils.py b/tests/server/services/test_extractor_interaction_utils.py index c2ccdf20..52192886 100644 --- a/tests/server/services/test_extractor_interaction_utils.py +++ b/tests/server/services/test_extractor_interaction_utils.py @@ -321,7 +321,9 @@ def test_empty_list_yields_nothing(self): def test_single_model_fits_in_window(self): """Test single model that fits in window yields one window.""" models = [_create_mock_request_interaction_model(5)] - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=5)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=5) + ) assert len(windows) == 1 assert windows[0][0] == 0 # window index @@ -333,7 +335,9 @@ def test_multiple_models_fit_in_one_window(self): _create_mock_request_interaction_model(3), _create_mock_request_interaction_model(4), ] - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=5)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=5) + ) assert len(windows) == 1 assert windows[0][0] == 0 @@ -352,7 +356,9 @@ def test_basic_sliding_window(self): _create_mock_request_interaction_model(10), _create_mock_request_interaction_model(10), ] - windows = list(iter_sliding_windows(models, batch_size=15, batch_interval_size=10)) + windows = list( + iter_sliding_windows(models, batch_size=15, batch_interval_size=10) + ) assert len(windows) == 3 # Window 0: covers [0-14], includes models[0] and models[1] @@ -372,7 +378,9 @@ def test_non_overlapping_windows(self): _create_mock_request_interaction_model(10), _create_mock_request_interaction_model(10), ] - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=10)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=10) + ) assert len(windows) == 3 # Each window should contain exactly one model @@ -388,7 +396,9 @@ def test_stride_larger_than_window(self): _create_mock_request_interaction_model(10), ] # batch_size=5, stride=15 means windows at positions 0-4, 15-19 - windows = list(iter_sliding_windows(models, batch_size=5, batch_interval_size=15)) + windows = list( + iter_sliding_windows(models, batch_size=5, batch_interval_size=15) + ) assert len(windows) == 2 # Window 0: covers 0-4, only models[0] @@ -401,7 +411,9 @@ def test_stride_larger_than_window(self): def test_invalid_window_size_zero(self): """Test that batch_size=0 yields single window with all data.""" models = [_create_mock_request_interaction_model(10)] - windows = list(iter_sliding_windows(models, batch_size=0, batch_interval_size=5)) + windows = list( + iter_sliding_windows(models, batch_size=0, batch_interval_size=5) + ) assert len(windows) == 1 assert windows[0][1] == models @@ -409,7 +421,9 @@ def test_invalid_window_size_zero(self): def test_invalid_window_size_negative(self): """Test that negative window_size yields single window with all data.""" models = [_create_mock_request_interaction_model(10)] - windows = list(iter_sliding_windows(models, batch_size=-5, batch_interval_size=5)) + windows = list( + iter_sliding_windows(models, batch_size=-5, batch_interval_size=5) + ) assert len(windows) == 1 assert windows[0][1] == models @@ -421,7 +435,9 @@ def test_stride_zero_defaults_to_window_size(self): _create_mock_request_interaction_model(10), ] # stride=0 should default to batch_size=10, yielding 2 non-overlapping windows - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=0)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=0) + ) assert len(windows) == 2 @@ -432,7 +448,9 @@ def test_stride_none_defaults_to_window_size(self): _create_mock_request_interaction_model(10), ] # stride=None should default to batch_size=10 - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=None)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=None) + ) assert len(windows) == 2 @@ -446,7 +464,9 @@ def test_models_with_varying_sizes(self): ] # Total: 30 interactions # batch_size=15, stride=10 - windows = list(iter_sliding_windows(models, batch_size=15, batch_interval_size=10)) + windows = list( + iter_sliding_windows(models, batch_size=15, batch_interval_size=10) + ) assert len(windows) == 3 # Window 0: covers [0-14], models[0] (0-4) and models[1] (5-24) overlap @@ -464,7 +484,9 @@ def test_preserves_model_order(self): _create_mock_request_interaction_model(5), _create_mock_request_interaction_model(5), ] - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=5)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=5) + ) # First window should have models[0] and models[1] in order assert windows[0][1][0] is models[0] @@ -478,7 +500,9 @@ def test_model_with_zero_interactions_included(self): _create_mock_request_interaction_model(10), ] # Total: 20 interactions, empty model at position 10 - windows = list(iter_sliding_windows(models, batch_size=15, batch_interval_size=10)) + windows = list( + iter_sliding_windows(models, batch_size=15, batch_interval_size=10) + ) assert len(windows) == 2 @@ -488,14 +512,18 @@ def test_all_empty_models_yields_nothing(self): _create_mock_request_interaction_model(0), _create_mock_request_interaction_model(0), ] - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=5)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=5) + ) assert windows == [] def test_window_indices_are_sequential(self): """Test that window indices are sequential starting from 0.""" models = [_create_mock_request_interaction_model(10) for _ in range(5)] - windows = list(iter_sliding_windows(models, batch_size=10, batch_interval_size=10)) + windows = list( + iter_sliding_windows(models, batch_size=10, batch_interval_size=10) + ) indices = [w[0] for w in windows] assert indices == list(range(5)) diff --git a/tests/server/services/test_generation_service_dispatcher.py b/tests/server/services/test_generation_service_dispatcher.py new file mode 100644 index 00000000..21184f14 --- /dev/null +++ b/tests/server/services/test_generation_service_dispatcher.py @@ -0,0 +1,73 @@ +"""Task 2.6: config dispatcher for extraction/search backends.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from reflexio.models.config_schema import Config, StorageConfigSQLite +from reflexio.server.services.generation_service import ( + build_extraction_service, + build_search_service, +) + + +def _make_config(**overrides) -> Config: + """Build a minimal Config with optional field overrides. + + Args: + **overrides: Field overrides for Config. + + Returns: + Config: Minimal valid Config instance. + """ + base: dict = { + "storage_config": StorageConfigSQLite(), + } + base.update(overrides) + return Config(**base) + + +def test_config_defaults_extraction_backend_to_classic() -> None: + config = _make_config() + assert config.extraction_backend == "classic" + + +def test_config_defaults_search_backend_to_classic() -> None: + config = _make_config() + assert config.search_backend == "classic" + + +def test_config_accepts_agentic_backends() -> None: + config = _make_config(extraction_backend="agentic", search_backend="agentic") + assert config.extraction_backend == "agentic" + assert config.search_backend == "agentic" + + +def test_build_extraction_service_picks_classic_by_default() -> None: + config = _make_config() + svc = build_extraction_service( + config, llm_client=MagicMock(), request_context=MagicMock() + ) + assert svc.__class__.__name__ == "ProfileGenerationService" + + +def test_build_search_service_picks_classic_by_default() -> None: + config = _make_config() + svc = build_search_service( + config, llm_client=MagicMock(), request_context=MagicMock() + ) + assert svc.__class__.__name__ == "UnifiedSearchService" + + +def test_build_search_service_picks_agentic_when_configured() -> None: + # AgenticSearchService now lives alongside the dispatcher; if the import + # fails the dispatcher itself is broken — fail fast instead of skipping. + from reflexio.server.services.search.agentic_search_service import ( + AgenticSearchService, + ) + + config = _make_config(search_backend="agentic") + svc = build_search_service( + config, llm_client=MagicMock(), request_context=MagicMock() + ) + assert isinstance(svc, AgenticSearchService) diff --git a/tests/server/services/test_profile_generation_service.py b/tests/server/services/test_profile_generation_service.py index 941fbbfe..e4c8333a 100644 --- a/tests/server/services/test_profile_generation_service.py +++ b/tests/server/services/test_profile_generation_service.py @@ -1138,7 +1138,7 @@ def test_should_run_before_extraction_combines_all_extractor_criteria(): user_id=user_id, request_id="request-1", content="I am leading a migration project and prefer concise updates.", - role="user", + role="User", created_at=int(datetime.datetime.now(UTC).timestamp()), ) request_obj = Request( diff --git a/tests/server/services/test_prompt_model_mapping.py b/tests/server/services/test_prompt_model_mapping.py index 8f2c4b1d..076621c6 100644 --- a/tests/server/services/test_prompt_model_mapping.py +++ b/tests/server/services/test_prompt_model_mapping.py @@ -32,20 +32,18 @@ "playbook_extraction_main": ("v1.0.0", "playbook_extraction"), "playbook_extraction_main_incremental": ("v1.0.0", "playbook_extraction"), "playbook_extraction_context": ("v4.0.1", None), - "playbook_extraction_context_incremental": ("v4.0.1", None), + "playbook_extraction_context_incremental": ("v4.0.0", None), "playbook_should_generate": ("v3.0.0", "boolean_evaluation"), "playbook_should_generate_expert": ("v1.0.0", "boolean_evaluation"), "playbook_extraction_context_expert": ("v3.0.0", None), "playbook_extraction_main_expert": ("v1.0.0", "playbook_extraction"), - "playbook_aggregation": ("v2.0.0", "playbook_aggregation"), - "playbook_deduplication": ("v2.0.0", "playbook_deduplication"), + "playbook_aggregation": ("v2.1.0", "playbook_aggregation"), "profile_update_main": ("v1.0.0", "profile_extraction"), "profile_update_main_incremental": ("v1.0.0", "profile_extraction"), "profile_update_instruction_start": ("v1.0.0", None), "profile_update_instruction_incremental": ("v1.0.0", None), "profile_should_generate": ("v1.0.0", "boolean_evaluation"), "profile_should_generate_override": ("v1.0.0", "boolean_evaluation"), - "profile_deduplication": ("v1.0.0", "profile_deduplication"), "agent_success_evaluation": ("v1.0.0", "agent_success_evaluation"), "agent_success_evaluation_with_comparison": ( "v1.0.0", @@ -54,6 +52,10 @@ "shadow_content_evaluation": ("v1.0.0", None), "query_reformulation": ("v1.0.0", None), "document_expansion": ("v1.0.0", None), + # Agentic extraction pipeline — Phase 3 (v2 single-loop) + "extraction_agent": ("v1.5.0", None), + # Agentic search pipeline — agentic-v2 single-loop agent + "search_agent": ("v1.3.0", None), } diff --git a/tests/server/services/test_service_utils.py b/tests/server/services/test_service_utils.py index ae2d768f..1cbe18d0 100644 --- a/tests/server/services/test_service_utils.py +++ b/tests/server/services/test_service_utils.py @@ -235,8 +235,13 @@ def test_format_sessions_to_history_string_empty(): def test_format_sessions_to_history_string_single_group(): - """Test formatting a single session.""" + """Test formatting a single session. + + Header includes the session date so downstream extraction agents have + a temporal anchor for relative-time references in the conversation. + """ base_time = int(datetime.now(UTC).timestamp()) + iso = datetime.fromtimestamp(base_time, tz=UTC).strftime("%Y-%m-%d") session_data = RequestInteractionDataModel( session_id="group_1", @@ -248,7 +253,10 @@ def test_format_sessions_to_history_string_single_group(): ) result = format_sessions_to_history_string([session_data]) - expected = "=== Session: group_1 ===\nuser: ```Hello```\nassistant: ```Hi there!```" + expected = ( + f"=== Session: group_1 (date: {iso}) ===\n" + "user: ```Hello```\nassistant: ```Hi there!```" + ) assert result == expected @@ -288,9 +296,10 @@ def test_format_sessions_to_history_string_consolidates_same_group(): [session_id_1, session_id_2, session_id_3] ) + iso = datetime.fromtimestamp(base_time, tz=UTC).strftime("%Y-%m-%d") # All interactions should be under a single header expected = ( - "=== Session: group_1 ===\n" + f"=== Session: group_1 (date: {iso}) ===\n" "user: ```First message```\n" "assistant: ```First response```\n" "user: ```Second message```\n" @@ -322,10 +331,12 @@ def test_format_sessions_to_history_string_multiple_groups(): ) result = format_sessions_to_history_string([group_a, group_b]) + iso_a = datetime.fromtimestamp(base_time, tz=UTC).strftime("%Y-%m-%d") + iso_b = datetime.fromtimestamp(base_time + 100, tz=UTC).strftime("%Y-%m-%d") expected = ( - "=== Session: session_a ===\n" + f"=== Session: session_a (date: {iso_a}) ===\n" "user: ```Message A```\n\n" - "=== Session: session_b ===\n" + f"=== Session: session_b (date: {iso_b}) ===\n" "user: ```Message B```" ) assert result == expected @@ -365,13 +376,15 @@ def test_format_sessions_to_history_string_mixed_groups(): [group_1_req_1, group_2_req, group_1_req_2] ) + iso_1 = datetime.fromtimestamp(base_time, tz=UTC).strftime("%Y-%m-%d") + iso_2 = datetime.fromtimestamp(base_time + 50, tz=UTC).strftime("%Y-%m-%d") # Groups should be sorted by earliest request timestamp # group_1 (base_time) should come before group_2 (base_time + 50) expected = ( - "=== Session: group_1 ===\n" + f"=== Session: group_1 (date: {iso_1}) ===\n" "user: ```Group 1 - Request 1```\n" "user: ```Group 1 - Request 2```\n\n" - "=== Session: group_2 ===\n" + f"=== Session: group_2 (date: {iso_2}) ===\n" "user: ```Group 2 - Request 1```" ) assert result == expected @@ -411,9 +424,10 @@ def test_format_sessions_to_history_string_preserves_order_within_group(): [late_request, early_request, middle_request] ) + iso = datetime.fromtimestamp(base_time, tz=UTC).strftime("%Y-%m-%d") # Should be sorted by created_at within the group expected = ( - "=== Session: group_1 ===\n" + f"=== Session: group_1 (date: {iso}) ===\n" "user: ```Early message```\n" "user: ```Middle message```\n" "user: ```Late message```" diff --git a/tests/server/services/test_service_utils_extended.py b/tests/server/services/test_service_utils_extended.py index bb807b44..7938f398 100644 --- a/tests/server/services/test_service_utils_extended.py +++ b/tests/server/services/test_service_utils_extended.py @@ -204,3 +204,174 @@ def test_format_messages_for_logging_list_content(): assert "role: user" in result assert "Describe this image" in result assert "image_url" in result + + +def test_format_messages_for_logging_renders_assistant_tool_calls_sdk_shape(): + """Assistant messages with SDK-object tool_calls must render id/name/arguments. + + Before this fix, an assistant message with ``content=None`` and only + ``tool_calls`` looked like ``content: null`` with zero visibility into + the tools the model invoked. + """ + from types import SimpleNamespace + + tc = SimpleNamespace( + id="call_abc", + function=SimpleNamespace( + name="flag_cross_entity_conflict", + arguments='{"candidate_index":0,"reason":"contradicts profile"}', + ), + ) + messages = [{"role": "assistant", "content": None, "tool_calls": [tc]}] + + result = format_messages_for_logging(messages) + + assert "role: assistant" in result + assert "content: null" in result + assert "tool_calls:" in result + assert "- id: call_abc" in result + assert "name: flag_cross_entity_conflict" in result + # Arguments should be parsed + re-serialised for readability + assert '"candidate_index": 0' in result + assert '"reason": "contradicts profile"' in result + + +def test_format_messages_for_logging_renders_assistant_tool_calls_dict_shape(): + """Pass-through serialisation sometimes produces dict-shaped tool_calls.""" + messages = [ + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_xyz", + "type": "function", + "function": { + "name": "emit_profile", + "arguments": '{"content":"User likes Go","time_to_live":"infinity"}', + }, + } + ], + } + ] + + result = format_messages_for_logging(messages) + + assert "- id: call_xyz" in result + assert "name: emit_profile" in result + assert '"content": "User likes Go"' in result + + +def test_format_messages_for_logging_renders_tool_call_id_on_tool_role(): + """Tool-role messages must surface tool_call_id so readers can correlate.""" + messages = [ + {"role": "tool", "tool_call_id": "call_abc", "content": '{"flagged": 0}'}, + ] + + result = format_messages_for_logging(messages) + + assert "role: tool" in result + assert "tool_call_id: call_abc" in result + assert '{"flagged": 0}' in result + + +def test_format_messages_for_logging_handles_malformed_arguments_json(): + """Tool_call arguments that aren't valid JSON should fall back to raw string.""" + from types import SimpleNamespace + + tc = SimpleNamespace( + id="call_bad", + function=SimpleNamespace(name="emit", arguments="not valid json {"), + ) + messages = [{"role": "assistant", "content": None, "tool_calls": [tc]}] + + result = format_messages_for_logging(messages) + + # Formatter must not crash, and should preserve the raw string + assert "name: emit" in result + assert "not valid json {" in result + + +def test_format_messages_for_logging_skips_tool_calls_block_when_absent(): + """Assistant messages without tool_calls don't emit a ``tool_calls:`` header.""" + messages = [{"role": "assistant", "content": "plain text response"}] + + result = format_messages_for_logging(messages) + + assert "tool_calls:" not in result + assert "plain text response" in result + + +# --------------------------------------------------------------------------- +# _format_response_for_logging — ToolCallingChatResponse rendering +# --------------------------------------------------------------------------- + + +def test_format_response_renders_tool_calling_chat_response_with_sdk_tool_calls(): + """ToolCallingChatResponse with SDK-shaped tool_calls renders id/name/arguments.""" + from types import SimpleNamespace + + from reflexio.server.llm.litellm_client import ToolCallingChatResponse + from reflexio.server.services.service_utils import _format_response_for_logging + + tc = SimpleNamespace( + id="call_abc", + function=SimpleNamespace(name="rank", arguments='{"ordered_ids":["b1","b2"]}'), + ) + resp = ToolCallingChatResponse( + content=None, tool_calls=[tc], finish_reason="tool_calls" + ) + + out = _format_response_for_logging(resp) + + assert isinstance(out, str) + assert "ToolCallingChatResponse(finish_reason='tool_calls')" in out + assert "content: None" in out + assert "tool_calls:" in out + assert "- id: call_abc" in out + assert "name: rank" in out + # Arguments are parsed from JSON + re-serialized for readability + assert '"ordered_ids": ["b1", "b2"]' in out + + +def test_format_response_renders_tool_calling_chat_response_with_empty_tool_calls(): + """ToolCallingChatResponse with no tool_calls still renders content + finish_reason.""" + from reflexio.server.llm.litellm_client import ToolCallingChatResponse + from reflexio.server.services.service_utils import _format_response_for_logging + + resp = ToolCallingChatResponse( + content="plain text reply", tool_calls=None, finish_reason="stop" + ) + + out = _format_response_for_logging(resp) + + assert "ToolCallingChatResponse(finish_reason='stop')" in out + assert "content: 'plain text reply'" in out + assert "tool_calls: []" in out + + +def test_format_response_passes_basemodel_through_unchanged(): + """Pydantic BaseModel responses (classic extractor / deduplicator outputs) + must NOT be transformed — preserves existing llm_io.log shape for classic.""" + from pydantic import BaseModel + + from reflexio.server.services.service_utils import _format_response_for_logging + + class FakeClassicOutput(BaseModel): + profiles: list[str] = [] + + resp = FakeClassicOutput(profiles=["User likes polars"]) + + out = _format_response_for_logging(resp) + + # The helper returned the same object — caller's %s formatter will + # render it via str(resp) exactly as today. + assert out is resp + + +def test_format_response_passes_string_through_unchanged(): + """Plain strings go straight through (tool_loop handlers return strings).""" + from reflexio.server.services.service_utils import _format_response_for_logging + + out = _format_response_for_logging("raw string response") + assert out == "raw string response" diff --git a/tests/server/site_var/test_feature_flags.py b/tests/server/site_var/test_feature_flags.py index 601ea9f4..5b02bdbb 100644 --- a/tests/server/site_var/test_feature_flags.py +++ b/tests/server/site_var/test_feature_flags.py @@ -3,7 +3,6 @@ from reflexio.server.site_var.feature_flags import ( get_all_feature_flags, - is_deduplicator_enabled, is_feature_enabled, ) @@ -121,31 +120,6 @@ def test_get_all_flags_empty_config(self, _mock): result = get_all_feature_flags("org-123") self.assertEqual(result, {}) - @patch( - "reflexio.server.site_var.feature_flags._get_feature_flags_config", - return_value=MOCK_CONFIG, - ) - def test_is_deduplicator_enabled_for_enabled_org(self, _mock): - """is_deduplicator_enabled should return True for orgs in enabled_org_ids.""" - self.assertTrue(is_deduplicator_enabled("org-dedup")) - - @patch( - "reflexio.server.site_var.feature_flags._get_feature_flags_config", - return_value=MOCK_CONFIG, - ) - def test_is_deduplicator_disabled_for_other_org(self, _mock): - """is_deduplicator_enabled should return False for orgs not in enabled_org_ids.""" - self.assertFalse(is_deduplicator_enabled("org-123")) - self.assertFalse(is_deduplicator_enabled("org-999")) - - @patch( - "reflexio.server.site_var.feature_flags._get_feature_flags_config", - return_value={}, - ) - def test_is_deduplicator_enabled_unknown_defaults_enabled(self, _mock): - """is_deduplicator_enabled with empty config should default to enabled.""" - self.assertTrue(is_deduplicator_enabled("org-123")) - if __name__ == "__main__": unittest.main() diff --git a/tests/server/test_logging_timezone.py b/tests/server/test_logging_timezone.py new file mode 100644 index 00000000..869b2ede --- /dev/null +++ b/tests/server/test_logging_timezone.py @@ -0,0 +1,50 @@ +"""Tests for TZ-aware log formatters in reflexio.server.__init__.""" + +from __future__ import annotations + +import logging +import re + +from reflexio.server import _LLMIOFormatter, _TZAwareFormatter + +_TZ_PATTERN = re.compile( + r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3} [+-]\d{2}:\d{2}(?: [A-Z]{1,5})?" +) + + +def _make_record(msg: str = "payload") -> logging.LogRecord: + return logging.LogRecord( + name="reflexio.server.services.tools", + level=logging.DEBUG, + pathname="", + lineno=0, + msg=msg, + args=(), + exc_info=None, + ) + + +class TestTZAwareFormatter: + def test_format_time_contains_offset(self) -> None: + formatter = _TZAwareFormatter() + record = _make_record() + rendered = formatter.formatTime(record) + assert _TZ_PATTERN.match(rendered), f"timestamp missing TZ offset: {rendered!r}" + + def test_format_substitutes_asctime_with_offset(self) -> None: + """Verify the %(asctime)s path surfaces the TZ-aware timestamp.""" + formatter = _TZAwareFormatter("%(asctime)s %(levelname)s %(message)s") + record = _make_record("hello") + out = formatter.format(record) + assert _TZ_PATTERN.search(out), f"asctime missing TZ offset: {out!r}" + assert "hello" in out + + +class TestLLMIOFormatter: + def test_rendered_header_includes_tz_offset(self) -> None: + """The _LLMIOFormatter's header line must carry a TZ offset so + llm_io.log readers in any zone can localise the timestamp.""" + formatter = _LLMIOFormatter() + record = _make_record("full message payload") + out = formatter.format(record) + assert _TZ_PATTERN.search(out), f"header missing TZ offset: {out!r}" diff --git a/tests/server/test_uvicorn_logging.py b/tests/server/test_uvicorn_logging.py index 16fe027b..50a6793e 100644 --- a/tests/server/test_uvicorn_logging.py +++ b/tests/server/test_uvicorn_logging.py @@ -57,3 +57,30 @@ def test_dict_is_valid_dictconfig(self) -> None: def test_loggers_wire_uvicorn_names(self) -> None: names = set(UVICORN_LOG_CONFIG["loggers"]) assert {"uvicorn", "uvicorn.error", "uvicorn.access"}.issubset(names) + + @pytest.mark.usefixtures("isolate_logging_state") + def test_access_formatter_emits_without_keyerror( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """Regression: stdlib ``logging.Formatter`` doesn't know the + uvicorn-specific ``client_addr`` / ``request_line`` / ``status_code`` + fields; the access formatter must be wired to + ``uvicorn.logging.AccessFormatter`` via the ``()`` factory key or + every request raises ``KeyError: 'client_addr'`` at emit time. + """ + logging.config.dictConfig(UVICORN_LOG_CONFIG) + access = logging.getLogger("uvicorn.access") + # Shape matches uvicorn's real access-log emission — positional args + # consumed by AccessFormatter to derive client_addr / request_line / status_code. + access.info( + '%s - "%s %s HTTP/%s" %d', + "127.0.0.1:12345", + "POST", + "/api/ping", + "1.1", + 200, + ) + out = capsys.readouterr().out + assert "127.0.0.1:12345" in out + assert "POST /api/ping HTTP/1.1" in out + assert "200" in out diff --git a/uv.lock b/uv.lock index 4759529c..f8d3d340 100644 --- a/uv.lock +++ b/uv.lock @@ -975,6 +975,75 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, ] +[[package]] +name = "cuda-bindings" +version = "13.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" }, + { url = "https://files.pythonhosted.org/packages/1f/92/f899f7bbb5617bb65ec52a6eac1e9a1447a86b916c4194f8a5001b8cde0c/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46d8776a55d6d5da9dd6e9858fba2efcda2abe6743871dee47dd06eb8cb6d955", size = 6320619, upload-time = "2026-03-11T00:12:45.939Z" }, + { url = "https://files.pythonhosted.org/packages/df/93/eef988860a3ca985f82c4f3174fc0cdd94e07331ba9a92e8e064c260337f/cuda_bindings-13.2.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6629ca2df6f795b784752409bcaedbd22a7a651b74b56a165ebc0c9dcbd504d0", size = 5614610, upload-time = "2026-03-11T00:12:50.337Z" }, + { url = "https://files.pythonhosted.org/packages/18/23/6db3aba46864aee357ab2415135b3fe3da7e9f1fa0221fa2a86a5968099c/cuda_bindings-13.2.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dca0da053d3b4cc4869eff49c61c03f3c5dbaa0bcd712317a358d5b8f3f385d", size = 6149914, upload-time = "2026-03-11T00:12:52.374Z" }, + { url = "https://files.pythonhosted.org/packages/c0/87/87a014f045b77c6de5c8527b0757fe644417b184e5367db977236a141602/cuda_bindings-13.2.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6464b30f46692d6c7f65d4a0e0450d81dd29de3afc1bb515653973d01c2cd6e", size = 5685673, upload-time = "2026-03-11T00:12:56.371Z" }, + { url = "https://files.pythonhosted.org/packages/ee/5e/c0fe77a73aaefd3fff25ffaccaac69c5a63eafdf8b9a4c476626ef0ac703/cuda_bindings-13.2.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f4af9f3e1be603fa12d5ad6cfca7844c9d230befa9792b5abdf7dd79979c3626", size = 6191386, upload-time = "2026-03-11T00:12:58.965Z" }, + { url = "https://files.pythonhosted.org/packages/5f/58/ed2c3b39c8dd5f96aa7a4abef0d47a73932c7a988e30f5fa428f00ed0da1/cuda_bindings-13.2.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df850a1ff8ce1b3385257b08e47b70e959932f5f432d0a4e46a355962b4e4771", size = 5507469, upload-time = "2026-03-11T00:13:04.063Z" }, + { url = "https://files.pythonhosted.org/packages/1f/01/0c941b112ceeb21439b05895eace78ca1aa2eaaf695c8521a068fd9b4c00/cuda_bindings-13.2.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8a16384c6494e5485f39314b0b4afb04bee48d49edb16d5d8593fd35bbd231b", size = 6059693, upload-time = "2026-03-11T00:13:06.003Z" }, +] + +[[package]] +name = "cuda-pathfinder" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/d6/ac63065d33dd700fee7ebd7d287332401b54e31b9346e142f871e1f0b116/cuda_pathfinder-1.5.3-py3-none-any.whl", hash = "sha256:dff021123aedbb4117cc7ec81717bbfe198fb4e8b5f1ee57e0e084fec5c8577d", size = 49991, upload-time = "2026-04-14T20:09:27.037Z" }, +] + +[[package]] +name = "cuda-toolkit" +version = "13.0.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/b2/453099f5f3b698d7d0eab38916aac44c7f76229f451709e2eb9db6615dcd/cuda_toolkit-13.0.2-py2.py3-none-any.whl", hash = "sha256:b198824cf2f54003f50d64ada3a0f184b42ca0846c1c94192fa269ecd97a66eb", size = 2364, upload-time = "2025-12-19T23:24:07.328Z" }, +] + +[package.optional-dependencies] +cublas = [ + { name = "nvidia-cublas", marker = "sys_platform == 'linux'" }, +] +cudart = [ + { name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux'" }, +] +cufft = [ + { name = "nvidia-cufft", marker = "sys_platform == 'linux'" }, +] +cufile = [ + { name = "nvidia-cufile", marker = "sys_platform == 'linux'" }, +] +cupti = [ + { name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux'" }, +] +curand = [ + { name = "nvidia-curand", marker = "sys_platform == 'linux'" }, +] +cusolver = [ + { name = "nvidia-cusolver", marker = "sys_platform == 'linux'" }, +] +cusparse = [ + { name = "nvidia-cusparse", marker = "sys_platform == 'linux'" }, +] +nvjitlink = [ + { name = "nvidia-nvjitlink", marker = "sys_platform == 'linux'" }, +] +nvrtc = [ + { name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux'" }, +] +nvtx = [ + { name = "nvidia-nvtx", marker = "sys_platform == 'linux'" }, +] + [[package]] name = "cycler" version = "0.12.1" @@ -3023,6 +3092,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/4f/8812a01e3e0bd6be3e13b90432fb5c696af9a720af3f00e6eba5ad748345/moto-5.1.22-py3-none-any.whl", hash = "sha256:d9f20ae3cf29c44f93c1f8f06c8f48d5560e5dc027816ef1d0d2059741ffcfbe", size = 6617400, upload-time = "2026-03-08T21:06:41.093Z" }, ] +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + [[package]] name = "multidict" version = "6.7.1" @@ -3229,6 +3307,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, ] +[[package]] +name = "networkx" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, +] + [[package]] name = "nh3" version = "0.3.4" @@ -3360,6 +3447,155 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/58/78/548fb8e07b1a341746bfbecb32f2c268470f45fa028aacdbd10d9bc73aab/numpy-2.4.4-cp314-cp314t-win_arm64.whl", hash = "sha256:ba203255017337d39f89bdd58417f03c4426f12beed0440cfd933cb15f8669c7", size = 10566643, upload-time = "2026-03-29T13:21:34.339Z" }, ] +[[package]] +name = "nvidia-cublas" +version = "13.1.0.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/a5/fce49e2ae977e0ccc084e5adafceb4f0ac0c8333cb6863501618a7277f67/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2", size = 542851226, upload-time = "2025-10-09T08:59:04.818Z" }, + { url = "https://files.pythonhosted.org/packages/e7/44/423ac00af4dd95a5aeb27207e2c0d9b7118702149bf4704c3ddb55bb7429/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171", size = 423133236, upload-time = "2025-10-09T08:59:32.536Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti" +version = "13.0.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/2a/80353b103fc20ce05ef51e928daed4b6015db4aaa9162ed0997090fe2250/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_aarch64.whl", hash = "sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151", size = 10310827, upload-time = "2025-09-04T08:26:42.012Z" }, + { url = "https://files.pythonhosted.org/packages/33/6d/737d164b4837a9bbd202f5ae3078975f0525a55730fe871d8ed4e3b952b0/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_x86_64.whl", hash = "sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8", size = 10715597, upload-time = "2025-09-04T08:26:51.312Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc" +version = "13.0.88" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/68/483a78f5e8f31b08fb1bb671559968c0ca3a065ac7acabfc7cee55214fd6/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575", size = 90215200, upload-time = "2025-09-04T08:28:44.204Z" }, + { url = "https://files.pythonhosted.org/packages/b7/dc/6bb80850e0b7edd6588d560758f17e0550893a1feaf436807d64d2da040f/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b", size = 43015449, upload-time = "2025-09-04T08:28:20.239Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime" +version = "13.0.96" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/4f/17d7b9b8e285199c58ce28e31b5c5bbaa4d8271af06a89b6405258245de2/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55", size = 2261060, upload-time = "2025-10-09T08:55:15.78Z" }, + { url = "https://files.pythonhosted.org/packages/2e/24/d1558f3b68b1d26e706813b1d10aa1d785e4698c425af8db8edc3dced472/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548", size = 2243632, upload-time = "2025-10-09T08:55:36.117Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu13" +version = "9.19.0.56" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" }, + { url = "https://files.pythonhosted.org/packages/a3/22/0b4b932655d17a6da1b92fa92ab12844b053bb2ac2475e179ba6f043da1e/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:d20e1734305e9d68889a96e3f35094d733ff1f83932ebe462753973e53a572bf", size = 366066321, upload-time = "2026-02-03T20:44:52.837Z" }, +] + +[[package]] +name = "nvidia-cufft" +version = "12.0.0.61" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" }, + { url = "https://files.pythonhosted.org/packages/a8/2f/7b57e29836ea8714f81e9898409196f47d772d5ddedddf1592eadb8ab743/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3", size = 214085489, upload-time = "2025-09-04T08:31:56.044Z" }, +] + +[[package]] +name = "nvidia-cufile" +version = "1.15.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/70/4f193de89a48b71714e74602ee14d04e4019ad36a5a9f20c425776e72cd6/nvidia_cufile-1.15.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08a3ecefae5a01c7f5117351c64f17c7c62efa5fffdbe24fc7d298da19cd0b44", size = 1223672, upload-time = "2025-09-04T08:32:22.779Z" }, + { url = "https://files.pythonhosted.org/packages/ab/73/cc4a14c9813a8a0d509417cf5f4bdaba76e924d58beb9864f5a7baceefbf/nvidia_cufile-1.15.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:bdc0deedc61f548bddf7733bdc216456c2fdb101d020e1ab4b88d232d5e2f6d1", size = 1136992, upload-time = "2025-09-04T08:32:14.119Z" }, +] + +[[package]] +name = "nvidia-curand" +version = "10.4.0.35" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/72/7c2ae24fb6b63a32e6ae5d241cc65263ea18d08802aaae087d9f013335a2/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:133df5a7509c3e292aaa2b477afd0194f06ce4ea24d714d616ff36439cee349a", size = 61962106, upload-time = "2025-08-04T10:21:41.128Z" }, + { url = "https://files.pythonhosted.org/packages/a5/9f/be0a41ca4a4917abf5cb9ae0daff1a6060cc5de950aec0396de9f3b52bc5/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:1aee33a5da6e1db083fe2b90082def8915f30f3248d5896bcec36a579d941bfc", size = 59544258, upload-time = "2025-08-04T10:22:03.992Z" }, +] + +[[package]] +name = "nvidia-cusolver" +version = "12.0.4.66" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" }, + { url = "https://files.pythonhosted.org/packages/5f/67/cba3777620cdacb99102da4042883709c41c709f4b6323c10781a9c3aa34/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112", size = 200941980, upload-time = "2025-09-04T08:33:22.767Z" }, +] + +[[package]] +name = "nvidia-cusparse" +version = "12.6.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" }, + { url = "https://files.pythonhosted.org/packages/fa/18/623c77619c31d62efd55302939756966f3ecc8d724a14dab2b75f1508850/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b", size = 145942937, upload-time = "2025-09-04T08:33:58.029Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu13" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/10/8dcd1175260706a2fc92a16a52e306b71d4c1ea0b0cc4a9484183399818a/nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:400c6ed1cf6780fc6efedd64ec9f1345871767e6a1a0a552a1ea0578117ea77c", size = 220791277, upload-time = "2025-08-13T19:22:40.982Z" }, + { url = "https://files.pythonhosted.org/packages/fd/53/43b0d71f4e702fa9733f8b4571fdca50a8813f1e450b656c239beff12315/nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:25e30a8a7323935d4ad0340b95a0b69926eee755767e8e0b1cf8dd85b197d3fd", size = 169884119, upload-time = "2025-08-13T19:23:41.967Z" }, +] + +[[package]] +name = "nvidia-nccl-cu13" +version = "2.28.9" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/55/1920646a2e43ffd4fc958536b276197ed740e9e0c54105b4bb3521591fc7/nvidia_nccl_cu13-2.28.9-py3-none-manylinux_2_18_aarch64.whl", hash = "sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643", size = 196561677, upload-time = "2025-11-18T05:49:03.45Z" }, + { url = "https://files.pythonhosted.org/packages/b0/b4/878fefaad5b2bcc6fcf8d474a25e3e3774bc5133e4b58adff4d0bca238bc/nvidia_nccl_cu13-2.28.9-py3-none-manylinux_2_18_x86_64.whl", hash = "sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42", size = 196493177, upload-time = "2025-11-18T05:49:17.677Z" }, +] + +[[package]] +name = "nvidia-nvjitlink" +version = "13.0.88" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/7a/123e033aaff487c77107195fa5a2b8686795ca537935a24efae476c41f05/nvidia_nvjitlink-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b", size = 40713933, upload-time = "2025-09-04T08:35:43.553Z" }, + { url = "https://files.pythonhosted.org/packages/ab/2c/93c5250e64df4f894f1cbb397c6fd71f79813f9fd79d7cd61de3f97b3c2d/nvidia_nvjitlink-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c", size = 38768748, upload-time = "2025-09-04T08:35:20.008Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu13" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/0f/05cc9c720236dcd2db9c1ab97fff629e96821be2e63103569da0c9b72f19/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9", size = 60215947, upload-time = "2025-09-06T00:32:20.022Z" }, + { url = "https://files.pythonhosted.org/packages/3c/35/a9bf80a609e74e3b000fef598933235c908fcefcef9026042b8e6dfde2a9/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80", size = 60412546, upload-time = "2025-09-06T00:32:41.564Z" }, +] + +[[package]] +name = "nvidia-nvtx" +version = "13.0.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f3/d86c845465a2723ad7e1e5c36dcd75ddb82898b3f53be47ebd429fb2fa5d/nvidia_nvtx-13.0.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4", size = 148047, upload-time = "2025-09-04T08:29:01.761Z" }, + { url = "https://files.pythonhosted.org/packages/a8/64/3708a90d1ebe202ffdeb7185f878a3c84d15c2b2c31858da2ce0583e2def/nvidia_nvtx-13.0.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6", size = 148878, upload-time = "2025-09-04T08:28:53.627Z" }, +] + [[package]] name = "oauthlib" version = "3.3.1" @@ -3793,6 +4029,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "polars" +version = "1.40.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/8c/bc9bc948058348ed43117cecc3007cd608f395915dae8a00974579a5dab1/polars-1.40.1.tar.gz", hash = "sha256:ab2694134b137596b5a59bfd7b4c54ebbc9b59f9403127f18e32d363777552e8", size = 733574, upload-time = "2026-04-22T19:15:55.507Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/91/74fc60d94488685a92ac9d49d7ec55f3e91fe9b77942a6235a5fa7f249c3/polars-1.40.1-py3-none-any.whl", hash = "sha256:c0f861219d1319cdea45c4ce4d30355a47176b8f98dcedf95ea8269f131b8abd", size = 828723, upload-time = "2026-04-22T19:14:25.452Z" }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.40.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ba/26d40f039be9f552b5fd7365a621bdfc0f8e912ef77094ae4693491b0bae/polars_runtime_32-1.40.1.tar.gz", hash = "sha256:37f3065615d1bf90d03b5326222df4c5c1f8a5d33e50470aa588e3465e6eb814", size = 2935843, upload-time = "2026-04-22T19:15:57.26Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/46/22c8af5eed68ac2eeb556e0fa3ca8a7b798e984ceff4450888f3b5ac61fd/polars_runtime_32-1.40.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b748ef652270cc49e9e69f99a035e0eb4d5f856d42bcd6ac4d9d80a40142aa1e", size = 52098755, upload-time = "2026-04-22T19:14:28.555Z" }, + { url = "https://files.pythonhosted.org/packages/c6/3e/48599a38009ca60ff82a6f38c8a621ce3c0286aa7397c7d79e741bd9060e/polars_runtime_32-1.40.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:d249b3743e05986060cec0a7aaa542d020df6c6b876e556023a310efd581f9be", size = 46367542, upload-time = "2026-04-22T19:14:32.433Z" }, + { url = "https://files.pythonhosted.org/packages/43/e9/384bc069367a1a36ee31c13782c178dbd039b2b873b772d4a0fc23a2373d/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5987b30e7aa1059d069498496e8dda35afd592b0ac3d46ed87e3ff8df1ad652c", size = 50252104, upload-time = "2026-04-22T19:14:35.945Z" }, + { url = "https://files.pythonhosted.org/packages/15/ef/7d57ceb0651af74194e97ed6583e148d352f03d696090221b8059cdfc90b/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d7f42a8b3f16fc66002cc0f6516f7dd7653396886ae0ed362ab95c0b3408b59", size = 56250788, upload-time = "2026-04-22T19:14:39.743Z" }, + { url = "https://files.pythonhosted.org/packages/10/0f/e4b3ffc748827a14a474ec9c42e45c066050e440fec57e914091d9adda75/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e5f7becc237a7ec9d9a10878dc8e54b73bbf4e2d94a2991c37d7a0b38590d8f9", size = 50432590, upload-time = "2026-04-22T19:14:43.388Z" }, + { url = "https://files.pythonhosted.org/packages/d9/0b/b8d95fbed869fa4caabe9c400e4210374913b376e925e96fdcfa9be6416b/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:992d14cf191dde043d36fbdbc98a65e43fbc7e9a5024cecd45f838ac4988c1ee", size = 54155564, upload-time = "2026-04-22T19:14:47.239Z" }, + { url = "https://files.pythonhosted.org/packages/06/d9/d091d8fb5cbed5e9536adfed955c4c89987a4cc3b8e73ae4532402b91c74/polars_runtime_32-1.40.1-cp310-abi3-win_amd64.whl", hash = "sha256:f78bb2abd00101cbb23cc0cb068f7e36e081057a15d2ec2dde3dda280709f030", size = 51829755, upload-time = "2026-04-22T19:14:50.85Z" }, + { url = "https://files.pythonhosted.org/packages/65/ad/b33c3022a394f3eb55c3310597cec615412a8a33880055eee191d154a628/polars_runtime_32-1.40.1-cp310-abi3-win_arm64.whl", hash = "sha256:b5cbfaf6b085b420b4bfcbe24e8f665076d1cccfdb80c0484c02a023ce205537", size = 45822104, upload-time = "2026-04-22T19:14:54.192Z" }, +] + [[package]] name = "pre-commit" version = "4.5.1" @@ -4817,6 +5081,7 @@ dependencies = [ { name = "redis" }, { name = "requests" }, { name = "rich" }, + { name = "sentence-transformers" }, { name = "slowapi" }, { name = "tenacity" }, { name = "tiktoken" }, @@ -4857,6 +5122,7 @@ dev = [ { name = "matplotlib" }, { name = "moto" }, { name = "mutmut" }, + { name = "polars" }, { name = "pre-commit" }, { name = "pyright" }, { name = "pytest" }, @@ -4914,6 +5180,7 @@ requires-dist = [ { name = "reportlab", marker = "extra == 'benchmark'", specifier = ">=4.4.10" }, { name = "requests", specifier = ">=2.25.0" }, { name = "rich", specifier = ">=13.0.0" }, + { name = "sentence-transformers", specifier = ">=3.0" }, { name = "slowapi", specifier = ">=0.1.9" }, { name = "sqlite-vec", marker = "extra == 'vec'", specifier = ">=0.1.6" }, { name = "tenacity", specifier = ">=9.0.0" }, @@ -4934,6 +5201,7 @@ dev = [ { name = "matplotlib", specifier = ">=3.10.8" }, { name = "moto", specifier = ">=5.0.28" }, { name = "mutmut", specifier = ">=3.2.0" }, + { name = "polars", specifier = ">=1.40.1" }, { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.3.4" }, @@ -5294,6 +5562,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, ] +[[package]] +name = "safetensors" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" }, + { url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" }, + { url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" }, + { url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" }, + { url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" }, + { url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" }, + { url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" }, + { url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" }, + { url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" }, + { url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" }, + { url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, +] + [[package]] name = "scikit-learn" version = "1.8.0" @@ -5421,6 +5711,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/78/504fdd027da3b84ff1aecd9f6957e65f35134534ccc6da8628eb71e76d3f/send2trash-2.1.0-py3-none-any.whl", hash = "sha256:0da2f112e6d6bb22de6aa6daa7e144831a4febf2a87261451c4ad849fe9a873c", size = 17610, upload-time = "2026-01-14T06:27:35.218Z" }, ] +[[package]] +name = "sentence-transformers" +version = "5.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/68/7f98c221940ce783b492ad6140384daf2e2918cd7175009d6a362c22b9ee/sentence_transformers-5.4.1.tar.gz", hash = "sha256:436bcb1182a0ff42a8fb2b1c43498a70d0a75b688d182f2cd0d1dd115af61ddc", size = 428910, upload-time = "2026-04-14T13:34:59.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/d9/3a9b6f2ccdedc9dc00fe37b2fc58f58f8efbff44565cf4bf39d8568bb13a/sentence_transformers-5.4.1-py3-none-any.whl", hash = "sha256:a6d640fc363849b63affb8e140e9d328feabab86f83d58ac3e16b1c28140b790", size = 571311, upload-time = "2026-04-14T13:34:57.731Z" }, +] + [[package]] name = "setproctitle" version = "1.3.7" @@ -5481,11 +5790,11 @@ wheels = [ [[package]] name = "setuptools" -version = "82.0.1" +version = "81.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4f/db/cfac1baf10650ab4d1c111714410d2fbb77ac5a616db26775db562c8fab2/setuptools-82.0.1.tar.gz", hash = "sha256:7d872682c5d01cfde07da7bccc7b65469d3dca203318515ada1de5eda35efbf9", size = 1152316, upload-time = "2026-03-09T12:47:17.221Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/1c/73e719955c59b8e424d015ab450f51c0af856ae46ea2da83eba51cc88de1/setuptools-81.0.0.tar.gz", hash = "sha256:487b53915f52501f0a79ccfd0c02c165ffe06631443a886740b91af4b7a5845a", size = 1198299, upload-time = "2026-02-06T21:10:39.601Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" }, + { url = "https://files.pythonhosted.org/packages/e1/e3/c164c88b2e5ce7b24d667b9bd83589cf4f3520d97cad01534cd3c4f55fdb/setuptools-81.0.0-py3-none-any.whl", hash = "sha256:fdd925d5c5d9f62e4b74b30d6dd7828ce236fd6ed998a08d81de62ce5a6310d6", size = 1062021, upload-time = "2026-02-06T21:10:37.175Z" }, ] [[package]] @@ -5592,6 +5901,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/c9/584bc9651441b4ba60cc4d557d8a547b5aff901af35bda3a4ee30c819b82/starlette-1.0.0-py3-none-any.whl", hash = "sha256:d3ec55e0bb321692d275455ddfd3df75fff145d009685eb40dc91fc66b03d38b", size = 72651, upload-time = "2026-03-22T18:29:45.111Z" }, ] +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + [[package]] name = "syrupy" version = "5.1.0" @@ -5765,6 +6086,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/75/8539d011f6be8e29f339c42e633aae3cb73bffa95dd0f9adec09b9c58e85/tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0", size = 38901, upload-time = "2025-06-05T07:13:43.546Z" }, ] +[[package]] +name = "torch" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, + { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux'" }, + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cudnn-cu13", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu13", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu13", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu13", marker = "sys_platform == 'linux'" }, + { name = "setuptools" }, + { name = "sympy" }, + { name = "triton", marker = "sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/8b/69e3008d78e5cee2b30183340cc425081b78afc5eff3d080daab0adda9aa/torch-2.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b5866312ee6e52ea625cd211dcb97d6a2cdc1131a5f15cc0d87eec948f6dd34", size = 80606338, upload-time = "2026-03-23T18:11:34.781Z" }, + { url = "https://files.pythonhosted.org/packages/13/16/42e5915ebe4868caa6bac83a8ed59db57f12e9a61b7d749d584776ed53d5/torch-2.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f99924682ef0aa6a4ab3b1b76f40dc6e273fca09f367d15a524266db100a723f", size = 419731115, upload-time = "2026-03-23T18:11:06.944Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c9/82638ef24d7877510f83baf821f5619a61b45568ce21c0a87a91576510aa/torch-2.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:0f68f4ac6d95d12e896c3b7a912b5871619542ec54d3649cf48cc1edd4dd2756", size = 530712279, upload-time = "2026-03-23T18:10:31.481Z" }, + { url = "https://files.pythonhosted.org/packages/1c/ff/6756f1c7ee302f6d202120e0f4f05b432b839908f9071157302cedfc5232/torch-2.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbf39280699d1b869f55eac536deceaa1b60bd6788ba74f399cc67e60a5fab10", size = 114556047, upload-time = "2026-03-23T18:10:55.931Z" }, + { url = "https://files.pythonhosted.org/packages/87/89/5ea6722763acee56b045435fb84258db7375c48165ec8be7880ab2b281c5/torch-2.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1e6debd97ccd3205bbb37eb806a9d8219e1139d15419982c09e23ef7d4369d18", size = 80606801, upload-time = "2026-03-23T18:10:18.649Z" }, + { url = "https://files.pythonhosted.org/packages/32/d1/8ed2173589cbfe744ed54e5a73efc107c0085ba5777ee93a5f4c1ab90553/torch-2.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:63a68fa59de8f87acc7e85a5478bb2dddbb3392b7593ec3e78827c793c4b73fd", size = 419732382, upload-time = "2026-03-23T18:08:30.835Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e1/b73f7c575a4b8f87a5928f50a1e35416b5e27295d8be9397d5293e7e8d4c/torch-2.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:cc89b9b173d9adfab59fd227f0ab5e5516d9a52b658ae41d64e59d2e55a418db", size = 530711509, upload-time = "2026-03-23T18:08:47.213Z" }, + { url = "https://files.pythonhosted.org/packages/66/82/3e3fcdd388fbe54e29fd3f991f36846ff4ac90b0d0181e9c8f7236565f82/torch-2.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:4dda3b3f52d121063a731ddb835f010dc137b920d7fec2778e52f60d8e4bf0cd", size = 114555842, upload-time = "2026-03-23T18:09:52.111Z" }, + { url = "https://files.pythonhosted.org/packages/db/38/8ac78069621b8c2b4979c2f96dc8409ef5e9c4189f6aac629189a78677ca/torch-2.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8b394322f49af4362d4f80e424bcaca7efcd049619af03a4cf4501520bdf0fb4", size = 80959574, upload-time = "2026-03-23T18:10:14.214Z" }, + { url = "https://files.pythonhosted.org/packages/6d/6c/56bfb37073e7136e6dd86bfc6af7339946dd684e0ecf2155ac0eee687ae1/torch-2.11.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2658f34ce7e2dabf4ec73b45e2ca68aedad7a5be87ea756ad656eaf32bf1e1ea", size = 419732324, upload-time = "2026-03-23T18:09:36.604Z" }, + { url = "https://files.pythonhosted.org/packages/07/f4/1b666b6d61d3394cca306ea543ed03a64aad0a201b6cd159f1d41010aeb1/torch-2.11.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:98bb213c3084cfe176302949bdc360074b18a9da7ab59ef2edc9d9f742504778", size = 530596026, upload-time = "2026-03-23T18:09:20.842Z" }, + { url = "https://files.pythonhosted.org/packages/48/6b/30d1459fa7e4b67e9e3fe1685ca1d8bb4ce7c62ef436c3a615963c6c866c/torch-2.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a97b94bbf62992949b4730c6cd2cc9aee7b335921ee8dc207d930f2ed09ae2db", size = 114793702, upload-time = "2026-03-23T18:09:47.304Z" }, + { url = "https://files.pythonhosted.org/packages/26/0d/8603382f61abd0db35841148ddc1ffd607bf3100b11c6e1dab6d2fc44e72/torch-2.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:01018087326984a33b64e04c8cb5c2795f9120e0d775ada1f6638840227b04d7", size = 80573442, upload-time = "2026-03-23T18:09:10.117Z" }, + { url = "https://files.pythonhosted.org/packages/c7/86/7cd7c66cb9cec6be330fff36db5bd0eef386d80c031b581ec81be1d4b26c/torch-2.11.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:2bb3cc54bd0dea126b0060bb1ec9de0f9c7f7342d93d436646516b0330cd5be7", size = 419749385, upload-time = "2026-03-23T18:07:33.77Z" }, + { url = "https://files.pythonhosted.org/packages/47/e8/b98ca2d39b2e0e4730c0ee52537e488e7008025bc77ca89552ff91021f7c/torch-2.11.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4dc8b3809469b6c30b411bb8c4cad3828efd26236153d9beb6a3ec500f211a60", size = 530716756, upload-time = "2026-03-23T18:07:50.02Z" }, + { url = "https://files.pythonhosted.org/packages/78/88/d4a4cda8362f8a30d1ed428564878c3cafb0d87971fbd3947d4c84552095/torch-2.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:2b4e811728bd0cc58fb2b0948fe939a1ee2bf1422f6025be2fca4c7bd9d79718", size = 114552300, upload-time = "2026-03-23T18:09:05.617Z" }, + { url = "https://files.pythonhosted.org/packages/bf/46/4419098ed6d801750f26567b478fc185c3432e11e2cad712bc6b4c2ab0d0/torch-2.11.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8245477871c3700d4370352ffec94b103cfcb737229445cf9946cddb7b2ca7cd", size = 80959460, upload-time = "2026-03-23T18:09:00.818Z" }, + { url = "https://files.pythonhosted.org/packages/fd/66/54a56a4a6ceaffb567231994a9745821d3af922a854ed33b0b3a278e0a99/torch-2.11.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:ab9a8482f475f9ba20e12db84b0e55e2f58784bdca43a854a6ccd3fd4b9f75e6", size = 419735835, upload-time = "2026-03-23T18:07:18.974Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e7/0b6665f533aa9e337662dc190425abc0af1fe3234088f4454c52393ded61/torch-2.11.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:563ed3d25542d7e7bbc5b235ccfacfeb97fb470c7fee257eae599adb8005c8a2", size = 530613405, upload-time = "2026-03-23T18:08:07.014Z" }, + { url = "https://files.pythonhosted.org/packages/cf/bf/c8d12a2c86dbfd7f40fb2f56fbf5a505ccf2d9ce131eb559dfc7c51e1a04/torch-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b2a43985ff5ef6ddd923bbcf99943e5f58059805787c5c9a2622bf05ca2965b0", size = 114792991, upload-time = "2026-03-23T18:08:19.216Z" }, +] + [[package]] name = "tornado" version = "6.5.5" @@ -5803,6 +6167,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, ] +[[package]] +name = "transformers" +version = "5.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a4/e9/c6c80a07690142a7d05444271f47b9f3c8aac7dea01d52e1137ee480ad78/transformers-5.6.2.tar.gz", hash = "sha256:e657134c3e5a6bc00a3c35f4e2674bb51adfcd89898495b788a18552bac2b91a", size = 8311867, upload-time = "2026-04-23T18:33:29.332Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/95/0b0218149b0d6f14df35f5b8f676fa83df4f19ed253c3cc447107ef86eca/transformers-5.6.2-py3-none-any.whl", hash = "sha256:f8d3a1bb96778fed9b8aabfd0dd6e19843e4b0f2bb6b59f32b8a92051b0f348f", size = 10364898, upload-time = "2026-04-23T18:33:26.081Z" }, +] + +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, + { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, + { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, + { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, +] + [[package]] name = "twine" version = "6.2.0"