From 38aeb73a48d4e38035bbdfdfd09860feb1faaab2 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 13 Mar 2026 23:24:13 +0100 Subject: [PATCH 001/135] Sprint 5A: task workspace records and provisioning (#1) Co-authored-by: Sami Rusani --- .env.example | 12 + .gitignore | 9 + ARCHITECTURE.md | 207 ++ CHANGELOG.md | 27 + PRODUCT_BRIEF.md | 77 + README.md | 74 +- ROADMAP.md | 97 + RULES.md | 52 + apps/api/.gitkeep | 1 + apps/api/alembic.ini | 37 + apps/api/alembic/env.py | 71 + .../20260310_0001_foundation_continuity.py | 167 + ...0260311_0002_tighten_runtime_privileges.py | 39 + .../versions/20260311_0003_trace_backbone.py | 117 + .../20260311_0004_memory_admission.py | 123 + .../20260312_0005_memory_review_labels.py | 99 + .../20260312_0006_entities_backbone.py | 72 + .../versions/20260312_0007_entity_edges.py | 83 + .../20260312_0008_embedding_substrate.py | 115 + .../20260312_0009_policy_and_consent_core.py | 111 + ...60312_0010_tools_registry_and_allowlist.py | 96 + .../20260312_0011_approval_request_records.py | 88 + .../20260312_0012_approval_resolution.py | 63 + .../versions/20260313_0013_tool_executions.py | 118 + .../20260313_0014_execution_budgets.py | 71 + ...0260313_0015_execution_budget_lifecycle.py | 80 + ...13_0016_execution_budget_rolling_window.py | 47 + .../20260313_0017_tasks_lifecycle_records.py | 112 + .../versions/20260313_0018_task_steps.py | 93 + .../20260313_0019_task_step_lineage.py | 58 + ...0260313_0020_approval_task_step_linkage.py | 41 + ...3_0021_tool_execution_task_step_linkage.py | 81 + .../versions/20260313_0022_task_workspaces.py | 77 + apps/api/src/alicebot_api/__init__.py | 2 + apps/api/src/alicebot_api/approvals.py | 490 +++ apps/api/src/alicebot_api/compiler.py | 832 +++++ apps/api/src/alicebot_api/config.py | 119 + apps/api/src/alicebot_api/contracts.py | 2080 +++++++++++++ apps/api/src/alicebot_api/db.py | 37 + apps/api/src/alicebot_api/embedding.py | 242 ++ apps/api/src/alicebot_api/entity.py | 117 + apps/api/src/alicebot_api/entity_edge.py | 134 + .../api/src/alicebot_api/execution_budgets.py | 818 +++++ apps/api/src/alicebot_api/executions.py | 68 + .../src/alicebot_api/explicit_preferences.py | 262 ++ apps/api/src/alicebot_api/main.py | 1837 +++++++++++ apps/api/src/alicebot_api/memory.py | 483 +++ apps/api/src/alicebot_api/migrations.py | 17 + apps/api/src/alicebot_api/policy.py | 421 +++ apps/api/src/alicebot_api/proxy_execution.py | 557 ++++ .../src/alicebot_api/response_generation.py | 474 +++ .../src/alicebot_api/semantic_retrieval.py | 107 + apps/api/src/alicebot_api/store.py | 2713 +++++++++++++++++ apps/api/src/alicebot_api/tasks.py | 1170 +++++++ apps/api/src/alicebot_api/tools.py | 553 ++++ apps/api/src/alicebot_api/workspaces.py | 144 + apps/web/.gitkeep | 1 + apps/web/app/layout.tsx | 10 + apps/web/app/page.tsx | 51 + apps/web/next-env.d.ts | 5 + apps/web/next.config.mjs | 6 + apps/web/package.json | 25 + apps/web/tsconfig.json | 19 + docker-compose.yml | 36 + docs/adr/.gitkeep | 1 + docs/archive/.gitkeep | 1 + docs/runbooks/.gitkeep | 1 + infra/postgres/init/001_roles.sql | 16 + pyproject.toml | 32 + scripts/.gitkeep | 1 + scripts/api_dev.sh | 30 + scripts/dev_up.sh | 49 + scripts/migrate.sh | 20 + tests/.gitkeep | 1 + tests/integration/conftest.py | 55 + tests/integration/test_approval_api.py | 929 ++++++ tests/integration/test_context_compile.py | 890 ++++++ tests/integration/test_continuity_store.py | 244 ++ tests/integration/test_embeddings_api.py | 793 +++++ tests/integration/test_entities_api.py | 309 ++ tests/integration/test_entity_edges_api.py | 376 +++ .../integration/test_execution_budgets_api.py | 432 +++ .../test_explicit_preferences_api.py | 398 +++ tests/integration/test_healthcheck.py | 175 ++ tests/integration/test_memory_admission.py | 252 ++ tests/integration/test_memory_review_api.py | 526 ++++ .../test_memory_review_labels_api.py | 333 ++ tests/integration/test_migrations.py | 798 +++++ tests/integration/test_policy_api.py | 424 +++ tests/integration/test_proxy_execution_api.py | 1478 +++++++++ tests/integration/test_responses_api.py | 315 ++ tests/integration/test_task_workspaces_api.py | 218 ++ tests/integration/test_tasks_api.py | 946 ++++++ tests/integration/test_tool_api.py | 930 ++++++ ...est_20260310_0001_foundation_continuity.py | 59 + ...0260311_0002_tighten_runtime_privileges.py | 42 + .../unit/test_20260311_0003_trace_backbone.py | 51 + .../test_20260311_0004_memory_admission.py | 51 + ...test_20260312_0005_memory_review_labels.py | 48 + .../test_20260312_0006_entities_backbone.py | 46 + tests/unit/test_20260312_0007_entity_edges.py | 46 + .../test_20260312_0008_embedding_substrate.py | 49 + ...t_20260312_0009_policy_and_consent_core.py | 49 + ...60312_0010_tools_registry_and_allowlist.py | 46 + ..._20260312_0011_approval_request_records.py | 46 + .../test_20260312_0012_approval_resolution.py | 43 + .../test_20260313_0013_tool_executions.py | 46 + .../test_20260313_0014_execution_budgets.py | 46 + ...0260313_0015_execution_budget_lifecycle.py | 38 + ...13_0016_execution_budget_rolling_window.py | 32 + tests/unit/test_20260313_0018_task_steps.py | 46 + .../test_20260313_0019_task_step_lineage.py | 32 + ...0260313_0020_approval_task_step_linkage.py | 32 + ...3_0021_tool_execution_task_step_linkage.py | 32 + .../test_20260313_0022_task_workspaces.py | 46 + tests/unit/test_approval_store.py | 152 + tests/unit/test_approvals.py | 1200 ++++++++ tests/unit/test_approvals_main.py | 376 +++ tests/unit/test_compiler.py | 760 +++++ tests/unit/test_config.py | 91 + tests/unit/test_db.py | 122 + tests/unit/test_embedding.py | 437 +++ tests/unit/test_embedding_store.py | 223 ++ tests/unit/test_entity.py | 170 ++ tests/unit/test_entity_edge.py | 231 ++ tests/unit/test_entity_store.py | 207 ++ tests/unit/test_env.py | 178 ++ tests/unit/test_events.py | 20 + tests/unit/test_execution_budget_store.py | 96 + tests/unit/test_execution_budgets.py | 709 +++++ tests/unit/test_execution_budgets_main.py | 373 +++ tests/unit/test_executions.py | 251 ++ tests/unit/test_executions_main.py | 166 + tests/unit/test_explicit_preferences.py | 266 ++ tests/unit/test_main.py | 2378 +++++++++++++++ tests/unit/test_memory.py | 897 ++++++ tests/unit/test_memory_store.py | 357 +++ tests/unit/test_ops_assets.py | 20 + tests/unit/test_policy.py | 447 +++ tests/unit/test_policy_main.py | 206 ++ tests/unit/test_policy_store.py | 184 ++ tests/unit/test_proxy_execution.py | 783 +++++ tests/unit/test_proxy_execution_main.py | 289 ++ tests/unit/test_response_generation.py | 267 ++ tests/unit/test_semantic_retrieval.py | 176 ++ tests/unit/test_store.py | 157 + tests/unit/test_task_step_store.py | 293 ++ tests/unit/test_task_workspace_store.py | 184 ++ tests/unit/test_tasks.py | 1663 ++++++++++ tests/unit/test_tasks_main.py | 178 ++ tests/unit/test_tool_execution_store.py | 185 ++ tests/unit/test_tool_store.py | 165 + tests/unit/test_tools.py | 688 +++++ tests/unit/test_tools_main.py | 315 ++ tests/unit/test_trace_store.py | 148 + tests/unit/test_worker_main.py | 40 + tests/unit/test_workspaces.py | 207 ++ tests/unit/test_workspaces_main.py | 112 + workers/.gitkeep | 1 + workers/alicebot_worker/__init__.py | 2 + workers/alicebot_worker/main.py | 15 + 161 files changed, 44706 insertions(+), 23 deletions(-) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 ARCHITECTURE.md create mode 100644 CHANGELOG.md create mode 100644 PRODUCT_BRIEF.md create mode 100644 ROADMAP.md create mode 100644 RULES.md create mode 100644 apps/api/.gitkeep create mode 100644 apps/api/alembic.ini create mode 100644 apps/api/alembic/env.py create mode 100644 apps/api/alembic/versions/20260310_0001_foundation_continuity.py create mode 100644 apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py create mode 100644 apps/api/alembic/versions/20260311_0003_trace_backbone.py create mode 100644 apps/api/alembic/versions/20260311_0004_memory_admission.py create mode 100644 apps/api/alembic/versions/20260312_0005_memory_review_labels.py create mode 100644 apps/api/alembic/versions/20260312_0006_entities_backbone.py create mode 100644 apps/api/alembic/versions/20260312_0007_entity_edges.py create mode 100644 apps/api/alembic/versions/20260312_0008_embedding_substrate.py create mode 100644 apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py create mode 100644 apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py create mode 100644 apps/api/alembic/versions/20260312_0011_approval_request_records.py create mode 100644 apps/api/alembic/versions/20260312_0012_approval_resolution.py create mode 100644 apps/api/alembic/versions/20260313_0013_tool_executions.py create mode 100644 apps/api/alembic/versions/20260313_0014_execution_budgets.py create mode 100644 apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py create mode 100644 apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py create mode 100644 apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py create mode 100644 apps/api/alembic/versions/20260313_0018_task_steps.py create mode 100644 apps/api/alembic/versions/20260313_0019_task_step_lineage.py create mode 100644 apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py create mode 100644 apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py create mode 100644 apps/api/alembic/versions/20260313_0022_task_workspaces.py create mode 100644 apps/api/src/alicebot_api/__init__.py create mode 100644 apps/api/src/alicebot_api/approvals.py create mode 100644 apps/api/src/alicebot_api/compiler.py create mode 100644 apps/api/src/alicebot_api/config.py create mode 100644 apps/api/src/alicebot_api/contracts.py create mode 100644 apps/api/src/alicebot_api/db.py create mode 100644 apps/api/src/alicebot_api/embedding.py create mode 100644 apps/api/src/alicebot_api/entity.py create mode 100644 apps/api/src/alicebot_api/entity_edge.py create mode 100644 apps/api/src/alicebot_api/execution_budgets.py create mode 100644 apps/api/src/alicebot_api/executions.py create mode 100644 apps/api/src/alicebot_api/explicit_preferences.py create mode 100644 apps/api/src/alicebot_api/main.py create mode 100644 apps/api/src/alicebot_api/memory.py create mode 100644 apps/api/src/alicebot_api/migrations.py create mode 100644 apps/api/src/alicebot_api/policy.py create mode 100644 apps/api/src/alicebot_api/proxy_execution.py create mode 100644 apps/api/src/alicebot_api/response_generation.py create mode 100644 apps/api/src/alicebot_api/semantic_retrieval.py create mode 100644 apps/api/src/alicebot_api/store.py create mode 100644 apps/api/src/alicebot_api/tasks.py create mode 100644 apps/api/src/alicebot_api/tools.py create mode 100644 apps/api/src/alicebot_api/workspaces.py create mode 100644 apps/web/.gitkeep create mode 100644 apps/web/app/layout.tsx create mode 100644 apps/web/app/page.tsx create mode 100644 apps/web/next-env.d.ts create mode 100644 apps/web/next.config.mjs create mode 100644 apps/web/package.json create mode 100644 apps/web/tsconfig.json create mode 100644 docker-compose.yml create mode 100644 docs/adr/.gitkeep create mode 100644 docs/archive/.gitkeep create mode 100644 docs/runbooks/.gitkeep create mode 100644 infra/postgres/init/001_roles.sql create mode 100644 pyproject.toml create mode 100644 scripts/.gitkeep create mode 100755 scripts/api_dev.sh create mode 100755 scripts/dev_up.sh create mode 100755 scripts/migrate.sh create mode 100644 tests/.gitkeep create mode 100644 tests/integration/conftest.py create mode 100644 tests/integration/test_approval_api.py create mode 100644 tests/integration/test_context_compile.py create mode 100644 tests/integration/test_continuity_store.py create mode 100644 tests/integration/test_embeddings_api.py create mode 100644 tests/integration/test_entities_api.py create mode 100644 tests/integration/test_entity_edges_api.py create mode 100644 tests/integration/test_execution_budgets_api.py create mode 100644 tests/integration/test_explicit_preferences_api.py create mode 100644 tests/integration/test_healthcheck.py create mode 100644 tests/integration/test_memory_admission.py create mode 100644 tests/integration/test_memory_review_api.py create mode 100644 tests/integration/test_memory_review_labels_api.py create mode 100644 tests/integration/test_migrations.py create mode 100644 tests/integration/test_policy_api.py create mode 100644 tests/integration/test_proxy_execution_api.py create mode 100644 tests/integration/test_responses_api.py create mode 100644 tests/integration/test_task_workspaces_api.py create mode 100644 tests/integration/test_tasks_api.py create mode 100644 tests/integration/test_tool_api.py create mode 100644 tests/unit/test_20260310_0001_foundation_continuity.py create mode 100644 tests/unit/test_20260311_0002_tighten_runtime_privileges.py create mode 100644 tests/unit/test_20260311_0003_trace_backbone.py create mode 100644 tests/unit/test_20260311_0004_memory_admission.py create mode 100644 tests/unit/test_20260312_0005_memory_review_labels.py create mode 100644 tests/unit/test_20260312_0006_entities_backbone.py create mode 100644 tests/unit/test_20260312_0007_entity_edges.py create mode 100644 tests/unit/test_20260312_0008_embedding_substrate.py create mode 100644 tests/unit/test_20260312_0009_policy_and_consent_core.py create mode 100644 tests/unit/test_20260312_0010_tools_registry_and_allowlist.py create mode 100644 tests/unit/test_20260312_0011_approval_request_records.py create mode 100644 tests/unit/test_20260312_0012_approval_resolution.py create mode 100644 tests/unit/test_20260313_0013_tool_executions.py create mode 100644 tests/unit/test_20260313_0014_execution_budgets.py create mode 100644 tests/unit/test_20260313_0015_execution_budget_lifecycle.py create mode 100644 tests/unit/test_20260313_0016_execution_budget_rolling_window.py create mode 100644 tests/unit/test_20260313_0018_task_steps.py create mode 100644 tests/unit/test_20260313_0019_task_step_lineage.py create mode 100644 tests/unit/test_20260313_0020_approval_task_step_linkage.py create mode 100644 tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py create mode 100644 tests/unit/test_20260313_0022_task_workspaces.py create mode 100644 tests/unit/test_approval_store.py create mode 100644 tests/unit/test_approvals.py create mode 100644 tests/unit/test_approvals_main.py create mode 100644 tests/unit/test_compiler.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_db.py create mode 100644 tests/unit/test_embedding.py create mode 100644 tests/unit/test_embedding_store.py create mode 100644 tests/unit/test_entity.py create mode 100644 tests/unit/test_entity_edge.py create mode 100644 tests/unit/test_entity_store.py create mode 100644 tests/unit/test_env.py create mode 100644 tests/unit/test_events.py create mode 100644 tests/unit/test_execution_budget_store.py create mode 100644 tests/unit/test_execution_budgets.py create mode 100644 tests/unit/test_execution_budgets_main.py create mode 100644 tests/unit/test_executions.py create mode 100644 tests/unit/test_executions_main.py create mode 100644 tests/unit/test_explicit_preferences.py create mode 100644 tests/unit/test_main.py create mode 100644 tests/unit/test_memory.py create mode 100644 tests/unit/test_memory_store.py create mode 100644 tests/unit/test_ops_assets.py create mode 100644 tests/unit/test_policy.py create mode 100644 tests/unit/test_policy_main.py create mode 100644 tests/unit/test_policy_store.py create mode 100644 tests/unit/test_proxy_execution.py create mode 100644 tests/unit/test_proxy_execution_main.py create mode 100644 tests/unit/test_response_generation.py create mode 100644 tests/unit/test_semantic_retrieval.py create mode 100644 tests/unit/test_store.py create mode 100644 tests/unit/test_task_step_store.py create mode 100644 tests/unit/test_task_workspace_store.py create mode 100644 tests/unit/test_tasks.py create mode 100644 tests/unit/test_tasks_main.py create mode 100644 tests/unit/test_tool_execution_store.py create mode 100644 tests/unit/test_tool_store.py create mode 100644 tests/unit/test_tools.py create mode 100644 tests/unit/test_tools_main.py create mode 100644 tests/unit/test_trace_store.py create mode 100644 tests/unit/test_worker_main.py create mode 100644 tests/unit/test_workspaces.py create mode 100644 tests/unit/test_workspaces_main.py create mode 100644 workers/.gitkeep create mode 100644 workers/alicebot_worker/__init__.py create mode 100644 workers/alicebot_worker/main.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..923dfc8 --- /dev/null +++ b/.env.example @@ -0,0 +1,12 @@ +APP_ENV=development +APP_HOST=127.0.0.1 +APP_PORT=8000 +DATABASE_URL=postgresql://alicebot_app:alicebot_app@localhost:5432/alicebot +DATABASE_ADMIN_URL=postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot +REDIS_URL=redis://localhost:6379/0 +S3_ENDPOINT_URL=http://localhost:9000 +S3_ACCESS_KEY=alicebot +S3_SECRET_KEY=alicebot-secret +S3_BUCKET=alicebot-local +HEALTHCHECK_TIMEOUT_SECONDS=2 +TASK_WORKSPACE_ROOT=/tmp/alicebot/task-workspaces diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6acc4a7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +.env +.pytest_cache/ +.venv/ +*.egg-info/ +__pycache__/ +*.pyc +apps/web/.next/ +apps/web/node_modules/ diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..9cefafb --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,207 @@ +# Architecture + +## Current Implemented Slice + +AliceBot now implements the accepted repo slice through Sprint 5A. The shipped backend includes: + +- foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` +- deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records +- governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge +- deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events +- user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement +- durable `tasks`, `task_steps`, and `task_workspaces`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, and deterministic rooted local task-workspace provisioning + +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries. Broader runner-style orchestration, automatic multi-step progression, artifact indexing, document ingestion, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. + +## Implemented Now + +### Runtime + +- `docker-compose.yml` starts local Postgres with `pgvector`, Redis, and MinIO. +- `scripts/dev_up.sh`, `scripts/migrate.sh`, and `scripts/api_dev.sh` provide the local startup path, with readiness gating before migrations. +- `apps/api` exposes FastAPI endpoints for: + - health and compile: `/healthz`, `POST /v0/context/compile`, `POST /v0/responses` + - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` + - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` + - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` + - task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` +- `apps/web` and `workers` remain starter shells only. + +### Data Foundation + +- Postgres is the current system of record. +- Alembic manages schema changes through `apps/api/alembic`. +- The live schema includes: + - continuity tables: `users`, `threads`, `sessions`, `events` + - trace tables: `traces`, `trace_events` + - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` + - graph tables: `entities`, `entity_edges` + - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` + - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces` +- `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. +- `memory_review_labels` are append-only by database enforcement. +- `tasks` are explicit user-scoped lifecycle records keyed to one thread and one tool, with durable request/tool snapshots, status in `pending_approval | approved | executed | denied | blocked`, and latest approval/execution pointers for the current narrow lifecycle seam. +- `task_steps` are explicit user-scoped ordered lifecycle records keyed by `(user_id, task_id, sequence_no)`, with `kind = 'governed_request'`, status in `created | approved | executed | blocked | denied`, durable request/outcome snapshots, and one trace reference describing the latest mutation. +- Sprint 4O added lineage columns on `task_steps`: + - `parent_step_id` + - `source_approval_id` + - `source_execution_id` +- Lineage fields are guarded by composite user-scoped foreign keys and a self-reference check so a step cannot cite itself as its parent. +- `tool_executions` now persist an explicit `task_step_id` linked by a composite foreign key to `task_steps(id, user_id)`. +- `task_workspaces` persist one active workspace record per visible task and user, store a deterministic `local_path`, and enforce that active uniqueness through a partial unique index on `(user_id, task_id)`. +- `execution_budgets` enforce at most one active budget per `(user_id, tool_key, domain_hint)` selector scope through a partial unique index. +- Per-request user context is set in the database through `app.current_user_id()`. +- `TASK_WORKSPACE_ROOT` defines the only allowed base directory for workspace provisioning, and the live path rule is `resolved_root / user_id / task_id`. + +### Repo Boundaries In This Slice + +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, and task workspaces. +- `apps/web`: minimal shell only; no shipped workflow UI. +- `workers`: scaffold only; no background jobs or runner logic are implemented. +- `infra`: local development bootstrap assets only. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, and Sprint 5A task-workspace provisioning. + +## Core Flows Implemented Now + +### Deterministic Context Compilation + +1. Accept a user-scoped `POST /v0/context/compile` request. +2. Read durable continuity records in deterministic order. +3. Merge in active memories, entities, and entity edges through the currently shipped symbolic and optional semantic retrieval paths. +4. Persist a `context.compile` trace plus explicit inclusion and exclusion events. +5. Return one deterministic `context_pack` describing scope, limits, selected context, and trace metadata. + +### Governed Memory And Retrieval + +1. Accept explicit memory candidates through `POST /v0/memories/admit`. +2. Require cited source events, default to `NOOP`, and persist `memory_revisions` only for evidence-backed non-`NOOP` mutations. +3. Support a narrow deterministic explicit-preference extractor over stored `message.user` events. +4. Persist user-scoped embedding configs and memory embeddings explicitly. +5. Support direct semantic retrieval over active memories for a caller-selected embedding config. +6. Merge symbolic and semantic memory results deterministically into the compile path with trace-visible source provenance. +7. Expose review reads, unlabeled review queue reads, evaluation summary reads, and append-only memory-review labels. + +### Policy, Tool, Approval, And Execution Governance + +1. Evaluate policies deterministically over active user-scoped policy and consent state. +2. Evaluate tool allowlists against active tool metadata plus policy decisions. +3. Route one requested invocation deterministically to `ready`, `denied`, or `approval_required`. +4. Persist durable approval rows only for `approval_required` outcomes. +5. Resolve approvals explicitly through approve and reject endpoints. +6. Execute approved requests only through the registered proxy-handler map. +7. In the current repo, only `proxy.echo` is enabled, and it performs no external I/O. +8. Persist one durable `tool_executions` row for every approved execution attempt, including budget-blocked attempts. +9. Enforce narrow execution budgets by selector scope and optional rolling window before approved dispatch. + +### Task Lifecycle Creation + +1. `POST /v0/approvals/requests` always creates one durable `tasks` row and one initial `task_steps` row, even when no approval row is persisted. +2. The initial task and task step reflect the routing decision: + - `approval_required` creates `task.status = pending_approval` and `task_step.status = created` + - `ready` creates `task.status = approved` and `task_step.status = approved` + - `denied` creates `task.status = denied` and `task_step.status = denied` +3. The initial task step is always `sequence_no = 1`. +4. Approval-request traces include task lifecycle and task-step lifecycle events alongside the approval request events. + +### Approval Resolution And Proxy Execution Synchronization + +1. Approval resolution reuses the existing task seam and updates the durable task plus the explicitly linked task step from `approvals.task_step_id`. +2. Approval resolution rejects missing, invisible, cross-task, and inconsistent approval-to-step linkage deterministically. +3. Approved proxy execution validates the approval’s linked task step before dispatch and persists `tool_executions.task_step_id` on every durable execution row. +4. Execution synchronization now reuses `tool_executions.task_step_id` and updates the explicitly linked step by id rather than inferring `sequence_no = 1`. +5. Execution synchronization rejects missing, invisible, cross-task, and inconsistent execution-to-step linkage deterministically before mutating task or task-step state. + +### Task-Step Manual Continuation + +1. Accept a user-scoped `POST /v0/tasks/{task_id}/steps` request to append exactly one next step to an existing task. +2. Lock the task-step sequence before allocating the next `sequence_no`. +3. Require the task to already have visible steps. +4. Allow append only when the latest visible step is in `executed`, `blocked`, or `denied`. +5. Require explicit lineage: + - `lineage.parent_step_id` must be present + - the parent step must belong to the same visible task + - the parent step must be the latest visible task step +6. Optionally allow `lineage.source_approval_id` and `lineage.source_execution_id`, but only when: + - the referenced records are visible in the current user scope + - the referenced records already appear on the parent step outcome +7. Persist the new `task_steps` row with the lineage fields and incremented `sequence_no`. +8. Update the parent `tasks` row to the task status implied by the appended step status. +9. Persist one `task.step.continuation` trace plus request, lineage, summary, task lifecycle, and task-step lifecycle events. +10. Return the updated task, the appended step, deterministic sequencing metadata, and trace summary. + +### Task-Step Transition + +1. Accept a user-scoped `POST /v0/task-steps/{task_step_id}/transition` request. +2. Require the referenced step to be the latest visible step on its task. +3. Enforce the explicit status graph: + - `created -> approved | denied` + - `approved -> executed | blocked` + - terminal states have no further transitions +4. Require approval linkage when the step must reflect approval state and execution linkage when the step must reflect execution state. +5. Update the target step in place with a new trace reference and outcome snapshot. +6. Update the parent task status and latest approval/execution pointers consistently. +7. Persist one `task.step.transition` trace plus request, state, summary, task lifecycle, and task-step lifecycle events. + +### Task And Task-Step Reads + +1. `GET /v0/tasks` lists durable task rows in deterministic `created_at ASC, id ASC` order. +2. `GET /v0/tasks/{task_id}` returns one user-visible task detail record. +3. `GET /v0/tasks/{task_id}/steps` returns task steps in deterministic `sequence_no ASC, created_at ASC, id ASC` order plus sequencing summary metadata. +4. `GET /v0/task-steps/{task_step_id}` returns one user-visible task-step detail record. +5. Task-step list and detail reads expose lineage fields directly. + +### Task Workspace Provisioning + +1. Accept a user-scoped `POST /v0/tasks/{task_id}/workspace` request for one visible task. +2. Resolve the configured `TASK_WORKSPACE_ROOT`. +3. Build the deterministic local path as `resolved_root / user_id / task_id`. +4. Reject provisioning if the resolved workspace path escapes the resolved workspace root. +5. Lock workspace creation for the target task before checking for an existing active workspace. +6. Reject duplicate active workspace creation for the same visible task deterministically. +7. Create the local directory boundary and persist one `task_workspaces` row with `status = active` and the rooted `local_path`. +8. `GET /v0/task-workspaces` lists visible workspaces in deterministic `created_at ASC, id ASC` order. +9. `GET /v0/task-workspaces/{task_workspace_id}` returns one user-visible workspace detail record. + +## Security Model Implemented Now + +- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, and task-workspace tables enforce row-level security. +- The runtime role is limited to the narrow `SELECT` / `INSERT` / `UPDATE` permissions required by the shipped seams; there is no broad DDL or unrestricted table access at runtime. +- Cross-user references are constrained through composite foreign keys on `(id, user_id)` where the schema needs ownership-linked joins. +- Approval, execution, memory, entity, task/task-step, and task-workspace reads all operate only inside the current user scope. +- Task-step manual continuation adds both schema-level and service-level lineage protection: + - schema-level: user-scoped foreign keys and parent-not-self check + - service-level: same-task, latest-step, visible-approval, visible-execution, and parent-outcome-match validation +- In-place updates and deletes remain blocked for append-only continuity and trace records. + +## Testing Coverage Implemented Now + +- Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. +- Sprint 4O, Sprint 4S, and Sprint 5A added explicit task lifecycle coverage: + - migrations for `tasks`, `task_steps`, and task-step lineage + - staged/backfilled migration coverage for `tool_executions.task_step_id` + - task and task-step store contracts + - task list/detail and task-step list/detail reads + - deterministic sequencing summaries + - manual continuation success paths + - task-step transition success paths + - explicit later-step execution synchronization by linked `task_step_id` + - deterministic task-workspace path generation and rooted-path enforcement + - workspace create/list/detail response shape + - duplicate active workspace rejection + - task-workspace per-user isolation + - trace visibility for continuation and transition events + - user isolation for task and task-step reads and mutations + - adversarial lineage validation for cross-task, cross-user, and parent-step mismatch cases + +## Planned Later + +The following areas remain planned later and must not be described as implemented: + +- runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam +- artifact storage, artifact indexing, and document ingestion beyond the current rooted local workspace boundary +- read-only Gmail and Calendar connectors +- broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler +- model-driven extraction, reranking, and broader memory review automation +- production deployment automation beyond the local developer stack + +Future docs and code should continue to distinguish the implemented seams above from these later milestones. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3fd0cc0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,27 @@ +# Changelog + +## 2026-03-11 + +- Redacted embedded Redis credentials from `/healthz` so the endpoint no longer echoes `REDIS_URL` secrets back to callers. +- Added readiness gating to `./scripts/dev_up.sh` so bootstrap waits for Postgres and `alicebot_app` role initialization before running migrations. +- Bound local Postgres, Redis, and MinIO ports to `127.0.0.1` by default and removed the unnecessary runtime-role `CONNECT` grant on the shared `postgres` database. +- Removed the redundant `(thread_id, sequence_no)` events index from the base continuity migration because the unique constraint already provides that index. +- Tightened architecture, roadmap, handoff, and builder-report wording so exposed routes and environment-specific verification claims stay accurate. +- Tightened the runtime Postgres role so the continuity tables are insert/select-only in the migration chain and for upgraded databases. +- Stopped the base migration downgrade from dropping shared `pgcrypto` and `vector` extensions. +- Made the local helper scripts prefer `.venv/bin/python` when the project virtualenv exists, falling back to `python3` otherwise. +- Corrected `/healthz` so only Postgres is reported as live-checked, while Redis and MinIO are surfaced as configured but `not_checked`. +- Fixed Alembic runtime URL handling so migrations use the installed `psycopg` SQLAlchemy driver instead of the missing `psycopg2` default. +- Fixed concurrent event append sequencing by acquiring the per-thread advisory lock before reading the next `sequence_no`. +- Verified the local foundation runtime with `docker compose up -d`, `./scripts/migrate.sh`, `./.venv/bin/python -m pytest tests/unit tests/integration`, and a live `GET /healthz`. + +## 2026-03-10 + +- Bootstrapped the canonical project operating files. +- Created the initial AI handoff snapshot and first sprint packet. +- Added the recommended repo scaffolding directories for implementation work. +- Added local Docker Compose infrastructure for Postgres with `pgvector`, Redis, and MinIO. +- Added the FastAPI foundation scaffold, configuration loading, `/healthz`, and Alembic migration plumbing. +- Added continuity tables for `users`, `threads`, `sessions`, and append-only `events` with RLS and isolation tests. +- Fixed the local quick-start path so repo scripts source `.env`, use `python3`, and keep migrations pointed at the `alicebot` database. +- Serialized same-thread event appends before sequence allocation and added an integration test for concurrent event numbering. diff --git a/PRODUCT_BRIEF.md b/PRODUCT_BRIEF.md new file mode 100644 index 0000000..1735c23 --- /dev/null +++ b/PRODUCT_BRIEF.md @@ -0,0 +1,77 @@ +# Product Brief + +## Product Summary + +AliceBot is a private, permissioned personal AI operating system for a single primary user. It is designed to preserve durable personal context, retrieve the right context at the right time, and move safely from conversation to action without hiding why it acted. + +## Problem + +General-purpose assistants forget preferences, prior decisions, and relationships across sessions. They also make it difficult to audit why they answered a certain way or whether a tool action was properly governed. The result is low trust, repeated user effort, and unsafe action handling. + +## Target Users + +- Primary v1 user: one power user with recurring life and work workflows. +- Delivery model: a human lead working with AI builders and reviewers. +- Architectural assumption: v1 UX is single-user, but the data model must support strict per-user isolation from day one. + +## Core Value Proposition + +- Durable memory for preferences, relationships, prior decisions, and recurring tasks. +- Deterministic context compilation instead of ad hoc prompt stuffing. +- Safe action orchestration with policy checks, approvals, and budgets. +- Clear explainability through traces, memory evidence, and tool history. + +## V1 Scope + +- Web-based chat and task orchestration. +- Immutable thread and session continuity. +- Structured memory with admission controls, revision history, and user review. +- Entity and relationship tracking for people, merchants, products, projects, and routines. +- Hybrid retrieval across memories, entities, relationships, and documents. +- Policy engine, tool proxy, approval workflows, and task budgets. +- Scoped task workspaces and artifact storage. +- Read-only document ingestion plus read-only Gmail and Calendar connectors. +- Hot consolidation for immediate truth updates and cold consolidation for cleanup and summarization. +- Explain-why views for important responses and actions. + +## Non-Goals + +- Autonomous side effects without user approval. +- Multi-user collaboration UX in v1. +- Mobile-first delivery. +- Dedicated graph or vector infrastructure in v1. +- Browser automation, write-capable connectors, proactive automations, and voice at launch. + +## Key User Journeys + +1. Ask a question that depends on prior preferences, purchases, or relationships and get a context-aware answer without restating history. +2. Correct a preference or fact and have the next turn reflect the new truth immediately. +3. Inspect why the system answered or proposed an action by reviewing memories, retrieval choices, and tool traces. +4. Run a repeat-purchase workflow that gathers prior context, proposes the order, pauses for approval, and records the outcome. +5. Retrieve relevant context from documents, Gmail, or Calendar without granting write access. + +## Constraints + +- Single-user product experience, multi-tenant-safe architecture. +- Web-first v1. +- Explicit approval for consequential actions. +- Operational simplicity beats platform sprawl in v1. +- Memory quality, retrieval quality, and explainability are ship-gating concerns. + +## Success Criteria + +- The system recalls relevant preferences, past purchases, relationships, and prior decisions without repeated user restatement. +- The repeat magnesium reorder workflow succeeds end to end with approval gating and memory write-back. +- Every consequential action is explainable through trace, memory, rule, and tool evidence. +- Purchases, emails, bookings, and other side effects never occur without explicit approval. +- Standard retrieval-plus-response interactions reach p95 latency under 5 seconds. +- Prompt and cache reuse exceeds 70% on repeated patterns. +- Memory extraction precision exceeds 80% at ship. + +## Product Non-Negotiables + +- The user stays in control of consequential actions. +- Durable context must come from governed storage, not raw transcript stuffing. +- Explainability is a product requirement, not a debugging feature. +- Preference contradictions must be reflected immediately. +- The repeat magnesium reorder scenario is the canonical v1 ship gate. diff --git a/README.md b/README.md index 5ea9f40..97fd320 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,36 @@ # AliceBot -AliceBot is a private, permissioned personal AI operating system. This repository currently holds the canonical product, architecture, roadmap, and AI handoff documents that future implementation work should follow. +AliceBot is a private, permissioned personal AI operating system. The repository now includes the runnable foundation slice plus the first tracing/context-compilation seam, the first governed memory/admissions-and-embeddings slice, the first deterministic response-generation seam, the first governance routing seam for non-executing tool requests, the first durable approval-request persistence seam for `approval_required` routing outcomes, the explicit approval-resolution seam, the first minimal approved-only proxy-execution seam, the first durable execution-review seam over that proxy path, the narrow execution-budget lifecycle seam over approved proxy execution, and the first deterministic task-workspace provisioning seam: local infrastructure, an API scaffold, migration tooling, continuity primitives, persisted traces, a deterministic continuity-only compiler, explicit memory admission, a narrow deterministic explicit-preference extraction path, explicit embedding-config and memory-embedding storage paths, a direct semantic memory retrieval primitive, deterministic hybrid compile-path memory merge, a no-tools model invocation path over deterministically assembled prompts, deterministic policy and tool-governance seams, a narrow no-side-effect proxy handler path, durable `tool_executions` records, durable `execution_budgets` records, durable `task_workspaces` records, execution-budget create/list/detail reads, budget deactivate/supersede lifecycle operations, active-only budget enforcement, budget-blocked execution persistence, task-workspace create/list/detail reads, and backend verification coverage. ## Status -- Planning has been distilled into durable operating docs. -- Application code has not been scaffolded yet. -- The first execution target is the foundation sprint in [.ai/active/SPRINT_PACKET.md](.ai/active/SPRINT_PACKET.md). +- Local Docker Compose infrastructure is defined for Postgres with `pgvector`, Redis, and MinIO. +- `apps/api` contains FastAPI health, compile, response-generation, memory-admission, explicit-preference extraction, semantic-memory-retrieval, policy, tool-registry, tool-allowlist, tool-routing, approval-request, approval-resolution, proxy-execution, execution-budget, execution-review, task, and task-workspace endpoints, configuration loading, Alembic migrations, continuity storage primitives, the Sprint 2A trace/compiler path, the Sprint 3A memory-admission path, the Sprint 3I deterministic extraction path, the Sprint 3K embedding substrate, the Sprint 3L semantic retrieval primitive, the Sprint 3M compile-path semantic retrieval adoption, the Sprint 3N deterministic hybrid memory merge, the Sprint 4A deterministic prompt-assembly and no-tools response path, the Sprint 4D deterministic non-executing tool-routing seam, the Sprint 4E durable approval-request persistence seam, the Sprint 4F approval-resolution seam, the Sprint 4G minimal approved-only proxy-execution seam, the Sprint 4H durable execution-review seam, the Sprint 4I execution-budget guard seam, the Sprint 4J execution-budget lifecycle seam, the Sprint 4K time-windowed execution-budget seam, the Sprint 4S explicit execution-to-task-step linkage seam, and the Sprint 5A task-workspace provisioning seam. +- `apps/web` and `workers` contain minimal starter scaffolds for later milestone work. +- The active sprint is documented in [.ai/active/SPRINT_PACKET.md](.ai/active/SPRINT_PACKET.md). -## Quick Start Assumptions +## Quick Start -- Assumption: local development will use Docker Compose for Postgres, Redis, and S3-compatible storage. -- Assumption: backend work will use Python 3.12 and FastAPI. -- Assumption: frontend work will use Node.js 20, `pnpm`, and Next.js. -- Secrets must stay out of the repo; use `.env` files locally and a secret manager in deployed environments. +1. Create a local env file: `cp .env.example .env` +2. Start required infrastructure with one command: `docker compose up -d` +3. Create a project virtualenv and install Python dependencies: `python3 -m venv .venv && ./.venv/bin/python -m pip install -e '.[dev]'` +4. Run database migrations: `./scripts/migrate.sh` +5. Start the API locally: `./scripts/api_dev.sh` + +The health endpoint is exposed at [http://127.0.0.1:8000/healthz](http://127.0.0.1:8000/healthz). +The minimal context-compilation API path is `POST /v0/context/compile`. +The minimal response-generation API path is `POST /v0/responses`. +The minimal memory-admission API path is `POST /v0/memories/admit`. +The explicit-preference extraction API path is `POST /v0/memories/extract-explicit-preferences`. +The minimal non-executing tool-routing API path is `POST /v0/tools/route`. +The minimal approval API paths are `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, and `POST /v0/approvals/{approval_id}/execute`. +The execution-budget API paths are `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, and `POST /v0/execution-budgets/{execution_budget_id}/supersede`. +The execution-review API paths are `GET /v0/tool-executions` and `GET /v0/tool-executions/{execution_id}`. +The task-workspace API paths are `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, and `GET /v0/task-workspaces/{task_workspace_id}`. +The helper scripts load the repo-root `.env` automatically and prefer `.venv/bin/python` when that virtualenv exists, falling back to `python3` otherwise. The default migration/admin URL targets the same local `alicebot` database as the app runtime. +`/healthz` currently performs a live Postgres check only. Redis and MinIO are reported as configured endpoints with `not_checked` status. +`TASK_WORKSPACE_ROOT` controls the single rooted base directory used for deterministic local task-workspace provisioning. By default it is `/tmp/alicebot/task-workspaces`, and each workspace path is created as `//`. +The current backend path has been verified in a local developer environment with `docker compose up -d`, `./scripts/migrate.sh`, `./.venv/bin/python -m pytest tests/unit tests/integration`, a live `GET /healthz`, and the Postgres-backed `POST /v0/context/compile`, `POST /v0/responses`, `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `POST /v0/memories/semantic-retrieval`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `POST /v0/approvals/{approval_id}/execute`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, and `GET /v0/tool-executions/{execution_id}` integration paths, including compile requests that explicitly enable the hybrid memory merge, response requests that persist assistant events and response traces, deterministic non-executing tool-routing requests that persist `tool.route.*` traces, approval-request persistence requests that persist `approval.request.*` traces plus durable approval rows only for `approval_required` outcomes, approved proxy execution that persists `tool.proxy.execute.*` traces plus durable `tool_executions` rows for approved execution attempts, deterministic budget-management requests over durable `execution_budgets` rows, lifecycle requests that persist `execution_budget.lifecycle.*` traces and change budget status deterministically, budget-prechecked proxy execution that emits `tool.proxy.execute.budget` trace events against active budgets only, and execution-review reads over those durable records including budget-blocked attempts. ## Repo Structure @@ -23,23 +40,34 @@ AliceBot is a private, permissioned personal AI operating system. This repositor - [RULES.md](RULES.md): durable engineering and scope rules. - [.ai/handoff/CURRENT_STATE.md](.ai/handoff/CURRENT_STATE.md): fresh-thread recovery snapshot. - [.ai/active/SPRINT_PACKET.md](.ai/active/SPRINT_PACKET.md): current builder sprint. -- `docs/adr/`: architecture decision records. -- `docs/runbooks/`: operational procedures. -- `docs/archive/`: source material and retired planning docs. -- `apps/api/`, `apps/web/`, `workers/`, `tests/`, `scripts/`: planned implementation areas. +- `docker-compose.yml`: local Postgres, Redis, and MinIO stack. +- `infra/postgres/init/`: Postgres bootstrap SQL, including the non-superuser app role. +- `apps/api/`: FastAPI app, config, continuity store, and Alembic migrations. +- `apps/web/`: minimal Next.js shell for later dashboard work. +- `workers/`: placeholder Python worker package for future background jobs. +- `tests/`: unit and Postgres-backed integration tests for the foundation slice. +- `scripts/`: local development and migration entrypoints. ## Essential Commands -- `docker compose up -d`: expected local infra start command once the foundation sprint lands. -- `alembic upgrade head`: expected database migration command once the API scaffold exists. -- `pytest`: expected backend and integration test entrypoint. -- `pnpm test`: expected frontend test entrypoint. -- `pnpm lint`: expected frontend lint entrypoint. +- `docker compose up -d`: start Postgres, Redis, and MinIO on `127.0.0.1`. +- `./scripts/dev_up.sh`: start local infrastructure, wait for Postgres and role bootstrap readiness, and apply Alembic migrations. +- `./scripts/migrate.sh`: apply Alembic migrations with the admin database URL from `.env` or the built-in defaults. +- `./scripts/api_dev.sh`: run the FastAPI service with auto-reload. +- `./.venv/bin/python -m pytest tests/unit tests/integration`: run backend tests from the project virtualenv. +- `pnpm --dir apps/web dev`: start the web shell after frontend dependencies are installed. ## Environment Notes -- Postgres is the planned system of record and must support `pgvector`. -- Redis is planned for queues, locks, and short-lived cache data. -- Object storage is planned for documents and task artifacts. -- Authentication, row-level security, and approval boundaries are first-class requirements from the start. -# AliceBot +- Postgres is the system of record and the live schema now includes continuity tables, trace tables, policy-governance tables including `approvals`, `tool_executions`, and `execution_budgets`, task lifecycle tables including `tasks`, `task_steps`, and `task_workspaces`, memory tables, entity tables, and the embedding substrate tables `embedding_configs` and `memory_embeddings`. +- Sprint 2A adds persisted `traces` and `trace_events` plus a deterministic continuity-only context compiler over existing durable continuity records. +- Sprint 3A adds governed `memories` and append-only `memory_revisions` plus an explicit `NOOP`-first admission path over cited source events. +- The app and migration defaults both target the local `alicebot` database to keep quick-start behavior deterministic. +- `TASK_WORKSPACE_ROOT` defaults to `/tmp/alicebot/task-workspaces` and defines the only allowed root for deterministic local task-workspace provisioning. +- Local service ports are bound to `127.0.0.1` by default to avoid exposing fixed development credentials on non-loopback interfaces. +- Redis is reserved for future queue, lock, and cache work; no retrieval or orchestration features are enabled in this sprint. +- MinIO provides the local S3-compatible endpoint for future document and artifact storage. +- Continuity tables enforce row-level security from the start and `events` are append-only by application contract plus database trigger, with concurrent appends serialized per thread. +- Trace tables follow the same per-user isolation model, with append-only `trace_events` for compiler explainability. +- Memory admission remains explicit and evidence-backed, automatic extraction is currently limited to a narrow deterministic explicit-preference path over stored user messages, and the repo now includes explicit versioned embedding-config storage, direct memory-embedding persistence, a direct semantic retrieval API over active durable memories, compile-path hybrid memory merge into one `context_pack["memories"]` section with `memory_summary.hybrid_retrieval` metadata, one deterministic no-tools response path that assembles prompts from durable compiled context and persists assistant replies plus response traces, one deterministic approval-request persistence path over `approval_required` tool-routing outcomes, explicit approval resolution, one minimal approved-only proxy execution path through the no-side-effect `proxy.echo` handler, durable execution-review records plus list/detail reads for approved execution attempts, one narrow deterministic execution-budget seam that can activate, deactivate, supersede, and enforce both lifetime and rolling-window limits using durable `tool_executions` history while keeping blocked attempts reviewable, and one narrow deterministic task-workspace seam that provisions rooted local workspace directories and persists durable `task_workspaces` rows. Broader extraction, reranking, external-connector tool execution, artifact indexing, document ingestion, orchestration, and review UI remain deferred. +- The runtime database role is limited to `SELECT`/`INSERT` on continuity and trace tables, `SELECT`/`INSERT` on `memory_revisions`, `memory_review_labels`, `embedding_configs`, `entities`, and `entity_edges`, plus `SELECT`/`INSERT`/`UPDATE` on `consents`, `memories`, `memory_embeddings`, and `execution_budgets`. diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..ed7c2ba --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,97 @@ +# Roadmap + +## Current State + +- The repo has shipped the implementation slices originally planned as Milestones 1 through 4. +- Sprint 4O added the latest accepted backend seam: durable `tasks` and `task_steps` with explicit manual continuation lineage and deterministic task-step transitions. +- The project is no longer at Foundation. The current repo state is a post-Milestone-4 checkpoint, and this sprint is synchronizing project-truth docs before Milestone 5 work begins. +- No task runner, workspace/artifact layer, document ingestion, read-only connector, or broader side-effect surface has landed yet. + +## Completed Milestones + +### Milestone 1: Foundation + +- Repo scaffold, local Docker Compose infra, FastAPI app shell, config loading, migration tooling, and backend test harness. +- Postgres continuity primitives: `users`, `threads`, `sessions`, and append-only `events`. +- Row-level-security foundation and concurrent event sequencing hardening. + +Status on March 13, 2026: +- Complete. + +### Milestone 2: Context Compiler and Tracing + +- Deterministic context compilation over durable continuity records. +- Persisted `traces` and append-only `trace_events`. +- Trace-visible inclusion and exclusion reasoning for compiled context. + +Status on March 13, 2026: +- Complete. + +### Milestone 3: Memory and Retrieval + +- Governed memory admission with append-only revisions. +- Narrow deterministic explicit-preference extraction from stored user events. +- Memory review labels, review queue reads, and evaluation summary reads. +- Explicit entities and temporal entity edges backed by cited memories. +- Versioned embedding configs, durable memory embeddings, direct semantic retrieval, and deterministic hybrid compile-path memory merge. + +Status on March 13, 2026: +- Complete. + +### Milestone 4: Governance and Safe Action + +- Deterministic response generation over compiled context. +- User-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, and tool routing. +- Durable approval requests and explicit approval resolution. +- Approved-only proxy execution through the in-process `proxy.echo` handler. +- Durable execution review, execution-budget enforcement, lifecycle mutations, and optional rolling-window limits. +- Durable `tasks` and `task_steps`, deterministic task-step reads, explicit task-step transitions, and explicit manual continuation with lineage. + +Status on March 13, 2026: +- Complete through Sprint 4O. + +## Current Milestone Position + +- The repo is at the boundary after Milestone 4. +- Milestone 5 has not started in shipped code yet. +- The immediate work is documentation synchronization and narrow lifecycle-boundary hardening so Milestone 5 planning and review start from truthful artifacts. + +## Next Milestones + +### Immediate Next Narrow Boundary + +- Preserve the current manual-continuation seam as the only shipped multi-step task path. +- Remove or explicitly constrain the remaining approval/execution helpers that still synchronize against `task_steps.sequence_no = 1` before starting runner-style orchestration or workspace-heavy task flows. + +### Milestone 5: Documents, Workspaces, and Read-Only Connectors + +- Add document ingestion and chunk retrieval. +- Add scoped task workspaces and artifact handling. +- Add read-only Gmail and Calendar sync. +- Keep connector scope read-only and approval-aware. + +### Sequencing After Milestone 5 + +- Generalize task lifecycle handling beyond the current manual continuation seam. +- Introduce runner-style orchestration only after the first-step lifecycle assumption is removed. +- Expand tool execution breadth only after the governance and task seams stay deterministic under multi-step flows. + +## Dependencies + +- Truth artifacts must stay synchronized before milestone planning and review work can be trusted. +- The current first-step lifecycle assumption must be resolved before broader runner or workspace work can safely depend on `tasks` / `task_steps`. +- Scoped workspace and artifact boundaries should land before document-heavy or connector-heavy flows rely on them. +- Connector scope should remain deferred until the core memory, governance, and task seams stay stable under the shipped workload. + +## Blockers and Risks + +- Memory extraction and retrieval quality remain the biggest product risk. +- Auth beyond DB user context is still unimplemented. +- The remaining first-step approval/execution synchronization helpers are a forward-compatibility risk for broader multi-step orchestration. +- Workspace or connector work could create hidden scope drift if it starts before the current task-lifecycle boundary is hardened. + +## Recently Completed + +- Durable approval, execution review, and execution-budget seams over the approved proxy path. +- Durable `tasks` and `task_steps` with deterministic reads and status transitions. +- Explicit task-step lineage and manual continuation, including adversarial validation for cross-task, cross-user, and parent-step mismatch cases. diff --git a/RULES.md b/RULES.md new file mode 100644 index 0000000..f6ac44b --- /dev/null +++ b/RULES.md @@ -0,0 +1,52 @@ +# Rules + +## Product / Scope Rules + +- The active sprint packet is the top priority scope boundary for implementation work and overrides broader roadmap intent when they conflict. +- Never represent planned architecture as implemented behavior in docs, handoffs, or build reports. +- Never execute a consequential external action without explicit user approval. +- Always treat explainability as a product feature, not an internal debugging aid. +- Treat the repeat magnesium reorder as the v1 ship-gate scenario. +- Never expand v1 scope with proactive automation, write-capable connectors, voice, or browser automation without an explicit roadmap change. +- Do not start runner, workspace/artifact, document-ingestion, or connector work unless the active sprint explicitly opens that boundary. + +## Architecture Rules + +- Treat the immutable event store as ground truth; memories, tasks, and summaries are derived or governed views over durable records. +- Always compile context per invocation from durable sources. +- Keep prompt prefixes, tool schemas, and serialized context ordering deterministic. +- Treat Postgres as the v1 system of record unless measured constraints justify a platform split. +- Appended task steps must carry explicit lineage to a prior visible task step. Do not relink approvals or executions heuristically from broader task history. +- Manual continuation is the current multi-step boundary. Until the older first-step lifecycle helpers are removed or constrained, do not describe broader automatic multi-step orchestration as implemented. + +## Coding Rules + +- Always build against typed contracts and migration-backed schemas first. +- Never mutate tool schemas mid-session; enforce access through policy and proxy layers. +- Keep changes small, module-scoped, and test-backed. +- Stop long-running tasks with a clear progress summary when budgets or circuit breakers trip. +- Sprint-scoped docs must clearly separate what exists now from what is only planned later. + +## Data / Schema Rules + +- Enforce row-level security on every user-owned table from the start. +- Default memory admission to `NOOP`; promote only evidence-backed changes. +- Always keep memory revision history for non-`NOOP` changes. +- Task-step lineage references must stay inside the current user scope and must validate against the intended parent step and its recorded outcome. +- Apply domain and sensitivity filters before semantic retrieval. + +## Deployment / Ops Rules + +- Keep v1 operations simple: one modular monolith, one primary database, one cache, one object store. +- Never store secrets in source control, committed config, or logs. +- Any repo-advertised bootstrap script that starts dependencies and then runs dependent commands must wait for service readiness before proceeding. +- When external side effects are introduced, route them through approval-aware tool execution paths. +- Backups and object versioning are required before production use. + +## Testing Rules + +- Schema changes are not complete without forward and rollback coverage. +- Every module needs unit tests and at least one integration boundary test. +- Approval boundaries, RLS isolation, and audit logging require adversarial tests. +- Lineage changes require adversarial tests for cross-task, cross-user, and parent-step mismatch cases. +- Memory quality and retrieval quality need labeled evaluations before release claims. diff --git a/apps/api/.gitkeep b/apps/api/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/apps/api/.gitkeep @@ -0,0 +1 @@ + diff --git a/apps/api/alembic.ini b/apps/api/alembic.ini new file mode 100644 index 0000000..2ca852a --- /dev/null +++ b/apps/api/alembic.ini @@ -0,0 +1,37 @@ +[alembic] +script_location = apps/api/alembic +prepend_sys_path = apps/api/src +path_separator = os +sqlalchemy.url = postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = console +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s diff --git a/apps/api/alembic/env.py b/apps/api/alembic/env.py new file mode 100644 index 0000000..b8880aa --- /dev/null +++ b/apps/api/alembic/env.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from logging.config import fileConfig +import os + +from alembic import context +from sqlalchemy import engine_from_config, pool + + +config = context.config + +target_metadata = None + + +def normalize_sqlalchemy_url(database_url: str) -> str: + if database_url.startswith("postgresql://"): + return database_url.replace("postgresql://", "postgresql+psycopg://", 1) + return database_url + + +def get_url() -> str: + database_url = ( + os.getenv("DATABASE_ADMIN_URL") + or os.getenv("DATABASE_URL") + or config.get_main_option("sqlalchemy.url") + ) + return normalize_sqlalchemy_url(database_url) + + +def configure_logging() -> None: + if config.config_file_name is not None: + fileConfig(config.config_file_name) + + +def run_migrations_offline() -> None: + context.configure( + url=get_url(), + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + configuration = config.get_section(config.config_ini_section, {}) + configuration["sqlalchemy.url"] = get_url() + connectable = engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations() -> None: + configure_logging() + if context.is_offline_mode(): + run_migrations_offline() + else: + run_migrations_online() + + +run_migrations() diff --git a/apps/api/alembic/versions/20260310_0001_foundation_continuity.py b/apps/api/alembic/versions/20260310_0001_foundation_continuity.py new file mode 100644 index 0000000..eeb1d3b --- /dev/null +++ b/apps/api/alembic/versions/20260310_0001_foundation_continuity.py @@ -0,0 +1,167 @@ +"""Create continuity foundation tables with RLS and append-only events.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260310_0001" +down_revision = None +branch_labels = None +depends_on = None + +_RLS_TABLES = ("users", "threads", "sessions", "events") + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + "CREATE EXTENSION IF NOT EXISTS pgcrypto", + "CREATE EXTENSION IF NOT EXISTS vector", + "CREATE SCHEMA IF NOT EXISTS app", + """ + CREATE OR REPLACE FUNCTION app.current_user_id() + RETURNS uuid + LANGUAGE sql + STABLE + AS $$ + SELECT NULLIF(current_setting('app.current_user_id', true), '')::uuid + $$; + """, + """ + CREATE OR REPLACE FUNCTION app.reject_event_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'events are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE users ( + id uuid PRIMARY KEY, + email text NOT NULL UNIQUE, + display_name text, + created_at timestamptz NOT NULL DEFAULT now() + ); + + CREATE TABLE threads ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id) + ); + + CREATE TABLE sessions ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + thread_id uuid NOT NULL, + status text NOT NULL DEFAULT 'active', + started_at timestamptz NOT NULL DEFAULT now(), + ended_at timestamptz, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE + ); + + CREATE TABLE events ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + thread_id uuid NOT NULL, + session_id uuid, + sequence_no bigint NOT NULL, + kind text NOT NULL, + payload jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (thread_id, sequence_no), + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + FOREIGN KEY (session_id, user_id) + REFERENCES sessions(id, user_id) + ON DELETE CASCADE + ); + + CREATE INDEX sessions_thread_created_idx + ON sessions (thread_id, created_at); + CREATE INDEX threads_user_created_idx + ON threads (user_id, created_at); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER events_append_only + BEFORE UPDATE OR DELETE ON events + FOR EACH ROW + EXECUTE FUNCTION app.reject_event_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT USAGE ON SCHEMA public TO alicebot_app", + "GRANT USAGE ON SCHEMA app TO alicebot_app", + "GRANT SELECT, INSERT ON users TO alicebot_app", + "GRANT SELECT, INSERT ON threads TO alicebot_app", + "GRANT SELECT, INSERT ON sessions TO alicebot_app", + "GRANT SELECT, INSERT ON events TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY users_is_owner ON users + USING (id = app.current_user_id()) + WITH CHECK (id = app.current_user_id()); + + CREATE POLICY threads_is_owner ON threads + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY sessions_is_owner ON sessions + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY events_read_own ON events + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY events_insert_own ON events + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS events_append_only ON events", + "DROP TABLE IF EXISTS events", + "DROP TABLE IF EXISTS sessions", + "DROP TABLE IF EXISTS threads", + "DROP TABLE IF EXISTS users", + "DROP FUNCTION IF EXISTS app.reject_event_mutation()", + "DROP FUNCTION IF EXISTS app.current_user_id()", + "DROP SCHEMA IF EXISTS app", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py b/apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py new file mode 100644 index 0000000..5935399 --- /dev/null +++ b/apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py @@ -0,0 +1,39 @@ +"""Tighten the runtime role to insert/select-only continuity access.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260311_0002" +down_revision = "20260310_0001" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + "REVOKE UPDATE ON users FROM alicebot_app", + "REVOKE UPDATE ON threads FROM alicebot_app", + "REVOKE UPDATE ON sessions FROM alicebot_app", +) + +# Revision 20260310_0001 already leaves the runtime role with no UPDATE grants +# on these tables. Downgrading back to that revision should therefore preserve +# the same privilege floor explicitly rather than re-introducing broader access. +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON users FROM alicebot_app", + "REVOKE UPDATE ON threads FROM alicebot_app", + "REVOKE UPDATE ON sessions FROM alicebot_app", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260311_0003_trace_backbone.py b/apps/api/alembic/versions/20260311_0003_trace_backbone.py new file mode 100644 index 0000000..6028ff4 --- /dev/null +++ b/apps/api/alembic/versions/20260311_0003_trace_backbone.py @@ -0,0 +1,117 @@ +"""Add persisted traces and trace events for context compilation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260311_0003" +down_revision = "20260311_0002" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("traces", "trace_events") + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + """ + CREATE OR REPLACE FUNCTION app.reject_trace_event_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'trace events are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE traces ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + thread_id uuid NOT NULL, + kind text NOT NULL, + compiler_version text NOT NULL, + status text NOT NULL, + limits jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE + ); + + CREATE TABLE trace_events ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + trace_id uuid NOT NULL, + sequence_no bigint NOT NULL, + kind text NOT NULL, + payload jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (trace_id, sequence_no), + FOREIGN KEY (trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE CASCADE + ); + + CREATE INDEX traces_thread_created_idx + ON traces (thread_id, created_at); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER trace_events_append_only + BEFORE UPDATE OR DELETE ON trace_events + FOR EACH ROW + EXECUTE FUNCTION app.reject_trace_event_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON traces TO alicebot_app", + "GRANT SELECT, INSERT ON trace_events TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY traces_is_owner ON traces + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY trace_events_read_own ON trace_events + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY trace_events_insert_own ON trace_events + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS trace_events_append_only ON trace_events", + "DROP TABLE IF EXISTS trace_events", + "DROP TABLE IF EXISTS traces", + "DROP FUNCTION IF EXISTS app.reject_trace_event_mutation()", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260311_0004_memory_admission.py b/apps/api/alembic/versions/20260311_0004_memory_admission.py new file mode 100644 index 0000000..c782d3b --- /dev/null +++ b/apps/api/alembic/versions/20260311_0004_memory_admission.py @@ -0,0 +1,123 @@ +"""Add governed memory tables and append-only memory revisions.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260311_0004" +down_revision = "20260311_0003" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("memories", "memory_revisions") + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + """ + CREATE OR REPLACE FUNCTION app.reject_memory_revision_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'memory revisions are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE memories ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + memory_key text NOT NULL, + value jsonb NOT NULL, + status text NOT NULL, + source_event_ids jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + deleted_at timestamptz, + UNIQUE (id, user_id), + UNIQUE (user_id, memory_key) + ); + + CREATE TABLE memory_revisions ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + memory_id uuid NOT NULL, + sequence_no bigint NOT NULL, + action text NOT NULL, + memory_key text NOT NULL, + previous_value jsonb, + new_value jsonb, + source_event_ids jsonb NOT NULL, + candidate jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (memory_id, sequence_no), + FOREIGN KEY (memory_id, user_id) + REFERENCES memories(id, user_id) + ON DELETE CASCADE + ); + + CREATE INDEX memories_user_status_updated_idx + ON memories (user_id, status, updated_at); + CREATE INDEX memory_revisions_memory_created_idx + ON memory_revisions (memory_id, created_at); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER memory_revisions_append_only + BEFORE UPDATE OR DELETE ON memory_revisions + FOR EACH ROW + EXECUTE FUNCTION app.reject_memory_revision_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON memories TO alicebot_app", + "GRANT SELECT, INSERT ON memory_revisions TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY memories_is_owner ON memories + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY memory_revisions_read_own ON memory_revisions + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY memory_revisions_insert_own ON memory_revisions + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS memory_revisions_append_only ON memory_revisions", + "DROP TABLE IF EXISTS memory_revisions", + "DROP TABLE IF EXISTS memories", + "DROP FUNCTION IF EXISTS app.reject_memory_revision_mutation()", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0005_memory_review_labels.py b/apps/api/alembic/versions/20260312_0005_memory_review_labels.py new file mode 100644 index 0000000..2b7ede5 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0005_memory_review_labels.py @@ -0,0 +1,99 @@ +"""Add append-only memory review labels for human evaluation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0005" +down_revision = "20260311_0004" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("memory_review_labels",) + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + """ + CREATE OR REPLACE FUNCTION app.reject_memory_review_label_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'memory review labels are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE memory_review_labels ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + memory_id uuid NOT NULL, + label text NOT NULL, + note text, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + FOREIGN KEY (memory_id, user_id) + REFERENCES memories(id, user_id) + ON DELETE CASCADE, + CONSTRAINT memory_review_labels_label_check + CHECK (label IN ('correct', 'incorrect', 'outdated', 'insufficient_evidence')), + CONSTRAINT memory_review_labels_note_length_check + CHECK (note IS NULL OR char_length(note) <= 280) + ); + + CREATE INDEX memory_review_labels_memory_created_idx + ON memory_review_labels (memory_id, created_at, id); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER memory_review_labels_append_only + BEFORE UPDATE OR DELETE ON memory_review_labels + FOR EACH ROW + EXECUTE FUNCTION app.reject_memory_review_label_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON memory_review_labels TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY memory_review_labels_read_own ON memory_review_labels + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY memory_review_labels_insert_own ON memory_review_labels + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS memory_review_labels_append_only ON memory_review_labels", + "DROP TABLE IF EXISTS memory_review_labels", + "DROP FUNCTION IF EXISTS app.reject_memory_review_label_mutation()", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0006_entities_backbone.py b/apps/api/alembic/versions/20260312_0006_entities_backbone.py new file mode 100644 index 0000000..a1d3bcb --- /dev/null +++ b/apps/api/alembic/versions/20260312_0006_entities_backbone.py @@ -0,0 +1,72 @@ +"""Add explicit user-scoped entities backed by durable source memories.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0006" +down_revision = "20260312_0005" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("entities",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE entities ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + entity_type text NOT NULL, + name text NOT NULL, + source_memory_ids jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT entities_type_check + CHECK (entity_type IN ('person', 'merchant', 'product', 'project', 'routine')), + CONSTRAINT entities_name_length_check + CHECK (char_length(name) BETWEEN 1 AND 200), + CONSTRAINT entities_source_memory_ids_array_check + CHECK (jsonb_typeof(source_memory_ids) = 'array'), + CONSTRAINT entities_source_memory_ids_nonempty_check + CHECK (jsonb_array_length(source_memory_ids) > 0) + ); + + CREATE INDEX entities_user_created_idx + ON entities (user_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON entities TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY entities_is_owner ON entities + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS entities", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0007_entity_edges.py b/apps/api/alembic/versions/20260312_0007_entity_edges.py new file mode 100644 index 0000000..fa08bda --- /dev/null +++ b/apps/api/alembic/versions/20260312_0007_entity_edges.py @@ -0,0 +1,83 @@ +"""Add explicit user-scoped entity edges with simple temporal metadata.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0007" +down_revision = "20260312_0006" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("entity_edges",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE entity_edges ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + from_entity_id uuid NOT NULL, + to_entity_id uuid NOT NULL, + relationship_type text NOT NULL, + valid_from timestamptz NULL, + valid_to timestamptz NULL, + source_memory_ids jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT entity_edges_from_entity_fkey + FOREIGN KEY (from_entity_id, user_id) REFERENCES entities(id, user_id) ON DELETE CASCADE, + CONSTRAINT entity_edges_to_entity_fkey + FOREIGN KEY (to_entity_id, user_id) REFERENCES entities(id, user_id) ON DELETE CASCADE, + CONSTRAINT entity_edges_relationship_type_length_check + CHECK (char_length(relationship_type) BETWEEN 1 AND 100), + CONSTRAINT entity_edges_source_memory_ids_array_check + CHECK (jsonb_typeof(source_memory_ids) = 'array'), + CONSTRAINT entity_edges_source_memory_ids_nonempty_check + CHECK (jsonb_array_length(source_memory_ids) > 0), + CONSTRAINT entity_edges_valid_range_check + CHECK (valid_from IS NULL OR valid_to IS NULL OR valid_to >= valid_from) + ); + + CREATE INDEX entity_edges_user_created_idx + ON entity_edges (user_id, created_at, id); + CREATE INDEX entity_edges_user_from_created_idx + ON entity_edges (user_id, from_entity_id, created_at, id); + CREATE INDEX entity_edges_user_to_created_idx + ON entity_edges (user_id, to_entity_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON entity_edges TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY entity_edges_is_owner ON entity_edges + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS entity_edges", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0008_embedding_substrate.py b/apps/api/alembic/versions/20260312_0008_embedding_substrate.py new file mode 100644 index 0000000..d83551e --- /dev/null +++ b/apps/api/alembic/versions/20260312_0008_embedding_substrate.py @@ -0,0 +1,115 @@ +"""Add versioned embedding configs and user-scoped memory embeddings.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0008" +down_revision = "20260312_0007" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("embedding_configs", "memory_embeddings") + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE embedding_configs ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider text NOT NULL, + model text NOT NULL, + version text NOT NULL, + dimensions integer NOT NULL, + status text NOT NULL, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, provider, model, version), + CONSTRAINT embedding_configs_provider_length_check + CHECK (char_length(provider) BETWEEN 1 AND 100), + CONSTRAINT embedding_configs_model_length_check + CHECK (char_length(model) BETWEEN 1 AND 200), + CONSTRAINT embedding_configs_version_length_check + CHECK (char_length(version) BETWEEN 1 AND 100), + CONSTRAINT embedding_configs_dimensions_check + CHECK (dimensions > 0), + CONSTRAINT embedding_configs_status_check + CHECK (status IN ('active', 'deprecated', 'disabled')), + CONSTRAINT embedding_configs_metadata_object_check + CHECK (jsonb_typeof(metadata) = 'object') + ); + + CREATE INDEX embedding_configs_user_created_idx + ON embedding_configs (user_id, created_at, id); + + CREATE TABLE memory_embeddings ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + memory_id uuid NOT NULL, + embedding_config_id uuid NOT NULL, + dimensions integer NOT NULL, + vector jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, memory_id, embedding_config_id), + CONSTRAINT memory_embeddings_memory_fkey + FOREIGN KEY (memory_id, user_id) REFERENCES memories(id, user_id) ON DELETE CASCADE, + CONSTRAINT memory_embeddings_embedding_config_fkey + FOREIGN KEY (embedding_config_id, user_id) + REFERENCES embedding_configs(id, user_id) ON DELETE CASCADE, + CONSTRAINT memory_embeddings_dimensions_check + CHECK (dimensions > 0), + CONSTRAINT memory_embeddings_vector_array_check + CHECK (jsonb_typeof(vector) = 'array'), + CONSTRAINT memory_embeddings_vector_nonempty_check + CHECK (jsonb_array_length(vector) > 0), + CONSTRAINT memory_embeddings_vector_dimensions_match_check + CHECK (jsonb_array_length(vector) = dimensions) + ); + + CREATE INDEX memory_embeddings_user_memory_created_idx + ON memory_embeddings (user_id, memory_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON embedding_configs TO alicebot_app", + "GRANT SELECT, INSERT, UPDATE ON memory_embeddings TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY embedding_configs_is_owner ON embedding_configs + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY memory_embeddings_is_owner ON memory_embeddings + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS memory_embeddings", + "DROP TABLE IF EXISTS embedding_configs", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py b/apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py new file mode 100644 index 0000000..25fcf20 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py @@ -0,0 +1,111 @@ +"""Add user-scoped consents and deterministic policy storage.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0009" +down_revision = "20260312_0008" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("consents", "policies") + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE consents ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + consent_key text NOT NULL, + status text NOT NULL, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, consent_key), + CONSTRAINT consents_key_length_check + CHECK (char_length(consent_key) BETWEEN 1 AND 200), + CONSTRAINT consents_status_check + CHECK (status IN ('granted', 'revoked')), + CONSTRAINT consents_metadata_object_check + CHECK (jsonb_typeof(metadata) = 'object') + ); + + CREATE INDEX consents_user_key_created_idx + ON consents (user_id, consent_key, created_at, id); + + CREATE TABLE policies ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + name text NOT NULL, + action text NOT NULL, + scope text NOT NULL, + effect text NOT NULL, + priority integer NOT NULL, + active boolean NOT NULL DEFAULT TRUE, + conditions jsonb NOT NULL DEFAULT '{}'::jsonb, + required_consents jsonb NOT NULL DEFAULT '[]'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT policies_name_length_check + CHECK (char_length(name) BETWEEN 1 AND 200), + CONSTRAINT policies_action_length_check + CHECK (char_length(action) BETWEEN 1 AND 100), + CONSTRAINT policies_scope_length_check + CHECK (char_length(scope) BETWEEN 1 AND 200), + CONSTRAINT policies_effect_check + CHECK (effect IN ('allow', 'deny', 'require_approval')), + CONSTRAINT policies_priority_check + CHECK (priority >= 0), + CONSTRAINT policies_conditions_object_check + CHECK (jsonb_typeof(conditions) = 'object'), + CONSTRAINT policies_required_consents_array_check + CHECK (jsonb_typeof(required_consents) = 'array') + ); + + CREATE INDEX policies_user_active_priority_created_idx + ON policies (user_id, active, priority, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON consents TO alicebot_app", + "GRANT SELECT, INSERT ON policies TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY consents_is_owner ON consents + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY policies_is_owner ON policies + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS policies", + "DROP TABLE IF EXISTS consents", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py b/apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py new file mode 100644 index 0000000..6d58470 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py @@ -0,0 +1,96 @@ +"""Add stable tool registry storage for deterministic allowlist evaluation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0010" +down_revision = "20260312_0009" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("tools",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE tools ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + tool_key text NOT NULL, + name text NOT NULL, + description text NOT NULL, + version text NOT NULL, + metadata_version text NOT NULL, + active boolean NOT NULL DEFAULT TRUE, + tags jsonb NOT NULL DEFAULT '[]'::jsonb, + action_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + scope_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + domain_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + risk_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, tool_key, version), + CONSTRAINT tools_key_length_check + CHECK (char_length(tool_key) BETWEEN 1 AND 200), + CONSTRAINT tools_name_length_check + CHECK (char_length(name) BETWEEN 1 AND 200), + CONSTRAINT tools_description_length_check + CHECK (char_length(description) BETWEEN 1 AND 500), + CONSTRAINT tools_version_length_check + CHECK (char_length(version) BETWEEN 1 AND 100), + CONSTRAINT tools_metadata_version_check + CHECK (metadata_version = 'tool_metadata_v0'), + CONSTRAINT tools_tags_array_check + CHECK (jsonb_typeof(tags) = 'array'), + CONSTRAINT tools_action_hints_array_check + CHECK (jsonb_typeof(action_hints) = 'array'), + CONSTRAINT tools_scope_hints_array_check + CHECK (jsonb_typeof(scope_hints) = 'array'), + CONSTRAINT tools_domain_hints_array_check + CHECK (jsonb_typeof(domain_hints) = 'array'), + CONSTRAINT tools_risk_hints_array_check + CHECK (jsonb_typeof(risk_hints) = 'array'), + CONSTRAINT tools_metadata_object_check + CHECK (jsonb_typeof(metadata) = 'object') + ); + + CREATE INDEX tools_user_active_key_version_created_idx + ON tools (user_id, active, tool_key, version, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON tools TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY tools_is_owner ON tools + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS tools", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0011_approval_request_records.py b/apps/api/alembic/versions/20260312_0011_approval_request_records.py new file mode 100644 index 0000000..49aff5c --- /dev/null +++ b/apps/api/alembic/versions/20260312_0011_approval_request_records.py @@ -0,0 +1,88 @@ +"""Add durable approval request records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0011" +down_revision = "20260312_0010" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("approvals",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE approvals ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + thread_id uuid NOT NULL, + tool_id uuid NOT NULL, + status text NOT NULL DEFAULT 'pending', + request jsonb NOT NULL, + tool jsonb NOT NULL, + routing jsonb NOT NULL, + routing_trace_id uuid NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT approvals_thread_user_fk + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + CONSTRAINT approvals_tool_user_fk + FOREIGN KEY (tool_id, user_id) + REFERENCES tools(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT approvals_routing_trace_user_fk + FOREIGN KEY (routing_trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT approvals_status_check + CHECK (status = 'pending'), + CONSTRAINT approvals_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT approvals_tool_object_check + CHECK (jsonb_typeof(tool) = 'object'), + CONSTRAINT approvals_routing_object_check + CHECK (jsonb_typeof(routing) = 'object') + ); + + CREATE INDEX approvals_user_created_idx + ON approvals (user_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON approvals TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY approvals_is_owner ON approvals + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS approvals", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0012_approval_resolution.py b/apps/api/alembic/versions/20260312_0012_approval_resolution.py new file mode 100644 index 0000000..7ef2907 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0012_approval_resolution.py @@ -0,0 +1,63 @@ +"""Add approval resolution state and runtime update access.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0012" +down_revision = "20260312_0011" +branch_labels = None +depends_on = None + +_UPGRADE_SCHEMA_STATEMENT = """ + ALTER TABLE approvals + DROP CONSTRAINT approvals_status_check, + ADD COLUMN resolved_at timestamptz, + ADD COLUMN resolved_by_user_id uuid REFERENCES users(id) ON DELETE RESTRICT, + ADD CONSTRAINT approvals_status_check + CHECK (status IN ('pending', 'approved', 'rejected')), + ADD CONSTRAINT approvals_resolution_consistency_check + CHECK ( + (status = 'pending' AND resolved_at IS NULL AND resolved_by_user_id IS NULL) + OR ( + status IN ('approved', 'rejected') + AND resolved_at IS NOT NULL + AND resolved_by_user_id IS NOT NULL + ) + ), + ADD CONSTRAINT approvals_resolved_by_owner_check + CHECK (resolved_by_user_id IS NULL OR resolved_by_user_id = user_id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT UPDATE ON approvals TO alicebot_app", +) + +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON approvals FROM alicebot_app", + """ + ALTER TABLE approvals + DROP CONSTRAINT approvals_resolved_by_owner_check, + DROP CONSTRAINT approvals_resolution_consistency_check, + DROP CONSTRAINT approvals_status_check, + DROP COLUMN resolved_by_user_id, + DROP COLUMN resolved_at, + ADD CONSTRAINT approvals_status_check + CHECK (status = 'pending'); + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0013_tool_executions.py b/apps/api/alembic/versions/20260313_0013_tool_executions.py new file mode 100644 index 0000000..9bcfdfe --- /dev/null +++ b/apps/api/alembic/versions/20260313_0013_tool_executions.py @@ -0,0 +1,118 @@ +"""Add durable tool execution review records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0013" +down_revision = "20260312_0012" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("tool_executions",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE tool_executions ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + approval_id uuid NOT NULL, + thread_id uuid NOT NULL, + tool_id uuid NOT NULL, + trace_id uuid NOT NULL, + request_event_id uuid, + result_event_id uuid, + status text NOT NULL, + handler_key text, + request jsonb NOT NULL, + tool jsonb NOT NULL, + result jsonb NOT NULL, + executed_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT tool_executions_approval_user_fk + FOREIGN KEY (approval_id, user_id) + REFERENCES approvals(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_thread_user_fk + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + CONSTRAINT tool_executions_tool_user_fk + FOREIGN KEY (tool_id, user_id) + REFERENCES tools(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_trace_user_fk + FOREIGN KEY (trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_request_event_user_fk + FOREIGN KEY (request_event_id, user_id) + REFERENCES events(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_result_event_user_fk + FOREIGN KEY (result_event_id, user_id) + REFERENCES events(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_status_check + CHECK (status IN ('completed', 'blocked')), + CONSTRAINT tool_executions_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT tool_executions_tool_object_check + CHECK (jsonb_typeof(tool) = 'object'), + CONSTRAINT tool_executions_result_object_check + CHECK (jsonb_typeof(result) = 'object'), + CONSTRAINT tool_executions_status_event_consistency_check + CHECK ( + ( + status = 'completed' + AND handler_key IS NOT NULL + AND request_event_id IS NOT NULL + AND result_event_id IS NOT NULL + ) + OR ( + status = 'blocked' + AND request_event_id IS NULL + AND result_event_id IS NULL + ) + ) + ); + + CREATE INDEX tool_executions_user_executed_idx + ON tool_executions (user_id, executed_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON tool_executions TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY tool_executions_is_owner ON tool_executions + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS tool_executions", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0014_execution_budgets.py b/apps/api/alembic/versions/20260313_0014_execution_budgets.py new file mode 100644 index 0000000..f6c3519 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0014_execution_budgets.py @@ -0,0 +1,71 @@ +"""Add deterministic execution budget records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0014" +down_revision = "20260313_0013" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("execution_budgets",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE execution_budgets ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + tool_key text, + domain_hint text, + max_completed_executions integer NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT execution_budgets_selector_check + CHECK (tool_key IS NOT NULL OR domain_hint IS NOT NULL), + CONSTRAINT execution_budgets_max_completed_executions_check + CHECK (max_completed_executions > 0) + ); + + CREATE INDEX execution_budgets_user_created_idx + ON execution_budgets (user_id, created_at, id); + + CREATE INDEX execution_budgets_user_match_idx + ON execution_budgets (user_id, tool_key, domain_hint, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON execution_budgets TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY execution_budgets_is_owner ON execution_budgets + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS execution_budgets", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py b/apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py new file mode 100644 index 0000000..ffacd8c --- /dev/null +++ b/apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py @@ -0,0 +1,80 @@ +"""Add execution budget lifecycle controls.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0015" +down_revision = "20260313_0014" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + """ + ALTER TABLE execution_budgets + ADD COLUMN status text NOT NULL DEFAULT 'active', + ADD COLUMN deactivated_at timestamptz, + ADD COLUMN superseded_by_budget_id uuid REFERENCES execution_budgets(id) ON DELETE SET NULL DEFERRABLE INITIALLY DEFERRED, + ADD COLUMN supersedes_budget_id uuid REFERENCES execution_budgets(id) ON DELETE SET NULL DEFERRABLE INITIALLY DEFERRED; + """, + """ + ALTER TABLE execution_budgets + ADD CONSTRAINT execution_budgets_status_check + CHECK (status IN ('active', 'inactive', 'superseded')), + ADD CONSTRAINT execution_budgets_lifecycle_state_check + CHECK ( + (status = 'active' AND deactivated_at IS NULL AND superseded_by_budget_id IS NULL) + OR (status = 'inactive' AND deactivated_at IS NOT NULL AND superseded_by_budget_id IS NULL) + OR (status = 'superseded' AND deactivated_at IS NOT NULL AND superseded_by_budget_id IS NOT NULL) + ), + ADD CONSTRAINT execution_budgets_supersedes_budget_unique + UNIQUE (supersedes_budget_id); + """, + """ + CREATE INDEX execution_budgets_user_status_created_idx + ON execution_budgets (user_id, status, created_at, id); + """, + """ + CREATE UNIQUE INDEX execution_budgets_one_active_scope_idx + ON execution_budgets ( + user_id, + COALESCE(tool_key, ''), + COALESCE(domain_hint, '') + ) + WHERE status = 'active'; + """, + "GRANT SELECT, INSERT, UPDATE ON execution_budgets TO alicebot_app", +) + +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON execution_budgets FROM alicebot_app", + "DROP INDEX IF EXISTS execution_budgets_one_active_scope_idx", + "DROP INDEX IF EXISTS execution_budgets_user_status_created_idx", + """ + ALTER TABLE execution_budgets + DROP CONSTRAINT IF EXISTS execution_budgets_supersedes_budget_unique, + DROP CONSTRAINT IF EXISTS execution_budgets_lifecycle_state_check, + DROP CONSTRAINT IF EXISTS execution_budgets_status_check; + """, + """ + ALTER TABLE execution_budgets + DROP COLUMN IF EXISTS supersedes_budget_id, + DROP COLUMN IF EXISTS superseded_by_budget_id, + DROP COLUMN IF EXISTS deactivated_at, + DROP COLUMN IF EXISTS status; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py b/apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py new file mode 100644 index 0000000..31a842c --- /dev/null +++ b/apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py @@ -0,0 +1,47 @@ +"""Add optional rolling-window execution budget support.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0016" +down_revision = "20260313_0015" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + """ + ALTER TABLE execution_budgets + ADD COLUMN rolling_window_seconds integer; + """, + """ + ALTER TABLE execution_budgets + ADD CONSTRAINT execution_budgets_rolling_window_seconds_check + CHECK (rolling_window_seconds IS NULL OR rolling_window_seconds > 0); + """, +) + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE execution_budgets + DROP CONSTRAINT IF EXISTS execution_budgets_rolling_window_seconds_check; + """, + """ + ALTER TABLE execution_budgets + DROP COLUMN IF EXISTS rolling_window_seconds; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py b/apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py new file mode 100644 index 0000000..f00f07c --- /dev/null +++ b/apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py @@ -0,0 +1,112 @@ +"""Add durable task records with deterministic lifecycle status.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0017" +down_revision = "20260313_0016" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("tasks",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE tasks ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + thread_id uuid NOT NULL, + tool_id uuid NOT NULL, + status text NOT NULL, + request jsonb NOT NULL, + tool jsonb NOT NULL, + latest_approval_id uuid, + latest_execution_id uuid, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT tasks_thread_user_fk + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + CONSTRAINT tasks_tool_user_fk + FOREIGN KEY (tool_id, user_id) + REFERENCES tools(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tasks_latest_approval_user_fk + FOREIGN KEY (latest_approval_id, user_id) + REFERENCES approvals(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tasks_latest_execution_user_fk + FOREIGN KEY (latest_execution_id, user_id) + REFERENCES tool_executions(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tasks_status_check + CHECK (status IN ('pending_approval', 'approved', 'executed', 'denied', 'blocked')), + CONSTRAINT tasks_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT tasks_tool_object_check + CHECK (jsonb_typeof(tool) = 'object'), + CONSTRAINT tasks_pending_approval_link_check + CHECK (status <> 'pending_approval' OR latest_approval_id IS NOT NULL), + CONSTRAINT tasks_execution_link_check + CHECK ( + ( + status IN ('executed', 'blocked') + AND latest_execution_id IS NOT NULL + ) + OR ( + status NOT IN ('executed', 'blocked') + AND latest_execution_id IS NULL + ) + ) + ); + + CREATE INDEX tasks_user_created_idx + ON tasks (user_id, created_at, id); + + CREATE UNIQUE INDEX tasks_latest_approval_unique_idx + ON tasks (user_id, latest_approval_id) + WHERE latest_approval_id IS NOT NULL; + + CREATE UNIQUE INDEX tasks_latest_execution_unique_idx + ON tasks (user_id, latest_execution_id) + WHERE latest_execution_id IS NOT NULL; + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON tasks TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY tasks_is_owner ON tasks + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS tasks", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0018_task_steps.py b/apps/api/alembic/versions/20260313_0018_task_steps.py new file mode 100644 index 0000000..9467472 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0018_task_steps.py @@ -0,0 +1,93 @@ +"""Add durable task-step review records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0018" +down_revision = "20260313_0017" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_steps",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_steps ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_id uuid NOT NULL, + sequence_no integer NOT NULL, + kind text NOT NULL, + status text NOT NULL, + request jsonb NOT NULL, + outcome jsonb NOT NULL, + trace_id uuid NOT NULL, + trace_kind text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT task_steps_task_user_fk + FOREIGN KEY (task_id, user_id) + REFERENCES tasks(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_steps_trace_user_fk + FOREIGN KEY (trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT task_steps_sequence_no_check + CHECK (sequence_no > 0), + CONSTRAINT task_steps_kind_check + CHECK (kind IN ('governed_request')), + CONSTRAINT task_steps_status_check + CHECK (status IN ('created', 'approved', 'executed', 'blocked', 'denied')), + CONSTRAINT task_steps_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT task_steps_outcome_object_check + CHECK (jsonb_typeof(outcome) = 'object'), + CONSTRAINT task_steps_trace_kind_nonempty_check + CHECK (length(trace_kind) > 0) + ); + + CREATE UNIQUE INDEX task_steps_task_sequence_idx + ON task_steps (user_id, task_id, sequence_no); + + CREATE INDEX task_steps_user_created_idx + ON task_steps (user_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON task_steps TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_steps_is_owner ON task_steps + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS task_steps", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0019_task_step_lineage.py b/apps/api/alembic/versions/20260313_0019_task_step_lineage.py new file mode 100644 index 0000000..b0d98a5 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0019_task_step_lineage.py @@ -0,0 +1,58 @@ +"""Add explicit lineage fields for manual task-step continuation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0019" +down_revision = "20260313_0018" +branch_labels = None +depends_on = None + +_UPGRADE_SCHEMA_STATEMENT = """ + ALTER TABLE task_steps + ADD COLUMN parent_step_id uuid, + ADD COLUMN source_approval_id uuid, + ADD COLUMN source_execution_id uuid, + ADD CONSTRAINT task_steps_parent_step_user_fk + FOREIGN KEY (parent_step_id, user_id) + REFERENCES task_steps(id, user_id) + ON DELETE RESTRICT, + ADD CONSTRAINT task_steps_source_approval_user_fk + FOREIGN KEY (source_approval_id, user_id) + REFERENCES approvals(id, user_id) + ON DELETE RESTRICT, + ADD CONSTRAINT task_steps_source_execution_user_fk + FOREIGN KEY (source_execution_id, user_id) + REFERENCES tool_executions(id, user_id) + ON DELETE RESTRICT, + ADD CONSTRAINT task_steps_parent_step_not_self_check + CHECK (parent_step_id IS NULL OR parent_step_id <> id); + """ + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE task_steps + DROP CONSTRAINT task_steps_parent_step_not_self_check, + DROP CONSTRAINT task_steps_source_execution_user_fk, + DROP CONSTRAINT task_steps_source_approval_user_fk, + DROP CONSTRAINT task_steps_parent_step_user_fk, + DROP COLUMN source_execution_id, + DROP COLUMN source_approval_id, + DROP COLUMN parent_step_id; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py b/apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py new file mode 100644 index 0000000..fe8a270 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py @@ -0,0 +1,41 @@ +"""Link approvals directly to their durable task step.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0020" +down_revision = "20260313_0019" +branch_labels = None +depends_on = None + +_UPGRADE_SCHEMA_STATEMENT = """ + ALTER TABLE approvals + ADD COLUMN task_step_id uuid, + ADD CONSTRAINT approvals_task_step_user_fk + FOREIGN KEY (task_step_id, user_id) + REFERENCES task_steps(id, user_id) + ON DELETE RESTRICT; + """ + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE approvals + DROP CONSTRAINT approvals_task_step_user_fk, + DROP COLUMN task_step_id; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py b/apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py new file mode 100644 index 0000000..6a26d41 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py @@ -0,0 +1,81 @@ +"""Link tool executions directly to their durable task step.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0021" +down_revision = "20260313_0020" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + """ + ALTER TABLE tool_executions + ADD COLUMN task_step_id uuid; + """, + """ + UPDATE tool_executions AS executions + SET task_step_id = COALESCE( + approvals.task_step_id, + ( + SELECT task_steps.id + FROM task_steps + WHERE task_steps.user_id = executions.user_id + AND task_steps.outcome ->> 'approval_id' = approvals.id::text + ORDER BY task_steps.created_at ASC, task_steps.id ASC + LIMIT 1 + ) + ) + FROM approvals + WHERE approvals.id = executions.approval_id + AND approvals.user_id = executions.user_id; + """, + """ + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 + FROM tool_executions + WHERE task_step_id IS NULL + ) THEN + RAISE EXCEPTION + 'tool_executions.task_step_id backfill failed for existing rows'; + END IF; + END; + $$; + """, + """ + ALTER TABLE tool_executions + ADD CONSTRAINT tool_executions_task_step_user_fk + FOREIGN KEY (task_step_id, user_id) + REFERENCES task_steps(id, user_id) + ON DELETE RESTRICT; + """, + """ + ALTER TABLE tool_executions + ALTER COLUMN task_step_id SET NOT NULL; + """, +) + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE tool_executions + DROP CONSTRAINT tool_executions_task_step_user_fk, + DROP COLUMN task_step_id; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0022_task_workspaces.py b/apps/api/alembic/versions/20260313_0022_task_workspaces.py new file mode 100644 index 0000000..626224f --- /dev/null +++ b/apps/api/alembic/versions/20260313_0022_task_workspaces.py @@ -0,0 +1,77 @@ +"""Add user-scoped task workspace records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0022" +down_revision = "20260313_0021" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_workspaces",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_workspaces ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_id uuid NOT NULL, + status text NOT NULL, + local_path text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT task_workspaces_task_user_fk + FOREIGN KEY (task_id, user_id) + REFERENCES tasks(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_workspaces_status_check + CHECK (status IN ('active')), + CONSTRAINT task_workspaces_local_path_nonempty_check + CHECK (length(local_path) > 0) + ); + + CREATE INDEX task_workspaces_user_created_idx + ON task_workspaces (user_id, created_at, id); + + CREATE UNIQUE INDEX task_workspaces_active_task_idx + ON task_workspaces (user_id, task_id) + WHERE status = 'active'; + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON task_workspaces TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_workspaces_is_owner ON task_workspaces + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS task_workspaces", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/__init__.py b/apps/api/src/alicebot_api/__init__.py new file mode 100644 index 0000000..39a8a6c --- /dev/null +++ b/apps/api/src/alicebot_api/__init__.py @@ -0,0 +1,2 @@ +"""AliceBot foundation API package.""" + diff --git a/apps/api/src/alicebot_api/approvals.py b/apps/api/src/alicebot_api/approvals.py new file mode 100644 index 0000000..7ef1827 --- /dev/null +++ b/apps/api/src/alicebot_api/approvals.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +from typing import cast +from uuid import UUID + +from alicebot_api.contracts import ( + APPROVAL_LIST_ORDER, + APPROVAL_REQUEST_VERSION_V0, + APPROVAL_RESOLUTION_VERSION_V0, + TRACE_KIND_APPROVAL_REQUEST, + TRACE_KIND_APPROVAL_RESOLUTION, + ApprovalApproveInput, + ApprovalDetailResponse, + ApprovalListResponse, + ApprovalListSummary, + ApprovalRecord, + ApprovalRejectInput, + ApprovalResolutionAction, + ApprovalResolutionOutcome, + ApprovalResolutionRecord, + ApprovalResolutionRequestTracePayload, + ApprovalResolutionResponse, + ApprovalResolutionStateTracePayload, + ApprovalResolutionSummaryTracePayload, + ApprovalRequestCreateInput, + ApprovalRequestCreateResponse, + ApprovalRequestTraceSummary, + ApprovalRoutingRecord, + TaskCreateInput, + TaskStepCreateInput, + ToolRoutingRequestInput, +) +from alicebot_api.store import ApprovalRow, ContinuityStore +from alicebot_api.tasks import ( + DEFAULT_TASK_STEP_KIND, + DEFAULT_TASK_STEP_SEQUENCE_NO, + create_task_step_for_governed_request, + create_task_for_governed_request, + sync_task_step_with_approval, + task_step_lifecycle_trace_events, + task_step_outcome_snapshot, + task_step_status_for_routing_decision, + sync_task_with_approval, + task_lifecycle_trace_events, + task_status_for_routing_decision, + validate_linked_task_step_for_approval, +) +from alicebot_api.tools import route_tool_invocation + + +class ApprovalNotFoundError(LookupError): + """Raised when an approval record is not visible inside the current user scope.""" + + +class ApprovalResolutionConflictError(RuntimeError): + """Raised when a visible approval record is no longer pending.""" + + +def _serialize_resolution(row: ApprovalRow) -> ApprovalResolutionRecord | None: + if row["resolved_at"] is None or row["resolved_by_user_id"] is None: + return None + return { + "resolved_at": row["resolved_at"].isoformat(), + "resolved_by_user_id": str(row["resolved_by_user_id"]), + } + + +def serialize_approval_row(row: ApprovalRow) -> ApprovalRecord: + return { + "id": str(row["id"]), + "thread_id": str(row["thread_id"]), + "task_step_id": None if row["task_step_id"] is None else str(row["task_step_id"]), + "status": cast(str, row["status"]), + "request": cast(dict[str, object], row["request"]), + "tool": cast(dict[str, object], row["tool"]), + "routing": cast(ApprovalRoutingRecord, row["routing"]), + "created_at": row["created_at"].isoformat(), + "resolution": _serialize_resolution(row), + } + + +_serialize_approval = serialize_approval_row + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _resolution_outcome( + *, + requested_action: ApprovalResolutionAction, + current_status: str, +) -> ApprovalResolutionOutcome: + if ( + requested_action == "approve" + and current_status == "approved" + ) or ( + requested_action == "reject" + and current_status == "rejected" + ): + return "duplicate_rejected" + return "conflict_rejected" + + +def _resolution_error( + approval_id: UUID, + *, + requested_action: ApprovalResolutionAction, + current_status: str, +) -> ApprovalResolutionConflictError: + if ( + requested_action == "approve" + and current_status == "approved" + ) or ( + requested_action == "reject" + and current_status == "rejected" + ): + return ApprovalResolutionConflictError(f"approval {approval_id} was already {current_status}") + + requested_status = "approved" if requested_action == "approve" else "rejected" + return ApprovalResolutionConflictError( + f"approval {approval_id} was already {current_status} and cannot be {requested_status}" + ) + + +def _resolve_approval( + store: ContinuityStore, + *, + user_id: UUID, + approval_id: UUID, + requested_action: ApprovalResolutionAction, + resolved_status: str, +) -> ApprovalResolutionResponse: + del user_id + + approval = store.get_approval_optional(approval_id) + if approval is None: + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + validate_linked_task_step_for_approval( + store, + approval_id=approval_id, + task_step_id=cast(UUID | None, approval["task_step_id"]), + ) + + previous_status = cast(str, approval["status"]) + current = approval + outcome: ApprovalResolutionOutcome + + if approval["status"] == "pending": + resolved = store.resolve_approval_optional( + approval_id=approval_id, + status=resolved_status, + ) + if resolved is None: + current = store.get_approval_optional(approval_id) + if current is None: + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + outcome = _resolution_outcome( + requested_action=requested_action, + current_status=cast(str, current["status"]), + ) + else: + current = resolved + outcome = "resolved" + else: + outcome = _resolution_outcome( + requested_action=requested_action, + current_status=previous_status, + ) + + trace = store.create_trace( + user_id=current["user_id"], + thread_id=current["thread_id"], + kind=TRACE_KIND_APPROVAL_RESOLUTION, + compiler_version=APPROVAL_RESOLUTION_VERSION_V0, + status="completed", + limits={ + "order": list(APPROVAL_LIST_ORDER), + "requested_action": requested_action, + "outcome": outcome, + }, + ) + + resolution = _serialize_resolution(current) + linked_task_step_id = None if current["task_step_id"] is None else str(current["task_step_id"]) + request_payload: ApprovalResolutionRequestTracePayload = { + "approval_id": str(approval_id), + "task_step_id": linked_task_step_id, + "requested_action": requested_action, + } + state_payload: ApprovalResolutionStateTracePayload = { + "approval_id": str(current["id"]), + "task_step_id": linked_task_step_id, + "requested_action": requested_action, + "previous_status": previous_status, + "outcome": outcome, + "current_status": cast(str, current["status"]), + "resolved_at": None if resolution is None else resolution["resolved_at"], + "resolved_by_user_id": None if resolution is None else resolution["resolved_by_user_id"], + } + summary_payload: ApprovalResolutionSummaryTracePayload = { + "approval_id": str(current["id"]), + "task_step_id": linked_task_step_id, + "requested_action": requested_action, + "outcome": outcome, + "final_status": cast(str, current["status"]), + } + task_transition = sync_task_with_approval( + store, + approval_id=current["id"], + approval_status=cast(str, current["status"]), + ) + task_step_transition = sync_task_step_with_approval( + store, + approval_id=current["id"], + task_step_id=cast(UUID | None, current["task_step_id"]), + approval_status=cast(str, current["status"]), + trace_id=trace["id"], + trace_kind=TRACE_KIND_APPROVAL_RESOLUTION, + ) + trace_events: list[tuple[str, dict[str, object]]] = [ + ("approval.resolution.request", cast(dict[str, object], request_payload)), + ("approval.resolution.state", cast(dict[str, object], state_payload)), + ("approval.resolution.summary", cast(dict[str, object], summary_payload)), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="approval_resolution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="approval_resolution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + if outcome != "resolved": + raise _resolution_error( + approval_id, + requested_action=requested_action, + current_status=cast(str, current["status"]), + ) + + return { + "approval": _serialize_approval(current), + "trace": { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + }, + } + + +def submit_approval_request( + store: ContinuityStore, + *, + user_id: UUID, + request: ApprovalRequestCreateInput, +) -> ApprovalRequestCreateResponse: + routing = route_tool_invocation( + store, + user_id=user_id, + request=ToolRoutingRequestInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise RuntimeError("validated thread disappeared before approval request trace creation") + + approval_persist_requested = routing["decision"] == "approval_required" + approval = None + approval_created = False + if routing["decision"] == "approval_required": + approval_row = store.create_approval( + thread_id=request.thread_id, + tool_id=request.tool_id, + task_step_id=None, + status="pending", + request=routing["request"], + tool=routing["tool"], + routing={ + "decision": routing["decision"], + "reasons": routing["reasons"], + "trace": routing["trace"], + }, + routing_trace_id=UUID(routing["trace"]["trace_id"]), + ) + approval = _serialize_approval(approval_row) + approval_created = True + + task = create_task_for_governed_request( + store, + request=TaskCreateInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + status=task_status_for_routing_decision(routing["decision"]), + request=routing["request"], + tool=routing["tool"], + latest_approval_id=None if approval is None else UUID(approval["id"]), + ), + )["task"] + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_APPROVAL_REQUEST, + compiler_version=APPROVAL_REQUEST_VERSION_V0, + status="completed", + limits={ + "order": list(APPROVAL_LIST_ORDER), + "persisted": approval_persist_requested, + }, + ) + task_step = create_task_step_for_governed_request( + store, + request=TaskStepCreateInput( + task_id=UUID(task["id"]), + sequence_no=DEFAULT_TASK_STEP_SEQUENCE_NO, + kind=DEFAULT_TASK_STEP_KIND, + status=task_step_status_for_routing_decision(routing["decision"]), + request=routing["request"], + outcome=task_step_outcome_snapshot( + routing_decision=routing["decision"], + approval_id=None if approval is None else approval["id"], + approval_status=None if approval is None else approval["status"], + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=trace["id"], + trace_kind=TRACE_KIND_APPROVAL_REQUEST, + ), + )["task_step"] + if approval is not None: + updated_approval = store.update_approval_task_step_optional( + approval_id=UUID(approval["id"]), + task_step_id=UUID(task_step["id"]), + ) + if updated_approval is None: + raise RuntimeError("approval disappeared while linking it to its originating task step") + approval = _serialize_approval(updated_approval) + + trace_events: list[tuple[str, dict[str, object]]] = [ + ("approval.request.request", request.as_payload()), + ( + "approval.request.routing", + { + "decision": routing["decision"], + "tool_id": routing["tool"]["id"], + "tool_key": routing["tool"]["tool_key"], + "tool_version": routing["tool"]["version"], + "routing_trace_id": routing["trace"]["trace_id"], + "routing_trace_event_count": routing["trace"]["trace_event_count"], + "reasons": routing["reasons"], + }, + ), + ( + "approval.request.persisted" if approval_created else "approval.request.skipped", + { + "approval_id": None if approval is None else approval["id"], + "task_step_id": None if approval is None else approval["task_step_id"], + "decision": routing["decision"], + "persisted": approval_created, + }, + ), + ( + "approval.request.summary", + { + "decision": routing["decision"], + "persisted": approval_created, + "approval_id": None if approval is None else approval["id"], + "task_step_id": None if approval is None else approval["task_step_id"], + }, + ), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task, + previous_status=None, + source="approval_request", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step, + previous_status=None, + source="approval_request", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + trace_summary: ApprovalRequestTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "request": routing["request"], + "decision": routing["decision"], + "tool": routing["tool"], + "reasons": routing["reasons"], + "task": task, + "approval": approval, + "routing_trace": routing["trace"], + "trace": trace_summary, + } + + +def approve_approval_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ApprovalApproveInput, +) -> ApprovalResolutionResponse: + return _resolve_approval( + store, + user_id=user_id, + approval_id=request.approval_id, + requested_action="approve", + resolved_status="approved", + ) + + +def reject_approval_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ApprovalRejectInput, +) -> ApprovalResolutionResponse: + return _resolve_approval( + store, + user_id=user_id, + approval_id=request.approval_id, + requested_action="reject", + resolved_status="rejected", + ) + + +def list_approval_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ApprovalListResponse: + del user_id + + items = [_serialize_approval(row) for row in store.list_approvals()] + summary: ApprovalListSummary = { + "total_count": len(items), + "order": list(APPROVAL_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_approval_record( + store: ContinuityStore, + *, + user_id: UUID, + approval_id: UUID, +) -> ApprovalDetailResponse: + del user_id + + approval = store.get_approval_optional(approval_id) + if approval is None: + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + return {"approval": _serialize_approval(approval)} diff --git a/apps/api/src/alicebot_api/compiler.py b/apps/api/src/alicebot_api/compiler.py new file mode 100644 index 0000000..3626eed --- /dev/null +++ b/apps/api/src/alicebot_api/compiler.py @@ -0,0 +1,832 @@ +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from alicebot_api.contracts import ( + COMPILER_VERSION_V0, + CompilerDecision, + CompileContextSemanticRetrievalInput, + CompilerRunResult, + CompiledContextPack, + ContextCompilerLimits, + ContextPackHybridMemorySummary, + ContextPackMemory, + ContextPackMemorySummary, + HybridMemoryDecisionTracePayload, + MemorySelectionSource, + SEMANTIC_MEMORY_RETRIEVAL_ORDER, + SemanticMemoryRetrievalRequestInput, + TRACE_KIND_CONTEXT_COMPILE, + TraceEventRecord, + isoformat_or_none, +) +from alicebot_api.semantic_retrieval import validate_semantic_memory_retrieval_request +from alicebot_api.store import ( + ContinuityStore, + EntityEdgeRow, + EntityRow, + EventRow, + MemoryRow, + SemanticMemoryRetrievalRow, + SessionRow, + ThreadRow, + UserRow, +) + +SUMMARY_TRACE_EVENT_KIND = "context.summary" +_UNBOUNDED_SEMANTIC_RETRIEVAL_LIMIT = 2_147_483_647 +HYBRID_MEMORY_SOURCE_PRECEDENCE: list[MemorySelectionSource] = ["symbolic", "semantic"] +HYBRID_SYMBOLIC_ORDER = ["updated_at_asc", "created_at_asc", "id_asc"] + + +@dataclass(frozen=True, slots=True) +class CompiledTraceRun: + trace_id: str + context_pack: CompiledContextPack + trace_event_count: int + + +@dataclass(frozen=True, slots=True) +class CompiledMemorySection: + items: list[ContextPackMemory] + summary: ContextPackMemorySummary + decisions: list[CompilerDecision] + + +@dataclass(slots=True) +class HybridMemoryCandidate: + memory: MemoryRow + sources: list[MemorySelectionSource] + semantic_score: float | None = None + + +def _session_sort_key( + session: SessionRow, + latest_session_sequence: dict[UUID, int], +) -> tuple[int, str, str, str]: + latest_sequence = latest_session_sequence.get(session["id"], -1) + started_at = isoformat_or_none(session["started_at"]) or "" + created_at = session["created_at"].isoformat() + return (latest_sequence, started_at, created_at, str(session["id"])) + + +def _serialize_user(user: UserRow) -> dict[str, str | None]: + return { + "id": str(user["id"]), + "email": user["email"], + "display_name": user["display_name"], + "created_at": user["created_at"].isoformat(), + } + + +def _serialize_thread(thread: ThreadRow) -> dict[str, str]: + return { + "id": str(thread["id"]), + "title": thread["title"], + "created_at": thread["created_at"].isoformat(), + "updated_at": thread["updated_at"].isoformat(), + } + + +def _serialize_session(session: SessionRow) -> dict[str, str | None]: + return { + "id": str(session["id"]), + "status": session["status"], + "started_at": isoformat_or_none(session["started_at"]), + "ended_at": isoformat_or_none(session["ended_at"]), + "created_at": session["created_at"].isoformat(), + } + + +def _serialize_event(event: EventRow) -> dict[str, object]: + return { + "id": str(event["id"]), + "session_id": None if event["session_id"] is None else str(event["session_id"]), + "sequence_no": event["sequence_no"], + "kind": event["kind"], + "payload": event["payload"], + "created_at": event["created_at"].isoformat(), + } + + +def _memory_sort_key(memory: MemoryRow) -> tuple[str, str, str]: + return ( + memory["updated_at"].isoformat(), + memory["created_at"].isoformat(), + str(memory["id"]), + ) + + +def _serialize_memory(memory: MemoryRow) -> dict[str, object]: + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "source_provenance": { + "sources": ["symbolic"], + "semantic_score": None, + }, + } + + +def _entity_sort_key(entity: EntityRow) -> tuple[str, str]: + return (entity["created_at"].isoformat(), str(entity["id"])) + + +def _serialize_entity(entity: EntityRow) -> dict[str, object]: + return { + "id": str(entity["id"]), + "entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + "created_at": entity["created_at"].isoformat(), + } + + +def _entity_edge_sort_key(edge: EntityEdgeRow) -> tuple[str, str]: + return (edge["created_at"].isoformat(), str(edge["id"])) + + +def _serialize_entity_edge(edge: EntityEdgeRow) -> dict[str, object]: + return { + "id": str(edge["id"]), + "from_entity_id": str(edge["from_entity_id"]), + "to_entity_id": str(edge["to_entity_id"]), + "relationship_type": edge["relationship_type"], + "valid_from": isoformat_or_none(edge["valid_from"]), + "valid_to": isoformat_or_none(edge["valid_to"]), + "source_memory_ids": edge["source_memory_ids"], + "created_at": edge["created_at"].isoformat(), + } + + +def _semantic_memory_sort_key(memory: SemanticMemoryRetrievalRow) -> tuple[float, str, str]: + return (-float(memory["score"]), memory["created_at"].isoformat(), str(memory["id"])) + + +def _semantic_deleted_memory_sort_key(memory: MemoryRow) -> tuple[str, str, str]: + return ( + memory["updated_at"].isoformat(), + memory["created_at"].isoformat(), + str(memory["id"]), + ) + + +def _empty_hybrid_memory_summary() -> ContextPackHybridMemorySummary: + return { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 0, + "semantic_selected_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": list(HYBRID_MEMORY_SOURCE_PRECEDENCE), + "symbolic_order": list(HYBRID_SYMBOLIC_ORDER), + "semantic_order": list(SEMANTIC_MEMORY_RETRIEVAL_ORDER), + } + + +def _hybrid_memory_decision_metadata( + *, + embedding_config_id: UUID | None, + memory_key: str, + status: str, + source_event_ids: list[str], + selected_sources: list[MemorySelectionSource], + semantic_score: float | None, +) -> HybridMemoryDecisionTracePayload: + return { + "embedding_config_id": None if embedding_config_id is None else str(embedding_config_id), + "memory_key": memory_key, + "status": status, + "source_event_ids": source_event_ids, + "selected_sources": list(selected_sources), + "semantic_score": semantic_score, + } + + +def _serialize_hybrid_memory(candidate: HybridMemoryCandidate) -> ContextPackMemory: + memory = candidate.memory + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "source_provenance": { + "sources": list(candidate.sources), + "semantic_score": candidate.semantic_score, + }, + } + + +def _build_symbolic_memory_section( + *, + memories: list[MemoryRow], + limits: ContextCompilerLimits, +) -> CompiledMemorySection: + ordered_memories = sorted(memories, key=_memory_sort_key) + active_memories = [memory for memory in ordered_memories if memory["status"] == "active"] + deleted_memories = [memory for memory in ordered_memories if memory["status"] != "active"] + symbolic_candidates = active_memories[-limits.max_memories :] if limits.max_memories > 0 else [] + memory_candidates = [ + HybridMemoryCandidate(memory=memory, sources=["symbolic"]) + for memory in symbolic_candidates + ] + decisions: list[CompilerDecision] = [] + + for position, candidate in enumerate(memory_candidates, start=1): + decisions.append( + CompilerDecision( + "included", + "memory", + candidate.memory["id"], + "within_hybrid_memory_limit", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=None, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=None, + ), + ) + ) + + for position, memory in enumerate(deleted_memories, start=1): + decisions.append( + CompilerDecision( + "excluded", + "memory", + memory["id"], + "hybrid_memory_deleted", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=None, + memory_key=memory["memory_key"], + status=memory["status"], + source_event_ids=memory["source_event_ids"], + selected_sources=["symbolic"], + semantic_score=None, + ), + ) + ) + + included_items = [_serialize_hybrid_memory(candidate) for candidate in memory_candidates] + return CompiledMemorySection( + items=included_items, + summary={ + "candidate_count": len(memory_candidates) + len(deleted_memories), + "included_count": len(included_items), + "excluded_deleted_count": len(deleted_memories), + "excluded_limit_count": 0, + "hybrid_retrieval": { + **_empty_hybrid_memory_summary(), + "symbolic_selected_count": len(memory_candidates), + "merged_candidate_count": len(memory_candidates), + "included_symbolic_only_count": len(included_items), + }, + }, + decisions=decisions, + ) + + +def _compile_memory_section( + store: ContinuityStore, + *, + memories: list[MemoryRow], + limits: ContextCompilerLimits, + semantic_retrieval: CompileContextSemanticRetrievalInput | None, +) -> CompiledMemorySection: + if semantic_retrieval is None: + return _build_symbolic_memory_section(memories=memories, limits=limits) + + ordered_memories = sorted(memories, key=_memory_sort_key) + active_memories = [memory for memory in ordered_memories if memory["status"] == "active"] + deleted_memories = [memory for memory in ordered_memories if memory["status"] != "active"] + symbolic_candidates = active_memories[-limits.max_memories :] if limits.max_memories > 0 else [] + active_memories_by_id = {memory["id"]: memory for memory in active_memories} + + request = SemanticMemoryRetrievalRequestInput( + embedding_config_id=semantic_retrieval.embedding_config_id, + query_vector=semantic_retrieval.query_vector, + limit=semantic_retrieval.limit, + ) + _config, query_vector = validate_semantic_memory_retrieval_request(store, request=request) + ordered_semantic_candidates = sorted( + store.retrieve_semantic_memory_matches( + embedding_config_id=semantic_retrieval.embedding_config_id, + query_vector=query_vector, + limit=_UNBOUNDED_SEMANTIC_RETRIEVAL_LIMIT, + ), + key=_semantic_memory_sort_key, + ) + selected_semantic_candidates = ordered_semantic_candidates[: semantic_retrieval.limit] + + merged_candidates: list[HybridMemoryCandidate] = [ + HybridMemoryCandidate(memory=memory, sources=["symbolic"]) + for memory in symbolic_candidates + ] + merged_candidate_ids = {candidate.memory["id"] for candidate in merged_candidates} + deduplication_decisions: list[CompilerDecision] = [] + deduplicated_count = 0 + + for position, semantic_candidate in enumerate(selected_semantic_candidates, start=1): + memory = active_memories_by_id.get(semantic_candidate["id"], semantic_candidate) + if semantic_candidate["id"] in merged_candidate_ids: + deduplicated_count += 1 + for candidate in merged_candidates: + if candidate.memory["id"] != semantic_candidate["id"]: + continue + if "semantic" not in candidate.sources: + candidate.sources.append("semantic") + candidate.semantic_score = float(semantic_candidate["score"]) + deduplication_decisions.append( + CompilerDecision( + "included", + "memory", + semantic_candidate["id"], + "hybrid_memory_deduplicated", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=candidate.semantic_score, + ), + ) + ) + break + continue + + merged_candidate_ids.add(semantic_candidate["id"]) + merged_candidates.append( + HybridMemoryCandidate( + memory=memory, + sources=["semantic"], + semantic_score=float(semantic_candidate["score"]), + ) + ) + + deleted_candidates = [ + HybridMemoryCandidate( + memory=memory, + sources=["symbolic"], + ) + for memory in sorted(deleted_memories, key=_semantic_deleted_memory_sort_key) + ] + + decisions = list(deduplication_decisions) + included_candidates = merged_candidates[: limits.max_memories] if limits.max_memories > 0 else [] + excluded_candidates = merged_candidates[limits.max_memories :] if limits.max_memories > 0 else merged_candidates + included_symbolic_only_count = 0 + included_semantic_only_count = 0 + included_dual_source_count = 0 + + for position, candidate in enumerate(merged_candidates, start=1): + if position <= limits.max_memories and limits.max_memories > 0: + if candidate.sources == ["symbolic"]: + included_symbolic_only_count += 1 + elif candidate.sources == ["semantic"]: + included_semantic_only_count += 1 + else: + included_dual_source_count += 1 + decisions.append( + CompilerDecision( + "included", + "memory", + candidate.memory["id"], + "within_hybrid_memory_limit", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=candidate.semantic_score, + ), + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "memory", + candidate.memory["id"], + "hybrid_memory_limit_exceeded", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=candidate.semantic_score, + ), + ) + ) + + for position, candidate in enumerate(deleted_candidates, start=1): + decisions.append( + CompilerDecision( + "excluded", + "memory", + candidate.memory["id"], + "hybrid_memory_deleted", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=None, + ), + ) + ) + + return CompiledMemorySection( + items=[_serialize_hybrid_memory(candidate) for candidate in included_candidates], + summary={ + "candidate_count": len(merged_candidates) + len(deleted_candidates), + "included_count": len(included_candidates), + "excluded_deleted_count": len(deleted_candidates), + "excluded_limit_count": len(excluded_candidates), + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(semantic_retrieval.embedding_config_id), + "query_vector_dimensions": len(query_vector), + "semantic_limit": semantic_retrieval.limit, + "symbolic_selected_count": len(symbolic_candidates), + "semantic_selected_count": len(selected_semantic_candidates), + "merged_candidate_count": len(merged_candidates), + "deduplicated_count": deduplicated_count, + "included_symbolic_only_count": included_symbolic_only_count, + "included_semantic_only_count": included_semantic_only_count, + "included_dual_source_count": included_dual_source_count, + "similarity_metric": "cosine_similarity", + "source_precedence": list(HYBRID_MEMORY_SOURCE_PRECEDENCE), + "symbolic_order": list(HYBRID_SYMBOLIC_ORDER), + "semantic_order": list(SEMANTIC_MEMORY_RETRIEVAL_ORDER), + }, + }, + decisions=decisions, + ) + + +def compile_continuity_context( + *, + user: UserRow, + thread: ThreadRow, + sessions: list[SessionRow], + events: list[EventRow], + memories: list[MemoryRow], + entities: list[EntityRow], + entity_edges: list[EntityEdgeRow], + limits: ContextCompilerLimits, + memory_section: CompiledMemorySection | None = None, +) -> CompilerRunResult: + latest_session_sequence: dict[UUID, int] = {} + for event in events: + session_id = event["session_id"] + if session_id is None: + continue + latest_session_sequence[session_id] = max( + latest_session_sequence.get(session_id, -1), + event["sequence_no"], + ) + + ordered_sessions = sorted( + sessions, + key=lambda session: _session_sort_key(session, latest_session_sequence), + ) + included_sessions = ordered_sessions[-limits.max_sessions :] if limits.max_sessions > 0 else [] + included_session_ids = {session["id"] for session in included_sessions} + + decisions: list[CompilerDecision] = [ + CompilerDecision("included", "user", user["id"], "scope_user", 1), + CompilerDecision("included", "thread", thread["id"], "scope_thread", 1), + ] + + for position, session in enumerate(included_sessions, start=1): + decisions.append( + CompilerDecision( + "included", + "session", + session["id"], + "within_session_limit", + position, + ) + ) + + excluded_sessions = ordered_sessions[: max(len(ordered_sessions) - len(included_sessions), 0)] + for position, session in enumerate(excluded_sessions, start=1): + decisions.append( + CompilerDecision( + "excluded", + "session", + session["id"], + "session_limit_exceeded", + position, + ) + ) + + eligible_events: list[EventRow] = [] + for event in events: + if event["session_id"] is not None and event["session_id"] not in included_session_ids: + decisions.append( + CompilerDecision( + "excluded", + "event", + event["id"], + "session_not_included", + event["sequence_no"], + ) + ) + continue + eligible_events.append(event) + + included_events = eligible_events[-limits.max_events :] if limits.max_events > 0 else [] + included_event_ids = {event["id"] for event in included_events} + + for event in eligible_events: + if event["id"] in included_event_ids: + decisions.append( + CompilerDecision( + "included", + "event", + event["id"], + "within_event_limit", + event["sequence_no"], + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "event", + event["id"], + "event_limit_exceeded", + event["sequence_no"], + ) + ) + + resolved_memory_section = memory_section or _build_symbolic_memory_section( + memories=memories, + limits=limits, + ) + decisions.extend(resolved_memory_section.decisions) + ordered_entities = sorted(entities, key=_entity_sort_key) + included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] + included_entity_ids = {entity["id"] for entity in included_entities} + excluded_entity_limit_count = max(len(ordered_entities) - len(included_entities), 0) + + for position, entity in enumerate(ordered_entities, start=1): + if entity["id"] in included_entity_ids: + decisions.append( + CompilerDecision( + "included", + "entity", + entity["id"], + "within_entity_limit", + position, + metadata={ + "record_entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + }, + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "entity", + entity["id"], + "entity_limit_exceeded", + position, + metadata={ + "record_entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + }, + ) + ) + + ordered_candidate_entity_edges = sorted( + [ + edge + for edge in entity_edges + if edge["from_entity_id"] in included_entity_ids + or edge["to_entity_id"] in included_entity_ids + ], + key=_entity_edge_sort_key, + ) + included_entity_edges = ( + ordered_candidate_entity_edges[-limits.max_entity_edges :] + if limits.max_entity_edges > 0 + else [] + ) + included_entity_edge_ids = {edge["id"] for edge in included_entity_edges} + excluded_entity_edge_limit_count = max( + len(ordered_candidate_entity_edges) - len(included_entity_edges), + 0, + ) + + for position, edge in enumerate(ordered_candidate_entity_edges, start=1): + attached_included_entity_ids = [ + str(entity_id) + for entity_id in (edge["from_entity_id"], edge["to_entity_id"]) + if entity_id in included_entity_ids + ] + metadata = { + "from_entity_id": str(edge["from_entity_id"]), + "to_entity_id": str(edge["to_entity_id"]), + "relationship_type": edge["relationship_type"], + "valid_from": isoformat_or_none(edge["valid_from"]), + "valid_to": isoformat_or_none(edge["valid_to"]), + "source_memory_ids": edge["source_memory_ids"], + "attached_included_entity_ids": attached_included_entity_ids, + } + if edge["id"] in included_entity_edge_ids: + decisions.append( + CompilerDecision( + "included", + "entity_edge", + edge["id"], + "within_entity_edge_limit", + position, + metadata=metadata, + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "entity_edge", + edge["id"], + "entity_edge_limit_exceeded", + position, + metadata=metadata, + ) + ) + + trace_events = [decision.to_trace_event() for decision in decisions] + trace_events.append( + TraceEventRecord( + kind=SUMMARY_TRACE_EVENT_KIND, + payload={ + "included_session_count": len(included_sessions), + "excluded_session_count": len(excluded_sessions), + "included_event_count": len(included_events), + "excluded_event_count": len(events) - len(included_events), + "included_memory_count": resolved_memory_section.summary["included_count"], + "excluded_memory_count": ( + resolved_memory_section.summary["excluded_deleted_count"] + + resolved_memory_section.summary["excluded_limit_count"] + ), + "excluded_deleted_memory_count": resolved_memory_section.summary[ + "excluded_deleted_count" + ], + "excluded_memory_limit_count": resolved_memory_section.summary[ + "excluded_limit_count" + ], + "hybrid_memory_requested": resolved_memory_section.summary["hybrid_retrieval"][ + "requested" + ], + "hybrid_memory_candidate_count": resolved_memory_section.summary["candidate_count"], + "hybrid_memory_merged_candidate_count": resolved_memory_section.summary[ + "hybrid_retrieval" + ]["merged_candidate_count"], + "hybrid_memory_deduplicated_count": resolved_memory_section.summary[ + "hybrid_retrieval" + ]["deduplicated_count"], + "included_dual_source_memory_count": resolved_memory_section.summary[ + "hybrid_retrieval" + ]["included_dual_source_count"], + "included_entity_count": len(included_entities), + "excluded_entity_count": excluded_entity_limit_count, + "excluded_entity_limit_count": excluded_entity_limit_count, + "included_entity_edge_count": len(included_entity_edges), + "excluded_entity_edge_count": excluded_entity_edge_limit_count, + "excluded_entity_edge_limit_count": excluded_entity_edge_limit_count, + "compiler_version": COMPILER_VERSION_V0, + }, + ) + ) + + return CompilerRunResult( + context_pack={ + "compiler_version": COMPILER_VERSION_V0, + "scope": { + "user_id": str(user["id"]), + "thread_id": str(thread["id"]), + }, + "limits": { + "max_sessions": limits.max_sessions, + "max_events": limits.max_events, + "max_memories": limits.max_memories, + "max_entities": limits.max_entities, + "max_entity_edges": limits.max_entity_edges, + }, + "user": _serialize_user(user), + "thread": _serialize_thread(thread), + "sessions": [_serialize_session(session) for session in included_sessions], + "events": [_serialize_event(event) for event in included_events], + "memories": list(resolved_memory_section.items), + "memory_summary": resolved_memory_section.summary, + "entities": [_serialize_entity(entity) for entity in included_entities], + "entity_summary": { + "candidate_count": len(ordered_entities), + "included_count": len(included_entities), + "excluded_limit_count": excluded_entity_limit_count, + }, + "entity_edges": [_serialize_entity_edge(edge) for edge in included_entity_edges], + "entity_edge_summary": { + "anchor_entity_count": len(included_entities), + "candidate_count": len(ordered_candidate_entity_edges), + "included_count": len(included_entity_edges), + "excluded_limit_count": excluded_entity_edge_limit_count, + }, + }, + trace_events=trace_events, + ) + + +def compile_and_persist_trace( + store: ContinuityStore, + *, + user_id: UUID, + thread_id: UUID, + limits: ContextCompilerLimits, + semantic_retrieval: CompileContextSemanticRetrievalInput | None = None, +) -> CompiledTraceRun: + user = store.get_user(user_id) + thread = store.get_thread(thread_id) + sessions = store.list_thread_sessions(thread_id) + events = store.list_thread_events(thread_id) + memories = store.list_context_memories() + memory_section = _compile_memory_section( + store, + memories=memories, + limits=limits, + semantic_retrieval=semantic_retrieval, + ) + entities = store.list_entities() + ordered_entities = sorted(entities, key=_entity_sort_key) + included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] + entity_edges = store.list_entity_edges_for_entities([entity["id"] for entity in included_entities]) + compiler_run = compile_continuity_context( + user=user, + thread=thread, + sessions=sessions, + events=events, + memories=memories, + entities=entities, + entity_edges=entity_edges, + limits=limits, + memory_section=memory_section, + ) + trace = store.create_trace( + user_id=user_id, + thread_id=thread_id, + kind=TRACE_KIND_CONTEXT_COMPILE, + compiler_version=COMPILER_VERSION_V0, + status="completed", + limits=limits.as_payload(), + ) + + for sequence_no, trace_event in enumerate(compiler_run.trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=trace_event.kind, + payload=trace_event.payload, + ) + + return CompiledTraceRun( + trace_id=str(trace["id"]), + context_pack=compiler_run.context_pack, + trace_event_count=len(compiler_run.trace_events), + ) diff --git a/apps/api/src/alicebot_api/config.py b/apps/api/src/alicebot_api/config.py new file mode 100644 index 0000000..3f41bd7 --- /dev/null +++ b/apps/api/src/alicebot_api/config.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from functools import lru_cache +import os + +DEFAULT_APP_ENV = "development" +DEFAULT_APP_HOST = "127.0.0.1" +DEFAULT_APP_PORT = 8000 +DEFAULT_DATABASE_NAME = "alicebot" +DEFAULT_DATABASE_HOST = "localhost" +DEFAULT_DATABASE_PORT = 5432 +DEFAULT_DATABASE_URL = ( + f"postgresql://alicebot_app:alicebot_app@{DEFAULT_DATABASE_HOST}:" + f"{DEFAULT_DATABASE_PORT}/{DEFAULT_DATABASE_NAME}" +) +DEFAULT_DATABASE_ADMIN_URL = ( + f"postgresql://alicebot_admin:alicebot_admin@{DEFAULT_DATABASE_HOST}:" + f"{DEFAULT_DATABASE_PORT}/{DEFAULT_DATABASE_NAME}" +) +DEFAULT_REDIS_URL = f"redis://{DEFAULT_DATABASE_HOST}:6379/0" +DEFAULT_S3_ENDPOINT_URL = "http://localhost:9000" +DEFAULT_S3_ACCESS_KEY = "alicebot" +DEFAULT_S3_SECRET_KEY = "alicebot-secret" +DEFAULT_S3_BUCKET = "alicebot-local" +DEFAULT_HEALTHCHECK_TIMEOUT_SECONDS = 2 +DEFAULT_MODEL_PROVIDER = "openai_responses" +DEFAULT_MODEL_BASE_URL = "https://api.openai.com/v1" +DEFAULT_MODEL_NAME = "gpt-5-mini" +DEFAULT_MODEL_API_KEY = "" +DEFAULT_MODEL_TIMEOUT_SECONDS = 30 +DEFAULT_TASK_WORKSPACE_ROOT = "/tmp/alicebot/task-workspaces" + +Environment = Mapping[str, str] + + +def _get_env_value(env: Environment, key: str, default: str) -> str: + return env.get(key, default) + + +def _get_env_int(env: Environment, key: str, default: int) -> int: + raw_value = env.get(key) + if raw_value is None: + return default + + try: + return int(raw_value) + except ValueError as exc: + raise ValueError(f"{key} must be an integer") from exc + + +@dataclass(frozen=True) +class Settings: + app_env: str = DEFAULT_APP_ENV + app_host: str = DEFAULT_APP_HOST + app_port: int = DEFAULT_APP_PORT + database_url: str = DEFAULT_DATABASE_URL + database_admin_url: str = DEFAULT_DATABASE_ADMIN_URL + redis_url: str = DEFAULT_REDIS_URL + s3_endpoint_url: str = DEFAULT_S3_ENDPOINT_URL + s3_access_key: str = DEFAULT_S3_ACCESS_KEY + s3_secret_key: str = DEFAULT_S3_SECRET_KEY + s3_bucket: str = DEFAULT_S3_BUCKET + healthcheck_timeout_seconds: int = DEFAULT_HEALTHCHECK_TIMEOUT_SECONDS + model_provider: str = DEFAULT_MODEL_PROVIDER + model_base_url: str = DEFAULT_MODEL_BASE_URL + model_name: str = DEFAULT_MODEL_NAME + model_api_key: str = DEFAULT_MODEL_API_KEY + model_timeout_seconds: int = DEFAULT_MODEL_TIMEOUT_SECONDS + task_workspace_root: str = DEFAULT_TASK_WORKSPACE_ROOT + + @classmethod + def from_env(cls, env: Environment | None = None) -> "Settings": + current_env = os.environ if env is None else env + return cls( + app_env=_get_env_value(current_env, "APP_ENV", cls.app_env), + app_host=_get_env_value(current_env, "APP_HOST", cls.app_host), + app_port=_get_env_int(current_env, "APP_PORT", cls.app_port), + database_url=_get_env_value(current_env, "DATABASE_URL", cls.database_url), + database_admin_url=_get_env_value( + current_env, + "DATABASE_ADMIN_URL", + cls.database_admin_url, + ), + redis_url=_get_env_value(current_env, "REDIS_URL", cls.redis_url), + s3_endpoint_url=_get_env_value( + current_env, + "S3_ENDPOINT_URL", + cls.s3_endpoint_url, + ), + s3_access_key=_get_env_value(current_env, "S3_ACCESS_KEY", cls.s3_access_key), + s3_secret_key=_get_env_value(current_env, "S3_SECRET_KEY", cls.s3_secret_key), + s3_bucket=_get_env_value(current_env, "S3_BUCKET", cls.s3_bucket), + healthcheck_timeout_seconds=_get_env_int( + current_env, + "HEALTHCHECK_TIMEOUT_SECONDS", + cls.healthcheck_timeout_seconds, + ), + model_provider=_get_env_value(current_env, "MODEL_PROVIDER", cls.model_provider), + model_base_url=_get_env_value(current_env, "MODEL_BASE_URL", cls.model_base_url), + model_name=_get_env_value(current_env, "MODEL_NAME", cls.model_name), + model_api_key=_get_env_value(current_env, "MODEL_API_KEY", cls.model_api_key), + model_timeout_seconds=_get_env_int( + current_env, + "MODEL_TIMEOUT_SECONDS", + cls.model_timeout_seconds, + ), + task_workspace_root=_get_env_value( + current_env, + "TASK_WORKSPACE_ROOT", + cls.task_workspace_root, + ), + ) + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + return Settings.from_env() diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py new file mode 100644 index 0000000..fc794c2 --- /dev/null +++ b/apps/api/src/alicebot_api/contracts.py @@ -0,0 +1,2080 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Literal, NotRequired, TypedDict +from uuid import UUID + +from alicebot_api.store import JsonObject, JsonValue + +DecisionKind = Literal["included", "excluded"] +AdmissionAction = Literal["NOOP", "ADD", "UPDATE", "DELETE"] +MemoryStatus = Literal["active", "deleted"] +MemoryReviewStatusFilter = Literal["active", "deleted", "all"] +MemoryReviewLabelValue = Literal["correct", "incorrect", "outdated", "insufficient_evidence"] +EntityType = Literal["person", "merchant", "product", "project", "routine"] +EmbeddingConfigStatus = Literal["active", "deprecated", "disabled"] +ConsentStatus = Literal["granted", "revoked"] +ApprovalStatus = Literal["pending", "approved", "rejected"] +ApprovalResolutionAction = Literal["approve", "reject"] +ApprovalResolutionOutcome = Literal["resolved", "duplicate_rejected", "conflict_rejected"] +TaskStatus = Literal["pending_approval", "approved", "executed", "denied", "blocked"] +TaskWorkspaceStatus = Literal["active"] +TaskLifecycleSource = Literal[ + "approval_request", + "approval_resolution", + "proxy_execution", + "task_step_continuation", + "task_step_sequence", + "task_step_transition", +] +TaskStepKind = Literal["governed_request"] +TaskStepStatus = Literal["created", "approved", "executed", "blocked", "denied"] +ProxyExecutionStatus = Literal["completed", "blocked"] +ExecutionBudgetStatus = Literal["active", "inactive", "superseded"] +ExecutionBudgetDecision = Literal["allow", "block"] +ExecutionBudgetDecisionReason = Literal["no_matching_budget", "within_budget", "budget_exceeded"] +ExecutionBudgetCountScope = Literal["lifetime", "rolling_window"] +ExecutionBudgetLifecycleAction = Literal["deactivate", "supersede"] +ExecutionBudgetLifecycleOutcome = Literal["deactivated", "superseded", "rejected"] +PolicyEffect = Literal["allow", "deny", "require_approval"] +PolicyEvaluationReasonCode = Literal[ + "matched_policy", + "policy_effect_allow", + "policy_effect_deny", + "policy_effect_require_approval", + "consent_missing", + "consent_revoked", + "no_matching_policy", +] +ToolMetadataVersion = Literal["tool_metadata_v0"] +ToolAllowlistReasonCode = Literal[ + "tool_metadata_matched", + "tool_action_unsupported", + "tool_scope_unsupported", + "tool_domain_mismatch", + "tool_risk_mismatch", + "matched_policy", + "policy_effect_allow", + "policy_effect_deny", + "policy_effect_require_approval", + "consent_missing", + "consent_revoked", + "no_matching_policy", +] +ToolAllowlistDecision = Literal["allowed", "denied", "approval_required"] +ToolRoutingDecision = Literal["ready", "denied", "approval_required"] +PromptSectionName = Literal["system", "developer", "context", "conversation"] +ModelProvider = Literal["openai_responses"] +ModelFinishReason = Literal["completed", "incomplete"] +ExplicitPreferencePattern = Literal[ + "i_like", + "i_dont_like", + "i_prefer", + "remember_that_i_like", + "remember_that_i_dont_like", + "remember_that_i_prefer", +] +MemorySelectionSource = Literal["symbolic", "semantic"] + +DEFAULT_MAX_SESSIONS = 3 +DEFAULT_MAX_EVENTS = 8 +DEFAULT_MAX_MEMORIES = 5 +DEFAULT_MAX_ENTITIES = 5 +DEFAULT_MAX_ENTITY_EDGES = 10 +DEFAULT_MEMORY_REVIEW_LIMIT = 20 +MAX_MEMORY_REVIEW_LIMIT = 100 +DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT = 5 +MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT = 50 +COMPILER_VERSION_V0 = "continuity_v0" +PROMPT_ASSEMBLY_VERSION_V0 = "prompt_assembly_v0" +RESPONSE_GENERATION_VERSION_V0 = "response_generation_v0" +TRACE_KIND_CONTEXT_COMPILE = "context.compile" +TRACE_KIND_RESPONSE_GENERATE = "response.generate" +MEMORY_REVIEW_ORDER = ["updated_at_desc", "created_at_desc", "id_desc"] +MEMORY_REVIEW_QUEUE_ORDER = ["updated_at_desc", "created_at_desc", "id_desc"] +MEMORY_REVISION_REVIEW_ORDER = ["sequence_no_asc"] +MEMORY_REVIEW_LABEL_VALUES = [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", +] +MEMORY_REVIEW_LABEL_ORDER = ["created_at_asc", "id_asc"] +ENTITY_TYPES = [ + "person", + "merchant", + "product", + "project", + "routine", +] +ENTITY_LIST_ORDER = ["created_at_asc", "id_asc"] +ENTITY_EDGE_LIST_ORDER = ["created_at_asc", "id_asc"] +EMBEDDING_CONFIG_LIST_ORDER = ["created_at_asc", "id_asc"] +MEMORY_EMBEDDING_LIST_ORDER = ["created_at_asc", "id_asc"] +SEMANTIC_MEMORY_RETRIEVAL_ORDER = ["score_desc", "created_at_asc", "id_asc"] +EMBEDDING_CONFIG_STATUSES = ["active", "deprecated", "disabled"] +CONSENT_STATUSES = ["granted", "revoked"] +CONSENT_LIST_ORDER = ["consent_key_asc", "created_at_asc", "id_asc"] +POLICY_EFFECTS = ["allow", "deny", "require_approval"] +POLICY_LIST_ORDER = ["priority_asc", "created_at_asc", "id_asc"] +POLICY_EVALUATION_VERSION_V0 = "policy_evaluation_v0" +TRACE_KIND_POLICY_EVALUATE = "policy.evaluate" +TOOL_METADATA_VERSION_V0 = "tool_metadata_v0" +TOOL_LIST_ORDER = ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"] +TOOL_ALLOWLIST_EVALUATION_VERSION_V0 = "tool_allowlist_evaluation_v0" +TRACE_KIND_TOOL_ALLOWLIST_EVALUATE = "tool.allowlist.evaluate" +TOOL_ROUTING_VERSION_V0 = "tool_routing_v0" +TRACE_KIND_TOOL_ROUTE = "tool.route" +APPROVAL_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_STEP_LIST_ORDER = ["sequence_no_asc", "created_at_asc", "id_asc"] +TOOL_EXECUTION_LIST_ORDER = ["executed_at_asc", "id_asc"] +EXECUTION_BUDGET_LIST_ORDER = ["created_at_asc", "id_asc"] +EXECUTION_BUDGET_MATCH_ORDER = ["specificity_desc", "created_at_asc", "id_asc"] +EXECUTION_BUDGET_STATUSES = ["active", "inactive", "superseded"] +TASK_STATUSES = ["pending_approval", "approved", "executed", "denied", "blocked"] +TASK_WORKSPACE_STATUSES = ["active"] +TASK_STEP_KINDS = ["governed_request"] +TASK_STEP_STATUSES = ["created", "approved", "executed", "blocked", "denied"] +APPROVAL_REQUEST_VERSION_V0 = "approval_request_v0" +TRACE_KIND_APPROVAL_REQUEST = "approval.request" +APPROVAL_RESOLUTION_VERSION_V0 = "approval_resolution_v0" +TRACE_KIND_APPROVAL_RESOLUTION = "approval.resolve" +TRACE_KIND_APPROVAL_RESOLVE = TRACE_KIND_APPROVAL_RESOLUTION +PROXY_EXECUTION_VERSION_V0 = "proxy_execution_v0" +TRACE_KIND_PROXY_EXECUTE = "tool.proxy.execute" +TASK_STEP_SEQUENCE_VERSION_V0 = "task_step_sequence_v0" +TRACE_KIND_TASK_STEP_SEQUENCE = "task.step.sequence" +TASK_STEP_CONTINUATION_VERSION_V0 = "task_step_continuation_v0" +TRACE_KIND_TASK_STEP_CONTINUATION = "task.step.continuation" +TASK_STEP_TRANSITION_VERSION_V0 = "task_step_transition_v0" +TRACE_KIND_TASK_STEP_TRANSITION = "task.step.transition" +EXECUTION_BUDGET_LIFECYCLE_VERSION_V0 = "execution_budget_lifecycle_v0" +TRACE_KIND_EXECUTION_BUDGET_LIFECYCLE = "execution_budget.lifecycle" + + +@dataclass(frozen=True, slots=True) +class ContextCompilerLimits: + max_sessions: int = DEFAULT_MAX_SESSIONS + max_events: int = DEFAULT_MAX_EVENTS + max_memories: int = DEFAULT_MAX_MEMORIES + max_entities: int = DEFAULT_MAX_ENTITIES + max_entity_edges: int = DEFAULT_MAX_ENTITY_EDGES + + def as_payload(self) -> JsonObject: + return { + "max_sessions": self.max_sessions, + "max_events": self.max_events, + "max_memories": self.max_memories, + "max_entities": self.max_entities, + "max_entity_edges": self.max_entity_edges, + } + + +@dataclass(frozen=True, slots=True) +class CompileContextSemanticRetrievalInput: + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class TraceCreate: + user_id: UUID + thread_id: UUID + kind: str + compiler_version: str + status: str + limits: ContextCompilerLimits + + +@dataclass(frozen=True, slots=True) +class TraceEventRecord: + kind: str + payload: JsonObject + + +@dataclass(frozen=True, slots=True) +class CompilerDecision: + kind: DecisionKind + entity_type: str + entity_id: UUID + reason: str + position: int + metadata: JsonObject | None = None + + def to_trace_event(self) -> TraceEventRecord: + payload: JsonObject = { + "entity_type": self.entity_type, + "entity_id": str(self.entity_id), + "reason": self.reason, + "position": self.position, + } + if self.metadata is not None: + payload.update(self.metadata) + return TraceEventRecord(kind=f"context.{self.kind}", payload=payload) + + +class ContextPackScope(TypedDict): + user_id: str + thread_id: str + + +class ContextPackLimits(TypedDict): + max_sessions: int + max_events: int + max_memories: int + max_entities: int + max_entity_edges: int + + +class ContextPackUser(TypedDict): + id: str + email: str + display_name: str | None + created_at: str + + +class ContextPackThread(TypedDict): + id: str + title: str + created_at: str + updated_at: str + + +class ContextPackSession(TypedDict): + id: str + status: str + started_at: str | None + ended_at: str | None + created_at: str + + +class ContextPackEvent(TypedDict): + id: str + session_id: str | None + sequence_no: int + kind: str + payload: JsonObject + created_at: str + + +class ContextPackMemory(TypedDict): + id: str + memory_key: str + value: JsonValue + status: MemoryStatus + source_event_ids: list[str] + created_at: str + updated_at: str + source_provenance: "ContextPackMemorySourceProvenance" + + +class ContextPackMemorySourceProvenance(TypedDict): + sources: list[MemorySelectionSource] + semantic_score: float | None + + +class ContextPackHybridMemorySummary(TypedDict): + requested: bool + embedding_config_id: str | None + query_vector_dimensions: int + semantic_limit: int + symbolic_selected_count: int + semantic_selected_count: int + merged_candidate_count: int + deduplicated_count: int + included_symbolic_only_count: int + included_semantic_only_count: int + included_dual_source_count: int + similarity_metric: Literal["cosine_similarity"] | None + source_precedence: list[MemorySelectionSource] + symbolic_order: list[str] + semantic_order: list[str] + + +class ContextPackMemorySummary(TypedDict): + candidate_count: int + included_count: int + excluded_deleted_count: int + excluded_limit_count: int + hybrid_retrieval: ContextPackHybridMemorySummary + + +class HybridMemoryDecisionTracePayload(TypedDict): + embedding_config_id: str | None + memory_key: str + status: MemoryStatus + source_event_ids: list[str] + selected_sources: list[MemorySelectionSource] + semantic_score: float | None + + +class ContextPackEntity(TypedDict): + id: str + entity_type: EntityType + name: str + source_memory_ids: list[str] + created_at: str + + +class ContextPackEntitySummary(TypedDict): + candidate_count: int + included_count: int + excluded_limit_count: int + + +class EntityDecisionTracePayload(TypedDict): + entity_type: str + entity_id: str + reason: str + position: int + record_entity_type: EntityType + name: str + source_memory_ids: list[str] + + +class ContextPackEntityEdge(TypedDict): + id: str + from_entity_id: str + to_entity_id: str + relationship_type: str + valid_from: str | None + valid_to: str | None + source_memory_ids: list[str] + created_at: str + + +class ContextPackEntityEdgeSummary(TypedDict): + anchor_entity_count: int + candidate_count: int + included_count: int + excluded_limit_count: int + + +class EntityEdgeDecisionTracePayload(TypedDict): + entity_type: str + entity_id: str + reason: str + position: int + from_entity_id: str + to_entity_id: str + relationship_type: str + valid_from: str | None + valid_to: str | None + source_memory_ids: list[str] + attached_included_entity_ids: list[str] + + +class CompiledContextPack(TypedDict): + compiler_version: str + scope: ContextPackScope + limits: ContextPackLimits + user: ContextPackUser + thread: ContextPackThread + sessions: list[ContextPackSession] + events: list[ContextPackEvent] + memories: list[ContextPackMemory] + memory_summary: ContextPackMemorySummary + entities: list[ContextPackEntity] + entity_summary: ContextPackEntitySummary + entity_edges: list[ContextPackEntityEdge] + entity_edge_summary: ContextPackEntityEdgeSummary + + +@dataclass(frozen=True, slots=True) +class CompilerRunResult: + context_pack: CompiledContextPack + trace_events: list[TraceEventRecord] + + +@dataclass(frozen=True, slots=True) +class PromptAssemblyInput: + context_pack: CompiledContextPack + system_instruction: str + developer_instruction: str + + +@dataclass(frozen=True, slots=True) +class PromptSection: + name: PromptSectionName + content: str + + +class PromptAssemblyTracePayload(TypedDict): + version: str + compile_trace_id: str + compiler_version: str + prompt_sha256: str + prompt_char_count: int + section_order: list[PromptSectionName] + section_characters: dict[PromptSectionName, int] + included_session_count: int + included_event_count: int + included_memory_count: int + included_entity_count: int + included_entity_edge_count: int + + +@dataclass(frozen=True, slots=True) +class PromptAssemblyResult: + sections: tuple[PromptSection, ...] + prompt_text: str + prompt_sha256: str + trace_payload: PromptAssemblyTracePayload + + +class ModelInvocationRequestPayload(TypedDict): + provider: ModelProvider + model: str + tool_choice: Literal["none"] + tools: list[JsonObject] + store: bool + sections: list[PromptSectionName] + prompt: str + + +@dataclass(frozen=True, slots=True) +class ModelInvocationRequest: + provider: ModelProvider + model: str + prompt: PromptAssemblyResult + tool_choice: Literal["none"] = "none" + store: bool = False + + def as_payload(self) -> ModelInvocationRequestPayload: + return { + "provider": self.provider, + "model": self.model, + "tool_choice": self.tool_choice, + "tools": [], + "store": self.store, + "sections": [section.name for section in self.prompt.sections], + "prompt": self.prompt.prompt_text, + } + + +class ModelUsagePayload(TypedDict): + input_tokens: int | None + output_tokens: int | None + total_tokens: int | None + + +class ModelInvocationTracePayload(TypedDict): + provider: ModelProvider + model: str + tool_choice: Literal["none"] + tools_enabled: Literal[False] + response_id: str | None + finish_reason: ModelFinishReason + output_text_char_count: int + usage: ModelUsagePayload + error_message: str | None + + +@dataclass(frozen=True, slots=True) +class ModelInvocationResponse: + provider: ModelProvider + model: str + response_id: str | None + finish_reason: ModelFinishReason + output_text: str + usage: ModelUsagePayload + + def to_trace_payload(self, *, error_message: str | None = None) -> ModelInvocationTracePayload: + return { + "provider": self.provider, + "model": self.model, + "tool_choice": "none", + "tools_enabled": False, + "response_id": self.response_id, + "finish_reason": self.finish_reason, + "output_text_char_count": len(self.output_text), + "usage": self.usage, + "error_message": error_message, + } + + +class AssistantResponseModelRecord(TypedDict): + provider: ModelProvider + model: str + response_id: str | None + finish_reason: ModelFinishReason + usage: ModelUsagePayload + + +class AssistantResponsePromptRecord(TypedDict): + assembly_version: str + prompt_sha256: str + section_order: list[PromptSectionName] + + +class AssistantResponseEventPayload(TypedDict): + text: str + model: AssistantResponseModelRecord + prompt: AssistantResponsePromptRecord + + +class GeneratedAssistantRecord(TypedDict): + event_id: str + sequence_no: int + text: str + model_provider: ModelProvider + model: str + + +class ResponseTraceSummary(TypedDict): + compile_trace_id: str + compile_trace_event_count: int + response_trace_id: str + response_trace_event_count: int + + +class GenerateResponseSuccess(TypedDict): + assistant: GeneratedAssistantRecord + trace: ResponseTraceSummary + + +@dataclass(frozen=True, slots=True) +class MemoryCandidateInput: + memory_key: str + value: JsonValue | None + source_event_ids: tuple[UUID, ...] + delete_requested: bool = False + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "memory_key": self.memory_key, + "source_event_ids": [str(source_event_id) for source_event_id in self.source_event_ids], + "delete_requested": self.delete_requested, + } + payload["value"] = self.value + return payload + + +@dataclass(frozen=True, slots=True) +class ExplicitPreferenceExtractionRequestInput: + source_event_id: UUID + + def as_payload(self) -> JsonObject: + return { + "source_event_id": str(self.source_event_id), + } + + +class ExtractedPreferenceCandidateRecord(TypedDict): + memory_key: str + value: JsonValue + source_event_ids: list[str] + delete_requested: bool + pattern: ExplicitPreferencePattern + subject_text: str + + +@dataclass(frozen=True, slots=True) +class EntityCreateInput: + entity_type: EntityType + name: str + source_memory_ids: tuple[UUID, ...] + + def as_payload(self) -> JsonObject: + return { + "entity_type": self.entity_type, + "name": self.name, + "source_memory_ids": [str(source_memory_id) for source_memory_id in self.source_memory_ids], + } + + +@dataclass(frozen=True, slots=True) +class EntityEdgeCreateInput: + from_entity_id: UUID + to_entity_id: UUID + relationship_type: str + valid_from: datetime | None + valid_to: datetime | None + source_memory_ids: tuple[UUID, ...] + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "from_entity_id": str(self.from_entity_id), + "to_entity_id": str(self.to_entity_id), + "relationship_type": self.relationship_type, + "source_memory_ids": [str(source_memory_id) for source_memory_id in self.source_memory_ids], + } + payload["valid_from"] = isoformat_or_none(self.valid_from) + payload["valid_to"] = isoformat_or_none(self.valid_to) + return payload + + +@dataclass(frozen=True, slots=True) +class EmbeddingConfigCreateInput: + provider: str + model: str + version: str + dimensions: int + status: EmbeddingConfigStatus + metadata: JsonObject + + def as_payload(self) -> JsonObject: + return { + "provider": self.provider, + "model": self.model, + "version": self.version, + "dimensions": self.dimensions, + "status": self.status, + "metadata": self.metadata, + } + + +@dataclass(frozen=True, slots=True) +class MemoryEmbeddingUpsertInput: + memory_id: UUID + embedding_config_id: UUID + vector: tuple[float, ...] + + def as_payload(self) -> JsonObject: + return { + "memory_id": str(self.memory_id), + "embedding_config_id": str(self.embedding_config_id), + "vector": [float(value) for value in self.vector], + } + + +@dataclass(frozen=True, slots=True) +class SemanticMemoryRetrievalRequestInput: + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class ConsentUpsertInput: + consent_key: str + status: ConsentStatus + metadata: JsonObject + + def as_payload(self) -> JsonObject: + return { + "consent_key": self.consent_key, + "status": self.status, + "metadata": self.metadata, + } + + +@dataclass(frozen=True, slots=True) +class PolicyCreateInput: + name: str + action: str + scope: str + effect: PolicyEffect + priority: int + active: bool + conditions: JsonObject + required_consents: tuple[str, ...] + + def as_payload(self) -> JsonObject: + return { + "name": self.name, + "action": self.action, + "scope": self.scope, + "effect": self.effect, + "priority": self.priority, + "active": self.active, + "conditions": self.conditions, + "required_consents": list(self.required_consents), + } + + +@dataclass(frozen=True, slots=True) +class PolicyEvaluationRequestInput: + thread_id: UUID + action: str + scope: str + attributes: JsonObject + + def as_payload(self) -> JsonObject: + return { + "thread_id": str(self.thread_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + + +@dataclass(frozen=True, slots=True) +class ToolCreateInput: + tool_key: str + name: str + description: str + version: str + metadata_version: ToolMetadataVersion = TOOL_METADATA_VERSION_V0 + active: bool = True + tags: tuple[str, ...] = field(default_factory=tuple) + action_hints: tuple[str, ...] = field(default_factory=tuple) + scope_hints: tuple[str, ...] = field(default_factory=tuple) + domain_hints: tuple[str, ...] = field(default_factory=tuple) + risk_hints: tuple[str, ...] = field(default_factory=tuple) + metadata: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + return { + "tool_key": self.tool_key, + "name": self.name, + "description": self.description, + "version": self.version, + "metadata_version": self.metadata_version, + "active": self.active, + "tags": list(self.tags), + "action_hints": list(self.action_hints), + "scope_hints": list(self.scope_hints), + "domain_hints": list(self.domain_hints), + "risk_hints": list(self.risk_hints), + "metadata": self.metadata, + } + + +@dataclass(frozen=True, slots=True) +class ToolAllowlistEvaluationRequestInput: + thread_id: UUID + action: str + scope: str + domain_hint: str | None = None + risk_hint: str | None = None + attributes: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "thread_id": str(self.thread_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + payload["domain_hint"] = self.domain_hint + payload["risk_hint"] = self.risk_hint + return payload + + +@dataclass(frozen=True, slots=True) +class ToolRoutingRequestInput: + thread_id: UUID + tool_id: UUID + action: str + scope: str + domain_hint: str | None = None + risk_hint: str | None = None + attributes: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "thread_id": str(self.thread_id), + "tool_id": str(self.tool_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + payload["domain_hint"] = self.domain_hint + payload["risk_hint"] = self.risk_hint + return payload + + +@dataclass(frozen=True, slots=True) +class ApprovalRequestCreateInput: + thread_id: UUID + tool_id: UUID + action: str + scope: str + domain_hint: str | None = None + risk_hint: str | None = None + attributes: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "thread_id": str(self.thread_id), + "tool_id": str(self.tool_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + payload["domain_hint"] = self.domain_hint + payload["risk_hint"] = self.risk_hint + return payload + + +@dataclass(frozen=True, slots=True) +class ApprovalApproveInput: + approval_id: UUID + + def as_payload(self) -> JsonObject: + return { + "approval_id": str(self.approval_id), + "requested_action": "approve", + } + + +@dataclass(frozen=True, slots=True) +class ApprovalRejectInput: + approval_id: UUID + + def as_payload(self) -> JsonObject: + return { + "approval_id": str(self.approval_id), + "requested_action": "reject", + } + + +@dataclass(frozen=True, slots=True) +class ProxyExecutionRequestInput: + approval_id: UUID + + def as_payload(self) -> JsonObject: + return { + "approval_id": str(self.approval_id), + } + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetCreateInput: + max_completed_executions: int + tool_key: str | None = None + domain_hint: str | None = None + rolling_window_seconds: int | None = None + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "max_completed_executions": self.max_completed_executions, + } + payload["tool_key"] = self.tool_key + payload["domain_hint"] = self.domain_hint + payload["rolling_window_seconds"] = self.rolling_window_seconds + return payload + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetDeactivateInput: + thread_id: UUID + execution_budget_id: UUID + + def as_payload(self) -> JsonObject: + return { + "thread_id": str(self.thread_id), + "execution_budget_id": str(self.execution_budget_id), + "requested_action": "deactivate", + } + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetSupersedeInput: + thread_id: UUID + execution_budget_id: UUID + max_completed_executions: int + + def as_payload(self) -> JsonObject: + return { + "thread_id": str(self.thread_id), + "execution_budget_id": str(self.execution_budget_id), + "requested_action": "supersede", + "max_completed_executions": self.max_completed_executions, + } + + +class PersistedMemoryRecord(TypedDict): + id: str + user_id: str + memory_key: str + value: JsonValue + status: MemoryStatus + source_event_ids: list[str] + created_at: str + updated_at: str + deleted_at: str | None + + +class PersistedMemoryRevisionRecord(TypedDict): + id: str + user_id: str + memory_id: str + sequence_no: int + action: AdmissionAction + memory_key: str + previous_value: JsonValue | None + new_value: JsonValue | None + source_event_ids: list[str] + candidate: JsonObject + created_at: str + + +@dataclass(frozen=True, slots=True) +class AdmissionDecisionOutput: + action: AdmissionAction + reason: str + memory: PersistedMemoryRecord | None + revision: PersistedMemoryRevisionRecord | None + + +class ExplicitPreferenceAdmissionRecord(TypedDict): + decision: AdmissionAction + reason: str + memory: PersistedMemoryRecord | None + revision: PersistedMemoryRevisionRecord | None + + +class ExplicitPreferenceExtractionSummary(TypedDict): + source_event_id: str + source_event_kind: str + candidate_count: int + admission_count: int + persisted_change_count: int + noop_count: int + + +class ExplicitPreferenceExtractionResponse(TypedDict): + candidates: list[ExtractedPreferenceCandidateRecord] + admissions: list[ExplicitPreferenceAdmissionRecord] + summary: ExplicitPreferenceExtractionSummary + + +class MemoryReviewRecord(TypedDict): + id: str + memory_key: str + value: JsonValue + status: MemoryStatus + source_event_ids: list[str] + created_at: str + updated_at: str + deleted_at: str | None + + +class MemoryReviewListSummary(TypedDict): + status: MemoryReviewStatusFilter + limit: int + returned_count: int + total_count: int + has_more: bool + order: list[str] + + +class MemoryReviewListResponse(TypedDict): + items: list[MemoryReviewRecord] + summary: MemoryReviewListSummary + + +class MemoryReviewDetailResponse(TypedDict): + memory: MemoryReviewRecord + + +class MemoryRevisionReviewRecord(TypedDict): + id: str + memory_id: str + sequence_no: int + action: AdmissionAction + memory_key: str + previous_value: JsonValue | None + new_value: JsonValue | None + source_event_ids: list[str] + created_at: str + + +class MemoryRevisionReviewListSummary(TypedDict): + memory_id: str + limit: int + returned_count: int + total_count: int + has_more: bool + order: list[str] + + +class MemoryRevisionReviewListResponse(TypedDict): + items: list[MemoryRevisionReviewRecord] + summary: MemoryRevisionReviewListSummary + + +class MemoryReviewLabelCounts(TypedDict): + correct: int + incorrect: int + outdated: int + insufficient_evidence: int + + +class MemoryReviewLabelRecord(TypedDict): + id: str + memory_id: str + reviewer_user_id: str + label: MemoryReviewLabelValue + note: str | None + created_at: str + + +class MemoryReviewLabelSummary(TypedDict): + memory_id: str + total_count: int + counts_by_label: MemoryReviewLabelCounts + order: list[str] + + +class MemoryReviewLabelCreateResponse(TypedDict): + label: MemoryReviewLabelRecord + summary: MemoryReviewLabelSummary + + +class MemoryReviewLabelListResponse(TypedDict): + items: list[MemoryReviewLabelRecord] + summary: MemoryReviewLabelSummary + + +class MemoryReviewQueueItem(TypedDict): + id: str + memory_key: str + value: JsonValue + status: Literal["active"] + source_event_ids: list[str] + created_at: str + updated_at: str + + +class MemoryReviewQueueSummary(TypedDict): + memory_status: Literal["active"] + review_state: Literal["unlabeled"] + limit: int + returned_count: int + total_count: int + has_more: bool + order: list[str] + + +class MemoryReviewQueueResponse(TypedDict): + items: list[MemoryReviewQueueItem] + summary: MemoryReviewQueueSummary + + +class MemoryEvaluationSummary(TypedDict): + total_memory_count: int + active_memory_count: int + deleted_memory_count: int + labeled_memory_count: int + unlabeled_memory_count: int + total_label_row_count: int + label_row_counts_by_value: MemoryReviewLabelCounts + label_value_order: list[MemoryReviewLabelValue] + + +class MemoryEvaluationSummaryResponse(TypedDict): + summary: MemoryEvaluationSummary + + +class EntityRecord(TypedDict): + id: str + entity_type: EntityType + name: str + source_memory_ids: list[str] + created_at: str + + +class EntityCreateResponse(TypedDict): + entity: EntityRecord + + +class EntityListSummary(TypedDict): + total_count: int + order: list[str] + + +class EntityListResponse(TypedDict): + items: list[EntityRecord] + summary: EntityListSummary + + +class EntityDetailResponse(TypedDict): + entity: EntityRecord + + +class EntityEdgeRecord(ContextPackEntityEdge): + pass + + +class EntityEdgeCreateResponse(TypedDict): + edge: EntityEdgeRecord + + +class EntityEdgeListSummary(TypedDict): + entity_id: str + total_count: int + order: list[str] + + +class EntityEdgeListResponse(TypedDict): + items: list[EntityEdgeRecord] + summary: EntityEdgeListSummary + + +class EmbeddingConfigRecord(TypedDict): + id: str + provider: str + model: str + version: str + dimensions: int + status: EmbeddingConfigStatus + metadata: JsonObject + created_at: str + + +class EmbeddingConfigCreateResponse(TypedDict): + embedding_config: EmbeddingConfigRecord + + +class EmbeddingConfigListSummary(TypedDict): + total_count: int + order: list[str] + + +class EmbeddingConfigListResponse(TypedDict): + items: list[EmbeddingConfigRecord] + summary: EmbeddingConfigListSummary + + +class MemoryEmbeddingRecord(TypedDict): + id: str + memory_id: str + embedding_config_id: str + dimensions: int + vector: list[float] + created_at: str + updated_at: str + + +class MemoryEmbeddingUpsertResponse(TypedDict): + embedding: MemoryEmbeddingRecord + write_mode: Literal["created", "updated"] + + +class MemoryEmbeddingDetailResponse(TypedDict): + embedding: MemoryEmbeddingRecord + + +class MemoryEmbeddingListSummary(TypedDict): + memory_id: str + total_count: int + order: list[str] + + +class MemoryEmbeddingListResponse(TypedDict): + items: list[MemoryEmbeddingRecord] + summary: MemoryEmbeddingListSummary + + +class SemanticMemoryRetrievalResultItem(TypedDict): + memory_id: str + memory_key: str + value: JsonValue + source_event_ids: list[str] + created_at: str + updated_at: str + score: float + + +class SemanticMemoryRetrievalSummary(TypedDict): + embedding_config_id: str + limit: int + returned_count: int + similarity_metric: Literal["cosine_similarity"] + order: list[str] + + +class SemanticMemoryRetrievalResponse(TypedDict): + items: list[SemanticMemoryRetrievalResultItem] + summary: SemanticMemoryRetrievalSummary + + +class ConsentRecord(TypedDict): + id: str + consent_key: str + status: ConsentStatus + metadata: JsonObject + created_at: str + updated_at: str + + +class ConsentUpsertResponse(TypedDict): + consent: ConsentRecord + write_mode: Literal["created", "updated"] + + +class ConsentListSummary(TypedDict): + total_count: int + order: list[str] + + +class ConsentListResponse(TypedDict): + items: list[ConsentRecord] + summary: ConsentListSummary + + +class PolicyRecord(TypedDict): + id: str + name: str + action: str + scope: str + effect: PolicyEffect + priority: int + active: bool + conditions: JsonObject + required_consents: list[str] + created_at: str + updated_at: str + + +class PolicyCreateResponse(TypedDict): + policy: PolicyRecord + + +class PolicyListSummary(TypedDict): + total_count: int + order: list[str] + + +class PolicyListResponse(TypedDict): + items: list[PolicyRecord] + summary: PolicyListSummary + + +class PolicyDetailResponse(TypedDict): + policy: PolicyRecord + + +class PolicyEvaluationReason(TypedDict): + code: PolicyEvaluationReasonCode + source: Literal["policy", "consent", "system"] + message: str + policy_id: str | None + consent_key: str | None + + +class PolicyEvaluationSummary(TypedDict): + action: str + scope: str + evaluated_policy_count: int + matched_policy_id: str | None + order: list[str] + + +class PolicyEvaluationTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class PolicyEvaluationResponse(TypedDict): + decision: PolicyEffect + matched_policy: PolicyRecord | None + reasons: list[PolicyEvaluationReason] + evaluation: PolicyEvaluationSummary + trace: PolicyEvaluationTraceSummary + + +class ToolRecord(TypedDict): + id: str + tool_key: str + name: str + description: str + version: str + metadata_version: ToolMetadataVersion + active: bool + tags: list[str] + action_hints: list[str] + scope_hints: list[str] + domain_hints: list[str] + risk_hints: list[str] + metadata: JsonObject + created_at: str + + +class ToolCreateResponse(TypedDict): + tool: ToolRecord + + +class ToolListSummary(TypedDict): + total_count: int + order: list[str] + + +class ToolListResponse(TypedDict): + items: list[ToolRecord] + summary: ToolListSummary + + +class ToolDetailResponse(TypedDict): + tool: ToolRecord + + +class ToolAllowlistReason(TypedDict): + code: ToolAllowlistReasonCode + source: Literal["tool", "policy", "consent", "system"] + message: str + tool_id: str | None + policy_id: str | None + consent_key: str | None + + +class ToolAllowlistDecisionRecord(TypedDict): + decision: ToolAllowlistDecision + tool: ToolRecord + reasons: list[ToolAllowlistReason] + + +class ToolAllowlistEvaluationSummary(TypedDict): + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + evaluated_tool_count: int + allowed_count: int + denied_count: int + approval_required_count: int + order: list[str] + + +class ToolAllowlistTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ToolAllowlistEvaluationResponse(TypedDict): + allowed: list[ToolAllowlistDecisionRecord] + denied: list[ToolAllowlistDecisionRecord] + approval_required: list[ToolAllowlistDecisionRecord] + summary: ToolAllowlistEvaluationSummary + trace: ToolAllowlistTraceSummary + + +class ToolRoutingRequestRecord(TypedDict): + thread_id: str + tool_id: str + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + attributes: JsonObject + + +class ToolRoutingRequestTracePayload(TypedDict): + thread_id: str + tool_id: str + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + attributes: JsonObject + + +class ToolRoutingDecisionTracePayload(TypedDict): + tool_id: str + tool_key: str + tool_version: str + allowlist_decision: ToolAllowlistDecision + routing_decision: ToolRoutingDecision + matched_policy_id: str | None + reasons: list[ToolAllowlistReason] + + +class ToolRoutingSummaryTracePayload(TypedDict): + decision: ToolRoutingDecision + evaluated_tool_count: int + active_policy_count: int + consent_count: int + + +class ToolRoutingSummary(TypedDict): + thread_id: str + tool_id: str + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + decision: ToolRoutingDecision + evaluated_tool_count: int + active_policy_count: int + consent_count: int + order: list[str] + + +class ToolRoutingTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ToolRoutingResponse(TypedDict): + request: ToolRoutingRequestRecord + decision: ToolRoutingDecision + tool: ToolRecord + reasons: list[ToolAllowlistReason] + summary: ToolRoutingSummary + trace: ToolRoutingTraceSummary + + +class ApprovalRoutingRecord(TypedDict): + decision: ToolRoutingDecision + reasons: list[ToolAllowlistReason] + trace: ToolRoutingTraceSummary + + +class ApprovalResolutionRecord(TypedDict): + resolved_at: str + resolved_by_user_id: str + + +class ApprovalRecord(TypedDict): + id: str + thread_id: str + task_step_id: str | None + status: ApprovalStatus + request: ToolRoutingRequestRecord + tool: ToolRecord + routing: ApprovalRoutingRecord + created_at: str + resolution: ApprovalResolutionRecord | None + + +class ApprovalRequestTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ApprovalResolutionTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ApprovalResolutionRequestTracePayload(TypedDict): + approval_id: str + task_step_id: str | None + requested_action: ApprovalResolutionAction + + +class ApprovalResolutionStateTracePayload(TypedDict): + approval_id: str + task_step_id: str | None + requested_action: ApprovalResolutionAction + previous_status: ApprovalStatus + outcome: ApprovalResolutionOutcome + current_status: ApprovalStatus + resolved_at: str | None + resolved_by_user_id: str | None + + +class ApprovalResolutionSummaryTracePayload(TypedDict): + approval_id: str + task_step_id: str | None + requested_action: ApprovalResolutionAction + outcome: ApprovalResolutionOutcome + final_status: ApprovalStatus + + +@dataclass(frozen=True, slots=True) +class TaskCreateInput: + thread_id: UUID + tool_id: UUID + status: TaskStatus + request: ToolRoutingRequestRecord + tool: ToolRecord + latest_approval_id: UUID | None = None + latest_execution_id: UUID | None = None + + +class TaskRecord(TypedDict): + id: str + thread_id: str + tool_id: str + status: TaskStatus + request: ToolRoutingRequestRecord + tool: ToolRecord + latest_approval_id: str | None + latest_execution_id: str | None + created_at: str + updated_at: str + + +class TaskCreateResponse(TypedDict): + task: TaskRecord + + +@dataclass(frozen=True, slots=True) +class TaskStepCreateInput: + task_id: UUID + sequence_no: int + kind: TaskStepKind + status: TaskStepStatus + request: ToolRoutingRequestRecord + outcome: "TaskStepOutcomeSnapshot" + trace_id: UUID + trace_kind: str + + +@dataclass(frozen=True, slots=True) +class TaskStepNextCreateInput: + task_id: UUID + kind: TaskStepKind + status: TaskStepStatus + request: ToolRoutingRequestRecord + outcome: "TaskStepOutcomeSnapshot" + lineage: "TaskStepLineageInput" + + +@dataclass(frozen=True, slots=True) +class TaskStepTransitionInput: + task_step_id: UUID + status: TaskStepStatus + outcome: "TaskStepOutcomeSnapshot" + + +@dataclass(frozen=True, slots=True) +class TaskStepLineageInput: + parent_step_id: UUID + source_approval_id: UUID | None = None + source_execution_id: UUID | None = None + + +class TaskListSummary(TypedDict): + total_count: int + order: list[str] + + +class TaskListResponse(TypedDict): + items: list[TaskRecord] + summary: TaskListSummary + + +class TaskDetailResponse(TypedDict): + task: TaskRecord + + +@dataclass(frozen=True, slots=True) +class TaskWorkspaceCreateInput: + task_id: UUID + status: TaskWorkspaceStatus + + +class TaskWorkspaceRecord(TypedDict): + id: str + task_id: str + status: TaskWorkspaceStatus + local_path: str + created_at: str + updated_at: str + + +class TaskWorkspaceCreateResponse(TypedDict): + workspace: TaskWorkspaceRecord + + +class TaskWorkspaceListSummary(TypedDict): + total_count: int + order: list[str] + + +class TaskWorkspaceListResponse(TypedDict): + items: list[TaskWorkspaceRecord] + summary: TaskWorkspaceListSummary + + +class TaskWorkspaceDetailResponse(TypedDict): + workspace: TaskWorkspaceRecord + + +class TaskStepTraceLink(TypedDict): + trace_id: str + trace_kind: str + + +class TaskStepOutcomeSnapshot(TypedDict): + routing_decision: ToolRoutingDecision + approval_id: str | None + approval_status: ApprovalStatus | None + execution_id: str | None + execution_status: ProxyExecutionStatus | None + blocked_reason: str | None + + +class TaskStepLineageRecord(TypedDict): + parent_step_id: str | None + source_approval_id: str | None + source_execution_id: str | None + + +class TaskStepRecord(TypedDict): + id: str + task_id: str + sequence_no: int + kind: TaskStepKind + status: TaskStepStatus + request: ToolRoutingRequestRecord + outcome: TaskStepOutcomeSnapshot + lineage: TaskStepLineageRecord + trace: TaskStepTraceLink + created_at: str + updated_at: str + + +class TaskStepCreateResponse(TypedDict): + task_step: TaskStepRecord + + +class TaskStepSequencingSummary(TypedDict): + task_id: str + total_count: int + latest_sequence_no: int | None + latest_status: TaskStepStatus | None + next_sequence_no: int + append_allowed: bool + order: list[str] + + +class TaskStepListSummary(TaskStepSequencingSummary): + pass + + +class TaskStepListResponse(TypedDict): + items: list[TaskStepRecord] + summary: TaskStepListSummary + + +class TaskStepDetailResponse(TypedDict): + task_step: TaskStepRecord + + +class TaskStepMutationTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class TaskStepNextCreateResponse(TypedDict): + task: TaskRecord + task_step: TaskStepRecord + sequencing: TaskStepSequencingSummary + trace: TaskStepMutationTraceSummary + + +class TaskStepTransitionResponse(TypedDict): + task: TaskRecord + task_step: TaskStepRecord + sequencing: TaskStepSequencingSummary + trace: TaskStepMutationTraceSummary + + +class TaskLifecycleStateTracePayload(TypedDict): + task_id: str + source: TaskLifecycleSource + previous_status: TaskStatus | None + current_status: TaskStatus + latest_approval_id: str | None + latest_execution_id: str | None + + +class TaskLifecycleSummaryTracePayload(TypedDict): + task_id: str + source: TaskLifecycleSource + final_status: TaskStatus + latest_approval_id: str | None + latest_execution_id: str | None + + +class TaskStepLifecycleStateTracePayload(TypedDict): + task_id: str + task_step_id: str + source: TaskLifecycleSource + sequence_no: int + kind: TaskStepKind + previous_status: TaskStepStatus | None + current_status: TaskStepStatus + trace: TaskStepTraceLink + + +class TaskStepLifecycleSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + source: TaskLifecycleSource + sequence_no: int + kind: TaskStepKind + final_status: TaskStepStatus + trace: TaskStepTraceLink + + +class TaskStepSequenceRequestTracePayload(TypedDict): + task_id: str + previous_task_step_id: str + previous_sequence_no: int + previous_status: TaskStepStatus + requested_kind: TaskStepKind + requested_status: TaskStepStatus + + +class TaskStepSequenceStateTracePayload(TypedDict): + task_id: str + previous_task_step_id: str + previous_sequence_no: int + previous_status: TaskStepStatus + task_step_id: str + assigned_sequence_no: int + kind: TaskStepKind + current_status: TaskStepStatus + + +class TaskStepSequenceSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + latest_sequence_no: int + next_sequence_no: int + append_allowed: bool + + +class TaskStepContinuationRequestTracePayload(TypedDict): + task_id: str + parent_task_step_id: str + parent_sequence_no: int + parent_status: TaskStepStatus + requested_kind: TaskStepKind + requested_status: TaskStepStatus + requested_source_approval_id: str | None + requested_source_execution_id: str | None + + +class TaskStepContinuationLineageTracePayload(TypedDict): + task_id: str + parent_task_step_id: str + parent_sequence_no: int + parent_status: TaskStepStatus + source_approval_id: str | None + source_execution_id: str | None + + +class TaskStepContinuationSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + latest_sequence_no: int + next_sequence_no: int + append_allowed: bool + lineage: TaskStepLineageRecord + + +class TaskStepTransitionRequestTracePayload(TypedDict): + task_id: str + task_step_id: str + sequence_no: int + previous_status: TaskStepStatus + requested_status: TaskStepStatus + + +class TaskStepTransitionStateTracePayload(TypedDict): + task_id: str + task_step_id: str + sequence_no: int + previous_status: TaskStepStatus + current_status: TaskStepStatus + allowed_next_statuses: list[TaskStepStatus] + trace: TaskStepTraceLink + + +class TaskStepTransitionSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + sequence_no: int + final_status: TaskStepStatus + parent_task_status: TaskStatus + trace: TaskStepTraceLink + + +class ApprovalRequestCreateResponse(TypedDict): + request: ToolRoutingRequestRecord + decision: ToolRoutingDecision + tool: ToolRecord + reasons: list[ToolAllowlistReason] + task: TaskRecord + approval: ApprovalRecord | None + routing_trace: ToolRoutingTraceSummary + trace: ApprovalRequestTraceSummary + + +class ApprovalListSummary(TypedDict): + total_count: int + order: list[str] + + +class ApprovalListResponse(TypedDict): + items: list[ApprovalRecord] + summary: ApprovalListSummary + + +class ApprovalDetailResponse(TypedDict): + approval: ApprovalRecord + + +class ApprovalResolutionResponse(TypedDict): + approval: ApprovalRecord + trace: ApprovalResolutionTraceSummary + + +class ExecutionBudgetRecord(TypedDict): + id: str + tool_key: str | None + domain_hint: str | None + max_completed_executions: int + rolling_window_seconds: int | None + status: ExecutionBudgetStatus + deactivated_at: str | None + superseded_by_budget_id: str | None + supersedes_budget_id: str | None + created_at: str + + +class ExecutionBudgetCreateResponse(TypedDict): + execution_budget: ExecutionBudgetRecord + + +class ExecutionBudgetListSummary(TypedDict): + total_count: int + order: list[str] + + +class ExecutionBudgetListResponse(TypedDict): + items: list[ExecutionBudgetRecord] + summary: ExecutionBudgetListSummary + + +class ExecutionBudgetDetailResponse(TypedDict): + execution_budget: ExecutionBudgetRecord + + +class ExecutionBudgetLifecycleTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ExecutionBudgetDeactivateResponse(TypedDict): + execution_budget: ExecutionBudgetRecord + trace: ExecutionBudgetLifecycleTraceSummary + + +class ExecutionBudgetSupersedeResponse(TypedDict): + superseded_budget: ExecutionBudgetRecord + replacement_budget: ExecutionBudgetRecord + trace: ExecutionBudgetLifecycleTraceSummary + + +class ExecutionBudgetDecisionRecord(TypedDict): + matched_budget_id: str | None + tool_key: str + domain_hint: str | None + budget_tool_key: str | None + budget_domain_hint: str | None + max_completed_executions: int | None + rolling_window_seconds: int | None + count_scope: ExecutionBudgetCountScope + window_started_at: str | None + completed_execution_count: int + projected_completed_execution_count: int + decision: ExecutionBudgetDecision + reason: ExecutionBudgetDecisionReason + order: list[str] + history_order: list[str] + + +class ExecutionBudgetLifecycleRequestTracePayload(TypedDict): + thread_id: str + execution_budget_id: str + requested_action: ExecutionBudgetLifecycleAction + replacement_max_completed_executions: int | None + + +class ExecutionBudgetLifecycleStateTracePayload(TypedDict): + execution_budget_id: str + requested_action: ExecutionBudgetLifecycleAction + previous_status: ExecutionBudgetStatus + current_status: ExecutionBudgetStatus + tool_key: str | None + domain_hint: str | None + max_completed_executions: int + rolling_window_seconds: int | None + deactivated_at: str | None + superseded_by_budget_id: str | None + supersedes_budget_id: str | None + replacement_budget_id: str | None + replacement_status: ExecutionBudgetStatus | None + replacement_max_completed_executions: int | None + replacement_rolling_window_seconds: int | None + rejection_reason: str | None + + +class ExecutionBudgetLifecycleSummaryTracePayload(TypedDict): + execution_budget_id: str + requested_action: ExecutionBudgetLifecycleAction + outcome: ExecutionBudgetLifecycleOutcome + replacement_budget_id: str | None + active_budget_id: str | None + + +@dataclass(frozen=True, slots=True) +class ToolExecutionCreateInput: + approval_id: UUID + task_step_id: UUID + thread_id: UUID + tool_id: UUID + trace_id: UUID + request_event_id: UUID | None + result_event_id: UUID | None + status: ProxyExecutionStatus + handler_key: str | None + request: ToolRoutingRequestRecord + tool: ToolRecord + result: "ToolExecutionResultRecord" + + +class ToolExecutionRecord(TypedDict): + id: str + approval_id: str + task_step_id: str + thread_id: str + tool_id: str + trace_id: str + request_event_id: str | None + result_event_id: str | None + status: ProxyExecutionStatus + handler_key: str | None + request: ToolRoutingRequestRecord + tool: ToolRecord + result: "ToolExecutionResultRecord" + executed_at: str + + +class ToolExecutionListSummary(TypedDict): + total_count: int + order: list[str] + + +class ToolExecutionListResponse(TypedDict): + items: list[ToolExecutionRecord] + summary: ToolExecutionListSummary + + +class ToolExecutionDetailResponse(TypedDict): + execution: ToolExecutionRecord + + +class ProxyExecutionRequestRecord(TypedDict): + approval_id: str + task_step_id: str + + +class ProxyExecutionRequestEventPayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + request: ToolRoutingRequestRecord + + +class ProxyExecutionResultRecord(TypedDict): + handler_key: str + status: Literal["completed"] + output: JsonObject + + +class ProxyExecutionResultEventPayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + handler_key: str + status: Literal["completed"] + output: JsonObject + + +class ToolExecutionResultRecord(TypedDict): + handler_key: str | None + status: ProxyExecutionStatus + output: JsonObject | None + reason: str | None + budget_decision: NotRequired[ExecutionBudgetDecisionRecord] + + +class ProxyExecutionEventSummary(TypedDict): + request_event_id: str + request_sequence_no: int + result_event_id: str + result_sequence_no: int + + +class ProxyExecutionTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ProxyExecutionBudgetPrecheckTracePayload(ExecutionBudgetDecisionRecord): + pass + + +class ProxyExecutionApprovalTracePayload(TypedDict): + approval_id: str + task_step_id: str + approval_status: ApprovalStatus + eligible_for_execution: bool + + +class ProxyExecutionDispatchTracePayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + handler_key: str | None + dispatch_status: Literal["executed", "blocked"] + reason: str | None + result_status: ProxyExecutionStatus | None + output: JsonObject | None + + +class ProxyExecutionSummaryTracePayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + approval_status: ApprovalStatus + execution_status: Literal["completed", "blocked"] + handler_key: str | None + request_event_id: str | None + result_event_id: str | None + + +class ProxyExecutionResponse(TypedDict): + request: ProxyExecutionRequestRecord + approval: ApprovalRecord + tool: ToolRecord + result: ProxyExecutionResultRecord | ToolExecutionResultRecord + events: ProxyExecutionEventSummary | None + trace: ProxyExecutionTraceSummary + + +class ProxyExecutionBudgetBlockedResponse(TypedDict): + request: ProxyExecutionRequestRecord + approval: ApprovalRecord + tool: ToolRecord + result: ToolExecutionResultRecord + events: None + trace: ProxyExecutionTraceSummary + + +def isoformat_or_none(value: datetime | None) -> str | None: + if value is None: + return None + return value.isoformat() diff --git a/apps/api/src/alicebot_api/db.py b/apps/api/src/alicebot_api/db.py new file mode 100644 index 0000000..cc6e87b --- /dev/null +++ b/apps/api/src/alicebot_api/db.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from uuid import UUID + +import psycopg +from psycopg.rows import dict_row + +PING_DATABASE_SQL = "SELECT 1" +SET_CURRENT_USER_SQL = "SELECT set_config('app.current_user_id', %s, true)" +ConnectionRow = dict[str, object] +UserConnection = psycopg.Connection[ConnectionRow] + + +def ping_database(database_url: str, timeout_seconds: int) -> bool: + try: + with psycopg.connect(database_url, connect_timeout=timeout_seconds) as conn: + with conn.cursor() as cur: + cur.execute(PING_DATABASE_SQL) + cur.fetchone() + return True + except psycopg.Error: + return False + + +def set_current_user(conn: psycopg.Connection, user_id: UUID) -> None: + with conn.cursor() as cur: + cur.execute(SET_CURRENT_USER_SQL, (str(user_id),)) + + +@contextmanager +def user_connection(database_url: str, user_id: UUID) -> Iterator[UserConnection]: + with psycopg.connect(database_url, row_factory=dict_row) as conn: + with conn.transaction(): + set_current_user(conn, user_id) + yield conn diff --git a/apps/api/src/alicebot_api/embedding.py b/apps/api/src/alicebot_api/embedding.py new file mode 100644 index 0000000..5248197 --- /dev/null +++ b/apps/api/src/alicebot_api/embedding.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import math +from uuid import UUID + +import psycopg + +from alicebot_api.contracts import ( + EMBEDDING_CONFIG_LIST_ORDER, + MEMORY_EMBEDDING_LIST_ORDER, + EmbeddingConfigCreateInput, + EmbeddingConfigCreateResponse, + EmbeddingConfigListResponse, + EmbeddingConfigListSummary, + EmbeddingConfigRecord, + MemoryEmbeddingDetailResponse, + MemoryEmbeddingListResponse, + MemoryEmbeddingListSummary, + MemoryEmbeddingRecord, + MemoryEmbeddingUpsertInput, + MemoryEmbeddingUpsertResponse, +) +from alicebot_api.store import ContinuityStore, EmbeddingConfigRow, MemoryEmbeddingRow + + +class EmbeddingConfigValidationError(ValueError): + """Raised when an embedding-config request fails explicit validation.""" + + +class MemoryEmbeddingValidationError(ValueError): + """Raised when a memory-embedding request fails explicit validation.""" + + +class MemoryEmbeddingNotFoundError(LookupError): + """Raised when a requested memory embedding is not visible inside the current user scope.""" + + +def _duplicate_embedding_config_message( + *, + provider: str, + model: str, + version: str, +) -> str: + return ( + "embedding config already exists for provider/model/version under the user scope: " + f"{provider}/{model}/{version}" + ) + + +def _serialize_embedding_config(config: EmbeddingConfigRow) -> EmbeddingConfigRecord: + return { + "id": str(config["id"]), + "provider": config["provider"], + "model": config["model"], + "version": config["version"], + "dimensions": config["dimensions"], + "status": config["status"], + "metadata": config["metadata"], + "created_at": config["created_at"].isoformat(), + } + + +def _serialize_memory_embedding(embedding: MemoryEmbeddingRow) -> MemoryEmbeddingRecord: + return { + "id": str(embedding["id"]), + "memory_id": str(embedding["memory_id"]), + "embedding_config_id": str(embedding["embedding_config_id"]), + "dimensions": embedding["dimensions"], + "vector": [float(value) for value in embedding["vector"]], + "created_at": embedding["created_at"].isoformat(), + "updated_at": embedding["updated_at"].isoformat(), + } + + +def _validate_vector(vector: tuple[float, ...]) -> list[float]: + if not vector: + raise MemoryEmbeddingValidationError("vector must include at least one numeric value") + + normalized: list[float] = [] + for value in vector: + normalized_value = float(value) + if not math.isfinite(normalized_value): + raise MemoryEmbeddingValidationError("vector must contain only finite numeric values") + normalized.append(normalized_value) + + return normalized + + +def create_embedding_config_record( + store: ContinuityStore, + *, + user_id: UUID, + config: EmbeddingConfigCreateInput, +) -> EmbeddingConfigCreateResponse: + del user_id + + existing = store.get_embedding_config_by_identity_optional( + provider=config.provider, + model=config.model, + version=config.version, + ) + if existing is not None: + raise EmbeddingConfigValidationError( + _duplicate_embedding_config_message( + provider=config.provider, + model=config.model, + version=config.version, + ) + ) + + try: + created = store.create_embedding_config( + provider=config.provider, + model=config.model, + version=config.version, + dimensions=config.dimensions, + status=config.status, + metadata=config.metadata, + ) + except psycopg.errors.UniqueViolation as exc: + raise EmbeddingConfigValidationError( + _duplicate_embedding_config_message( + provider=config.provider, + model=config.model, + version=config.version, + ) + ) from exc + return {"embedding_config": _serialize_embedding_config(created)} + + +def list_embedding_config_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> EmbeddingConfigListResponse: + del user_id + + configs = store.list_embedding_configs() + items = [_serialize_embedding_config(config) for config in configs] + summary: EmbeddingConfigListSummary = { + "total_count": len(items), + "order": list(EMBEDDING_CONFIG_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def upsert_memory_embedding_record( + store: ContinuityStore, + *, + user_id: UUID, + request: MemoryEmbeddingUpsertInput, +) -> MemoryEmbeddingUpsertResponse: + del user_id + + memory = store.get_memory_optional(request.memory_id) + if memory is None: + raise MemoryEmbeddingValidationError( + f"memory_id must reference an existing memory owned by the user: {request.memory_id}" + ) + + config = store.get_embedding_config_optional(request.embedding_config_id) + if config is None: + raise MemoryEmbeddingValidationError( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{request.embedding_config_id}" + ) + + vector = _validate_vector(request.vector) + if len(vector) != config["dimensions"]: + raise MemoryEmbeddingValidationError( + "vector length must match embedding config dimensions " + f"({config['dimensions']}): {len(vector)}" + ) + + existing = store.get_memory_embedding_by_memory_and_config_optional( + memory_id=request.memory_id, + embedding_config_id=request.embedding_config_id, + ) + if existing is None: + created = store.create_memory_embedding( + memory_id=request.memory_id, + embedding_config_id=request.embedding_config_id, + dimensions=config["dimensions"], + vector=vector, + ) + return { + "embedding": _serialize_memory_embedding(created), + "write_mode": "created", + } + + updated = store.update_memory_embedding( + memory_embedding_id=existing["id"], + dimensions=config["dimensions"], + vector=vector, + ) + return { + "embedding": _serialize_memory_embedding(updated), + "write_mode": "updated", + } + + +def get_memory_embedding_record( + store: ContinuityStore, + *, + user_id: UUID, + memory_embedding_id: UUID, +) -> MemoryEmbeddingDetailResponse: + del user_id + + embedding = store.get_memory_embedding_optional(memory_embedding_id) + if embedding is None: + raise MemoryEmbeddingNotFoundError(f"memory embedding {memory_embedding_id} was not found") + + return {"embedding": _serialize_memory_embedding(embedding)} + + +def list_memory_embedding_records( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, +) -> MemoryEmbeddingListResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryEmbeddingNotFoundError(f"memory {memory_id} was not found") + + embeddings = store.list_memory_embeddings_for_memory(memory_id) + items = [_serialize_memory_embedding(embedding) for embedding in embeddings] + summary: MemoryEmbeddingListSummary = { + "memory_id": str(memory_id), + "total_count": len(items), + "order": list(MEMORY_EMBEDDING_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } diff --git a/apps/api/src/alicebot_api/entity.py b/apps/api/src/alicebot_api/entity.py new file mode 100644 index 0000000..8e811eb --- /dev/null +++ b/apps/api/src/alicebot_api/entity.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from uuid import UUID + +from alicebot_api.contracts import ( + ENTITY_LIST_ORDER, + EntityCreateInput, + EntityCreateResponse, + EntityDetailResponse, + EntityListResponse, + EntityListSummary, + EntityRecord, +) +from alicebot_api.store import ContinuityStore, EntityRow + + +class EntityValidationError(ValueError): + """Raised when an entity create request fails explicit validation.""" + + +class EntityNotFoundError(LookupError): + """Raised when a requested entity is not visible inside the current user scope.""" + + +def _serialize_entity(entity: EntityRow) -> EntityRecord: + return { + "id": str(entity["id"]), + "entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + "created_at": entity["created_at"].isoformat(), + } + + +def _dedupe_source_memory_ids(source_memory_ids: tuple[UUID, ...]) -> tuple[UUID, ...]: + deduped: list[UUID] = [] + seen: set[UUID] = set() + for source_memory_id in source_memory_ids: + if source_memory_id in seen: + continue + seen.add(source_memory_id) + deduped.append(source_memory_id) + return tuple(deduped) + + +def _validate_source_memories(store: ContinuityStore, source_memory_ids: tuple[UUID, ...]) -> list[str]: + normalized_memory_ids = _dedupe_source_memory_ids(source_memory_ids) + if not normalized_memory_ids: + raise EntityValidationError( + "source_memory_ids must include at least one existing memory owned by the user" + ) + + source_memories = store.list_memories_by_ids(list(normalized_memory_ids)) + found_memory_ids = {memory["id"] for memory in source_memories} + missing_memory_ids = [ + str(source_memory_id) + for source_memory_id in normalized_memory_ids + if source_memory_id not in found_memory_ids + ] + if missing_memory_ids: + raise EntityValidationError( + "source_memory_ids must all reference existing memories owned by the user: " + + ", ".join(missing_memory_ids) + ) + + return [str(source_memory_id) for source_memory_id in normalized_memory_ids] + + +def create_entity_record( + store: ContinuityStore, + *, + user_id: UUID, + entity: EntityCreateInput, +) -> EntityCreateResponse: + del user_id + + source_memory_ids = _validate_source_memories(store, entity.source_memory_ids) + created = store.create_entity( + entity_type=entity.entity_type, + name=entity.name, + source_memory_ids=source_memory_ids, + ) + return {"entity": _serialize_entity(created)} + + +def list_entity_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> EntityListResponse: + del user_id + + entities = store.list_entities() + items = [_serialize_entity(entity) for entity in entities] + summary: EntityListSummary = { + "total_count": len(items), + "order": list(ENTITY_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_entity_record( + store: ContinuityStore, + *, + user_id: UUID, + entity_id: UUID, +) -> EntityDetailResponse: + del user_id + + entity = store.get_entity_optional(entity_id) + if entity is None: + raise EntityNotFoundError(f"entity {entity_id} was not found") + + return {"entity": _serialize_entity(entity)} diff --git a/apps/api/src/alicebot_api/entity_edge.py b/apps/api/src/alicebot_api/entity_edge.py new file mode 100644 index 0000000..84731a2 --- /dev/null +++ b/apps/api/src/alicebot_api/entity_edge.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from datetime import datetime +from uuid import UUID + +from alicebot_api.contracts import ( + ENTITY_EDGE_LIST_ORDER, + EntityEdgeCreateInput, + EntityEdgeCreateResponse, + EntityEdgeListResponse, + EntityEdgeListSummary, + EntityEdgeRecord, + isoformat_or_none, +) +from alicebot_api.entity import EntityNotFoundError +from alicebot_api.store import ContinuityStore, EntityEdgeRow + + +class EntityEdgeValidationError(ValueError): + """Raised when an entity-edge request fails explicit validation.""" + + +def _serialize_entity_edge(edge: EntityEdgeRow) -> EntityEdgeRecord: + return { + "id": str(edge["id"]), + "from_entity_id": str(edge["from_entity_id"]), + "to_entity_id": str(edge["to_entity_id"]), + "relationship_type": edge["relationship_type"], + "valid_from": isoformat_or_none(edge["valid_from"]), + "valid_to": isoformat_or_none(edge["valid_to"]), + "source_memory_ids": edge["source_memory_ids"], + "created_at": edge["created_at"].isoformat(), + } + + +def _dedupe_source_memory_ids(source_memory_ids: tuple[UUID, ...]) -> tuple[UUID, ...]: + deduped: list[UUID] = [] + seen: set[UUID] = set() + for source_memory_id in source_memory_ids: + if source_memory_id in seen: + continue + seen.add(source_memory_id) + deduped.append(source_memory_id) + return tuple(deduped) + + +def _validate_source_memories(store: ContinuityStore, source_memory_ids: tuple[UUID, ...]) -> list[str]: + normalized_memory_ids = _dedupe_source_memory_ids(source_memory_ids) + if not normalized_memory_ids: + raise EntityEdgeValidationError( + "source_memory_ids must include at least one existing memory owned by the user" + ) + + source_memories = store.list_memories_by_ids(list(normalized_memory_ids)) + found_memory_ids = {memory["id"] for memory in source_memories} + missing_memory_ids = [ + str(source_memory_id) + for source_memory_id in normalized_memory_ids + if source_memory_id not in found_memory_ids + ] + if missing_memory_ids: + raise EntityEdgeValidationError( + "source_memory_ids must all reference existing memories owned by the user: " + + ", ".join(missing_memory_ids) + ) + + return [str(source_memory_id) for source_memory_id in normalized_memory_ids] + + +def _validate_entity_exists( + store: ContinuityStore, + *, + field_name: str, + entity_id: UUID, +) -> None: + entity = store.get_entity_optional(entity_id) + if entity is None: + raise EntityEdgeValidationError( + f"{field_name} must reference an existing entity owned by the user: {entity_id}" + ) + + +def _validate_temporal_range(valid_from: datetime | None, valid_to: datetime | None) -> None: + if valid_from is not None and valid_to is not None and valid_to < valid_from: + raise EntityEdgeValidationError("valid_to must be greater than or equal to valid_from") + + +def create_entity_edge_record( + store: ContinuityStore, + *, + user_id: UUID, + edge: EntityEdgeCreateInput, +) -> EntityEdgeCreateResponse: + del user_id + + _validate_entity_exists(store, field_name="from_entity_id", entity_id=edge.from_entity_id) + _validate_entity_exists(store, field_name="to_entity_id", entity_id=edge.to_entity_id) + _validate_temporal_range(edge.valid_from, edge.valid_to) + source_memory_ids = _validate_source_memories(store, edge.source_memory_ids) + + created = store.create_entity_edge( + from_entity_id=edge.from_entity_id, + to_entity_id=edge.to_entity_id, + relationship_type=edge.relationship_type, + valid_from=edge.valid_from, + valid_to=edge.valid_to, + source_memory_ids=source_memory_ids, + ) + return {"edge": _serialize_entity_edge(created)} + + +def list_entity_edge_records( + store: ContinuityStore, + *, + user_id: UUID, + entity_id: UUID, +) -> EntityEdgeListResponse: + del user_id + + entity = store.get_entity_optional(entity_id) + if entity is None: + raise EntityNotFoundError(f"entity {entity_id} was not found") + + edges = store.list_entity_edges_for_entity(entity_id) + items = [_serialize_entity_edge(edge) for edge in edges] + summary: EntityEdgeListSummary = { + "entity_id": str(entity["id"]), + "total_count": len(items), + "order": list(ENTITY_EDGE_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } diff --git a/apps/api/src/alicebot_api/execution_budgets.py b/apps/api/src/alicebot_api/execution_budgets.py new file mode 100644 index 0000000..870bd13 --- /dev/null +++ b/apps/api/src/alicebot_api/execution_budgets.py @@ -0,0 +1,818 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import cast +from uuid import UUID, uuid4 + +import psycopg + +from alicebot_api.contracts import ( + EXECUTION_BUDGET_LIFECYCLE_VERSION_V0, + EXECUTION_BUDGET_LIST_ORDER, + EXECUTION_BUDGET_MATCH_ORDER, + EXECUTION_BUDGET_STATUSES, + TOOL_EXECUTION_LIST_ORDER, + TRACE_KIND_EXECUTION_BUDGET_LIFECYCLE, + ExecutionBudgetCreateInput, + ExecutionBudgetCreateResponse, + ExecutionBudgetDeactivateInput, + ExecutionBudgetDeactivateResponse, + ExecutionBudgetDecisionRecord, + ExecutionBudgetDetailResponse, + ExecutionBudgetLifecycleAction, + ExecutionBudgetLifecycleOutcome, + ExecutionBudgetLifecycleRequestTracePayload, + ExecutionBudgetLifecycleStateTracePayload, + ExecutionBudgetLifecycleSummaryTracePayload, + ExecutionBudgetListResponse, + ExecutionBudgetListSummary, + ExecutionBudgetRecord, + ExecutionBudgetSupersedeInput, + ExecutionBudgetSupersedeResponse, + ToolExecutionResultRecord, + ToolRecord, + ToolRoutingRequestRecord, +) +from alicebot_api.store import ContinuityStore, ExecutionBudgetRow, ToolExecutionRow + + +class ExecutionBudgetValidationError(ValueError): + """Raised when an execution-budget request fails explicit validation.""" + + +class ExecutionBudgetNotFoundError(LookupError): + """Raised when an execution budget is not visible inside the current user scope.""" + + +class ExecutionBudgetLifecycleError(RuntimeError): + """Raised when an execution budget lifecycle transition is invalid.""" + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetDecision: + record: ExecutionBudgetDecisionRecord + blocked_result: ToolExecutionResultRecord | None + + +def serialize_execution_budget_row(row: ExecutionBudgetRow) -> ExecutionBudgetRecord: + return { + "id": str(row["id"]), + "tool_key": row["tool_key"], + "domain_hint": row["domain_hint"], + "max_completed_executions": row["max_completed_executions"], + "rolling_window_seconds": row["rolling_window_seconds"], + "status": cast(str, row["status"]), + "deactivated_at": None if row["deactivated_at"] is None else row["deactivated_at"].isoformat(), + "superseded_by_budget_id": ( + None if row["superseded_by_budget_id"] is None else str(row["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if row["supersedes_budget_id"] is None else str(row["supersedes_budget_id"]) + ), + "created_at": row["created_at"].isoformat(), + } + + +def _validate_budget_scope(*, tool_key: str | None, domain_hint: str | None) -> None: + if tool_key is None and domain_hint is None: + raise ExecutionBudgetValidationError( + "execution budget requires at least one selector: tool_key or domain_hint" + ) + + +def _validate_rolling_window_seconds(rolling_window_seconds: int | None) -> None: + if rolling_window_seconds is not None and rolling_window_seconds <= 0: + raise ExecutionBudgetValidationError( + "rolling_window_seconds must be greater than 0 when provided" + ) + + +def _validate_lifecycle_thread(store: ContinuityStore, *, thread_id: UUID) -> dict[str, object]: + thread = store.get_thread_optional(thread_id) + if thread is None: + raise ExecutionBudgetValidationError( + "thread_id must reference an existing thread owned by the user" + ) + return cast(dict[str, object], thread) + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _trace_summary(trace_id: UUID, trace_events: list[tuple[str, dict[str, object]]]) -> dict[str, object]: + return { + "trace_id": str(trace_id), + "trace_event_count": len(trace_events), + } + + +def _active_budget_rows_for_scope( + store: ContinuityStore, + *, + tool_key: str | None, + domain_hint: str | None, +) -> list[ExecutionBudgetRow]: + rows = [ + row + for row in store.list_execution_budgets() + if row["tool_key"] == tool_key + and row["domain_hint"] == domain_hint + and cast(str, row["status"]) == "active" + ] + return sorted(rows, key=lambda row: (row["created_at"], row["id"])) + + +def _scope_label(*, tool_key: str | None, domain_hint: str | None) -> str: + return f"tool_key={tool_key!r}, domain_hint={domain_hint!r}" + + +def _duplicate_active_scope_message(*, tool_key: str | None, domain_hint: str | None) -> str: + return ( + "active execution budget already exists for selector scope " + f"{_scope_label(tool_key=tool_key, domain_hint=domain_hint)}" + ) + + +def _is_active_scope_uniqueness_error(exc: psycopg.Error) -> bool: + diag = getattr(exc, "diag", None) + return getattr(diag, "constraint_name", None) == "execution_budgets_one_active_scope_idx" + + +def _invalid_transition_error( + *, + row: ExecutionBudgetRow, + requested_action: ExecutionBudgetLifecycleAction, +) -> ExecutionBudgetLifecycleError: + return ExecutionBudgetLifecycleError( + f"execution budget {row['id']} is {row['status']} and cannot be {requested_action}d" + ) + + +def _record_lifecycle_trace( + store: ContinuityStore, + *, + thread: dict[str, object], + request_payload: ExecutionBudgetLifecycleRequestTracePayload, + state_payload: ExecutionBudgetLifecycleStateTracePayload, + summary_payload: ExecutionBudgetLifecycleSummaryTracePayload, + requested_action: ExecutionBudgetLifecycleAction, + outcome: ExecutionBudgetLifecycleOutcome, +) -> dict[str, object]: + trace = store.create_trace( + user_id=cast(UUID, thread["user_id"]), + thread_id=cast(UUID, thread["id"]), + kind=TRACE_KIND_EXECUTION_BUDGET_LIFECYCLE, + compiler_version=EXECUTION_BUDGET_LIFECYCLE_VERSION_V0, + status="completed", + limits={ + "order": list(EXECUTION_BUDGET_LIST_ORDER), + "match_order": list(EXECUTION_BUDGET_MATCH_ORDER), + "statuses": list(EXECUTION_BUDGET_STATUSES), + "requested_action": requested_action, + "outcome": outcome, + }, + ) + trace_events: list[tuple[str, dict[str, object]]] = [ + ("execution_budget.lifecycle.request", cast(dict[str, object], request_payload)), + ("execution_budget.lifecycle.state", cast(dict[str, object], state_payload)), + ("execution_budget.lifecycle.summary", cast(dict[str, object], summary_payload)), + ] + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + return _trace_summary(trace["id"], trace_events) + + +def create_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ExecutionBudgetCreateInput, +) -> ExecutionBudgetCreateResponse: + del user_id + + _validate_budget_scope(tool_key=request.tool_key, domain_hint=request.domain_hint) + _validate_rolling_window_seconds(request.rolling_window_seconds) + if _active_budget_rows_for_scope( + store, + tool_key=request.tool_key, + domain_hint=request.domain_hint, + ): + raise ExecutionBudgetValidationError( + _duplicate_active_scope_message( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + ) + ) + try: + row = store.create_execution_budget( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + max_completed_executions=request.max_completed_executions, + rolling_window_seconds=request.rolling_window_seconds, + ) + except psycopg.IntegrityError as exc: + if _is_active_scope_uniqueness_error(exc): + raise ExecutionBudgetValidationError( + _duplicate_active_scope_message( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + ) + ) from exc + raise + return {"execution_budget": serialize_execution_budget_row(row)} + + +def list_execution_budget_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ExecutionBudgetListResponse: + del user_id + + items = [serialize_execution_budget_row(row) for row in store.list_execution_budgets()] + summary: ExecutionBudgetListSummary = { + "total_count": len(items), + "order": list(EXECUTION_BUDGET_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + execution_budget_id: UUID, +) -> ExecutionBudgetDetailResponse: + del user_id + + row = store.get_execution_budget_optional(execution_budget_id) + if row is None: + raise ExecutionBudgetNotFoundError(f"execution budget {execution_budget_id} was not found") + return {"execution_budget": serialize_execution_budget_row(row)} + + +def deactivate_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ExecutionBudgetDeactivateInput, +) -> ExecutionBudgetDeactivateResponse: + del user_id + + thread = _validate_lifecycle_thread(store, thread_id=request.thread_id) + row = store.get_execution_budget_optional(request.execution_budget_id) + if row is None: + raise ExecutionBudgetNotFoundError( + f"execution budget {request.execution_budget_id} was not found" + ) + + request_payload: ExecutionBudgetLifecycleRequestTracePayload = { + "thread_id": str(request.thread_id), + "execution_budget_id": str(request.execution_budget_id), + "requested_action": "deactivate", + "replacement_max_completed_executions": None, + } + + if cast(str, row["status"]) != "active": + error = _invalid_transition_error(row=row, requested_action="deactivate") + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(row["id"]), + "requested_action": "deactivate", + "previous_status": cast(str, row["status"]), + "current_status": cast(str, row["status"]), + "tool_key": row["tool_key"], + "domain_hint": row["domain_hint"], + "max_completed_executions": row["max_completed_executions"], + "rolling_window_seconds": row["rolling_window_seconds"], + "deactivated_at": ( + None if row["deactivated_at"] is None else row["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if row["superseded_by_budget_id"] is None else str(row["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if row["supersedes_budget_id"] is None else str(row["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": None, + "replacement_rolling_window_seconds": None, + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(row["id"]), + "requested_action": "deactivate", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": None, + }, + requested_action="deactivate", + outcome="rejected", + ) + del trace + raise error + + updated = store.deactivate_execution_budget_optional(request.execution_budget_id) + if updated is None: + raise ExecutionBudgetLifecycleError( + f"execution budget {request.execution_budget_id} could not be deactivated" + ) + + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(updated["id"]), + "requested_action": "deactivate", + "previous_status": "active", + "current_status": cast(str, updated["status"]), + "tool_key": updated["tool_key"], + "domain_hint": updated["domain_hint"], + "max_completed_executions": updated["max_completed_executions"], + "rolling_window_seconds": updated["rolling_window_seconds"], + "deactivated_at": ( + None if updated["deactivated_at"] is None else updated["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if updated["superseded_by_budget_id"] is None else str(updated["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if updated["supersedes_budget_id"] is None else str(updated["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": None, + "replacement_rolling_window_seconds": None, + "rejection_reason": None, + }, + summary_payload={ + "execution_budget_id": str(updated["id"]), + "requested_action": "deactivate", + "outcome": "deactivated", + "replacement_budget_id": None, + "active_budget_id": None, + }, + requested_action="deactivate", + outcome="deactivated", + ) + return { + "execution_budget": serialize_execution_budget_row(updated), + "trace": cast(dict[str, object], trace), + } + + +def supersede_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ExecutionBudgetSupersedeInput, +) -> ExecutionBudgetSupersedeResponse: + del user_id + + thread = _validate_lifecycle_thread(store, thread_id=request.thread_id) + current = store.get_execution_budget_optional(request.execution_budget_id) + if current is None: + raise ExecutionBudgetNotFoundError( + f"execution budget {request.execution_budget_id} was not found" + ) + + request_payload: ExecutionBudgetLifecycleRequestTracePayload = { + "thread_id": str(request.thread_id), + "execution_budget_id": str(request.execution_budget_id), + "requested_action": "supersede", + "replacement_max_completed_executions": request.max_completed_executions, + } + + if cast(str, current["status"]) != "active": + error = _invalid_transition_error(row=current, requested_action="supersede") + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "previous_status": cast(str, current["status"]), + "current_status": cast(str, current["status"]), + "tool_key": current["tool_key"], + "domain_hint": current["domain_hint"], + "max_completed_executions": current["max_completed_executions"], + "rolling_window_seconds": current["rolling_window_seconds"], + "deactivated_at": ( + None if current["deactivated_at"] is None else current["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if current["superseded_by_budget_id"] is None else str(current["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if current["supersedes_budget_id"] is None else str(current["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": request.max_completed_executions, + "replacement_rolling_window_seconds": current["rolling_window_seconds"], + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": str(current["id"]) if cast(str, current["status"]) == "active" else None, + }, + requested_action="supersede", + outcome="rejected", + ) + del trace + raise error + + active_scope_rows = _active_budget_rows_for_scope( + store, + tool_key=current["tool_key"], + domain_hint=current["domain_hint"], + ) + if [row["id"] for row in active_scope_rows] != [current["id"]]: + error = ExecutionBudgetLifecycleError( + "execution budget selector scope must have exactly one active budget to supersede: " + f"{_scope_label(tool_key=current['tool_key'], domain_hint=current['domain_hint'])}" + ) + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "previous_status": "active", + "current_status": "active", + "tool_key": current["tool_key"], + "domain_hint": current["domain_hint"], + "max_completed_executions": current["max_completed_executions"], + "rolling_window_seconds": current["rolling_window_seconds"], + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": ( + None if current["supersedes_budget_id"] is None else str(current["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": request.max_completed_executions, + "replacement_rolling_window_seconds": current["rolling_window_seconds"], + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": str(current["id"]), + }, + requested_action="supersede", + outcome="rejected", + ) + del trace + raise error + + replacement_budget_id = uuid4() + try: + with store.conn.transaction(): + superseded = store.supersede_execution_budget_optional( + execution_budget_id=request.execution_budget_id, + superseded_by_budget_id=replacement_budget_id, + ) + if superseded is None: + raise ExecutionBudgetLifecycleError( + f"execution budget {request.execution_budget_id} could not be superseded" + ) + replacement = store.create_execution_budget( + budget_id=replacement_budget_id, + tool_key=current["tool_key"], + domain_hint=current["domain_hint"], + max_completed_executions=request.max_completed_executions, + rolling_window_seconds=current["rolling_window_seconds"], + supersedes_budget_id=current["id"], + ) + except psycopg.IntegrityError as exc: + if _is_active_scope_uniqueness_error(exc): + error = ExecutionBudgetLifecycleError( + _duplicate_active_scope_message( + tool_key=current["tool_key"], + domain_hint=current["domain_hint"], + ) + ) + else: + raise + except ExecutionBudgetLifecycleError as exc: + error = exc + else: + error = None + + if error is not None: + current_state = store.get_execution_budget_optional(request.execution_budget_id) + if current_state is None: + raise ExecutionBudgetNotFoundError( + f"execution budget {request.execution_budget_id} was not found" + ) + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(current_state["id"]), + "requested_action": "supersede", + "previous_status": cast(str, current["status"]), + "current_status": cast(str, current_state["status"]), + "tool_key": current_state["tool_key"], + "domain_hint": current_state["domain_hint"], + "max_completed_executions": current_state["max_completed_executions"], + "rolling_window_seconds": current_state["rolling_window_seconds"], + "deactivated_at": ( + None + if current_state["deactivated_at"] is None + else current_state["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None + if current_state["superseded_by_budget_id"] is None + else str(current_state["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None + if current_state["supersedes_budget_id"] is None + else str(current_state["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": request.max_completed_executions, + "replacement_rolling_window_seconds": current["rolling_window_seconds"], + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(current_state["id"]), + "requested_action": "supersede", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": ( + str(current_state["id"]) + if cast(str, current_state["status"]) == "active" + else None + ), + }, + requested_action="supersede", + outcome="rejected", + ) + del trace + raise error + + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(superseded["id"]), + "requested_action": "supersede", + "previous_status": "active", + "current_status": cast(str, superseded["status"]), + "tool_key": superseded["tool_key"], + "domain_hint": superseded["domain_hint"], + "max_completed_executions": superseded["max_completed_executions"], + "rolling_window_seconds": superseded["rolling_window_seconds"], + "deactivated_at": ( + None if superseded["deactivated_at"] is None else superseded["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if superseded["superseded_by_budget_id"] is None else str(superseded["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if superseded["supersedes_budget_id"] is None else str(superseded["supersedes_budget_id"]) + ), + "replacement_budget_id": str(replacement["id"]), + "replacement_status": cast(str, replacement["status"]), + "replacement_max_completed_executions": replacement["max_completed_executions"], + "replacement_rolling_window_seconds": replacement["rolling_window_seconds"], + "rejection_reason": None, + }, + summary_payload={ + "execution_budget_id": str(superseded["id"]), + "requested_action": "supersede", + "outcome": "superseded", + "replacement_budget_id": str(replacement["id"]), + "active_budget_id": str(replacement["id"]), + }, + requested_action="supersede", + outcome="superseded", + ) + return { + "superseded_budget": serialize_execution_budget_row(superseded), + "replacement_budget": serialize_execution_budget_row(replacement), + "trace": cast(dict[str, object], trace), + } + + +def _budget_specificity(row: ExecutionBudgetRow) -> int: + return int(row["tool_key"] is not None) + int(row["domain_hint"] is not None) + + +def _matches_budget( + row: ExecutionBudgetRow, + *, + tool_key: str, + domain_hint: str | None, +) -> bool: + if row["tool_key"] is not None and row["tool_key"] != tool_key: + return False + if row["domain_hint"] is not None and row["domain_hint"] != domain_hint: + return False + return True + + +def _matching_budget_rows( + store: ContinuityStore, + *, + tool_key: str, + domain_hint: str | None, +) -> list[ExecutionBudgetRow]: + rows = [ + row + for row in store.list_execution_budgets() + if cast(str, row["status"]) == "active" + and _matches_budget(row, tool_key=tool_key, domain_hint=domain_hint) + ] + return sorted( + rows, + key=lambda row: (-_budget_specificity(row), row["created_at"], row["id"]), + ) + + +def _execution_matches_budget(row: ToolExecutionRow, budget: ExecutionBudgetRow) -> bool: + if cast(str, row["status"]) != "completed": + return False + + tool = cast(dict[str, object], row["tool"]) + request = cast(dict[str, object], row["request"]) + + if budget["tool_key"] is not None and tool.get("tool_key") != budget["tool_key"]: + return False + if budget["domain_hint"] is not None and request.get("domain_hint") != budget["domain_hint"]: + return False + return True + + +def _current_time(store: ContinuityStore) -> datetime: + current_time = getattr(store, "current_time", None) + if callable(current_time): + value = current_time() + if isinstance(value, datetime): + return value + return datetime.now(UTC) + + +def _window_started_at( + *, + evaluation_time: datetime, + rolling_window_seconds: int | None, +) -> datetime | None: + if rolling_window_seconds is None: + return None + return evaluation_time - timedelta(seconds=rolling_window_seconds) + + +def _counted_completed_execution_rows( + store: ContinuityStore, + *, + matched_budget: ExecutionBudgetRow, + evaluation_time: datetime, +) -> list[ToolExecutionRow]: + window_started_at = _window_started_at( + evaluation_time=evaluation_time, + rolling_window_seconds=matched_budget["rolling_window_seconds"], + ) + counted_rows: list[ToolExecutionRow] = [] + for row in store.list_tool_executions(): + execution_row = cast(ToolExecutionRow, row) + if not _execution_matches_budget(execution_row, matched_budget): + continue + if window_started_at is not None and execution_row["executed_at"] < window_started_at: + continue + counted_rows.append(execution_row) + return counted_rows + + +def _blocked_result( + decision: ExecutionBudgetDecisionRecord, +) -> ToolExecutionResultRecord: + matched_budget_id = decision["matched_budget_id"] + max_completed_executions = decision["max_completed_executions"] + projected_completed_execution_count = decision["projected_completed_execution_count"] + rolling_window_seconds = decision["rolling_window_seconds"] + if rolling_window_seconds is None: + reason = ( + f"execution budget {matched_budget_id} blocks execution: projected completed executions " + f"{projected_completed_execution_count} would exceed limit {max_completed_executions}" + ) + else: + reason = ( + f"execution budget {matched_budget_id} blocks execution: projected completed executions " + f"{projected_completed_execution_count} within rolling window {rolling_window_seconds} " + f"seconds would exceed limit {max_completed_executions}" + ) + return { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": reason, + "budget_decision": decision, + } + + +def evaluate_execution_budget( + store: ContinuityStore, + *, + tool: ToolRecord, + request: ToolRoutingRequestRecord, +) -> ExecutionBudgetDecision: + matching_budgets = _matching_budget_rows( + store, + tool_key=tool["tool_key"], + domain_hint=request["domain_hint"], + ) + matched_budget = matching_budgets[0] if matching_budgets else None + evaluation_time = _current_time(store) + window_started_at = ( + None + if matched_budget is None + else _window_started_at( + evaluation_time=evaluation_time, + rolling_window_seconds=matched_budget["rolling_window_seconds"], + ) + ) + completed_execution_count = 0 + projected_completed_execution_count = 1 + + if matched_budget is not None: + completed_execution_count = len( + _counted_completed_execution_rows( + store, + matched_budget=matched_budget, + evaluation_time=evaluation_time, + ) + ) + projected_completed_execution_count = completed_execution_count + 1 + + record: ExecutionBudgetDecisionRecord = { + "matched_budget_id": None if matched_budget is None else str(matched_budget["id"]), + "tool_key": tool["tool_key"], + "domain_hint": request["domain_hint"], + "budget_tool_key": None if matched_budget is None else matched_budget["tool_key"], + "budget_domain_hint": None if matched_budget is None else matched_budget["domain_hint"], + "max_completed_executions": ( + None if matched_budget is None else matched_budget["max_completed_executions"] + ), + "rolling_window_seconds": ( + None if matched_budget is None else matched_budget["rolling_window_seconds"] + ), + "count_scope": ( + "lifetime" + if matched_budget is None or matched_budget["rolling_window_seconds"] is None + else "rolling_window" + ), + "window_started_at": None if window_started_at is None else window_started_at.isoformat(), + "completed_execution_count": completed_execution_count, + "projected_completed_execution_count": projected_completed_execution_count, + "decision": "allow", + "reason": "no_matching_budget", + "order": list(EXECUTION_BUDGET_MATCH_ORDER), + "history_order": list(TOOL_EXECUTION_LIST_ORDER), + } + + if matched_budget is None: + return ExecutionBudgetDecision(record=record, blocked_result=None) + + if projected_completed_execution_count <= matched_budget["max_completed_executions"]: + record["reason"] = "within_budget" + return ExecutionBudgetDecision(record=record, blocked_result=None) + + record["decision"] = "block" + record["reason"] = "budget_exceeded" + blocked_result = _blocked_result(record) + return ExecutionBudgetDecision(record=record, blocked_result=blocked_result) diff --git a/apps/api/src/alicebot_api/executions.py b/apps/api/src/alicebot_api/executions.py new file mode 100644 index 0000000..5bb0740 --- /dev/null +++ b/apps/api/src/alicebot_api/executions.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import cast +from uuid import UUID + +from alicebot_api.contracts import ( + TOOL_EXECUTION_LIST_ORDER, + ToolExecutionDetailResponse, + ToolExecutionListResponse, + ToolExecutionListSummary, + ToolExecutionRecord, +) +from alicebot_api.store import ContinuityStore, ToolExecutionRow + + +class ToolExecutionNotFoundError(LookupError): + """Raised when an execution record is not visible inside the current user scope.""" + + +def serialize_tool_execution_row(row: ToolExecutionRow) -> ToolExecutionRecord: + return { + "id": str(row["id"]), + "approval_id": str(row["approval_id"]), + "task_step_id": str(row["task_step_id"]), + "thread_id": str(row["thread_id"]), + "tool_id": str(row["tool_id"]), + "trace_id": str(row["trace_id"]), + "request_event_id": None if row["request_event_id"] is None else str(row["request_event_id"]), + "result_event_id": None if row["result_event_id"] is None else str(row["result_event_id"]), + "status": cast(str, row["status"]), + "handler_key": row["handler_key"], + "request": cast(dict[str, object], row["request"]), + "tool": cast(dict[str, object], row["tool"]), + "result": cast(dict[str, object], row["result"]), + "executed_at": row["executed_at"].isoformat(), + } + + +def list_tool_execution_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ToolExecutionListResponse: + del user_id + + items = [serialize_tool_execution_row(row) for row in store.list_tool_executions()] + summary: ToolExecutionListSummary = { + "total_count": len(items), + "order": list(TOOL_EXECUTION_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_tool_execution_record( + store: ContinuityStore, + *, + user_id: UUID, + execution_id: UUID, +) -> ToolExecutionDetailResponse: + del user_id + + execution = store.get_tool_execution_optional(execution_id) + if execution is None: + raise ToolExecutionNotFoundError(f"tool execution {execution_id} was not found") + return {"execution": serialize_tool_execution_row(execution)} diff --git a/apps/api/src/alicebot_api/explicit_preferences.py b/apps/api/src/alicebot_api/explicit_preferences.py new file mode 100644 index 0000000..f451426 --- /dev/null +++ b/apps/api/src/alicebot_api/explicit_preferences.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import hashlib +import re +from collections.abc import Sequence +from typing import Literal +from uuid import UUID + +from alicebot_api.contracts import ( + AdmissionDecisionOutput, + ExplicitPreferenceAdmissionRecord, + ExplicitPreferenceExtractionRequestInput, + ExplicitPreferenceExtractionResponse, + ExplicitPreferenceExtractionSummary, + ExplicitPreferencePattern, + ExtractedPreferenceCandidateRecord, + MemoryCandidateInput, +) +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore, EventRow, JsonObject + +PreferenceKind = Literal["like", "dislike", "prefer"] +_DIRECT_PATTERNS: tuple[tuple[ExplicitPreferencePattern, PreferenceKind, re.Pattern[str]], ...] = ( + ("i_like", "like", re.compile(r"^i like (?P.+)$", re.IGNORECASE)), + ("i_dont_like", "dislike", re.compile(r"^i don't like (?P.+)$", re.IGNORECASE)), + ("i_prefer", "prefer", re.compile(r"^i prefer (?P.+)$", re.IGNORECASE)), +) +_REMEMBER_PREFIX = "remember that " +_TRAILING_PUNCTUATION = ".!?" +_MEMORY_KEY_PREFIX = "user.preference." +_MAX_MEMORY_KEY_LENGTH = 200 +_MEMORY_KEY_HASH_LENGTH = 12 +_MAX_SUBJECT_TOKENS = 6 +_ALLOWED_SUBJECT_TOKEN = re.compile(r"^[a-z0-9][a-z0-9+#&./+'-]*$", re.IGNORECASE) +_DISALLOWED_SUBJECT_PREFIX_TOKENS = { + "that", + "to", + "if", + "when", + "because", + "whether", + "we", + "you", + "they", + "he", + "she", + "it", + "there", + "this", +} +_REMEMBER_PATTERN_MAP: dict[ExplicitPreferencePattern, ExplicitPreferencePattern] = { + "i_like": "remember_that_i_like", + "i_dont_like": "remember_that_i_dont_like", + "i_prefer": "remember_that_i_prefer", + "remember_that_i_like": "remember_that_i_like", + "remember_that_i_dont_like": "remember_that_i_dont_like", + "remember_that_i_prefer": "remember_that_i_prefer", +} + + +class ExplicitPreferenceExtractionValidationError(ValueError): + """Raised when an explicit-preference extraction request is invalid.""" + + +def _normalize_whitespace(value: str) -> str: + return re.sub(r"\s+", " ", value).strip() + + +def _normalize_subject(subject: str) -> str: + normalized = _normalize_whitespace(subject) + normalized = normalized.rstrip(_TRAILING_PUNCTUATION).strip() + return normalized + + +def _canonicalize_subject_for_key(subject: str) -> str: + return subject.casefold() + + +def _subject_has_supported_shape(subject: str) -> bool: + tokens = subject.split(" ") + if not tokens or len(tokens) > _MAX_SUBJECT_TOKENS: + return False + + if tokens[0].casefold() in _DISALLOWED_SUBJECT_PREFIX_TOKENS: + return False + + return all(_ALLOWED_SUBJECT_TOKEN.fullmatch(token) is not None for token in tokens) + + +def _slugify_subject(subject: str, *, max_length: int) -> str: + slug = subject.casefold() + slug = slug.replace("'", "") + slug = re.sub(r"[^a-z0-9]+", "_", slug) + slug = slug.strip("_") + if len(slug) > max_length: + slug = slug[:max_length].rstrip("_") + return slug + + +def _build_memory_key(subject: str) -> str: + canonical_subject = _canonicalize_subject_for_key(subject) + digest = hashlib.sha256(canonical_subject.encode("utf-8")).hexdigest()[:_MEMORY_KEY_HASH_LENGTH] + max_slug_length = _MAX_MEMORY_KEY_LENGTH - len(_MEMORY_KEY_PREFIX) - len("__") - len(digest) + slug = _slugify_subject(canonical_subject, max_length=max_slug_length) + if not slug: + return f"{_MEMORY_KEY_PREFIX}{digest}" + return f"{_MEMORY_KEY_PREFIX}{slug}__{digest}" + + +def _build_candidate( + *, + source_event_id: UUID, + pattern: ExplicitPreferencePattern, + preference: PreferenceKind, + subject_text: str, +) -> ExtractedPreferenceCandidateRecord | None: + normalized_subject = _normalize_subject(subject_text) + if not normalized_subject: + return None + + if not _subject_has_supported_shape(normalized_subject): + return None + + value: JsonObject = { + "kind": "explicit_preference", + "preference": preference, + "text": normalized_subject, + } + return { + "memory_key": _build_memory_key(normalized_subject), + "value": value, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + "pattern": pattern, + "subject_text": normalized_subject, + } + + +def extract_explicit_preference_candidates( + *, + source_event_id: UUID, + text: str, +) -> list[ExtractedPreferenceCandidateRecord]: + normalized_text = _normalize_whitespace(text) + if not normalized_text: + return [] + + for pattern_name, preference, pattern in _DIRECT_PATTERNS: + match = pattern.fullmatch(normalized_text) + if match is not None: + candidate = _build_candidate( + source_event_id=source_event_id, + pattern=pattern_name, + preference=preference, + subject_text=match.group("subject"), + ) + return [] if candidate is None else [candidate] + + lowered_text = normalized_text.lower() + if lowered_text.startswith(_REMEMBER_PREFIX): + nested_text = normalized_text[len(_REMEMBER_PREFIX) :] + nested_candidates = extract_explicit_preference_candidates( + source_event_id=source_event_id, + text=nested_text, + ) + if not nested_candidates: + return [] + candidate = dict(nested_candidates[0]) + candidate["pattern"] = _REMEMBER_PATTERN_MAP[candidate["pattern"]] + return [candidate] + + return [] + + +def _get_single_source_event(store: ContinuityStore, source_event_id: UUID) -> EventRow: + events = store.list_events_by_ids([source_event_id]) + if not events: + raise ExplicitPreferenceExtractionValidationError( + "source_event_id must reference an existing message.user event owned by the user" + ) + return events[0] + + +def _extract_text_payload(event: EventRow) -> str: + if event["kind"] != "message.user": + raise ExplicitPreferenceExtractionValidationError( + "source_event_id must reference an existing message.user event owned by the user" + ) + + payload_text = event["payload"].get("text") + if not isinstance(payload_text, str): + raise ExplicitPreferenceExtractionValidationError( + "source_event_id must reference a message.user event with string payload.text" + ) + + return payload_text + + +def _serialize_admission(decision: AdmissionDecisionOutput) -> ExplicitPreferenceAdmissionRecord: + return { + "decision": decision.action, + "reason": decision.reason, + "memory": decision.memory, + "revision": decision.revision, + } + + +def _build_summary( + *, + source_event_id: UUID, + source_event_kind: str, + admissions: Sequence[ExplicitPreferenceAdmissionRecord], + candidates: Sequence[ExtractedPreferenceCandidateRecord], +) -> ExplicitPreferenceExtractionSummary: + noop_count = sum(1 for admission in admissions if admission["decision"] == "NOOP") + return { + "source_event_id": str(source_event_id), + "source_event_kind": source_event_kind, + "candidate_count": len(candidates), + "admission_count": len(admissions), + "persisted_change_count": len(admissions) - noop_count, + "noop_count": noop_count, + } + + +def extract_and_admit_explicit_preferences( + store: ContinuityStore, + *, + user_id: UUID, + request: ExplicitPreferenceExtractionRequestInput, +) -> ExplicitPreferenceExtractionResponse: + source_event = _get_single_source_event(store, request.source_event_id) + payload_text = _extract_text_payload(source_event) + candidates = extract_explicit_preference_candidates( + source_event_id=request.source_event_id, + text=payload_text, + ) + + admissions: list[ExplicitPreferenceAdmissionRecord] = [] + for candidate in candidates: + decision = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key=candidate["memory_key"], + value=candidate["value"], + source_event_ids=(request.source_event_id,), + delete_requested=candidate["delete_requested"], + ), + ) + admissions.append(_serialize_admission(decision)) + + return { + "candidates": list(candidates), + "admissions": admissions, + "summary": _build_summary( + source_event_id=request.source_event_id, + source_event_kind=source_event["kind"], + admissions=admissions, + candidates=candidates, + ), + } diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py new file mode 100644 index 0000000..764812d --- /dev/null +++ b/apps/api/src/alicebot_api/main.py @@ -0,0 +1,1837 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal, TypedDict +from uuid import UUID +from fastapi import FastAPI, Query +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel, Field +from fastapi.responses import JSONResponse +from urllib.parse import urlsplit, urlunsplit + +from alicebot_api.compiler import compile_and_persist_trace +from alicebot_api.config import Settings, get_settings +from alicebot_api.contracts import ( + ApprovalApproveInput, + ApprovalRejectInput, + ApprovalRequestCreateInput, + ConsentStatus, + ConsentUpsertInput, + CompileContextSemanticRetrievalInput, + DEFAULT_MAX_EVENTS, + DEFAULT_MAX_ENTITY_EDGES, + DEFAULT_MAX_ENTITIES, + DEFAULT_MAX_MEMORIES, + DEFAULT_MEMORY_REVIEW_LIMIT, + DEFAULT_MAX_SESSIONS, + DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + MAX_MEMORY_REVIEW_LIMIT, + MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ContextCompilerLimits, + EmbeddingConfigStatus, + EmbeddingConfigCreateInput, + ExecutionBudgetCreateInput, + ExecutionBudgetDeactivateInput, + ExecutionBudgetSupersedeInput, + EntityEdgeCreateInput, + EntityCreateInput, + EntityType, + ExplicitPreferenceExtractionRequestInput, + MemoryCandidateInput, + MemoryEmbeddingUpsertInput, + MemoryReviewLabelValue, + MemoryReviewStatusFilter, + PolicyCreateInput, + PolicyEffect, + PolicyEvaluationRequestInput, + SemanticMemoryRetrievalRequestInput, + TOOL_METADATA_VERSION_V0, + ApprovalStatus, + ProxyExecutionStatus, + ToolAllowlistEvaluationRequestInput, + ProxyExecutionRequestInput, + TaskStepKind, + TaskStepLineageInput, + TaskStepNextCreateInput, + TaskStepStatus, + TaskStepTransitionInput, + TaskWorkspaceCreateInput, + ToolRoutingDecision, + ToolRoutingRequestInput, + ToolCreateInput, +) +from alicebot_api.approvals import ( + ApprovalNotFoundError, + ApprovalResolutionConflictError, + approve_approval_record, + get_approval_record, + list_approval_records, + reject_approval_record, + submit_approval_request, +) +from alicebot_api.db import ping_database, user_connection +from alicebot_api.executions import ( + ToolExecutionNotFoundError, + get_tool_execution_record, + list_tool_execution_records, +) +from alicebot_api.tasks import ( + TaskNotFoundError, + TaskStepApprovalLinkageError, + TaskStepExecutionLinkageError, + TaskStepLifecycleBoundaryError, + TaskStepSequenceError, + TaskStepNotFoundError, + TaskStepTransitionError, + create_next_task_step_record, + get_task_record, + get_task_step_record, + list_task_records, + list_task_step_records, + transition_task_step_record, +) +from alicebot_api.workspaces import ( + TaskWorkspaceAlreadyExistsError, + TaskWorkspaceNotFoundError, + TaskWorkspaceProvisioningError, + create_task_workspace_record, + get_task_workspace_record, + list_task_workspace_records, +) +from alicebot_api.execution_budgets import ( + ExecutionBudgetLifecycleError, + ExecutionBudgetNotFoundError, + ExecutionBudgetValidationError, + create_execution_budget_record, + deactivate_execution_budget_record, + get_execution_budget_record, + list_execution_budget_records, + supersede_execution_budget_record, +) +from alicebot_api.embedding import ( + EmbeddingConfigValidationError, + MemoryEmbeddingNotFoundError, + MemoryEmbeddingValidationError, + create_embedding_config_record, + get_memory_embedding_record, + list_embedding_config_records, + list_memory_embedding_records, + upsert_memory_embedding_record, +) +from alicebot_api.entity import ( + EntityNotFoundError, + EntityValidationError, + create_entity_record, + get_entity_record, + list_entity_records, +) +from alicebot_api.entity_edge import ( + EntityEdgeValidationError, + create_entity_edge_record, + list_entity_edge_records, +) +from alicebot_api.explicit_preferences import ( + ExplicitPreferenceExtractionValidationError, + extract_and_admit_explicit_preferences, +) +from alicebot_api.memory import ( + MemoryAdmissionValidationError, + MemoryReviewNotFoundError, + admit_memory_candidate, + create_memory_review_label_record, + get_memory_evaluation_summary, + get_memory_review_record, + list_memory_review_queue_records, + list_memory_review_label_records, + list_memory_review_records, + list_memory_revision_review_records, +) +from alicebot_api.policy import ( + PolicyEvaluationValidationError, + PolicyNotFoundError, + PolicyValidationError, + create_policy_record, + evaluate_policy_request, + get_policy_record, + list_consent_records, + list_policy_records, + upsert_consent_record, +) +from alicebot_api.tools import ( + ToolAllowlistValidationError, + ToolNotFoundError, + ToolRoutingValidationError, + ToolValidationError, + create_tool_record, + evaluate_tool_allowlist, + get_tool_record, + list_tool_records, + route_tool_invocation, +) +from alicebot_api.semantic_retrieval import ( + SemanticMemoryRetrievalValidationError, + retrieve_semantic_memory_records, +) +from alicebot_api.response_generation import ( + ResponseFailure, + generate_response, +) +from alicebot_api.proxy_execution import ( + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, + execute_approved_proxy_request, +) +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError + + +app = FastAPI(title="AliceBot API", version="0.1.0") +HealthStatus = Literal["ok", "degraded"] +ServiceStatus = Literal["ok", "unreachable", "not_checked"] + + +class DatabaseServicePayload(TypedDict): + status: Literal["ok", "unreachable"] + + +class RedisServicePayload(TypedDict): + status: Literal["not_checked"] + url: str + + +class ObjectStorageServicePayload(TypedDict): + status: Literal["not_checked"] + endpoint_url: str + + +class HealthServicesPayload(TypedDict): + database: DatabaseServicePayload + redis: RedisServicePayload + object_storage: ObjectStorageServicePayload + + +class HealthcheckPayload(TypedDict): + status: HealthStatus + environment: str + services: HealthServicesPayload + + +class CompileContextSemanticRequest(BaseModel): + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ge=1, + le=MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ) + + +class CompileContextRequest(BaseModel): + user_id: UUID + thread_id: UUID + max_sessions: int = Field(default=DEFAULT_MAX_SESSIONS, ge=0, le=25) + max_events: int = Field(default=DEFAULT_MAX_EVENTS, ge=0, le=200) + max_memories: int = Field(default=DEFAULT_MAX_MEMORIES, ge=0, le=50) + max_entities: int = Field(default=DEFAULT_MAX_ENTITIES, ge=0, le=50) + max_entity_edges: int = Field(default=DEFAULT_MAX_ENTITY_EDGES, ge=0, le=100) + semantic: CompileContextSemanticRequest | None = None + + +class GenerateResponseRequest(BaseModel): + user_id: UUID + thread_id: UUID + message: str = Field(min_length=1, max_length=8000) + max_sessions: int = Field(default=DEFAULT_MAX_SESSIONS, ge=0, le=25) + max_events: int = Field(default=DEFAULT_MAX_EVENTS, ge=0, le=200) + max_memories: int = Field(default=DEFAULT_MAX_MEMORIES, ge=0, le=50) + max_entities: int = Field(default=DEFAULT_MAX_ENTITIES, ge=0, le=50) + max_entity_edges: int = Field(default=DEFAULT_MAX_ENTITY_EDGES, ge=0, le=100) + + +class AdmitMemoryRequest(BaseModel): + user_id: UUID + memory_key: str = Field(min_length=1, max_length=200) + value: object | None = None + source_event_ids: list[UUID] = Field(min_length=1) + delete_requested: bool = False + + +class ExtractExplicitPreferencesRequest(BaseModel): + user_id: UUID + source_event_id: UUID + + +class CreateMemoryReviewLabelRequest(BaseModel): + user_id: UUID + label: MemoryReviewLabelValue + note: str | None = Field(default=None, min_length=1, max_length=280) + + +class CreateEntityRequest(BaseModel): + user_id: UUID + entity_type: EntityType + name: str = Field(min_length=1, max_length=200) + source_memory_ids: list[UUID] = Field(min_length=1) + + +class CreateEntityEdgeRequest(BaseModel): + user_id: UUID + from_entity_id: UUID + to_entity_id: UUID + relationship_type: str = Field(min_length=1, max_length=100) + valid_from: datetime | None = None + valid_to: datetime | None = None + source_memory_ids: list[UUID] = Field(min_length=1) + + +class CreateEmbeddingConfigRequest(BaseModel): + user_id: UUID + provider: str = Field(min_length=1, max_length=100) + model: str = Field(min_length=1, max_length=200) + version: str = Field(min_length=1, max_length=100) + dimensions: int = Field(ge=1, le=20000) + status: EmbeddingConfigStatus = "active" + metadata: dict[str, object] = Field(default_factory=dict) + + +class UpsertMemoryEmbeddingRequest(BaseModel): + user_id: UUID + memory_id: UUID + embedding_config_id: UUID + vector: list[float] = Field(min_length=1, max_length=20000) + + +class RetrieveSemanticMemoriesRequest(BaseModel): + user_id: UUID + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ge=1, + le=MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ) + + +class UpsertConsentRequest(BaseModel): + user_id: UUID + consent_key: str = Field(min_length=1, max_length=200) + status: ConsentStatus + metadata: dict[str, object] = Field(default_factory=dict) + + +class CreatePolicyRequest(BaseModel): + user_id: UUID + name: str = Field(min_length=1, max_length=200) + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + effect: PolicyEffect + priority: int = Field(ge=0, le=1000000) + active: bool = True + conditions: dict[str, object] = Field(default_factory=dict) + required_consents: list[str] = Field(default_factory=list) + + +class EvaluatePolicyRequest(BaseModel): + user_id: UUID + thread_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + attributes: dict[str, object] = Field(default_factory=dict) + + +class CreateToolRequest(BaseModel): + user_id: UUID + tool_key: str = Field(min_length=1, max_length=200) + name: str = Field(min_length=1, max_length=200) + description: str = Field(min_length=1, max_length=500) + version: str = Field(min_length=1, max_length=100) + metadata_version: str = Field(default=TOOL_METADATA_VERSION_V0, pattern=f"^{TOOL_METADATA_VERSION_V0}$") + active: bool = True + tags: list[str] = Field(default_factory=list) + action_hints: list[str] = Field(default_factory=list, min_length=1) + scope_hints: list[str] = Field(default_factory=list, min_length=1) + domain_hints: list[str] = Field(default_factory=list) + risk_hints: list[str] = Field(default_factory=list) + metadata: dict[str, object] = Field(default_factory=dict) + + +class EvaluateToolAllowlistRequest(BaseModel): + user_id: UUID + thread_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class RouteToolRequest(BaseModel): + user_id: UUID + thread_id: UUID + tool_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class CreateApprovalRequest(BaseModel): + user_id: UUID + thread_id: UUID + tool_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class ResolveApprovalRequest(BaseModel): + user_id: UUID + + +class ExecuteApprovedProxyRequest(BaseModel): + user_id: UUID + + +class CreateTaskWorkspaceRequest(BaseModel): + user_id: UUID + + +class TaskStepRequestSnapshot(BaseModel): + thread_id: UUID + tool_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class TaskStepOutcomeRequest(BaseModel): + routing_decision: ToolRoutingDecision + approval_id: UUID | None = None + approval_status: ApprovalStatus | None = None + execution_id: UUID | None = None + execution_status: ProxyExecutionStatus | None = None + blocked_reason: str | None = Field(default=None, min_length=1, max_length=500) + + +class TaskStepLineageRequest(BaseModel): + parent_step_id: UUID + source_approval_id: UUID | None = None + source_execution_id: UUID | None = None + + +class CreateNextTaskStepRequest(BaseModel): + user_id: UUID + kind: TaskStepKind = "governed_request" + status: TaskStepStatus + request: TaskStepRequestSnapshot + outcome: TaskStepOutcomeRequest + lineage: TaskStepLineageRequest + + +class TransitionTaskStepRequest(BaseModel): + user_id: UUID + status: TaskStepStatus + outcome: TaskStepOutcomeRequest + + +class CreateExecutionBudgetRequest(BaseModel): + user_id: UUID + tool_key: str | None = Field(default=None, min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + max_completed_executions: int = Field(ge=1, le=1000000) + rolling_window_seconds: int | None = Field(default=None, ge=1) + + +class DeactivateExecutionBudgetRequest(BaseModel): + user_id: UUID + thread_id: UUID + + +class SupersedeExecutionBudgetRequest(BaseModel): + user_id: UUID + thread_id: UUID + max_completed_executions: int = Field(ge=1, le=1000000) + + +def redact_url_credentials(raw_url: str) -> str: + parsed = urlsplit(raw_url) + + if parsed.hostname is None or (parsed.username is None and parsed.password is None): + return raw_url + + hostname = parsed.hostname + if ":" in hostname and not hostname.startswith("["): + hostname = f"[{hostname}]" + + netloc = hostname + if parsed.port is not None: + netloc = f"{hostname}:{parsed.port}" + + return urlunsplit((parsed.scheme, netloc, parsed.path, parsed.query, parsed.fragment)) + + +def build_healthcheck_payload(settings: Settings, database_ok: bool) -> HealthcheckPayload: + status: HealthStatus = "ok" if database_ok else "degraded" + database_status: Literal["ok", "unreachable"] = "ok" if database_ok else "unreachable" + + return { + "status": status, + "environment": settings.app_env, + "services": { + "database": { + "status": database_status, + }, + "redis": { + "status": "not_checked", + "url": redact_url_credentials(settings.redis_url), + }, + "object_storage": { + "status": "not_checked", + "endpoint_url": settings.s3_endpoint_url, + }, + }, + } + + +@app.get("/healthz") +def healthcheck() -> JSONResponse: + settings = get_settings() + database_ok = ping_database( + settings.database_url, + settings.healthcheck_timeout_seconds, + ) + payload = build_healthcheck_payload(settings, database_ok) + status_code = 200 if payload["status"] == "ok" else 503 + return JSONResponse( + status_code=status_code, + content=payload, + ) + + +@app.post("/v0/context/compile") +def compile_context(request: CompileContextRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + result = compile_and_persist_trace( + ContinuityStore(conn), + user_id=request.user_id, + thread_id=request.thread_id, + limits=ContextCompilerLimits( + max_sessions=request.max_sessions, + max_events=request.max_events, + max_memories=request.max_memories, + max_entities=request.max_entities, + max_entity_edges=request.max_entity_edges, + ), + semantic_retrieval=( + None + if request.semantic is None + else CompileContextSemanticRetrievalInput( + embedding_config_id=request.semantic.embedding_config_id, + query_vector=tuple(request.semantic.query_vector), + limit=request.semantic.limit, + ) + ), + ) + except SemanticMemoryRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except ContinuityStoreInvariantError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder( + { + "trace_id": result.trace_id, + "trace_event_count": result.trace_event_count, + "context_pack": result.context_pack, + } + ), + ) + + +@app.post("/v0/responses") +def generate_assistant_response(request: GenerateResponseRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + result = generate_response( + store=ContinuityStore(conn), + settings=settings, + user_id=request.user_id, + thread_id=request.thread_id, + message_text=request.message, + limits=ContextCompilerLimits( + max_sessions=request.max_sessions, + max_events=request.max_events, + max_memories=request.max_memories, + max_entities=request.max_entities, + max_entity_edges=request.max_entity_edges, + ), + ) + except ContinuityStoreInvariantError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if isinstance(result, ResponseFailure): + return JSONResponse( + status_code=502, + content=jsonable_encoder( + { + "detail": result.detail, + "trace": result.trace, + } + ), + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(result), + ) + + +@app.post("/v0/memories/admit") +def admit_memory(request: AdmitMemoryRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + decision = admit_memory_candidate( + ContinuityStore(conn), + user_id=request.user_id, + candidate=MemoryCandidateInput( + memory_key=request.memory_key, + value=request.value, + source_event_ids=tuple(request.source_event_ids), + delete_requested=request.delete_requested, + ), + ) + except MemoryAdmissionValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder( + { + "decision": decision.action, + "reason": decision.reason, + "memory": decision.memory, + "revision": decision.revision, + } + ), + ) + + +@app.post("/v0/consents") +def upsert_consent(request: UpsertConsentRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = upsert_consent_record( + ContinuityStore(conn), + user_id=request.user_id, + consent=ConsentUpsertInput( + consent_key=request.consent_key, + status=request.status, + metadata=request.metadata, + ), + ) + except PolicyValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + status_code = 201 if payload["write_mode"] == "created" else 200 + return JSONResponse( + status_code=status_code, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/consents") +def list_consents(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_consent_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/policies") +def create_policy(request: CreatePolicyRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_policy_record( + ContinuityStore(conn), + user_id=request.user_id, + policy=PolicyCreateInput( + name=request.name, + action=request.action, + scope=request.scope, + effect=request.effect, + priority=request.priority, + active=request.active, + conditions=request.conditions, + required_consents=tuple(request.required_consents), + ), + ) + except PolicyValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/policies") +def list_policies(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_policy_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/policies/{policy_id}") +def get_policy(policy_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_policy_record( + ContinuityStore(conn), + user_id=user_id, + policy_id=policy_id, + ) + except PolicyNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/policies/evaluate") +def evaluate_policy(request: EvaluatePolicyRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = evaluate_policy_request( + ContinuityStore(conn), + user_id=request.user_id, + request=PolicyEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + attributes=request.attributes, + ), + ) + except PolicyEvaluationValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tools") +def create_tool(request: CreateToolRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_tool_record( + ContinuityStore(conn), + user_id=request.user_id, + tool=ToolCreateInput( + tool_key=request.tool_key, + name=request.name, + description=request.description, + version=request.version, + metadata_version=request.metadata_version, + active=request.active, + tags=tuple(request.tags), + action_hints=tuple(request.action_hints), + scope_hints=tuple(request.scope_hints), + domain_hints=tuple(request.domain_hints), + risk_hints=tuple(request.risk_hints), + metadata=request.metadata, + ), + ) + except ToolValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tools") +def list_tools(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_tool_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tools/allowlist/evaluate") +def evaluate_tools_allowlist(request: EvaluateToolAllowlistRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = evaluate_tool_allowlist( + ContinuityStore(conn), + user_id=request.user_id, + request=ToolAllowlistEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + except ToolAllowlistValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tools/route") +def route_tool(request: RouteToolRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = route_tool_invocation( + ContinuityStore(conn), + user_id=request.user_id, + request=ToolRoutingRequestInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + except ToolRoutingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/requests") +def create_approval_request(request: CreateApprovalRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = submit_approval_request( + ContinuityStore(conn), + user_id=request.user_id, + request=ApprovalRequestCreateInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + except ToolRoutingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/approvals") +def list_approvals(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_approval_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/approvals/{approval_id}") +def get_approval(approval_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_approval_record( + ContinuityStore(conn), + user_id=user_id, + approval_id=approval_id, + ) + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/{approval_id}/approve") +def approve_approval(approval_id: UUID, request: ResolveApprovalRequest) -> JSONResponse: + settings = get_settings() + resolution_error: ( + ApprovalResolutionConflictError | TaskStepApprovalLinkageError | TaskStepLifecycleBoundaryError | None + ) = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = approve_approval_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ApprovalApproveInput(approval_id=approval_id), + ) + except ( + ApprovalResolutionConflictError, + TaskStepApprovalLinkageError, + TaskStepLifecycleBoundaryError, + ) as exc: + resolution_error = exc + payload = None + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if resolution_error is not None: + return JSONResponse(status_code=409, content={"detail": str(resolution_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/{approval_id}/reject") +def reject_approval(approval_id: UUID, request: ResolveApprovalRequest) -> JSONResponse: + settings = get_settings() + resolution_error: ( + ApprovalResolutionConflictError | TaskStepApprovalLinkageError | TaskStepLifecycleBoundaryError | None + ) = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = reject_approval_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ApprovalRejectInput(approval_id=approval_id), + ) + except ( + ApprovalResolutionConflictError, + TaskStepApprovalLinkageError, + TaskStepLifecycleBoundaryError, + ) as exc: + resolution_error = exc + payload = None + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if resolution_error is not None: + return JSONResponse(status_code=409, content={"detail": str(resolution_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/{approval_id}/execute") +def execute_approved_proxy(approval_id: UUID, request: ExecuteApprovedProxyRequest) -> JSONResponse: + settings = get_settings() + execution_error: ( + ProxyExecutionApprovalStateError + | ProxyExecutionHandlerNotFoundError + | TaskStepApprovalLinkageError + | TaskStepExecutionLinkageError + | None + ) = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = execute_approved_proxy_request( + ContinuityStore(conn), + user_id=request.user_id, + request=ProxyExecutionRequestInput(approval_id=approval_id), + ) + except ( + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, + TaskStepApprovalLinkageError, + TaskStepExecutionLinkageError, + ) as exc: + execution_error = exc + payload = None + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if execution_error is not None: + return JSONResponse(status_code=409, content={"detail": str(execution_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tasks") +def list_tasks(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tasks/{task_id}") +def get_task(task_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_record( + ContinuityStore(conn), + user_id=user_id, + task_id=task_id, + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tasks/{task_id}/workspace") +def create_task_workspace(task_id: UUID, request: CreateTaskWorkspaceRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_task_workspace_record( + ContinuityStore(conn), + settings=settings, + user_id=request.user_id, + request=TaskWorkspaceCreateInput( + task_id=task_id, + status="active", + ), + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except (TaskWorkspaceAlreadyExistsError, TaskWorkspaceProvisioningError) as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-workspaces") +def list_task_workspaces(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_workspace_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-workspaces/{task_workspace_id}") +def get_task_workspace(task_workspace_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_workspace_record( + ContinuityStore(conn), + user_id=user_id, + task_workspace_id=task_workspace_id, + ) + except TaskWorkspaceNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tasks/{task_id}/steps") +def list_task_steps(task_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_step_records( + ContinuityStore(conn), + user_id=user_id, + task_id=task_id, + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-steps/{task_step_id}") +def get_task_step(task_step_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_step_record( + ContinuityStore(conn), + user_id=user_id, + task_step_id=task_step_id, + ) + except TaskStepNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tasks/{task_id}/steps") +def create_next_task_step(task_id: UUID, request: CreateNextTaskStepRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_next_task_step_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskStepNextCreateInput( + task_id=task_id, + kind=request.kind, + status=request.status, + request=request.request.model_dump(mode="json"), + outcome=request.outcome.model_dump(mode="json"), + lineage=TaskStepLineageInput( + parent_step_id=request.lineage.parent_step_id, + source_approval_id=request.lineage.source_approval_id, + source_execution_id=request.lineage.source_execution_id, + ), + ), + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskStepSequenceError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/task-steps/{task_step_id}/transition") +def transition_task_step(task_step_id: UUID, request: TransitionTaskStepRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = transition_task_step_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskStepTransitionInput( + task_step_id=task_step_id, + status=request.status, + outcome=request.outcome.model_dump(mode="json"), + ), + ) + except TaskStepNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskStepTransitionError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/execution-budgets") +def create_execution_budget(request: CreateExecutionBudgetRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_execution_budget_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ExecutionBudgetCreateInput( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + max_completed_executions=request.max_completed_executions, + rolling_window_seconds=request.rolling_window_seconds, + ), + ) + except ExecutionBudgetValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/execution-budgets") +def list_execution_budgets(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_execution_budget_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/execution-budgets/{execution_budget_id}") +def get_execution_budget(execution_budget_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_execution_budget_record( + ContinuityStore(conn), + user_id=user_id, + execution_budget_id=execution_budget_id, + ) + except ExecutionBudgetNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/execution-budgets/{execution_budget_id}/deactivate") +def deactivate_execution_budget( + execution_budget_id: UUID, + request: DeactivateExecutionBudgetRequest, +) -> JSONResponse: + settings = get_settings() + lifecycle_error: ExecutionBudgetLifecycleError | None = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = deactivate_execution_budget_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=request.thread_id, + execution_budget_id=execution_budget_id, + ), + ) + except ExecutionBudgetLifecycleError as exc: + lifecycle_error = exc + payload = None + except ExecutionBudgetValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except ExecutionBudgetNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if lifecycle_error is not None: + return JSONResponse(status_code=409, content={"detail": str(lifecycle_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/execution-budgets/{execution_budget_id}/supersede") +def supersede_execution_budget( + execution_budget_id: UUID, + request: SupersedeExecutionBudgetRequest, +) -> JSONResponse: + settings = get_settings() + lifecycle_error: ExecutionBudgetLifecycleError | None = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = supersede_execution_budget_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ExecutionBudgetSupersedeInput( + thread_id=request.thread_id, + execution_budget_id=execution_budget_id, + max_completed_executions=request.max_completed_executions, + ), + ) + except ExecutionBudgetLifecycleError as exc: + lifecycle_error = exc + payload = None + except ExecutionBudgetValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except ExecutionBudgetNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if lifecycle_error is not None: + return JSONResponse(status_code=409, content={"detail": str(lifecycle_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tool-executions") +def list_tool_executions(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_tool_execution_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tool-executions/{execution_id}") +def get_tool_execution(execution_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_tool_execution_record( + ContinuityStore(conn), + user_id=user_id, + execution_id=execution_id, + ) + except ToolExecutionNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tools/{tool_id}") +def get_tool(tool_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_tool_record( + ContinuityStore(conn), + user_id=user_id, + tool_id=tool_id, + ) + except ToolNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memories/extract-explicit-preferences") +def extract_explicit_preferences(request: ExtractExplicitPreferencesRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = extract_and_admit_explicit_preferences( + ContinuityStore(conn), + user_id=request.user_id, + request=ExplicitPreferenceExtractionRequestInput( + source_event_id=request.source_event_id, + ), + ) + except ExplicitPreferenceExtractionValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except MemoryAdmissionValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories") +def list_memories( + user_id: UUID, + status: MemoryReviewStatusFilter = Query(default="active"), + limit: int = Query(default=DEFAULT_MEMORY_REVIEW_LIMIT, ge=1, le=MAX_MEMORY_REVIEW_LIMIT), +) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_review_records( + ContinuityStore(conn), + user_id=user_id, + status=status, + limit=limit, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/review-queue") +def list_memory_review_queue( + user_id: UUID, + limit: int = Query(default=DEFAULT_MEMORY_REVIEW_LIMIT, ge=1, le=MAX_MEMORY_REVIEW_LIMIT), +) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_review_queue_records( + ContinuityStore(conn), + user_id=user_id, + limit=limit, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/evaluation-summary") +def get_memories_evaluation_summary(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = get_memory_evaluation_summary( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memories/semantic-retrieval") +def retrieve_semantic_memories(request: RetrieveSemanticMemoriesRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = retrieve_semantic_memory_records( + ContinuityStore(conn), + user_id=request.user_id, + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=request.embedding_config_id, + query_vector=tuple(request.query_vector), + limit=request.limit, + ), + ) + except SemanticMemoryRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}") +def get_memory( + memory_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_memory_review_record( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}/revisions") +def list_memory_revisions( + memory_id: UUID, + user_id: UUID, + limit: int = Query(default=DEFAULT_MEMORY_REVIEW_LIMIT, ge=1, le=MAX_MEMORY_REVIEW_LIMIT), +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_revision_review_records( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + limit=limit, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memories/{memory_id}/labels") +def create_memory_review_label( + memory_id: UUID, + request: CreateMemoryReviewLabelRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_memory_review_label_record( + ContinuityStore(conn), + user_id=request.user_id, + memory_id=memory_id, + label=request.label, + note=request.note, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}/labels") +def list_memory_review_labels( + memory_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_review_label_records( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/embedding-configs") +def create_embedding_config(request: CreateEmbeddingConfigRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_embedding_config_record( + ContinuityStore(conn), + user_id=request.user_id, + config=EmbeddingConfigCreateInput( + provider=request.provider, + model=request.model, + version=request.version, + dimensions=request.dimensions, + status=request.status, + metadata=request.metadata, + ), + ) + except EmbeddingConfigValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/embedding-configs") +def list_embedding_configs(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_embedding_config_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memory-embeddings") +def upsert_memory_embedding(request: UpsertMemoryEmbeddingRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = upsert_memory_embedding_record( + ContinuityStore(conn), + user_id=request.user_id, + request=MemoryEmbeddingUpsertInput( + memory_id=request.memory_id, + embedding_config_id=request.embedding_config_id, + vector=tuple(request.vector), + ), + ) + except MemoryEmbeddingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}/embeddings") +def list_memory_embeddings(memory_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_embedding_records( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + ) + except MemoryEmbeddingNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memory-embeddings/{memory_embedding_id}") +def get_memory_embedding(memory_embedding_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_memory_embedding_record( + ContinuityStore(conn), + user_id=user_id, + memory_embedding_id=memory_embedding_id, + ) + except MemoryEmbeddingNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/entities") +def create_entity(request: CreateEntityRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_entity_record( + ContinuityStore(conn), + user_id=request.user_id, + entity=EntityCreateInput( + entity_type=request.entity_type, + name=request.name, + source_memory_ids=tuple(request.source_memory_ids), + ), + ) + except EntityValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/entity-edges") +def create_entity_edge(request: CreateEntityEdgeRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_entity_edge_record( + ContinuityStore(conn), + user_id=request.user_id, + edge=EntityEdgeCreateInput( + from_entity_id=request.from_entity_id, + to_entity_id=request.to_entity_id, + relationship_type=request.relationship_type, + valid_from=request.valid_from, + valid_to=request.valid_to, + source_memory_ids=tuple(request.source_memory_ids), + ), + ) + except EntityEdgeValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/entities") +def list_entities(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_entity_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/entities/{entity_id}/edges") +def list_entity_edges(entity_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_entity_edge_records( + ContinuityStore(conn), + user_id=user_id, + entity_id=entity_id, + ) + except EntityNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/entities/{entity_id}") +def get_entity(entity_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_entity_record( + ContinuityStore(conn), + user_id=user_id, + entity_id=entity_id, + ) + except EntityNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) diff --git a/apps/api/src/alicebot_api/memory.py b/apps/api/src/alicebot_api/memory.py new file mode 100644 index 0000000..3c5ebc3 --- /dev/null +++ b/apps/api/src/alicebot_api/memory.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +from uuid import UUID + +from alicebot_api.contracts import ( + AdmissionDecisionOutput, + DEFAULT_MEMORY_REVIEW_LIMIT, + MEMORY_REVIEW_LABEL_ORDER, + MEMORY_REVIEW_LABEL_VALUES, + MEMORY_REVIEW_QUEUE_ORDER, + MEMORY_REVISION_REVIEW_ORDER, + MEMORY_REVIEW_ORDER, + MemoryCandidateInput, + MemoryEvaluationSummary, + MemoryEvaluationSummaryResponse, + MemoryReviewLabelCounts, + MemoryReviewLabelCreateResponse, + MemoryReviewLabelListResponse, + MemoryReviewLabelRecord, + MemoryReviewLabelSummary, + MemoryReviewLabelValue, + MemoryReviewQueueItem, + MemoryReviewQueueResponse, + MemoryReviewQueueSummary, + MemoryRevisionReviewListResponse, + MemoryRevisionReviewListSummary, + MemoryRevisionReviewRecord, + MemoryReviewDetailResponse, + MemoryReviewListResponse, + MemoryReviewListSummary, + MemoryReviewRecord, + MemoryReviewStatusFilter, + PersistedMemoryRecord, + PersistedMemoryRevisionRecord, + isoformat_or_none, +) +from alicebot_api.store import ContinuityStore, JsonObject, LabelCountRow, MemoryReviewLabelRow, MemoryRevisionRow, MemoryRow + + +class MemoryAdmissionValidationError(ValueError): + """Raised when an admission request fails explicit candidate validation.""" + + +class MemoryReviewNotFoundError(LookupError): + """Raised when a requested memory is not visible inside the current user scope.""" + + +def _serialize_memory(memory: MemoryRow) -> PersistedMemoryRecord: + return { + "id": str(memory["id"]), + "user_id": str(memory["user_id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "deleted_at": isoformat_or_none(memory["deleted_at"]), + } + + +def _serialize_memory_revision(revision: MemoryRevisionRow) -> PersistedMemoryRevisionRecord: + return { + "id": str(revision["id"]), + "user_id": str(revision["user_id"]), + "memory_id": str(revision["memory_id"]), + "sequence_no": revision["sequence_no"], + "action": revision["action"], + "memory_key": revision["memory_key"], + "previous_value": revision["previous_value"], + "new_value": revision["new_value"], + "source_event_ids": revision["source_event_ids"], + "candidate": revision["candidate"], + "created_at": revision["created_at"].isoformat(), + } + + +def _serialize_memory_review(memory: MemoryRow) -> MemoryReviewRecord: + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "deleted_at": isoformat_or_none(memory["deleted_at"]), + } + + +def _serialize_memory_review_queue_item(memory: MemoryRow) -> MemoryReviewQueueItem: + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + } + + +def _serialize_memory_revision_review(revision: MemoryRevisionRow) -> MemoryRevisionReviewRecord: + return { + "id": str(revision["id"]), + "memory_id": str(revision["memory_id"]), + "sequence_no": revision["sequence_no"], + "action": revision["action"], + "memory_key": revision["memory_key"], + "previous_value": revision["previous_value"], + "new_value": revision["new_value"], + "source_event_ids": revision["source_event_ids"], + "created_at": revision["created_at"].isoformat(), + } + + +def _serialize_memory_review_label(label: MemoryReviewLabelRow) -> MemoryReviewLabelRecord: + return { + "id": str(label["id"]), + "memory_id": str(label["memory_id"]), + "reviewer_user_id": str(label["user_id"]), + "label": label["label"], + "note": label["note"], + "created_at": label["created_at"].isoformat(), + } + + +def _empty_memory_review_label_counts() -> MemoryReviewLabelCounts: + return { + "correct": 0, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + } + + +def _summarize_memory_review_label_counts(rows: list[LabelCountRow]) -> MemoryReviewLabelCounts: + counts = _empty_memory_review_label_counts() + for row in rows: + label = row["label"] + if label in counts: + counts[label] = row["count"] + return counts + + +def _build_memory_review_label_summary( + *, + memory_id: UUID, + counts: MemoryReviewLabelCounts, +) -> MemoryReviewLabelSummary: + return { + "memory_id": str(memory_id), + "total_count": sum(counts.values()), + "counts_by_label": counts, + "order": list(MEMORY_REVIEW_LABEL_ORDER), + } + + +def _normalize_memory_status_filter(status: MemoryReviewStatusFilter) -> str | None: + if status == "all": + return None + return status + + +def list_memory_review_records( + store: ContinuityStore, + *, + user_id: UUID, + status: MemoryReviewStatusFilter = "active", + limit: int = DEFAULT_MEMORY_REVIEW_LIMIT, +) -> MemoryReviewListResponse: + del user_id + + normalized_status = _normalize_memory_status_filter(status) + total_count = store.count_memories(status=normalized_status) + memories = store.list_review_memories(status=normalized_status, limit=limit) + items = [_serialize_memory_review(memory) for memory in memories] + summary: MemoryReviewListSummary = { + "status": status, + "limit": limit, + "returned_count": len(items), + "total_count": total_count, + "has_more": len(items) < total_count, + "order": list(MEMORY_REVIEW_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def list_memory_review_queue_records( + store: ContinuityStore, + *, + user_id: UUID, + limit: int = DEFAULT_MEMORY_REVIEW_LIMIT, +) -> MemoryReviewQueueResponse: + del user_id + + total_count = store.count_unlabeled_review_memories() + memories = store.list_unlabeled_review_memories(limit=limit) + items = [_serialize_memory_review_queue_item(memory) for memory in memories] + summary: MemoryReviewQueueSummary = { + "memory_status": "active", + "review_state": "unlabeled", + "limit": limit, + "returned_count": len(items), + "total_count": total_count, + "has_more": len(items) < total_count, + "order": list(MEMORY_REVIEW_QUEUE_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_memory_review_record( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, +) -> MemoryReviewDetailResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + return { + "memory": _serialize_memory_review(memory), + } + + +def list_memory_revision_review_records( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, + limit: int = DEFAULT_MEMORY_REVIEW_LIMIT, +) -> MemoryRevisionReviewListResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + total_count = store.count_memory_revisions(memory_id) + revisions = store.list_memory_revisions(memory_id, limit=limit) + items = [_serialize_memory_revision_review(revision) for revision in revisions] + summary: MemoryRevisionReviewListSummary = { + "memory_id": str(memory["id"]), + "limit": limit, + "returned_count": len(items), + "total_count": total_count, + "has_more": len(items) < total_count, + "order": list(MEMORY_REVISION_REVIEW_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def create_memory_review_label_record( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, + label: MemoryReviewLabelValue, + note: str | None, +) -> MemoryReviewLabelCreateResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + created_label = store.create_memory_review_label( + memory_id=memory_id, + label=label, + note=note, + ) + counts = _summarize_memory_review_label_counts(store.list_memory_review_label_counts(memory_id)) + return { + "label": _serialize_memory_review_label(created_label), + "summary": _build_memory_review_label_summary(memory_id=memory_id, counts=counts), + } + + +def list_memory_review_label_records( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, +) -> MemoryReviewLabelListResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + items = [_serialize_memory_review_label(label) for label in store.list_memory_review_labels(memory_id)] + counts = _summarize_memory_review_label_counts(store.list_memory_review_label_counts(memory_id)) + return { + "items": items, + "summary": _build_memory_review_label_summary(memory_id=memory_id, counts=counts), + } + + +def get_memory_evaluation_summary( + store: ContinuityStore, + *, + user_id: UUID, +) -> MemoryEvaluationSummaryResponse: + del user_id + + total_memory_count = store.count_memories() + active_memory_count = store.count_memories(status="active") + deleted_memory_count = store.count_memories(status="deleted") + labeled_memory_count = store.count_labeled_memories() + unlabeled_memory_count = store.count_unlabeled_memories() + label_row_counts = _summarize_memory_review_label_counts(store.list_all_memory_review_label_counts()) + summary: MemoryEvaluationSummary = { + "total_memory_count": total_memory_count, + "active_memory_count": active_memory_count, + "deleted_memory_count": deleted_memory_count, + "labeled_memory_count": labeled_memory_count, + "unlabeled_memory_count": unlabeled_memory_count, + "total_label_row_count": sum(label_row_counts.values()), + "label_row_counts_by_value": label_row_counts, + "label_value_order": list(MEMORY_REVIEW_LABEL_VALUES), + } + return { + "summary": summary, + } + + +def _dedupe_source_event_ids(source_event_ids: tuple[UUID, ...]) -> tuple[UUID, ...]: + deduped: list[UUID] = [] + seen: set[UUID] = set() + for source_event_id in source_event_ids: + if source_event_id in seen: + continue + seen.add(source_event_id) + deduped.append(source_event_id) + return tuple(deduped) + + +def _validate_source_events(store: ContinuityStore, source_event_ids: tuple[UUID, ...]) -> list[str]: + normalized_event_ids = _dedupe_source_event_ids(source_event_ids) + if not normalized_event_ids: + raise MemoryAdmissionValidationError( + "source_event_ids must include at least one existing event owned by the user" + ) + source_events = store.list_events_by_ids(list(normalized_event_ids)) + found_event_ids = {event["id"] for event in source_events} + missing_event_ids = [ + str(source_event_id) + for source_event_id in normalized_event_ids + if source_event_id not in found_event_ids + ] + if missing_event_ids: + raise MemoryAdmissionValidationError( + "source_event_ids must all reference existing events owned by the user: " + + ", ".join(missing_event_ids) + ) + return [str(source_event_id) for source_event_id in normalized_event_ids] + + +def _candidate_payload(candidate: MemoryCandidateInput) -> JsonObject: + return candidate.as_payload() + + +def admit_memory_candidate( + store: ContinuityStore, + *, + user_id: UUID, + candidate: MemoryCandidateInput, +) -> AdmissionDecisionOutput: + del user_id + + source_event_ids = _validate_source_events(store, candidate.source_event_ids) + existing_memory = store.get_memory_by_key(candidate.memory_key) + + noop_decision = AdmissionDecisionOutput( + action="NOOP", + reason="candidate_default_noop", + memory=None, + revision=None, + ) + + if candidate.delete_requested: + if existing_memory is None or existing_memory["status"] == "deleted": + return AdmissionDecisionOutput( + action=noop_decision.action, + reason="memory_not_found_for_delete", + memory=None if existing_memory is None else _serialize_memory(existing_memory), + revision=None, + ) + + memory = store.update_memory( + memory_id=existing_memory["id"], + value=existing_memory["value"], + status="deleted", + source_event_ids=source_event_ids, + ) + revision = store.append_memory_revision( + memory_id=memory["id"], + action="DELETE", + memory_key=memory["memory_key"], + previous_value=existing_memory["value"], + new_value=None, + source_event_ids=source_event_ids, + candidate=_candidate_payload(candidate), + ) + return AdmissionDecisionOutput( + action="DELETE", + reason="source_backed_delete", + memory=_serialize_memory(memory), + revision=_serialize_memory_revision(revision), + ) + + if candidate.value is None: + return AdmissionDecisionOutput( + action=noop_decision.action, + reason="candidate_value_missing", + memory=None if existing_memory is None else _serialize_memory(existing_memory), + revision=None, + ) + + if existing_memory is None: + memory = store.create_memory( + memory_key=candidate.memory_key, + value=candidate.value, + status="active", + source_event_ids=source_event_ids, + ) + revision = store.append_memory_revision( + memory_id=memory["id"], + action="ADD", + memory_key=memory["memory_key"], + previous_value=None, + new_value=candidate.value, + source_event_ids=source_event_ids, + candidate=_candidate_payload(candidate), + ) + return AdmissionDecisionOutput( + action="ADD", + reason="source_backed_add", + memory=_serialize_memory(memory), + revision=_serialize_memory_revision(revision), + ) + + if existing_memory["status"] == "active" and existing_memory["value"] == candidate.value: + return AdmissionDecisionOutput( + action=noop_decision.action, + reason="memory_unchanged", + memory=_serialize_memory(existing_memory), + revision=None, + ) + + memory = store.update_memory( + memory_id=existing_memory["id"], + value=candidate.value, + status="active", + source_event_ids=source_event_ids, + ) + revision = store.append_memory_revision( + memory_id=memory["id"], + action="UPDATE", + memory_key=memory["memory_key"], + previous_value=existing_memory["value"], + new_value=candidate.value, + source_event_ids=source_event_ids, + candidate=_candidate_payload(candidate), + ) + return AdmissionDecisionOutput( + action="UPDATE", + reason="source_backed_update", + memory=_serialize_memory(memory), + revision=_serialize_memory_revision(revision), + ) diff --git a/apps/api/src/alicebot_api/migrations.py b/apps/api/src/alicebot_api/migrations.py new file mode 100644 index 0000000..52a5d15 --- /dev/null +++ b/apps/api/src/alicebot_api/migrations.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from pathlib import Path + +from alembic.config import Config + + +PROJECT_ROOT = Path(__file__).resolve().parents[4] +ALEMBIC_INI_PATH = PROJECT_ROOT / "apps" / "api" / "alembic.ini" + + +def make_alembic_config(database_url: str | None = None) -> Config: + config = Config(str(ALEMBIC_INI_PATH)) + if database_url: + config.set_main_option("sqlalchemy.url", database_url) + return config + diff --git a/apps/api/src/alicebot_api/policy.py b/apps/api/src/alicebot_api/policy.py new file mode 100644 index 0000000..68988c4 --- /dev/null +++ b/apps/api/src/alicebot_api/policy.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from alicebot_api.contracts import ( + CONSENT_LIST_ORDER, + POLICY_EVALUATION_VERSION_V0, + POLICY_LIST_ORDER, + TRACE_KIND_POLICY_EVALUATE, + ConsentListResponse, + ConsentListSummary, + ConsentRecord, + ConsentUpsertInput, + ConsentUpsertResponse, + PolicyCreateInput, + PolicyCreateResponse, + PolicyDetailResponse, + PolicyEvaluationReason, + PolicyEvaluationRequestInput, + PolicyEvaluationResponse, + PolicyEvaluationSummary, + PolicyEvaluationTraceSummary, + PolicyListResponse, + PolicyListSummary, + PolicyRecord, + isoformat_or_none, +) +from alicebot_api.store import ConsentRow, ContinuityStore, PolicyRow + + +class PolicyValidationError(ValueError): + """Raised when a policy or consent request fails explicit validation.""" + + +class PolicyNotFoundError(LookupError): + """Raised when a requested policy is not visible inside the current user scope.""" + + +class PolicyEvaluationValidationError(ValueError): + """Raised when a policy-evaluation request fails explicit validation.""" + + +@dataclass(frozen=True, slots=True) +class PolicyEvaluationContext: + active_policies: tuple[PolicyRow, ...] + consents_by_key: dict[str, ConsentRow] + + +@dataclass(frozen=True, slots=True) +class PolicyEvaluationCoreDecision: + decision: str + matched_policy: PolicyRow | None + reasons: list[PolicyEvaluationReason] + + +def _serialize_consent(consent: ConsentRow) -> ConsentRecord: + return { + "id": str(consent["id"]), + "consent_key": consent["consent_key"], + "status": consent["status"], + "metadata": consent["metadata"], + "created_at": consent["created_at"].isoformat(), + "updated_at": consent["updated_at"].isoformat(), + } + + +def _serialize_policy(policy: PolicyRow) -> PolicyRecord: + return { + "id": str(policy["id"]), + "name": policy["name"], + "action": policy["action"], + "scope": policy["scope"], + "effect": policy["effect"], + "priority": policy["priority"], + "active": policy["active"], + "conditions": policy["conditions"], + "required_consents": policy["required_consents"], + "created_at": policy["created_at"].isoformat(), + "updated_at": policy["updated_at"].isoformat(), + } + + +def _dedupe_required_consents(required_consents: tuple[str, ...]) -> list[str]: + deduped: list[str] = [] + seen: set[str] = set() + for consent_key in required_consents: + if consent_key in seen: + continue + seen.add(consent_key) + deduped.append(consent_key) + return deduped + + +def _policy_matches(policy: PolicyRow, request: PolicyEvaluationRequestInput) -> bool: + if policy["action"] != request.action or policy["scope"] != request.scope: + return False + + conditions = policy["conditions"] + for key, expected_value in conditions.items(): + if key not in request.attributes: + return False + if request.attributes[key] != expected_value: + return False + + return True + + +def _build_reason( + *, + code: str, + source: str, + message: str, + policy_id: UUID | None = None, + consent_key: str | None = None, +) -> PolicyEvaluationReason: + return { + "code": code, + "source": source, + "message": message, + "policy_id": None if policy_id is None else str(policy_id), + "consent_key": consent_key, + } + + +def upsert_consent_record( + store: ContinuityStore, + *, + user_id: UUID, + consent: ConsentUpsertInput, +) -> ConsentUpsertResponse: + del user_id + + existing = store.get_consent_by_key_optional(consent.consent_key) + if existing is None: + created = store.create_consent( + consent_key=consent.consent_key, + status=consent.status, + metadata=consent.metadata, + ) + return { + "consent": _serialize_consent(created), + "write_mode": "created", + } + + updated = store.update_consent( + consent_id=existing["id"], + status=consent.status, + metadata=consent.metadata, + ) + return { + "consent": _serialize_consent(updated), + "write_mode": "updated", + } + + +def list_consent_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ConsentListResponse: + del user_id + + items = [_serialize_consent(consent) for consent in store.list_consents()] + summary: ConsentListSummary = { + "total_count": len(items), + "order": list(CONSENT_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def create_policy_record( + store: ContinuityStore, + *, + user_id: UUID, + policy: PolicyCreateInput, +) -> PolicyCreateResponse: + del user_id + + required_consents = _dedupe_required_consents(policy.required_consents) + created = store.create_policy( + name=policy.name, + action=policy.action, + scope=policy.scope, + effect=policy.effect, + priority=policy.priority, + active=policy.active, + conditions=policy.conditions, + required_consents=required_consents, + ) + return {"policy": _serialize_policy(created)} + + +def list_policy_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> PolicyListResponse: + del user_id + + items = [_serialize_policy(policy) for policy in store.list_policies()] + summary: PolicyListSummary = { + "total_count": len(items), + "order": list(POLICY_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_policy_record( + store: ContinuityStore, + *, + user_id: UUID, + policy_id: UUID, +) -> PolicyDetailResponse: + del user_id + + policy = store.get_policy_optional(policy_id) + if policy is None: + raise PolicyNotFoundError(f"policy {policy_id} was not found") + + return {"policy": _serialize_policy(policy)} + + +def load_policy_evaluation_context(store: ContinuityStore) -> PolicyEvaluationContext: + return PolicyEvaluationContext( + active_policies=tuple(store.list_active_policies()), + consents_by_key={consent["consent_key"]: consent for consent in store.list_consents()}, + ) + + +def evaluate_policy_against_context( + context: PolicyEvaluationContext, + *, + request: PolicyEvaluationRequestInput, +) -> PolicyEvaluationCoreDecision: + matched_policy = next( + (policy for policy in context.active_policies if _policy_matches(policy, request)), + None, + ) + + reasons: list[PolicyEvaluationReason] = [] + decision = "deny" + + if matched_policy is None: + reasons.append( + _build_reason( + code="no_matching_policy", + source="system", + message="No active policy matched the requested action, scope, and attributes.", + ) + ) + return PolicyEvaluationCoreDecision( + decision=decision, + matched_policy=None, + reasons=reasons, + ) + + reasons.append( + _build_reason( + code="matched_policy", + source="policy", + message=f"Matched policy '{matched_policy['name']}' at priority {matched_policy['priority']}.", + policy_id=matched_policy["id"], + ) + ) + + missing_or_revoked = False + for consent_key in matched_policy["required_consents"]: + consent = context.consents_by_key.get(consent_key) + if consent is None: + missing_or_revoked = True + reasons.append( + _build_reason( + code="consent_missing", + source="consent", + message=f"Required consent '{consent_key}' is missing.", + policy_id=matched_policy["id"], + consent_key=consent_key, + ) + ) + continue + if consent["status"] != "granted": + missing_or_revoked = True + reasons.append( + _build_reason( + code="consent_revoked", + source="consent", + message=f"Required consent '{consent_key}' is not granted (status={consent['status']}).", + policy_id=matched_policy["id"], + consent_key=consent_key, + ) + ) + + if not missing_or_revoked: + decision = matched_policy["effect"] + effect_code = { + "allow": "policy_effect_allow", + "deny": "policy_effect_deny", + "require_approval": "policy_effect_require_approval", + }[decision] + reasons.append( + _build_reason( + code=effect_code, + source="policy", + message=f"Policy effect resolved the decision to '{decision}'.", + policy_id=matched_policy["id"], + ) + ) + + return PolicyEvaluationCoreDecision( + decision=decision, + matched_policy=matched_policy, + reasons=reasons, + ) + + +def evaluate_policy_request( + store: ContinuityStore, + *, + user_id: UUID, + request: PolicyEvaluationRequestInput, +) -> PolicyEvaluationResponse: + del user_id + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise PolicyEvaluationValidationError( + "thread_id must reference an existing thread owned by the user" + ) + + context = load_policy_evaluation_context(store) + core_decision = evaluate_policy_against_context( + context, + request=request, + ) + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_POLICY_EVALUATE, + compiler_version=POLICY_EVALUATION_VERSION_V0, + status="completed", + limits={ + "order": list(POLICY_LIST_ORDER), + "active_policy_count": len(context.active_policies), + "consent_count": len(context.consents_by_key), + }, + ) + + trace_events = [ + ( + "policy.evaluate.request", + { + "thread_id": str(request.thread_id), + "action": request.action, + "scope": request.scope, + "attributes": request.attributes, + }, + ), + ( + "policy.evaluate.order", + { + "order": list(POLICY_LIST_ORDER), + "policy_ids": [str(policy["id"]) for policy in context.active_policies], + }, + ), + ( + "policy.evaluate.decision", + { + "decision": core_decision.decision, + "matched_policy_id": ( + None if core_decision.matched_policy is None else str(core_decision.matched_policy["id"]) + ), + "reasons": core_decision.reasons, + "evaluated_policy_count": len(context.active_policies), + "consent_states": { + consent_key: { + "status": consent["status"], + "updated_at": isoformat_or_none(consent["updated_at"]), + } + for consent_key, consent in context.consents_by_key.items() + }, + }, + ), + ] + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + evaluation: PolicyEvaluationSummary = { + "action": request.action, + "scope": request.scope, + "evaluated_policy_count": len(context.active_policies), + "matched_policy_id": ( + None if core_decision.matched_policy is None else str(core_decision.matched_policy["id"]) + ), + "order": list(POLICY_LIST_ORDER), + } + trace_summary: PolicyEvaluationTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "decision": core_decision.decision, + "matched_policy": ( + None if core_decision.matched_policy is None else _serialize_policy(core_decision.matched_policy) + ), + "reasons": core_decision.reasons, + "evaluation": evaluation, + "trace": trace_summary, + } diff --git a/apps/api/src/alicebot_api/proxy_execution.py b/apps/api/src/alicebot_api/proxy_execution.py new file mode 100644 index 0000000..ce0a54c --- /dev/null +++ b/apps/api/src/alicebot_api/proxy_execution.py @@ -0,0 +1,557 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import cast +from uuid import UUID + +from alicebot_api.approvals import ApprovalNotFoundError, serialize_approval_row +from alicebot_api.contracts import ( + PROXY_EXECUTION_VERSION_V0, + EXECUTION_BUDGET_MATCH_ORDER, + TRACE_KIND_PROXY_EXECUTE, + ApprovalRecord, + ProxyExecutionApprovalTracePayload, + ProxyExecutionBudgetPrecheckTracePayload, + ProxyExecutionDispatchTracePayload, + ProxyExecutionEventSummary, + ProxyExecutionRequestEventPayload, + ProxyExecutionRequestInput, + ProxyExecutionResponse, + ProxyExecutionResultEventPayload, + ProxyExecutionResultRecord, + ProxyExecutionStatus, + ProxyExecutionSummaryTracePayload, + ProxyExecutionTraceSummary, + ToolRecord, + ToolExecutionCreateInput, + ToolExecutionResultRecord, + ToolRoutingRequestRecord, +) +from alicebot_api.execution_budgets import evaluate_execution_budget +from alicebot_api.store import ContinuityStore, JsonObject, ToolExecutionRow +from alicebot_api.tasks import ( + validate_linked_task_step_for_approval, + sync_task_step_with_execution, + sync_task_with_execution, + task_lifecycle_trace_events, + task_step_lifecycle_trace_events, +) + +PROXY_EXECUTION_REQUEST_EVENT_KIND = "tool.proxy.execution.request" +PROXY_EXECUTION_RESULT_EVENT_KIND = "tool.proxy.execution.result" + + +class ProxyExecutionApprovalStateError(RuntimeError): + """Raised when an approval is visible but not executable in its current state.""" + + +class ProxyExecutionHandlerNotFoundError(RuntimeError): + """Raised when an approved tool has no registered proxy handler.""" + + +ProxyHandler = Callable[[ToolRoutingRequestRecord, ToolRecord], ProxyExecutionResultRecord] + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _proxy_echo_handler( + request: ToolRoutingRequestRecord, + tool: ToolRecord, +) -> ProxyExecutionResultRecord: + output: JsonObject = { + "mode": "no_side_effect", + "tool_key": tool["tool_key"], + "action": request["action"], + "scope": request["scope"], + "domain_hint": request["domain_hint"], + "risk_hint": request["risk_hint"], + "attributes": request["attributes"], + } + return { + "handler_key": "proxy.echo", + "status": "completed", + "output": output, + } + + +REGISTERED_PROXY_HANDLERS: dict[str, ProxyHandler] = { + "proxy.echo": _proxy_echo_handler, +} + + +def registered_proxy_handler_keys() -> tuple[str, ...]: + return tuple(sorted(REGISTERED_PROXY_HANDLERS)) + + +def _trace_summary(trace_id: UUID, trace_events: list[tuple[str, dict[str, object]]]) -> ProxyExecutionTraceSummary: + return { + "trace_id": str(trace_id), + "trace_event_count": len(trace_events), + } + + +def _blocked_state_error(*, approval: ApprovalRecord) -> ProxyExecutionApprovalStateError: + return ProxyExecutionApprovalStateError( + f"approval {approval['id']} is {approval['status']} and cannot be executed" + ) + + +def _missing_handler_error(*, tool: ToolRecord) -> ProxyExecutionHandlerNotFoundError: + return ProxyExecutionHandlerNotFoundError( + f"tool '{tool['tool_key']}' has no registered proxy handler" + ) + + +def _tool_execution_result( + *, + handler_key: str | None, + status: ProxyExecutionStatus, + output: JsonObject | None, + reason: str | None, + budget_decision: dict[str, object] | None = None, +) -> ToolExecutionResultRecord: + payload: ToolExecutionResultRecord = { + "handler_key": handler_key, + "status": status, + "output": output, + "reason": reason, + } + if budget_decision is not None: + payload["budget_decision"] = cast(dict[str, object], budget_decision) + return payload + + +def _persist_tool_execution( + store: ContinuityStore, + *, + approval_row: dict[str, object], + task_step_id: UUID, + trace_id: UUID, + handler_key: str | None, + request: ToolRoutingRequestRecord, + tool: ToolRecord, + result: ToolExecutionResultRecord, + request_event_id: UUID | None, + result_event_id: UUID | None, +) -> ToolExecutionRow: + execution = ToolExecutionCreateInput( + approval_id=cast(UUID, approval_row["id"]), + task_step_id=task_step_id, + thread_id=cast(UUID, approval_row["thread_id"]), + tool_id=cast(UUID, approval_row["tool_id"]), + trace_id=trace_id, + request_event_id=request_event_id, + result_event_id=result_event_id, + status=result["status"], + handler_key=handler_key, + request=request, + tool=tool, + result=result, + ) + return store.create_tool_execution( + approval_id=execution.approval_id, + task_step_id=execution.task_step_id, + thread_id=execution.thread_id, + tool_id=execution.tool_id, + trace_id=execution.trace_id, + request_event_id=execution.request_event_id, + result_event_id=execution.result_event_id, + status=execution.status, + handler_key=execution.handler_key, + request=cast(JsonObject, execution.request), + tool=cast(JsonObject, execution.tool), + result=cast(JsonObject, execution.result), + ) + + +def execute_approved_proxy_request( + store: ContinuityStore, + *, + user_id: UUID, + request: ProxyExecutionRequestInput, +) -> ProxyExecutionResponse: + del user_id + + approval_row = store.get_approval_optional(request.approval_id) + if approval_row is None: + raise ApprovalNotFoundError(f"approval {request.approval_id} was not found") + _, linked_task_step = validate_linked_task_step_for_approval( + store, + approval_id=request.approval_id, + task_step_id=cast(UUID | None, approval_row["task_step_id"]), + ) + + approval = serialize_approval_row(approval_row) + linked_task_step_id = cast(str, approval["task_step_id"]) + tool = cast(ToolRecord, approval["tool"]) + routed_request = cast(ToolRoutingRequestRecord, approval["request"]) + handler = REGISTERED_PROXY_HANDLERS.get(tool["tool_key"]) + + trace = store.create_trace( + user_id=approval_row["user_id"], + thread_id=approval_row["thread_id"], + kind=TRACE_KIND_PROXY_EXECUTE, + compiler_version=PROXY_EXECUTION_VERSION_V0, + status="completed", + limits={ + "approval_status": approval["status"], + "enabled_handler_keys": list(registered_proxy_handler_keys()), + "budget_match_order": list(EXECUTION_BUDGET_MATCH_ORDER), + }, + ) + + approval_trace_payload: ProxyExecutionApprovalTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "approval_status": approval["status"], + "eligible_for_execution": approval["status"] == "approved", + } + + trace_events: list[tuple[str, dict[str, object]]] = [ + ( + "tool.proxy.execute.request", + { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + }, + ), + ("tool.proxy.execute.approval", cast(dict[str, object], approval_trace_payload)), + ] + + if approval["status"] != "approved": + error = _blocked_state_error(approval=approval) + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": None, + "dispatch_status": "blocked", + "reason": str(error), + "result_status": None, + "output": None, + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "blocked", + "handler_key": None, + "request_event_id": None, + "result_event_id": None, + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + raise error + + budget_decision = evaluate_execution_budget( + store, + tool=tool, + request=routed_request, + ) + budget_trace_payload: ProxyExecutionBudgetPrecheckTracePayload = budget_decision.record + trace_events.append( + ("tool.proxy.execute.budget", cast(dict[str, object], budget_trace_payload)) + ) + + if budget_decision.blocked_result is not None: + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": None, + "dispatch_status": "blocked", + "reason": budget_decision.blocked_result["reason"], + "result_status": budget_decision.blocked_result["status"], + "output": budget_decision.blocked_result["output"], + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "blocked", + "handler_key": None, + "request_event_id": None, + "result_event_id": None, + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + execution = _persist_tool_execution( + store, + approval_row=cast(dict[str, object], approval_row), + task_step_id=cast(UUID, linked_task_step["id"]), + trace_id=trace["id"], + handler_key=None, + request=routed_request, + tool=tool, + result=budget_decision.blocked_result, + request_event_id=None, + result_event_id=None, + ) + task_transition = sync_task_with_execution( + store, + approval_id=cast(UUID, approval_row["id"]), + execution_id=execution["id"], + execution_status=execution["status"], + ) + task_step_transition = sync_task_step_with_execution( + store, + task_id=UUID(task_transition.task["id"]), + execution=execution, + trace_id=trace["id"], + trace_kind=TRACE_KIND_PROXY_EXECUTE, + ) + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="proxy_execution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="proxy_execution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + return { + "request": { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + }, + "approval": approval, + "tool": tool, + "result": budget_decision.blocked_result, + "events": None, + "trace": _trace_summary(trace["id"], trace_events), + } + + if handler is None: + error = _missing_handler_error(tool=tool) + result = _tool_execution_result( + handler_key=None, + status="blocked", + output=None, + reason=str(error), + ) + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": None, + "dispatch_status": "blocked", + "reason": str(error), + "result_status": result["status"], + "output": None, + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "blocked", + "handler_key": None, + "request_event_id": None, + "result_event_id": None, + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + execution = _persist_tool_execution( + store, + approval_row=cast(dict[str, object], approval_row), + task_step_id=cast(UUID, linked_task_step["id"]), + trace_id=trace["id"], + handler_key=None, + request=routed_request, + tool=tool, + result=result, + request_event_id=None, + result_event_id=None, + ) + task_transition = sync_task_with_execution( + store, + approval_id=cast(UUID, approval_row["id"]), + execution_id=execution["id"], + execution_status=execution["status"], + ) + task_step_transition = sync_task_step_with_execution( + store, + task_id=UUID(task_transition.task["id"]), + execution=execution, + trace_id=trace["id"], + trace_kind=TRACE_KIND_PROXY_EXECUTE, + ) + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="proxy_execution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="proxy_execution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + raise error + + request_event_payload: ProxyExecutionRequestEventPayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "request": routed_request, + } + request_event = store.append_event( + approval_row["thread_id"], + None, + PROXY_EXECUTION_REQUEST_EVENT_KIND, + cast(JsonObject, request_event_payload), + ) + + result = handler(routed_request, tool) + result_event_payload: ProxyExecutionResultEventPayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": result["handler_key"], + "status": result["status"], + "output": result["output"], + } + result_event = store.append_event( + approval_row["thread_id"], + None, + PROXY_EXECUTION_RESULT_EVENT_KIND, + cast(JsonObject, result_event_payload), + ) + execution = _persist_tool_execution( + store, + approval_row=cast(dict[str, object], approval_row), + task_step_id=cast(UUID, linked_task_step["id"]), + trace_id=trace["id"], + handler_key=result["handler_key"], + request=routed_request, + tool=tool, + result=_tool_execution_result( + handler_key=result["handler_key"], + status=result["status"], + output=result["output"], + reason=None, + ), + request_event_id=request_event["id"], + result_event_id=result_event["id"], + ) + + events: ProxyExecutionEventSummary = { + "request_event_id": str(request_event["id"]), + "request_sequence_no": request_event["sequence_no"], + "result_event_id": str(result_event["id"]), + "result_sequence_no": result_event["sequence_no"], + } + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": result["handler_key"], + "dispatch_status": "executed", + "reason": None, + "result_status": result["status"], + "output": result["output"], + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "completed", + "handler_key": result["handler_key"], + "request_event_id": events["request_event_id"], + "result_event_id": events["result_event_id"], + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + task_transition = sync_task_with_execution( + store, + approval_id=cast(UUID, approval_row["id"]), + execution_id=execution["id"], + execution_status=execution["status"], + ) + task_step_transition = sync_task_step_with_execution( + store, + task_id=UUID(task_transition.task["id"]), + execution=execution, + trace_id=trace["id"], + trace_kind=TRACE_KIND_PROXY_EXECUTE, + ) + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="proxy_execution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="proxy_execution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + return { + "request": { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + }, + "approval": approval, + "tool": tool, + "result": result, + "events": events, + "trace": _trace_summary(trace["id"], trace_events), + } diff --git a/apps/api/src/alicebot_api/response_generation.py b/apps/api/src/alicebot_api/response_generation.py new file mode 100644 index 0000000..7652a5d --- /dev/null +++ b/apps/api/src/alicebot_api/response_generation.py @@ -0,0 +1,474 @@ +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +import json +from typing import Any, TypedDict, cast +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen +from uuid import UUID + +from alicebot_api.compiler import compile_and_persist_trace +from alicebot_api.config import Settings +from alicebot_api.contracts import ( + AssistantResponseEventPayload, + CompiledContextPack, + ContextCompilerLimits, + GenerateResponseSuccess, + ModelInvocationRequest, + ModelInvocationResponse, + ModelUsagePayload, + PROMPT_ASSEMBLY_VERSION_V0, + PromptAssemblyInput, + PromptAssemblyResult, + PromptAssemblyTracePayload, + PromptSection, + RESPONSE_GENERATION_VERSION_V0, + ResponseTraceSummary, + TRACE_KIND_RESPONSE_GENERATE, + TraceEventRecord, +) +from alicebot_api.store import ContinuityStore, JsonObject + +PROMPT_TRACE_EVENT_KIND = "response.prompt.assembled" +MODEL_COMPLETED_TRACE_EVENT_KIND = "response.model.completed" +MODEL_FAILED_TRACE_EVENT_KIND = "response.model.failed" +SYSTEM_INSTRUCTION = ( + "You are AliceBot. Reply to the latest user message using the provided durable context. " + "If the context is insufficient, say so briefly instead of inventing facts." +) +DEVELOPER_INSTRUCTION = ( + "Treat the CONTEXT and CONVERSATION sections as authoritative durable state. " + "Do not call tools, do not describe hidden chain-of-thought, and keep the reply concise." +) + + +class ModelInvocationError(RuntimeError): + """Raised when the configured model provider cannot produce a response.""" + + +@dataclass(frozen=True, slots=True) +class ResponseFailure: + detail: str + trace: ResponseTraceSummary + + +class _OpenAIResponseContentItem(TypedDict, total=False): + type: str + text: str + + +class _OpenAIResponseOutputItem(TypedDict, total=False): + type: str + content: list[_OpenAIResponseContentItem] + + +class _OpenAIResponseUsage(TypedDict, total=False): + input_tokens: int | None + output_tokens: int | None + total_tokens: int | None + + +class _OpenAIResponsePayload(TypedDict, total=False): + id: str + status: str + output: list[_OpenAIResponseOutputItem] + usage: _OpenAIResponseUsage + + +def _deterministic_json(value: JsonObject | list[object]) -> str: + return json.dumps(value, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + + +def _context_section_payload(context_pack: CompiledContextPack) -> JsonObject: + return { + "compiler_version": context_pack["compiler_version"], + "scope": context_pack["scope"], + "limits": context_pack["limits"], + "user": context_pack["user"], + "thread": context_pack["thread"], + "sessions": context_pack["sessions"], + "memories": context_pack["memories"], + "memory_summary": context_pack["memory_summary"], + "entities": context_pack["entities"], + "entity_summary": context_pack["entity_summary"], + "entity_edges": context_pack["entity_edges"], + "entity_edge_summary": context_pack["entity_edge_summary"], + } + + +def assemble_prompt( + *, + request: PromptAssemblyInput, + compile_trace_id: str, +) -> PromptAssemblyResult: + sections = ( + PromptSection(name="system", content=request.system_instruction), + PromptSection(name="developer", content=request.developer_instruction), + PromptSection( + name="context", + content=_deterministic_json(_context_section_payload(request.context_pack)), + ), + PromptSection( + name="conversation", + content=_deterministic_json({"events": request.context_pack["events"]}), + ), + ) + prompt_text = "\n\n".join( + f"[{section.name.upper()}]\n{section.content}" for section in sections + ) + prompt_sha256 = hashlib.sha256(prompt_text.encode("utf-8")).hexdigest() + trace_payload: PromptAssemblyTracePayload = { + "version": PROMPT_ASSEMBLY_VERSION_V0, + "compile_trace_id": compile_trace_id, + "compiler_version": request.context_pack["compiler_version"], + "prompt_sha256": prompt_sha256, + "prompt_char_count": len(prompt_text), + "section_order": [section.name for section in sections], + "section_characters": {section.name: len(section.content) for section in sections}, + "included_session_count": len(request.context_pack["sessions"]), + "included_event_count": len(request.context_pack["events"]), + "included_memory_count": len(request.context_pack["memories"]), + "included_entity_count": len(request.context_pack["entities"]), + "included_entity_edge_count": len(request.context_pack["entity_edges"]), + } + return PromptAssemblyResult( + sections=sections, + prompt_text=prompt_text, + prompt_sha256=prompt_sha256, + trace_payload=trace_payload, + ) + + +def _openai_input_message(role: str, content: str) -> JsonObject: + return { + "role": role, + "content": [{"type": "input_text", "text": content}], + } + + +def _build_openai_responses_payload(request: ModelInvocationRequest) -> JsonObject: + sections = {section.name: section.content for section in request.prompt.sections} + return { + "model": request.model, + "store": request.store, + "tool_choice": request.tool_choice, + "tools": [], + "input": [ + _openai_input_message("system", sections["system"]), + _openai_input_message("developer", sections["developer"]), + _openai_input_message("user", f"[CONTEXT]\n{sections['context']}"), + _openai_input_message("user", f"[CONVERSATION]\n{sections['conversation']}"), + ], + "text": {"format": {"type": "text"}}, + } + + +def _extract_output_text(response_payload: _OpenAIResponsePayload) -> str: + output_items = response_payload.get("output", []) + for output_item in output_items: + if output_item.get("type") != "message": + continue + for content_item in output_item.get("content", []): + if content_item.get("type") == "output_text": + text = content_item.get("text") + if isinstance(text, str) and text: + return text + raise ModelInvocationError("model response did not include assistant output text") + + +def _parse_usage(response_payload: _OpenAIResponsePayload) -> ModelUsagePayload: + usage = response_payload.get("usage", {}) + if not isinstance(usage, dict): + return {"input_tokens": None, "output_tokens": None, "total_tokens": None} + return { + "input_tokens": usage.get("input_tokens"), + "output_tokens": usage.get("output_tokens"), + "total_tokens": usage.get("total_tokens"), + } + + +def _parse_openai_response_payload(raw_payload: bytes) -> _OpenAIResponsePayload: + try: + parsed_payload = json.loads(raw_payload) + except json.JSONDecodeError as exc: + raise ModelInvocationError("model provider returned invalid JSON") from exc + + if not isinstance(parsed_payload, dict): + raise ModelInvocationError("model provider returned invalid JSON") + + return cast(_OpenAIResponsePayload, parsed_payload) + + +def _extract_http_error_detail(exc: HTTPError) -> str | None: + raw_body = exc.read().decode("utf-8", errors="replace") + try: + parsed_error = json.loads(raw_body) + except json.JSONDecodeError: + return None + + if not isinstance(parsed_error, dict): + return None + + error = parsed_error.get("error", {}) + if not isinstance(error, dict): + return None + + detail = error.get("message") + if isinstance(detail, str) and detail: + return detail + return None + + +def _build_model_http_request(*, settings: Settings, payload: JsonObject) -> Request: + endpoint = settings.model_base_url.rstrip("/") + "/responses" + return Request( + endpoint, + data=json.dumps(payload).encode("utf-8"), + headers={ + "Authorization": f"Bearer {settings.model_api_key}", + "Content-Type": "application/json", + }, + method="POST", + ) + + +def _model_failure_trace_payload( + *, + request: ModelInvocationRequest, + error_message: str, +) -> JsonObject: + return { + "provider": request.provider, + "model": request.model, + "tool_choice": "none", + "tools_enabled": False, + "response_id": None, + "finish_reason": "incomplete", + "output_text_char_count": 0, + "usage": { + "input_tokens": None, + "output_tokens": None, + "total_tokens": None, + }, + "error_message": error_message, + } + + +def _create_linked_response_trace( + *, + store: ContinuityStore, + user_id: UUID, + thread_id: UUID, + limits: ContextCompilerLimits, + compiled_trace_id: str, + compiled_trace_event_count: int, + status: str, + trace_events: list[TraceEventRecord], +) -> ResponseTraceSummary: + trace = _create_response_trace( + store=store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + status=status, + trace_events=trace_events, + ) + trace["compile_trace_id"] = compiled_trace_id + trace["compile_trace_event_count"] = compiled_trace_event_count + return trace + + +def invoke_model( + *, + settings: Settings, + request: ModelInvocationRequest, +) -> ModelInvocationResponse: + if request.provider != "openai_responses": + raise ModelInvocationError(f"unsupported model provider: {request.provider}") + if not settings.model_api_key: + raise ModelInvocationError("MODEL_API_KEY is not configured") + + payload = _build_openai_responses_payload(request) + http_request = _build_model_http_request(settings=settings, payload=payload) + + try: + with urlopen(http_request, timeout=settings.model_timeout_seconds) as response: + raw_payload = response.read() + except HTTPError as exc: + detail = _extract_http_error_detail(exc) + if detail is not None: + raise ModelInvocationError(detail) from exc + raise ModelInvocationError(f"model provider returned HTTP {exc.code}") from exc + except URLError as exc: + raise ModelInvocationError(f"model provider request failed: {exc.reason}") from exc + + response_payload = _parse_openai_response_payload(raw_payload) + output_text = _extract_output_text(response_payload) + finish_reason = "completed" if response_payload.get("status") == "completed" else "incomplete" + return ModelInvocationResponse( + provider=request.provider, + model=request.model, + response_id=response_payload.get("id"), + finish_reason=finish_reason, + output_text=output_text, + usage=_parse_usage(response_payload), + ) + + +def build_assistant_response_payload( + *, + prompt: PromptAssemblyResult, + model_response: ModelInvocationResponse, +) -> AssistantResponseEventPayload: + return { + "text": model_response.output_text, + "model": { + "provider": model_response.provider, + "model": model_response.model, + "response_id": model_response.response_id, + "finish_reason": model_response.finish_reason, + "usage": model_response.usage, + }, + "prompt": { + "assembly_version": PROMPT_ASSEMBLY_VERSION_V0, + "prompt_sha256": prompt.prompt_sha256, + "section_order": [section.name for section in prompt.sections], + }, + } + + +def _create_response_trace( + *, + store: ContinuityStore, + user_id: UUID, + thread_id: UUID, + limits: ContextCompilerLimits, + status: str, + trace_events: list[TraceEventRecord], +) -> ResponseTraceSummary: + trace = store.create_trace( + user_id=user_id, + thread_id=thread_id, + kind=TRACE_KIND_RESPONSE_GENERATE, + compiler_version=RESPONSE_GENERATION_VERSION_V0, + status=status, + limits=limits.as_payload(), + ) + for sequence_no, trace_event in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=trace_event.kind, + payload=trace_event.payload, + ) + return { + "compile_trace_id": "", + "compile_trace_event_count": 0, + "response_trace_id": str(trace["id"]), + "response_trace_event_count": len(trace_events), + } + + +def generate_response( + *, + store: ContinuityStore, + settings: Settings, + user_id: UUID, + thread_id: UUID, + message_text: str, + limits: ContextCompilerLimits, +) -> GenerateResponseSuccess | ResponseFailure: + store.get_user(user_id) + store.get_thread(thread_id) + + store.append_event( + thread_id, + None, + "message.user", + {"text": message_text}, + ) + compiled_trace = compile_and_persist_trace( + store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + ) + prompt = assemble_prompt( + request=PromptAssemblyInput( + context_pack=compiled_trace.context_pack, + system_instruction=SYSTEM_INSTRUCTION, + developer_instruction=DEVELOPER_INSTRUCTION, + ), + compile_trace_id=compiled_trace.trace_id, + ) + request = ModelInvocationRequest( + provider=settings.model_provider, # type: ignore[arg-type] + model=settings.model_name, + prompt=prompt, + ) + prompt_trace_event = TraceEventRecord( + kind=PROMPT_TRACE_EVENT_KIND, + payload=prompt.trace_payload, + ) + + try: + model_response = invoke_model(settings=settings, request=request) + except ModelInvocationError as exc: + trace = _create_linked_response_trace( + store=store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + compiled_trace_id=compiled_trace.trace_id, + compiled_trace_event_count=compiled_trace.trace_event_count, + status="failed", + trace_events=[ + prompt_trace_event, + TraceEventRecord( + kind=MODEL_FAILED_TRACE_EVENT_KIND, + payload=_model_failure_trace_payload( + request=request, + error_message=str(exc), + ), + ), + ], + ) + return ResponseFailure(detail=str(exc), trace=trace) + + assistant_payload = build_assistant_response_payload( + prompt=prompt, + model_response=model_response, + ) + assistant_event = store.append_event( + thread_id, + None, + "message.assistant", + assistant_payload, + ) + trace = _create_linked_response_trace( + store=store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + compiled_trace_id=compiled_trace.trace_id, + compiled_trace_event_count=compiled_trace.trace_event_count, + status="completed", + trace_events=[ + prompt_trace_event, + TraceEventRecord( + kind=MODEL_COMPLETED_TRACE_EVENT_KIND, + payload=model_response.to_trace_payload(), + ), + ], + ) + return { + "assistant": { + "event_id": str(assistant_event["id"]), + "sequence_no": assistant_event["sequence_no"], + "text": model_response.output_text, + "model_provider": model_response.provider, + "model": model_response.model, + }, + "trace": trace, + } diff --git a/apps/api/src/alicebot_api/semantic_retrieval.py b/apps/api/src/alicebot_api/semantic_retrieval.py new file mode 100644 index 0000000..5384e3d --- /dev/null +++ b/apps/api/src/alicebot_api/semantic_retrieval.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import math +from uuid import UUID + +from alicebot_api.contracts import ( + SEMANTIC_MEMORY_RETRIEVAL_ORDER, + SemanticMemoryRetrievalRequestInput, + SemanticMemoryRetrievalResponse, + SemanticMemoryRetrievalResultItem, + SemanticMemoryRetrievalSummary, +) +from alicebot_api.store import ContinuityStore, SemanticMemoryRetrievalRow + + +class SemanticMemoryRetrievalValidationError(ValueError): + """Raised when semantic memory retrieval fails explicit validation.""" + + +def _validate_query_vector(query_vector: tuple[float, ...]) -> list[float]: + if not query_vector: + raise SemanticMemoryRetrievalValidationError( + "query_vector must include at least one numeric value" + ) + + normalized: list[float] = [] + for value in query_vector: + normalized_value = float(value) + if not math.isfinite(normalized_value): + raise SemanticMemoryRetrievalValidationError( + "query_vector must contain only finite numeric values" + ) + normalized.append(normalized_value) + + return normalized + + +def validate_semantic_memory_retrieval_request( + store: ContinuityStore, + *, + request: SemanticMemoryRetrievalRequestInput, +) -> tuple[dict[str, object], list[float]]: + config = store.get_embedding_config_optional(request.embedding_config_id) + if config is None: + raise SemanticMemoryRetrievalValidationError( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{request.embedding_config_id}" + ) + + query_vector = _validate_query_vector(request.query_vector) + if len(query_vector) != config["dimensions"]: + raise SemanticMemoryRetrievalValidationError( + "query_vector length must match embedding config dimensions " + f"({config['dimensions']}): {len(query_vector)}" + ) + + return config, query_vector + + +def serialize_semantic_memory_result_item( + row: SemanticMemoryRetrievalRow, +) -> SemanticMemoryRetrievalResultItem: + if row["status"] != "active": + raise SemanticMemoryRetrievalValidationError( + f"semantic retrieval only supports active memories: {row['id']}" + ) + + return { + "memory_id": str(row["id"]), + "memory_key": row["memory_key"], + "value": row["value"], + "source_event_ids": row["source_event_ids"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + "score": float(row["score"]), + } + + +def retrieve_semantic_memory_records( + store: ContinuityStore, + *, + user_id: UUID, + request: SemanticMemoryRetrievalRequestInput, +) -> SemanticMemoryRetrievalResponse: + del user_id + + _config, query_vector = validate_semantic_memory_retrieval_request(store, request=request) + + items = [ + serialize_semantic_memory_result_item(row) + for row in store.retrieve_semantic_memory_matches( + embedding_config_id=request.embedding_config_id, + query_vector=query_vector, + limit=request.limit, + ) + ] + summary: SemanticMemoryRetrievalSummary = { + "embedding_config_id": str(request.embedding_config_id), + "limit": request.limit, + "returned_count": len(items), + "similarity_metric": "cosine_similarity", + "order": list(SEMANTIC_MEMORY_RETRIEVAL_ORDER), + } + return { + "items": items, + "summary": summary, + } diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py new file mode 100644 index 0000000..8c3b551 --- /dev/null +++ b/apps/api/src/alicebot_api/store.py @@ -0,0 +1,2713 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, TypedDict, TypeVar, cast +from uuid import UUID + +import psycopg +from psycopg.types.json import Jsonb + +JsonScalar = str | int | float | bool | None +JsonValue = JsonScalar | list["JsonValue"] | dict[str, "JsonValue"] +JsonObject = dict[str, JsonValue] +RowT = TypeVar("RowT") + + +class UserRow(TypedDict): + id: UUID + email: str + display_name: str | None + created_at: datetime + + +class ThreadRow(TypedDict): + id: UUID + user_id: UUID + title: str + created_at: datetime + updated_at: datetime + + +class SessionRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + status: str + started_at: datetime | None + ended_at: datetime | None + created_at: datetime + + +class EventRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + session_id: UUID | None + sequence_no: int + kind: str + payload: JsonObject + created_at: datetime + + +class TraceRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + kind: str + compiler_version: str + status: str + limits: JsonObject + created_at: datetime + + +class TraceEventRow(TypedDict): + id: UUID + user_id: UUID + trace_id: UUID + sequence_no: int + kind: str + payload: JsonObject + created_at: datetime + + +class MemoryRow(TypedDict): + id: UUID + user_id: UUID + memory_key: str + value: JsonValue + status: str + source_event_ids: list[str] + created_at: datetime + updated_at: datetime + deleted_at: datetime | None + + +class MemoryRevisionRow(TypedDict): + id: UUID + user_id: UUID + memory_id: UUID + sequence_no: int + action: str + memory_key: str + previous_value: JsonValue | None + new_value: JsonValue | None + source_event_ids: list[str] + candidate: JsonObject + created_at: datetime + + +class MemoryReviewLabelRow(TypedDict): + id: UUID + user_id: UUID + memory_id: UUID + label: str + note: str | None + created_at: datetime + + +class EmbeddingConfigRow(TypedDict): + id: UUID + user_id: UUID + provider: str + model: str + version: str + dimensions: int + status: str + metadata: JsonObject + created_at: datetime + + +class MemoryEmbeddingRow(TypedDict): + id: UUID + user_id: UUID + memory_id: UUID + embedding_config_id: UUID + dimensions: int + vector: list[float] + created_at: datetime + updated_at: datetime + + +class SemanticMemoryRetrievalRow(TypedDict): + id: UUID + user_id: UUID + memory_key: str + value: JsonValue + status: str + source_event_ids: list[str] + created_at: datetime + updated_at: datetime + deleted_at: datetime | None + score: float + + +class EntityRow(TypedDict): + id: UUID + user_id: UUID + entity_type: str + name: str + source_memory_ids: list[str] + created_at: datetime + + +class EntityEdgeRow(TypedDict): + id: UUID + user_id: UUID + from_entity_id: UUID + to_entity_id: UUID + relationship_type: str + valid_from: datetime | None + valid_to: datetime | None + source_memory_ids: list[str] + created_at: datetime + + +class ConsentRow(TypedDict): + id: UUID + user_id: UUID + consent_key: str + status: str + metadata: JsonObject + created_at: datetime + updated_at: datetime + + +class PolicyRow(TypedDict): + id: UUID + user_id: UUID + name: str + action: str + scope: str + effect: str + priority: int + active: bool + conditions: JsonObject + required_consents: list[str] + created_at: datetime + updated_at: datetime + + +class ToolRow(TypedDict): + id: UUID + user_id: UUID + tool_key: str + name: str + description: str + version: str + metadata_version: str + active: bool + tags: list[str] + action_hints: list[str] + scope_hints: list[str] + domain_hints: list[str] + risk_hints: list[str] + metadata: JsonObject + created_at: datetime + + +class ApprovalRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + tool_id: UUID + task_step_id: UUID | None + status: str + request: JsonObject + tool: JsonObject + routing: JsonObject + routing_trace_id: UUID + created_at: datetime + resolved_at: datetime | None + resolved_by_user_id: UUID | None + + +class TaskRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + tool_id: UUID + status: str + request: JsonObject + tool: JsonObject + latest_approval_id: UUID | None + latest_execution_id: UUID | None + created_at: datetime + updated_at: datetime + + +class TaskWorkspaceRow(TypedDict): + id: UUID + user_id: UUID + task_id: UUID + status: str + local_path: str + created_at: datetime + updated_at: datetime + + +class TaskStepRow(TypedDict): + id: UUID + user_id: UUID + task_id: UUID + sequence_no: int + parent_step_id: UUID | None + source_approval_id: UUID | None + source_execution_id: UUID | None + kind: str + status: str + request: JsonObject + outcome: JsonObject + trace_id: UUID + trace_kind: str + created_at: datetime + updated_at: datetime + + +class ToolExecutionRow(TypedDict): + id: UUID + user_id: UUID + approval_id: UUID + task_step_id: UUID + thread_id: UUID + tool_id: UUID + trace_id: UUID + request_event_id: UUID | None + result_event_id: UUID | None + status: str + handler_key: str | None + request: JsonObject + tool: JsonObject + result: JsonObject + executed_at: datetime + + +class ExecutionBudgetRow(TypedDict): + id: UUID + user_id: UUID + tool_key: str | None + domain_hint: str | None + max_completed_executions: int + rolling_window_seconds: int | None + status: str + deactivated_at: datetime | None + superseded_by_budget_id: UUID | None + supersedes_budget_id: UUID | None + created_at: datetime + + +class CountRow(TypedDict): + count: int + + +class LabelCountRow(TypedDict): + label: str + count: int + + +INSERT_USER_SQL = """ + INSERT INTO users (id, email, display_name) + VALUES (%s, %s, %s) + RETURNING id, email, display_name, created_at + """ + +GET_USER_SQL = """ + SELECT id, email, display_name, created_at + FROM users + WHERE id = %s + """ + +INSERT_THREAD_SQL = """ + INSERT INTO threads (user_id, title) + VALUES (app.current_user_id(), %s) + RETURNING id, user_id, title, created_at, updated_at + """ + +GET_THREAD_SQL = """ + SELECT id, user_id, title, created_at, updated_at + FROM threads + WHERE id = %s + """ + +INSERT_SESSION_SQL = """ + INSERT INTO sessions (user_id, thread_id, status) + VALUES (app.current_user_id(), %s, %s) + RETURNING id, user_id, thread_id, status, started_at, ended_at, created_at + """ + +LIST_THREAD_SESSIONS_SQL = """ + SELECT id, user_id, thread_id, status, started_at, ended_at, created_at + FROM sessions + WHERE thread_id = %s + ORDER BY started_at ASC, created_at ASC, id ASC + """ + +LOCK_THREAD_EVENTS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 0))" +LOCK_TASK_STEPS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 2))" +LOCK_TASK_WORKSPACES_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 3))" + +INSERT_EVENT_SQL = """ + WITH next_sequence AS ( + SELECT COALESCE(MAX(sequence_no) + 1, 1) AS sequence_no + FROM events + WHERE thread_id = %s + AND user_id = app.current_user_id() + ) + INSERT INTO events (user_id, thread_id, session_id, sequence_no, kind, payload) + SELECT app.current_user_id(), %s, %s, next_sequence.sequence_no, %s, %s + FROM next_sequence + RETURNING id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + """ + +LIST_THREAD_EVENTS_SQL = """ + SELECT id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + FROM events + WHERE thread_id = %s + ORDER BY sequence_no ASC + """ + +LIST_EVENTS_BY_IDS_SQL = """ + SELECT id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + FROM events + WHERE id = ANY(%s) + ORDER BY sequence_no ASC + """ + +INSERT_TRACE_SQL = """ + INSERT INTO traces (user_id, thread_id, kind, compiler_version, status, limits) + VALUES (%s, %s, %s, %s, %s, %s) + RETURNING id, user_id, thread_id, kind, compiler_version, status, limits, created_at + """ + +GET_TRACE_SQL = """ + SELECT id, user_id, thread_id, kind, compiler_version, status, limits, created_at + FROM traces + WHERE id = %s + """ + +INSERT_TRACE_EVENT_SQL = """ + INSERT INTO trace_events (user_id, trace_id, sequence_no, kind, payload) + VALUES (app.current_user_id(), %s, %s, %s, %s) + RETURNING id, user_id, trace_id, sequence_no, kind, payload, created_at + """ + +LIST_TRACE_EVENTS_SQL = """ + SELECT id, user_id, trace_id, sequence_no, kind, payload, created_at + FROM trace_events + WHERE trace_id = %s + ORDER BY sequence_no ASC + """ + +INSERT_MEMORY_SQL = """ + INSERT INTO memories ( + user_id, + memory_key, + value, + status, + source_event_ids, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + """ + +GET_MEMORY_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = %s + """ + +LIST_MEMORIES_BY_IDS_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = ANY(%s) + ORDER BY created_at ASC, id ASC + """ + +GET_MEMORY_BY_KEY_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE memory_key = %s + """ + +LIST_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY created_at ASC, id ASC + """ + +COUNT_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + """ + +COUNT_MEMORIES_BY_STATUS_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE status = %s + """ + +COUNT_UNLABELED_REVIEW_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE status = 'active' + AND NOT EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + """ + +LIST_REVIEW_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """ + +LIST_REVIEW_MEMORIES_BY_STATUS_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE status = %s + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """ + +LIST_UNLABELED_REVIEW_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE status = 'active' + AND NOT EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """ + +LIST_CONTEXT_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY updated_at ASC, created_at ASC, id ASC + """ + +UPDATE_MEMORY_SQL = """ + UPDATE memories + SET value = %s, + status = %s, + source_event_ids = %s, + updated_at = clock_timestamp(), + deleted_at = CASE + WHEN %s = 'deleted' THEN clock_timestamp() + ELSE NULL + END + WHERE id = %s + RETURNING id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + """ + +LOCK_MEMORY_REVISIONS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 1))" + +INSERT_MEMORY_REVISION_SQL = """ + WITH next_sequence AS ( + SELECT COALESCE(MAX(sequence_no) + 1, 1) AS sequence_no + FROM memory_revisions + WHERE memory_id = %s + AND user_id = app.current_user_id() + ) + INSERT INTO memory_revisions ( + user_id, + memory_id, + sequence_no, + action, + memory_key, + previous_value, + new_value, + source_event_ids, + candidate + ) + SELECT + app.current_user_id(), + %s, + next_sequence.sequence_no, + %s, + %s, + %s, + %s, + %s, + %s + FROM next_sequence + RETURNING id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + """ + +LIST_MEMORY_REVISIONS_SQL = """ + SELECT id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + FROM memory_revisions + WHERE memory_id = %s + ORDER BY sequence_no ASC + """ + +COUNT_MEMORY_REVISIONS_SQL = """ + SELECT COUNT(*) AS count + FROM memory_revisions + WHERE memory_id = %s + """ + +LIST_LIMITED_MEMORY_REVISIONS_SQL = """ + SELECT id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + FROM memory_revisions + WHERE memory_id = %s + ORDER BY sequence_no ASC + LIMIT %s + """ + +INSERT_MEMORY_REVIEW_LABEL_SQL = """ + INSERT INTO memory_review_labels (user_id, memory_id, label, note) + VALUES (app.current_user_id(), %s, %s, %s) + RETURNING id, user_id, memory_id, label, note, created_at + """ + +LIST_MEMORY_REVIEW_LABELS_SQL = """ + SELECT id, user_id, memory_id, label, note, created_at + FROM memory_review_labels + WHERE memory_id = %s + ORDER BY created_at ASC, id ASC + """ + +LIST_MEMORY_REVIEW_LABEL_COUNTS_SQL = """ + SELECT label, COUNT(*) AS count + FROM memory_review_labels + WHERE memory_id = %s + GROUP BY label + ORDER BY label ASC + """ + +COUNT_LABELED_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + """ + +COUNT_UNLABELED_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE NOT EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + """ + +LIST_ALL_MEMORY_REVIEW_LABEL_COUNTS_SQL = """ + SELECT label, COUNT(*) AS count + FROM memory_review_labels + GROUP BY label + ORDER BY label ASC + """ + +INSERT_EMBEDDING_CONFIG_SQL = """ + INSERT INTO embedding_configs ( + user_id, + provider, + model, + version, + dimensions, + status, + metadata, + created_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, %s, %s, clock_timestamp()) + RETURNING id, user_id, provider, model, version, dimensions, status, metadata, created_at + """ + +GET_EMBEDDING_CONFIG_SQL = """ + SELECT id, user_id, provider, model, version, dimensions, status, metadata, created_at + FROM embedding_configs + WHERE id = %s + """ + +GET_EMBEDDING_CONFIG_BY_IDENTITY_SQL = """ + SELECT id, user_id, provider, model, version, dimensions, status, metadata, created_at + FROM embedding_configs + WHERE provider = %s + AND model = %s + AND version = %s + """ + +LIST_EMBEDDING_CONFIGS_SQL = """ + SELECT id, user_id, provider, model, version, dimensions, status, metadata, created_at + FROM embedding_configs + ORDER BY created_at ASC, id ASC + """ + +INSERT_MEMORY_EMBEDDING_SQL = """ + INSERT INTO memory_embeddings ( + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + """ + +GET_MEMORY_EMBEDDING_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE id = %s + """ + +GET_MEMORY_EMBEDDING_BY_MEMORY_AND_CONFIG_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE memory_id = %s + AND embedding_config_id = %s + """ + +LIST_MEMORY_EMBEDDINGS_FOR_MEMORY_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE memory_id = %s + ORDER BY created_at ASC, id ASC + """ + +LIST_MEMORY_EMBEDDINGS_FOR_CONFIG_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE embedding_config_id = %s + ORDER BY created_at ASC, id ASC + """ + +UPDATE_MEMORY_EMBEDDING_SQL = """ + UPDATE memory_embeddings + SET dimensions = %s, + vector = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + """ + +RETRIEVE_SEMANTIC_MEMORY_MATCHES_SQL = """ + SELECT + memories.id, + memories.user_id, + memories.memory_key, + memories.value, + memories.status, + memories.source_event_ids, + memories.created_at, + memories.updated_at, + memories.deleted_at, + 1 - ( + replace(memory_embeddings.vector::text, ' ', '')::vector <=> %s::vector + ) AS score + FROM memory_embeddings + JOIN memories + ON memories.id = memory_embeddings.memory_id + AND memories.user_id = memory_embeddings.user_id + WHERE memory_embeddings.embedding_config_id = %s + AND memory_embeddings.dimensions = %s + AND memories.status = 'active' + ORDER BY score DESC, memories.created_at ASC, memories.id ASC + LIMIT %s + """ + +INSERT_ENTITY_SQL = """ + INSERT INTO entities (user_id, entity_type, name, source_memory_ids, created_at) + VALUES (app.current_user_id(), %s, %s, %s, clock_timestamp()) + RETURNING id, user_id, entity_type, name, source_memory_ids, created_at + """ + +GET_ENTITY_SQL = """ + SELECT id, user_id, entity_type, name, source_memory_ids, created_at + FROM entities + WHERE id = %s + """ + +LIST_ENTITIES_SQL = """ + SELECT id, user_id, entity_type, name, source_memory_ids, created_at + FROM entities + ORDER BY created_at ASC, id ASC + """ + +INSERT_ENTITY_EDGE_SQL = """ + INSERT INTO entity_edges ( + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, %s, %s, clock_timestamp()) + RETURNING + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + """ + +LIST_ENTITY_EDGES_FOR_ENTITY_SQL = """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = %s OR to_entity_id = %s + ORDER BY created_at ASC, id ASC + """ + +LIST_ENTITY_EDGES_FOR_ENTITIES_SQL = """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = ANY(%s) OR to_entity_id = ANY(%s) + ORDER BY created_at ASC, id ASC + """ + +INSERT_CONSENT_SQL = """ + INSERT INTO consents ( + user_id, + consent_key, + status, + metadata, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING id, user_id, consent_key, status, metadata, created_at, updated_at + """ + +GET_CONSENT_BY_KEY_SQL = """ + SELECT id, user_id, consent_key, status, metadata, created_at, updated_at + FROM consents + WHERE consent_key = %s + """ + +LIST_CONSENTS_SQL = """ + SELECT id, user_id, consent_key, status, metadata, created_at, updated_at + FROM consents + ORDER BY consent_key ASC, created_at ASC, id ASC + """ + +UPDATE_CONSENT_SQL = """ + UPDATE consents + SET status = %s, + metadata = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING id, user_id, consent_key, status, metadata, created_at, updated_at + """ + +INSERT_POLICY_SQL = """ + INSERT INTO policies ( + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, %s, %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + """ + +GET_POLICY_SQL = """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + WHERE id = %s + """ + +LIST_POLICIES_SQL = """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + ORDER BY priority ASC, created_at ASC, id ASC + """ + +LIST_ACTIVE_POLICIES_SQL = """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + WHERE active = TRUE + ORDER BY priority ASC, created_at ASC, id ASC + """ + +INSERT_TOOL_SQL = """ + INSERT INTO tools ( + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp() + ) + RETURNING + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + """ + +GET_TOOL_SQL = """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + WHERE id = %s + """ + +LIST_TOOLS_SQL = """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + ORDER BY tool_key ASC, version ASC, created_at ASC, id ASC + """ + +LIST_ACTIVE_TOOLS_SQL = """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + WHERE active = TRUE + ORDER BY tool_key ASC, version ASC, created_at ASC, id ASC + """ + +INSERT_APPROVAL_SQL = """ + INSERT INTO approvals ( + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp() + ) + RETURNING + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + """ + +GET_APPROVAL_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + FROM approvals + WHERE id = %s + """ + +LIST_APPROVALS_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + FROM approvals + ORDER BY created_at ASC, id ASC + """ + +UPDATE_APPROVAL_RESOLUTION_SQL = """ + UPDATE approvals + SET status = %s, + resolved_at = clock_timestamp(), + resolved_by_user_id = app.current_user_id() + WHERE id = %s + AND status = 'pending' + RETURNING + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + """ + +UPDATE_APPROVAL_TASK_STEP_SQL = """ + UPDATE approvals + SET task_step_id = %s + WHERE id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + """ + +INSERT_TASK_SQL = """ + INSERT INTO tasks ( + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +GET_TASK_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + FROM tasks + WHERE id = %s + """ + +GET_TASK_BY_APPROVAL_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + FROM tasks + WHERE latest_approval_id = %s + """ + +LIST_TASKS_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + FROM tasks + ORDER BY created_at ASC, id ASC + """ + +UPDATE_TASK_STATUS_BY_APPROVAL_SQL = """ + UPDATE tasks + SET status = %s, + updated_at = clock_timestamp() + WHERE latest_approval_id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +UPDATE_TASK_EXECUTION_BY_APPROVAL_SQL = """ + UPDATE tasks + SET status = %s, + latest_execution_id = %s, + updated_at = clock_timestamp() + WHERE latest_approval_id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +UPDATE_TASK_STATUS_SQL = """ + UPDATE tasks + SET status = %s, + latest_approval_id = %s, + latest_execution_id = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +INSERT_TASK_WORKSPACE_SQL = """ + INSERT INTO task_workspaces ( + user_id, + task_id, + status, + local_path, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + """ + +GET_TASK_WORKSPACE_SQL = """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE id = %s + """ + +GET_ACTIVE_TASK_WORKSPACE_FOR_TASK_SQL = """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE task_id = %s + AND status = 'active' + ORDER BY created_at ASC, id ASC + LIMIT 1 + """ + +LIST_TASK_WORKSPACES_SQL = """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + ORDER BY created_at ASC, id ASC + """ + +INSERT_TASK_STEP_SQL = """ + INSERT INTO task_steps ( + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + """ + +GET_TASK_STEP_SQL = """ + SELECT + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + FROM task_steps + WHERE id = %s + """ + +GET_TASK_STEP_FOR_TASK_SEQUENCE_SQL = """ + SELECT + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + FROM task_steps + WHERE task_id = %s + AND sequence_no = %s + """ + +LIST_TASK_STEPS_FOR_TASK_SQL = """ + SELECT + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + FROM task_steps + WHERE task_id = %s + ORDER BY sequence_no ASC, created_at ASC, id ASC + """ + +UPDATE_TASK_STEP_FOR_TASK_SEQUENCE_SQL = """ + UPDATE task_steps + SET status = %s, + outcome = %s, + trace_id = %s, + trace_kind = %s, + updated_at = clock_timestamp() + WHERE task_id = %s + AND sequence_no = %s + RETURNING + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + """ + +UPDATE_TASK_STEP_SQL = """ + UPDATE task_steps + SET status = %s, + outcome = %s, + trace_id = %s, + trace_kind = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + """ + +INSERT_TOOL_EXECUTION_SQL = """ + INSERT INTO tool_executions ( + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp() + ) + RETURNING + id, + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + """ + +GET_TOOL_EXECUTION_SQL = """ + SELECT + id, + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + FROM tool_executions + WHERE id = %s + """ + +LIST_TOOL_EXECUTIONS_SQL = """ + SELECT + id, + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + FROM tool_executions + ORDER BY executed_at ASC, id ASC + """ + +INSERT_EXECUTION_BUDGET_SQL = """ + INSERT INTO execution_budgets ( + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + supersedes_budget_id + ) + VALUES ( + COALESCE(%s, gen_random_uuid()), + app.current_user_id(), + %s, + %s, + %s, + %s, + %s + ) + RETURNING + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + """ + +GET_EXECUTION_BUDGET_SQL = """ + SELECT + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + FROM execution_budgets + WHERE id = %s + """ + +LIST_EXECUTION_BUDGETS_SQL = """ + SELECT + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + FROM execution_budgets + ORDER BY created_at ASC, id ASC + """ + +DEACTIVATE_EXECUTION_BUDGET_SQL = """ + UPDATE execution_budgets + SET status = 'inactive', + deactivated_at = now() + WHERE id = %s + AND status = 'active' + RETURNING + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + """ + +SUPERSEDE_EXECUTION_BUDGET_SQL = """ + UPDATE execution_budgets + SET status = 'superseded', + deactivated_at = now(), + superseded_by_budget_id = %s + WHERE id = %s + AND status = 'active' + RETURNING + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + """ + +UPDATE_EVENT_ERROR = "events are append-only and must be superseded by new records" +DELETE_EVENT_ERROR = "events are append-only and must not be deleted in place" +UPDATE_TRACE_EVENT_ERROR = "trace events are append-only and must be superseded by new records" +DELETE_TRACE_EVENT_ERROR = "trace events are append-only and must not be deleted in place" + + +class AppendOnlyViolation(RuntimeError): + """Raised when a caller attempts to mutate an immutable event.""" + + +class ContinuityStoreInvariantError(RuntimeError): + """Raised when a write query does not return the row its contract promises.""" + + +class ContinuityStore: + def __init__(self, conn: psycopg.Connection): + self.conn = conn + + def _fetch_one( + self, + operation_name: str, + query: str, + params: tuple[object, ...] | None = None, + ) -> RowT: + with self.conn.cursor() as cur: + cur.execute(query, params) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + f"{operation_name} did not return a row from the database", + ) + + return cast(RowT, row) + + def _fetch_all( + self, + query: str, + params: tuple[object, ...] | None = None, + ) -> list[RowT]: + with self.conn.cursor() as cur: + cur.execute(query, params) + return cast(list[RowT], list(cur.fetchall())) + + def _fetch_optional_one( + self, + query: str, + params: tuple[object, ...] | None = None, + ) -> RowT | None: + with self.conn.cursor() as cur: + cur.execute(query, params) + row = cur.fetchone() + return cast(RowT | None, row) + + def _fetch_count( + self, + query: str, + params: tuple[object, ...] | None = None, + ) -> int: + with self.conn.cursor() as cur: + cur.execute(query, params) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "count query did not return a row from the database", + ) + + return cast(CountRow, row)["count"] + + @staticmethod + def _vector_literal(vector: list[float]) -> str: + return "[" + ",".join(repr(value) for value in vector) + "]" + + def create_user(self, user_id: UUID, email: str, display_name: str | None = None) -> UserRow: + return self._fetch_one( + "create_user", + INSERT_USER_SQL, + (user_id, email, display_name), + ) + + def get_user(self, user_id: UUID) -> UserRow: + return self._fetch_one("get_user", GET_USER_SQL, (user_id,)) + + def create_thread(self, title: str) -> ThreadRow: + return self._fetch_one("create_thread", INSERT_THREAD_SQL, (title,)) + + def get_thread(self, thread_id: UUID) -> ThreadRow: + return self._fetch_one("get_thread", GET_THREAD_SQL, (thread_id,)) + + def get_thread_optional(self, thread_id: UUID) -> ThreadRow | None: + return self._fetch_optional_one(GET_THREAD_SQL, (thread_id,)) + + def create_session(self, thread_id: UUID, status: str = "active") -> SessionRow: + return self._fetch_one("create_session", INSERT_SESSION_SQL, (thread_id, status)) + + def list_thread_sessions(self, thread_id: UUID) -> list[SessionRow]: + return self._fetch_all(LIST_THREAD_SESSIONS_SQL, (thread_id,)) + + def append_event( + self, + thread_id: UUID, + session_id: UUID | None, + kind: str, + payload: JsonObject, + ) -> EventRow: + with self.conn.cursor() as cur: + cur.execute(LOCK_THREAD_EVENTS_SQL, (str(thread_id),)) + cur.execute( + INSERT_EVENT_SQL, + (thread_id, thread_id, session_id, kind, Jsonb(payload)), + ) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "append_event did not return a row from the database", + ) + + return cast(EventRow, row) + + def list_thread_events(self, thread_id: UUID) -> list[EventRow]: + return self._fetch_all(LIST_THREAD_EVENTS_SQL, (thread_id,)) + + def list_events_by_ids(self, event_ids: list[UUID]) -> list[EventRow]: + if not event_ids: + return [] + return self._fetch_all(LIST_EVENTS_BY_IDS_SQL, (event_ids,)) + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: JsonObject, + ) -> TraceRow: + return self._fetch_one( + "create_trace", + INSERT_TRACE_SQL, + (user_id, thread_id, kind, compiler_version, status, Jsonb(limits)), + ) + + def get_trace(self, trace_id: UUID) -> TraceRow: + return self._fetch_one("get_trace", GET_TRACE_SQL, (trace_id,)) + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: JsonObject, + ) -> TraceEventRow: + return self._fetch_one( + "append_trace_event", + INSERT_TRACE_EVENT_SQL, + (trace_id, sequence_no, kind, Jsonb(payload)), + ) + + def list_trace_events(self, trace_id: UUID) -> list[TraceEventRow]: + return self._fetch_all(LIST_TRACE_EVENTS_SQL, (trace_id,)) + + def create_memory( + self, + *, + memory_key: str, + value: JsonValue, + status: str, + source_event_ids: list[str], + ) -> MemoryRow: + return self._fetch_one( + "create_memory", + INSERT_MEMORY_SQL, + (memory_key, Jsonb(value), status, Jsonb(source_event_ids)), + ) + + def get_memory(self, memory_id: UUID) -> MemoryRow: + return self._fetch_one("get_memory", GET_MEMORY_SQL, (memory_id,)) + + def get_memory_optional(self, memory_id: UUID) -> MemoryRow | None: + return self._fetch_optional_one(GET_MEMORY_SQL, (memory_id,)) + + def list_memories_by_ids(self, memory_ids: list[UUID]) -> list[MemoryRow]: + if not memory_ids: + return [] + return self._fetch_all(LIST_MEMORIES_BY_IDS_SQL, (memory_ids,)) + + def get_memory_by_key(self, memory_key: str) -> MemoryRow | None: + return self._fetch_optional_one(GET_MEMORY_BY_KEY_SQL, (memory_key,)) + + def list_memories(self) -> list[MemoryRow]: + return self._fetch_all(LIST_MEMORIES_SQL) + + def count_memories(self, *, status: str | None = None) -> int: + if status is None: + return self._fetch_count(COUNT_MEMORIES_SQL) + return self._fetch_count(COUNT_MEMORIES_BY_STATUS_SQL, (status,)) + + def count_unlabeled_review_memories(self) -> int: + return self._fetch_count(COUNT_UNLABELED_REVIEW_MEMORIES_SQL) + + def list_review_memories(self, *, status: str | None = None, limit: int) -> list[MemoryRow]: + if status is None: + return self._fetch_all(LIST_REVIEW_MEMORIES_SQL, (limit,)) + return self._fetch_all(LIST_REVIEW_MEMORIES_BY_STATUS_SQL, (status, limit)) + + def list_unlabeled_review_memories(self, *, limit: int) -> list[MemoryRow]: + return self._fetch_all(LIST_UNLABELED_REVIEW_MEMORIES_SQL, (limit,)) + + def list_context_memories(self) -> list[MemoryRow]: + return self._fetch_all(LIST_CONTEXT_MEMORIES_SQL) + + def update_memory( + self, + *, + memory_id: UUID, + value: JsonValue, + status: str, + source_event_ids: list[str], + ) -> MemoryRow: + return self._fetch_one( + "update_memory", + UPDATE_MEMORY_SQL, + (Jsonb(value), status, Jsonb(source_event_ids), status, memory_id), + ) + + def append_memory_revision( + self, + *, + memory_id: UUID, + action: str, + memory_key: str, + previous_value: JsonValue | None, + new_value: JsonValue | None, + source_event_ids: list[str], + candidate: JsonObject, + ) -> MemoryRevisionRow: + with self.conn.cursor() as cur: + cur.execute(LOCK_MEMORY_REVISIONS_SQL, (str(memory_id),)) + cur.execute( + INSERT_MEMORY_REVISION_SQL, + ( + memory_id, + memory_id, + action, + memory_key, + Jsonb(previous_value), + Jsonb(new_value), + Jsonb(source_event_ids), + Jsonb(candidate), + ), + ) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "append_memory_revision did not return a row from the database", + ) + + return cast(MemoryRevisionRow, row) + + def count_memory_revisions(self, memory_id: UUID) -> int: + return self._fetch_count(COUNT_MEMORY_REVISIONS_SQL, (memory_id,)) + + def list_memory_revisions( + self, + memory_id: UUID, + *, + limit: int | None = None, + ) -> list[MemoryRevisionRow]: + if limit is None: + return self._fetch_all(LIST_MEMORY_REVISIONS_SQL, (memory_id,)) + return self._fetch_all(LIST_LIMITED_MEMORY_REVISIONS_SQL, (memory_id, limit)) + + def create_memory_review_label( + self, + *, + memory_id: UUID, + label: str, + note: str | None, + ) -> MemoryReviewLabelRow: + return self._fetch_one( + "create_memory_review_label", + INSERT_MEMORY_REVIEW_LABEL_SQL, + (memory_id, label, note), + ) + + def list_memory_review_labels(self, memory_id: UUID) -> list[MemoryReviewLabelRow]: + return self._fetch_all(LIST_MEMORY_REVIEW_LABELS_SQL, (memory_id,)) + + def list_memory_review_label_counts(self, memory_id: UUID) -> list[LabelCountRow]: + return self._fetch_all(LIST_MEMORY_REVIEW_LABEL_COUNTS_SQL, (memory_id,)) + + def count_labeled_memories(self) -> int: + return self._fetch_count(COUNT_LABELED_MEMORIES_SQL) + + def count_unlabeled_memories(self) -> int: + return self._fetch_count(COUNT_UNLABELED_MEMORIES_SQL) + + def list_all_memory_review_label_counts(self) -> list[LabelCountRow]: + return self._fetch_all(LIST_ALL_MEMORY_REVIEW_LABEL_COUNTS_SQL) + + def create_embedding_config( + self, + *, + provider: str, + model: str, + version: str, + dimensions: int, + status: str, + metadata: JsonObject, + ) -> EmbeddingConfigRow: + return self._fetch_one( + "create_embedding_config", + INSERT_EMBEDDING_CONFIG_SQL, + (provider, model, version, dimensions, status, Jsonb(metadata)), + ) + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> EmbeddingConfigRow | None: + return self._fetch_optional_one(GET_EMBEDDING_CONFIG_SQL, (embedding_config_id,)) + + def get_embedding_config_by_identity_optional( + self, + *, + provider: str, + model: str, + version: str, + ) -> EmbeddingConfigRow | None: + return self._fetch_optional_one( + GET_EMBEDDING_CONFIG_BY_IDENTITY_SQL, + (provider, model, version), + ) + + def list_embedding_configs(self) -> list[EmbeddingConfigRow]: + return self._fetch_all(LIST_EMBEDDING_CONFIGS_SQL) + + def create_memory_embedding( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + dimensions: int, + vector: list[float], + ) -> MemoryEmbeddingRow: + return self._fetch_one( + "create_memory_embedding", + INSERT_MEMORY_EMBEDDING_SQL, + (memory_id, embedding_config_id, dimensions, Jsonb(vector)), + ) + + def get_memory_embedding_optional(self, memory_embedding_id: UUID) -> MemoryEmbeddingRow | None: + return self._fetch_optional_one(GET_MEMORY_EMBEDDING_SQL, (memory_embedding_id,)) + + def get_memory_embedding_by_memory_and_config_optional( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + ) -> MemoryEmbeddingRow | None: + return self._fetch_optional_one( + GET_MEMORY_EMBEDDING_BY_MEMORY_AND_CONFIG_SQL, + (memory_id, embedding_config_id), + ) + + def list_memory_embeddings_for_memory(self, memory_id: UUID) -> list[MemoryEmbeddingRow]: + return self._fetch_all(LIST_MEMORY_EMBEDDINGS_FOR_MEMORY_SQL, (memory_id,)) + + def list_memory_embeddings_for_config( + self, + embedding_config_id: UUID, + ) -> list[MemoryEmbeddingRow]: + return self._fetch_all(LIST_MEMORY_EMBEDDINGS_FOR_CONFIG_SQL, (embedding_config_id,)) + + def update_memory_embedding( + self, + *, + memory_embedding_id: UUID, + dimensions: int, + vector: list[float], + ) -> MemoryEmbeddingRow: + return self._fetch_one( + "update_memory_embedding", + UPDATE_MEMORY_EMBEDDING_SQL, + (dimensions, Jsonb(vector), memory_embedding_id), + ) + + def retrieve_semantic_memory_matches( + self, + *, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[SemanticMemoryRetrievalRow]: + return self._fetch_all( + RETRIEVE_SEMANTIC_MEMORY_MATCHES_SQL, + ( + self._vector_literal(query_vector), + embedding_config_id, + len(query_vector), + limit, + ), + ) + + def create_entity( + self, + *, + entity_type: str, + name: str, + source_memory_ids: list[str], + ) -> EntityRow: + return self._fetch_one( + "create_entity", + INSERT_ENTITY_SQL, + (entity_type, name, Jsonb(source_memory_ids)), + ) + + def get_entity_optional(self, entity_id: UUID) -> EntityRow | None: + return self._fetch_optional_one(GET_ENTITY_SQL, (entity_id,)) + + def list_entities(self) -> list[EntityRow]: + return self._fetch_all(LIST_ENTITIES_SQL) + + def create_entity_edge( + self, + *, + from_entity_id: UUID, + to_entity_id: UUID, + relationship_type: str, + valid_from: datetime | None, + valid_to: datetime | None, + source_memory_ids: list[str], + ) -> EntityEdgeRow: + return self._fetch_one( + "create_entity_edge", + INSERT_ENTITY_EDGE_SQL, + ( + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + Jsonb(source_memory_ids), + ), + ) + + def list_entity_edges_for_entity(self, entity_id: UUID) -> list[EntityEdgeRow]: + return self._fetch_all(LIST_ENTITY_EDGES_FOR_ENTITY_SQL, (entity_id, entity_id)) + + def list_entity_edges_for_entities(self, entity_ids: list[UUID]) -> list[EntityEdgeRow]: + if not entity_ids: + return [] + return self._fetch_all(LIST_ENTITY_EDGES_FOR_ENTITIES_SQL, (entity_ids, entity_ids)) + + def create_consent( + self, + *, + consent_key: str, + status: str, + metadata: JsonObject, + ) -> ConsentRow: + return self._fetch_one( + "create_consent", + INSERT_CONSENT_SQL, + (consent_key, status, Jsonb(metadata)), + ) + + def get_consent_by_key_optional(self, consent_key: str) -> ConsentRow | None: + return self._fetch_optional_one(GET_CONSENT_BY_KEY_SQL, (consent_key,)) + + def list_consents(self) -> list[ConsentRow]: + return self._fetch_all(LIST_CONSENTS_SQL) + + def update_consent( + self, + *, + consent_id: UUID, + status: str, + metadata: JsonObject, + ) -> ConsentRow: + return self._fetch_one( + "update_consent", + UPDATE_CONSENT_SQL, + (status, Jsonb(metadata), consent_id), + ) + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: JsonObject, + required_consents: list[str], + ) -> PolicyRow: + return self._fetch_one( + "create_policy", + INSERT_POLICY_SQL, + ( + name, + action, + scope, + effect, + priority, + active, + Jsonb(conditions), + Jsonb(required_consents), + ), + ) + + def get_policy_optional(self, policy_id: UUID) -> PolicyRow | None: + return self._fetch_optional_one(GET_POLICY_SQL, (policy_id,)) + + def list_policies(self) -> list[PolicyRow]: + return self._fetch_all(LIST_POLICIES_SQL) + + def list_active_policies(self) -> list[PolicyRow]: + return self._fetch_all(LIST_ACTIVE_POLICIES_SQL) + + def create_tool( + self, + *, + tool_key: str, + name: str, + description: str, + version: str, + metadata_version: str, + active: bool, + tags: list[str], + action_hints: list[str], + scope_hints: list[str], + domain_hints: list[str], + risk_hints: list[str], + metadata: JsonObject, + ) -> ToolRow: + return self._fetch_one( + "create_tool", + INSERT_TOOL_SQL, + ( + tool_key, + name, + description, + version, + metadata_version, + active, + Jsonb(tags), + Jsonb(action_hints), + Jsonb(scope_hints), + Jsonb(domain_hints), + Jsonb(risk_hints), + Jsonb(metadata), + ), + ) + + def get_tool_optional(self, tool_id: UUID) -> ToolRow | None: + return self._fetch_optional_one(GET_TOOL_SQL, (tool_id,)) + + def list_tools(self) -> list[ToolRow]: + return self._fetch_all(LIST_TOOLS_SQL) + + def list_active_tools(self) -> list[ToolRow]: + return self._fetch_all(LIST_ACTIVE_TOOLS_SQL) + + def create_approval( + self, + *, + thread_id: UUID, + tool_id: UUID, + task_step_id: UUID | None, + status: str, + request: JsonObject, + tool: JsonObject, + routing: JsonObject, + routing_trace_id: UUID, + ) -> ApprovalRow: + return self._fetch_one( + "create_approval", + INSERT_APPROVAL_SQL, + ( + thread_id, + tool_id, + task_step_id, + status, + Jsonb(request), + Jsonb(tool), + Jsonb(routing), + routing_trace_id, + ), + ) + + def get_approval_optional(self, approval_id: UUID) -> ApprovalRow | None: + return self._fetch_optional_one(GET_APPROVAL_SQL, (approval_id,)) + + def list_approvals(self) -> list[ApprovalRow]: + return self._fetch_all(LIST_APPROVALS_SQL) + + def resolve_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> ApprovalRow | None: + return self._fetch_optional_one( + UPDATE_APPROVAL_RESOLUTION_SQL, + (status, approval_id), + ) + + def update_approval_task_step_optional( + self, + *, + approval_id: UUID, + task_step_id: UUID, + ) -> ApprovalRow | None: + return self._fetch_optional_one( + UPDATE_APPROVAL_TASK_STEP_SQL, + (task_step_id, approval_id), + ) + + def create_task( + self, + *, + thread_id: UUID, + tool_id: UUID, + status: str, + request: JsonObject, + tool: JsonObject, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> TaskRow: + return self._fetch_one( + "create_task", + INSERT_TASK_SQL, + ( + thread_id, + tool_id, + status, + Jsonb(request), + Jsonb(tool), + latest_approval_id, + latest_execution_id, + ), + ) + + def get_task_optional(self, task_id: UUID) -> TaskRow | None: + return self._fetch_optional_one(GET_TASK_SQL, (task_id,)) + + def get_task_by_approval_optional(self, approval_id: UUID) -> TaskRow | None: + return self._fetch_optional_one(GET_TASK_BY_APPROVAL_SQL, (approval_id,)) + + def list_tasks(self) -> list[TaskRow]: + return self._fetch_all(LIST_TASKS_SQL) + + def update_task_status_by_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> TaskRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STATUS_BY_APPROVAL_SQL, + (status, approval_id), + ) + + def update_task_execution_by_approval_optional( + self, + *, + approval_id: UUID, + latest_execution_id: UUID, + status: str, + ) -> TaskRow | None: + return self._fetch_optional_one( + UPDATE_TASK_EXECUTION_BY_APPROVAL_SQL, + (status, latest_execution_id, approval_id), + ) + + def update_task_status_optional( + self, + *, + task_id: UUID, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> TaskRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STATUS_SQL, + (status, latest_approval_id, latest_execution_id, task_id), + ) + + def lock_task_workspaces(self, task_id: UUID) -> None: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_WORKSPACES_SQL, (str(task_id),)) + + def create_task_workspace( + self, + *, + task_id: UUID, + status: str, + local_path: str, + ) -> TaskWorkspaceRow: + return self._fetch_one( + "create_task_workspace", + INSERT_TASK_WORKSPACE_SQL, + (task_id, status, local_path), + ) + + def get_task_workspace_optional(self, task_workspace_id: UUID) -> TaskWorkspaceRow | None: + return self._fetch_optional_one(GET_TASK_WORKSPACE_SQL, (task_workspace_id,)) + + def get_active_task_workspace_for_task_optional(self, task_id: UUID) -> TaskWorkspaceRow | None: + return self._fetch_optional_one(GET_ACTIVE_TASK_WORKSPACE_FOR_TASK_SQL, (task_id,)) + + def list_task_workspaces(self) -> list[TaskWorkspaceRow]: + return self._fetch_all(LIST_TASK_WORKSPACES_SQL) + + def lock_task_steps(self, task_id: UUID) -> None: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_STEPS_SQL, (str(task_id),)) + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: JsonObject, + outcome: JsonObject, + trace_id: UUID, + trace_kind: str, + ) -> TaskStepRow: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_STEPS_SQL, (str(task_id),)) + cur.execute( + INSERT_TASK_STEP_SQL, + ( + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + Jsonb(request), + Jsonb(outcome), + trace_id, + trace_kind, + ), + ) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "create_task_step did not return a row from the database", + ) + + return cast(TaskStepRow, row) + + def get_task_step_optional(self, task_step_id: UUID) -> TaskStepRow | None: + return self._fetch_optional_one(GET_TASK_STEP_SQL, (task_step_id,)) + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> TaskStepRow | None: + return self._fetch_optional_one( + GET_TASK_STEP_FOR_TASK_SEQUENCE_SQL, + (task_id, sequence_no), + ) + + def list_task_steps_for_task(self, task_id: UUID) -> list[TaskStepRow]: + return self._fetch_all(LIST_TASK_STEPS_FOR_TASK_SQL, (task_id,)) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: JsonObject, + trace_id: UUID, + trace_kind: str, + ) -> TaskStepRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STEP_FOR_TASK_SEQUENCE_SQL, + ( + status, + Jsonb(outcome), + trace_id, + trace_kind, + task_id, + sequence_no, + ), + ) + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: JsonObject, + trace_id: UUID, + trace_kind: str, + ) -> TaskStepRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STEP_SQL, + ( + status, + Jsonb(outcome), + trace_id, + trace_kind, + task_step_id, + ), + ) + + def create_tool_execution( + self, + *, + approval_id: UUID, + task_step_id: UUID, + thread_id: UUID, + tool_id: UUID, + trace_id: UUID, + request_event_id: UUID | None, + result_event_id: UUID | None, + status: str, + handler_key: str | None, + request: JsonObject, + tool: JsonObject, + result: JsonObject, + ) -> ToolExecutionRow: + return self._fetch_one( + "create_tool_execution", + INSERT_TOOL_EXECUTION_SQL, + ( + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + Jsonb(request), + Jsonb(tool), + Jsonb(result), + ), + ) + + def get_tool_execution_optional(self, execution_id: UUID) -> ToolExecutionRow | None: + return self._fetch_optional_one(GET_TOOL_EXECUTION_SQL, (execution_id,)) + + def list_tool_executions(self) -> list[ToolExecutionRow]: + return self._fetch_all(LIST_TOOL_EXECUTIONS_SQL) + + def create_execution_budget( + self, + *, + budget_id: UUID | None = None, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, + supersedes_budget_id: UUID | None = None, + ) -> ExecutionBudgetRow: + return self._fetch_one( + "create_execution_budget", + INSERT_EXECUTION_BUDGET_SQL, + ( + budget_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + supersedes_budget_id, + ), + ) + + def get_execution_budget_optional(self, execution_budget_id: UUID) -> ExecutionBudgetRow | None: + return self._fetch_optional_one(GET_EXECUTION_BUDGET_SQL, (execution_budget_id,)) + + def list_execution_budgets(self) -> list[ExecutionBudgetRow]: + return self._fetch_all(LIST_EXECUTION_BUDGETS_SQL) + + def deactivate_execution_budget_optional( + self, + execution_budget_id: UUID, + ) -> ExecutionBudgetRow | None: + return self._fetch_optional_one(DEACTIVATE_EXECUTION_BUDGET_SQL, (execution_budget_id,)) + + def supersede_execution_budget_optional( + self, + *, + execution_budget_id: UUID, + superseded_by_budget_id: UUID, + ) -> ExecutionBudgetRow | None: + return self._fetch_optional_one( + SUPERSEDE_EXECUTION_BUDGET_SQL, + ( + superseded_by_budget_id, + execution_budget_id, + ), + ) + + def update_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(UPDATE_EVENT_ERROR) + + def delete_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(DELETE_EVENT_ERROR) + + def update_trace_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(UPDATE_TRACE_EVENT_ERROR) + + def delete_trace_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(DELETE_TRACE_EVENT_ERROR) diff --git a/apps/api/src/alicebot_api/tasks.py b/apps/api/src/alicebot_api/tasks.py new file mode 100644 index 0000000..da88e5f --- /dev/null +++ b/apps/api/src/alicebot_api/tasks.py @@ -0,0 +1,1170 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import cast +from uuid import UUID + +import psycopg + +from alicebot_api.contracts import ( + TASK_LIST_ORDER, + TASK_STEP_CONTINUATION_VERSION_V0, + TASK_STEP_LIST_ORDER, + TASK_STEP_TRANSITION_VERSION_V0, + TRACE_KIND_TASK_STEP_CONTINUATION, + TRACE_KIND_TASK_STEP_TRANSITION, + TaskCreateInput, + TaskCreateResponse, + TaskDetailResponse, + TaskLifecycleSource, + TaskLifecycleStateTracePayload, + TaskLifecycleSummaryTracePayload, + TaskListResponse, + TaskListSummary, + TaskRecord, + TaskStatus, + TaskStepCreateInput, + TaskStepCreateResponse, + TaskStepDetailResponse, + TaskStepContinuationLineageTracePayload, + TaskStepContinuationRequestTracePayload, + TaskStepContinuationSummaryTracePayload, + TaskStepLifecycleStateTracePayload, + TaskStepLifecycleSummaryTracePayload, + TaskStepLineageRecord, + TaskStepListSummary, + TaskStepListResponse, + TaskStepMutationTraceSummary, + TaskStepNextCreateInput, + TaskStepNextCreateResponse, + TaskStepOutcomeSnapshot, + TaskStepRecord, + TaskStepStatus, + TaskStepTransitionInput, + TaskStepTransitionRequestTracePayload, + TaskStepTransitionResponse, + TaskStepTransitionStateTracePayload, + TaskStepTransitionSummaryTracePayload, +) +from alicebot_api.store import ( + ContinuityStore, + ContinuityStoreInvariantError, + TaskRow, + TaskStepRow, + ToolExecutionRow, +) + +TASK_LIFECYCLE_STATE_EVENT_KIND = "task.lifecycle.state" +TASK_LIFECYCLE_SUMMARY_EVENT_KIND = "task.lifecycle.summary" +TASK_STEP_LIFECYCLE_STATE_EVENT_KIND = "task.step.lifecycle.state" +TASK_STEP_LIFECYCLE_SUMMARY_EVENT_KIND = "task.step.lifecycle.summary" +TASK_STEP_CONTINUATION_REQUEST_EVENT_KIND = "task.step.continuation.request" +TASK_STEP_CONTINUATION_LINEAGE_EVENT_KIND = "task.step.continuation.lineage" +TASK_STEP_CONTINUATION_SUMMARY_EVENT_KIND = "task.step.continuation.summary" +TASK_STEP_TRANSITION_REQUEST_EVENT_KIND = "task.step.transition.request" +TASK_STEP_TRANSITION_STATE_EVENT_KIND = "task.step.transition.state" +TASK_STEP_TRANSITION_SUMMARY_EVENT_KIND = "task.step.transition.summary" +DEFAULT_TASK_STEP_SEQUENCE_NO = 1 +DEFAULT_TASK_STEP_KIND = "governed_request" +TASK_STEP_APPENDABLE_STATUSES = frozenset({"executed", "blocked", "denied"}) +TASK_STEP_INITIAL_STATUSES = frozenset({"created", "approved", "denied"}) +TASK_STEP_STATUS_GRAPH: dict[TaskStepStatus, tuple[TaskStepStatus, ...]] = { + "created": ("approved", "denied"), + "approved": ("executed", "blocked"), + "executed": (), + "blocked": (), + "denied": (), +} + + +class TaskNotFoundError(LookupError): + """Raised when a task record is not visible inside the current user scope.""" + + +class TaskStepNotFoundError(LookupError): + """Raised when a task-step record is not visible inside the current user scope.""" + + +class TaskStepSequenceError(RuntimeError): + """Raised when a task-step append request violates deterministic sequencing rules.""" + + +class TaskStepTransitionError(RuntimeError): + """Raised when a task-step transition request violates the explicit status graph.""" + + +class TaskStepLifecycleBoundaryError(RuntimeError): + """Raised when first-step-only lifecycle helpers are routed a later-step context.""" + + +class TaskStepApprovalLinkageError(RuntimeError): + """Raised when approval resolution cannot validate its linked task step.""" + + +class TaskStepExecutionLinkageError(RuntimeError): + """Raised when execution synchronization cannot validate its linked task step.""" + + +@dataclass(frozen=True, slots=True) +class TaskTransitionResult: + task: TaskRecord + previous_status: TaskStatus | None + + +@dataclass(frozen=True, slots=True) +class TaskStepTransitionResult: + task_step: TaskStepRecord + previous_status: TaskStepStatus | None + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _trace_summary( + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> TaskStepMutationTraceSummary: + return { + "trace_id": str(trace_id), + "trace_event_count": len(trace_events), + } + + +def validate_linked_task_step_for_approval( + store: ContinuityStore, + *, + approval_id: UUID, + task_step_id: UUID | None, +) -> tuple[TaskRow, TaskStepRow]: + if task_step_id is None: + raise TaskStepApprovalLinkageError(f"approval {approval_id} is missing linked task_step_id") + + unlocked_task = store.get_task_by_approval_optional(approval_id) + if unlocked_task is None: + raise TaskStepApprovalLinkageError(f"approval {approval_id} is not linked to a visible task") + store.lock_task_steps(cast(UUID, unlocked_task["id"])) + + task = store.get_task_optional(cast(UUID, unlocked_task["id"])) + if task is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared during approval linkage validation" + ) + + task_step = store.get_task_step_optional(task_step_id) + if task_step is None: + raise TaskStepApprovalLinkageError( + f"approval {approval_id} references linked task step {task_step_id} that was not found" + ) + if task_step["task_id"] != task["id"]: + raise TaskStepApprovalLinkageError( + f"approval {approval_id} links task step {task_step_id} outside task {task['id']}" + ) + + outcome = cast(TaskStepOutcomeSnapshot, task_step["outcome"]) + if outcome["approval_id"] != str(approval_id): + raise TaskStepApprovalLinkageError( + f"approval {approval_id} is inconsistent with linked task step {task_step_id}" + ) + + return task, task_step + + +def validate_linked_task_step_for_execution( + store: ContinuityStore, + *, + task_id: UUID, + execution: ToolExecutionRow, +) -> TaskStepRow: + store.lock_task_steps(task_id) + + execution_id = cast(UUID, execution["id"]) + task_step_id = cast(UUID | None, execution["task_step_id"]) + if task_step_id is None: + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} is missing linked task_step_id" + ) + + task_step = store.get_task_step_optional(task_step_id) + if task_step is None: + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} references linked task step {task_step_id} that was not found" + ) + if task_step["task_id"] != task_id: + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} links task step {task_step_id} outside task {task_id}" + ) + + outcome = cast(TaskStepOutcomeSnapshot, task_step["outcome"]) + if outcome["approval_id"] != str(execution["approval_id"]): + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} is inconsistent with linked task step {task_step_id}" + ) + + return task_step + + +def serialize_task_row(row: TaskRow) -> TaskRecord: + return { + "id": str(row["id"]), + "thread_id": str(row["thread_id"]), + "tool_id": str(row["tool_id"]), + "status": cast(TaskStatus, row["status"]), + "request": cast(dict[str, object], row["request"]), + "tool": cast(dict[str, object], row["tool"]), + "latest_approval_id": None if row["latest_approval_id"] is None else str(row["latest_approval_id"]), + "latest_execution_id": None if row["latest_execution_id"] is None else str(row["latest_execution_id"]), + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def serialize_task_step_row(row: TaskStepRow) -> TaskStepRecord: + return { + "id": str(row["id"]), + "task_id": str(row["task_id"]), + "sequence_no": row["sequence_no"], + "lineage": { + "parent_step_id": None if row["parent_step_id"] is None else str(row["parent_step_id"]), + "source_approval_id": ( + None if row["source_approval_id"] is None else str(row["source_approval_id"]) + ), + "source_execution_id": ( + None if row["source_execution_id"] is None else str(row["source_execution_id"]) + ), + }, + "kind": cast(str, row["kind"]), + "status": cast(TaskStepStatus, row["status"]), + "request": cast(dict[str, object], row["request"]), + "outcome": cast(TaskStepOutcomeSnapshot, row["outcome"]), + "trace": { + "trace_id": str(row["trace_id"]), + "trace_kind": row["trace_kind"], + }, + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def task_status_for_routing_decision(decision: str) -> TaskStatus: + return { + "approval_required": "pending_approval", + "ready": "approved", + "denied": "denied", + }[decision] + + +def task_status_for_approval_status(approval_status: str) -> TaskStatus: + return { + "pending": "pending_approval", + "approved": "approved", + "rejected": "denied", + }[approval_status] + + +def next_task_status_for_approval( + *, + current_status: TaskStatus, + approval_status: str, +) -> TaskStatus: + if current_status in {"executed", "blocked"}: + return current_status + return task_status_for_approval_status(approval_status) + + +def task_status_for_execution_status(execution_status: str) -> TaskStatus: + return { + "completed": "executed", + "blocked": "blocked", + }[execution_status] + + +def task_status_for_step_status(step_status: TaskStepStatus) -> TaskStatus: + return { + "created": "pending_approval", + "approved": "approved", + "executed": "executed", + "blocked": "blocked", + "denied": "denied", + }[step_status] + + +def task_step_status_for_routing_decision(decision: str) -> TaskStepStatus: + return { + "approval_required": "created", + "ready": "approved", + "denied": "denied", + }[decision] + + +def task_step_status_for_approval_status(approval_status: str) -> TaskStepStatus: + return { + "pending": "created", + "approved": "approved", + "rejected": "denied", + }[approval_status] + + +def next_task_step_status_for_approval( + *, + current_status: TaskStepStatus, + approval_status: str, +) -> TaskStepStatus: + if current_status in {"executed", "blocked"}: + return current_status + return task_step_status_for_approval_status(approval_status) + + +def task_step_status_for_execution_status(execution_status: str) -> TaskStepStatus: + return { + "completed": "executed", + "blocked": "blocked", + }[execution_status] + + +def allowed_task_step_transitions(current_status: TaskStepStatus) -> list[TaskStepStatus]: + return list(TASK_STEP_STATUS_GRAPH[current_status]) + + +def task_step_outcome_snapshot( + *, + routing_decision: str, + approval_id: str | None, + approval_status: str | None, + execution_id: str | None, + execution_status: str | None, + blocked_reason: str | None, +) -> TaskStepOutcomeSnapshot: + return { + "routing_decision": cast(str, routing_decision), + "approval_id": approval_id, + "approval_status": cast(str | None, approval_status), + "execution_id": execution_id, + "execution_status": cast(str | None, execution_status), + "blocked_reason": blocked_reason, + } + + +def create_task_for_governed_request( + store: ContinuityStore, + *, + request: TaskCreateInput, +) -> TaskCreateResponse: + task = store.create_task( + thread_id=request.thread_id, + tool_id=request.tool_id, + status=request.status, + request=cast(dict[str, object], request.request), + tool=cast(dict[str, object], request.tool), + latest_approval_id=request.latest_approval_id, + latest_execution_id=request.latest_execution_id, + ) + return {"task": serialize_task_row(task)} + + +def create_task_step_for_governed_request( + store: ContinuityStore, + *, + request: TaskStepCreateInput, +) -> TaskStepCreateResponse: + task_step = store.create_task_step( + task_id=request.task_id, + sequence_no=request.sequence_no, + kind=request.kind, + status=request.status, + request=cast(dict[str, object], request.request), + outcome=cast(dict[str, object], request.outcome), + trace_id=request.trace_id, + trace_kind=request.trace_kind, + ) + return {"task_step": serialize_task_step_row(task_step)} + + +def _task_step_sequencing_summary( + *, + task_id: str, + items: list[TaskStepRecord], +) -> TaskStepListSummary: + latest = items[-1] if items else None + latest_status = None if latest is None else latest["status"] + latest_sequence_no = None if latest is None else latest["sequence_no"] + return { + "task_id": task_id, + "total_count": len(items), + "latest_sequence_no": latest_sequence_no, + "latest_status": latest_status, + "next_sequence_no": 1 if latest_sequence_no is None else latest_sequence_no + 1, + "append_allowed": latest_status in TASK_STEP_APPENDABLE_STATUSES if latest_status is not None else False, + "order": list(TASK_STEP_LIST_ORDER), + } + + +def _validated_optional_approval_id( + store: ContinuityStore, + *, + approval_id: str | None, + current_approval_id: UUID | None, + task: TaskRow, + require_existing: bool, + missing_error: str, + error_cls: type[TaskStepSequenceError] | type[TaskStepTransitionError], +) -> UUID | None: + def _approval_belongs_to_task(approval_uuid: UUID) -> bool: + if current_approval_id == approval_uuid: + return True + for task_step in store.list_task_steps_for_task(task["id"]): + outcome = cast(dict[str, object], task_step["outcome"]) + linked_approval_id = outcome.get("approval_id") + if linked_approval_id is not None and str(linked_approval_id) == str(approval_uuid): + return True + return False + + if approval_id is None: + if require_existing and current_approval_id is None: + raise error_cls(missing_error) + approval_uuid = current_approval_id + else: + approval_uuid = UUID(approval_id) + if not _approval_belongs_to_task(approval_uuid): + raise error_cls(f"approval {approval_uuid} does not belong to task {task['id']}") + + if approval_uuid is None: + return None + + approval_row = store.get_approval_optional(approval_uuid) + if approval_row is None: + raise error_cls(f"approval {approval_uuid} was not found") + return approval_uuid + + +def _validated_optional_execution_id( + store: ContinuityStore, + *, + execution_id: str | None, + current_execution_id: UUID | None, + task: TaskRow, + require_existing: bool, + missing_error: str, + error_cls: type[TaskStepSequenceError] | type[TaskStepTransitionError], +) -> UUID | None: + def _execution_belongs_to_task(execution_uuid: UUID) -> bool: + if current_execution_id == execution_uuid: + return True + for task_step in store.list_task_steps_for_task(task["id"]): + outcome = cast(dict[str, object], task_step["outcome"]) + linked_execution_id = outcome.get("execution_id") + if linked_execution_id is not None and str(linked_execution_id) == str(execution_uuid): + return True + return False + + if execution_id is None: + if require_existing and current_execution_id is None: + raise error_cls(missing_error) + execution_uuid = current_execution_id + else: + execution_uuid = UUID(execution_id) + if not _execution_belongs_to_task(execution_uuid): + raise error_cls(f"tool execution {execution_uuid} does not belong to task {task['id']}") + + if execution_uuid is None: + return None + + execution_row = store.get_tool_execution_optional(execution_uuid) + if execution_row is None: + raise error_cls(f"tool execution {execution_uuid} was not found") + return execution_uuid + + +def _validated_continuation_parent_step( + *, + task_id: UUID, + latest: TaskStepRecord, + existing_items: list[TaskStepRecord], + parent_step_id: UUID, +) -> TaskStepRecord: + parent_step = next( + ( + item + for item in existing_items + if item["id"] == str(parent_step_id) + ), + None, + ) + if parent_step is None: + raise TaskStepSequenceError(f"task step {parent_step_id} does not belong to task {task_id}") + if parent_step["id"] != latest["id"]: + raise TaskStepSequenceError( + f"task {task_id} continuation must reference latest step {latest['id']}; received {parent_step_id}" + ) + return parent_step + + +def _validated_continuation_lineage( + *, + parent_step: TaskStepRecord, + source_approval_id: UUID | None, + source_execution_id: UUID | None, +) -> TaskStepLineageRecord: + parent_outcome = parent_step["outcome"] + if source_approval_id is not None and parent_outcome["approval_id"] != str(source_approval_id): + raise TaskStepSequenceError( + f"approval {source_approval_id} is not linked from parent step {parent_step['id']}" + ) + if source_execution_id is not None and parent_outcome["execution_id"] != str(source_execution_id): + raise TaskStepSequenceError( + f"tool execution {source_execution_id} is not linked from parent step {parent_step['id']}" + ) + + return { + "parent_step_id": parent_step["id"], + "source_approval_id": None if source_approval_id is None else str(source_approval_id), + "source_execution_id": None if source_execution_id is None else str(source_execution_id), + } + + +def sync_task_with_task_step_status( + store: ContinuityStore, + *, + task_id: UUID, + task_step_status: TaskStepStatus, + linked_approval_id: UUID | None, + linked_execution_id: UUID | None, +) -> TaskTransitionResult: + current = store.get_task_optional(task_id) + if current is None: + raise ContinuityStoreInvariantError( + f"task {task_id} disappeared before task-step lifecycle synchronization" + ) + previous_status = cast(TaskStatus, current["status"]) + target_status = task_status_for_step_status(task_step_status) + latest_execution_id = ( + current["latest_execution_id"] if linked_execution_id is None else linked_execution_id + ) if target_status in {"executed", "blocked"} else None + updated = store.update_task_status_optional( + task_id=task_id, + status=target_status, + latest_approval_id=linked_approval_id, + latest_execution_id=latest_execution_id, + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"task {task_id} disappeared during task-step lifecycle synchronization" + ) + return TaskTransitionResult( + task=serialize_task_row(updated), + previous_status=previous_status, + ) + + +def sync_task_with_approval( + store: ContinuityStore, + *, + approval_id: UUID, + approval_status: str, +) -> TaskTransitionResult: + current = store.get_task_by_approval_optional(approval_id) + if current is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared before lifecycle synchronization" + ) + previous_status = cast(TaskStatus, current["status"]) + + updated = store.update_task_status_by_approval_optional( + approval_id=approval_id, + status=next_task_status_for_approval( + current_status=previous_status, + approval_status=approval_status, + ), + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared during lifecycle synchronization" + ) + + return TaskTransitionResult( + task=serialize_task_row(updated), + previous_status=previous_status, + ) + + +def sync_task_step_with_approval( + store: ContinuityStore, + *, + approval_id: UUID, + task_step_id: UUID | None, + approval_status: str, + trace_id: UUID, + trace_kind: str, +) -> TaskStepTransitionResult: + _, current = validate_linked_task_step_for_approval( + store, + approval_id=approval_id, + task_step_id=task_step_id, + ) + previous_status = cast(TaskStepStatus, current["status"]) + current_outcome = cast(TaskStepOutcomeSnapshot, current["outcome"]) + updated_outcome = task_step_outcome_snapshot( + routing_decision=current_outcome["routing_decision"], + approval_id=str(approval_id), + approval_status=approval_status, + execution_id=current_outcome["execution_id"], + execution_status=current_outcome["execution_status"], + blocked_reason=current_outcome["blocked_reason"], + ) + + updated = store.update_task_step_optional( + task_step_id=cast(UUID, current["id"]), + status=next_task_step_status_for_approval( + current_status=previous_status, + approval_status=approval_status, + ), + outcome=cast(dict[str, object], updated_outcome), + trace_id=trace_id, + trace_kind=trace_kind, + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"linked task step {current['id']} disappeared during approval lifecycle synchronization" + ) + + return TaskStepTransitionResult( + task_step=serialize_task_step_row(updated), + previous_status=previous_status, + ) + + +def sync_task_with_execution( + store: ContinuityStore, + *, + approval_id: UUID, + execution_id: UUID, + execution_status: str, +) -> TaskTransitionResult: + current = store.get_task_by_approval_optional(approval_id) + if current is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared before execution synchronization" + ) + previous_status = cast(TaskStatus, current["status"]) + + updated = store.update_task_execution_by_approval_optional( + approval_id=approval_id, + latest_execution_id=execution_id, + status=task_status_for_execution_status(execution_status), + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared during execution synchronization" + ) + + return TaskTransitionResult( + task=serialize_task_row(updated), + previous_status=previous_status, + ) + + +def sync_task_step_with_execution( + store: ContinuityStore, + *, + task_id: UUID, + execution: ToolExecutionRow, + trace_id: UUID, + trace_kind: str, +) -> TaskStepTransitionResult: + current = validate_linked_task_step_for_execution( + store, + task_id=task_id, + execution=execution, + ) + previous_status = cast(TaskStepStatus, current["status"]) + current_outcome = cast(TaskStepOutcomeSnapshot, current["outcome"]) + execution_result = cast(dict[str, object], execution["result"]) + updated_outcome = task_step_outcome_snapshot( + routing_decision=current_outcome["routing_decision"], + approval_id=current_outcome["approval_id"], + approval_status=current_outcome["approval_status"], + execution_id=str(execution["id"]), + execution_status=cast(str, execution["status"]), + blocked_reason=cast(str | None, execution_result.get("reason")), + ) + + updated = store.update_task_step_optional( + task_step_id=cast(UUID, current["id"]), + status=task_step_status_for_execution_status(cast(str, execution["status"])), + outcome=cast(dict[str, object], updated_outcome), + trace_id=trace_id, + trace_kind=trace_kind, + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"linked task step {current['id']} disappeared during execution lifecycle synchronization" + ) + + return TaskStepTransitionResult( + task_step=serialize_task_step_row(updated), + previous_status=previous_status, + ) + + +def create_next_task_step_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskStepNextCreateInput, +) -> TaskStepNextCreateResponse: + del user_id + + task_row = store.get_task_optional(request.task_id) + if task_row is None: + raise TaskNotFoundError(f"task {request.task_id} was not found") + + store.lock_task_steps(request.task_id) + existing_items = [serialize_task_step_row(row) for row in store.list_task_steps_for_task(request.task_id)] + if not existing_items: + raise TaskStepSequenceError(f"task {request.task_id} has no existing steps and cannot append a next step") + + latest = existing_items[-1] + if latest["status"] not in TASK_STEP_APPENDABLE_STATUSES: + raise TaskStepSequenceError( + f"task {request.task_id} latest step {latest['id']} is {latest['status']} and cannot append a next step" + ) + if request.status not in TASK_STEP_INITIAL_STATUSES: + allowed = ", ".join(sorted(TASK_STEP_INITIAL_STATUSES)) + raise TaskStepSequenceError( + f"new task step for task {request.task_id} must start in one of {allowed}; received {request.status}" + ) + parent_step = _validated_continuation_parent_step( + task_id=request.task_id, + latest=latest, + existing_items=existing_items, + parent_step_id=request.lineage.parent_step_id, + ) + source_approval_id = _validated_optional_approval_id( + store, + approval_id=( + None if request.lineage.source_approval_id is None else str(request.lineage.source_approval_id) + ), + current_approval_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + source_execution_id = _validated_optional_execution_id( + store, + execution_id=( + None if request.lineage.source_execution_id is None else str(request.lineage.source_execution_id) + ), + current_execution_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + lineage = _validated_continuation_lineage( + parent_step=parent_step, + source_approval_id=source_approval_id, + source_execution_id=source_execution_id, + ) + linked_approval_id = _validated_optional_approval_id( + store, + approval_id=request.outcome["approval_id"], + current_approval_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + linked_execution_id = _validated_optional_execution_id( + store, + execution_id=request.outcome["execution_id"], + current_execution_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + + trace = store.create_trace( + user_id=task_row["user_id"], + thread_id=task_row["thread_id"], + kind=TRACE_KIND_TASK_STEP_CONTINUATION, + compiler_version=TASK_STEP_CONTINUATION_VERSION_V0, + status="completed", + limits={ + "order": list(TASK_STEP_LIST_ORDER), + "appendable_statuses": sorted(TASK_STEP_APPENDABLE_STATUSES), + "initial_statuses": sorted(TASK_STEP_INITIAL_STATUSES), + "parent_step_id": parent_step["id"], + "parent_sequence_no": parent_step["sequence_no"], + }, + ) + try: + created = store.create_task_step( + task_id=request.task_id, + sequence_no=latest["sequence_no"] + 1, + parent_step_id=request.lineage.parent_step_id, + source_approval_id=source_approval_id, + source_execution_id=source_execution_id, + kind=request.kind, + status=request.status, + request=cast(dict[str, object], request.request), + outcome=cast(dict[str, object], request.outcome), + trace_id=trace["id"], + trace_kind=TRACE_KIND_TASK_STEP_CONTINUATION, + ) + except psycopg.IntegrityError as exc: + raise TaskStepSequenceError( + f"task {request.task_id} next-step creation conflicted with a concurrent append" + ) from exc + task_step = serialize_task_step_row(created) + task_transition = sync_task_with_task_step_status( + store, + task_id=request.task_id, + task_step_status=request.status, + linked_approval_id=( + source_approval_id if request.status == "created" and linked_approval_id is None else linked_approval_id + ), + linked_execution_id=linked_execution_id, + ) + updated_items = [*existing_items, task_step] + sequencing = _task_step_sequencing_summary( + task_id=str(task_row["id"]), + items=updated_items, + ) + + request_payload: TaskStepContinuationRequestTracePayload = { + "task_id": str(task_row["id"]), + "parent_task_step_id": parent_step["id"], + "parent_sequence_no": parent_step["sequence_no"], + "parent_status": parent_step["status"], + "requested_kind": request.kind, + "requested_status": request.status, + "requested_source_approval_id": lineage["source_approval_id"], + "requested_source_execution_id": lineage["source_execution_id"], + } + lineage_payload: TaskStepContinuationLineageTracePayload = { + "task_id": str(task_row["id"]), + "parent_task_step_id": parent_step["id"], + "parent_sequence_no": parent_step["sequence_no"], + "parent_status": parent_step["status"], + "source_approval_id": lineage["source_approval_id"], + "source_execution_id": lineage["source_execution_id"], + } + summary_payload: TaskStepContinuationSummaryTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": task_step["id"], + "latest_sequence_no": task_step["sequence_no"], + "next_sequence_no": sequencing["next_sequence_no"], + "append_allowed": sequencing["append_allowed"], + "lineage": task_step["lineage"], + } + trace_events: list[tuple[str, dict[str, object]]] = [ + (TASK_STEP_CONTINUATION_REQUEST_EVENT_KIND, cast(dict[str, object], request_payload)), + (TASK_STEP_CONTINUATION_LINEAGE_EVENT_KIND, cast(dict[str, object], lineage_payload)), + (TASK_STEP_CONTINUATION_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="task_step_continuation", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step, + previous_status=None, + source="task_step_continuation", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + return { + "task": task_transition.task, + "task_step": task_step, + "sequencing": sequencing, + "trace": _trace_summary(trace["id"], trace_events), + } + + +def transition_task_step_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskStepTransitionInput, +) -> TaskStepTransitionResponse: + del user_id + + step_row = store.get_task_step_optional(request.task_step_id) + if step_row is None: + raise TaskStepNotFoundError(f"task step {request.task_step_id} was not found") + + task_row = store.get_task_optional(step_row["task_id"]) + if task_row is None: + raise ContinuityStoreInvariantError( + f"task {step_row['task_id']} disappeared before task-step transition" + ) + + existing_items = [serialize_task_step_row(row) for row in store.list_task_steps_for_task(step_row["task_id"])] + latest = existing_items[-1] if existing_items else None + if latest is None: + raise ContinuityStoreInvariantError( + f"task {step_row['task_id']} has no visible steps during transition" + ) + if latest["id"] != str(step_row["id"]): + raise TaskStepTransitionError( + f"task step {request.task_step_id} is not the latest step on task {step_row['task_id']}" + ) + + previous_status = cast(TaskStepStatus, step_row["status"]) + allowed_next_statuses = allowed_task_step_transitions(previous_status) + if request.status not in allowed_next_statuses: + allowed = ", ".join(allowed_next_statuses) or "no further statuses" + raise TaskStepTransitionError( + f"task step {request.task_step_id} is {previous_status} and cannot transition to {request.status}; allowed: {allowed}" + ) + linked_approval_id = _validated_optional_approval_id( + store, + approval_id=request.outcome["approval_id"], + current_approval_id=task_row["latest_approval_id"], + task=task_row, + require_existing=request.status == "created", + missing_error=f"task {task_row['id']} cannot reflect created without an approval link", + error_cls=TaskStepTransitionError, + ) + linked_execution_id = _validated_optional_execution_id( + store, + execution_id=request.outcome["execution_id"], + current_execution_id=task_row["latest_execution_id"], + task=task_row, + require_existing=request.status in {"executed", "blocked"}, + missing_error=f"task {task_row['id']} cannot reflect {request.status} without an existing execution link", + error_cls=TaskStepTransitionError, + ) + + trace = store.create_trace( + user_id=task_row["user_id"], + thread_id=task_row["thread_id"], + kind=TRACE_KIND_TASK_STEP_TRANSITION, + compiler_version=TASK_STEP_TRANSITION_VERSION_V0, + status="completed", + limits={ + "order": list(TASK_STEP_LIST_ORDER), + "status_graph": {status: list(next_statuses) for status, next_statuses in TASK_STEP_STATUS_GRAPH.items()}, + "requested_status": request.status, + }, + ) + updated_row = store.update_task_step_for_task_sequence_optional( + task_id=step_row["task_id"], + sequence_no=step_row["sequence_no"], + status=request.status, + outcome=cast(dict[str, object], request.outcome), + trace_id=trace["id"], + trace_kind=TRACE_KIND_TASK_STEP_TRANSITION, + ) + if updated_row is None: + raise ContinuityStoreInvariantError( + f"task step {request.task_step_id} disappeared during transition" + ) + + updated_step = serialize_task_step_row(updated_row) + task_transition = sync_task_with_task_step_status( + store, + task_id=step_row["task_id"], + task_step_status=request.status, + linked_approval_id=linked_approval_id, + linked_execution_id=linked_execution_id, + ) + updated_items = [*existing_items[:-1], updated_step] + sequencing = _task_step_sequencing_summary( + task_id=str(task_row["id"]), + items=updated_items, + ) + + request_payload: TaskStepTransitionRequestTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": updated_step["id"], + "sequence_no": updated_step["sequence_no"], + "previous_status": previous_status, + "requested_status": request.status, + } + state_payload: TaskStepTransitionStateTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": updated_step["id"], + "sequence_no": updated_step["sequence_no"], + "previous_status": previous_status, + "current_status": updated_step["status"], + "allowed_next_statuses": allowed_next_statuses, + "trace": updated_step["trace"], + } + summary_payload: TaskStepTransitionSummaryTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": updated_step["id"], + "sequence_no": updated_step["sequence_no"], + "final_status": updated_step["status"], + "parent_task_status": task_transition.task["status"], + "trace": updated_step["trace"], + } + trace_events: list[tuple[str, dict[str, object]]] = [ + (TASK_STEP_TRANSITION_REQUEST_EVENT_KIND, cast(dict[str, object], request_payload)), + (TASK_STEP_TRANSITION_STATE_EVENT_KIND, cast(dict[str, object], state_payload)), + (TASK_STEP_TRANSITION_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="task_step_transition", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=updated_step, + previous_status=previous_status, + source="task_step_transition", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + return { + "task": task_transition.task, + "task_step": updated_step, + "sequencing": sequencing, + "trace": _trace_summary(trace["id"], trace_events), + } + + +def task_lifecycle_trace_events( + *, + task: TaskRecord, + previous_status: TaskStatus | None, + source: TaskLifecycleSource, +) -> list[tuple[str, dict[str, object]]]: + state_payload: TaskLifecycleStateTracePayload = { + "task_id": task["id"], + "source": source, + "previous_status": previous_status, + "current_status": task["status"], + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + } + summary_payload: TaskLifecycleSummaryTracePayload = { + "task_id": task["id"], + "source": source, + "final_status": task["status"], + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + } + return [ + (TASK_LIFECYCLE_STATE_EVENT_KIND, cast(dict[str, object], state_payload)), + (TASK_LIFECYCLE_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + + +def task_step_lifecycle_trace_events( + *, + task_step: TaskStepRecord, + previous_status: TaskStepStatus | None, + source: TaskLifecycleSource, +) -> list[tuple[str, dict[str, object]]]: + state_payload: TaskStepLifecycleStateTracePayload = { + "task_id": task_step["task_id"], + "task_step_id": task_step["id"], + "source": source, + "sequence_no": task_step["sequence_no"], + "kind": task_step["kind"], + "previous_status": previous_status, + "current_status": task_step["status"], + "trace": task_step["trace"], + } + summary_payload: TaskStepLifecycleSummaryTracePayload = { + "task_id": task_step["task_id"], + "task_step_id": task_step["id"], + "source": source, + "sequence_no": task_step["sequence_no"], + "kind": task_step["kind"], + "final_status": task_step["status"], + "trace": task_step["trace"], + } + return [ + (TASK_STEP_LIFECYCLE_STATE_EVENT_KIND, cast(dict[str, object], state_payload)), + (TASK_STEP_LIFECYCLE_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + + +def list_task_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> TaskListResponse: + del user_id + + items = [serialize_task_row(row) for row in store.list_tasks()] + summary: TaskListSummary = { + "total_count": len(items), + "order": list(TASK_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_task_record( + store: ContinuityStore, + *, + user_id: UUID, + task_id: UUID, +) -> TaskDetailResponse: + del user_id + + task = store.get_task_optional(task_id) + if task is None: + raise TaskNotFoundError(f"task {task_id} was not found") + return {"task": serialize_task_row(task)} + + +def list_task_step_records( + store: ContinuityStore, + *, + user_id: UUID, + task_id: UUID, +) -> TaskStepListResponse: + del user_id + + task = store.get_task_optional(task_id) + if task is None: + raise TaskNotFoundError(f"task {task_id} was not found") + + items = [serialize_task_step_row(row) for row in store.list_task_steps_for_task(task_id)] + summary = _task_step_sequencing_summary(task_id=str(task["id"]), items=items) + return { + "items": items, + "summary": summary, + } + + +def get_task_step_record( + store: ContinuityStore, + *, + user_id: UUID, + task_step_id: UUID, +) -> TaskStepDetailResponse: + del user_id + + task_step = store.get_task_step_optional(task_step_id) + if task_step is None: + raise TaskStepNotFoundError(f"task step {task_step_id} was not found") + return {"task_step": serialize_task_step_row(task_step)} diff --git a/apps/api/src/alicebot_api/tools.py b/apps/api/src/alicebot_api/tools.py new file mode 100644 index 0000000..0634990 --- /dev/null +++ b/apps/api/src/alicebot_api/tools.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from alicebot_api.contracts import ( + TOOL_ALLOWLIST_EVALUATION_VERSION_V0, + TOOL_ROUTING_VERSION_V0, + TOOL_LIST_ORDER, + TRACE_KIND_TOOL_ALLOWLIST_EVALUATE, + TRACE_KIND_TOOL_ROUTE, + PolicyEvaluationRequestInput, + ToolAllowlistDecisionRecord, + ToolAllowlistEvaluationRequestInput, + ToolAllowlistEvaluationResponse, + ToolAllowlistEvaluationSummary, + ToolAllowlistReason, + ToolAllowlistTraceSummary, + ToolRoutingDecision, + ToolRoutingDecisionTracePayload, + ToolRoutingRequestInput, + ToolRoutingRequestTracePayload, + ToolRoutingResponse, + ToolRoutingSummary, + ToolRoutingSummaryTracePayload, + ToolRoutingTraceSummary, + ToolCreateInput, + ToolCreateResponse, + ToolDetailResponse, + ToolListResponse, + ToolListSummary, + ToolRecord, + isoformat_or_none, +) +from alicebot_api.policy import ( + evaluate_policy_against_context, + load_policy_evaluation_context, +) +from alicebot_api.store import ContinuityStore, ToolRow + + +class ToolValidationError(ValueError): + """Raised when a tool-registry request fails explicit validation.""" + + +class ToolNotFoundError(LookupError): + """Raised when a requested tool is not visible inside the current user scope.""" + + +class ToolAllowlistValidationError(ValueError): + """Raised when a tool-allowlist evaluation request fails explicit validation.""" + + +class ToolRoutingValidationError(ValueError): + """Raised when a tool-routing request fails explicit validation.""" + + +@dataclass(frozen=True, slots=True) +class ToolClassificationResult: + decision: str + tool: ToolRecord + reasons: list[ToolAllowlistReason] + matched_policy_id: str | None + + +def _serialize_tool(tool: ToolRow) -> ToolRecord: + return { + "id": str(tool["id"]), + "tool_key": tool["tool_key"], + "name": tool["name"], + "description": tool["description"], + "version": tool["version"], + "metadata_version": tool["metadata_version"], + "active": tool["active"], + "tags": list(tool["tags"]), + "action_hints": list(tool["action_hints"]), + "scope_hints": list(tool["scope_hints"]), + "domain_hints": list(tool["domain_hints"]), + "risk_hints": list(tool["risk_hints"]), + "metadata": tool["metadata"], + "created_at": tool["created_at"].isoformat(), + } + + +def _build_tool_reason( + *, + code: str, + source: str, + message: str, + tool_id: UUID, + policy_id: str | None = None, + consent_key: str | None = None, +) -> ToolAllowlistReason: + return { + "code": code, + "source": source, + "message": message, + "tool_id": str(tool_id), + "policy_id": policy_id, + "consent_key": consent_key, + } + + +def _metadata_match_reasons( + *, + tool: ToolRow, + request: ToolAllowlistEvaluationRequestInput, +) -> tuple[bool, list[ToolAllowlistReason]]: + reasons: list[ToolAllowlistReason] = [] + matched = True + + if request.action not in tool["action_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_action_unsupported", + source="tool", + message=f"Tool '{tool['tool_key']}' does not declare support for action '{request.action}'.", + tool_id=tool["id"], + ) + ) + + if request.scope not in tool["scope_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_scope_unsupported", + source="tool", + message=f"Tool '{tool['tool_key']}' does not declare support for scope '{request.scope}'.", + tool_id=tool["id"], + ) + ) + + if request.domain_hint is not None and tool["domain_hints"] and request.domain_hint not in tool["domain_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_domain_mismatch", + source="tool", + message=( + f"Tool '{tool['tool_key']}' does not declare domain hint '{request.domain_hint}'." + ), + tool_id=tool["id"], + ) + ) + + if request.risk_hint is not None and tool["risk_hints"] and request.risk_hint not in tool["risk_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_risk_mismatch", + source="tool", + message=f"Tool '{tool['tool_key']}' does not declare risk hint '{request.risk_hint}'.", + tool_id=tool["id"], + ) + ) + + if matched: + reasons.append( + _build_tool_reason( + code="tool_metadata_matched", + source="tool", + message="Tool metadata matched the requested action, scope, and optional hints.", + tool_id=tool["id"], + ) + ) + + return matched, reasons + + +def _policy_attributes( + *, + tool: ToolRow, + request: ToolAllowlistEvaluationRequestInput, +) -> dict[str, object]: + attributes: dict[str, object] = dict(request.attributes) + attributes["tool_key"] = tool["tool_key"] + attributes["tool_version"] = tool["version"] + attributes["metadata_version"] = tool["metadata_version"] + if request.domain_hint is not None: + attributes["domain_hint"] = request.domain_hint + if request.risk_hint is not None: + attributes["risk_hint"] = request.risk_hint + return attributes + + +def _classify_tool_request( + *, + tool: ToolRow, + request: ToolAllowlistEvaluationRequestInput, + policy_context, +) -> ToolClassificationResult: + metadata_matched, metadata_reasons = _metadata_match_reasons(tool=tool, request=request) + serialized_tool = _serialize_tool(tool) + + if not metadata_matched: + return ToolClassificationResult( + decision="denied", + tool=serialized_tool, + reasons=metadata_reasons, + matched_policy_id=None, + ) + + policy_decision = evaluate_policy_against_context( + policy_context, + request=PolicyEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + attributes=_policy_attributes(tool=tool, request=request), + ), + ) + reasons = metadata_reasons + [ + { + "code": reason["code"], + "source": reason["source"], + "message": reason["message"], + "tool_id": str(tool["id"]), + "policy_id": reason["policy_id"], + "consent_key": reason["consent_key"], + } + for reason in policy_decision.reasons + ] + return ToolClassificationResult( + decision={ + "allow": "allowed", + "deny": "denied", + "require_approval": "approval_required", + }[policy_decision.decision], + tool=serialized_tool, + reasons=reasons, + matched_policy_id=( + None if policy_decision.matched_policy is None else str(policy_decision.matched_policy["id"]) + ), + ) + + +def _decision_record_from_classification( + classification: ToolClassificationResult, +) -> ToolAllowlistDecisionRecord: + return { + "decision": classification.decision, + "tool": classification.tool, + "reasons": classification.reasons, + } + + +def _allowlist_trace_payload( + classification: ToolClassificationResult, +) -> dict[str, object]: + return { + "tool_id": classification.tool["id"], + "tool_key": classification.tool["tool_key"], + "tool_version": classification.tool["version"], + "decision": classification.decision, + "matched_policy_id": classification.matched_policy_id, + "reasons": classification.reasons, + } + + +def _allowlist_request_from_routing( + request: ToolRoutingRequestInput, +) -> ToolAllowlistEvaluationRequestInput: + return ToolAllowlistEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ) + + +def _routing_decision_from_allowlist(allowlist_decision: str) -> ToolRoutingDecision: + return { + "allowed": "ready", + "denied": "denied", + "approval_required": "approval_required", + }[allowlist_decision] + + +def create_tool_record( + store: ContinuityStore, + *, + user_id: UUID, + tool: ToolCreateInput, +) -> ToolCreateResponse: + del user_id + + created = store.create_tool( + tool_key=tool.tool_key, + name=tool.name, + description=tool.description, + version=tool.version, + metadata_version=tool.metadata_version, + active=tool.active, + tags=list(tool.tags), + action_hints=list(tool.action_hints), + scope_hints=list(tool.scope_hints), + domain_hints=list(tool.domain_hints), + risk_hints=list(tool.risk_hints), + metadata=tool.metadata, + ) + return {"tool": _serialize_tool(created)} + + +def list_tool_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ToolListResponse: + del user_id + + items = [_serialize_tool(tool) for tool in store.list_tools()] + summary: ToolListSummary = { + "total_count": len(items), + "order": list(TOOL_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_tool_record( + store: ContinuityStore, + *, + user_id: UUID, + tool_id: UUID, +) -> ToolDetailResponse: + del user_id + + tool = store.get_tool_optional(tool_id) + if tool is None: + raise ToolNotFoundError(f"tool {tool_id} was not found") + return {"tool": _serialize_tool(tool)} + + +def evaluate_tool_allowlist( + store: ContinuityStore, + *, + user_id: UUID, + request: ToolAllowlistEvaluationRequestInput, +) -> ToolAllowlistEvaluationResponse: + del user_id + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise ToolAllowlistValidationError( + "thread_id must reference an existing thread owned by the user" + ) + + active_tools = store.list_active_tools() + policy_context = load_policy_evaluation_context(store) + + allowed: list[ToolAllowlistDecisionRecord] = [] + denied: list[ToolAllowlistDecisionRecord] = [] + approval_required: list[ToolAllowlistDecisionRecord] = [] + tool_trace_events: list[tuple[str, dict[str, object]]] = [] + + for tool in active_tools: + classification = _classify_tool_request( + tool=tool, + request=request, + policy_context=policy_context, + ) + decision_record = _decision_record_from_classification(classification) + + if classification.decision == "allowed": + allowed.append(decision_record) + elif classification.decision == "approval_required": + approval_required.append(decision_record) + else: + denied.append(decision_record) + + tool_trace_events.append( + ( + "tool.allowlist.decision", + _allowlist_trace_payload(classification), + ) + ) + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_TOOL_ALLOWLIST_EVALUATE, + compiler_version=TOOL_ALLOWLIST_EVALUATION_VERSION_V0, + status="completed", + limits={ + "order": list(TOOL_LIST_ORDER), + "active_tool_count": len(active_tools), + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + }, + ) + + trace_events: list[tuple[str, dict[str, object]]] = [ + ( + "tool.allowlist.request", + { + "thread_id": str(request.thread_id), + "action": request.action, + "scope": request.scope, + "domain_hint": request.domain_hint, + "risk_hint": request.risk_hint, + "attributes": request.attributes, + }, + ), + ( + "tool.allowlist.order", + { + "order": list(TOOL_LIST_ORDER), + "tool_ids": [str(tool["id"]) for tool in active_tools], + }, + ), + *tool_trace_events, + ( + "tool.allowlist.summary", + { + "allowed_count": len(allowed), + "denied_count": len(denied), + "approval_required_count": len(approval_required), + }, + ), + ] + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + summary: ToolAllowlistEvaluationSummary = { + "action": request.action, + "scope": request.scope, + "domain_hint": request.domain_hint, + "risk_hint": request.risk_hint, + "evaluated_tool_count": len(active_tools), + "allowed_count": len(allowed), + "denied_count": len(denied), + "approval_required_count": len(approval_required), + "order": list(TOOL_LIST_ORDER), + } + trace_summary: ToolAllowlistTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "allowed": allowed, + "denied": denied, + "approval_required": approval_required, + "summary": summary, + "trace": trace_summary, + } + + +def route_tool_invocation( + store: ContinuityStore, + *, + user_id: UUID, + request: ToolRoutingRequestInput, +) -> ToolRoutingResponse: + del user_id + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise ToolRoutingValidationError( + "thread_id must reference an existing thread owned by the user" + ) + + tool = store.get_tool_optional(request.tool_id) + if tool is None or tool["active"] is not True: + raise ToolRoutingValidationError( + "tool_id must reference an existing active tool owned by the user" + ) + + policy_context = load_policy_evaluation_context(store) + classification = _classify_tool_request( + tool=tool, + request=_allowlist_request_from_routing(request), + policy_context=policy_context, + ) + routing_decision = _routing_decision_from_allowlist(classification.decision) + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_TOOL_ROUTE, + compiler_version=TOOL_ROUTING_VERSION_V0, + status="completed", + limits={ + "order": list(TOOL_LIST_ORDER), + "evaluated_tool_count": 1, + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + }, + ) + + request_payload: ToolRoutingRequestTracePayload = request.as_payload() + decision_payload: ToolRoutingDecisionTracePayload = { + "tool_id": classification.tool["id"], + "tool_key": classification.tool["tool_key"], + "tool_version": classification.tool["version"], + "allowlist_decision": classification.decision, + "routing_decision": routing_decision, + "matched_policy_id": classification.matched_policy_id, + "reasons": classification.reasons, + } + summary_payload: ToolRoutingSummaryTracePayload = { + "decision": routing_decision, + "evaluated_tool_count": 1, + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + } + trace_events = [ + ("tool.route.request", request_payload), + ("tool.route.decision", decision_payload), + ("tool.route.summary", summary_payload), + ] + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + summary: ToolRoutingSummary = { + "thread_id": str(request.thread_id), + "tool_id": classification.tool["id"], + "action": request.action, + "scope": request.scope, + "domain_hint": request.domain_hint, + "risk_hint": request.risk_hint, + "decision": routing_decision, + "evaluated_tool_count": 1, + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + "order": list(TOOL_LIST_ORDER), + } + trace_summary: ToolRoutingTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "request": request_payload, + "decision": routing_decision, + "tool": classification.tool, + "reasons": classification.reasons, + "summary": summary, + "trace": trace_summary, + } diff --git a/apps/api/src/alicebot_api/workspaces.py b/apps/api/src/alicebot_api/workspaces.py new file mode 100644 index 0000000..d058fb0 --- /dev/null +++ b/apps/api/src/alicebot_api/workspaces.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from pathlib import Path +from typing import cast +from uuid import UUID + +from alicebot_api.config import Settings +from alicebot_api.contracts import ( + TASK_WORKSPACE_LIST_ORDER, + TaskWorkspaceCreateInput, + TaskWorkspaceCreateResponse, + TaskWorkspaceDetailResponse, + TaskWorkspaceListResponse, + TaskWorkspaceRecord, + TaskWorkspaceStatus, +) +from alicebot_api.tasks import TaskNotFoundError +from alicebot_api.store import ContinuityStore, TaskWorkspaceRow + + +class TaskWorkspaceNotFoundError(LookupError): + """Raised when a task workspace record is not visible inside the current user scope.""" + + +class TaskWorkspaceAlreadyExistsError(RuntimeError): + """Raised when an active task workspace already exists for a task.""" + + +class TaskWorkspaceProvisioningError(RuntimeError): + """Raised when local workspace provisioning cannot satisfy rooted path rules.""" + + +def resolve_workspace_root(workspace_root: str) -> Path: + return Path(workspace_root).expanduser().resolve() + + +def build_task_workspace_path( + *, + workspace_root: Path, + user_id: UUID, + task_id: UUID, +) -> Path: + return workspace_root / str(user_id) / str(task_id) + + +def ensure_workspace_path_is_rooted( + *, + workspace_root: Path, + workspace_path: Path, +) -> None: + resolved_root = workspace_root.resolve() + resolved_path = workspace_path.resolve() + try: + resolved_path.relative_to(resolved_root) + except ValueError as exc: + raise TaskWorkspaceProvisioningError( + f"workspace path {resolved_path} escapes configured root {resolved_root}" + ) from exc + + +def serialize_task_workspace_row(row: TaskWorkspaceRow) -> TaskWorkspaceRecord: + return { + "id": str(row["id"]), + "task_id": str(row["task_id"]), + "status": cast(TaskWorkspaceStatus, row["status"]), + "local_path": row["local_path"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def create_task_workspace_record( + store: ContinuityStore, + *, + settings: Settings, + user_id: UUID, + request: TaskWorkspaceCreateInput, +) -> TaskWorkspaceCreateResponse: + task = store.get_task_optional(request.task_id) + if task is None: + raise TaskNotFoundError(f"task {request.task_id} was not found") + + workspace_root = resolve_workspace_root(settings.task_workspace_root) + workspace_path = build_task_workspace_path( + workspace_root=workspace_root, + user_id=user_id, + task_id=request.task_id, + ) + ensure_workspace_path_is_rooted( + workspace_root=workspace_root, + workspace_path=workspace_path, + ) + + store.lock_task_workspaces(request.task_id) + existing_workspace = store.get_active_task_workspace_for_task_optional(request.task_id) + if existing_workspace is not None: + raise TaskWorkspaceAlreadyExistsError( + f"task {request.task_id} already has active workspace {existing_workspace['id']}" + ) + + try: + workspace_path.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise TaskWorkspaceProvisioningError( + f"workspace path {workspace_path} could not be provisioned" + ) from exc + + row = store.create_task_workspace( + task_id=request.task_id, + status=request.status, + local_path=str(workspace_path), + ) + return {"workspace": serialize_task_workspace_row(row)} + + +def list_task_workspace_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> TaskWorkspaceListResponse: + del user_id + + items = [serialize_task_workspace_row(row) for row in store.list_task_workspaces()] + return { + "items": items, + "summary": { + "total_count": len(items), + "order": list(TASK_WORKSPACE_LIST_ORDER), + }, + } + + +def get_task_workspace_record( + store: ContinuityStore, + *, + user_id: UUID, + task_workspace_id: UUID, +) -> TaskWorkspaceDetailResponse: + del user_id + + row = store.get_task_workspace_optional(task_workspace_id) + if row is None: + raise TaskWorkspaceNotFoundError(f"task workspace {task_workspace_id} was not found") + return {"workspace": serialize_task_workspace_row(row)} diff --git a/apps/web/.gitkeep b/apps/web/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/apps/web/.gitkeep @@ -0,0 +1 @@ + diff --git a/apps/web/app/layout.tsx b/apps/web/app/layout.tsx new file mode 100644 index 0000000..ed6cafd --- /dev/null +++ b/apps/web/app/layout.tsx @@ -0,0 +1,10 @@ +export default function RootLayout({ + children, +}: Readonly<{ children: React.ReactNode }>) { + return ( + + {children} + + ); +} + diff --git a/apps/web/app/page.tsx b/apps/web/app/page.tsx new file mode 100644 index 0000000..7a46a7a --- /dev/null +++ b/apps/web/app/page.tsx @@ -0,0 +1,51 @@ +const milestones = [ + "API foundation and migrations", + "Continuity event store", + "Web dashboard shell", + "Worker orchestration", +]; + +export default function HomePage() { + return ( +
+
+

+ AliceBot Foundation +

+

+ Operational shell for the modular monolith +

+

+ The web app is intentionally minimal in this sprint. It exists to prove repository + structure while continuity, migrations, and safety primitives land in the API layer. +

+
    + {milestones.map((item) => ( +
  • {item}
  • + ))} +
+
+
+ ); +} + diff --git a/apps/web/next-env.d.ts b/apps/web/next-env.d.ts new file mode 100644 index 0000000..dc86238 --- /dev/null +++ b/apps/web/next-env.d.ts @@ -0,0 +1,5 @@ +/// +/// + +// This file is managed by Next.js. + diff --git a/apps/web/next.config.mjs b/apps/web/next.config.mjs new file mode 100644 index 0000000..06cd07e --- /dev/null +++ b/apps/web/next.config.mjs @@ -0,0 +1,6 @@ +const nextConfig = { + reactStrictMode: true, +}; + +export default nextConfig; + diff --git a/apps/web/package.json b/apps/web/package.json new file mode 100644 index 0000000..7f5ec8b --- /dev/null +++ b/apps/web/package.json @@ -0,0 +1,25 @@ +{ + "name": "@alicebot/web", + "private": true, + "version": "0.1.0", + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start", + "lint": "next lint" + }, + "dependencies": { + "next": "15.2.0", + "react": "19.0.0", + "react-dom": "19.0.0" + }, + "devDependencies": { + "@types/node": "22.13.10", + "@types/react": "19.0.10", + "@types/react-dom": "19.0.4", + "eslint": "9.22.0", + "eslint-config-next": "15.2.0", + "typescript": "5.8.2" + } +} + diff --git a/apps/web/tsconfig.json b/apps/web/tsconfig.json new file mode 100644 index 0000000..bbd3768 --- /dev/null +++ b/apps/web/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["dom", "dom.iterable", "es2022"], + "allowJs": false, + "skipLibCheck": true, + "strict": true, + "noEmit": true, + "module": "esnext", + "moduleResolution": "bundler", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve", + "incremental": true + }, + "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"], + "exclude": ["node_modules"] +} + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..2066a2b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,36 @@ +services: + postgres: + image: pgvector/pgvector:pg16 + container_name: alicebot-postgres + environment: + POSTGRES_USER: alicebot_admin + POSTGRES_PASSWORD: alicebot_admin + POSTGRES_DB: alicebot + ports: + - "127.0.0.1:5432:5432" + volumes: + - postgres-data:/var/lib/postgresql/data + - ./infra/postgres/init:/docker-entrypoint-initdb.d:ro + + redis: + image: redis:7-alpine + container_name: alicebot-redis + ports: + - "127.0.0.1:6379:6379" + + minio: + image: minio/minio:RELEASE.2025-02-28T09-55-16Z + container_name: alicebot-minio + command: server /data --console-address ":9001" + environment: + MINIO_ROOT_USER: alicebot + MINIO_ROOT_PASSWORD: alicebot-secret + ports: + - "127.0.0.1:9000:9000" + - "127.0.0.1:9001:9001" + volumes: + - minio-data:/data + +volumes: + postgres-data: + minio-data: diff --git a/docs/adr/.gitkeep b/docs/adr/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/adr/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/archive/.gitkeep b/docs/archive/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/archive/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/runbooks/.gitkeep b/docs/runbooks/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/runbooks/.gitkeep @@ -0,0 +1 @@ + diff --git a/infra/postgres/init/001_roles.sql b/infra/postgres/init/001_roles.sql new file mode 100644 index 0000000..78f9d49 --- /dev/null +++ b/infra/postgres/init/001_roles.sql @@ -0,0 +1,16 @@ +DO +$$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'alicebot_app') THEN + CREATE ROLE alicebot_app + LOGIN + PASSWORD 'alicebot_app' + NOSUPERUSER + NOCREATEDB + NOCREATEROLE + NOINHERIT; + END IF; +END +$$; + +GRANT CONNECT ON DATABASE alicebot TO alicebot_app; diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..51cf232 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=69", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "alicebot-foundation" +version = "0.1.0" +description = "Foundation scaffold for the AliceBot modular monolith." +requires-python = ">=3.12" +dependencies = [ + "alembic>=1.14,<2.0", + "fastapi>=0.115,<1.0", + "psycopg[binary]>=3.2,<4.0", + "sqlalchemy>=2.0,<3.0", + "uvicorn>=0.34,<1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3,<9.0", +] + +[tool.setuptools.package-dir] +"" = "." + +[tool.setuptools.packages.find] +where = ["apps/api/src", "workers"] + +[tool.pytest.ini_options] +pythonpath = ["apps/api/src", "workers"] +testpaths = ["tests"] + diff --git a/scripts/.gitkeep b/scripts/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/scripts/.gitkeep @@ -0,0 +1 @@ + diff --git a/scripts/api_dev.sh b/scripts/api_dev.sh new file mode 100755 index 0000000..17ee7f4 --- /dev/null +++ b/scripts/api_dev.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +if [ -f "${REPO_ROOT}/.env" ]; then + set -a + . "${REPO_ROOT}/.env" + set +a +fi + +PYTHON_BIN="python3" +if [ -x "${REPO_ROOT}/.venv/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" +fi + +cd "${REPO_ROOT}" + +UVICORN_ARGS=( + --app-dir "${REPO_ROOT}/apps/api/src" + --host "${APP_HOST:-127.0.0.1}" + --port "${APP_PORT:-8000}" +) + +if [ "${APP_RELOAD:-true}" = "true" ]; then + UVICORN_ARGS+=(--reload) +fi + +exec "${PYTHON_BIN}" -m uvicorn alicebot_api.main:app "${UVICORN_ARGS[@]}" diff --git a/scripts/dev_up.sh b/scripts/dev_up.sh new file mode 100755 index 0000000..983ce1b --- /dev/null +++ b/scripts/dev_up.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +if [ -f "${REPO_ROOT}/.env" ]; then + set -a + . "${REPO_ROOT}/.env" + set +a +fi + +PYTHON_BIN="python3" +if [ -x "${REPO_ROOT}/.venv/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" +fi + +cd "${REPO_ROOT}" + +docker compose up -d + +"${PYTHON_BIN}" -c ' +import os +import sys +import time + +import psycopg + +database_url = os.getenv( + "DATABASE_ADMIN_URL", + "postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot", +) +deadline = time.time() + 60 + +while time.time() < deadline: + try: + with psycopg.connect(database_url, connect_timeout=1) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1 FROM pg_roles WHERE rolname = %s", ("alicebot_app",)) + if cur.fetchone() == (1,): + sys.exit(0) + except psycopg.Error: + pass + time.sleep(1) + +raise SystemExit("Timed out waiting for Postgres readiness and alicebot_app bootstrap") +' + +"${PYTHON_BIN}" -m alembic -c "${REPO_ROOT}/apps/api/alembic.ini" upgrade head diff --git a/scripts/migrate.sh b/scripts/migrate.sh new file mode 100755 index 0000000..ef2401b --- /dev/null +++ b/scripts/migrate.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +if [ -f "${REPO_ROOT}/.env" ]; then + set -a + . "${REPO_ROOT}/.env" + set +a +fi + +PYTHON_BIN="python3" +if [ -x "${REPO_ROOT}/.venv/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" +fi + +cd "${REPO_ROOT}" + +"${PYTHON_BIN}" -m alembic -c "${REPO_ROOT}/apps/api/alembic.ini" upgrade "${1:-head}" diff --git a/tests/.gitkeep b/tests/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/.gitkeep @@ -0,0 +1 @@ + diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..f413549 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections.abc import Iterator +import os +from urllib.parse import urlsplit, urlunsplit +from uuid import uuid4 + +from alembic import command +import psycopg +from psycopg import sql +import pytest + +from alicebot_api.migrations import make_alembic_config + + +DEFAULT_ADMIN_URL = "postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot" +DEFAULT_APP_URL = "postgresql://alicebot_app:alicebot_app@localhost:5432/alicebot" + + +def swap_database_name(database_url: str, database_name: str) -> str: + parsed = urlsplit(database_url) + return urlunsplit((parsed.scheme, parsed.netloc, f"/{database_name}", parsed.query, parsed.fragment)) + + +@pytest.fixture +def database_urls() -> Iterator[dict[str, str]]: + admin_root_url = os.getenv("DATABASE_ADMIN_URL", DEFAULT_ADMIN_URL) + app_root_url = os.getenv("DATABASE_URL", DEFAULT_APP_URL) + database_name = f"alicebot_test_{uuid4().hex[:12]}" + admin_database_url = swap_database_name(admin_root_url, database_name) + app_database_url = swap_database_name(app_root_url, database_name) + + with psycopg.connect(admin_root_url, autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(database_name))) + cur.execute( + sql.SQL("GRANT CONNECT, TEMPORARY ON DATABASE {} TO alicebot_app").format( + sql.Identifier(database_name) + ) + ) + + yield {"admin": admin_database_url, "app": app_database_url} + + with psycopg.connect(admin_root_url, autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL("DROP DATABASE IF EXISTS {} WITH (FORCE)").format(sql.Identifier(database_name)) + ) + + +@pytest.fixture +def migrated_database_urls(database_urls: dict[str, str]) -> Iterator[dict[str, str]]: + config = make_alembic_config(database_urls["admin"]) + command.upgrade(config, "head") + yield database_urls diff --git a/tests/integration/test_approval_api.py b/tests/integration/test_approval_api.py new file mode 100644 index 0000000..0f6e980 --- /dev/null +++ b/tests/integration/test_approval_api.py @@ -0,0 +1,929 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Approval thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_approval_request_persists_record_for_approval_required_route( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + status, payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + + assert status == 200 + assert list(payload) == [ + "request", + "decision", + "tool", + "reasons", + "task", + "approval", + "routing_trace", + "trace", + ] + assert payload["decision"] == "approval_required" + assert payload["task"]["status"] == "pending_approval" + assert payload["task"]["latest_approval_id"] == payload["approval"]["id"] + assert payload["task"]["latest_execution_id"] is None + assert payload["approval"] is not None + assert payload["approval"]["status"] == "pending" + assert payload["approval"]["task_step_id"] is not None + assert payload["approval"]["resolution"] is None + assert payload["approval"]["request"] == payload["request"] + assert payload["approval"]["tool"] == payload["tool"] + assert payload["approval"]["routing"] == { + "decision": "approval_required", + "reasons": payload["reasons"], + "trace": payload["routing_trace"], + } + assert payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + } + assert payload["routing_trace"]["trace_event_count"] == 3 + assert payload["trace"]["trace_event_count"] == 8 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + approvals = store.list_approvals() + tasks = store.list_tasks() + task_steps = store.list_task_steps_for_task(tasks[0]["id"]) + approval_trace = store.get_trace(UUID(payload["trace"]["trace_id"])) + approval_trace_events = store.list_trace_events(UUID(payload["trace"]["trace_id"])) + + assert len(approvals) == 1 + assert len(tasks) == 1 + assert len(task_steps) == 1 + assert approvals[0]["id"] == UUID(payload["approval"]["id"]) + assert approvals[0]["task_step_id"] == task_steps[0]["id"] + assert tasks[0]["id"] == UUID(payload["task"]["id"]) + assert approval_trace["kind"] == "approval.request" + assert approval_trace["compiler_version"] == "approval_request_v0" + assert approval_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "persisted": True, + } + assert [event["kind"] for event in approval_trace_events] == [ + "approval.request.request", + "approval.request.routing", + "approval.request.persisted", + "approval.request.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert approval_trace_events[1]["payload"] == { + "decision": "approval_required", + "tool_id": str(tool["id"]), + "tool_key": "shell.exec", + "tool_version": "1.0.0", + "routing_trace_id": payload["routing_trace"]["trace_id"], + "routing_trace_event_count": 3, + "reasons": payload["reasons"], + } + assert approval_trace_events[4]["payload"] == { + "task_id": payload["task"]["id"], + "source": "approval_request", + "previous_status": None, + "current_status": "pending_approval", + "latest_approval_id": payload["approval"]["id"], + "latest_execution_id": None, + } + assert approval_trace_events[2]["payload"] == { + "approval_id": payload["approval"]["id"], + "task_step_id": payload["approval"]["task_step_id"], + "decision": "approval_required", + "persisted": True, + } + assert approval_trace_events[6]["payload"] == { + "task_id": payload["task"]["id"], + "task_step_id": str(task_steps[0]["id"]), + "source": "approval_request", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": None, + "current_status": "created", + "trace": { + "trace_id": payload["trace"]["trace_id"], + "trace_kind": "approval.request", + }, + } + + +def test_approval_request_does_not_create_records_for_ready_or_denied_routes( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {}, + }, + ) + denied_status, denied_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert ready_status == 200 + assert ready_payload["decision"] == "ready" + assert ready_payload["task"]["status"] == "approved" + assert ready_payload["approval"] is None + assert denied_status == 200 + assert denied_payload["decision"] == "denied" + assert denied_payload["task"]["status"] == "denied" + assert denied_payload["approval"] is None + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + approvals = store.list_approvals() + tasks = store.list_tasks() + ready_task_steps = store.list_task_steps_for_task(tasks[0]["id"]) + denied_task_steps = store.list_task_steps_for_task(tasks[1]["id"]) + + assert approvals == [] + assert [task["status"] for task in tasks] == ["approved", "denied"] + assert [task_step["status"] for task_step in ready_task_steps] == ["approved"] + assert [task_step["status"] for task_step in denied_task_steps] == ["denied"] + + +def test_approval_endpoints_list_and_detail_are_deterministic_and_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + first_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + second_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="2.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + first_status, first_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(first_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "pwd"}, + }, + ) + second_status, second_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(second_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/approvals", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/approvals/{second_payload['approval']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/approvals", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/approvals/{first_payload['approval']['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert first_status == 200 + assert second_status == 200 + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + first_payload["approval"]["id"], + second_payload["approval"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"approval": second_payload["approval"]} + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"approval {first_payload['approval']['id']} was not found" + } + + +def test_approval_resolution_endpoints_update_reads_and_emit_trace( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + first_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + second_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="2.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + _, first_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(first_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "pwd"}, + }, + ) + _, second_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(second_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{first_request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + reject_status, reject_payload = invoke_request( + "POST", + f"/v0/approvals/{second_request_payload['approval']['id']}/reject", + payload={"user_id": str(owner['user_id'])}, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/approvals", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/approvals/{second_request_payload['approval']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert approve_status == 200 + assert list(approve_payload) == ["approval", "trace"] + assert approve_payload["approval"]["status"] == "approved" + assert approve_payload["approval"]["task_step_id"] == first_request_payload["approval"]["task_step_id"] + assert approve_payload["approval"]["resolution"] is not None + assert approve_payload["trace"]["trace_event_count"] == 7 + + assert reject_status == 200 + assert list(reject_payload) == ["approval", "trace"] + assert reject_payload["approval"]["status"] == "rejected" + assert reject_payload["approval"]["task_step_id"] == second_request_payload["approval"]["task_step_id"] + assert reject_payload["approval"]["resolution"] is not None + assert reject_payload["trace"]["trace_event_count"] == 7 + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + first_request_payload["approval"]["id"], + second_request_payload["approval"]["id"], + ] + assert [item["status"] for item in list_payload["items"]] == ["approved", "rejected"] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"approval": reject_payload["approval"]} + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + approve_trace = store.get_trace(UUID(approve_payload["trace"]["trace_id"])) + approve_trace_events = store.list_trace_events(UUID(approve_payload["trace"]["trace_id"])) + reject_trace = store.get_trace(UUID(reject_payload["trace"]["trace_id"])) + reject_trace_events = store.list_trace_events(UUID(reject_payload["trace"]["trace_id"])) + + assert approve_trace["kind"] == "approval.resolve" + assert approve_trace["compiler_version"] == "approval_resolution_v0" + assert approve_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "approve", + "outcome": "resolved", + } + assert [event["kind"] for event in approve_trace_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert approve_trace_events[1]["payload"]["current_status"] == "approved" + assert approve_trace_events[1]["payload"]["task_step_id"] == first_request_payload["approval"]["task_step_id"] + assert approve_trace_events[1]["payload"]["resolved_by_user_id"] == str(owner["user_id"]) + + assert reject_trace["kind"] == "approval.resolve" + assert reject_trace["compiler_version"] == "approval_resolution_v0" + assert reject_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "reject", + "outcome": "resolved", + } + assert [event["kind"] for event in reject_trace_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert reject_trace_events[1]["payload"]["current_status"] == "rejected" + assert reject_trace_events[1]["payload"]["task_step_id"] == second_request_payload["approval"]["task_step_id"] + assert reject_trace_events[1]["payload"]["resolved_by_user_id"] == str(owner["user_id"]) + + +def test_approval_resolution_rejects_duplicate_conflicting_and_cross_user_attempts( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + _, request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + approval_id = request_payload["approval"]["id"] + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + duplicate_status, duplicate_payload = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + conflict_status, conflict_payload = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/reject", + payload={"user_id": str(owner["user_id"])}, + ) + intruder_status, intruder_payload = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/reject", + payload={"user_id": str(intruder["user_id"])}, + ) + + assert first_approve_status == 200 + assert duplicate_status == 409 + assert duplicate_payload == {"detail": f"approval {approval_id} was already approved"} + assert conflict_status == 409 + assert conflict_payload == { + "detail": f"approval {approval_id} was already approved and cannot be rejected" + } + assert intruder_status == 404 + assert intruder_payload == {"detail": f"approval {approval_id} was not found"} + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + approval = store.get_approval_optional(UUID(approval_id)) + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, limits + FROM traces + WHERE thread_id = %s + AND kind = 'approval.resolve' + ORDER BY created_at ASC, id ASC + """, + (owner["thread_id"],), + ) + trace_rows = cur.fetchall() + duplicate_trace = trace_rows[-2] + conflict_trace = trace_rows[-1] + duplicate_events = store.list_trace_events(duplicate_trace["id"]) + conflict_events = store.list_trace_events(conflict_trace["id"]) + + assert approval is not None + assert approval["status"] == "approved" + assert duplicate_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "approve", + "outcome": "duplicate_rejected", + } + assert [event["kind"] for event in duplicate_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert duplicate_events[1]["payload"] == { + "approval_id": approval_id, + "task_step_id": str(approval["task_step_id"]), + "requested_action": "approve", + "previous_status": "approved", + "outcome": "duplicate_rejected", + "current_status": "approved", + "resolved_at": approval["resolved_at"].isoformat(), + "resolved_by_user_id": str(owner["user_id"]), + } + assert conflict_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "reject", + "outcome": "conflict_rejected", + } + assert [event["kind"] for event in conflict_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert conflict_events[1]["payload"] == { + "approval_id": approval_id, + "task_step_id": str(approval["task_step_id"]), + "requested_action": "reject", + "previous_status": "approved", + "outcome": "conflict_rejected", + "current_status": "approved", + "resolved_at": approval["resolved_at"].isoformat(), + "resolved_by_user_id": str(owner["user_id"]), + } + + +def test_approval_resolution_rejects_inconsistent_linkage_without_mutating_task_steps( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-boundary@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + + _, request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "initial"}, + }, + ) + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert detail_status == 200 + assert step_list_status == 200 + initial_execution_id = detail_payload["task"]["latest_execution_id"] + assert initial_execution_id is not None + + create_step_status, create_step_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-2"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": step_list_payload["items"][0]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + assert create_step_status == 201 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + conn.execute( + "UPDATE approvals SET task_step_id = %s WHERE id = %s", + ( + create_step_payload["task_step"]["id"], + request_payload["approval"]["id"], + ), + ) + + boundary_status, boundary_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + + assert boundary_status == 409 + assert boundary_payload == { + "detail": ( + f"approval {request_payload['approval']['id']} is inconsistent with linked task step " + f"{create_step_payload['task_step']['id']}" + ) + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + task = store.get_task_optional(UUID(request_payload["task"]["id"])) + task_steps = store.list_task_steps_for_task(UUID(request_payload["task"]["id"])) + approval = store.get_approval_optional(UUID(request_payload["approval"]["id"])) + approval_resolve_traces = store.conn.execute( + """ + SELECT id + FROM traces + WHERE thread_id = %s + AND kind = 'approval.resolve' + ORDER BY created_at ASC, id ASC + """, + (owner["thread_id"],), + ).fetchall() + + assert task is not None + assert approval is not None + assert task["status"] == "pending_approval" + assert task["latest_approval_id"] == UUID(request_payload["approval"]["id"]) + assert task["latest_execution_id"] is None + assert len(task_steps) == 2 + assert task_steps[0]["status"] == "executed" + assert task_steps[0]["trace_id"] == UUID(execute_payload["trace"]["trace_id"]) + assert task_steps[0]["outcome"]["execution_id"] == initial_execution_id + assert task_steps[1]["status"] == "created" + assert task_steps[1]["id"] == UUID(create_step_payload["task_step"]["id"]) + assert task_steps[1]["trace_kind"] == "task.step.continuation" + assert approval["status"] == "approved" + assert approval["task_step_id"] == UUID(create_step_payload["task_step"]["id"]) + assert len(approval_resolve_traces) == 1 diff --git a/tests/integration/test_context_compile.py b/tests/integration/test_context_compile.py new file mode 100644 index 0000000..f86bfe7 --- /dev/null +++ b/tests/integration/test_context_compile.py @@ -0,0 +1,890 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime +from typing import Any +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_compile_context(payload: dict[str, Any]) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/context/compile", + "raw_path": b"/v0/context/compile", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_traceable_thread( + database_url: str, + *, + email: str = "owner@example.com", + display_name: str = "Owner", +) -> dict[str, object]: + user_id = uuid4() + included_edge_valid_from = datetime(2026, 3, 12, 10, 0, tzinfo=UTC) + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, display_name) + thread = store.create_thread("Context thread") + first_session = store.create_session(thread["id"], status="complete") + second_session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], first_session["id"], "message.user", {"text": "old"})["id"], + store.append_event(thread["id"], second_session["id"], "message.assistant", {"text": "newer"})["id"], + store.append_event(thread["id"], second_session["id"], "message.user", {"text": "newest"})["id"], + ] + breakfast_memory = store.create_memory( + memory_key="user.preference.breakfast", + value={"likes": "toast"}, + status="active", + source_event_ids=[str(event_ids[0])], + ) + coffee_memory = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(event_ids[1])], + ) + deleted_memory = store.create_memory( + memory_key="user.preference.old", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_ids[1])], + ) + deleted_memory = store.update_memory( + memory_id=deleted_memory["id"], + value=deleted_memory["value"], + status="deleted", + source_event_ids=[str(event_ids[2])], + ) + person = store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(breakfast_memory["id"])], + ) + merchant = store.create_entity( + entity_type="merchant", + name="Neighborhood Cafe", + source_memory_ids=[str(coffee_memory["id"])], + ) + project = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(breakfast_memory["id"]), str(coffee_memory["id"])], + ) + excluded_edge = store.create_entity_edge( + from_entity_id=person["id"], + to_entity_id=project["id"], + relationship_type="visited_by", + valid_from=None, + valid_to=None, + source_memory_ids=[str(breakfast_memory["id"])], + ) + included_edge = store.create_entity_edge( + from_entity_id=project["id"], + to_entity_id=merchant["id"], + relationship_type="depends_on", + valid_from=included_edge_valid_from, + valid_to=None, + source_memory_ids=[str(coffee_memory["id"])], + ) + ignored_when_project_only_edge = store.create_entity_edge( + from_entity_id=person["id"], + to_entity_id=merchant["id"], + relationship_type="introduced_to", + valid_from=None, + valid_to=None, + source_memory_ids=[str(breakfast_memory["id"])], + ) + entities = store.list_entities() + entity_edges = store.list_entity_edges_for_entities([person["id"], merchant["id"], project["id"]]) + + return { + "user_id": user_id, + "thread_id": thread["id"], + "event_ids": event_ids, + "memories": { + "breakfast": breakfast_memory, + "coffee": coffee_memory, + "deleted": deleted_memory, + }, + "entities": entities, + "entity_edges": entity_edges, + "project_only_candidate_edges": { + "excluded": excluded_edge, + "included": included_edge, + "ignored": ignored_when_project_only_edge, + }, + "included_edge_valid_from": included_edge_valid_from, + } + + +def seed_thread_with_updated_active_memory(database_url: str) -> dict[str, object]: + user_id = uuid4() + included_edge_valid_from = datetime(2026, 3, 12, 11, 0, tzinfo=UTC) + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Updated memory thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "baseline memory evidence"}, + )["id"], + store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "updated memory evidence"}, + )["id"], + ] + store.create_memory( + memory_key="user.preference.breakfast", + value={"likes": "toast"}, + status="active", + source_event_ids=[str(event_ids[0])], + ) + coffee_memory = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_ids[0])], + ) + store.update_memory( + memory_id=coffee_memory["id"], + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(event_ids[1])], + ) + routine = store.create_entity( + entity_type="routine", + name="Breakfast", + source_memory_ids=[str(coffee_memory["id"])], + ) + project = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(coffee_memory["id"])], + ) + included_edge = store.create_entity_edge( + from_entity_id=project["id"], + to_entity_id=routine["id"], + relationship_type="references", + valid_from=included_edge_valid_from, + valid_to=None, + source_memory_ids=[str(coffee_memory["id"])], + ) + store.create_entity_edge( + from_entity_id=routine["id"], + to_entity_id=routine["id"], + relationship_type="superseded_by", + valid_from=None, + valid_to=None, + source_memory_ids=[str(coffee_memory["id"])], + ) + entities = store.list_entities() + + return { + "user_id": user_id, + "thread_id": thread["id"], + "event_ids": event_ids, + "entities": entities, + "included_edge": included_edge, + "included_edge_valid_from": included_edge_valid_from, + } + + +def seed_embedding_config_for_user( + database_url: str, + *, + user_id: UUID, + dimensions: int = 3, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + config = store.create_embedding_config( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=dimensions, + status="active", + metadata={"task": "compile_semantic_retrieval"}, + ) + return config["id"] + + +def seed_memory_embedding_for_user( + database_url: str, + *, + user_id: UUID, + memory_id: UUID, + embedding_config_id: UUID, + vector: list[float], +) -> None: + with user_connection(database_url, user_id) as conn: + ContinuityStore(conn).create_memory_embedding( + memory_id=memory_id, + embedding_config_id=embedding_config_id, + dimensions=len(vector), + vector=vector, + ) + + +def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_database_urls, monkeypatch) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + user_id = seeded["user_id"] + thread_id = seeded["thread_id"] + event_ids = seeded["event_ids"] + entities = seeded["entities"] + included_entity = entities[-1] + project_only_candidate_edges = seeded["project_only_candidate_edges"] + included_entity_edge = project_only_candidate_edges["included"] + excluded_entity_edge = project_only_candidate_edges["excluded"] + ignored_entity_edge = project_only_candidate_edges["ignored"] + included_edge_valid_from = seeded["included_edge_valid_from"] + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(user_id), + "thread_id": str(thread_id), + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + ) + + assert status_code == 200 + assert payload["trace_event_count"] > 0 + assert payload["context_pack"]["limits"] == { + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + assert [session["status"] for session in payload["context_pack"]["sessions"]] == ["active"] + assert [event["sequence_no"] for event in payload["context_pack"]["events"]] == [3] + assert payload["context_pack"]["memories"] == [ + { + "id": payload["context_pack"]["memories"][0]["id"], + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_ids[1])], + "created_at": payload["context_pack"]["memories"][0]["created_at"], + "updated_at": payload["context_pack"]["memories"][0]["updated_at"], + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ] + assert payload["context_pack"]["memory_summary"] == { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert payload["context_pack"]["entities"] == [ + { + "id": str(included_entity["id"]), + "entity_type": included_entity["entity_type"], + "name": included_entity["name"], + "source_memory_ids": included_entity["source_memory_ids"], + "created_at": included_entity["created_at"].isoformat(), + } + ] + assert payload["context_pack"]["entity_summary"] == { + "candidate_count": 3, + "included_count": 1, + "excluded_limit_count": 2, + } + assert payload["context_pack"]["entity_edges"] == [ + { + "id": str(included_entity_edge["id"]), + "from_entity_id": str(included_entity_edge["from_entity_id"]), + "to_entity_id": str(included_entity_edge["to_entity_id"]), + "relationship_type": included_entity_edge["relationship_type"], + "valid_from": included_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": included_entity_edge["source_memory_ids"], + "created_at": payload["context_pack"]["entity_edges"][0]["created_at"], + } + ] + assert payload["context_pack"]["entity_edge_summary"] == { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(trace_id) + trace_events = store.list_trace_events(trace_id) + + assert trace["thread_id"] == thread_id + assert trace["kind"] == "context.compile" + assert trace["limits"] == { + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + assert trace_events[0]["kind"] == "context.included" + assert trace_events[-1]["kind"] == "context.summary" + assert any( + event["payload"]["reason"] == "session_limit_exceeded" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "event_limit_exceeded" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_deleted" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_hybrid_memory_limit" + and event["payload"]["memory_key"] == "user.preference.coffee" + and event["payload"]["selected_sources"] == ["symbolic"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "entity_limit_exceeded" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_entity_limit" + and event["payload"]["name"] == included_entity["name"] + and event["payload"]["record_entity_type"] == included_entity["entity_type"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "entity_edge_limit_exceeded" + and event["payload"]["entity_id"] == str(excluded_entity_edge["id"]) + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_entity_edge_limit" + and event["payload"]["entity_id"] == str(included_entity_edge["id"]) + and event["payload"]["valid_from"] == included_edge_valid_from.isoformat() + for event in trace_events + if event["kind"] == "context.included" + ) + assert all( + event["payload"].get("entity_id") != str(ignored_entity_edge["id"]) + for event in trace_events + ) + assert trace_events[-1]["payload"]["included_memory_count"] == 1 + assert trace_events[-1]["payload"]["excluded_deleted_memory_count"] == 1 + assert trace_events[-1]["payload"]["excluded_memory_limit_count"] == 0 + assert trace_events[-1]["payload"]["hybrid_memory_requested"] is False + assert trace_events[-1]["payload"]["hybrid_memory_candidate_count"] == 2 + assert trace_events[-1]["payload"]["hybrid_memory_merged_candidate_count"] == 1 + assert trace_events[-1]["payload"]["hybrid_memory_deduplicated_count"] == 0 + assert trace_events[-1]["payload"]["included_entity_count"] == 1 + assert trace_events[-1]["payload"]["excluded_entity_limit_count"] == 2 + assert trace_events[-1]["payload"]["included_entity_edge_count"] == 1 + assert trace_events[-1]["payload"]["excluded_entity_edge_limit_count"] == 1 + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + with pytest.raises(psycopg.Error, match="append-only"): + cur.execute("UPDATE trace_events SET kind = 'mutated' WHERE trace_id = %s", (trace_id,)) + + +def test_compile_context_prefers_updated_active_memory_within_same_transaction( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_thread_with_updated_active_memory(migrated_database_urls["app"]) + user_id = seeded["user_id"] + thread_id = seeded["thread_id"] + event_ids = seeded["event_ids"] + entities = seeded["entities"] + excluded_entity = entities[0] + included_edge = seeded["included_edge"] + included_edge_valid_from = seeded["included_edge_valid_from"] + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(user_id), + "thread_id": str(thread_id), + "max_sessions": 1, + "max_events": 2, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + ) + + assert status_code == 200 + assert payload["trace_event_count"] > 0 + assert payload["context_pack"]["memories"] == [ + { + "id": payload["context_pack"]["memories"][0]["id"], + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_ids[1])], + "created_at": payload["context_pack"]["memories"][0]["created_at"], + "updated_at": payload["context_pack"]["memories"][0]["updated_at"], + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ] + assert payload["context_pack"]["memory_summary"] == { + "candidate_count": 1, + "included_count": 1, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert payload["context_pack"]["entity_summary"] == { + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + } + assert payload["context_pack"]["entity_edges"] == [ + { + "id": str(included_edge["id"]), + "from_entity_id": str(included_edge["from_entity_id"]), + "to_entity_id": str(included_edge["to_entity_id"]), + "relationship_type": included_edge["relationship_type"], + "valid_from": included_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": included_edge["source_memory_ids"], + "created_at": payload["context_pack"]["entity_edges"][0]["created_at"], + } + ] + assert payload["context_pack"]["entity_edge_summary"] == { + "anchor_entity_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_limit_count": 0, + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_hybrid_memory_limit" + and event["payload"]["memory_key"] == "user.preference.coffee" + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "entity_limit_exceeded" + and event["payload"]["name"] == excluded_entity["name"] + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_entity_edge_limit" + and event["payload"]["entity_id"] == str(included_edge["id"]) + for event in trace_events + if event["kind"] == "context.included" + ) + + +def test_compile_context_endpoint_merges_hybrid_memory_provenance_and_trace_events( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + user_id = seeded["user_id"] + thread_id = seeded["thread_id"] + memories = seeded["memories"] + config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=user_id, + ) + seed_memory_embedding_for_user( + migrated_database_urls["app"], + user_id=user_id, + memory_id=memories["breakfast"]["id"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_memory_embedding_for_user( + migrated_database_urls["app"], + user_id=user_id, + memory_id=memories["coffee"]["id"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_memory_embedding_for_user( + migrated_database_urls["app"], + user_id=user_id, + memory_id=memories["deleted"]["id"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(user_id), + "thread_id": str(thread_id), + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + "semantic": { + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["trace_event_count"] > 0 + assert payload["context_pack"]["memories"] == [ + { + "id": str(memories["coffee"]["id"]), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": memories["coffee"]["source_event_ids"], + "created_at": memories["coffee"]["created_at"].isoformat(), + "updated_at": memories["coffee"]["updated_at"].isoformat(), + "source_provenance": { + "sources": ["symbolic", "semantic"], + "semantic_score": 1.0, + }, + } + ] + assert payload["context_pack"]["memory_summary"] == { + "candidate_count": 3, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 1, + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "semantic_limit": 2, + "symbolic_selected_count": 1, + "semantic_selected_count": 2, + "merged_candidate_count": 2, + "deduplicated_count": 1, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_hybrid_memory_limit" + and event["payload"]["entity_id"] == str(memories["coffee"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] == 1.0 + and event["payload"]["selected_sources"] == ["symbolic", "semantic"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_deduplicated" + and event["payload"]["entity_id"] == str(memories["coffee"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] == 1.0 + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_limit_exceeded" + and event["payload"]["entity_id"] == str(memories["breakfast"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] == 1.0 + and event["payload"]["selected_sources"] == ["semantic"] + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_deleted" + and event["payload"]["entity_id"] == str(memories["deleted"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] is None + and event["payload"]["selected_sources"] == ["symbolic"] + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert trace_events[-1]["payload"]["hybrid_memory_requested"] is True + assert trace_events[-1]["payload"]["hybrid_memory_candidate_count"] == 3 + assert trace_events[-1]["payload"]["hybrid_memory_merged_candidate_count"] == 2 + assert trace_events[-1]["payload"]["hybrid_memory_deduplicated_count"] == 1 + assert trace_events[-1]["payload"]["included_dual_source_memory_count"] == 1 + + +def test_compile_context_semantic_validation_rejects_missing_config_dimension_mismatch_and_cross_user_access( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_traceable_thread(migrated_database_urls["app"]) + intruder = seed_traceable_thread( + migrated_database_urls["app"], + email="intruder@example.com", + display_name="Intruder", + ) + owner_config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=owner["user_id"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_status, missing_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic": { + "embedding_config_id": str(uuid4()), + "query_vector": [1.0, 0.0, 0.0], + "limit": 1, + }, + } + ) + mismatch_status, mismatch_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic": { + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0], + "limit": 1, + }, + } + ) + cross_user_status, cross_user_payload = invoke_compile_context( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "semantic": { + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 1, + }, + } + ) + + assert missing_status == 400 + assert missing_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "query_vector length must match embedding config dimensions (3): 2" + assert cross_user_status == 400 + assert cross_user_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_traces_and_trace_events_respect_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + owner_id = seeded["user_id"] + thread_id = seeded["thread_id"] + owner_event_ids = seeded["event_ids"] + owner_entities = seeded["entities"] + owner_entity_edges = seeded["entity_edges"] + intruder_id = uuid4() + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + store.create_user(intruder_id, "intruder@example.com", "Intruder") + intruder_thread = store.create_thread("Intruder thread") + intruder_session = store.create_session(intruder_thread["id"], status="active") + intruder_event = store.append_event( + intruder_thread["id"], + intruder_session["id"], + "message.user", + {"text": "intruder memory"}, + ) + store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(intruder_event["id"])], + ) + intruder_memory = store.create_memory( + memory_key="user.preference.tea", + value={"likes": "green"}, + status="active", + source_event_ids=[str(intruder_event["id"])], + ) + store.create_entity( + entity_type="merchant", + name="Intruder Cafe", + source_memory_ids=[str(intruder_memory["id"])], + ) + intruder_project = store.create_entity( + entity_type="project", + name="Intruder Project", + source_memory_ids=[str(intruder_memory["id"])], + ) + store.create_entity_edge( + from_entity_id=intruder_project["id"], + to_entity_id=store.list_entities()[0]["id"], + relationship_type="hidden_from_owner", + valid_from=None, + valid_to=None, + source_memory_ids=[str(intruder_memory["id"])], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(owner_id), + "thread_id": str(thread_id), + } + ) + + assert status_code == 200 + trace_id = UUID(payload["trace_id"]) + assert [memory["source_event_ids"] for memory in payload["context_pack"]["memories"]] == [ + [str(owner_event_ids[0])], + [str(owner_event_ids[1])], + ] + assert [memory["source_provenance"] for memory in payload["context_pack"]["memories"]] == [ + {"sources": ["symbolic"], "semantic_score": None}, + {"sources": ["symbolic"], "semantic_score": None}, + ] + assert [entity["id"] for entity in payload["context_pack"]["entities"]] == [ + str(entity["id"]) for entity in owner_entities + ] + assert [edge["id"] for edge in payload["context_pack"]["entity_edges"]] == [ + str(edge["id"]) for edge in owner_entity_edges + ] + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) AS count FROM traces WHERE id = %s", (trace_id,)) + trace_count = cur.fetchone() + cur.execute("SELECT COUNT(*) AS count FROM trace_events WHERE trace_id = %s", (trace_id,)) + trace_event_count = cur.fetchone() + + assert trace_count["count"] == 0 + assert trace_event_count["count"] == 0 + assert store.list_trace_events(trace_id) == [] diff --git a/tests/integration/test_continuity_store.py b/tests/integration/test_continuity_store.py new file mode 100644 index 0000000..9561563 --- /dev/null +++ b/tests/integration/test_continuity_store.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, TimeoutError +from uuid import uuid4 + +import psycopg +from psycopg.rows import dict_row +import pytest + +from alicebot_api.db import set_current_user, user_connection +from alicebot_api.store import ContinuityStore + + +def test_thread_session_and_event_persistence(migrated_database_urls): + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + user = store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Starter thread") + session = store.create_session(thread["id"]) + first_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "hello"}, + ) + second_event = store.append_event( + thread["id"], + session["id"], + "message.assistant", + {"text": "hi"}, + ) + events = store.list_thread_events(thread["id"]) + + assert user["id"] == user_id + assert session["thread_id"] == thread["id"] + assert [first_event["sequence_no"], second_event["sequence_no"]] == [1, 2] + assert [event["kind"] for event in events] == ["message.user", "message.assistant"] + assert events[0]["payload"]["text"] == "hello" + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "UPDATE events SET kind = 'message.mutated' WHERE id = %s", + (first_event["id"],), + ) + + +def test_event_deletes_are_rejected_at_database_level(migrated_database_urls): + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Delete-protected thread") + session = store.create_session(thread["id"]) + event = store.append_event(thread["id"], session["id"], "message.user", {"text": "keep"}) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute("DELETE FROM events WHERE id = %s", (event["id"],)) + + +def test_continuity_rls_blocks_cross_user_access(migrated_database_urls): + owner_id = uuid4() + intruder_id = uuid4() + + with user_connection(migrated_database_urls["app"], owner_id) as owner_conn: + owner_store = ContinuityStore(owner_conn) + owner_store.create_user(owner_id, "owner@example.com", "Owner") + thread = owner_store.create_thread("Private thread") + session = owner_store.create_session(thread["id"]) + owner_store.append_event(thread["id"], session["id"], "message.user", {"text": "secret"}) + + with user_connection(migrated_database_urls["app"], intruder_id) as intruder_conn: + intruder_store = ContinuityStore(intruder_conn) + intruder_store.create_user(intruder_id, "intruder@example.com", "Intruder") + + with intruder_conn.cursor() as cur: + cur.execute("SELECT COUNT(*) AS count FROM users WHERE id = %s", (owner_id,)) + user_count_row = cur.fetchone() + cur.execute("SELECT COUNT(*) AS count FROM threads WHERE id = %s", (thread["id"],)) + thread_count_row = cur.fetchone() + cur.execute("SELECT COUNT(*) AS count FROM sessions WHERE id = %s", (session["id"],)) + session_count_row = cur.fetchone() + + visible_events = intruder_store.list_thread_events(thread["id"]) + + assert user_count_row["count"] == 0 + assert thread_count_row["count"] == 0 + assert session_count_row["count"] == 0 + assert visible_events == [] + + with pytest.raises(psycopg.Error): + intruder_store.append_event( + thread["id"], + None, + "message.user", + {"text": "tamper"}, + ) + + +def test_runtime_role_is_insert_select_only_for_continuity_tables(migrated_database_urls): + with psycopg.connect(migrated_database_urls["app"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + has_table_privilege(current_user, 'users', 'SELECT'), + has_table_privilege(current_user, 'users', 'INSERT'), + has_table_privilege(current_user, 'users', 'UPDATE'), + has_table_privilege(current_user, 'threads', 'UPDATE'), + has_table_privilege(current_user, 'sessions', 'UPDATE'), + has_table_privilege(current_user, 'events', 'UPDATE'), + has_table_privilege(current_user, 'events', 'DELETE'), + has_table_privilege(current_user, 'traces', 'SELECT'), + has_table_privilege(current_user, 'traces', 'INSERT'), + has_table_privilege(current_user, 'traces', 'UPDATE'), + has_table_privilege(current_user, 'trace_events', 'SELECT'), + has_table_privilege(current_user, 'trace_events', 'INSERT'), + has_table_privilege(current_user, 'trace_events', 'UPDATE'), + has_table_privilege(current_user, 'trace_events', 'DELETE'), + has_table_privilege(current_user, 'consents', 'SELECT'), + has_table_privilege(current_user, 'consents', 'INSERT'), + has_table_privilege(current_user, 'consents', 'UPDATE'), + has_table_privilege(current_user, 'consents', 'DELETE'), + has_table_privilege(current_user, 'policies', 'SELECT'), + has_table_privilege(current_user, 'policies', 'INSERT'), + has_table_privilege(current_user, 'policies', 'UPDATE'), + has_table_privilege(current_user, 'policies', 'DELETE'), + has_table_privilege(current_user, 'tools', 'SELECT'), + has_table_privilege(current_user, 'tools', 'INSERT'), + has_table_privilege(current_user, 'tools', 'UPDATE'), + has_table_privilege(current_user, 'tools', 'DELETE') + """ + ) + ( + users_select, + users_insert, + users_update, + threads_update, + sessions_update, + events_update, + events_delete, + traces_select, + traces_insert, + traces_update, + trace_events_select, + trace_events_insert, + trace_events_update, + trace_events_delete, + consents_select, + consents_insert, + consents_update, + consents_delete, + policies_select, + policies_insert, + policies_update, + policies_delete, + tools_select, + tools_insert, + tools_update, + tools_delete, + ) = cur.fetchone() + + assert users_select is True + assert users_insert is True + assert users_update is False + assert threads_update is False + assert sessions_update is False + assert events_update is False + assert events_delete is False + assert traces_select is True + assert traces_insert is True + assert traces_update is False + assert trace_events_select is True + assert trace_events_insert is True + assert trace_events_update is False + assert trace_events_delete is False + assert consents_select is True + assert consents_insert is True + assert consents_update is True + assert consents_delete is False + assert policies_select is True + assert policies_insert is True + assert policies_update is False + assert policies_delete is False + assert tools_select is True + assert tools_insert is True + assert tools_update is False + assert tools_delete is False + + +def test_concurrent_event_appends_keep_monotonic_sequence_numbers(migrated_database_urls): + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Concurrent thread") + session = store.create_session(thread["id"]) + + with ( + psycopg.connect(migrated_database_urls["app"], row_factory=dict_row) as first_conn, + psycopg.connect(migrated_database_urls["app"], row_factory=dict_row) as second_conn, + ): + set_current_user(first_conn, user_id) + set_current_user(second_conn, user_id) + + first_store = ContinuityStore(first_conn) + second_store = ContinuityStore(second_conn) + first_event = first_store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "first"}, + ) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + second_store.append_event, + thread["id"], + session["id"], + "message.assistant", + {"text": "second"}, + ) + + with pytest.raises(TimeoutError): + future.result(timeout=0.2) + + first_conn.commit() + second_event = future.result(timeout=5) + + second_conn.commit() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + events = store.list_thread_events(thread["id"]) + + assert [first_event["sequence_no"], second_event["sequence_no"]] == [1, 2] + assert [event["sequence_no"] for event in events] == [1, 2] diff --git a/tests/integration/test_embeddings_api.py b/tests/integration/test_embeddings_api.py new file mode 100644 index 0000000..974c2c5 --- /dev/null +++ b/tests/integration/test_embeddings_api.py @@ -0,0 +1,793 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user_with_memory(database_url: str, *, email: str) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Embedding source thread") + session = store.create_session(thread["id"], status="active") + event_id = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "likes oat milk"}, + )["id"] + memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_id,), + ), + ) + + return { + "user_id": user_id, + "memory_id": UUID(memory.memory["id"]), + } + + +def seed_embedding_config( + database_url: str, + *, + user_id: UUID, + provider: str, + model: str, + version: str, + dimensions: int, +) -> UUID: + with user_connection(database_url, user_id) as conn: + created = ContinuityStore(conn).create_embedding_config( + provider=provider, + model=model, + version=version, + dimensions=dimensions, + status="active", + metadata={"task": "memory_retrieval"}, + ) + return created["id"] + + +def seed_memory_with_embedding( + database_url: str, + *, + user_id: UUID, + memory_key: str, + value: dict[str, object], + embedding_config_id: UUID, + vector: list[float], + delete_requested: bool = False, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + thread = store.create_thread(f"Semantic retrieval thread for {memory_key}") + session = store.create_session(thread["id"], status="active") + event_id = store.append_event( + thread["id"], + session["id"], + "message.user", + {"memory_key": memory_key, "value": value}, + )["id"] + admitted = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key=memory_key, + value=value, + source_event_ids=(event_id,), + ), + ) + memory_id = UUID(admitted.memory["id"]) + store.create_memory_embedding( + memory_id=memory_id, + embedding_config_id=embedding_config_id, + dimensions=len(vector), + vector=vector, + ) + if delete_requested: + delete_event_id = store.append_event( + thread["id"], + session["id"], + "message.user", + {"memory_key": memory_key, "delete_requested": True}, + )["id"] + admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key=memory_key, + value=None, + source_event_ids=(delete_event_id,), + delete_requested=True, + ), + ) + return memory_id + + +def test_embedding_config_endpoints_create_and_list_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-small", + version="2026-03-11", + dimensions=1536, + ) + seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + create_status, create_payload = invoke_request( + "POST", + "/v0/embedding-configs", + payload={ + "user_id": str(seeded["user_id"]), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-13", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/embedding-configs", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert create_status == 201 + assert create_payload["embedding_config"]["provider"] == "openai" + assert create_payload["embedding_config"]["version"] == "2026-03-13" + assert list_status == 200 + assert list_payload["summary"] == { + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + expected_configs = ContinuityStore(conn).list_embedding_configs() + + assert [item["id"] for item in list_payload["items"]] == [ + str(config["id"]) for config in expected_configs + ] + + +def test_embedding_config_create_rejects_duplicate_provider_model_version( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + first_status, first_payload = invoke_request( + "POST", + "/v0/embedding-configs", + payload={ + "user_id": str(seeded["user_id"]), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + }, + ) + second_status, second_payload = invoke_request( + "POST", + "/v0/embedding-configs", + payload={ + "user_id": str(seeded["user_id"]), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + }, + ) + + assert first_status == 201 + assert first_payload["embedding_config"]["version"] == "2026-03-12" + assert second_status == 400 + assert second_payload == { + "detail": ( + "embedding config already exists for provider/model/version under the user scope: " + "openai/text-embedding-3-large/2026-03-12" + ) + } + + +def test_memory_embedding_endpoints_persist_and_read_embeddings( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + first_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-small", + version="2026-03-11", + dimensions=3, + ) + second_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + first_write_status, first_write_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "memory_id": str(seeded["memory_id"]), + "embedding_config_id": str(first_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + second_write_status, second_write_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "memory_id": str(seeded["memory_id"]), + "embedding_config_id": str(second_config_id), + "vector": [0.4, 0.5, 0.6], + }, + ) + update_status, update_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "memory_id": str(seeded["memory_id"]), + "embedding_config_id": str(first_config_id), + "vector": [0.9, 0.8, 0.7], + }, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/embeddings", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/memory-embeddings/{first_write_payload['embedding']['id']}", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert first_write_status == 201 + assert first_write_payload["write_mode"] == "created" + assert second_write_status == 201 + assert second_write_payload["write_mode"] == "created" + assert update_status == 201 + assert update_payload["write_mode"] == "updated" + assert update_payload["embedding"]["id"] == first_write_payload["embedding"]["id"] + assert update_payload["embedding"]["vector"] == [0.9, 0.8, 0.7] + assert list_status == 200 + assert list_payload["summary"] == { + "memory_id": str(seeded["memory_id"]), + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload["embedding"]["id"] == first_write_payload["embedding"]["id"] + assert detail_payload["embedding"]["vector"] == [0.9, 0.8, 0.7] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored = ContinuityStore(conn).list_memory_embeddings_for_memory(seeded["memory_id"]) + + assert [item["id"] for item in list_payload["items"]] == [ + str(embedding["id"]) for embedding in stored + ] + assert len(stored) == 2 + assert stored[0]["embedding_config_id"] == first_config_id + assert stored[0]["vector"] == [0.9, 0.8, 0.7] + assert stored[1]["embedding_config_id"] == second_config_id + + +def test_memory_embedding_writes_reject_invalid_references_dimension_mismatches_and_cross_user_refs( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_config_status, missing_config_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(uuid4()), + "vector": [0.1, 0.2, 0.3], + }, + ) + missing_memory_status, missing_memory_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(uuid4()), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + mismatch_status, mismatch_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2], + }, + ) + cross_user_status, cross_user_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(intruder["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(intruder_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + cross_user_config_status, cross_user_config_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(intruder["user_id"]), + "memory_id": str(intruder["memory_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + + assert missing_config_status == 400 + assert missing_config_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert missing_memory_status == 400 + assert missing_memory_payload["detail"].startswith( + "memory_id must reference an existing memory owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "vector length must match embedding config dimensions (3): 2" + assert cross_user_status == 400 + assert cross_user_payload["detail"] == ( + f"memory_id must reference an existing memory owned by the user: {owner['memory_id']}" + ) + assert cross_user_config_status == 400 + assert cross_user_config_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_embedding_reads_respect_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + write_status, write_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + config_list_status, config_list_payload = invoke_request( + "GET", + "/v0/embedding-configs", + query_params={"user_id": str(intruder["user_id"])}, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{owner['memory_id']}/embeddings", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/memory-embeddings/{write_payload['embedding']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert write_status == 201 + assert config_list_status == 200 + assert config_list_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["created_at_asc", "id_asc"], + }, + } + assert list_status == 404 + assert list_payload == {"detail": f"memory {owner['memory_id']} was not found"} + assert detail_status == 404 + assert detail_payload == { + "detail": f"memory embedding {write_payload['embedding']['id']} was not found" + } + + +def test_semantic_memory_retrieval_returns_deterministic_results_and_excludes_deleted_memories( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + first_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.breakfast", + value={"likes": "porridge"}, + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + deleted_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.deleted", + value={"likes": "hidden"}, + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + delete_requested=True, + ) + second_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.lunch", + value={"likes": "ramen"}, + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + third_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.music", + value={"likes": "jazz"}, + embedding_config_id=config_id, + vector=[0.0, 1.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status, payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(seeded["user_id"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 10, + }, + ) + + assert status == 200 + assert payload["summary"] == { + "embedding_config_id": str(config_id), + "limit": 10, + "returned_count": 3, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + } + assert [item["memory_id"] for item in payload["items"]] == [ + str(first_memory_id), + str(second_memory_id), + str(third_memory_id), + ] + assert str(deleted_memory_id) not in {item["memory_id"] for item in payload["items"]} + assert payload["items"][0]["score"] == payload["items"][1]["score"] + assert payload["items"][0]["score"] > payload["items"][2]["score"] + assert set(payload["items"][0]) == { + "memory_id", + "memory_key", + "value", + "source_event_ids", + "created_at", + "updated_at", + "score", + } + + +def test_semantic_memory_retrieval_rejects_invalid_config_dimension_mismatch_and_cross_user_access( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=owner["user_id"], + memory_key="user.preference.owner", + value={"likes": "oat milk"}, + embedding_config_id=owner_config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=intruder["user_id"], + memory_key="user.preference.intruder", + value={"likes": "almond milk"}, + embedding_config_id=intruder_config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_status, missing_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(uuid4()), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + mismatch_status, mismatch_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0], + "limit": 5, + }, + ) + cross_user_status, cross_user_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + + assert missing_status == 400 + assert missing_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "query_vector length must match embedding config dimensions (3): 2" + assert cross_user_status == 400 + assert cross_user_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_semantic_memory_retrieval_scopes_results_per_user( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + owner_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=owner["user_id"], + memory_key="user.preference.owner.semantic", + value={"likes": "espresso"}, + embedding_config_id=owner_config_id, + vector=[1.0, 0.0, 0.0], + ) + intruder_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=intruder["user_id"], + memory_key="user.preference.intruder.semantic", + value={"likes": "matcha"}, + embedding_config_id=intruder_config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + owner_status, owner_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + intruder_status, intruder_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(intruder_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + + assert owner_status == 200 + assert [item["memory_id"] for item in owner_payload["items"]] == [str(owner_memory_id)] + assert intruder_status == 200 + assert [item["memory_id"] for item in intruder_payload["items"]] == [str(intruder_memory_id)] diff --git a/tests/integration/test_entities_api.py b/tests/integration/test_entities_api.py new file mode 100644 index 0000000..4236c1f --- /dev/null +++ b/tests/integration/test_entities_api.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user_with_source_memories(database_url: str, *, email: str) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Entity source thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "works on AliceBot"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "drinks oat milk"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "shops at cafe"})["id"], + ] + + first_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.project.current", + value={"name": "AliceBot"}, + source_event_ids=(event_ids[0],), + ), + ) + second_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_ids[1],), + ), + ) + third_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.merchant", + value={"name": "Neighborhood Cafe"}, + source_event_ids=(event_ids[2],), + ), + ) + + return { + "user_id": user_id, + "memory_ids": [ + UUID(first_memory.memory["id"]), + UUID(second_memory.memory["id"]), + UUID(third_memory.memory["id"]), + ], + } + + +def test_create_entity_endpoint_persists_entity_backed_by_user_owned_source_memories( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entities", + payload={ + "user_id": str(seeded["user_id"]), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(seeded["memory_ids"][0]), str(seeded["memory_ids"][1])], + }, + ) + + assert status_code == 201 + assert payload["entity"]["entity_type"] == "project" + assert payload["entity"]["name"] == "AliceBot" + assert payload["entity"]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][1]), + ] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored_entities = ContinuityStore(conn).list_entities() + + assert len(stored_entities) == 1 + assert stored_entities[0]["id"] == UUID(payload["entity"]["id"]) + assert stored_entities[0]["entity_type"] == "project" + assert stored_entities[0]["name"] == "AliceBot" + assert stored_entities[0]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][1]), + ] + + +def test_entity_endpoints_list_and_get_entities_in_deterministic_user_scoped_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + created_entities = [ + store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(seeded["memory_ids"][0])], + ), + store.create_entity( + entity_type="merchant", + name="Neighborhood Cafe", + source_memory_ids=[str(seeded["memory_ids"][2])], + ), + store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(seeded["memory_ids"][0]), str(seeded["memory_ids"][1])], + ), + ] + + list_status, list_payload = invoke_request( + "GET", + "/v0/entities", + query_params={"user_id": str(seeded["user_id"])}, + ) + + expected_entities = sorted(created_entities, key=lambda entity: (entity["created_at"], entity["id"])) + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [str(entity["id"]) for entity in expected_entities] + assert list_payload["summary"] == { + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + target_entity = expected_entities[1] + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/entities/{target_entity['id']}", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert detail_status == 200 + assert detail_payload == { + "entity": { + "id": str(target_entity["id"]), + "entity_type": target_entity["entity_type"], + "name": target_entity["name"], + "source_memory_ids": target_entity["source_memory_ids"], + "created_at": target_entity["created_at"].isoformat(), + } + } + + +def test_entity_endpoints_enforce_per_user_isolation_and_not_found_behavior( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_source_memories(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + entity = ContinuityStore(conn).create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(owner["memory_ids"][0])], + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/entities", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/entities/{entity['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + create_status, create_payload = invoke_request( + "POST", + "/v0/entities", + payload={ + "user_id": str(intruder["user_id"]), + "entity_type": "project", + "name": "Hidden Project", + "source_memory_ids": [str(owner["memory_ids"][0])], + }, + ) + + assert list_status == 200 + assert list_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["created_at_asc", "id_asc"], + }, + } + assert detail_status == 404 + assert detail_payload == { + "detail": f"entity {entity['id']} was not found", + } + assert create_status == 400 + assert create_payload["detail"].startswith( + "source_memory_ids must all reference existing memories owned by the user" + ) + + +def test_create_entity_endpoint_rejects_missing_source_memory_ids(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + missing_memory_id = uuid4() + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entities", + payload={ + "user_id": str(seeded["user_id"]), + "entity_type": "routine", + "name": "Morning Coffee", + "source_memory_ids": [str(missing_memory_id)], + }, + ) + + assert status_code == 400 + assert payload == { + "detail": "source_memory_ids must all reference existing memories owned by the user: " + f"{missing_memory_id}" + } diff --git a/tests/integration/test_entity_edges_api.py b/tests/integration/test_entity_edges_api.py new file mode 100644 index 0000000..d8ea5be --- /dev/null +++ b/tests/integration/test_entity_edges_api.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user_with_source_memories(database_url: str, *, email: str) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Entity edge source thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "works on AliceBot"})["id"], + store.append_event( + thread["id"], session["id"], "message.user", {"text": "works with Neighborhood Cafe"} + )["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "coffee preference"})["id"], + ] + + first_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.project.current", + value={"name": "AliceBot"}, + source_event_ids=(event_ids[0],), + ), + ) + second_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.merchant", + value={"name": "Neighborhood Cafe"}, + source_event_ids=(event_ids[1],), + ), + ) + third_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_ids[2],), + ), + ) + + return { + "user_id": user_id, + "memory_ids": [ + UUID(first_memory.memory["id"]), + UUID(second_memory.memory["id"]), + UUID(third_memory.memory["id"]), + ], + } + + +def seed_entities( + database_url: str, + *, + user_id: UUID, + memory_ids: list[UUID], +) -> dict[str, UUID]: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + person = store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(memory_ids[2])], + ) + merchant = store.create_entity( + entity_type="merchant", + name="Neighborhood Cafe", + source_memory_ids=[str(memory_ids[1])], + ) + project = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(memory_ids[0])], + ) + + return { + "person": person["id"], + "merchant": merchant["id"], + "project": project["id"], + } + + +def test_create_entity_edge_endpoint_persists_user_scoped_edge_with_temporal_metadata( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + entities = seed_entities( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_ids=seeded["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(seeded["user_id"]), + "from_entity_id": str(entities["person"]), + "to_entity_id": str(entities["project"]), + "relationship_type": "works_on", + "valid_from": "2026-03-12T10:00:00+00:00", + "valid_to": "2026-03-12T12:00:00+00:00", + "source_memory_ids": [str(seeded["memory_ids"][0]), str(seeded["memory_ids"][2])], + }, + ) + + assert status_code == 201 + assert payload["edge"]["from_entity_id"] == str(entities["person"]) + assert payload["edge"]["to_entity_id"] == str(entities["project"]) + assert payload["edge"]["relationship_type"] == "works_on" + assert payload["edge"]["valid_from"] == "2026-03-12T10:00:00+00:00" + assert payload["edge"]["valid_to"] == "2026-03-12T12:00:00+00:00" + assert payload["edge"]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][2]), + ] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored_edges = ContinuityStore(conn).list_entity_edges_for_entity(entities["person"]) + + assert len(stored_edges) == 1 + assert stored_edges[0]["id"] == UUID(payload["edge"]["id"]) + assert stored_edges[0]["relationship_type"] == "works_on" + assert stored_edges[0]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][2]), + ] + + +def test_entity_edge_list_endpoint_returns_incident_edges_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + entities = seed_entities( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_ids=seeded["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + created_edges = [ + store.create_entity_edge( + from_entity_id=entities["person"], + to_entity_id=entities["project"], + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=[str(seeded["memory_ids"][0])], + ), + store.create_entity_edge( + from_entity_id=entities["merchant"], + to_entity_id=entities["project"], + relationship_type="supplies", + valid_from=None, + valid_to=None, + source_memory_ids=[str(seeded["memory_ids"][1])], + ), + store.create_entity_edge( + from_entity_id=entities["project"], + to_entity_id=entities["merchant"], + relationship_type="references", + valid_from=None, + valid_to=None, + source_memory_ids=[str(seeded["memory_ids"][2])], + ), + ] + + status_code, payload = invoke_request( + "GET", + f"/v0/entities/{entities['project']}/edges", + query_params={"user_id": str(seeded["user_id"])}, + ) + + expected_edges = sorted(created_edges, key=lambda edge: (edge["created_at"], edge["id"])) + + assert status_code == 200 + assert [item["id"] for item in payload["items"]] == [str(edge["id"]) for edge in expected_edges] + assert payload["summary"] == { + "entity_id": str(entities["project"]), + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + +def test_entity_edge_endpoints_enforce_per_user_isolation_and_reference_validation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + owner_entities = seed_entities( + migrated_database_urls["app"], + user_id=owner["user_id"], + memory_ids=owner["memory_ids"], + ) + intruder = seed_user_with_source_memories(migrated_database_urls["app"], email="intruder@example.com") + intruder_entities = seed_entities( + migrated_database_urls["app"], + user_id=intruder["user_id"], + memory_ids=intruder["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + ContinuityStore(conn).create_entity_edge( + from_entity_id=owner_entities["person"], + to_entity_id=owner_entities["project"], + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=[str(owner["memory_ids"][0])], + ) + + list_status, list_payload = invoke_request( + "GET", + f"/v0/entities/{owner_entities['project']}/edges", + query_params={"user_id": str(intruder['user_id'])}, + ) + entity_status, entity_payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(intruder["user_id"]), + "from_entity_id": str(owner_entities["person"]), + "to_entity_id": str(intruder_entities["project"]), + "relationship_type": "works_on", + "source_memory_ids": [str(intruder["memory_ids"][0])], + }, + ) + memory_status, memory_payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(intruder["user_id"]), + "from_entity_id": str(intruder_entities["person"]), + "to_entity_id": str(intruder_entities["project"]), + "relationship_type": "works_on", + "source_memory_ids": [str(owner["memory_ids"][0])], + }, + ) + + assert list_status == 404 + assert list_payload == { + "detail": f"entity {owner_entities['project']} was not found", + } + assert entity_status == 400 + assert entity_payload == { + "detail": "from_entity_id must reference an existing entity owned by the user: " + f"{owner_entities['person']}" + } + assert memory_status == 400 + assert memory_payload == { + "detail": "source_memory_ids must all reference existing memories owned by the user: " + f"{owner['memory_ids'][0]}" + } + + +def test_create_entity_edge_endpoint_rejects_invalid_temporal_range( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + entities = seed_entities( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_ids=seeded["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(seeded["user_id"]), + "from_entity_id": str(entities["person"]), + "to_entity_id": str(entities["project"]), + "relationship_type": "works_on", + "valid_from": "2026-03-12T12:00:00+00:00", + "valid_to": "2026-03-12T10:00:00+00:00", + "source_memory_ids": [str(seeded["memory_ids"][0])], + }, + ) + + assert status_code == 400 + assert payload == { + "detail": "valid_to must be greater than or equal to valid_from", + } diff --git a/tests/integration/test_execution_budgets_api.py b/tests/integration/test_execution_budgets_api.py new file mode 100644 index 0000000..5fe3e4f --- /dev/null +++ b/tests/integration/test_execution_budgets_api.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Budget lifecycle thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def create_budget( + *, + user_id: UUID, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, +) -> tuple[int, dict[str, Any]]: + payload: dict[str, Any] = { + "user_id": str(user_id), + "max_completed_executions": max_completed_executions, + } + if tool_key is not None: + payload["tool_key"] = tool_key + if domain_hint is not None: + payload["domain_hint"] = domain_hint + if rolling_window_seconds is not None: + payload["rolling_window_seconds"] = rolling_window_seconds + return invoke_request("POST", "/v0/execution-budgets", payload=payload) + + +def test_execution_budget_endpoints_create_list_and_get_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + second_status, second_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ) + first_status, first_payload = create_budget( + user_id=owner["user_id"], + tool_key=None, + domain_hint="docs", + max_completed_executions=1, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{second_payload['execution_budget']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert first_status == 201 + assert second_status == 201 + assert second_payload["execution_budget"]["status"] == "active" + assert second_payload["execution_budget"]["deactivated_at"] is None + assert second_payload["execution_budget"]["rolling_window_seconds"] == 3600 + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + second_payload["execution_budget"]["id"], + first_payload["execution_budget"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"execution_budget": second_payload["execution_budget"]} + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{first_payload['execution_budget']['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"execution budget {first_payload['execution_budget']['id']} was not found" + } + + +def test_create_execution_budget_endpoint_requires_at_least_one_selector( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + status_code, payload = invoke_request( + "POST", + "/v0/execution-budgets", + payload={ + "user_id": str(owner["user_id"]), + "max_completed_executions": 1, + }, + ) + + assert status_code == 400 + assert payload == { + "detail": "execution budget requires at least one selector: tool_key or domain_hint" + } + + +def test_create_execution_budget_endpoint_rejects_duplicate_active_scope( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + first_status, _ = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ) + second_status, second_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ) + + assert first_status == 201 + assert second_status == 400 + assert second_payload == { + "detail": "active execution budget already exists for selector scope tool_key='proxy.echo', domain_hint='docs'" + } + + +def test_deactivate_execution_budget_endpoint_updates_reads_and_emits_trace( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + create_status, create_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + assert create_status == 201 + + deactivate_status, deactivate_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_status, isolated_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + }, + ) + + assert deactivate_status == 200 + assert deactivate_payload["execution_budget"]["status"] == "inactive" + assert deactivate_payload["execution_budget"]["deactivated_at"] is not None + assert deactivate_payload["trace"]["trace_event_count"] == 3 + assert list_status == 200 + assert list_payload["items"][0] == deactivate_payload["execution_budget"] + assert detail_status == 200 + assert detail_payload == {"execution_budget": deactivate_payload["execution_budget"]} + assert isolated_status == 404 + assert isolated_payload == { + "detail": f"execution budget {create_payload['execution_budget']['id']} was not found" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(deactivate_payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(deactivate_payload["trace"]["trace_id"])) + + assert trace["kind"] == "execution_budget.lifecycle" + assert trace["compiler_version"] == "execution_budget_lifecycle_v0" + assert trace["limits"]["requested_action"] == "deactivate" + assert [event["kind"] for event in trace_events] == [ + "execution_budget.lifecycle.request", + "execution_budget.lifecycle.state", + "execution_budget.lifecycle.summary", + ] + assert trace_events[1]["payload"]["current_status"] == "inactive" + assert trace_events[2]["payload"]["outcome"] == "deactivated" + + +def test_supersede_execution_budget_endpoint_replaces_active_budget_and_emits_trace( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + create_status, create_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + rolling_window_seconds=1800, + ) + assert create_status == 201 + + supersede_status, supersede_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/supersede", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "max_completed_executions": 3, + }, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(owner["user_id"])}, + ) + original_detail_status, original_detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + replacement_detail_status, replacement_detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{supersede_payload['replacement_budget']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert supersede_status == 200 + assert supersede_payload["superseded_budget"]["status"] == "superseded" + assert supersede_payload["replacement_budget"]["status"] == "active" + assert supersede_payload["replacement_budget"]["rolling_window_seconds"] == 1800 + assert supersede_payload["replacement_budget"]["supersedes_budget_id"] == create_payload["execution_budget"]["id"] + assert supersede_payload["superseded_budget"]["superseded_by_budget_id"] == supersede_payload["replacement_budget"]["id"] + assert list_status == 200 + assert [item["status"] for item in list_payload["items"]] == ["superseded", "active"] + assert original_detail_status == 200 + assert original_detail_payload == {"execution_budget": supersede_payload["superseded_budget"]} + assert replacement_detail_status == 200 + assert replacement_detail_payload == {"execution_budget": supersede_payload["replacement_budget"]} + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(supersede_payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(supersede_payload["trace"]["trace_id"])) + + assert trace["limits"]["requested_action"] == "supersede" + assert trace["limits"]["outcome"] == "superseded" + assert trace_events[1]["payload"]["replacement_budget_id"] == supersede_payload["replacement_budget"]["id"] + assert trace_events[2]["payload"]["active_budget_id"] == supersede_payload["replacement_budget"]["id"] + + +def test_execution_budget_lifecycle_rejects_invalid_transition_deterministically( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + create_status, create_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + assert create_status == 201 + + first_status, _ = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + second_status, second_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + + assert first_status == 200 + assert second_status == 409 + assert second_payload == { + "detail": f"execution budget {create_payload['execution_budget']['id']} is inactive and cannot be deactivated" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace_rows = store.conn.execute( + "SELECT id FROM traces WHERE kind = %s ORDER BY created_at ASC, id ASC", + ("execution_budget.lifecycle",), + ).fetchall() + rejected_trace_events = store.list_trace_events(trace_rows[-1]["id"]) + + assert rejected_trace_events[1]["payload"]["rejection_reason"] == second_payload["detail"] + assert rejected_trace_events[2]["payload"]["outcome"] == "rejected" + + +def test_execution_budget_active_scope_uniqueness_is_enforced_in_database( + migrated_database_urls, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ) + + with pytest.raises(psycopg.IntegrityError): + with conn.transaction(): + store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ) diff --git a/tests/integration/test_explicit_preferences_api.py b/tests/integration/test_explicit_preferences_api.py new file mode 100644 index 0000000..67c4db2 --- /dev/null +++ b/tests/integration/test_explicit_preferences_api.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import json +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.explicit_preferences import _build_memory_key +from alicebot_api.store import ContinuityStore + + +def invoke_extract_explicit_preferences(payload: dict[str, str]) -> tuple[int, dict[str, object]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/memories/extract-explicit-preferences", + "raw_path": b"/v0/memories/extract-explicit-preferences", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_explicit_preference_events(database_url: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Explicit preference extraction") + session = store.create_session(thread["id"], status="active") + like_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I like black coffee."}, + )["id"] + dislike_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I don't like black coffee."}, + )["id"] + unsupported_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I had coffee yesterday."}, + )["id"] + clause_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I prefer that we meet tomorrow."}, + )["id"] + cpp_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I like C++."}, + )["id"] + csharp_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I like C#."}, + )["id"] + assistant_event = store.append_event( + thread["id"], + session["id"], + "message.assistant", + {"text": "I like black coffee."}, + )["id"] + + return { + "user_id": user_id, + "like_event_id": like_event, + "dislike_event_id": dislike_event, + "unsupported_event_id": unsupported_event, + "clause_event_id": clause_event, + "cpp_event_id": cpp_event, + "csharp_event_id": csharp_event, + "assistant_event_id": assistant_event, + } + + +def test_extract_explicit_preferences_endpoint_admits_supported_candidates_and_persists_revisions( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + memory_key = _build_memory_key("black coffee") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + add_status, add_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["like_event_id"]), + } + ) + update_status, update_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["dislike_event_id"]), + } + ) + + assert add_status == 200 + assert add_payload["candidates"] == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["like_event_id"])], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ] + assert add_payload["admissions"][0]["decision"] == "ADD" + assert add_payload["summary"] == { + "source_event_id": str(seeded["like_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + } + + assert update_status == 200 + assert update_payload["candidates"] == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["dislike_event_id"])], + "delete_requested": False, + "pattern": "i_dont_like", + "subject_text": "black coffee", + } + ] + assert update_payload["admissions"][0]["decision"] == "UPDATE" + assert update_payload["summary"] == { + "source_event_id": str(seeded["dislike_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + } + + memory_id = UUID(str(update_payload["admissions"][0]["memory"]["id"])) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + memories = store.list_memories() + revisions = store.list_memory_revisions(memory_id) + + assert len(memories) == 1 + assert memories[0]["id"] == memory_id + assert memories[0]["memory_key"] == memory_key + assert memories[0]["value"] == { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + } + assert [revision["action"] for revision in revisions] == ["ADD", "UPDATE"] + assert revisions[0]["candidate"] == { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["like_event_id"])], + "delete_requested": False, + } + assert revisions[1]["candidate"] == { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["dislike_event_id"])], + "delete_requested": False, + } + + +def test_extract_explicit_preferences_endpoint_returns_no_candidates_for_unsupported_text( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["unsupported_event_id"]), + } + ) + + assert status_code == 200 + assert payload == { + "candidates": [], + "admissions": [], + "summary": { + "source_event_id": str(seeded["unsupported_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 0, + "admission_count": 0, + "persisted_change_count": 0, + "noop_count": 0, + }, + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] + + +def test_extract_explicit_preferences_endpoint_rejects_clause_style_tail( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["clause_event_id"]), + } + ) + + assert status_code == 200 + assert payload == { + "candidates": [], + "admissions": [], + "summary": { + "source_event_id": str(seeded["clause_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 0, + "admission_count": 0, + "persisted_change_count": 0, + "noop_count": 0, + }, + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] + + +def test_extract_explicit_preferences_endpoint_keeps_symbol_subjects_in_distinct_memories( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + cpp_key = _build_memory_key("C++") + csharp_key = _build_memory_key("C#") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + cpp_status, cpp_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["cpp_event_id"]), + } + ) + csharp_status, csharp_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["csharp_event_id"]), + } + ) + + assert cpp_status == 200 + assert cpp_payload["candidates"][0]["memory_key"] == cpp_key + assert cpp_payload["admissions"][0]["decision"] == "ADD" + assert csharp_status == 200 + assert csharp_payload["candidates"][0]["memory_key"] == csharp_key + assert csharp_payload["admissions"][0]["decision"] == "ADD" + assert cpp_key != csharp_key + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + memories = sorted(store.list_memories(), key=lambda memory: memory["memory_key"]) + + assert [memory["memory_key"] for memory in memories] == sorted([cpp_key, csharp_key]) + assert {memory["memory_key"]: memory["value"] for memory in memories} == { + cpp_key: { + "kind": "explicit_preference", + "preference": "like", + "text": "C++", + }, + csharp_key: { + "kind": "explicit_preference", + "preference": "like", + "text": "C#", + }, + } + + +def test_extract_explicit_preferences_endpoint_validates_source_event_and_user_scope( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + intruder_id = uuid4() + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + + assistant_status, assistant_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["assistant_event_id"]), + } + ) + intruder_status, intruder_payload = invoke_extract_explicit_preferences( + { + "user_id": str(intruder_id), + "source_event_id": str(seeded["like_event_id"]), + } + ) + + assert assistant_status == 400 + assert assistant_payload == { + "detail": "source_event_id must reference an existing message.user event owned by the user" + } + assert intruder_status == 400 + assert intruder_payload == { + "detail": "source_event_id must reference an existing message.user event owned by the user" + } + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] diff --git a/tests/integration/test_healthcheck.py b/tests/integration/test_healthcheck.py new file mode 100644 index 0000000..47801f1 --- /dev/null +++ b/tests/integration/test_healthcheck.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +import socket +import subprocess +import time +from urllib import error, request + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings + + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def invoke_healthcheck() -> tuple[int, dict[str, object]]: + messages: list[dict[str, object]] = [] + + async def receive() -> dict[str, object]: + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "GET", + "scheme": "http", + "path": "/healthz", + "raw_path": b"/healthz", + "query_string": b"", + "headers": [], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def test_healthcheck_endpoint_returns_ok_response(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=2, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "ping_database", lambda *_args, **_kwargs: True) + + status_code, payload = invoke_healthcheck() + + assert status_code == 200 + assert payload["status"] == "ok" + assert payload["services"]["database"]["status"] == "ok" + assert payload["services"]["redis"]["status"] == "not_checked" + assert payload["services"]["redis"]["url"] == "redis://cache:6379/0" + assert payload["services"]["object_storage"]["status"] == "not_checked" + + +def test_healthcheck_endpoint_returns_degraded_response(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=2, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "ping_database", lambda *_args, **_kwargs: False) + + status_code, payload = invoke_healthcheck() + + assert status_code == 503 + assert payload["status"] == "degraded" + assert payload["services"]["database"]["status"] == "unreachable" + assert payload["services"]["redis"]["status"] == "not_checked" + assert payload["services"]["redis"]["url"] == "redis://cache:6379/0" + assert payload["services"]["object_storage"]["status"] == "not_checked" + + +def test_api_dev_script_serves_live_healthcheck() -> None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + + env = os.environ.copy() + env.update( + { + "APP_HOST": "127.0.0.1", + "APP_PORT": str(port), + "APP_RELOAD": "false", + "APP_ENV": "test", + "DATABASE_URL": "postgresql://invalid:invalid@127.0.0.1:1/invalid", + "REDIS_URL": "redis://alicebot:supersecret@localhost:6379/0", + "HEALTHCHECK_TIMEOUT_SECONDS": "1", + } + ) + + process = subprocess.Popen( + ["/bin/bash", str(REPO_ROOT / "scripts" / "api_dev.sh")], + cwd=REPO_ROOT, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + payload: dict[str, object] | None = None + status_code: int | None = None + + try: + deadline = time.time() + 15 + url = f"http://127.0.0.1:{port}/healthz" + + while time.time() < deadline: + if process.poll() is not None: + stdout, stderr = process.communicate(timeout=1) + raise AssertionError( + "api_dev.sh exited before serving /healthz\n" + f"stdout:\n{stdout}\n" + f"stderr:\n{stderr}" + ) + + try: + with request.urlopen(url, timeout=0.5) as response: + status_code = response.status + payload = json.loads(response.read()) + break + except error.HTTPError as exc: + status_code = exc.code + payload = json.loads(exc.read()) + break + except OSError: + time.sleep(0.1) + else: + raise AssertionError("Timed out waiting for api_dev.sh to serve /healthz") + finally: + process.terminate() + try: + process.communicate(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.communicate(timeout=5) + + assert status_code == 503 + assert payload == { + "status": "degraded", + "environment": "test", + "services": { + "database": {"status": "unreachable"}, + "redis": {"status": "not_checked", "url": "redis://localhost:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://localhost:9000", + }, + }, + } diff --git a/tests/integration/test_memory_admission.py b/tests/integration/test_memory_admission.py new file mode 100644 index 0000000..43c4e75 --- /dev/null +++ b/tests/integration/test_memory_admission.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import json +from typing import Any +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_admit_memory(payload: dict[str, Any]) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/memories/admit", + "raw_path": b"/v0/memories/admit", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_memory_evidence(database_url: str) -> tuple[UUID, list[UUID]]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Memory thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "likes black coffee"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "actually likes oat milk"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "stop remembering coffee"})["id"], + ] + + return user_id, event_ids + + +def test_admit_memory_endpoint_returns_noop_and_persists_nothing_without_value( + migrated_database_urls, + monkeypatch, +) -> None: + user_id, event_ids = seed_memory_evidence(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": None, + "source_event_ids": [str(event_ids[0])], + } + ) + + assert status_code == 200 + assert payload == { + "decision": "NOOP", + "reason": "candidate_value_missing", + "memory": None, + "revision": None, + } + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] + + +def test_admit_memory_endpoint_rejects_unknown_source_events(migrated_database_urls, monkeypatch) -> None: + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + ContinuityStore(conn).create_user(user_id, "owner@example.com", "Owner") + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(uuid4())], + } + ) + + assert status_code == 400 + assert payload["detail"].startswith( + "source_event_ids must all reference existing events owned by the user" + ) + + +def test_admit_memory_endpoint_persists_add_update_and_delete_revisions( + migrated_database_urls, + monkeypatch, +) -> None: + user_id, event_ids = seed_memory_evidence(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + add_status, add_payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(event_ids[0])], + } + ) + update_status, update_payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": [str(event_ids[1])], + } + ) + delete_status, delete_payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": None, + "source_event_ids": [str(event_ids[2])], + "delete_requested": True, + } + ) + + assert add_status == 200 + assert add_payload["decision"] == "ADD" + assert update_status == 200 + assert update_payload["decision"] == "UPDATE" + assert delete_status == 200 + assert delete_payload["decision"] == "DELETE" + + memory_id = UUID(delete_payload["memory"]["id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + memories = store.list_memories() + revisions = store.list_memory_revisions(memory_id) + + assert len(memories) == 1 + assert memories[0]["id"] == memory_id + assert memories[0]["status"] == "deleted" + assert memories[0]["source_event_ids"] == [str(event_ids[2])] + assert [revision["sequence_no"] for revision in revisions] == [1, 2, 3] + assert [revision["action"] for revision in revisions] == ["ADD", "UPDATE", "DELETE"] + assert revisions[0]["new_value"] == {"likes": "black"} + assert revisions[1]["previous_value"] == {"likes": "black"} + assert revisions[1]["new_value"] == {"likes": "oat milk"} + assert revisions[2]["previous_value"] == {"likes": "oat milk"} + assert revisions[2]["new_value"] is None + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + with pytest.raises(psycopg.Error, match="append-only"): + cur.execute( + "UPDATE memory_revisions SET action = 'MUTATED' WHERE memory_id = %s", + (memory_id,), + ) + + +def test_memories_and_memory_revisions_respect_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner_id, event_ids = seed_memory_evidence(migrated_database_urls["app"]) + intruder_id = uuid4() + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_admit_memory( + { + "user_id": str(owner_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(event_ids[0])], + } + ) + + assert status_code == 200 + memory_id = UUID(payload["memory"]["id"]) + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + store.create_user(intruder_id, "intruder@example.com", "Intruder") + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) AS count FROM memories WHERE id = %s", (memory_id,)) + memory_count = cur.fetchone() + cur.execute( + "SELECT COUNT(*) AS count FROM memory_revisions WHERE memory_id = %s", + (memory_id,), + ) + revision_count = cur.fetchone() + cur.execute( + "UPDATE memories SET status = 'deleted' WHERE id = %s RETURNING id", + (memory_id,), + ) + updated_rows = cur.fetchall() + + assert memory_count["count"] == 0 + assert revision_count["count"] == 0 + assert updated_rows == [] + assert store.list_memories() == [] + assert store.list_memory_revisions(memory_id) == [] diff --git a/tests/integration/test_memory_review_api.py b/tests/integration/test_memory_review_api.py new file mode 100644 index 0000000..c096817 --- /dev/null +++ b/tests/integration/test_memory_review_api.py @@ -0,0 +1,526 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_review_memories(database_url: str) -> dict[str, str]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "reviewer@example.com", "Reviewer") + thread = store.create_thread("Memory review thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "likes black coffee"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "likes salty snacks"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "reads science fiction"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "enjoys hiking"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "forget the snack preference"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "actually likes oat milk"})["id"], + ] + + coffee = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "black"}, + source_event_ids=(event_ids[0],), + ), + ) + snack = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.snack", + value={"likes": "chips"}, + source_event_ids=(event_ids[1],), + ), + ) + book = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.book", + value={"genre": "science fiction"}, + source_event_ids=(event_ids[2],), + ), + ) + hobby = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.hobby", + value={"likes": "hiking"}, + source_event_ids=(event_ids[3],), + ), + ) + admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.snack", + value=None, + source_event_ids=(event_ids[4],), + delete_requested=True, + ), + ) + admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_ids[5],), + ), + ) + + return { + "user_id": str(user_id), + "coffee_memory_id": coffee.memory["id"], + "snack_memory_id": snack.memory["id"], + "book_memory_id": book.memory["id"], + "hobby_memory_id": hobby.memory["id"], + "coffee_add_event_id": str(event_ids[0]), + "coffee_update_event_id": str(event_ids[5]), + "book_add_event_id": str(event_ids[2]), + "hobby_add_event_id": str(event_ids[3]), + "snack_delete_event_id": str(event_ids[4]), + } + + +def seed_review_queue_state(database_url: str) -> dict[str, str]: + seeded = seed_review_memories(database_url) + + with user_connection(database_url, UUID(seeded["user_id"])) as conn: + store = ContinuityStore(conn) + store.create_memory_review_label( + memory_id=UUID(seeded["hobby_memory_id"]), + label="correct", + note="Already reviewed.", + ) + store.create_memory_review_label( + memory_id=UUID(seeded["snack_memory_id"]), + label="outdated", + note="Deleted memory remains part of evaluation counts only.", + ) + + return seeded + + +def seed_memory_evaluation_state(database_url: str) -> dict[str, str]: + seeded = seed_review_memories(database_url) + + with user_connection(database_url, UUID(seeded["user_id"])) as conn: + store = ContinuityStore(conn) + store.create_memory_review_label( + memory_id=UUID(seeded["coffee_memory_id"]), + label="correct", + note="Matches the latest coffee preference.", + ) + store.create_memory_review_label( + memory_id=UUID(seeded["coffee_memory_id"]), + label="insufficient_evidence", + note="One source event is still a thin basis.", + ) + store.create_memory_review_label( + memory_id=UUID(seeded["snack_memory_id"]), + label="outdated", + note="The deleted snack preference is superseded.", + ) + + return seeded + + +def test_list_memories_endpoint_returns_filtered_memories_with_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_memories(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + "/v0/memories", + query_params={ + "user_id": seeded["user_id"], + "status": "active", + "limit": "2", + }, + ) + + assert status_code == 200 + assert [item["id"] for item in payload["items"]] == [ + seeded["coffee_memory_id"], + seeded["hobby_memory_id"], + ] + assert payload["items"][0]["status"] == "active" + assert payload["items"][0]["value"] == {"likes": "oat milk"} + assert payload["items"][0]["source_event_ids"] == [seeded["coffee_update_event_id"]] + assert payload["summary"] == { + "status": "active", + "limit": 2, + "returned_count": 2, + "total_count": 3, + "has_more": True, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + } + + deleted_status, deleted_payload = invoke_request( + "GET", + "/v0/memories", + query_params={ + "user_id": seeded["user_id"], + "status": "deleted", + "limit": "5", + }, + ) + + assert deleted_status == 200 + assert deleted_payload["items"] == [ + { + "id": seeded["snack_memory_id"], + "memory_key": "user.preference.snack", + "value": {"likes": "chips"}, + "status": "deleted", + "source_event_ids": [seeded["snack_delete_event_id"]], + "created_at": deleted_payload["items"][0]["created_at"], + "updated_at": deleted_payload["items"][0]["updated_at"], + "deleted_at": deleted_payload["items"][0]["deleted_at"], + } + ] + assert deleted_payload["summary"] == { + "status": "deleted", + "limit": 5, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + } + + +def test_memory_review_endpoints_return_current_memory_and_revision_history( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_memories(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + memory_status, memory_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}", + query_params={"user_id": seeded["user_id"]}, + ) + revisions_status, revisions_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}/revisions", + query_params={"user_id": seeded["user_id"], "limit": "5"}, + ) + + assert memory_status == 200 + assert memory_payload["memory"]["id"] == seeded["coffee_memory_id"] + assert memory_payload["memory"]["memory_key"] == "user.preference.coffee" + assert memory_payload["memory"]["status"] == "active" + assert memory_payload["memory"]["value"] == {"likes": "oat milk"} + assert memory_payload["memory"]["source_event_ids"] == [seeded["coffee_update_event_id"]] + + assert revisions_status == 200 + assert [item["sequence_no"] for item in revisions_payload["items"]] == [1, 2] + assert [item["action"] for item in revisions_payload["items"]] == ["ADD", "UPDATE"] + assert revisions_payload["items"][0]["new_value"] == {"likes": "black"} + assert revisions_payload["items"][0]["source_event_ids"] == [seeded["coffee_add_event_id"]] + assert revisions_payload["items"][1]["previous_value"] == {"likes": "black"} + assert revisions_payload["items"][1]["new_value"] == {"likes": "oat milk"} + assert revisions_payload["items"][1]["source_event_ids"] == [seeded["coffee_update_event_id"]] + assert revisions_payload["summary"] == { + "memory_id": seeded["coffee_memory_id"], + "limit": 5, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["sequence_no_asc"], + } + + +def test_memory_review_endpoints_enforce_per_user_isolation_and_not_found_behavior( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_memories(migrated_database_urls["app"]) + intruder_id = uuid4() + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/memories", + query_params={"user_id": str(intruder_id), "status": "all", "limit": "10"}, + ) + memory_status, memory_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}", + query_params={"user_id": str(intruder_id)}, + ) + revisions_status, revisions_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}/revisions", + query_params={"user_id": str(intruder_id), "limit": "10"}, + ) + + assert list_status == 200 + assert list_payload == { + "items": [], + "summary": { + "status": "all", + "limit": 10, + "returned_count": 0, + "total_count": 0, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert memory_status == 404 + assert memory_payload == { + "detail": f"memory {seeded['coffee_memory_id']} was not found", + } + assert revisions_status == 404 + assert revisions_payload == { + "detail": f"memory {seeded['coffee_memory_id']} was not found", + } + + +def test_memory_review_queue_endpoint_returns_only_active_unlabeled_memories_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_queue_state(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + "/v0/memories/review-queue", + query_params={ + "user_id": seeded["user_id"], + "limit": "2", + }, + ) + + assert status_code == 200 + assert payload == { + "items": [ + { + "id": seeded["coffee_memory_id"], + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [seeded["coffee_update_event_id"]], + "created_at": payload["items"][0]["created_at"], + "updated_at": payload["items"][0]["updated_at"], + }, + { + "id": seeded["book_memory_id"], + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": [seeded["book_add_event_id"]], + "created_at": payload["items"][1]["created_at"], + "updated_at": payload["items"][1]["updated_at"], + }, + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 2, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + +def test_memory_evaluation_summary_endpoint_returns_explicit_consistent_counts( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_evaluation_state(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + "/v0/memories/evaluation-summary", + query_params={"user_id": seeded["user_id"]}, + ) + + assert status_code == 200 + assert payload == { + "summary": { + "total_memory_count": 4, + "active_memory_count": 3, + "deleted_memory_count": 1, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } + + +def test_memory_review_queue_and_evaluation_summary_enforce_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_evaluation_state(migrated_database_urls["app"]) + intruder_id = uuid4() + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + queue_status, queue_payload = invoke_request( + "GET", + "/v0/memories/review-queue", + query_params={"user_id": str(intruder_id), "limit": "10"}, + ) + summary_status, summary_payload = invoke_request( + "GET", + "/v0/memories/evaluation-summary", + query_params={"user_id": str(intruder_id)}, + ) + + assert seeded["user_id"] != str(intruder_id) + assert queue_status == 200 + assert queue_payload == { + "items": [], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 10, + "returned_count": 0, + "total_count": 0, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert summary_status == 200 + assert summary_payload == { + "summary": { + "total_memory_count": 0, + "active_memory_count": 0, + "deleted_memory_count": 0, + "labeled_memory_count": 0, + "unlabeled_memory_count": 0, + "total_label_row_count": 0, + "label_row_counts_by_value": { + "correct": 0, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } diff --git a/tests/integration/test_memory_review_labels_api.py b/tests/integration/test_memory_review_labels_api.py new file mode 100644 index 0000000..1b184e8 --- /dev/null +++ b/tests/integration/test_memory_review_labels_api.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_memory_for_review_labels(database_url: str) -> dict[str, str]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "reviewer@example.com", "Reviewer") + thread = store.create_thread("Memory review labels thread") + session = store.create_session(thread["id"], status="active") + event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "likes oat milk in coffee"}, + ) + decision = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event["id"],), + ), + ) + + assert decision.memory is not None + return { + "user_id": str(user_id), + "memory_id": decision.memory["id"], + } + + +def seed_intruder(database_url: str) -> UUID: + intruder_id = uuid4() + with user_connection(database_url, intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + return intruder_id + + +def test_memory_review_label_endpoints_create_and_list_labels_with_stable_summary_counts( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + first_status, first_payload = invoke_request( + "POST", + f"/v0/memories/{seeded['memory_id']}/labels", + payload={ + "user_id": seeded["user_id"], + "label": "correct", + "note": "Matches the latest admitted evidence.", + }, + ) + second_status, second_payload = invoke_request( + "POST", + f"/v0/memories/{seeded['memory_id']}/labels", + payload={ + "user_id": seeded["user_id"], + "label": "outdated", + "note": None, + }, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": seeded["user_id"]}, + ) + + assert first_status == 201 + assert first_payload["label"]["memory_id"] == seeded["memory_id"] + assert first_payload["label"]["reviewer_user_id"] == seeded["user_id"] + assert first_payload["label"]["label"] == "correct" + assert first_payload["label"]["note"] == "Matches the latest admitted evidence." + assert first_payload["summary"] == { + "memory_id": seeded["memory_id"], + "total_count": 1, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + } + + assert second_status == 201 + assert second_payload["label"]["label"] == "outdated" + assert second_payload["label"]["note"] is None + assert second_payload["summary"] == { + "memory_id": seeded["memory_id"], + "total_count": 2, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + } + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + first_payload["label"]["id"], + second_payload["label"]["id"], + ] + assert list_payload["summary"] == second_payload["summary"] + + +def test_memory_review_label_listing_uses_deterministic_created_at_then_id_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], UUID(seeded["user_id"])) as conn: + store = ContinuityStore(conn) + created_labels = [ + store.create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="incorrect", + note="Conflicts with the source event.", + ), + store.create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="insufficient_evidence", + note="The evidence is too weak.", + ), + store.create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="outdated", + note="Superseded by newer behavior.", + ), + ] + + status_code, payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": seeded["user_id"]}, + ) + + expected_ids = [ + str(label["id"]) + for label in sorted( + created_labels, + key=lambda label: (label["created_at"], label["id"]), + ) + ] + + assert status_code == 200 + assert [item["id"] for item in payload["items"]] == expected_ids + assert payload["summary"] == { + "memory_id": seeded["memory_id"], + "total_count": 3, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 1, + "insufficient_evidence": 1, + }, + "order": ["created_at_asc", "id_asc"], + } + + +def test_memory_review_label_list_returns_empty_items_and_zero_filled_summary_for_unlabeled_memory( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": seeded["user_id"]}, + ) + + assert status_code == 200 + assert payload == { + "items": [], + "summary": { + "memory_id": seeded["memory_id"], + "total_count": 0, + "counts_by_label": { + "correct": 0, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_memory_review_labels_reject_update_and_delete_at_database_level(migrated_database_urls) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + + with user_connection(migrated_database_urls["app"], UUID(seeded["user_id"])) as conn: + label = ContinuityStore(conn).create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="correct", + note="Initial review label.", + ) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "UPDATE memory_review_labels SET label = 'incorrect' WHERE id = %s", + (label["id"],), + ) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "DELETE FROM memory_review_labels WHERE id = %s", + (label["id"],), + ) + + +def test_memory_review_label_endpoints_enforce_per_user_isolation_and_not_found_behavior( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + intruder_id = seed_intruder(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + create_status, create_payload = invoke_request( + "POST", + f"/v0/memories/{seeded['memory_id']}/labels", + payload={ + "user_id": str(intruder_id), + "label": "incorrect", + "note": "Should not be able to label another user's memory.", + }, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": str(intruder_id)}, + ) + + assert create_status == 404 + assert create_payload == {"detail": f"memory {seeded['memory_id']} was not found"} + assert list_status == 404 + assert list_payload == {"detail": f"memory {seeded['memory_id']} was not found"} diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py new file mode 100644 index 0000000..434645e --- /dev/null +++ b/tests/integration/test_migrations.py @@ -0,0 +1,798 @@ +from __future__ import annotations + +from alembic import command +import psycopg + +from alicebot_api.migrations import make_alembic_config + + +def test_tool_execution_task_step_linkage_migration_backfills_existing_rows(database_urls): + config = make_alembic_config(database_urls["admin"]) + user_id = "00000000-0000-0000-0000-000000000001" + thread_id = "00000000-0000-0000-0000-000000000002" + trace_id = "00000000-0000-0000-0000-000000000003" + tool_id = "00000000-0000-0000-0000-000000000004" + approval_id = "00000000-0000-0000-0000-000000000005" + task_id = "00000000-0000-0000-0000-000000000006" + task_step_id = "00000000-0000-0000-0000-000000000007" + execution_id = "00000000-0000-0000-0000-000000000008" + + command.upgrade(config, "20260313_0020") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, 'migration@example.com', 'Migration User') + """, + (user_id,), + ) + cur.execute( + """ + INSERT INTO threads (id, user_id, title) + VALUES (%s, %s, 'Migration Thread') + """, + (thread_id, user_id), + ) + cur.execute( + """ + INSERT INTO traces ( + id, + user_id, + thread_id, + kind, + compiler_version, + status, + limits + ) + VALUES ( + %s, + %s, + %s, + 'migration.seed', + 'v0', + 'completed', + '{}'::jsonb + ) + """, + (trace_id, user_id, thread_id), + ) + cur.execute( + """ + INSERT INTO tools ( + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata + ) + VALUES ( + %s, + %s, + 'proxy.echo', + 'Proxy Echo', + 'Seed tool for migration coverage', + '1.0.0', + 'tool_metadata_v0', + TRUE, + '[]'::jsonb, + '[]'::jsonb, + '[]'::jsonb, + '[]'::jsonb, + '[]'::jsonb, + '{}'::jsonb + ) + """, + (tool_id, user_id), + ) + cur.execute( + """ + INSERT INTO approvals ( + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + resolved_at, + resolved_by_user_id + ) + VALUES ( + %s, + %s, + %s, + %s, + NULL, + 'approved', + '{"action":"echo"}'::jsonb, + '{"id":"tool"}'::jsonb, + '{"decision":"approval_required"}'::jsonb, + %s, + now(), + %s + ) + """, + (approval_id, user_id, thread_id, tool_id, trace_id, user_id), + ) + cur.execute( + """ + INSERT INTO tasks ( + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id + ) + VALUES ( + %s, + %s, + %s, + %s, + 'approved', + '{"action":"echo"}'::jsonb, + '{"id":"tool"}'::jsonb, + %s, + NULL + ) + """, + (task_id, user_id, thread_id, tool_id, approval_id), + ) + cur.execute( + """ + INSERT INTO task_steps ( + id, + user_id, + task_id, + sequence_no, + kind, + status, + request, + outcome, + trace_id, + trace_kind + ) + VALUES ( + %s, + %s, + %s, + 1, + 'governed_request', + 'approved', + '{"action":"echo"}'::jsonb, + '{"routing_decision":"approval_required","approval_id":"00000000-0000-0000-0000-000000000005","approval_status":"approved","execution_id":null,"execution_status":null,"blocked_reason":null}'::jsonb, + %s, + 'migration.seed' + ) + """, + (task_step_id, user_id, task_id, trace_id), + ) + cur.execute( + """ + INSERT INTO tool_executions ( + id, + user_id, + approval_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result + ) + VALUES ( + %s, + %s, + %s, + %s, + %s, + %s, + NULL, + NULL, + 'blocked', + NULL, + '{"action":"echo"}'::jsonb, + '{"id":"tool"}'::jsonb, + '{"blocked_reason":"seed"}'::jsonb + ) + """, + (execution_id, user_id, approval_id, thread_id, tool_id, trace_id), + ) + conn.commit() + + command.upgrade(config, "head") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT task_step_id + FROM tool_executions + WHERE id = %s + """, + (execution_id,), + ) + row = cur.fetchone() + assert row is not None + assert str(row[0]) == task_step_id + cur.execute( + """ + SELECT is_nullable + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'tool_executions' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchone() == ("NO",) + + +def test_migrations_upgrade_and_downgrade(database_urls): + config = make_alembic_config(database_urls["admin"]) + + command.upgrade(config, "head") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.users')") + assert cur.fetchone()[0] == "users" + cur.execute("SELECT to_regclass('public.threads')") + assert cur.fetchone()[0] == "threads" + cur.execute("SELECT to_regclass('public.sessions')") + assert cur.fetchone()[0] == "sessions" + cur.execute("SELECT to_regclass('public.events')") + assert cur.fetchone()[0] == "events" + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] == "memories" + cur.execute("SELECT to_regclass('public.memory_revisions')") + assert cur.fetchone()[0] == "memory_revisions" + cur.execute("SELECT to_regclass('public.memory_review_labels')") + assert cur.fetchone()[0] == "memory_review_labels" + cur.execute("SELECT to_regclass('public.entities')") + assert cur.fetchone()[0] == "entities" + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] == "entity_edges" + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] == "embedding_configs" + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] == "memory_embeddings" + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] == "consents" + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] == "policies" + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] == "tools" + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] == "approvals" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [("task_step_id",)] + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] == "tasks" + cur.execute("SELECT to_regclass('public.task_workspaces')") + assert cur.fetchone()[0] == "task_workspaces" + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] == "task_steps" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'task_steps' + AND column_name IN ( + 'parent_step_id', + 'source_approval_id', + 'source_execution_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [ + ("parent_step_id",), + ("source_approval_id",), + ("source_execution_id",), + ] + cur.execute("SELECT to_regclass('public.tool_executions')") + assert cur.fetchone()[0] == "tool_executions" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'tool_executions' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [("task_step_id",)] + cur.execute("SELECT to_regclass('public.execution_budgets')") + assert cur.fetchone()[0] == "execution_budgets" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'execution_budgets' + AND column_name IN ( + 'status', + 'deactivated_at', + 'superseded_by_budget_id', + 'supersedes_budget_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [ + ("deactivated_at",), + ("status",), + ("superseded_by_budget_id",), + ("supersedes_budget_id",), + ] + cur.execute( + """ + SELECT c.relname, c.relrowsecurity, c.relforcerowsecurity + FROM pg_class AS c + JOIN pg_namespace AS n + ON n.oid = c.relnamespace + WHERE n.nspname = 'public' + AND c.relname IN ( + 'users', + 'threads', + 'sessions', + 'events', + 'memories', + 'memory_revisions', + 'memory_review_labels', + 'entities', + 'entity_edges', + 'embedding_configs', + 'memory_embeddings', + 'consents', + 'policies', + 'tools', + 'approvals', + 'tasks', + 'task_workspaces', + 'task_steps', + 'execution_budgets', + 'tool_executions' + ) + ORDER BY c.relname + """ + ) + assert cur.fetchall() == [ + ("approvals", True, True), + ("consents", True, True), + ("embedding_configs", True, True), + ("entities", True, True), + ("entity_edges", True, True), + ("events", True, True), + ("execution_budgets", True, True), + ("memories", True, True), + ("memory_embeddings", True, True), + ("memory_review_labels", True, True), + ("memory_revisions", True, True), + ("policies", True, True), + ("sessions", True, True), + ("task_steps", True, True), + ("task_workspaces", True, True), + ("tasks", True, True), + ("threads", True, True), + ("tool_executions", True, True), + ("tools", True, True), + ("users", True, True), + ] + cur.execute( + """ + SELECT tgname + FROM pg_trigger + WHERE tgrelid = 'events'::regclass + AND NOT tgisinternal + """ + ) + assert cur.fetchall() == [("events_append_only",)] + cur.execute( + """ + SELECT tgname + FROM pg_trigger + WHERE tgrelid = 'memory_revisions'::regclass + AND NOT tgisinternal + """ + ) + assert cur.fetchall() == [("memory_revisions_append_only",)] + cur.execute( + """ + SELECT tgname + FROM pg_trigger + WHERE tgrelid = 'memory_review_labels'::regclass + AND NOT tgisinternal + """ + ) + assert cur.fetchall() == [("memory_review_labels_append_only",)] + cur.execute( + """ + SELECT + has_table_privilege('alicebot_app', 'users', 'UPDATE'), + has_table_privilege('alicebot_app', 'threads', 'UPDATE'), + has_table_privilege('alicebot_app', 'sessions', 'UPDATE'), + has_table_privilege('alicebot_app', 'memories', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_revisions', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_revisions', 'DELETE'), + has_table_privilege('alicebot_app', 'memory_review_labels', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_review_labels', 'DELETE'), + has_table_privilege('alicebot_app', 'entities', 'UPDATE'), + has_table_privilege('alicebot_app', 'entities', 'DELETE'), + has_table_privilege('alicebot_app', 'entity_edges', 'UPDATE'), + has_table_privilege('alicebot_app', 'entity_edges', 'DELETE'), + has_table_privilege('alicebot_app', 'embedding_configs', 'UPDATE'), + has_table_privilege('alicebot_app', 'embedding_configs', 'DELETE'), + has_table_privilege('alicebot_app', 'memory_embeddings', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_embeddings', 'DELETE'), + has_table_privilege('alicebot_app', 'consents', 'UPDATE'), + has_table_privilege('alicebot_app', 'consents', 'DELETE'), + has_table_privilege('alicebot_app', 'policies', 'UPDATE'), + has_table_privilege('alicebot_app', 'policies', 'DELETE'), + has_table_privilege('alicebot_app', 'tools', 'UPDATE'), + has_table_privilege('alicebot_app', 'tools', 'DELETE'), + has_table_privilege('alicebot_app', 'approvals', 'UPDATE'), + has_table_privilege('alicebot_app', 'approvals', 'DELETE'), + has_table_privilege('alicebot_app', 'tasks', 'UPDATE'), + has_table_privilege('alicebot_app', 'tasks', 'DELETE'), + has_table_privilege('alicebot_app', 'task_workspaces', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_workspaces', 'DELETE'), + has_table_privilege('alicebot_app', 'task_steps', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_steps', 'DELETE'), + has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE'), + has_table_privilege('alicebot_app', 'execution_budgets', 'DELETE'), + has_table_privilege('alicebot_app', 'tool_executions', 'UPDATE'), + has_table_privilege('alicebot_app', 'tool_executions', 'DELETE') + """ + ) + assert cur.fetchone() == ( + False, + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + False, + True, + False, + False, + False, + False, + False, + True, + False, + True, + False, + False, + False, + True, + False, + True, + False, + False, + False, + ) + + command.downgrade(config, "20260313_0021") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_workspaces')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [("task_step_id",)] + + command.downgrade(config, "20260313_0018") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [] + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] == "task_steps" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'tool_executions' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [] + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'task_steps' + AND column_name IN ( + 'parent_step_id', + 'source_approval_id', + 'source_execution_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [] + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] == "tasks" + + command.downgrade(config, "20260313_0017") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] == "tasks" + + command.downgrade(config, "20260313_0014") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.execution_budgets')") + assert cur.fetchone()[0] == "execution_budgets" + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'execution_budgets' + AND column_name IN ( + 'status', + 'deactivated_at', + 'superseded_by_budget_id', + 'supersedes_budget_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [] + cur.execute( + "SELECT has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE')" + ) + assert cur.fetchone()[0] is False + + command.downgrade(config, "20260313_0013") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.execution_budgets')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tool_executions')") + assert cur.fetchone()[0] == "tool_executions" + + command.downgrade(config, "20260312_0012") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.tool_executions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] == "approvals" + + command.downgrade(config, "20260312_0011") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] == "approvals" + cur.execute( + """ + SELECT + has_table_privilege('alicebot_app', 'approvals', 'UPDATE'), + EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'resolved_at' + ), + EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'resolved_by_user_id' + ) + """ + ) + assert cur.fetchone() == ( + False, + False, + False, + ) + + command.downgrade(config, "20260312_0010") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] == "tools" + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] == "consents" + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] == "policies" + + command.downgrade(config, "20260312_0009") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] == "consents" + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] == "policies" + + command.downgrade(config, "20260312_0008") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] == "embedding_configs" + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] == "memory_embeddings" + + command.downgrade(config, "20260312_0007") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] == "memories" + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] == "entity_edges" + + command.downgrade(config, "20260311_0003") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_revisions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_review_labels')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entities')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT + has_table_privilege('alicebot_app', 'users', 'UPDATE'), + has_table_privilege('alicebot_app', 'threads', 'UPDATE'), + has_table_privilege('alicebot_app', 'sessions', 'UPDATE') + """ + ) + # Revision 20260310_0001 already leaves the runtime role without UPDATE + # access, so downgrading from head must preserve that same privilege floor. + assert cur.fetchone() == (False, False, False) + + command.downgrade(config, "20260310_0001") + + command.downgrade(config, "base") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.users')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.threads')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.sessions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.events')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_revisions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_review_labels')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entities')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT extname + FROM pg_extension + WHERE extname IN ('pgcrypto', 'vector') + ORDER BY extname + """ + ) + assert [row[0] for row in cur.fetchall()] == ["pgcrypto", "vector"] diff --git a/tests/integration/test_policy_api.py b/tests/integration/test_policy_api.py new file mode 100644 index 0000000..0ae0b37 --- /dev/null +++ b/tests/integration/test_policy_api.py @@ -0,0 +1,424 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Policy thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_consent_endpoints_upsert_and_list_deterministically(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + first_status, first_payload = invoke_request( + "POST", + "/v0/consents", + payload={ + "user_id": str(seeded["user_id"]), + "consent_key": "email_marketing", + "status": "granted", + "metadata": {"source": "settings"}, + }, + ) + second_status, second_payload = invoke_request( + "POST", + "/v0/consents", + payload={ + "user_id": str(seeded["user_id"]), + "consent_key": "analytics_tracking", + "status": "revoked", + "metadata": {"source": "banner"}, + }, + ) + third_status, third_payload = invoke_request( + "POST", + "/v0/consents", + payload={ + "user_id": str(seeded["user_id"]), + "consent_key": "email_marketing", + "status": "revoked", + "metadata": {"source": "preferences"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/consents", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert first_status == 201 + assert second_status == 201 + assert third_status == 200 + assert first_payload["write_mode"] == "created" + assert second_payload["write_mode"] == "created" + assert third_payload["write_mode"] == "updated" + assert third_payload["consent"]["id"] == first_payload["consent"]["id"] + assert list_status == 200 + assert [item["consent_key"] for item in list_payload["items"]] == [ + "analytics_tracking", + "email_marketing", + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["consent_key_asc", "created_at_asc", "id_asc"], + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored_consents = ContinuityStore(conn).list_consents() + + assert [consent["consent_key"] for consent in stored_consents] == [ + "analytics_tracking", + "email_marketing", + ] + assert stored_consents[1]["status"] == "revoked" + assert stored_consents[1]["metadata"] == {"source": "preferences"} + + +def test_policy_endpoints_create_list_and_get_in_priority_order(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + low_priority_status, low_priority_payload = invoke_request( + "POST", + "/v0/policies", + payload={ + "user_id": str(seeded["user_id"]), + "name": "Require approval for export", + "action": "memory.export", + "scope": "profile", + "effect": "require_approval", + "priority": 20, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing", "email_marketing"], + }, + ) + high_priority_status, high_priority_payload = invoke_request( + "POST", + "/v0/policies", + payload={ + "user_id": str(seeded["user_id"]), + "name": "Allow profile read", + "action": "memory.read", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {}, + "required_consents": [], + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/policies", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/policies/{low_priority_payload['policy']['id']}", + query_params={"user_id": str(seeded['user_id'])}, + ) + + assert low_priority_status == 201 + assert high_priority_status == 201 + assert low_priority_payload["policy"]["required_consents"] == ["email_marketing"] + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + high_priority_payload["policy"]["id"], + low_priority_payload["policy"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"policy": low_priority_payload["policy"]} + + +def test_policy_evaluation_allow_records_trace_events(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ) + created_policy = store.create_policy( + name="Allow export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={"channel": "email"}, + required_consents=["email_marketing"], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {"channel": "email"}, + }, + ) + + assert status_code == 200 + assert payload["decision"] == "allow" + assert payload["matched_policy"]["id"] == str(created_policy["id"]) + assert payload["evaluation"] == { + "action": "memory.export", + "scope": "profile", + "evaluated_policy_count": 1, + "matched_policy_id": str(created_policy["id"]), + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "policy_effect_allow", + ] + assert payload["trace"]["trace_event_count"] == 3 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(payload["trace"]["trace_id"])) + + assert trace["kind"] == "policy.evaluate" + assert trace["compiler_version"] == "policy_evaluation_v0" + assert trace["limits"] == { + "order": ["priority_asc", "created_at_asc", "id_asc"], + "active_policy_count": 1, + "consent_count": 1, + } + assert [event["kind"] for event in trace_events] == [ + "policy.evaluate.request", + "policy.evaluate.order", + "policy.evaluate.decision", + ] + assert trace_events[2]["payload"]["decision"] == "allow" + assert trace_events[2]["payload"]["matched_policy_id"] == str(created_policy["id"]) + + +def test_policy_evaluation_denies_when_required_consent_is_missing(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + ContinuityStore(conn).create_policy( + name="Allow export with consent", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {}, + }, + ) + + assert status_code == 200 + assert payload["decision"] == "deny" + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "consent_missing", + ] + + +def test_policy_evaluation_returns_require_approval(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + created_policy = ContinuityStore(conn).create_policy( + name="Escalate export", + action="memory.export", + scope="profile", + effect="require_approval", + priority=10, + active=True, + conditions={}, + required_consents=[], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {}, + }, + ) + + assert status_code == 200 + assert payload["decision"] == "require_approval" + assert payload["matched_policy"]["id"] == str(created_policy["id"]) + assert payload["reasons"][-1]["code"] == "policy_effect_require_approval" + + +def test_policy_and_consent_endpoints_enforce_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent(consent_key="email_marketing", status="granted", metadata={}) + owner_policy = store.create_policy( + name="Allow export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + + consent_status, consent_payload = invoke_request( + "GET", + "/v0/consents", + query_params={"user_id": str(intruder["user_id"])}, + ) + policy_status, policy_payload = invoke_request( + "GET", + "/v0/policies", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/policies/{owner_policy['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + evaluation_status, evaluation_payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {}, + }, + ) + + assert consent_status == 200 + assert consent_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["consent_key_asc", "created_at_asc", "id_asc"], + }, + } + assert policy_status == 200 + assert policy_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["priority_asc", "created_at_asc", "id_asc"], + }, + } + assert detail_status == 404 + assert detail_payload == {"detail": f"policy {owner_policy['id']} was not found"} + assert evaluation_status == 200 + assert evaluation_payload["decision"] == "deny" + assert evaluation_payload["matched_policy"] is None + assert evaluation_payload["reasons"] == [ + { + "code": "no_matching_policy", + "source": "system", + "message": "No active policy matched the requested action, scope, and attributes.", + "policy_id": None, + "consent_key": None, + } + ] diff --git a/tests/integration/test_proxy_execution_api.py b/tests/integration/test_proxy_execution_api.py new file mode 100644 index 0000000..755f5f3 --- /dev/null +++ b/tests/integration/test_proxy_execution_api.py @@ -0,0 +1,1478 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import psycopg + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Proxy execution thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def create_tool_and_policy( + database_url: str, + *, + user_id: UUID, + tool_key: str, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key=tool_key, + name="Proxy Tool", + description="Deterministic proxy tool.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name=f"Require approval for {tool_key}", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": tool_key}, + required_consents=[], + ) + return tool["id"] + + +def create_pending_approval( + *, + user_id: UUID, + thread_id: UUID, + tool_id: UUID, +) -> tuple[int, dict[str, Any]]: + return invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(user_id), + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "hello", "count": 2}, + }, + ) + + +def create_execution_budget( + database_url: str, + *, + user_id: UUID, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + budget = store.create_execution_budget( + tool_key=tool_key, + domain_hint=domain_hint, + max_completed_executions=max_completed_executions, + rolling_window_seconds=rolling_window_seconds, + supersedes_budget_id=None, + ) + return budget["id"] + + +def set_execution_executed_at( + admin_database_url: str, + *, + execution_id: UUID, + executed_at_sql: str, +) -> None: + with psycopg.connect(admin_database_url) as conn: + conn.execute( + f"UPDATE tool_executions SET executed_at = {executed_at_sql} WHERE id = %s", + (execution_id,), + ) + conn.commit() + + +def test_execute_approved_proxy_endpoint_executes_only_approved_requests( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + assert approve_payload["approval"]["status"] == "approved" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 200 + assert list(execute_payload) == ["request", "approval", "tool", "result", "events", "trace"] + assert execute_payload["request"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + } + assert execute_payload["approval"]["id"] == create_payload["approval"]["id"] + assert execute_payload["approval"]["status"] == "approved" + assert execute_payload["tool"]["id"] == str(tool_id) + assert execute_payload["tool"]["tool_key"] == "proxy.echo" + assert execute_payload["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": { + "mode": "no_side_effect", + "tool_key": "proxy.echo", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + } + assert execute_payload["events"]["request_sequence_no"] == 1 + assert execute_payload["events"]["result_sequence_no"] == 2 + assert execute_payload["trace"]["trace_event_count"] == 9 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + thread_events = store.list_thread_events(owner["thread_id"]) + tasks = store.list_tasks() + task_steps = store.list_task_steps_for_task(tasks[0]["id"]) + tool_executions = store.list_tool_executions() + execute_trace = store.get_trace(UUID(execute_payload["trace"]["trace_id"])) + execute_trace_events = store.list_trace_events(UUID(execute_payload["trace"]["trace_id"])) + + assert [event["kind"] for event in thread_events] == [ + "tool.proxy.execution.request", + "tool.proxy.execution.result", + ] + assert len(tool_executions) == 1 + assert len(tasks) == 1 + assert len(task_steps) == 1 + assert tasks[0]["status"] == "executed" + assert tasks[0]["latest_execution_id"] == tool_executions[0]["id"] + assert task_steps[0]["status"] == "executed" + assert tool_executions[0]["approval_id"] == UUID(create_payload["approval"]["id"]) + assert tool_executions[0]["task_step_id"] == task_steps[0]["id"] + assert tool_executions[0]["thread_id"] == owner["thread_id"] + assert tool_executions[0]["tool_id"] == tool_id + assert tool_executions[0]["trace_id"] == UUID(execute_payload["trace"]["trace_id"]) + assert tool_executions[0]["handler_key"] == "proxy.echo" + assert tool_executions[0]["status"] == "completed" + assert tool_executions[0]["request"] == thread_events[0]["payload"]["request"] + assert tool_executions[0]["tool"]["tool_key"] == "proxy.echo" + assert tool_executions[0]["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": execute_payload["result"]["output"], + "reason": None, + } + assert thread_events[0]["payload"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + "tool_id": str(tool_id), + "tool_key": "proxy.echo", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + } + assert execute_trace["kind"] == "tool.proxy.execute" + assert execute_trace["compiler_version"] == "proxy_execution_v0" + assert execute_trace["limits"] == { + "approval_status": "approved", + "enabled_handler_keys": ["proxy.echo"], + "budget_match_order": ["specificity_desc", "created_at_asc", "id_asc"], + } + assert [event["kind"] for event in execute_trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert execute_trace_events[0]["payload"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + } + assert execute_trace_events[1]["payload"]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert execute_trace_events[2]["payload"]["decision"] == "allow" + assert execute_trace_events[3]["payload"]["dispatch_status"] == "executed" + assert execute_trace_events[3]["payload"]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert execute_trace_events[4]["payload"]["request_event_id"] == execute_payload["events"]["request_event_id"] + assert execute_trace_events[4]["payload"]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert execute_trace_events[7]["payload"] == { + "task_id": create_payload["task"]["id"], + "task_step_id": str(task_steps[0]["id"]), + "source": "proxy_execution", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": "approved", + "current_status": "executed", + "trace": { + "trace_id": execute_payload["trace"]["trace_id"], + "trace_kind": "tool.proxy.execute", + }, + } + + +def test_execute_approved_proxy_endpoint_rejects_pending_approval( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 409 + assert execute_payload == { + "detail": f"approval {create_payload['approval']['id']} is pending and cannot be executed" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace_rows = store.conn.execute( + "SELECT id, kind, limits FROM traces WHERE kind = %s ORDER BY created_at ASC, id ASC", + ("tool.proxy.execute",), + ).fetchall() + trace_events = store.list_trace_events(trace_rows[-1]["id"]) + thread_events = store.list_thread_events(owner["thread_id"]) + + assert thread_events == [] + assert trace_rows[-1]["limits"] == { + "approval_status": "pending", + "enabled_handler_keys": ["proxy.echo"], + "budget_match_order": ["specificity_desc", "created_at_asc", "id_asc"], + } + assert trace_events[2]["payload"]["dispatch_status"] == "blocked" + assert trace_events[3]["payload"]["execution_status"] == "blocked" + + +def test_execute_approved_proxy_endpoint_rejects_rejected_approval( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + reject_status, reject_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/reject", + payload={"user_id": str(owner["user_id"])}, + ) + assert reject_status == 200 + assert reject_payload["approval"]["status"] == "rejected" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 409 + assert execute_payload == { + "detail": f"approval {create_payload['approval']['id']} is rejected and cannot be executed" + } + + +def test_execute_approved_proxy_endpoint_rejects_missing_handler( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.missing", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + assert approve_payload["approval"]["status"] == "approved" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 409 + assert execute_payload == { + "detail": "tool 'proxy.missing' has no registered proxy handler" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace_rows = store.conn.execute( + "SELECT id FROM traces WHERE kind = %s ORDER BY created_at ASC, id ASC", + ("tool.proxy.execute",), + ).fetchall() + trace_events = store.list_trace_events(trace_rows[-1]["id"]) + tool_executions = store.list_tool_executions() + thread_events = store.list_thread_events(owner["thread_id"]) + + assert thread_events == [] + assert len(tool_executions) == 1 + assert tool_executions[0]["approval_id"] == UUID(create_payload["approval"]["id"]) + assert tool_executions[0]["task_step_id"] == UUID(create_payload["approval"]["task_step_id"]) + assert tool_executions[0]["handler_key"] is None + assert tool_executions[0]["status"] == "blocked" + assert tool_executions[0]["request_event_id"] is None + assert tool_executions[0]["result_event_id"] is None + assert tool_executions[0]["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + } + assert trace_events[2]["payload"]["decision"] == "allow" + assert trace_events[3]["payload"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + "tool_id": str(tool_id), + "tool_key": "proxy.missing", + "handler_key": None, + "dispatch_status": "blocked", + "reason": "tool 'proxy.missing' has no registered proxy handler", + "result_status": "blocked", + "output": None, + } + + list_status, list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{tool_executions[0]['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert list_status == 200 + assert list_payload["items"][0]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert list_payload["items"][0]["status"] == "blocked" + assert list_payload["items"][0]["request_event_id"] is None + assert list_payload["items"][0]["result_event_id"] is None + assert list_payload["items"][0]["result"]["reason"] == "tool 'proxy.missing' has no registered proxy handler" + assert detail_status == 200 + assert detail_payload == {"execution": list_payload["items"][0]} + + +def test_execute_approved_proxy_endpoint_enforces_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + other_user = seed_user(migrated_database_urls["app"], email="other@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(other_user["user_id"])}, + ) + + assert execute_status == 404 + assert execute_payload == { + "detail": f"approval {create_payload['approval']['id']} was not found" + } + + +def test_execute_approved_proxy_endpoint_updates_the_explicitly_linked_later_step( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-step-linkage@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{create_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{create_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert detail_status == 200 + assert step_list_status == 200 + initial_execution_id = detail_payload["task"]["latest_execution_id"] + assert initial_execution_id is not None + + create_step_status, create_step_payload = invoke_request( + "POST", + f"/v0/tasks/{create_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-2"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": step_list_payload["items"][0]["id"], + "source_approval_id": create_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + assert create_step_status == 201 + + transition_status, transition_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_step_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": create_payload["approval"]["id"], + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert transition_status == 200 + assert transition_payload["task_step"]["status"] == "approved" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + relinked = store.update_approval_task_step_optional( + approval_id=UUID(create_payload["approval"]["id"]), + task_step_id=UUID(create_step_payload["task_step"]["id"]), + ) + assert relinked is not None + + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert second_execute_status == 200 + assert second_execute_payload["request"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_step_payload["task_step"]["id"], + } + assert second_execute_payload["approval"]["task_step_id"] == create_step_payload["task_step"]["id"] + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + task = store.get_task_optional(UUID(create_payload["task"]["id"])) + task_steps = store.list_task_steps_for_task(UUID(create_payload["task"]["id"])) + tool_executions = store.list_tool_executions() + proxy_traces = store.conn.execute( + """ + SELECT id + FROM traces + WHERE thread_id = %s + AND kind = 'tool.proxy.execute' + ORDER BY created_at ASC, id ASC + """, + (owner["thread_id"],), + ).fetchall() + + assert task is not None + assert task["status"] == "executed" + assert task["latest_approval_id"] == UUID(create_payload["approval"]["id"]) + assert len(task_steps) == 2 + assert task_steps[0]["status"] == "executed" + assert task_steps[0]["trace_id"] == UUID(execute_payload["trace"]["trace_id"]) + assert task_steps[0]["outcome"]["execution_id"] == initial_execution_id + assert task_steps[1]["status"] == "executed" + assert task_steps[1]["id"] == UUID(create_step_payload["task_step"]["id"]) + assert task_steps[1]["trace_id"] == UUID(second_execute_payload["trace"]["trace_id"]) + assert task_steps[1]["outcome"]["approval_id"] == create_payload["approval"]["id"] + assert task_steps[1]["outcome"]["execution_status"] == "completed" + assert len(tool_executions) == 2 + assert task["latest_execution_id"] == tool_executions[1]["id"] + assert tool_executions[1]["task_step_id"] == UUID(create_step_payload["task_step"]["id"]) + assert task_steps[1]["outcome"]["execution_id"] == str(tool_executions[1]["id"]) + assert len(proxy_traces) == 2 + + +def test_execute_approved_proxy_endpoint_blocks_when_execution_budget_is_exceeded( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, first_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert second_execute_status == 200 + assert second_execute_payload["events"] is None + assert second_execute_payload["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {budget_id} blocks execution: projected completed executions " + "2 would exceed limit 1" + ), + "budget_decision": { + "matched_budget_id": str(budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + stored_executions = store.list_tool_executions() + blocked_trace = store.get_trace(UUID(second_execute_payload["trace"]["trace_id"])) + blocked_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + thread_events = store.list_thread_events(owner["thread_id"]) + + list_status, list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{stored_executions[1]['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert len(stored_executions) == 2 + assert [row["status"] for row in stored_executions] == ["completed", "blocked"] + assert stored_executions[1]["task_step_id"] == UUID(second_execute_payload["request"]["task_step_id"]) + assert stored_executions[1]["result"] == second_execute_payload["result"] + assert stored_executions[1]["request_event_id"] is None + assert stored_executions[1]["result_event_id"] is None + assert [event["kind"] for event in thread_events] == [ + "tool.proxy.execution.request", + "tool.proxy.execution.result", + ] + assert blocked_trace["limits"] == { + "approval_status": "approved", + "enabled_handler_keys": ["proxy.echo"], + "budget_match_order": ["specificity_desc", "created_at_asc", "id_asc"], + } + assert [event["kind"] for event in blocked_trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert blocked_trace_events[0]["payload"] == second_execute_payload["request"] + assert blocked_trace_events[1]["payload"]["task_step_id"] == second_execute_payload["request"]["task_step_id"] + assert blocked_trace_events[2]["payload"] == second_execute_payload["result"]["budget_decision"] + assert blocked_trace_events[3]["payload"]["dispatch_status"] == "blocked" + assert blocked_trace_events[3]["payload"]["task_step_id"] == second_execute_payload["request"]["task_step_id"] + assert list_status == 200 + assert list_payload["items"][1]["task_step_id"] == second_execute_payload["request"]["task_step_id"] + assert [item["status"] for item in list_payload["items"]] == ["completed", "blocked"] + assert list_payload["items"][1]["result"] == second_execute_payload["result"] + assert detail_status == 200 + assert detail_payload == {"execution": list_payload["items"][1]} + + +def test_execute_approved_proxy_endpoint_allows_when_recent_history_is_within_rolling_window_limit( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert second_execute_status == 200 + assert second_execute_payload["result"]["status"] == "completed" + assert second_execute_payload["events"] is not None + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"]["matched_budget_id"] is not None + assert execute_trace_events[2]["payload"]["rolling_window_seconds"] == 3600 + assert execute_trace_events[2]["payload"]["count_scope"] == "rolling_window" + assert execute_trace_events[2]["payload"]["window_started_at"] is not None + assert execute_trace_events[2]["payload"]["completed_execution_count"] == 1 + assert execute_trace_events[2]["payload"]["projected_completed_execution_count"] == 2 + assert execute_trace_events[2]["payload"]["decision"] == "allow" + assert execute_trace_events[2]["payload"]["reason"] == "within_budget" + assert execute_trace_events[2]["payload"]["history_order"] == ["executed_at_asc", "id_asc"] + + +def test_execute_approved_proxy_endpoint_blocks_when_recent_window_history_exceeds_limit( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=3600, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert second_execute_status == 200 + assert second_execute_payload["events"] is None + assert list(second_execute_payload["result"]) == [ + "handler_key", + "status", + "output", + "reason", + "budget_decision", + ] + assert second_execute_payload["result"]["handler_key"] is None + assert second_execute_payload["result"]["status"] == "blocked" + assert second_execute_payload["result"]["output"] is None + assert second_execute_payload["result"]["reason"] == ( + f"execution budget {budget_id} blocks execution: projected completed executions " + "2 within rolling window 3600 seconds would exceed limit 1" + ) + assert second_execute_payload["result"]["budget_decision"]["matched_budget_id"] == str(budget_id) + assert second_execute_payload["result"]["budget_decision"]["rolling_window_seconds"] == 3600 + assert second_execute_payload["result"]["budget_decision"]["count_scope"] == "rolling_window" + assert second_execute_payload["result"]["budget_decision"]["window_started_at"] is not None + assert second_execute_payload["result"]["budget_decision"]["completed_execution_count"] == 1 + assert second_execute_payload["result"]["budget_decision"]["projected_completed_execution_count"] == 2 + assert second_execute_payload["result"]["budget_decision"]["history_order"] == [ + "executed_at_asc", + "id_asc", + ] + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + stored_executions = store.list_tool_executions() + blocked_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert [row["status"] for row in stored_executions] == ["completed", "blocked"] + assert stored_executions[1]["task_step_id"] == UUID(second_execute_payload["request"]["task_step_id"]) + assert stored_executions[1]["result"] == second_execute_payload["result"] + assert blocked_trace_events[2]["payload"] == second_execute_payload["result"]["budget_decision"] + + +def test_execute_approved_proxy_endpoint_excludes_old_window_history_and_keeps_counts_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + other_user = seed_user(migrated_database_urls["app"], email="other@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + owner_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + other_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=other_user["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=60, + ) + + owner_first_status, owner_first_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=owner_tool_id, + ) + owner_second_status, owner_second_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=owner_tool_id, + ) + other_status, other_payload = create_pending_approval( + user_id=other_user["user_id"], + thread_id=other_user["thread_id"], + tool_id=other_tool_id, + ) + assert owner_first_status == 200 + assert owner_second_status == 200 + assert other_status == 200 + + for approval_payload, user_id in ( + (owner_first_payload, owner["user_id"]), + (owner_second_payload, owner["user_id"]), + (other_payload, other_user["user_id"]), + ): + approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{approval_payload['approval']['id']}/approve", + payload={"user_id": str(user_id)}, + ) + assert approve_status == 200 + + owner_first_execute_status, owner_first_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{owner_first_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + other_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{other_payload['approval']['id']}/execute", + payload={"user_id": str(other_user["user_id"])}, + ) + assert owner_first_execute_status == 200 + assert other_execute_status == 200 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + owner_first_execution_id = store.list_tool_executions()[0]["id"] + + set_execution_executed_at( + migrated_database_urls["admin"], + execution_id=owner_first_execution_id, + executed_at_sql="clock_timestamp() - interval '2 hours'", + ) + + owner_second_execute_status, owner_second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{owner_second_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert owner_second_execute_status == 200 + assert owner_second_execute_payload["result"]["status"] == "completed" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(owner_second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"]["matched_budget_id"] == str(budget_id) + assert execute_trace_events[2]["payload"]["rolling_window_seconds"] == 60 + assert execute_trace_events[2]["payload"]["count_scope"] == "rolling_window" + assert execute_trace_events[2]["payload"]["window_started_at"] is not None + assert execute_trace_events[2]["payload"]["completed_execution_count"] == 0 + assert execute_trace_events[2]["payload"]["projected_completed_execution_count"] == 1 + assert execute_trace_events[2]["payload"]["reason"] == "within_budget" + + +def test_execute_approved_proxy_endpoint_ignores_deactivated_budget( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + deactivate_status, deactivate_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{budget_id}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert deactivate_status == 200 + assert deactivate_payload["execution_budget"]["status"] == "inactive" + assert second_execute_status == 200 + assert second_execute_payload["result"]["status"] == "completed" + assert second_execute_payload["trace"]["trace_event_count"] == 9 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"] == { + "matched_budget_id": None, + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": None, + "budget_domain_hint": None, + "max_completed_executions": None, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 0, + "projected_completed_execution_count": 1, + "decision": "allow", + "reason": "no_matching_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + + +def test_execute_approved_proxy_endpoint_uses_replacement_budget_after_supersession( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + supersede_status, supersede_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{budget_id}/supersede", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "max_completed_executions": 2, + }, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert supersede_status == 200 + assert supersede_payload["superseded_budget"]["status"] == "superseded" + assert supersede_payload["replacement_budget"]["status"] == "active" + assert second_execute_status == 200 + assert second_execute_payload["result"]["status"] == "completed" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"] == { + "matched_budget_id": supersede_payload["replacement_budget"]["id"], + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 2, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "allow", + "reason": "within_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + + +def test_execute_approved_proxy_execution_budget_is_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + other_user = seed_user(migrated_database_urls["app"], email="other@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + owner_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + other_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=other_user["user_id"], + tool_key="proxy.echo", + ) + create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + owner_create_status, owner_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=owner_tool_id, + ) + other_create_status, other_create_payload = create_pending_approval( + user_id=other_user["user_id"], + thread_id=other_user["thread_id"], + tool_id=other_tool_id, + ) + assert owner_create_status == 200 + assert other_create_status == 200 + + owner_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{owner_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + other_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{other_create_payload['approval']['id']}/approve", + payload={"user_id": str(other_user["user_id"])}, + ) + assert owner_approve_status == 200 + assert other_approve_status == 200 + + owner_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{owner_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + other_execute_status, other_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{other_create_payload['approval']['id']}/execute", + payload={"user_id": str(other_user["user_id"])}, + ) + + assert owner_execute_status == 200 + assert other_execute_status == 200 + assert other_execute_payload["result"]["status"] == "completed" + + +def test_tool_execution_review_endpoints_are_deterministic_and_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + first_status, first_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_status, second_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_status == 200 + assert second_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_execute_status == 200 + assert second_execute_status == 200 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + stored_executions = store.list_tool_executions() + + list_status, list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{stored_executions[1]['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + str(stored_executions[0]["id"]), + str(stored_executions[1]["id"]), + ] + assert [item["task_step_id"] for item in list_payload["items"]] == [ + str(stored_executions[0]["task_step_id"]), + str(stored_executions[1]["task_step_id"]), + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["executed_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == { + "execution": next( + item for item in list_payload["items"] if item["id"] == str(stored_executions[1]["id"]) + ) + } + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["executed_at_asc", "id_asc"]}, + } + + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{stored_executions[0]['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"tool execution {stored_executions[0]['id']} was not found" + } diff --git a/tests/integration/test_responses_api.py b/tests/integration/test_responses_api.py new file mode 100644 index 0000000..1ed051f --- /dev/null +++ b/tests/integration/test_responses_api.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import json +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +import alicebot_api.response_generation as response_generation_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_generate_response(payload: dict[str, object]) -> tuple[int, dict[str, object]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/responses", + "raw_path": b"/v0/responses", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_response_thread( + database_url: str, + *, + email: str = "owner@example.com", + display_name: str = "Owner", +) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, display_name) + thread = store.create_thread("Response thread") + session = store.create_session(thread["id"], status="active") + prior_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "Remember that I like oat milk."}, + ) + memory = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(prior_event["id"])], + ) + + return { + "user_id": user_id, + "thread_id": thread["id"], + "session_id": session["id"], + "prior_event_id": prior_event["id"], + "memory_id": memory["id"], + } + + +def test_generate_response_persists_user_and_assistant_events_and_trace_metadata( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_response_thread(migrated_database_urls["app"]) + captured: dict[str, object] = {} + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + model_provider="openai_responses", + model_name="gpt-5-mini", + model_api_key="test-key", + ), + ) + + def fake_invoke_model(*, settings, request): + captured["settings"] = settings + captured["request_payload"] = request.as_payload() + return response_generation_module.ModelInvocationResponse( + provider="openai_responses", + model="gpt-5-mini", + response_id="resp_123", + finish_reason="completed", + output_text="You prefer oat milk.", + usage={"input_tokens": 20, "output_tokens": 6, "total_tokens": 26}, + ) + + monkeypatch.setattr(response_generation_module, "invoke_model", fake_invoke_model) + + status_code, payload = invoke_generate_response( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "message": "What do I usually take in coffee?", + } + ) + + assert status_code == 200 + assert payload["assistant"] == { + "event_id": payload["assistant"]["event_id"], + "sequence_no": 3, + "text": "You prefer oat milk.", + "model_provider": "openai_responses", + "model": "gpt-5-mini", + } + assert payload["trace"]["compile_trace_event_count"] > 0 + assert payload["trace"]["response_trace_event_count"] == 2 + assert captured["request_payload"]["tool_choice"] == "none" + assert captured["request_payload"]["tools"] == [] + assert captured["request_payload"]["store"] is False + assert captured["request_payload"]["sections"] == [ + "system", + "developer", + "context", + "conversation", + ] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + events = store.list_thread_events(seeded["thread_id"]) + compile_trace = store.get_trace(UUID(payload["trace"]["compile_trace_id"])) + response_trace = store.get_trace(UUID(payload["trace"]["response_trace_id"])) + response_trace_events = store.list_trace_events(UUID(payload["trace"]["response_trace_id"])) + + assert [event["sequence_no"] for event in events] == [1, 2, 3] + assert [event["kind"] for event in events] == [ + "message.user", + "message.user", + "message.assistant", + ] + assert events[1]["payload"] == {"text": "What do I usually take in coffee?"} + assert events[2]["payload"] == { + "text": "You prefer oat milk.", + "model": { + "provider": "openai_responses", + "model": "gpt-5-mini", + "response_id": "resp_123", + "finish_reason": "completed", + "usage": {"input_tokens": 20, "output_tokens": 6, "total_tokens": 26}, + }, + "prompt": { + "assembly_version": "prompt_assembly_v0", + "prompt_sha256": events[2]["payload"]["prompt"]["prompt_sha256"], + "section_order": ["system", "developer", "context", "conversation"], + }, + } + assert compile_trace["kind"] == "context.compile" + assert response_trace["kind"] == "response.generate" + assert response_trace["compiler_version"] == "response_generation_v0" + assert [event["kind"] for event in response_trace_events] == [ + "response.prompt.assembled", + "response.model.completed", + ] + assert response_trace_events[0]["payload"]["compile_trace_id"] == payload["trace"]["compile_trace_id"] + assert response_trace_events[1]["payload"] == { + "provider": "openai_responses", + "model": "gpt-5-mini", + "tool_choice": "none", + "tools_enabled": False, + "response_id": "resp_123", + "finish_reason": "completed", + "output_text_char_count": len("You prefer oat milk."), + "usage": {"input_tokens": 20, "output_tokens": 6, "total_tokens": 26}, + "error_message": None, + } + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "UPDATE events SET kind = 'message.mutated' WHERE id = %s", + (UUID(payload["assistant"]["event_id"]),), + ) + + +def test_generate_response_returns_clean_failure_without_persisting_assistant_event( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_response_thread(migrated_database_urls["app"]) + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + model_provider="openai_responses", + model_name="gpt-5-mini", + model_api_key="test-key", + ), + ) + monkeypatch.setattr( + response_generation_module, + "invoke_model", + lambda **_kwargs: (_ for _ in ()).throw( + response_generation_module.ModelInvocationError("upstream timeout") + ), + ) + + status_code, payload = invoke_generate_response( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "message": "What do I usually take in coffee?", + } + ) + + assert status_code == 502 + assert payload["detail"] == "upstream timeout" + assert payload["trace"]["compile_trace_event_count"] > 0 + assert payload["trace"]["response_trace_event_count"] == 2 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + events = store.list_thread_events(seeded["thread_id"]) + response_trace_events = store.list_trace_events(UUID(payload["trace"]["response_trace_id"])) + + assert [event["sequence_no"] for event in events] == [1, 2] + assert [event["kind"] for event in events] == ["message.user", "message.user"] + assert events[-1]["payload"] == {"text": "What do I usually take in coffee?"} + assert [event["kind"] for event in response_trace_events] == [ + "response.prompt.assembled", + "response.model.failed", + ] + assert response_trace_events[1]["payload"] == { + "provider": "openai_responses", + "model": "gpt-5-mini", + "tool_choice": "none", + "tools_enabled": False, + "response_id": None, + "finish_reason": "incomplete", + "output_text_char_count": 0, + "usage": {"input_tokens": None, "output_tokens": None, "total_tokens": None}, + "error_message": "upstream timeout", + } + + +def test_generate_response_respects_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_response_thread(migrated_database_urls["app"]) + intruder = seed_response_thread( + migrated_database_urls["app"], + email="intruder@example.com", + display_name="Intruder", + ) + captured = {"invoke_model_called": False} + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + model_provider="openai_responses", + model_name="gpt-5-mini", + model_api_key="test-key", + ), + ) + + def fake_invoke_model(**_kwargs): + captured["invoke_model_called"] = True + raise AssertionError("invoke_model should not be called for cross-user access") + + monkeypatch.setattr(response_generation_module, "invoke_model", fake_invoke_model) + + status_code, payload = invoke_generate_response( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(owner["thread_id"]), + "message": "Tell me their preferences.", + } + ) + + assert status_code == 404 + assert payload == {"detail": "get_thread did not return a row from the database"} + assert captured["invoke_model_called"] is False + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_events = ContinuityStore(conn).list_thread_events(owner["thread_id"]) + + assert [event["sequence_no"] for event in owner_events] == [1] diff --git a/tests/integration/test_task_workspaces_api.py b/tests/integration/test_task_workspaces_api.py new file mode 100644 index 0000000..31aa9d5 --- /dev/null +++ b/tests/integration/test_task_workspaces_api.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_task(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Workspace thread") + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + task = store.create_task( + thread_id=thread["id"], + tool_id=tool["id"], + status="approved", + request={ + "thread_id": str(thread["id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + tool={ + "id": str(tool["id"]), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + latest_approval_id=None, + latest_execution_id=None, + ) + + return { + "user_id": user_id, + "task_id": task["id"], + } + + +def test_task_workspace_endpoints_provision_read_isolate_and_reject_duplicates( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + create_status, create_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/task-workspaces", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/task-workspaces/{create_payload['workspace']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + duplicate_status, duplicate_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/task-workspaces", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/task-workspaces/{create_payload['workspace']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_create_status, isolated_create_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(intruder["user_id"])}, + ) + + expected_path = (workspace_root / str(owner["user_id"]) / str(owner["task_id"])).resolve() + + assert create_status == 201 + assert create_payload["workspace"] == { + "id": create_payload["workspace"]["id"], + "task_id": str(owner["task_id"]), + "status": "active", + "local_path": str(expected_path), + "created_at": create_payload["workspace"]["created_at"], + "updated_at": create_payload["workspace"]["updated_at"], + } + assert Path(create_payload["workspace"]["local_path"]).is_dir() + + assert list_status == 200 + assert list_payload == { + "items": [create_payload["workspace"]], + "summary": {"total_count": 1, "order": ["created_at_asc", "id_asc"]}, + } + + assert detail_status == 200 + assert detail_payload == {"workspace": create_payload["workspace"]} + + assert duplicate_status == 409 + assert duplicate_payload == { + "detail": f"task {owner['task_id']} already has active workspace {create_payload['workspace']['id']}" + } + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"task workspace {create_payload['workspace']['id']} was not found" + } + + assert isolated_create_status == 404 + assert isolated_create_payload == {"detail": f"task {owner['task_id']} was not found"} diff --git a/tests/integration/test_tasks_api.py b/tests/integration/test_tasks_api.py new file mode 100644 index 0000000..2987567 --- /dev/null +++ b/tests/integration/test_tasks_api.py @@ -0,0 +1,946 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Task thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_task_endpoints_list_detail_lifecycle_and_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + approval_tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=20, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + pending_status, pending_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(approval_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "hello"}, + }, + ) + assert pending_status == 200 + assert pending_payload["task"]["status"] == "pending_approval" + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{pending_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + assert approve_payload["approval"]["status"] == "approved" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{pending_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + assert execute_payload["result"]["status"] == "completed" + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {}, + }, + ) + assert ready_status == 200 + assert ready_payload["task"]["status"] == "approved" + + denied_status, denied_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + assert denied_status == 200 + assert denied_payload["task"]["status"] == "denied" + + list_status, list_payload = invoke_request( + "GET", + "/v0/tasks", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + step_detail_status, step_detail_payload = invoke_request( + "GET", + f"/v0/task-steps/{step_list_payload['items'][0]['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/tasks", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + isolated_step_list_status, isolated_step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}/steps", + query_params={"user_id": str(intruder['user_id'])}, + ) + isolated_step_detail_status, isolated_step_detail_payload = invoke_request( + "GET", + f"/v0/task-steps/{step_list_payload['items'][0]['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + pending_payload["task"]["id"], + ready_payload["task"]["id"], + denied_payload["task"]["id"], + ] + assert [item["status"] for item in list_payload["items"]] == [ + "executed", + "approved", + "denied", + ] + assert list_payload["summary"] == { + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + assert detail_status == 200 + assert detail_payload["task"]["id"] == pending_payload["task"]["id"] + assert detail_payload["task"]["status"] == "executed" + assert detail_payload["task"]["latest_approval_id"] == pending_payload["approval"]["id"] + assert detail_payload["task"]["latest_execution_id"] is not None + assert step_list_status == 200 + assert [item["sequence_no"] for item in step_list_payload["items"]] == [1] + assert step_list_payload["summary"] == { + "task_id": pending_payload["task"]["id"], + "total_count": 1, + "latest_sequence_no": 1, + "latest_status": "executed", + "next_sequence_no": 2, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert step_list_payload["items"][0] == { + "id": step_list_payload["items"][0]["id"], + "task_id": pending_payload["task"]["id"], + "sequence_no": 1, + "lineage": { + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + }, + "kind": "governed_request", + "status": "executed", + "request": pending_payload["request"], + "outcome": { + "routing_decision": "approval_required", + "approval_id": pending_payload["approval"]["id"], + "approval_status": "approved", + "execution_id": detail_payload["task"]["latest_execution_id"], + "execution_status": "completed", + "blocked_reason": None, + }, + "trace": { + "trace_id": execute_payload["trace"]["trace_id"], + "trace_kind": "tool.proxy.execute", + }, + "created_at": step_list_payload["items"][0]["created_at"], + "updated_at": step_list_payload["items"][0]["updated_at"], + } + assert step_detail_status == 200 + assert step_detail_payload == {"task_step": step_list_payload["items"][0]} + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"task {pending_payload['task']['id']} was not found" + } + assert isolated_step_list_status == 404 + assert isolated_step_list_payload == { + "detail": f"task {pending_payload['task']['id']} was not found" + } + assert isolated_step_detail_status == 404 + assert isolated_step_detail_payload == { + "detail": f"task step {step_list_payload['items'][0]['id']} was not found" + } + + +def test_task_step_sequence_and_transition_endpoints_preserve_parent_consistency_trace_and_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-sequence@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder-sequence@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + + request_status, request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "seed-step"}, + }, + ) + assert request_status == 200 + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + initial_detail_status, initial_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert initial_detail_status == 200 + initial_step_list_status, initial_step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert initial_step_list_status == 200 + initial_execution_id = initial_detail_payload["task"]["latest_execution_id"] + assert initial_execution_id is not None + + create_status, create_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-2"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": initial_step_list_payload["items"][0]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + + assert create_status == 201 + assert create_payload["task"]["status"] == "pending_approval" + assert create_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert create_payload["task_step"]["sequence_no"] == 2 + assert create_payload["task_step"]["status"] == "created" + assert create_payload["task_step"]["lineage"] == { + "parent_step_id": initial_step_list_payload["items"][0]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + } + assert create_payload["sequencing"] == { + "task_id": request_payload["task"]["id"], + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "created", + "next_sequence_no": 3, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + + duplicate_create_status, duplicate_create_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-3"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": create_payload["task_step"]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + assert duplicate_create_status == 409 + assert duplicate_create_payload["detail"] == ( + f"task {request_payload['task']['id']} latest step {create_payload['task_step']['id']} is created and cannot append a next step" + ) + + invalid_transition_status, invalid_transition_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "executed", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": str(uuid4()), + "execution_id": str(uuid4()), + "execution_status": "completed", + "blocked_reason": None, + }, + }, + ) + assert invalid_transition_status == 409 + assert invalid_transition_payload["detail"] == ( + f"task step {create_payload['task_step']['id']} is created and cannot transition to executed; allowed: approved, denied" + ) + + approve_step_status, approve_step_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": request_payload["approval"]["id"], + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert approve_step_status == 200 + assert approve_step_payload["task"]["status"] == "approved" + assert approve_step_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert approve_step_payload["task"]["latest_execution_id"] is None + assert approve_step_payload["task_step"]["status"] == "approved" + + execute_step_status, execute_step_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "executed", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": request_payload["approval"]["id"], + "execution_id": initial_execution_id, + "execution_status": "completed", + "blocked_reason": None, + }, + }, + ) + assert execute_step_status == 200 + assert execute_step_payload["task"]["status"] == "executed" + assert execute_step_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert execute_step_payload["task"]["latest_execution_id"] == initial_execution_id + assert execute_step_payload["task_step"]["status"] == "executed" + assert execute_step_payload["sequencing"] == { + "task_id": request_payload["task"]["id"], + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "executed", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + step_detail_status, step_detail_payload = invoke_request( + "GET", + f"/v0/task-steps/{create_payload['task_step']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert detail_status == 200 + assert detail_payload["task"]["status"] == "executed" + assert detail_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert detail_payload["task"]["latest_execution_id"] == initial_execution_id + assert step_list_status == 200 + assert [item["sequence_no"] for item in step_list_payload["items"]] == [1, 2] + assert step_list_payload["items"][1]["lineage"] == create_payload["task_step"]["lineage"] + assert step_list_payload["summary"] == { + "task_id": request_payload["task"]["id"], + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "executed", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert step_detail_status == 200 + assert step_detail_payload["task_step"] == step_list_payload["items"][1] + assert step_detail_payload["task_step"]["lineage"] == create_payload["task_step"]["lineage"] + assert step_detail_payload["task_step"]["outcome"] == { + "routing_decision": "approval_required", + "approval_id": request_payload["approval"]["id"], + "approval_status": "approved", + "execution_id": initial_execution_id, + "execution_status": "completed", + "blocked_reason": None, + } + + isolated_create_status, isolated_create_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(intruder["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": create_payload["task_step"]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + isolated_transition_status, isolated_transition_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(intruder["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": str(uuid4()), + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert isolated_create_status == 404 + assert isolated_create_payload == { + "detail": f"task {request_payload['task']['id']} was not found" + } + assert isolated_transition_status == 404 + assert isolated_transition_payload == { + "detail": f"task step {create_payload['task_step']['id']} was not found" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + create_trace_events = store.list_trace_events(UUID(create_payload["trace"]["trace_id"])) + transition_trace_events = store.list_trace_events(UUID(execute_step_payload["trace"]["trace_id"])) + + assert [event["kind"] for event in create_trace_events] == [ + "task.step.continuation.request", + "task.step.continuation.lineage", + "task.step.continuation.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert create_trace_events[1]["payload"] == { + "task_id": request_payload["task"]["id"], + "parent_task_step_id": step_list_payload["items"][0]["id"], + "parent_sequence_no": 1, + "parent_status": "executed", + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + } + assert [event["kind"] for event in transition_trace_events] == [ + "task.step.transition.request", + "task.step.transition.state", + "task.step.transition.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert transition_trace_events[1]["payload"] == { + "task_id": request_payload["task"]["id"], + "task_step_id": create_payload["task_step"]["id"], + "sequence_no": 2, + "previous_status": "approved", + "current_status": "executed", + "allowed_next_statuses": ["executed", "blocked"], + "trace": { + "trace_id": execute_step_payload["trace"]["trace_id"], + "trace_kind": "task.step.transition", + }, + } + assert transition_trace_events[2]["payload"] == { + "task_id": request_payload["task"]["id"], + "task_step_id": create_payload["task_step"]["id"], + "sequence_no": 2, + "final_status": "executed", + "parent_task_status": "executed", + "trace": { + "trace_id": execute_step_payload["trace"]["trace_id"], + "trace_kind": "task.step.transition", + }, + } + + +def test_task_step_mutations_reject_visible_links_from_other_task_lineages( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-lineage@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + + first_request_status, first_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "first"}, + }, + ) + assert first_request_status == 200 + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_execute_status == 200 + first_detail_status, first_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{first_request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert first_detail_status == 200 + first_step_list_status, first_step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{first_request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert first_step_list_status == 200 + first_step_id = first_step_list_payload["items"][0]["id"] + first_execution_id = first_detail_payload["task"]["latest_execution_id"] + assert first_execution_id is not None + + second_request_status, second_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "second"}, + }, + ) + assert second_request_status == 200 + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert second_approve_status == 200 + second_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert second_execute_status == 200 + second_detail_status, second_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{second_request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert second_detail_status == 200 + second_execution_id = second_detail_payload["task"]["latest_execution_id"] + assert second_execution_id is not None + + wrong_create_status, wrong_create_payload = invoke_request( + "POST", + f"/v0/tasks/{first_request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "lineage-mismatch"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": first_step_id, + "source_approval_id": second_request_payload["approval"]["id"], + "source_execution_id": None, + }, + }, + ) + assert wrong_create_status == 409 + assert wrong_create_payload == { + "detail": ( + f"approval {second_request_payload['approval']['id']} does not belong to task {first_request_payload['task']['id']}" + ) + } + + create_status, create_payload = invoke_request( + "POST", + f"/v0/tasks/{first_request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "valid"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": first_step_id, + "source_approval_id": first_request_payload["approval"]["id"], + "source_execution_id": first_execution_id, + }, + }, + ) + assert create_status == 201 + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": first_request_payload["approval"]["id"], + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert approve_status == 200 + + wrong_execute_status, wrong_execute_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "executed", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": first_request_payload["approval"]["id"], + "execution_id": second_execution_id, + "execution_status": "completed", + "blocked_reason": None, + }, + }, + ) + assert wrong_execute_status == 409 + assert wrong_execute_payload == { + "detail": ( + f"tool execution {second_execution_id} does not belong to task {first_request_payload['task']['id']}" + ) + } + + assert first_execution_id != second_execution_id + assert first_request_payload["approval"]["id"] != second_request_payload["approval"]["id"] + assert approve_payload["task"]["latest_approval_id"] == first_request_payload["approval"]["id"] diff --git a/tests/integration/test_tool_api.py b/tests/integration/test_tool_api.py new file mode 100644 index 0000000..df7afd3 --- /dev/null +++ b/tests/integration/test_tool_api.py @@ -0,0 +1,930 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Tool thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_tool_endpoints_create_list_and_get_in_deterministic_order(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + second_status, second_payload = invoke_request( + "POST", + "/v0/tools", + payload={ + "user_id": str(seeded["user_id"]), + "tool_key": "zeta.fetch", + "name": "Zeta Fetch", + "description": "Fetch zeta records.", + "version": "2.0.0", + "active": True, + "tags": ["fetch"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + ) + first_status, first_payload = invoke_request( + "POST", + "/v0/tools", + payload={ + "user_id": str(seeded["user_id"]), + "tool_key": "alpha.open", + "name": "Alpha Open", + "description": "Open alpha pages.", + "version": "1.0.0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/tools", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tools/{second_payload['tool']['id']}", + query_params={"user_id": str(seeded['user_id'])}, + ) + + assert first_status == 201 + assert second_status == 201 + assert list_status == 200 + assert [item["tool_key"] for item in list_payload["items"]] == ["alpha.open", "zeta.fetch"] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"tool": second_payload["tool"]} + assert first_payload["tool"]["metadata_version"] == "tool_metadata_v0" + + +def test_tool_allowlist_evaluation_returns_allowed_denied_and_approval_required_with_trace( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + allowed_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_by_metadata_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + denied_by_consent_tool = store.create_tool( + tool_key="contacts.export", + name="Contacts Export", + description="Export contacts.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["contacts"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + store.create_policy( + name="Allow contacts export with consent", + action="tool.run", + scope="workspace", + effect="allow", + priority=20, + active=True, + conditions={"tool_key": "contacts.export", "domain_hint": "docs"}, + required_consents=["contacts_consent"], + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=30, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/tools/allowlist/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {"channel": "chat"}, + }, + ) + + assert status_code == 200 + assert [item["tool"]["id"] for item in payload["allowed"]] == [str(allowed_tool["id"])] + assert [item["tool"]["id"] for item in payload["approval_required"]] == [str(approval_tool["id"])] + assert [item["tool"]["id"] for item in payload["denied"]] == [ + str(denied_by_metadata_tool["id"]), + str(denied_by_consent_tool["id"]), + ] + assert [reason["code"] for reason in payload["denied"][0]["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert [reason["code"] for reason in payload["denied"][1]["reasons"]] == [ + "tool_metadata_matched", + "matched_policy", + "consent_missing", + ] + assert payload["summary"] == { + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "evaluated_tool_count": 4, + "allowed_count": 1, + "denied_count": 2, + "approval_required_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 7 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(payload["trace"]["trace_id"])) + + assert trace["kind"] == "tool.allowlist.evaluate" + assert trace["compiler_version"] == "tool_allowlist_evaluation_v0" + assert trace["limits"] == { + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + "active_tool_count": 4, + "active_policy_count": 3, + "consent_count": 1, + } + assert [event["kind"] for event in trace_events] == [ + "tool.allowlist.request", + "tool.allowlist.order", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.summary", + ] + assert trace_events[2]["payload"]["decision"] == "allowed" + assert trace_events[-1]["payload"] == { + "allowed_count": 1, + "denied_count": 2, + "approval_required_count": 1, + } + + +def test_tool_route_returns_ready_denied_and_approval_required_with_trace( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + ready_policy = store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + approval_policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=20, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {"channel": "chat"}, + }, + ) + denied_status, denied_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + approval_status, approval_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(approval_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert ready_status == 200 + assert list(ready_payload) == ["request", "decision", "tool", "reasons", "summary", "trace"] + assert ready_payload["decision"] == "ready" + assert ready_payload["request"] == { + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {"channel": "chat"}, + } + assert ready_payload["tool"]["id"] == str(ready_tool["id"]) + assert [reason["code"] for reason in ready_payload["reasons"]] == [ + "tool_metadata_matched", + "matched_policy", + "policy_effect_allow", + ] + assert ready_payload["summary"] == { + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 2, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert ready_payload["trace"]["trace_event_count"] == 3 + + assert denied_status == 200 + assert denied_payload["decision"] == "denied" + assert [reason["code"] for reason in denied_payload["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert denied_payload["summary"]["decision"] == "denied" + + assert approval_status == 200 + assert approval_payload["decision"] == "approval_required" + assert approval_payload["summary"]["decision"] == "approval_required" + assert approval_payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(approval_tool["id"]), + "policy_id": str(approval_policy["id"]), + "consent_key": None, + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + ready_trace = store.get_trace(UUID(ready_payload["trace"]["trace_id"])) + ready_trace_events = store.list_trace_events(UUID(ready_payload["trace"]["trace_id"])) + + assert ready_trace["kind"] == "tool.route" + assert ready_trace["compiler_version"] == "tool_routing_v0" + assert ready_trace["limits"] == { + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + "evaluated_tool_count": 1, + "active_policy_count": 2, + "consent_count": 1, + } + assert [event["kind"] for event in ready_trace_events] == [ + "tool.route.request", + "tool.route.decision", + "tool.route.summary", + ] + assert ready_trace_events[1]["payload"] == { + "tool_id": str(ready_tool["id"]), + "tool_key": "browser.open", + "tool_version": "1.0.0", + "allowlist_decision": "allowed", + "routing_decision": "ready", + "matched_policy_id": str(ready_policy["id"]), + "reasons": ready_payload["reasons"], + } + assert ready_trace_events[2]["payload"] == { + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 2, + "consent_count": 1, + } + + +def test_tool_route_validates_invalid_thread_and_tool(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + invalid_thread_status, invalid_thread_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(uuid4()), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + invalid_tool_status, invalid_tool_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert invalid_thread_status == 400 + assert invalid_thread_payload == { + "detail": "thread_id must reference an existing thread owned by the user" + } + assert invalid_tool_status == 400 + assert invalid_tool_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + + +def test_tool_endpoints_and_allowlist_enforce_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/tools", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tools/{owner_tool['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + evaluation_status, evaluation_payload = invoke_request( + "POST", + "/v0/tools/allowlist/evaluate", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert list_status == 200 + assert list_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + } + assert detail_status == 404 + assert detail_payload == {"detail": f"tool {owner_tool['id']} was not found"} + assert evaluation_status == 200 + assert evaluation_payload["allowed"] == [] + assert evaluation_payload["denied"] == [] + assert evaluation_payload["approval_required"] == [] + assert evaluation_payload["summary"]["evaluated_tool_count"] == 0 + + +def test_tool_routing_returns_ready_denied_and_approval_required_with_trace( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="contacts.export", + name="Contacts Export", + description="Export contacts.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["contacts"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + store.create_policy( + name="Allow contacts export with consent", + action="tool.run", + scope="workspace", + effect="allow", + priority=20, + active=True, + conditions={"tool_key": "contacts.export", "domain_hint": "docs"}, + required_consents=["contacts_consent"], + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=30, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {"channel": "chat"}, + }, + ) + denied_status, denied_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {}, + }, + ) + approval_status, approval_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(approval_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert ready_status == 200 + assert ready_payload["decision"] == "ready" + assert ready_payload["tool"]["id"] == str(ready_tool["id"]) + assert ready_payload["summary"] == { + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 3, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert ready_payload["trace"]["trace_event_count"] == 3 + + assert denied_status == 200 + assert denied_payload["decision"] == "denied" + assert [reason["code"] for reason in denied_payload["reasons"]] == [ + "tool_metadata_matched", + "matched_policy", + "consent_missing", + ] + + assert approval_status == 200 + assert approval_payload["decision"] == "approval_required" + assert approval_payload["reasons"][-1]["code"] == "policy_effect_require_approval" + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(ready_payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(ready_payload["trace"]["trace_id"])) + + assert trace["kind"] == "tool.route" + assert trace["compiler_version"] == "tool_routing_v0" + assert trace["limits"] == { + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + "evaluated_tool_count": 1, + "active_policy_count": 3, + "consent_count": 1, + } + assert [event["kind"] for event in trace_events] == [ + "tool.route.request", + "tool.route.decision", + "tool.route.summary", + ] + assert trace_events[1]["payload"]["allowlist_decision"] == "allowed" + assert trace_events[1]["payload"]["routing_decision"] == "ready" + assert trace_events[2]["payload"] == { + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 3, + "consent_count": 1, + } + + +def test_tool_routing_validates_invalid_references_and_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + invalid_thread_status, invalid_thread_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(uuid4()), + "tool_id": str(owner_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + invalid_tool_status, invalid_tool_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + isolation_status, isolation_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "tool_id": str(owner_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert invalid_thread_status == 400 + assert invalid_thread_payload == { + "detail": "thread_id must reference an existing thread owned by the user" + } + assert invalid_tool_status == 400 + assert invalid_tool_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + assert isolation_status == 400 + assert isolation_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + + +def test_tool_route_enforces_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + route_status, route_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "tool_id": str(owner_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert route_status == 400 + assert route_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } diff --git a/tests/unit/test_20260310_0001_foundation_continuity.py b/tests/unit/test_20260310_0001_foundation_continuity.py new file mode 100644 index 0000000..9ac3fc7 --- /dev/null +++ b/tests/unit/test_20260310_0001_foundation_continuity.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260310_0001_foundation_continuity" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE users ENABLE ROW LEVEL SECURITY", + "ALTER TABLE users FORCE ROW LEVEL SECURITY", + "ALTER TABLE threads ENABLE ROW LEVEL SECURITY", + "ALTER TABLE threads FORCE ROW LEVEL SECURITY", + "ALTER TABLE sessions ENABLE ROW LEVEL SECURITY", + "ALTER TABLE sessions FORCE ROW LEVEL SECURITY", + "ALTER TABLE events ENABLE ROW LEVEL SECURITY", + "ALTER TABLE events FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_base_downgrade_does_not_drop_global_extensions() -> None: + module = load_migration_module() + + assert "DROP EXTENSION IF EXISTS vector" not in module._DOWNGRADE_STATEMENTS + assert "DROP EXTENSION IF EXISTS pgcrypto" not in module._DOWNGRADE_STATEMENTS + + +def test_base_schema_does_not_create_redundant_events_sequence_index() -> None: + module = load_migration_module() + + assert "CREATE INDEX events_thread_sequence_idx" not in module._UPGRADE_SCHEMA_STATEMENT diff --git a/tests/unit/test_20260311_0002_tighten_runtime_privileges.py b/tests/unit/test_20260311_0002_tighten_runtime_privileges.py new file mode 100644 index 0000000..af0925d --- /dev/null +++ b/tests/unit/test_20260311_0002_tighten_runtime_privileges.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260311_0002_tighten_runtime_privileges" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_downgrade_reasserts_revision_0001_privilege_floor() -> None: + module = load_migration_module() + + assert module._DOWNGRADE_STATEMENTS == ( + "REVOKE UPDATE ON users FROM alicebot_app", + "REVOKE UPDATE ON threads FROM alicebot_app", + "REVOKE UPDATE ON sessions FROM alicebot_app", + ) diff --git a/tests/unit/test_20260311_0003_trace_backbone.py b/tests/unit/test_20260311_0003_trace_backbone.py new file mode 100644 index 0000000..5780912 --- /dev/null +++ b/tests/unit/test_20260311_0003_trace_backbone.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260311_0003_trace_backbone" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE traces ENABLE ROW LEVEL SECURITY", + "ALTER TABLE traces FORCE ROW LEVEL SECURITY", + "ALTER TABLE trace_events ENABLE ROW LEVEL SECURITY", + "ALTER TABLE trace_events FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_trace_tables_keep_runtime_role_at_select_insert_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON traces TO alicebot_app", + "GRANT SELECT, INSERT ON trace_events TO alicebot_app", + ) diff --git a/tests/unit/test_20260311_0004_memory_admission.py b/tests/unit/test_20260311_0004_memory_admission.py new file mode 100644 index 0000000..fe561e1 --- /dev/null +++ b/tests/unit/test_20260311_0004_memory_admission.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260311_0004_memory_admission" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE memories ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memories FORCE ROW LEVEL SECURITY", + "ALTER TABLE memory_revisions ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memory_revisions FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_memory_table_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON memories TO alicebot_app", + "GRANT SELECT, INSERT ON memory_revisions TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0005_memory_review_labels.py b/tests/unit/test_20260312_0005_memory_review_labels.py new file mode 100644 index 0000000..2476797 --- /dev/null +++ b/tests/unit/test_20260312_0005_memory_review_labels.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0005_memory_review_labels" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE memory_review_labels ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memory_review_labels FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_memory_review_label_table_privileges_stay_append_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON memory_review_labels TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0006_entities_backbone.py b/tests/unit/test_20260312_0006_entities_backbone.py new file mode 100644 index 0000000..d099878 --- /dev/null +++ b/tests/unit/test_20260312_0006_entities_backbone.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0006_entities_backbone" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE entities ENABLE ROW LEVEL SECURITY", + "ALTER TABLE entities FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_entities_table_privileges_stay_insert_select_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON entities TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0007_entity_edges.py b/tests/unit/test_20260312_0007_entity_edges.py new file mode 100644 index 0000000..255b9fb --- /dev/null +++ b/tests/unit/test_20260312_0007_entity_edges.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0007_entity_edges" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE entity_edges ENABLE ROW LEVEL SECURITY", + "ALTER TABLE entity_edges FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_entity_edges_table_privileges_stay_insert_select_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON entity_edges TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0008_embedding_substrate.py b/tests/unit/test_20260312_0008_embedding_substrate.py new file mode 100644 index 0000000..240286f --- /dev/null +++ b/tests/unit/test_20260312_0008_embedding_substrate.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0008_embedding_substrate" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE embedding_configs ENABLE ROW LEVEL SECURITY", + "ALTER TABLE embedding_configs FORCE ROW LEVEL SECURITY", + "ALTER TABLE memory_embeddings ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memory_embeddings FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_embedding_tables_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON embedding_configs TO alicebot_app", + "GRANT SELECT, INSERT, UPDATE ON memory_embeddings TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0009_policy_and_consent_core.py b/tests/unit/test_20260312_0009_policy_and_consent_core.py new file mode 100644 index 0000000..b926485 --- /dev/null +++ b/tests/unit/test_20260312_0009_policy_and_consent_core.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0009_policy_and_consent_core" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE consents ENABLE ROW LEVEL SECURITY", + "ALTER TABLE consents FORCE ROW LEVEL SECURITY", + "ALTER TABLE policies ENABLE ROW LEVEL SECURITY", + "ALTER TABLE policies FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_policy_and_consent_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON consents TO alicebot_app", + "GRANT SELECT, INSERT ON policies TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0010_tools_registry_and_allowlist.py b/tests/unit/test_20260312_0010_tools_registry_and_allowlist.py new file mode 100644 index 0000000..b7c4215 --- /dev/null +++ b/tests/unit/test_20260312_0010_tools_registry_and_allowlist.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0010_tools_registry_and_allowlist" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE tools ENABLE ROW LEVEL SECURITY", + "ALTER TABLE tools FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_tools_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON tools TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0011_approval_request_records.py b/tests/unit/test_20260312_0011_approval_request_records.py new file mode 100644 index 0000000..00c051b --- /dev/null +++ b/tests/unit/test_20260312_0011_approval_request_records.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0011_approval_request_records" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE approvals ENABLE ROW LEVEL SECURITY", + "ALTER TABLE approvals FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_approvals_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON approvals TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0012_approval_resolution.py b/tests/unit/test_20260312_0012_approval_resolution.py new file mode 100644 index 0000000..7e37cd5 --- /dev/null +++ b/tests/unit/test_20260312_0012_approval_resolution.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0012_approval_resolution" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_approvals_resolution_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT UPDATE ON approvals TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0013_tool_executions.py b/tests/unit/test_20260313_0013_tool_executions.py new file mode 100644 index 0000000..84e4f67 --- /dev/null +++ b/tests/unit/test_20260313_0013_tool_executions.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0013_tool_executions" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE tool_executions ENABLE ROW LEVEL SECURITY", + "ALTER TABLE tool_executions FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_tool_executions_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON tool_executions TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0014_execution_budgets.py b/tests/unit/test_20260313_0014_execution_budgets.py new file mode 100644 index 0000000..a1cadf3 --- /dev/null +++ b/tests/unit/test_20260313_0014_execution_budgets.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0014_execution_budgets" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE execution_budgets ENABLE ROW LEVEL SECURITY", + "ALTER TABLE execution_budgets FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_execution_budgets_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON execution_budgets TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0015_execution_budget_lifecycle.py b/tests/unit/test_20260313_0015_execution_budget_lifecycle.py new file mode 100644 index 0000000..f1a7468 --- /dev/null +++ b/tests/unit/test_20260313_0015_execution_budget_lifecycle.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0015_execution_budget_lifecycle" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_execution_budget_lifecycle_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_STATEMENTS[-1] == "GRANT SELECT, INSERT, UPDATE ON execution_budgets TO alicebot_app" diff --git a/tests/unit/test_20260313_0016_execution_budget_rolling_window.py b/tests/unit/test_20260313_0016_execution_budget_rolling_window.py new file mode 100644 index 0000000..631b0bb --- /dev/null +++ b/tests/unit/test_20260313_0016_execution_budget_rolling_window.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0016_execution_budget_rolling_window" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0018_task_steps.py b/tests/unit/test_20260313_0018_task_steps.py new file mode 100644 index 0000000..c3ab793 --- /dev/null +++ b/tests/unit/test_20260313_0018_task_steps.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0018_task_steps" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_steps ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_steps FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_step_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON task_steps TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0019_task_step_lineage.py b/tests/unit/test_20260313_0019_task_step_lineage.py new file mode 100644 index 0000000..68fdcca --- /dev/null +++ b/tests/unit/test_20260313_0019_task_step_lineage.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0019_task_step_lineage" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statement(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [module._UPGRADE_SCHEMA_STATEMENT] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0020_approval_task_step_linkage.py b/tests/unit/test_20260313_0020_approval_task_step_linkage.py new file mode 100644 index 0000000..5f7816a --- /dev/null +++ b/tests/unit/test_20260313_0020_approval_task_step_linkage.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0020_approval_task_step_linkage" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [module._UPGRADE_SCHEMA_STATEMENT] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py b/tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py new file mode 100644 index 0000000..31f5330 --- /dev/null +++ b/tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0021_tool_execution_task_step_linkage" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0022_task_workspaces.py b/tests/unit/test_20260313_0022_task_workspaces.py new file mode 100644 index 0000000..6e352b9 --- /dev/null +++ b/tests/unit/test_20260313_0022_task_workspaces.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0022_task_workspaces" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_workspaces ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_workspaces FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_workspace_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON task_workspaces TO alicebot_app", + ) diff --git a/tests/unit/test_approval_store.py b/tests/unit/test_approval_store.py new file mode 100644 index 0000000..7a944e7 --- /dev/null +++ b/tests/unit/test_approval_store.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_approval_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + approval_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + task_step_id = uuid4() + routing_trace_id = uuid4() + resolved_by_user_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "pending", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": None, + "resolved_by_user_id": None, + }, + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "pending", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": None, + "resolved_by_user_id": None, + }, + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "approved", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": resolved_by_user_id, + }, + ], + fetchall_result=[ + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "pending", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": None, + "resolved_by_user_id": None, + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_approval( + thread_id=thread_id, + tool_id=tool_id, + task_step_id=task_step_id, + status="pending", + request={"thread_id": str(thread_id), "tool_id": str(tool_id)}, + tool={"id": str(tool_id), "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + routing_trace_id=routing_trace_id, + ) + fetched = store.get_approval_optional(approval_id) + listed = store.list_approvals() + resolved = store.resolve_approval_optional(approval_id=approval_id, status="approved") + + assert created["id"] == approval_id + assert created["resolved_at"] is None + assert fetched is not None + assert listed[0]["id"] == approval_id + assert resolved is not None + assert resolved["status"] == "approved" + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO approvals" in create_query + assert create_params is not None + assert create_params[:4] == (thread_id, tool_id, task_step_id, "pending") + assert isinstance(create_params[4], Jsonb) + assert create_params[4].obj == {"thread_id": str(thread_id), "tool_id": str(tool_id)} + assert isinstance(create_params[5], Jsonb) + assert create_params[5].obj == {"id": str(tool_id), "tool_key": "shell.exec"} + assert isinstance(create_params[6], Jsonb) + assert create_params[6].obj == { + "decision": "approval_required", + "trace": {"trace_id": str(routing_trace_id)}, + } + assert create_params[7] == routing_trace_id + assert "resolved_at" in cursor.executed[1][0] + assert "ORDER BY created_at ASC, id ASC" in cursor.executed[2][0] + + resolve_query, resolve_params = cursor.executed[3] + assert "UPDATE approvals" in resolve_query + assert "WHERE id = %s" in resolve_query + assert "AND status = 'pending'" in resolve_query + assert resolve_params == ("approved", approval_id) diff --git a/tests/unit/test_approvals.py b/tests/unit/test_approvals.py new file mode 100644 index 0000000..2ac7b2f --- /dev/null +++ b/tests/unit/test_approvals.py @@ -0,0 +1,1200 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +from alicebot_api.approvals import ( + ApprovalNotFoundError, + ApprovalResolutionConflictError, + approve_approval_record, + get_approval_record, + list_approval_records, + reject_approval_record, + submit_approval_request, +) +from alicebot_api.contracts import ApprovalApproveInput, ApprovalRejectInput, ApprovalRequestCreateInput +from alicebot_api.tasks import TaskStepApprovalLinkageError + + +class ApprovalStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.locked_task_ids: list[UUID] = [] + self.consents: dict[str, dict[str, object]] = {} + self.policies: list[dict[str, object]] = [] + self.tools: list[dict[str, object]] = [] + self.approvals: list[dict[str, object]] = [] + self.tasks: list[dict[str, object]] = [] + self.task_steps: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def create_consent(self, *, consent_key: str, status: str, metadata: dict[str, object]) -> dict[str, object]: + consent = { + "id": uuid4(), + "user_id": self.user_id, + "consent_key": consent_key, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.consents)), + "updated_at": self.base_time + timedelta(minutes=len(self.consents)), + } + self.consents[consent_key] = consent + return consent + + def list_consents(self) -> list[dict[str, object]]: + return sorted( + self.consents.values(), + key=lambda consent: (consent["consent_key"], consent["created_at"], consent["id"]), + ) + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: dict[str, object], + required_consents: list[str], + ) -> dict[str, object]: + policy = { + "id": uuid4(), + "user_id": self.user_id, + "name": name, + "action": action, + "scope": scope, + "effect": effect, + "priority": priority, + "active": active, + "conditions": conditions, + "required_consents": required_consents, + "created_at": self.base_time + timedelta(minutes=len(self.policies)), + "updated_at": self.base_time + timedelta(minutes=len(self.policies)), + } + self.policies.append(policy) + return policy + + def list_active_policies(self) -> list[dict[str, object]]: + return sorted( + [policy for policy in self.policies if policy["active"] is True], + key=lambda policy: (policy["priority"], policy["created_at"], policy["id"]), + ) + + def create_tool( + self, + *, + tool_key: str, + name: str, + description: str, + version: str, + metadata_version: str, + active: bool, + tags: list[str], + action_hints: list[str], + scope_hints: list[str], + domain_hints: list[str], + risk_hints: list[str], + metadata: dict[str, object], + ) -> dict[str, object]: + tool = { + "id": uuid4(), + "user_id": self.user_id, + "tool_key": tool_key, + "name": name, + "description": description, + "version": version, + "metadata_version": metadata_version, + "active": active, + "tags": tags, + "action_hints": action_hints, + "scope_hints": scope_hints, + "domain_hints": domain_hints, + "risk_hints": risk_hints, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.tools)), + } + self.tools.append(tool) + return tool + + def get_tool_optional(self, tool_id: UUID) -> dict[str, object] | None: + return next((tool for tool in self.tools if tool["id"] == tool_id), None) + + def list_active_tools(self) -> list[dict[str, object]]: + return [tool for tool in self.tools if tool["active"] is True] + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Approval thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + def create_approval( + self, + *, + thread_id: UUID, + tool_id: UUID, + task_step_id: UUID | None, + status: str, + request: dict[str, object], + tool: dict[str, object], + routing: dict[str, object], + routing_trace_id: UUID, + ) -> dict[str, object]: + approval = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": status, + "request": request, + "tool": tool, + "routing": routing, + "routing_trace_id": routing_trace_id, + "created_at": self.base_time + timedelta(minutes=len(self.approvals)), + "resolved_at": None, + "resolved_by_user_id": None, + } + self.approvals.append(approval) + return approval + + def get_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((approval for approval in self.approvals if approval["id"] == approval_id), None) + + def list_approvals(self) -> list[dict[str, object]]: + return sorted( + self.approvals, + key=lambda approval: (approval["created_at"], approval["id"]), + ) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def create_task( + self, + *, + thread_id: UUID, + tool_id: UUID, + status: str, + request: dict[str, object], + tool: dict[str, object], + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object]: + task = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "tool_id": tool_id, + "status": status, + "request": request, + "tool": tool, + "latest_approval_id": latest_approval_id, + "latest_execution_id": latest_execution_id, + "created_at": self.base_time + timedelta(minutes=len(self.tasks)), + "updated_at": self.base_time + timedelta(minutes=len(self.tasks)), + } + self.tasks.append(task) + return task + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def get_task_by_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["latest_approval_id"] == approval_id), None) + + def lock_task_steps(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def list_tasks(self) -> list[dict[str, object]]: + return sorted( + self.tasks, + key=lambda task: (task["created_at"], task["id"]), + ) + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: dict[str, object], + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object]: + task_step = { + "id": uuid4(), + "user_id": self.user_id, + "task_id": task_id, + "sequence_no": sequence_no, + "parent_step_id": parent_step_id, + "source_approval_id": source_approval_id, + "source_execution_id": source_execution_id, + "kind": kind, + "status": status, + "request": request, + "outcome": outcome, + "trace_id": trace_id, + "trace_kind": trace_kind, + "created_at": self.base_time + timedelta(minutes=len(self.task_steps)), + "updated_at": self.base_time + timedelta(minutes=len(self.task_steps)), + } + self.task_steps.append(task_step) + return task_step + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> dict[str, object] | None: + return next( + ( + task_step + for task_step in self.task_steps + if task_step["task_id"] == task_id and task_step["sequence_no"] == sequence_no + ), + None, + ) + + def get_task_step_optional(self, task_step_id: UUID) -> dict[str, object] | None: + return next((task_step for task_step in self.task_steps if task_step["id"] == task_step_id), None) + + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + [task_step for task_step in self.task_steps if task_step["task_id"] == task_id], + key=lambda task_step: (task_step["sequence_no"], task_step["created_at"], task_step["id"]), + ) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_for_task_sequence_optional(task_id=task_id, sequence_no=sequence_no) + if task_step is None: + return None + + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_optional(task_step_id) + if task_step is None: + return None + + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def update_task_status_by_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> dict[str, object] | None: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + return None + task["status"] = status + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def update_task_status_optional( + self, + *, + task_id: UUID, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object] | None: + task = self.get_task_optional(task_id) + if task is None: + return None + task["status"] = status + task["latest_approval_id"] = latest_approval_id + task["latest_execution_id"] = latest_execution_id + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def resolve_approval_optional(self, *, approval_id: UUID, status: str) -> dict[str, object] | None: + approval = self.get_approval_optional(approval_id) + if approval is None or approval["status"] != "pending": + return None + + approval["status"] = status + approval["resolved_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + approval["resolved_by_user_id"] = self.user_id + return approval + + def update_approval_task_step_optional( + self, + *, + approval_id: UUID, + task_step_id: UUID, + ) -> dict[str, object] | None: + approval = self.get_approval_optional(approval_id) + if approval is None: + return None + approval["task_step_id"] = task_step_id + return approval + + +def test_submit_approval_request_persists_record_for_approval_required_route() -> None: + store = ApprovalStoreStub() + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + payload = submit_approval_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalRequestCreateInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={"command": "ls"}, + ), + ) + + assert payload["decision"] == "approval_required" + assert payload["task"]["status"] == "pending_approval" + assert payload["task"]["latest_approval_id"] == payload["approval"]["id"] + assert payload["task"]["latest_execution_id"] is None + assert payload["approval"] is not None + assert payload["approval"]["status"] == "pending" + assert payload["approval"]["resolution"] is None + assert payload["approval"]["thread_id"] == str(store.thread_id) + assert payload["approval"]["task_step_id"] == str(store.task_steps[0]["id"]) + assert payload["approval"]["request"] == payload["request"] + assert payload["approval"]["tool"] == payload["tool"] + assert payload["approval"]["routing"] == { + "decision": "approval_required", + "reasons": payload["reasons"], + "trace": payload["routing_trace"], + } + assert payload["routing_trace"]["trace_event_count"] == 3 + assert payload["trace"]["trace_event_count"] == 8 + assert len(store.approvals) == 1 + assert len(store.tasks) == 1 + assert len(store.task_steps) == 1 + assert store.traces[0]["kind"] == "tool.route" + assert store.traces[1]["kind"] == "approval.request" + assert store.traces[1]["compiler_version"] == "approval_request_v0" + assert store.traces[1]["limits"] == { + "order": ["created_at_asc", "id_asc"], + "persisted": True, + } + assert [event["kind"] for event in store.trace_events[-8:]] == [ + "approval.request.request", + "approval.request.routing", + "approval.request.persisted", + "approval.request.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[-7]["payload"]["routing_trace_id"] == payload["routing_trace"]["trace_id"] + assert store.trace_events[-6]["payload"] == { + "approval_id": payload["approval"]["id"], + "task_step_id": payload["approval"]["task_step_id"], + "decision": "approval_required", + "persisted": True, + } + assert store.trace_events[-4]["payload"] == { + "task_id": payload["task"]["id"], + "source": "approval_request", + "previous_status": None, + "current_status": "pending_approval", + "latest_approval_id": payload["approval"]["id"], + "latest_execution_id": None, + } + assert store.trace_events[-2]["payload"] == { + "task_id": payload["task"]["id"], + "task_step_id": str(store.task_steps[0]["id"]), + "source": "approval_request", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": None, + "current_status": "created", + "trace": { + "trace_id": payload["trace"]["trace_id"], + "trace_kind": "approval.request", + }, + } + assert payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + } + + +def test_submit_approval_request_does_not_persist_for_ready_or_denied_routes() -> None: + ready_store = ApprovalStoreStub() + ready_store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = ready_store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + ready_store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + ready_payload = submit_approval_request( + ready_store, # type: ignore[arg-type] + user_id=ready_store.user_id, + request=ApprovalRequestCreateInput( + thread_id=ready_store.thread_id, + tool_id=ready_tool["id"], + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={}, + ), + ) + + denied_store = ApprovalStoreStub() + denied_tool = denied_store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + denied_payload = submit_approval_request( + denied_store, # type: ignore[arg-type] + user_id=denied_store.user_id, + request=ApprovalRequestCreateInput( + thread_id=denied_store.thread_id, + tool_id=denied_tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + assert ready_payload["decision"] == "ready" + assert ready_payload["task"]["status"] == "approved" + assert ready_payload["task"]["latest_approval_id"] is None + assert ready_payload["approval"] is None + assert ready_store.approvals == [] + assert len(ready_store.task_steps) == 1 + assert [event["kind"] for event in ready_store.trace_events[-8:]] == [ + "approval.request.request", + "approval.request.routing", + "approval.request.skipped", + "approval.request.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + + assert denied_payload["decision"] == "denied" + assert denied_payload["task"]["status"] == "denied" + assert denied_payload["task"]["latest_approval_id"] is None + assert denied_payload["approval"] is None + assert denied_store.approvals == [] + assert [reason["code"] for reason in denied_payload["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + + +def test_approve_approval_record_resolves_pending_and_records_trace() -> None: + store = ApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-1"}, + tool={"id": "tool-1", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-1", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + created_task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + created_step = store.create_task_step( + task_id=created_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + approval["task_step_id"] = created_step["id"] + + payload = approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=approval["id"]), + ) + + assert payload["approval"]["id"] == str(approval["id"]) + assert payload["approval"]["task_step_id"] == str(created_step["id"]) + assert payload["approval"]["status"] == "approved" + assert payload["approval"]["resolution"] == { + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + assert payload["trace"]["trace_event_count"] == 7 + assert store.traces[0]["kind"] == "approval.resolve" + assert store.traces[0]["compiler_version"] == "approval_resolution_v0" + assert store.traces[0]["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "approve", + "outcome": "resolved", + } + assert [event["kind"] for event in store.trace_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[1]["payload"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(approval["task_step_id"]), + "requested_action": "approve", + "previous_status": "pending", + "outcome": "resolved", + "current_status": "approved", + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + assert store.trace_events[3]["payload"] == { + "task_id": str(store.tasks[0]["id"]), + "source": "approval_resolution", + "previous_status": "pending_approval", + "current_status": "approved", + "latest_approval_id": str(approval["id"]), + "latest_execution_id": None, + } + assert store.trace_events[5]["payload"] == { + "task_id": str(store.tasks[0]["id"]), + "task_step_id": str(store.task_steps[0]["id"]), + "source": "approval_resolution", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": "created", + "current_status": "approved", + "trace": { + "trace_id": str(store.traces[0]["id"]), + "trace_kind": "approval.resolve", + }, + } + + +def test_reject_approval_record_resolves_pending_and_records_trace() -> None: + store = ApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-2"}, + tool={"id": "tool-2", "tool_key": "browser.open"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-2", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + created_task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + created_step = store.create_task_step( + task_id=created_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + approval["task_step_id"] = created_step["id"] + + payload = reject_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalRejectInput(approval_id=approval["id"]), + ) + + assert payload["approval"]["status"] == "rejected" + assert payload["approval"]["task_step_id"] == str(created_step["id"]) + assert payload["approval"]["resolution"] == { + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + assert store.trace_events[1]["payload"]["requested_action"] == "reject" + assert store.trace_events[1]["payload"]["current_status"] == "rejected" + + +def test_approval_resolution_locks_task_steps_before_task_and_step_mutation() -> None: + class LockingApprovalStoreStub(ApprovalStoreStub): + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + if task_id not in self.locked_task_ids: + raise AssertionError("task-step boundary was checked before the task-step lock was taken") + return super().list_task_steps_for_task(task_id) + + def update_task_status_by_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> dict[str, object] | None: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + return None + if task["id"] not in self.locked_task_ids: + raise AssertionError("task status changed before the task-step lock was taken") + return super().update_task_status_by_approval_optional( + approval_id=approval_id, + status=status, + ) + + store = LockingApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-lock"}, + tool={"id": "tool-lock", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-lock", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + created_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + approval["task_step_id"] = created_step["id"] + + payload = approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=approval["id"]), + ) + + assert payload["approval"]["status"] == "approved" + assert task["id"] in store.locked_task_ids + + +def test_resolution_rejects_duplicate_and_conflicting_updates_deterministically() -> None: + duplicate_store = ApprovalStoreStub() + duplicate_approval = duplicate_store.create_approval( + thread_id=duplicate_store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(duplicate_store.thread_id), "tool_id": "tool-3"}, + tool={"id": "tool-3", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-3", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + duplicate_task = duplicate_store.create_task( + thread_id=duplicate_store.thread_id, + tool_id=duplicate_approval["tool_id"], + status="pending_approval", + request=duplicate_approval["request"], + tool=duplicate_approval["tool"], + latest_approval_id=duplicate_approval["id"], + latest_execution_id=None, + ) + duplicate_step = duplicate_store.create_task_step( + task_id=duplicate_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=duplicate_approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(duplicate_approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + duplicate_approval["task_step_id"] = duplicate_step["id"] + approve_approval_record( + duplicate_store, # type: ignore[arg-type] + user_id=duplicate_store.user_id, + request=ApprovalApproveInput(approval_id=duplicate_approval["id"]), + ) + + try: + approve_approval_record( + duplicate_store, # type: ignore[arg-type] + user_id=duplicate_store.user_id, + request=ApprovalApproveInput(approval_id=duplicate_approval["id"]), + ) + except ApprovalResolutionConflictError as exc: + assert str(exc) == f"approval {duplicate_approval['id']} was already approved" + else: + raise AssertionError("expected ApprovalResolutionConflictError for duplicate approval") + + assert duplicate_store.trace_events[-6]["payload"]["outcome"] == "duplicate_rejected" + + conflict_store = ApprovalStoreStub() + conflict_approval = conflict_store.create_approval( + thread_id=conflict_store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(conflict_store.thread_id), "tool_id": "tool-4"}, + tool={"id": "tool-4", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-4", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + conflict_task = conflict_store.create_task( + thread_id=conflict_store.thread_id, + tool_id=conflict_approval["tool_id"], + status="pending_approval", + request=conflict_approval["request"], + tool=conflict_approval["tool"], + latest_approval_id=conflict_approval["id"], + latest_execution_id=None, + ) + conflict_step = conflict_store.create_task_step( + task_id=conflict_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=conflict_approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(conflict_approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + conflict_approval["task_step_id"] = conflict_step["id"] + approve_approval_record( + conflict_store, # type: ignore[arg-type] + user_id=conflict_store.user_id, + request=ApprovalApproveInput(approval_id=conflict_approval["id"]), + ) + + try: + reject_approval_record( + conflict_store, # type: ignore[arg-type] + user_id=conflict_store.user_id, + request=ApprovalRejectInput(approval_id=conflict_approval["id"]), + ) + except ApprovalResolutionConflictError as exc: + assert str(exc) == ( + f"approval {conflict_approval['id']} was already approved and cannot be rejected" + ) + else: + raise AssertionError("expected ApprovalResolutionConflictError for conflicting rejection") + + assert conflict_store.trace_events[-6]["payload"]["outcome"] == "conflict_rejected" + + +def test_approval_resolution_rejects_inconsistent_linkage_without_mutating_task_state() -> None: + store = ApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="approved", + request={"thread_id": str(store.thread_id), "tool_id": "tool-boundary"}, + tool={"id": "tool-boundary", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-boundary", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "approved", + "execution_id": str(uuid4()), + "execution_status": "completed", + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval["id"], + source_execution_id=uuid4(), + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + approval["task_step_id"] = later_step["id"] + + original_first_trace_id = first_step["trace_id"] + original_first_outcome = dict(first_step["outcome"]) + original_later_trace_id = later_step["trace_id"] + + try: + approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=approval["id"]), + ) + except TaskStepApprovalLinkageError as exc: + assert str(exc) == ( + f"approval {approval['id']} is inconsistent with linked task step {later_step['id']}" + ) + else: + raise AssertionError("expected TaskStepApprovalLinkageError") + + assert task["status"] == "pending_approval" + assert task["latest_execution_id"] is None + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "created" + assert later_step["trace_id"] == original_later_trace_id + assert store.traces == [] + assert store.trace_events == [] + + +def test_list_and_get_approval_records_use_deterministic_order_after_resolution() -> None: + store = ApprovalStoreStub() + first = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-1"}, + tool={"id": "tool-1", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-1", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + first_task = store.create_task( + thread_id=store.thread_id, + tool_id=first["tool_id"], + status="pending_approval", + request=first["request"], + tool=first["tool"], + latest_approval_id=first["id"], + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=first_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=first["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(first["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + first["task_step_id"] = first_step["id"] + second = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-2"}, + tool={"id": "tool-2", "tool_key": "browser.open"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-2", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + second_task = store.create_task( + thread_id=store.thread_id, + tool_id=second["tool_id"], + status="pending_approval", + request=second["request"], + tool=second["tool"], + latest_approval_id=second["id"], + latest_execution_id=None, + ) + second_step = store.create_task_step( + task_id=second_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=second["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(second["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + second["task_step_id"] = second_step["id"] + + approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=first["id"]), + ) + reject_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalRejectInput(approval_id=second["id"]), + ) + + listed = list_approval_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + approval_id=UUID(str(second["id"])), + ) + + assert [item["id"] for item in listed["items"]] == [str(first["id"]), str(second["id"])] + assert [item["task_step_id"] for item in listed["items"]] == [str(first_step["id"]), str(second_step["id"])] + assert [item["status"] for item in listed["items"]] == ["approved", "rejected"] + assert listed["items"][0]["resolution"] is not None + assert listed["items"][1]["resolution"] is not None + assert listed["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail["approval"]["id"] == str(second["id"]) + assert detail["approval"]["task_step_id"] == str(second_step["id"]) + assert detail["approval"]["status"] == "rejected" + assert detail["approval"]["resolution"] == { + "resolved_at": "2026-03-12T10:07:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + + +def test_get_approval_record_raises_not_found_when_missing() -> None: + store = ApprovalStoreStub() + + try: + get_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + approval_id=uuid4(), + ) + except ApprovalNotFoundError as exc: + assert "approval" in str(exc) + else: + raise AssertionError("expected ApprovalNotFoundError") diff --git a/tests/unit/test_approvals_main.py b/tests/unit/test_approvals_main.py new file mode 100644 index 0000000..833f78d --- /dev/null +++ b/tests/unit/test_approvals_main.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.approvals import ApprovalNotFoundError, ApprovalResolutionConflictError +from alicebot_api.tasks import TaskStepApprovalLinkageError +from alicebot_api.tools import ToolRoutingValidationError + + +def test_create_approval_request_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_submit_approval_request(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "request": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"command": "ls"}, + }, + "decision": "approval_required", + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "reasons": [], + "approval": { + "id": "approval-123", + "thread_id": str(thread_id), + "task_step_id": "task-step-123", + "status": "pending", + "resolution": None, + "request": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"command": "ls"}, + }, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "routing_trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + "trace": {"trace_id": "approval-trace-123", "trace_event_count": 4}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "submit_approval_request", fake_submit_approval_request) + + response = main_module.create_approval_request( + main_module.CreateApprovalRequest( + user_id=user_id, + thread_id=thread_id, + tool_id=tool_id, + action="tool.run", + scope="workspace", + attributes={"command": "ls"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == { + "trace_id": "approval-trace-123", + "trace_event_count": 4, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].tool_id == tool_id + assert captured["request"].attributes == {"command": "ls"} + + +def test_create_approval_request_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_submit_approval_request(*_args, **_kwargs): + raise ToolRoutingValidationError("tool_id must reference an existing active tool owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "submit_approval_request", fake_submit_approval_request) + + response = main_module.create_approval_request( + main_module.CreateApprovalRequest( + user_id=user_id, + thread_id=uuid4(), + tool_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + + +def test_list_approvals_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_approval_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + }, + ) + + response = main_module.list_approvals(user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + +def test_get_approval_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_approval_record(*_args, **_kwargs): + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_approval_record", fake_get_approval_record) + + response = main_module.get_approval(approval_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was not found"} + + +def test_approve_approval_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_approve_approval_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-123", + "status": "approved", + "resolution": { + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(user_id), + }, + "request": {"thread_id": "thread-123", "tool_id": "tool-123"}, + "tool": {"id": "tool-123", "tool_key": "shell.exec"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "trace": {"trace_id": "approval-resolution-trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "approve_approval_record", fake_approve_approval_record) + + response = main_module.approve_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == { + "trace_id": "approval-resolution-trace-123", + "trace_event_count": 3, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].approval_id == approval_id + + +def test_approve_approval_endpoint_maps_conflicts_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_approve_approval_record(*_args, **_kwargs): + raise ApprovalResolutionConflictError(f"approval {approval_id} was already approved") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "approve_approval_record", fake_approve_approval_record) + + response = main_module.approve_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was already approved"} + + +def test_approve_approval_endpoint_maps_linkage_errors_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_approve_approval_record(*_args, **_kwargs): + raise TaskStepApprovalLinkageError( + f"approval {approval_id} is inconsistent with linked task step task-step-123" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "approve_approval_record", fake_approve_approval_record) + + response = main_module.approve_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"approval {approval_id} is inconsistent with linked task step task-step-123" + } + + +def test_reject_approval_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_reject_approval_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-456", + "status": "rejected", + "resolution": { + "resolved_at": "2026-03-12T10:01:00+00:00", + "resolved_by_user_id": str(user_id), + }, + "request": {"thread_id": "thread-123", "tool_id": "tool-123"}, + "tool": {"id": "tool-123", "tool_key": "shell.exec"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "trace": {"trace_id": "approval-resolution-trace-456", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "reject_approval_record", fake_reject_approval_record) + + response = main_module.reject_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == { + "trace_id": "approval-resolution-trace-456", + "trace_event_count": 3, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].approval_id == approval_id + + +def test_reject_approval_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_reject_approval_record(*_args, **_kwargs): + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "reject_approval_record", fake_reject_approval_record) + + response = main_module.reject_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was not found"} diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py new file mode 100644 index 0000000..c221707 --- /dev/null +++ b/tests/unit/test_compiler.py @@ -0,0 +1,760 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +from alicebot_api.compiler import ( + SUMMARY_TRACE_EVENT_KIND, + _compile_memory_section, + compile_continuity_context, +) +from alicebot_api.contracts import CompileContextSemanticRetrievalInput, ContextCompilerLimits + + +def test_compile_continuity_context_is_deterministic_and_stably_ordered() -> None: + user_id = uuid4() + thread_id = uuid4() + base_time = datetime(2026, 3, 11, 9, 0, tzinfo=UTC) + session_ids = [uuid4(), uuid4(), uuid4()] + event_ids = [uuid4(), uuid4(), uuid4(), uuid4()] + memory_ids = [uuid4(), uuid4(), uuid4()] + entity_ids = [uuid4(), uuid4(), uuid4()] + edge_ids = [uuid4(), uuid4(), uuid4(), uuid4()] + + user = { + "id": user_id, + "email": "owner@example.com", + "display_name": "Owner", + "created_at": base_time, + } + thread = { + "id": thread_id, + "user_id": user_id, + "title": "Traceable thread", + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=4), + } + sessions = [ + { + "id": session_ids[0], + "user_id": user_id, + "thread_id": thread_id, + "status": "done", + "started_at": base_time, + "ended_at": base_time + timedelta(minutes=1), + "created_at": base_time, + }, + { + "id": session_ids[1], + "user_id": user_id, + "thread_id": thread_id, + "status": "done", + "started_at": base_time + timedelta(minutes=2), + "ended_at": base_time + timedelta(minutes=3), + "created_at": base_time + timedelta(minutes=2), + }, + { + "id": session_ids[2], + "user_id": user_id, + "thread_id": thread_id, + "status": "active", + "started_at": base_time + timedelta(minutes=4), + "ended_at": None, + "created_at": base_time + timedelta(minutes=4), + }, + ] + events = [ + { + "id": event_ids[0], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[0], + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "one"}, + "created_at": base_time, + }, + { + "id": event_ids[1], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[1], + "sequence_no": 2, + "kind": "message.assistant", + "payload": {"text": "two"}, + "created_at": base_time + timedelta(minutes=2), + }, + { + "id": event_ids[2], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[2], + "sequence_no": 3, + "kind": "message.user", + "payload": {"text": "three"}, + "created_at": base_time + timedelta(minutes=4), + }, + { + "id": event_ids[3], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[2], + "sequence_no": 4, + "kind": "message.assistant", + "payload": {"text": "four"}, + "created_at": base_time + timedelta(minutes=5), + }, + ] + memories = [ + { + "id": memory_ids[0], + "user_id": user_id, + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "active", + "source_event_ids": [str(event_ids[0])], + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=1), + "deleted_at": None, + }, + { + "id": memory_ids[1], + "user_id": user_id, + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_ids[1])], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=4), + "deleted_at": None, + }, + { + "id": memory_ids[2], + "user_id": user_id, + "memory_key": "user.preference.snacks", + "value": {"likes": "almonds"}, + "status": "active", + "source_event_ids": [str(event_ids[2])], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=5), + "deleted_at": None, + }, + ] + entities = [ + { + "id": entity_ids[0], + "user_id": user_id, + "entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(memory_ids[0])], + "created_at": base_time, + }, + { + "id": entity_ids[1], + "user_id": user_id, + "entity_type": "merchant", + "name": "Neighborhood Cafe", + "source_memory_ids": [str(memory_ids[1])], + "created_at": base_time + timedelta(minutes=3), + }, + { + "id": entity_ids[2], + "user_id": user_id, + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(memory_ids[1]), str(memory_ids[2])], + "created_at": base_time + timedelta(minutes=6), + }, + ] + entity_edges = [ + { + "id": edge_ids[0], + "user_id": user_id, + "from_entity_id": entity_ids[0], + "to_entity_id": entity_ids[1], + "relationship_type": "visits", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(memory_ids[0])], + "created_at": base_time + timedelta(minutes=2), + }, + { + "id": edge_ids[1], + "user_id": user_id, + "from_entity_id": entity_ids[2], + "to_entity_id": entity_ids[0], + "relationship_type": "references", + "valid_from": base_time + timedelta(minutes=5), + "valid_to": None, + "source_memory_ids": [str(memory_ids[2])], + "created_at": base_time + timedelta(minutes=5), + }, + { + "id": edge_ids[2], + "user_id": user_id, + "from_entity_id": entity_ids[1], + "to_entity_id": entity_ids[2], + "relationship_type": "works_on", + "valid_from": None, + "valid_to": base_time + timedelta(minutes=8), + "source_memory_ids": [str(memory_ids[1]), str(memory_ids[2])], + "created_at": base_time + timedelta(minutes=8), + }, + { + "id": edge_ids[3], + "user_id": user_id, + "from_entity_id": entity_ids[0], + "to_entity_id": entity_ids[0], + "relationship_type": "self_loop", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(memory_ids[0])], + "created_at": base_time + timedelta(minutes=9), + }, + ] + limits = ContextCompilerLimits( + max_sessions=2, + max_events=2, + max_memories=2, + max_entities=2, + max_entity_edges=2, + ) + + first_run = compile_continuity_context( + user=user, + thread=thread, + sessions=sessions, + events=events, + memories=memories, + entities=entities, + entity_edges=entity_edges, + limits=limits, + ) + second_run = compile_continuity_context( + user=user, + thread=thread, + sessions=sessions, + events=events, + memories=memories, + entities=entities, + entity_edges=entity_edges, + limits=limits, + ) + + assert first_run.context_pack == second_run.context_pack + assert first_run.trace_events == second_run.trace_events + assert [session["id"] for session in first_run.context_pack["sessions"]] == [ + str(session_ids[1]), + str(session_ids[2]), + ] + assert [event["sequence_no"] for event in first_run.context_pack["events"]] == [3, 4] + assert [memory["memory_key"] for memory in first_run.context_pack["memories"]] == [ + "user.preference.coffee", + "user.preference.snacks", + ] + assert [memory["source_provenance"] for memory in first_run.context_pack["memories"]] == [ + {"sources": ["symbolic"], "semantic_score": None}, + {"sources": ["symbolic"], "semantic_score": None}, + ] + assert [entity["id"] for entity in first_run.context_pack["entities"]] == [ + str(entity_ids[1]), + str(entity_ids[2]), + ] + assert [edge["id"] for edge in first_run.context_pack["entity_edges"]] == [ + str(edge_ids[1]), + str(edge_ids[2]), + ] + assert first_run.context_pack["memory_summary"] == { + "candidate_count": 2, + "included_count": 2, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 2, + "semantic_selected_count": 0, + "merged_candidate_count": 2, + "deduplicated_count": 0, + "included_symbolic_only_count": 2, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert first_run.context_pack["entity_summary"] == { + "candidate_count": 3, + "included_count": 2, + "excluded_limit_count": 1, + } + assert first_run.context_pack["entity_edge_summary"] == { + "anchor_entity_count": 2, + "candidate_count": 3, + "included_count": 2, + "excluded_limit_count": 1, + } + + +def test_compile_continuity_context_records_included_and_excluded_reasons() -> None: + user_id = uuid4() + thread_id = uuid4() + base_time = datetime(2026, 3, 11, 9, 0, tzinfo=UTC) + kept_session_id = uuid4() + dropped_session_id = uuid4() + dropped_by_session_event_id = uuid4() + dropped_by_event_limit_id = uuid4() + kept_event_id = uuid4() + dropped_by_memory_limit_id = uuid4() + kept_memory_id = uuid4() + deleted_memory_id = uuid4() + dropped_entity_id = uuid4() + kept_entity_id = uuid4() + dropped_entity_edge_id = uuid4() + kept_entity_edge_id = uuid4() + ignored_entity_edge_id = uuid4() + external_entity_id = uuid4() + kept_edge_valid_from = base_time + timedelta(minutes=5) + + compiler_run = compile_continuity_context( + user={ + "id": user_id, + "email": "owner@example.com", + "display_name": "Owner", + "created_at": base_time, + }, + thread={ + "id": thread_id, + "user_id": user_id, + "title": "Traceable thread", + "created_at": base_time, + "updated_at": base_time, + }, + sessions=[ + { + "id": dropped_session_id, + "user_id": user_id, + "thread_id": thread_id, + "status": "done", + "started_at": base_time, + "ended_at": base_time, + "created_at": base_time, + }, + { + "id": kept_session_id, + "user_id": user_id, + "thread_id": thread_id, + "status": "active", + "started_at": base_time + timedelta(minutes=1), + "ended_at": None, + "created_at": base_time + timedelta(minutes=1), + }, + ], + events=[ + { + "id": dropped_by_session_event_id, + "user_id": user_id, + "thread_id": thread_id, + "session_id": dropped_session_id, + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "old session"}, + "created_at": base_time, + }, + { + "id": dropped_by_event_limit_id, + "user_id": user_id, + "thread_id": thread_id, + "session_id": kept_session_id, + "sequence_no": 2, + "kind": "message.assistant", + "payload": {"text": "too old"}, + "created_at": base_time + timedelta(minutes=1), + }, + { + "id": kept_event_id, + "user_id": user_id, + "thread_id": thread_id, + "session_id": kept_session_id, + "sequence_no": 3, + "kind": "message.user", + "payload": {"text": "keep"}, + "created_at": base_time + timedelta(minutes=2), + }, + ], + memories=[ + { + "id": dropped_by_memory_limit_id, + "user_id": user_id, + "memory_key": "user.preference.old", + "value": {"likes": "black"}, + "status": "active", + "source_event_ids": [str(dropped_by_session_event_id)], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + }, + { + "id": kept_memory_id, + "user_id": user_id, + "memory_key": "user.preference.keep", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(kept_event_id)], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=2), + "deleted_at": None, + }, + { + "id": deleted_memory_id, + "user_id": user_id, + "memory_key": "user.preference.deleted", + "value": {"likes": "espresso"}, + "status": "deleted", + "source_event_ids": [str(dropped_by_event_limit_id)], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=3), + "deleted_at": base_time + timedelta(minutes=3), + }, + ], + entities=[ + { + "id": dropped_entity_id, + "user_id": user_id, + "entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(dropped_by_memory_limit_id)], + "created_at": base_time, + }, + { + "id": kept_entity_id, + "user_id": user_id, + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(kept_memory_id)], + "created_at": base_time + timedelta(minutes=4), + }, + ], + entity_edges=[ + { + "id": dropped_entity_edge_id, + "user_id": user_id, + "from_entity_id": dropped_entity_id, + "to_entity_id": kept_entity_id, + "relationship_type": "related_to", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "created_at": base_time + timedelta(minutes=3), + }, + { + "id": kept_entity_edge_id, + "user_id": user_id, + "from_entity_id": kept_entity_id, + "to_entity_id": external_entity_id, + "relationship_type": "depends_on", + "valid_from": kept_edge_valid_from, + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "created_at": base_time + timedelta(minutes=5), + }, + { + "id": ignored_entity_edge_id, + "user_id": user_id, + "from_entity_id": dropped_entity_id, + "to_entity_id": external_entity_id, + "relationship_type": "ignored", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(dropped_by_memory_limit_id)], + "created_at": base_time + timedelta(minutes=6), + }, + ], + limits=ContextCompilerLimits( + max_sessions=1, + max_events=1, + max_memories=1, + max_entities=1, + max_entity_edges=1, + ), + ) + + trace_payloads = [trace_event.payload for trace_event in compiler_run.trace_events] + + assert {"entity_type": "session", "entity_id": str(kept_session_id), "reason": "within_session_limit", "position": 1} in trace_payloads + assert {"entity_type": "session", "entity_id": str(dropped_session_id), "reason": "session_limit_exceeded", "position": 1} in trace_payloads + assert {"entity_type": "event", "entity_id": str(dropped_by_session_event_id), "reason": "session_not_included", "position": 1} in trace_payloads + assert {"entity_type": "event", "entity_id": str(dropped_by_event_limit_id), "reason": "event_limit_exceeded", "position": 2} in trace_payloads + assert {"entity_type": "event", "entity_id": str(kept_event_id), "reason": "within_event_limit", "position": 3} in trace_payloads + assert { + "entity_type": "memory", + "entity_id": str(kept_memory_id), + "reason": "within_hybrid_memory_limit", + "position": 1, + "memory_key": "user.preference.keep", + "status": "active", + "source_event_ids": [str(kept_event_id)], + "embedding_config_id": None, + "selected_sources": ["symbolic"], + "semantic_score": None, + } in trace_payloads + assert { + "entity_type": "memory", + "entity_id": str(deleted_memory_id), + "reason": "hybrid_memory_deleted", + "position": 1, + "memory_key": "user.preference.deleted", + "status": "deleted", + "source_event_ids": [str(dropped_by_event_limit_id)], + "embedding_config_id": None, + "selected_sources": ["symbolic"], + "semantic_score": None, + } in trace_payloads + assert { + "entity_type": "entity", + "entity_id": str(dropped_entity_id), + "reason": "entity_limit_exceeded", + "position": 1, + "record_entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(dropped_by_memory_limit_id)], + } in trace_payloads + assert { + "entity_type": "entity", + "entity_id": str(kept_entity_id), + "reason": "within_entity_limit", + "position": 2, + "record_entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(kept_memory_id)], + } in trace_payloads + assert { + "entity_type": "entity_edge", + "entity_id": str(dropped_entity_edge_id), + "reason": "entity_edge_limit_exceeded", + "position": 1, + "from_entity_id": str(dropped_entity_id), + "to_entity_id": str(kept_entity_id), + "relationship_type": "related_to", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "attached_included_entity_ids": [str(kept_entity_id)], + } in trace_payloads + assert { + "entity_type": "entity_edge", + "entity_id": str(kept_entity_edge_id), + "reason": "within_entity_edge_limit", + "position": 2, + "from_entity_id": str(kept_entity_id), + "to_entity_id": str(external_entity_id), + "relationship_type": "depends_on", + "valid_from": kept_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "attached_included_entity_ids": [str(kept_entity_id)], + } in trace_payloads + assert all(payload.get("entity_id") != str(ignored_entity_edge_id) for payload in trace_payloads) + assert compiler_run.trace_events[-1].kind == SUMMARY_TRACE_EVENT_KIND + assert compiler_run.context_pack["events"][0]["id"] == str(kept_event_id) + assert compiler_run.context_pack["memories"] == [ + { + "id": str(kept_memory_id), + "memory_key": "user.preference.keep", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(kept_event_id)], + "created_at": (base_time + timedelta(minutes=1)).isoformat(), + "updated_at": (base_time + timedelta(minutes=2)).isoformat(), + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ] + assert compiler_run.context_pack["memory_summary"] == { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert compiler_run.context_pack["entities"] == [ + { + "id": str(kept_entity_id), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(kept_memory_id)], + "created_at": (base_time + timedelta(minutes=4)).isoformat(), + } + ] + assert compiler_run.context_pack["entity_edges"] == [ + { + "id": str(kept_entity_edge_id), + "from_entity_id": str(kept_entity_id), + "to_entity_id": str(external_entity_id), + "relationship_type": "depends_on", + "valid_from": kept_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "created_at": (base_time + timedelta(minutes=5)).isoformat(), + } + ] + assert compiler_run.context_pack["entity_edge_summary"] == { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + } + assert compiler_run.trace_events[-1].payload["included_entity_edge_count"] == 1 + assert compiler_run.trace_events[-1].payload["excluded_entity_edge_limit_count"] == 1 + assert compiler_run.trace_events[-1].payload["hybrid_memory_requested"] is False + assert compiler_run.trace_events[-1].payload["hybrid_memory_candidate_count"] == 2 + assert compiler_run.trace_events[-1].payload["hybrid_memory_merged_candidate_count"] == 1 + assert compiler_run.trace_events[-1].payload["hybrid_memory_deduplicated_count"] == 0 + + +class SemanticCompileStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 12, 0, tzinfo=UTC) + self.config_id = uuid4() + self.memory_ids = [uuid4(), uuid4(), uuid4()] + self.event_ids = [uuid4(), uuid4(), uuid4()] + + def get_embedding_config_optional(self, embedding_config_id): + if embedding_config_id != self.config_id: + return None + return {"id": self.config_id, "dimensions": 3} + + def retrieve_semantic_memory_matches(self, *, embedding_config_id, query_vector, limit): + assert embedding_config_id == self.config_id + assert query_vector == [1.0, 0.0, 0.0] + assert limit > 1000 + return [ + { + "id": self.memory_ids[0], + "user_id": uuid4(), + "memory_key": "user.preference.breakfast", + "value": {"likes": "porridge"}, + "status": "active", + "source_event_ids": [str(self.event_ids[0])], + "created_at": self.base_time, + "updated_at": self.base_time, + "deleted_at": None, + "score": 1.0, + }, + { + "id": self.memory_ids[1], + "user_id": uuid4(), + "memory_key": "user.preference.lunch", + "value": {"likes": "ramen"}, + "status": "active", + "source_event_ids": [str(self.event_ids[1])], + "created_at": self.base_time + timedelta(minutes=1), + "updated_at": self.base_time + timedelta(minutes=1), + "deleted_at": None, + "score": 1.0, + }, + ] + + def list_memory_embeddings_for_config(self, embedding_config_id): + assert embedding_config_id == self.config_id + return [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": self.memory_ids[2], + "embedding_config_id": self.config_id, + "dimensions": 3, + "vector": [1.0, 0.0, 0.0], + "created_at": self.base_time + timedelta(minutes=2), + "updated_at": self.base_time + timedelta(minutes=2), + } + ] + + +def test_compile_memory_section_orders_limits_and_excludes_deleted() -> None: + store = SemanticCompileStoreStub() + deleted_memory = { + "id": store.memory_ids[2], + "user_id": uuid4(), + "memory_key": "user.preference.deleted", + "value": {"likes": "hidden"}, + "status": "deleted", + "source_event_ids": [str(store.event_ids[2])], + "created_at": store.base_time + timedelta(minutes=2), + "updated_at": store.base_time + timedelta(minutes=3), + "deleted_at": store.base_time + timedelta(minutes=3), + } + + memory_section = _compile_memory_section( + store, # type: ignore[arg-type] + memories=[deleted_memory], + limits=ContextCompilerLimits(max_memories=1), + semantic_retrieval=CompileContextSemanticRetrievalInput( + embedding_config_id=store.config_id, + query_vector=(1.0, 0.0, 0.0), + limit=1, + ), + ) + + assert memory_section.items == [ + { + "id": str(store.memory_ids[0]), + "memory_key": "user.preference.breakfast", + "value": {"likes": "porridge"}, + "status": "active", + "source_event_ids": [str(store.event_ids[0])], + "created_at": store.base_time.isoformat(), + "updated_at": store.base_time.isoformat(), + "source_provenance": { + "sources": ["semantic"], + "semantic_score": 1.0, + }, + } + ] + assert memory_section.summary == { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(store.config_id), + "query_vector_dimensions": 3, + "semantic_limit": 1, + "symbolic_selected_count": 0, + "semantic_selected_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 1, + "included_dual_source_count": 0, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert [decision.reason for decision in memory_section.decisions] == [ + "within_hybrid_memory_limit", + "hybrid_memory_deleted", + ] + assert memory_section.decisions[-1].metadata["selected_sources"] == ["symbolic"] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..6d10d22 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest + +from alicebot_api.config import Settings + + +def test_settings_defaults(monkeypatch): + for key in ( + "APP_ENV", + "APP_HOST", + "APP_PORT", + "DATABASE_URL", + "DATABASE_ADMIN_URL", + "REDIS_URL", + "S3_ENDPOINT_URL", + "S3_ACCESS_KEY", + "S3_SECRET_KEY", + "S3_BUCKET", + "HEALTHCHECK_TIMEOUT_SECONDS", + "MODEL_PROVIDER", + "MODEL_BASE_URL", + "MODEL_NAME", + "MODEL_API_KEY", + "MODEL_TIMEOUT_SECONDS", + "TASK_WORKSPACE_ROOT", + ): + monkeypatch.delenv(key, raising=False) + + settings = Settings.from_env() + + assert settings.app_env == "development" + assert settings.app_port == 8000 + assert settings.database_url.endswith("/alicebot") + assert settings.database_admin_url.endswith("/alicebot") + assert settings.s3_bucket == "alicebot-local" + assert settings.model_provider == "openai_responses" + assert settings.model_base_url == "https://api.openai.com/v1" + assert settings.model_name == "gpt-5-mini" + assert settings.model_timeout_seconds == 30 + assert settings.task_workspace_root == "/tmp/alicebot/task-workspaces" + + +def test_settings_honor_environment_overrides(monkeypatch): + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("APP_PORT", "8100") + monkeypatch.setenv("DATABASE_URL", "postgresql://app:secret@localhost:5432/custom") + monkeypatch.setenv("HEALTHCHECK_TIMEOUT_SECONDS", "9") + monkeypatch.setenv("MODEL_BASE_URL", "https://example.test/v1") + monkeypatch.setenv("MODEL_NAME", "gpt-5") + monkeypatch.setenv("MODEL_TIMEOUT_SECONDS", "45") + monkeypatch.setenv("TASK_WORKSPACE_ROOT", "/tmp/custom-workspaces") + + settings = Settings.from_env() + + assert settings.app_env == "test" + assert settings.app_port == 8100 + assert settings.database_url == "postgresql://app:secret@localhost:5432/custom" + assert settings.healthcheck_timeout_seconds == 9 + assert settings.model_base_url == "https://example.test/v1" + assert settings.model_name == "gpt-5" + assert settings.model_timeout_seconds == 45 + assert settings.task_workspace_root == "/tmp/custom-workspaces" + + +def test_settings_can_be_loaded_from_an_explicit_environment_mapping() -> None: + settings = Settings.from_env( + { + "APP_ENV": "test", + "APP_PORT": "8200", + "DATABASE_URL": "postgresql://app:secret@localhost:5432/mapped", + "MODEL_PROVIDER": "openai_responses", + "MODEL_NAME": "gpt-5-mini", + "TASK_WORKSPACE_ROOT": "/tmp/mapped-workspaces", + } + ) + + assert settings.app_env == "test" + assert settings.app_port == 8200 + assert settings.database_url == "postgresql://app:secret@localhost:5432/mapped" + assert settings.model_provider == "openai_responses" + assert settings.model_name == "gpt-5-mini" + assert settings.task_workspace_root == "/tmp/mapped-workspaces" + + +def test_settings_raise_clear_error_for_invalid_integer_values() -> None: + with pytest.raises(ValueError, match="APP_PORT must be an integer"): + Settings.from_env({"APP_PORT": "not-an-integer"}) + + with pytest.raises(ValueError, match="MODEL_TIMEOUT_SECONDS must be an integer"): + Settings.from_env({"MODEL_TIMEOUT_SECONDS": "not-an-integer"}) diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py new file mode 100644 index 0000000..95559eb --- /dev/null +++ b/tests/unit/test_db.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from collections.abc import Iterator +from uuid import uuid4 + +import psycopg + +from alicebot_api import db + + +class RecordingCursor: + def __init__(self) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> tuple[int]: + return (1,) + + +class TransactionContext: + def __init__(self) -> None: + self.entered = False + self.exited = False + + def __enter__(self) -> None: + self.entered = True + return None + + def __exit__(self, exc_type, exc, tb) -> None: + self.exited = True + return None + + +class RecordingConnection: + def __init__(self) -> None: + self.cursor_instance = RecordingCursor() + self.transaction_context = TransactionContext() + + def __enter__(self) -> "RecordingConnection": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + def transaction(self) -> TransactionContext: + return self.transaction_context + + +def test_ping_database_returns_true_when_select_succeeds(monkeypatch) -> None: + connection = RecordingConnection() + captured: dict[str, object] = {} + + def fake_connect(database_url: str, **kwargs: object) -> RecordingConnection: + captured["database_url"] = database_url + captured["kwargs"] = kwargs + return connection + + monkeypatch.setattr(db.psycopg, "connect", fake_connect) + + assert db.ping_database("postgresql://example", timeout_seconds=3) is True + assert captured["database_url"] == "postgresql://example" + assert captured["kwargs"] == {"connect_timeout": 3} + assert connection.cursor_instance.executed == [("SELECT 1", None)] + + +def test_ping_database_returns_false_on_psycopg_error(monkeypatch) -> None: + def fake_connect(_database_url: str, **_kwargs: object) -> RecordingConnection: + raise psycopg.Error("boom") + + monkeypatch.setattr(db.psycopg, "connect", fake_connect) + + assert db.ping_database("postgresql://example", timeout_seconds=3) is False + + +def test_set_current_user_sets_database_context() -> None: + connection = RecordingConnection() + user_id = uuid4() + + db.set_current_user(connection, user_id) + + assert connection.cursor_instance.executed == [ + ("SELECT set_config('app.current_user_id', %s, true)", (str(user_id),)), + ] + + +def test_user_connection_sets_current_user_inside_transaction(monkeypatch) -> None: + connection = RecordingConnection() + user_id = uuid4() + captured: dict[str, object] = {} + set_current_user_calls: list[tuple[RecordingConnection, object]] = [] + + def fake_connect(database_url: str, **kwargs: object) -> RecordingConnection: + captured["database_url"] = database_url + captured["kwargs"] = kwargs + return connection + + def fake_set_current_user(conn: RecordingConnection, current_user_id: object) -> None: + set_current_user_calls.append((conn, current_user_id)) + + monkeypatch.setattr(db.psycopg, "connect", fake_connect) + monkeypatch.setattr(db, "set_current_user", fake_set_current_user) + + with db.user_connection("postgresql://example", user_id) as conn: + assert conn is connection + assert connection.transaction_context.entered is True + assert connection.transaction_context.exited is False + + assert captured["database_url"] == "postgresql://example" + assert captured["kwargs"] == {"row_factory": db.dict_row} + assert set_current_user_calls == [(connection, user_id)] + assert connection.transaction_context.exited is True diff --git a/tests/unit/test_embedding.py b/tests/unit/test_embedding.py new file mode 100644 index 0000000..44401d4 --- /dev/null +++ b/tests/unit/test_embedding.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import psycopg +import pytest + +from alicebot_api.contracts import EmbeddingConfigCreateInput, MemoryEmbeddingUpsertInput +from alicebot_api.embedding import ( + EmbeddingConfigValidationError, + MemoryEmbeddingNotFoundError, + MemoryEmbeddingValidationError, + create_embedding_config_record, + get_memory_embedding_record, + list_embedding_config_records, + list_memory_embedding_records, + upsert_memory_embedding_record, +) + + +class EmbeddingStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.memories: dict[UUID, dict[str, object]] = {} + self.configs: list[dict[str, object]] = [] + self.config_by_id: dict[UUID, dict[str, object]] = {} + self.embeddings: list[dict[str, object]] = [] + self.embedding_by_id: dict[UUID, dict[str, object]] = {} + + def create_embedding_config( + self, + *, + provider: str, + model: str, + version: str, + dimensions: int, + status: str, + metadata: dict[str, object], + ) -> dict[str, object]: + config_id = uuid4() + record = { + "id": config_id, + "user_id": uuid4(), + "provider": provider, + "model": model, + "version": version, + "dimensions": dimensions, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.configs)), + } + self.configs.append(record) + self.config_by_id[config_id] = record + return record + + def list_embedding_configs(self) -> list[dict[str, object]]: + return list(self.configs) + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> dict[str, object] | None: + return self.config_by_id.get(embedding_config_id) + + def get_embedding_config_by_identity_optional( + self, + *, + provider: str, + model: str, + version: str, + ) -> dict[str, object] | None: + for config in self.configs: + if ( + config["provider"] == provider + and config["model"] == model + and config["version"] == version + ): + return config + return None + + def get_memory_optional(self, memory_id: UUID) -> dict[str, object] | None: + return self.memories.get(memory_id) + + def get_memory_embedding_by_memory_and_config_optional( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + ) -> dict[str, object] | None: + for embedding in self.embeddings: + if ( + embedding["memory_id"] == memory_id + and embedding["embedding_config_id"] == embedding_config_id + ): + return embedding + return None + + def create_memory_embedding( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + dimensions: int, + vector: list[float], + ) -> dict[str, object]: + embedding_id = uuid4() + record = { + "id": embedding_id, + "user_id": uuid4(), + "memory_id": memory_id, + "embedding_config_id": embedding_config_id, + "dimensions": dimensions, + "vector": vector, + "created_at": self.base_time + timedelta(minutes=len(self.embeddings)), + "updated_at": self.base_time + timedelta(minutes=len(self.embeddings)), + } + self.embeddings.append(record) + self.embedding_by_id[embedding_id] = record + return record + + def update_memory_embedding( + self, + *, + memory_embedding_id: UUID, + dimensions: int, + vector: list[float], + ) -> dict[str, object]: + record = self.embedding_by_id[memory_embedding_id] + updated = { + **record, + "dimensions": dimensions, + "vector": vector, + "updated_at": self.base_time + timedelta(minutes=10), + } + self.embedding_by_id[memory_embedding_id] = updated + for index, existing in enumerate(self.embeddings): + if existing["id"] == memory_embedding_id: + self.embeddings[index] = updated + return updated + + def get_memory_embedding_optional(self, memory_embedding_id: UUID) -> dict[str, object] | None: + return self.embedding_by_id.get(memory_embedding_id) + + def list_memory_embeddings_for_memory(self, memory_id: UUID) -> list[dict[str, object]]: + return [embedding for embedding in self.embeddings if embedding["memory_id"] == memory_id] + + +def seed_memory(store: EmbeddingStoreStub) -> UUID: + memory_id = uuid4() + store.memories[memory_id] = { + "id": memory_id, + "memory_key": "user.preference.coffee", + } + return memory_id + + +def seed_config(store: EmbeddingStoreStub, *, dimensions: int = 3) -> UUID: + created = store.create_embedding_config( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=dimensions, + status="active", + metadata={"task": "memory_retrieval"}, + ) + return created["id"] # type: ignore[return-value] + + +def test_create_and_list_embedding_configs_return_deterministic_shape() -> None: + store = EmbeddingStoreStub() + first = create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-small", + version="2026-03-11", + dimensions=1536, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + second = create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="deprecated", + metadata={"task": "memory_retrieval"}, + ), + ) + + payload = list_embedding_config_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + ) + + assert first["embedding_config"]["provider"] == "openai" + assert second["embedding_config"]["status"] == "deprecated" + assert payload == { + "items": [ + first["embedding_config"], + second["embedding_config"], + ], + "summary": { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_create_embedding_config_rejects_duplicate_provider_model_version() -> None: + store = EmbeddingStoreStub() + create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + + with pytest.raises( + EmbeddingConfigValidationError, + match="embedding config already exists for provider/model/version under the user scope", + ): + create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + + +def test_create_embedding_config_translates_database_unique_violation_into_validation_error() -> None: + class DuplicateConfigStoreStub(EmbeddingStoreStub): + def get_embedding_config_by_identity_optional( + self, + *, + provider: str, + model: str, + version: str, + ) -> dict[str, object] | None: + return None + + def create_embedding_config( + self, + *, + provider: str, + model: str, + version: str, + dimensions: int, + status: str, + metadata: dict[str, object], + ) -> dict[str, object]: + raise psycopg.errors.UniqueViolation("duplicate key value violates unique constraint") + + with pytest.raises( + EmbeddingConfigValidationError, + match="embedding config already exists for provider/model/version under the user scope", + ): + create_embedding_config_record( + DuplicateConfigStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + + +def test_upsert_memory_embedding_creates_then_updates_existing_record() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + config_id = seed_config(store, dimensions=3) + + created = upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + updated = upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.3, 0.2, 0.1), + ), + ) + + assert created["write_mode"] == "created" + assert created["embedding"]["vector"] == [0.1, 0.2, 0.3] + assert updated["write_mode"] == "updated" + assert updated["embedding"]["id"] == created["embedding"]["id"] + assert updated["embedding"]["vector"] == [0.3, 0.2, 0.1] + + +def test_upsert_memory_embedding_rejects_missing_memory() -> None: + store = EmbeddingStoreStub() + config_id = seed_config(store) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="memory_id must reference an existing memory owned by the user", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=uuid4(), + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + +def test_upsert_memory_embedding_rejects_missing_embedding_config() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="embedding_config_id must reference an existing embedding config owned by the user", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=uuid4(), + vector=(0.1, 0.2, 0.3), + ), + ) + + +def test_upsert_memory_embedding_rejects_dimension_mismatch_and_non_finite_values() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + config_id = seed_config(store, dimensions=2) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="vector length must match embedding config dimensions", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="vector must contain only finite numeric values", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, float("inf")), + ), + ) + + +def test_memory_embedding_reads_return_deterministic_shape_and_not_found() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + config_id = seed_config(store, dimensions=3) + created = upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + listed = list_memory_embedding_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=memory_id, + ) + detail = get_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_embedding_id=UUID(created["embedding"]["id"]), + ) + + assert listed == { + "items": [created["embedding"]], + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert detail == {"embedding": created["embedding"]} + + with pytest.raises(MemoryEmbeddingNotFoundError, match="memory .* was not found"): + list_memory_embedding_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=uuid4(), + ) + + with pytest.raises(MemoryEmbeddingNotFoundError, match="memory embedding .* was not found"): + get_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_embedding_id=uuid4(), + ) diff --git a/tests/unit/test_embedding_store.py b/tests/unit/test_embedding_store.py new file mode 100644 index 0000000..5a2b695 --- /dev/null +++ b/tests/unit/test_embedding_store.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_embedding_store_methods_use_expected_queries_and_serialization() -> None: + config_id = uuid4() + memory_id = uuid4() + embedding_id = uuid4() + created_at = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + updated_at = datetime(2026, 3, 12, 9, 5, tzinfo=UTC) + cursor = RecordingCursor( + fetchone_results=[ + { + "id": config_id, + "user_id": uuid4(), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + "created_at": created_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "memory_id": memory_id, + "embedding_config_id": config_id, + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": created_at, + "updated_at": created_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "memory_id": memory_id, + "embedding_config_id": config_id, + "dimensions": 3, + "vector": [0.3, 0.2, 0.1], + "created_at": created_at, + "updated_at": updated_at, + }, + ], + fetchall_results=[ + [ + { + "id": config_id, + "provider": "openai", + "version": "2026-03-12", + } + ], + [ + { + "id": embedding_id, + "memory_id": memory_id, + "embedding_config_id": config_id, + } + ], + [ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(uuid4())], + "created_at": created_at, + "updated_at": updated_at, + "deleted_at": None, + "score": 1.0, + } + ], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created_config = store.create_embedding_config( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + status="active", + metadata={"task": "memory_retrieval"}, + ) + listed_configs = store.list_embedding_configs() + created_embedding = store.create_memory_embedding( + memory_id=memory_id, + embedding_config_id=config_id, + dimensions=3, + vector=[0.1, 0.2, 0.3], + ) + updated_embedding = store.update_memory_embedding( + memory_embedding_id=embedding_id, + dimensions=3, + vector=[0.3, 0.2, 0.1], + ) + listed_embeddings = store.list_memory_embeddings_for_memory(memory_id) + retrieval_matches = store.retrieve_semantic_memory_matches( + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=5, + ) + + assert created_config["id"] == config_id + assert listed_configs == [{"id": config_id, "provider": "openai", "version": "2026-03-12"}] + assert created_embedding["id"] == embedding_id + assert updated_embedding["updated_at"] == updated_at + assert listed_embeddings == [ + {"id": embedding_id, "memory_id": memory_id, "embedding_config_id": config_id} + ] + assert len(retrieval_matches) == 1 + assert retrieval_matches[0]["id"] == memory_id + assert retrieval_matches[0]["memory_key"] == "user.preference.coffee" + assert retrieval_matches[0]["status"] == "active" + assert retrieval_matches[0]["score"] == 1.0 + + create_config_query, create_config_params = cursor.executed[0] + assert "INSERT INTO embedding_configs" in create_config_query + assert create_config_params is not None + assert create_config_params[:5] == ( + "openai", + "text-embedding-3-large", + "2026-03-12", + 3, + "active", + ) + assert isinstance(create_config_params[5], Jsonb) + assert create_config_params[5].obj == {"task": "memory_retrieval"} + + list_config_query, list_config_params = cursor.executed[1] + assert "FROM embedding_configs" in list_config_query + assert "ORDER BY created_at ASC, id ASC" in list_config_query + assert list_config_params is None + + create_embedding_query, create_embedding_params = cursor.executed[2] + assert "INSERT INTO memory_embeddings" in create_embedding_query + assert create_embedding_params is not None + assert create_embedding_params[:3] == (memory_id, config_id, 3) + assert isinstance(create_embedding_params[3], Jsonb) + assert create_embedding_params[3].obj == [0.1, 0.2, 0.3] + + update_embedding_query, update_embedding_params = cursor.executed[3] + assert "UPDATE memory_embeddings" in update_embedding_query + assert update_embedding_params is not None + assert update_embedding_params[0] == 3 + assert isinstance(update_embedding_params[1], Jsonb) + assert update_embedding_params[1].obj == [0.3, 0.2, 0.1] + assert update_embedding_params[2] == embedding_id + + list_embedding_query, list_embedding_params = cursor.executed[4] + assert "FROM memory_embeddings" in list_embedding_query + assert "ORDER BY created_at ASC, id ASC" in list_embedding_query + assert list_embedding_params == (memory_id,) + + retrieval_query, retrieval_params = cursor.executed[5] + assert "replace(memory_embeddings.vector::text, ' ', '')::vector <=> %s::vector" in retrieval_query + assert "JOIN memories" in retrieval_query + assert "memories.status = 'active'" in retrieval_query + assert "ORDER BY score DESC, memories.created_at ASC, memories.id ASC" in retrieval_query + assert retrieval_params == ("[0.1,0.2,0.3]", config_id, 3, 5) + + +def test_embedding_store_optional_reads_return_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_embedding_config_optional(uuid4()) is None + assert store.get_embedding_config_by_identity_optional( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + ) is None + assert store.get_memory_embedding_optional(uuid4()) is None + assert store.get_memory_embedding_by_memory_and_config_optional( + memory_id=uuid4(), + embedding_config_id=uuid4(), + ) is None diff --git a/tests/unit/test_entity.py b/tests/unit/test_entity.py new file mode 100644 index 0000000..c417b55 --- /dev/null +++ b/tests/unit/test_entity.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import EntityCreateInput +from alicebot_api.entity import ( + EntityNotFoundError, + EntityValidationError, + create_entity_record, + get_entity_record, + list_entity_records, +) + + +class EntityStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.memories: dict[UUID, dict[str, object]] = {} + self.created_entities: list[dict[str, object]] = [] + self.entity_by_id: dict[UUID, dict[str, object]] = {} + + def list_memories_by_ids(self, memory_ids: list[UUID]) -> list[dict[str, object]]: + return [self.memories[memory_id] for memory_id in memory_ids if memory_id in self.memories] + + def create_entity( + self, + *, + entity_type: str, + name: str, + source_memory_ids: list[str], + ) -> dict[str, object]: + entity_id = uuid4() + entity = { + "id": entity_id, + "user_id": uuid4(), + "entity_type": entity_type, + "name": name, + "source_memory_ids": source_memory_ids, + "created_at": self.base_time + timedelta(minutes=len(self.created_entities)), + } + self.created_entities.append(entity) + self.entity_by_id[entity_id] = entity + return entity + + def list_entities(self) -> list[dict[str, object]]: + return list(self.created_entities) + + def get_entity_optional(self, entity_id: UUID) -> dict[str, object] | None: + return self.entity_by_id.get(entity_id) + + +def seed_memory(store: EntityStoreStub) -> UUID: + memory_id = uuid4() + store.memories[memory_id] = { + "id": memory_id, + "memory_key": "user.preference.coffee", + } + return memory_id + + +def test_create_entity_record_rejects_empty_source_memory_ids() -> None: + store = EntityStoreStub() + + with pytest.raises( + EntityValidationError, + match="source_memory_ids must include at least one existing memory owned by the user", + ): + create_entity_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity=EntityCreateInput( + entity_type="person", + name="Samir", + source_memory_ids=(), + ), + ) + + +def test_create_entity_record_rejects_missing_source_memories() -> None: + store = EntityStoreStub() + + with pytest.raises( + EntityValidationError, + match="source_memory_ids must all reference existing memories owned by the user", + ): + create_entity_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity=EntityCreateInput( + entity_type="project", + name="AliceBot", + source_memory_ids=(uuid4(),), + ), + ) + + +def test_create_entity_record_creates_entity_with_deduped_source_memories() -> None: + store = EntityStoreStub() + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + + payload = create_entity_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity=EntityCreateInput( + entity_type="project", + name="AliceBot", + source_memory_ids=(first_memory_id, first_memory_id, second_memory_id), + ), + ) + + assert payload["entity"]["entity_type"] == "project" + assert payload["entity"]["name"] == "AliceBot" + assert payload["entity"]["source_memory_ids"] == [str(first_memory_id), str(second_memory_id)] + + +def test_list_entity_records_returns_deterministic_shape() -> None: + store = EntityStoreStub() + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + first_entity = store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(first_memory_id)], + ) + second_entity = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(second_memory_id)], + ) + + payload = list_entity_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + ) + + assert payload == { + "items": [ + { + "id": str(first_entity["id"]), + "entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(first_memory_id)], + "created_at": first_entity["created_at"].isoformat(), + }, + { + "id": str(second_entity["id"]), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(second_memory_id)], + "created_at": second_entity["created_at"].isoformat(), + }, + ], + "summary": { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_get_entity_record_raises_not_found_for_inaccessible_entity() -> None: + with pytest.raises(EntityNotFoundError, match="entity .* was not found"): + get_entity_record( + EntityStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + entity_id=uuid4(), + ) diff --git a/tests/unit/test_entity_edge.py b/tests/unit/test_entity_edge.py new file mode 100644 index 0000000..d30f376 --- /dev/null +++ b/tests/unit/test_entity_edge.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import EntityEdgeCreateInput +from alicebot_api.entity import EntityNotFoundError +from alicebot_api.entity_edge import ( + EntityEdgeValidationError, + create_entity_edge_record, + list_entity_edge_records, +) + + +class EntityEdgeStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.memories: dict[UUID, dict[str, object]] = {} + self.entities: dict[UUID, dict[str, object]] = {} + self.created_edges: list[dict[str, object]] = [] + + def list_memories_by_ids(self, memory_ids: list[UUID]) -> list[dict[str, object]]: + return [self.memories[memory_id] for memory_id in memory_ids if memory_id in self.memories] + + def get_entity_optional(self, entity_id: UUID) -> dict[str, object] | None: + return self.entities.get(entity_id) + + def create_entity_edge( + self, + *, + from_entity_id: UUID, + to_entity_id: UUID, + relationship_type: str, + valid_from: datetime | None, + valid_to: datetime | None, + source_memory_ids: list[str], + ) -> dict[str, object]: + edge_id = uuid4() + edge = { + "id": edge_id, + "user_id": uuid4(), + "from_entity_id": from_entity_id, + "to_entity_id": to_entity_id, + "relationship_type": relationship_type, + "valid_from": valid_from, + "valid_to": valid_to, + "source_memory_ids": source_memory_ids, + "created_at": self.base_time + timedelta(minutes=len(self.created_edges)), + } + self.created_edges.append(edge) + return edge + + def list_entity_edges_for_entity(self, entity_id: UUID) -> list[dict[str, object]]: + return [ + edge + for edge in self.created_edges + if edge["from_entity_id"] == entity_id or edge["to_entity_id"] == entity_id + ] + + +def seed_memory(store: EntityEdgeStoreStub) -> UUID: + memory_id = uuid4() + store.memories[memory_id] = { + "id": memory_id, + "memory_key": "user.project.current", + } + return memory_id + + +def seed_entity(store: EntityEdgeStoreStub) -> UUID: + entity_id = uuid4() + store.entities[entity_id] = { + "id": entity_id, + "name": "entity", + } + return entity_id + + +def test_create_entity_edge_record_rejects_missing_entities() -> None: + store = EntityEdgeStoreStub() + memory_id = seed_memory(store) + + with pytest.raises( + EntityEdgeValidationError, + match="from_entity_id must reference an existing entity owned by the user", + ): + create_entity_edge_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + edge=EntityEdgeCreateInput( + from_entity_id=uuid4(), + to_entity_id=uuid4(), + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=(memory_id,), + ), + ) + + +def test_create_entity_edge_record_rejects_invalid_temporal_range() -> None: + store = EntityEdgeStoreStub() + from_entity_id = seed_entity(store) + to_entity_id = seed_entity(store) + memory_id = seed_memory(store) + + with pytest.raises( + EntityEdgeValidationError, + match="valid_to must be greater than or equal to valid_from", + ): + create_entity_edge_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + edge=EntityEdgeCreateInput( + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from=datetime(2026, 3, 12, 11, 0, tzinfo=UTC), + valid_to=datetime(2026, 3, 12, 10, 0, tzinfo=UTC), + source_memory_ids=(memory_id,), + ), + ) + + +def test_create_entity_edge_record_creates_edge_with_deduped_source_memories() -> None: + store = EntityEdgeStoreStub() + from_entity_id = seed_entity(store) + to_entity_id = seed_entity(store) + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + valid_from = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + valid_to = datetime(2026, 3, 12, 10, 0, tzinfo=UTC) + + payload = create_entity_edge_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + edge=EntityEdgeCreateInput( + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from=valid_from, + valid_to=valid_to, + source_memory_ids=(first_memory_id, first_memory_id, second_memory_id), + ), + ) + + assert payload == { + "edge": { + "id": payload["edge"]["id"], + "from_entity_id": str(from_entity_id), + "to_entity_id": str(to_entity_id), + "relationship_type": "works_on", + "valid_from": valid_from.isoformat(), + "valid_to": valid_to.isoformat(), + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": store.created_edges[0]["created_at"].isoformat(), + } + } + + +def test_list_entity_edge_records_returns_deterministic_shape() -> None: + store = EntityEdgeStoreStub() + primary_entity_id = seed_entity(store) + secondary_entity_id = seed_entity(store) + tertiary_entity_id = seed_entity(store) + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + + first_edge = store.create_entity_edge( + from_entity_id=primary_entity_id, + to_entity_id=secondary_entity_id, + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=[str(first_memory_id)], + ) + second_edge = store.create_entity_edge( + from_entity_id=tertiary_entity_id, + to_entity_id=primary_entity_id, + relationship_type="references", + valid_from=None, + valid_to=None, + source_memory_ids=[str(second_memory_id)], + ) + + payload = list_entity_edge_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity_id=primary_entity_id, + ) + + assert payload == { + "items": [ + { + "id": str(first_edge["id"]), + "from_entity_id": str(primary_entity_id), + "to_entity_id": str(secondary_entity_id), + "relationship_type": "works_on", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(first_memory_id)], + "created_at": first_edge["created_at"].isoformat(), + }, + { + "id": str(second_edge["id"]), + "from_entity_id": str(tertiary_entity_id), + "to_entity_id": str(primary_entity_id), + "relationship_type": "references", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(second_memory_id)], + "created_at": second_edge["created_at"].isoformat(), + }, + ], + "summary": { + "entity_id": str(primary_entity_id), + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_list_entity_edge_records_raises_not_found_for_inaccessible_entity() -> None: + with pytest.raises(EntityNotFoundError, match="entity .* was not found"): + list_entity_edge_records( + EntityEdgeStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + entity_id=uuid4(), + ) diff --git a/tests/unit/test_entity_store.py b/tests/unit/test_entity_store.py new file mode 100644 index 0000000..7b377ca --- /dev/null +++ b/tests/unit/test_entity_store.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_entity_methods_use_expected_queries_and_deterministic_order() -> None: + entity_id = uuid4() + first_memory_id = uuid4() + second_memory_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": entity_id, + "user_id": uuid4(), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": "ignored", + } + ], + fetchall_results=[ + [{"id": first_memory_id}, {"id": second_memory_id}], + [{"id": entity_id, "name": "AliceBot"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(first_memory_id), str(second_memory_id)], + ) + listed_memories = store.list_memories_by_ids([first_memory_id, second_memory_id]) + listed_entities = store.list_entities() + + assert created["id"] == entity_id + assert listed_memories == [{"id": first_memory_id}, {"id": second_memory_id}] + assert listed_entities == [{"id": entity_id, "name": "AliceBot"}] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO entities" in create_query + assert create_params is not None + assert create_params[0] == "project" + assert create_params[1] == "AliceBot" + assert isinstance(create_params[2], Jsonb) + assert create_params[2].obj == [str(first_memory_id), str(second_memory_id)] + + assert cursor.executed[1] == ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = ANY(%s) + ORDER BY created_at ASC, id ASC + """, + ([first_memory_id, second_memory_id],), + ) + assert cursor.executed[2] == ( + """ + SELECT id, user_id, entity_type, name, source_memory_ids, created_at + FROM entities + ORDER BY created_at ASC, id ASC + """, + None, + ) + + +def test_get_entity_optional_returns_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_entity_optional(uuid4()) is None + + +def test_entity_edge_methods_use_expected_queries_and_deterministic_order() -> None: + edge_id = uuid4() + from_entity_id = uuid4() + to_entity_id = uuid4() + related_entity_id = uuid4() + source_memory_id = uuid4() + valid_from = datetime(2026, 3, 12, 10, 0, tzinfo=UTC) + cursor = RecordingCursor( + fetchone_results=[ + { + "id": edge_id, + "user_id": uuid4(), + "from_entity_id": from_entity_id, + "to_entity_id": to_entity_id, + "relationship_type": "works_on", + "valid_from": valid_from, + "valid_to": None, + "source_memory_ids": [str(source_memory_id)], + "created_at": "ignored", + } + ], + fetchall_results=[ + [{"id": edge_id, "relationship_type": "works_on"}], + [{"id": edge_id, "relationship_type": "works_on"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_entity_edge( + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from=valid_from, + valid_to=None, + source_memory_ids=[str(source_memory_id)], + ) + listed_edges = store.list_entity_edges_for_entity(from_entity_id) + listed_edges_for_entities = store.list_entity_edges_for_entities([from_entity_id, related_entity_id]) + + assert created["id"] == edge_id + assert listed_edges == [{"id": edge_id, "relationship_type": "works_on"}] + assert listed_edges_for_entities == [{"id": edge_id, "relationship_type": "works_on"}] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO entity_edges" in create_query + assert create_params is not None + assert create_params[0] == from_entity_id + assert create_params[1] == to_entity_id + assert create_params[2] == "works_on" + assert create_params[3] == valid_from + assert create_params[4] is None + assert isinstance(create_params[5], Jsonb) + assert create_params[5].obj == [str(source_memory_id)] + + assert cursor.executed[1] == ( + """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = %s OR to_entity_id = %s + ORDER BY created_at ASC, id ASC + """, + (from_entity_id, from_entity_id), + ) + assert cursor.executed[2] == ( + """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = ANY(%s) OR to_entity_id = ANY(%s) + ORDER BY created_at ASC, id ASC + """, + ([from_entity_id, related_entity_id], [from_entity_id, related_entity_id]), + ) diff --git a/tests/unit/test_env.py b/tests/unit/test_env.py new file mode 100644 index 0000000..b0fdb49 --- /dev/null +++ b/tests/unit/test_env.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from contextlib import contextmanager +import importlib +import sys +from typing import Any + + +MODULE_NAME = "apps.api.alembic.env" + + +class FakeAlembicConfig: + def __init__(self, sqlalchemy_url: str, section: dict[str, Any] | None = None) -> None: + self.config_file_name = "alembic.ini" + self.config_ini_section = "alembic" + self.sqlalchemy_url = sqlalchemy_url + self.section = section or {} + + def get_main_option(self, option: str) -> str: + assert option == "sqlalchemy.url" + return self.sqlalchemy_url + + def get_section(self, section_name: str, default: dict[str, Any] | None = None) -> dict[str, Any]: + assert section_name == self.config_ini_section + base = dict(default or {}) + base.update(self.section) + return base + + +class RecordingConnectable: + def __init__(self) -> None: + self.connection = object() + self.connected = False + + @contextmanager + def connect(self): + self.connected = True + yield self.connection + + +def load_env_module( + monkeypatch, + *, + offline_mode: bool, + admin_url: str | None = None, + app_url: str | None = None, + config_url: str = "postgresql://config-user:secret@localhost:5432/configdb", + config_section: dict[str, Any] | None = None, +) -> tuple[Any, dict[str, Any]]: + records: dict[str, Any] = { + "file_config_calls": [], + "configure_calls": [], + "run_migrations_calls": 0, + "begin_calls": 0, + "engine_calls": [], + } + fake_config = FakeAlembicConfig(config_url, config_section) + connectable = RecordingConnectable() + + if admin_url is None: + monkeypatch.delenv("DATABASE_ADMIN_URL", raising=False) + else: + monkeypatch.setenv("DATABASE_ADMIN_URL", admin_url) + if app_url is None: + monkeypatch.delenv("DATABASE_URL", raising=False) + else: + monkeypatch.setenv("DATABASE_URL", app_url) + + monkeypatch.setattr("logging.config.fileConfig", records["file_config_calls"].append) + monkeypatch.setattr("alembic.context.config", fake_config, raising=False) + monkeypatch.setattr("alembic.context.is_offline_mode", lambda: offline_mode, raising=False) + monkeypatch.setattr( + "alembic.context.configure", + lambda **kwargs: records["configure_calls"].append(kwargs), + raising=False, + ) + + @contextmanager + def begin_transaction(): + records["begin_calls"] += 1 + yield + + monkeypatch.setattr("alembic.context.begin_transaction", begin_transaction, raising=False) + monkeypatch.setattr( + "alembic.context.run_migrations", + lambda: records.__setitem__("run_migrations_calls", records["run_migrations_calls"] + 1), + raising=False, + ) + + def fake_engine_from_config(configuration: dict[str, Any], **kwargs: Any) -> RecordingConnectable: + records["engine_calls"].append((dict(configuration), kwargs)) + return connectable + + monkeypatch.setattr("sqlalchemy.engine_from_config", fake_engine_from_config) + + sys.modules.pop(MODULE_NAME, None) + module = importlib.import_module(MODULE_NAME) + records["connectable"] = connectable + return module, records + + +def test_normalize_sqlalchemy_url_rewrites_postgresql_scheme(monkeypatch) -> None: + module, _records = load_env_module(monkeypatch, offline_mode=True) + + assert module.normalize_sqlalchemy_url("postgresql://user:pw@localhost/db") == ( + "postgresql+psycopg://user:pw@localhost/db" + ) + assert module.normalize_sqlalchemy_url("sqlite:///tmp/test.db") == "sqlite:///tmp/test.db" + + +def test_get_url_prefers_admin_env_then_database_env_then_config(monkeypatch) -> None: + module, _records = load_env_module( + monkeypatch, + offline_mode=True, + admin_url="postgresql://admin-user:secret@localhost:5432/admin_db", + app_url="postgresql://app-user:secret@localhost:5432/app_db", + ) + + assert module.get_url() == "postgresql+psycopg://admin-user:secret@localhost:5432/admin_db" + + module, _records = load_env_module( + monkeypatch, + offline_mode=True, + admin_url=None, + app_url="postgresql://app-user:secret@localhost:5432/app_db", + ) + + assert module.get_url() == "postgresql+psycopg://app-user:secret@localhost:5432/app_db" + + module, _records = load_env_module(monkeypatch, offline_mode=True, admin_url=None, app_url=None) + + assert module.get_url() == "postgresql+psycopg://config-user:secret@localhost:5432/configdb" + + +def test_run_migrations_offline_configures_context_with_normalized_url(monkeypatch) -> None: + _module, records = load_env_module( + monkeypatch, + offline_mode=True, + admin_url="postgresql://admin-user:secret@localhost:5432/admin_db", + ) + + assert records["file_config_calls"] == ["alembic.ini"] + assert records["begin_calls"] == 1 + assert records["run_migrations_calls"] == 1 + assert records["configure_calls"] == [ + { + "url": "postgresql+psycopg://admin-user:secret@localhost:5432/admin_db", + "target_metadata": None, + "literal_binds": True, + "dialect_opts": {"paramstyle": "named"}, + } + ] + assert records["engine_calls"] == [] + + +def test_run_migrations_online_builds_engine_configuration(monkeypatch) -> None: + _module, records = load_env_module( + monkeypatch, + offline_mode=False, + app_url="postgresql://app-user:secret@localhost:5432/app_db", + config_section={"sqlalchemy.echo": "false"}, + ) + + configuration, engine_kwargs = records["engine_calls"][0] + + assert records["file_config_calls"] == ["alembic.ini"] + assert configuration == { + "sqlalchemy.echo": "false", + "sqlalchemy.url": "postgresql+psycopg://app-user:secret@localhost:5432/app_db", + } + assert engine_kwargs["prefix"] == "sqlalchemy." + assert engine_kwargs["poolclass"].__name__ == "NullPool" + assert records["connectable"].connected is True + assert records["configure_calls"] == [ + {"connection": records["connectable"].connection, "target_metadata": None} + ] + assert records["begin_calls"] == 1 + assert records["run_migrations_calls"] == 1 diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py new file mode 100644 index 0000000..7e64d9d --- /dev/null +++ b/tests/unit/test_events.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import pytest + +from alicebot_api.store import AppendOnlyViolation, ContinuityStore + + +def test_event_updates_are_rejected_by_contract(): + store = ContinuityStore(conn=None) # type: ignore[arg-type] + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.update_event("event-id", {"text": "mutated"}) + + +def test_event_deletes_are_rejected_by_contract(): + store = ContinuityStore(conn=None) # type: ignore[arg-type] + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.delete_event("event-id") + diff --git a/tests/unit/test_execution_budget_store.py b/tests/unit/test_execution_budget_store.py new file mode 100644 index 0000000..05e7b2e --- /dev/null +++ b/tests/unit/test_execution_budget_store.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_execution_budget_store_methods_use_expected_queries_and_parameters() -> None: + execution_budget_id = uuid4() + replacement_budget_id = uuid4() + row = { + "id": execution_budget_id, + "tool_key": "proxy.echo", + "domain_hint": "docs", + "max_completed_executions": 2, + "rolling_window_seconds": 3600, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + cursor = RecordingCursor( + fetchone_results=[ + row, + row, + {**row, "status": "inactive"}, + {**row, "status": "superseded", "superseded_by_budget_id": replacement_budget_id}, + ], + fetchall_result=[row], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + rolling_window_seconds=3600, + ) + fetched = store.get_execution_budget_optional(execution_budget_id) + listed = store.list_execution_budgets() + deactivated = store.deactivate_execution_budget_optional(execution_budget_id) + superseded = store.supersede_execution_budget_optional( + execution_budget_id=execution_budget_id, + superseded_by_budget_id=replacement_budget_id, + ) + + assert created["id"] == execution_budget_id + assert fetched is not None + assert fetched["id"] == execution_budget_id + assert listed[0]["id"] == execution_budget_id + assert deactivated is not None + assert deactivated["status"] == "inactive" + assert superseded is not None + assert superseded["superseded_by_budget_id"] == replacement_budget_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO execution_budgets" in create_query + assert create_params == (None, "proxy.echo", "docs", 2, 3600, None) + assert "FROM execution_budgets" in cursor.executed[1][0] + assert "ORDER BY created_at ASC, id ASC" in cursor.executed[2][0] + assert "UPDATE execution_budgets" in cursor.executed[3][0] + assert cursor.executed[4][1] == (replacement_budget_id, execution_budget_id) diff --git a/tests/unit/test_execution_budgets.py b/tests/unit/test_execution_budgets.py new file mode 100644 index 0000000..752f7d1 --- /dev/null +++ b/tests/unit/test_execution_budgets.py @@ -0,0 +1,709 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import ( + ExecutionBudgetCreateInput, + ExecutionBudgetDeactivateInput, + ExecutionBudgetSupersedeInput, +) +from alicebot_api.execution_budgets import ( + ExecutionBudgetLifecycleError, + ExecutionBudgetNotFoundError, + ExecutionBudgetValidationError, + create_execution_budget_record, + deactivate_execution_budget_record, + evaluate_execution_budget, + get_execution_budget_record, + list_execution_budget_records, + supersede_execution_budget_record, +) + + +class _SavepointConnection: + def __init__(self, store: "ExecutionBudgetStoreStub") -> None: + self.store = store + + def transaction(self) -> "_Savepoint": + return _Savepoint(self.store) + + +class _Savepoint: + def __init__(self, store: "ExecutionBudgetStoreStub") -> None: + self.store = store + self.snapshot: list[dict[str, object]] | None = None + + def __enter__(self) -> "_Savepoint": + self.snapshot = [dict(row) for row in self.store.budgets] + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + if exc_type is not None and self.snapshot is not None: + self.store.budgets = [dict(row) for row in self.snapshot] + return False + + +class ExecutionBudgetStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 11, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.budgets: list[dict[str, object]] = [] + self.executions: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + self.fail_next_supersede_update = False + self.conn = _SavepointConnection(self) + + def current_time(self) -> datetime: + return self.base_time + timedelta(minutes=len(self.executions)) + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Budget lifecycle thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + def create_execution_budget( + self, + *, + budget_id: UUID | None = None, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, + supersedes_budget_id: UUID | None = None, + ) -> dict[str, object]: + row = { + "id": uuid4() if budget_id is None else budget_id, + "user_id": self.user_id, + "tool_key": tool_key, + "domain_hint": domain_hint, + "max_completed_executions": max_completed_executions, + "rolling_window_seconds": rolling_window_seconds, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": supersedes_budget_id, + "created_at": self.base_time + timedelta(minutes=len(self.budgets)), + } + self.budgets.append(row) + self.budgets.sort(key=lambda item: (item["created_at"], item["id"])) + return row + + def deactivate_execution_budget_optional( + self, + execution_budget_id: UUID, + ) -> dict[str, object] | None: + row = self.get_execution_budget_optional(execution_budget_id) + if row is None or row["status"] != "active": + return None + row["status"] = "inactive" + row["deactivated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return row + + def supersede_execution_budget_optional( + self, + *, + execution_budget_id: UUID, + superseded_by_budget_id: UUID, + ) -> dict[str, object] | None: + if self.fail_next_supersede_update: + self.fail_next_supersede_update = False + return None + row = self.get_execution_budget_optional(execution_budget_id) + if row is None or row["status"] != "active": + return None + row["status"] = "superseded" + row["deactivated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + row["superseded_by_budget_id"] = superseded_by_budget_id + return row + + def get_execution_budget_optional(self, execution_budget_id: UUID) -> dict[str, object] | None: + return next((row for row in self.budgets if row["id"] == execution_budget_id), None) + + def list_execution_budgets(self) -> list[dict[str, object]]: + return list(self.budgets) + + def seed_execution( + self, + *, + tool_key: str, + domain_hint: str | None, + status: str, + offset_minutes: int, + ) -> None: + tool_id = uuid4() + self.executions.append( + { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": uuid4(), + "thread_id": self.thread_id, + "tool_id": tool_id, + "trace_id": uuid4(), + "request_event_id": None, + "result_event_id": None, + "status": status, + "handler_key": None if status == "blocked" else tool_key, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": domain_hint, + "risk_hint": None, + "attributes": {}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + }, + "result": { + "handler_key": None if status == "blocked" else tool_key, + "status": status, + "output": None, + "reason": None, + }, + "executed_at": self.base_time + timedelta(minutes=offset_minutes), + } + ) + + def list_tool_executions(self) -> list[dict[str, object]]: + return list(self.executions) + + +def test_create_execution_budget_requires_at_least_one_selector() -> None: + store = ExecutionBudgetStoreStub() + + with pytest.raises( + ExecutionBudgetValidationError, + match="execution budget requires at least one selector: tool_key or domain_hint", + ): + create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key=None, + domain_hint=None, + max_completed_executions=1, + ), + ) + + +def test_create_execution_budget_rejects_duplicate_active_scope() -> None: + store = ExecutionBudgetStoreStub() + create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ), + ) + + with pytest.raises( + ExecutionBudgetValidationError, + match="active execution budget already exists for selector scope", + ): + create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ), + ) + + +def test_create_execution_budget_includes_optional_rolling_window_seconds() -> None: + store = ExecutionBudgetStoreStub() + + payload = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ), + ) + + assert payload["execution_budget"]["rolling_window_seconds"] == 3600 + assert store.budgets[0]["rolling_window_seconds"] == 3600 + + +def test_create_list_and_get_execution_budget_records_are_deterministic() -> None: + store = ExecutionBudgetStoreStub() + second = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + ), + ) + first = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key=None, + domain_hint="docs", + max_completed_executions=1, + ), + ) + + listed = list_execution_budget_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_budget_id=UUID(second["execution_budget"]["id"]), + ) + + assert [item["id"] for item in listed["items"]] == [ + second["execution_budget"]["id"], + first["execution_budget"]["id"], + ] + assert listed["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail == {"execution_budget": second["execution_budget"]} + assert detail["execution_budget"]["status"] == "active" + assert detail["execution_budget"]["deactivated_at"] is None + assert detail["execution_budget"]["rolling_window_seconds"] is None + + +def test_deactivate_execution_budget_marks_row_inactive_and_records_trace() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ), + ) + + payload = deactivate_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + ), + ) + + assert payload["execution_budget"]["status"] == "inactive" + assert payload["execution_budget"]["deactivated_at"] == "2026-03-13T12:00:00+00:00" + assert payload["trace"]["trace_event_count"] == 3 + assert store.traces[0]["kind"] == "execution_budget.lifecycle" + assert store.traces[0]["compiler_version"] == "execution_budget_lifecycle_v0" + assert [event["kind"] for event in store.trace_events] == [ + "execution_budget.lifecycle.request", + "execution_budget.lifecycle.state", + "execution_budget.lifecycle.summary", + ] + assert store.trace_events[1]["payload"] == { + "execution_budget_id": created["execution_budget"]["id"], + "requested_action": "deactivate", + "previous_status": "active", + "current_status": "inactive", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "deactivated_at": "2026-03-13T12:00:00+00:00", + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": None, + "replacement_rolling_window_seconds": None, + "rejection_reason": None, + } + + +def test_supersede_execution_budget_replaces_active_budget_and_records_trace() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ), + ) + + payload = supersede_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetSupersedeInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + max_completed_executions=3, + ), + ) + + assert payload["superseded_budget"]["status"] == "superseded" + assert payload["replacement_budget"]["status"] == "active" + assert payload["replacement_budget"]["max_completed_executions"] == 3 + assert payload["replacement_budget"]["tool_key"] == "proxy.echo" + assert payload["replacement_budget"]["domain_hint"] == "docs" + assert payload["replacement_budget"]["rolling_window_seconds"] is None + assert payload["replacement_budget"]["supersedes_budget_id"] == created["execution_budget"]["id"] + assert payload["superseded_budget"]["superseded_by_budget_id"] == payload["replacement_budget"]["id"] + assert payload["trace"]["trace_event_count"] == 3 + assert store.trace_events[1]["payload"]["replacement_budget_id"] == payload["replacement_budget"]["id"] + assert store.trace_events[2]["payload"] == { + "execution_budget_id": created["execution_budget"]["id"], + "requested_action": "supersede", + "outcome": "superseded", + "replacement_budget_id": payload["replacement_budget"]["id"], + "active_budget_id": payload["replacement_budget"]["id"], + } + + +def test_lifecycle_rejects_invalid_transition_and_records_trace() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ), + ) + deactivate_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + ), + ) + + with pytest.raises( + ExecutionBudgetLifecycleError, + match=f"execution budget {created['execution_budget']['id']} is inactive and cannot be deactivated", + ): + deactivate_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + ), + ) + + assert store.trace_events[-2]["payload"]["current_status"] == "inactive" + assert store.trace_events[-2]["payload"]["rejection_reason"] == ( + f"execution budget {created['execution_budget']['id']} is inactive and cannot be deactivated" + ) + assert store.trace_events[-1]["payload"]["outcome"] == "rejected" + + +def test_supersede_execution_budget_rolls_back_replacement_when_source_update_fails() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ), + ) + store.fail_next_supersede_update = True + + with pytest.raises( + ExecutionBudgetLifecycleError, + match=f"execution budget {created['execution_budget']['id']} could not be superseded", + ): + supersede_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetSupersedeInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + max_completed_executions=3, + ), + ) + + assert len(store.budgets) == 1 + assert store.budgets[0]["id"] == UUID(created["execution_budget"]["id"]) + assert store.budgets[0]["status"] == "active" + assert store.budgets[0]["superseded_by_budget_id"] is None + assert store.trace_events[-1]["payload"]["outcome"] == "rejected" + + +def test_get_execution_budget_record_raises_clear_error_when_missing() -> None: + store = ExecutionBudgetStoreStub() + + with pytest.raises(ExecutionBudgetNotFoundError, match="execution budget .* was not found"): + get_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_budget_id=uuid4(), + ) + + +def test_evaluate_execution_budget_prefers_more_specific_active_match_and_ignores_inactive_rows() -> None: + store = ExecutionBudgetStoreStub() + inactive = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ) + store.deactivate_execution_budget_optional(inactive["id"]) + store.create_execution_budget(tool_key=None, domain_hint="docs", max_completed_executions=1) + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint="docs", status="completed", offset_minutes=0) + store.seed_execution(tool_key="proxy.echo", domain_hint="docs", status="blocked", offset_minutes=1) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": "docs", + "budget_tool_key": "proxy.echo", + "budget_domain_hint": "docs", + "max_completed_executions": 2, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "allow", + "reason": "within_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result is None + + +def test_evaluate_execution_budget_blocks_when_projected_completed_count_would_exceed_limit() -> None: + store = ExecutionBudgetStoreStub() + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=0) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {matched['id']} blocks execution: projected completed executions " + "2 would exceed limit 1" + ), + "budget_decision": decision.record, + } + + +def test_evaluate_execution_budget_uses_only_recent_completed_history_inside_window() -> None: + store = ExecutionBudgetStoreStub() + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=-120) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=-10) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 2, + "rolling_window_seconds": 3600, + "count_scope": "rolling_window", + "window_started_at": "2026-03-13T10:02:00+00:00", + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "allow", + "reason": "within_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result is None + + +def test_evaluate_execution_budget_blocks_when_recent_window_history_exceeds_limit() -> None: + store = ExecutionBudgetStoreStub() + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=900, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=-5) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": 900, + "count_scope": "rolling_window", + "window_started_at": "2026-03-13T10:46:00+00:00", + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {matched['id']} blocks execution: projected completed executions " + "2 within rolling window 900 seconds would exceed limit 1" + ), + "budget_decision": decision.record, + } diff --git a/tests/unit/test_execution_budgets_main.py b/tests/unit/test_execution_budgets_main.py new file mode 100644 index 0000000..bf7c1cf --- /dev/null +++ b/tests/unit/test_execution_budgets_main.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.execution_budgets import ( + ExecutionBudgetLifecycleError, + ExecutionBudgetNotFoundError, + ExecutionBudgetValidationError, +) + + +def test_create_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_execution_budget_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "execution_budget": { + "id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": 3600, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_execution_budget_record", fake_create_execution_budget_record) + + response = main_module.create_execution_budget( + main_module.CreateExecutionBudgetRequest( + user_id=user_id, + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=3600, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["execution_budget"]["id"] == "budget-123" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].tool_key == "proxy.echo" + assert captured["request"].rolling_window_seconds == 3600 + + +def test_create_execution_budget_endpoint_maps_validation_error_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_execution_budget_record(*_args, **_kwargs): + raise ExecutionBudgetValidationError( + "execution budget requires at least one selector: tool_key or domain_hint" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_execution_budget_record", fake_create_execution_budget_record) + + response = main_module.create_execution_budget( + main_module.CreateExecutionBudgetRequest( + user_id=user_id, + tool_key=None, + domain_hint="docs", + max_completed_executions=1, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "execution budget requires at least one selector: tool_key or domain_hint" + } + + +def test_list_execution_budgets_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_execution_budget_records(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "items": [ + { + "id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + ], + "summary": {"total_count": 1, "order": ["created_at_asc", "id_asc"]}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_execution_budget_records", fake_list_execution_budget_records) + + response = main_module.list_execution_budgets(user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["summary"] == { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + } + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + } + + +def test_get_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_execution_budget_record(store, *, user_id, execution_budget_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["execution_budget_id"] = execution_budget_id + return { + "execution_budget": { + "id": str(execution_budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_execution_budget_record", fake_get_execution_budget_record) + + response = main_module.get_execution_budget(execution_budget_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["execution_budget"]["id"] == str(execution_budget_id) + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + "execution_budget_id": execution_budget_id, + } + + +def test_get_execution_budget_endpoint_maps_missing_record_to_404(monkeypatch) -> None: + user_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_execution_budget_record(*_args, **_kwargs): + raise ExecutionBudgetNotFoundError(f"execution budget {execution_budget_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_execution_budget_record", fake_get_execution_budget_record) + + response = main_module.get_execution_budget(execution_budget_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"execution budget {execution_budget_id} was not found" + } + + +def test_deactivate_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_deactivate_execution_budget_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "execution_budget": { + "id": str(execution_budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "status": "inactive", + "deactivated_at": "2026-03-13T12:00:00+00:00", + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "deactivate_execution_budget_record", fake_deactivate_execution_budget_record) + + response = main_module.deactivate_execution_budget( + execution_budget_id, + main_module.DeactivateExecutionBudgetRequest( + user_id=user_id, + thread_id=thread_id, + ), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["execution_budget"]["status"] == "inactive" + assert captured["request"].thread_id == thread_id + assert captured["request"].execution_budget_id == execution_budget_id + + +def test_deactivate_execution_budget_endpoint_maps_lifecycle_error_to_409(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_deactivate_execution_budget_record(*_args, **_kwargs): + raise ExecutionBudgetLifecycleError( + f"execution budget {execution_budget_id} is inactive and cannot be deactivated" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "deactivate_execution_budget_record", fake_deactivate_execution_budget_record) + + response = main_module.deactivate_execution_budget( + execution_budget_id, + main_module.DeactivateExecutionBudgetRequest( + user_id=user_id, + thread_id=thread_id, + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"execution budget {execution_budget_id} is inactive and cannot be deactivated" + } + + +def test_supersede_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_supersede_execution_budget_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "superseded_budget": { + "id": str(execution_budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": 1800, + "status": "superseded", + "deactivated_at": "2026-03-13T12:00:00+00:00", + "superseded_by_budget_id": "budget-456", + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + }, + "replacement_budget": { + "id": "budget-456", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 3, + "rolling_window_seconds": 1800, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": str(execution_budget_id), + "created_at": "2026-03-13T11:01:00+00:00", + }, + "trace": {"trace_id": "trace-456", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "supersede_execution_budget_record", fake_supersede_execution_budget_record) + + response = main_module.supersede_execution_budget( + execution_budget_id, + main_module.SupersedeExecutionBudgetRequest( + user_id=user_id, + thread_id=thread_id, + max_completed_executions=3, + ), + ) + + assert response.status_code == 200 + body = json.loads(response.body) + assert body["superseded_budget"]["status"] == "superseded" + assert body["replacement_budget"]["status"] == "active" + assert captured["request"].thread_id == thread_id + assert captured["request"].execution_budget_id == execution_budget_id + assert captured["request"].max_completed_executions == 3 diff --git a/tests/unit/test_executions.py b/tests/unit/test_executions.py new file mode 100644 index 0000000..01dac78 --- /dev/null +++ b/tests/unit/test_executions.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.executions import ( + ToolExecutionNotFoundError, + get_tool_execution_record, + list_tool_execution_records, +) + + +class ToolExecutionStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.executions: list[dict[str, object]] = [] + + def seed_execution(self, *, tool_key: str, offset_minutes: int) -> dict[str, object]: + tool_id = uuid4() + execution = { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": uuid4(), + "task_step_id": uuid4(), + "thread_id": self.thread_id, + "tool_id": tool_id, + "trace_id": uuid4(), + "request_event_id": uuid4(), + "result_event_id": uuid4(), + "status": "completed", + "handler_key": tool_key, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": tool_key}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": (self.base_time + timedelta(minutes=offset_minutes)).isoformat(), + }, + "result": { + "handler_key": tool_key, + "status": "completed", + "output": {"mode": "no_side_effect", "tool_key": tool_key}, + "reason": None, + }, + "executed_at": self.base_time + timedelta(minutes=offset_minutes), + } + self.executions.append(execution) + self.executions.sort(key=lambda row: (row["executed_at"], row["id"])) + return execution + + def seed_blocked_execution(self, *, tool_key: str, offset_minutes: int) -> dict[str, object]: + tool_id = uuid4() + execution = { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": uuid4(), + "task_step_id": uuid4(), + "thread_id": self.thread_id, + "tool_id": tool_id, + "trace_id": uuid4(), + "request_event_id": None, + "result_event_id": None, + "status": "blocked", + "handler_key": None, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": tool_key}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + "name": "Missing Proxy", + "description": "Missing handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": (self.base_time + timedelta(minutes=offset_minutes)).isoformat(), + }, + "result": { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": f"tool '{tool_key}' has no registered proxy handler", + }, + "executed_at": self.base_time + timedelta(minutes=offset_minutes), + } + self.executions.append(execution) + self.executions.sort(key=lambda row: (row["executed_at"], row["id"])) + return execution + + def list_tool_executions(self) -> list[dict[str, object]]: + return list(self.executions) + + def get_tool_execution_optional(self, execution_id: UUID) -> dict[str, object] | None: + return next((row for row in self.executions if row["id"] == execution_id), None) + + +def test_list_tool_execution_records_uses_explicit_order_and_summary() -> None: + store = ToolExecutionStoreStub() + first = store.seed_execution(tool_key="proxy.echo", offset_minutes=0) + second = store.seed_execution(tool_key="proxy.echo", offset_minutes=5) + + payload = list_tool_execution_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + + assert [item["id"] for item in payload["items"]] == [str(first["id"]), str(second["id"])] + assert payload["summary"] == { + "total_count": 2, + "order": ["executed_at_asc", "id_asc"], + } + + +def test_get_tool_execution_record_returns_detail_shape() -> None: + store = ToolExecutionStoreStub() + execution = store.seed_execution(tool_key="proxy.echo", offset_minutes=0) + + payload = get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=execution["id"], + ) + + assert payload["execution"]["id"] == str(execution["id"]) + assert payload["execution"]["approval_id"] == str(execution["approval_id"]) + assert payload["execution"]["task_step_id"] == str(execution["task_step_id"]) + assert payload["execution"]["status"] == "completed" + assert payload["execution"]["tool"]["tool_key"] == "proxy.echo" + assert payload["execution"]["result"]["output"] == { + "mode": "no_side_effect", + "tool_key": "proxy.echo", + } + + +def test_get_tool_execution_record_preserves_blocked_attempt_shape() -> None: + store = ToolExecutionStoreStub() + execution = store.seed_blocked_execution(tool_key="proxy.missing", offset_minutes=0) + + payload = get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=execution["id"], + ) + + assert payload["execution"]["status"] == "blocked" + assert payload["execution"]["handler_key"] is None + assert payload["execution"]["request_event_id"] is None + assert payload["execution"]["result_event_id"] is None + assert payload["execution"]["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + } + + +def test_get_tool_execution_record_preserves_budget_blocked_attempt_shape() -> None: + store = ToolExecutionStoreStub() + execution = store.seed_blocked_execution(tool_key="proxy.echo", offset_minutes=0) + execution["result"] = { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "execution budget budget-123 blocks execution: projected completed executions 2 would exceed limit 1", + "budget_decision": { + "matched_budget_id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + } + + payload = get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=execution["id"], + ) + + assert payload["execution"]["result"]["budget_decision"] == { + "matched_budget_id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + + +def test_get_tool_execution_record_raises_clear_error_when_missing() -> None: + store = ToolExecutionStoreStub() + + with pytest.raises(ToolExecutionNotFoundError, match="tool execution .* was not found"): + get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=uuid4(), + ) diff --git a/tests/unit/test_executions_main.py b/tests/unit/test_executions_main.py new file mode 100644 index 0000000..9070c0e --- /dev/null +++ b/tests/unit/test_executions_main.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.executions import ToolExecutionNotFoundError + + +def test_list_tool_executions_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_tool_execution_records(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "items": [ + { + "id": "execution-123", + "approval_id": "approval-123", + "task_step_id": "task-step-123", + "thread_id": "thread-123", + "tool_id": "tool-123", + "trace_id": "trace-123", + "request_event_id": "event-1", + "result_event_id": "event-2", + "status": "completed", + "handler_key": "proxy.echo", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + "executed_at": "2026-03-13T10:00:00+00:00", + } + ], + "summary": {"total_count": 1, "order": ["executed_at_asc", "id_asc"]}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_tool_execution_records", fake_list_tool_execution_records) + + response = main_module.list_tool_executions(user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["summary"] == { + "total_count": 1, + "order": ["executed_at_asc", "id_asc"], + } + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + } + + +def test_get_tool_execution_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + execution_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_tool_execution_record(store, *, user_id, execution_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["execution_id"] = execution_id + return { + "execution": { + "id": str(execution_id), + "approval_id": "approval-123", + "task_step_id": "task-step-123", + "thread_id": "thread-123", + "tool_id": "tool-123", + "trace_id": "trace-123", + "request_event_id": "event-1", + "result_event_id": "event-2", + "status": "completed", + "handler_key": "proxy.echo", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + "executed_at": "2026-03-13T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_tool_execution_record", fake_get_tool_execution_record) + + response = main_module.get_tool_execution(execution_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["execution"]["id"] == str(execution_id) + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + "execution_id": execution_id, + } + + +def test_get_tool_execution_endpoint_maps_missing_record_to_404(monkeypatch) -> None: + user_id = uuid4() + execution_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_tool_execution_record(*_args, **_kwargs): + raise ToolExecutionNotFoundError(f"tool execution {execution_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_tool_execution_record", fake_get_tool_execution_record) + + response = main_module.get_tool_execution(execution_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"tool execution {execution_id} was not found" + } diff --git a/tests/unit/test_explicit_preferences.py b/tests/unit/test_explicit_preferences.py new file mode 100644 index 0000000..7fb5a31 --- /dev/null +++ b/tests/unit/test_explicit_preferences.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import AdmissionDecisionOutput, ExplicitPreferenceExtractionRequestInput +from alicebot_api.explicit_preferences import ( + ExplicitPreferenceExtractionValidationError, + _build_memory_key, + extract_and_admit_explicit_preferences, + extract_explicit_preference_candidates, +) + + +class ExplicitPreferenceStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.events: dict[UUID, dict[str, object]] = {} + + def list_events_by_ids(self, event_ids: list[UUID]) -> list[dict[str, object]]: + return [self.events[event_id] for event_id in event_ids if event_id in self.events] + + +def seed_event( + store: ExplicitPreferenceStoreStub, + *, + kind: str = "message.user", + text: str = "I like black coffee.", +) -> UUID: + event_id = uuid4() + store.events[event_id] = { + "id": event_id, + "sequence_no": 1, + "kind": kind, + "payload": {"text": text}, + "created_at": store.base_time, + } + return event_id + + +def test_extract_explicit_preference_candidates_returns_supported_candidate_shape() -> None: + event_id = UUID("11111111-1111-1111-1111-111111111111") + memory_key = _build_memory_key("black coffee") + + payload = extract_explicit_preference_candidates( + source_event_id=event_id, + text="I like black coffee.", + ) + + assert payload == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ] + + +def test_extract_explicit_preference_candidates_keeps_remember_pattern_deterministic() -> None: + event_id = UUID("22222222-2222-2222-2222-222222222222") + memory_key = _build_memory_key("oat milk") + + payload = extract_explicit_preference_candidates( + source_event_id=event_id, + text=" remember that I prefer oat milk!! ", + ) + + assert payload == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "prefer", + "text": "oat milk", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + "pattern": "remember_that_i_prefer", + "subject_text": "oat milk", + } + ] + + +def test_extract_explicit_preference_candidates_returns_empty_for_unsupported_text() -> None: + assert extract_explicit_preference_candidates( + source_event_id=uuid4(), + text="I had coffee yesterday.", + ) == [] + + +def test_extract_explicit_preference_candidates_rejects_clause_style_text() -> None: + assert extract_explicit_preference_candidates( + source_event_id=uuid4(), + text="I prefer that we meet tomorrow.", + ) == [] + + +def test_build_memory_key_keeps_symbol_bearing_subjects_distinct() -> None: + c_plus_plus_key = _build_memory_key("C++") + c_hash_key = _build_memory_key("C#") + + assert c_plus_plus_key != c_hash_key + assert c_plus_plus_key.startswith("user.preference.c__") + assert c_hash_key.startswith("user.preference.c__") + + +def test_build_memory_key_is_case_insensitive_for_the_same_subject() -> None: + assert _build_memory_key("Black Coffee") == _build_memory_key("black coffee") + + +def test_extract_and_admit_explicit_preferences_rejects_invalid_source_event() -> None: + store = ExplicitPreferenceStoreStub() + event_id = seed_event(store, kind="message.assistant", text="I like black coffee.") + + with pytest.raises( + ExplicitPreferenceExtractionValidationError, + match="source_event_id must reference an existing message.user event owned by the user", + ): + extract_and_admit_explicit_preferences( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=ExplicitPreferenceExtractionRequestInput(source_event_id=event_id), + ) + + +def test_extract_and_admit_explicit_preferences_routes_candidate_through_memory_admission( + monkeypatch, +) -> None: + store = ExplicitPreferenceStoreStub() + user_id = uuid4() + event_id = seed_event(store, text="I don't like black coffee.") + memory_key = _build_memory_key("black coffee") + captured: dict[str, object] = {} + + def fake_admit_memory_candidate(store_arg, *, user_id, candidate): + captured["store"] = store_arg + captured["user_id"] = user_id + captured["candidate"] = candidate + return AdmissionDecisionOutput( + action="ADD", + reason="source_backed_add", + memory={ + "id": "memory-123", + "user_id": str(user_id), + "memory_key": candidate.memory_key, + "value": candidate.value, + "status": "active", + "source_event_ids": [str(event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + revision={ + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": candidate.memory_key, + "previous_value": None, + "new_value": candidate.value, + "source_event_ids": [str(event_id)], + "candidate": candidate.as_payload(), + "created_at": "2026-03-12T09:00:00+00:00", + }, + ) + + monkeypatch.setattr( + "alicebot_api.explicit_preferences.admit_memory_candidate", + fake_admit_memory_candidate, + ) + + payload = extract_and_admit_explicit_preferences( + store, # type: ignore[arg-type] + user_id=user_id, + request=ExplicitPreferenceExtractionRequestInput(source_event_id=event_id), + ) + + assert captured["store"] is store + assert captured["user_id"] == user_id + assert captured["candidate"].memory_key == memory_key + assert captured["candidate"].value == { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + } + assert payload == { + "candidates": [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + "pattern": "i_dont_like", + "subject_text": "black coffee", + } + ], + "admissions": [ + { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "status": "active", + "source_event_ids": [str(event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": memory_key, + "previous_value": None, + "new_value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "candidate": { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + } + ], + "summary": { + "source_event_id": str(event_id), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + }, + } diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py new file mode 100644 index 0000000..dc1e5ca --- /dev/null +++ b/tests/unit/test_main.py @@ -0,0 +1,2378 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.compiler import CompiledTraceRun +from alicebot_api.contracts import AdmissionDecisionOutput +from alicebot_api.embedding import ( + EmbeddingConfigValidationError, + MemoryEmbeddingNotFoundError, + MemoryEmbeddingValidationError, +) +from alicebot_api.entity import EntityNotFoundError, EntityValidationError +from alicebot_api.entity_edge import EntityEdgeValidationError +from alicebot_api.memory import MemoryAdmissionValidationError, MemoryReviewNotFoundError +from alicebot_api.response_generation import ResponseFailure +from alicebot_api.semantic_retrieval import SemanticMemoryRetrievalValidationError +from alicebot_api.store import ContinuityStoreInvariantError + + +def test_healthcheck_reports_ok_when_database_is_reachable(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=7, + ) + ping_calls: list[tuple[str, int]] = [] + + def fake_get_settings() -> Settings: + return settings + + def fake_ping_database(database_url: str, timeout_seconds: int) -> bool: + ping_calls.append((database_url, timeout_seconds)) + return True + + monkeypatch.setattr(main_module, "get_settings", fake_get_settings) + monkeypatch.setattr(main_module, "ping_database", fake_ping_database) + + response = main_module.healthcheck() + + assert response.status_code == 200 + assert json.loads(response.body) == { + "status": "ok", + "environment": "test", + "services": { + "database": {"status": "ok"}, + "redis": {"status": "not_checked", "url": "redis://cache:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://object-store", + }, + }, + } + assert ping_calls == [("postgresql://db", 7)] + + +def test_healthcheck_reports_degraded_when_database_is_unreachable(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=4, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "ping_database", lambda *_args, **_kwargs: False) + + response = main_module.healthcheck() + + assert response.status_code == 503 + assert json.loads(response.body) == { + "status": "degraded", + "environment": "test", + "services": { + "database": {"status": "unreachable"}, + "redis": {"status": "not_checked", "url": "redis://cache:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://object-store", + }, + }, + } + + +def test_healthcheck_route_is_registered() -> None: + route_paths = {route.path for route in main_module.app.routes} + + assert "/healthz" in route_paths + assert "/v0/context/compile" in route_paths + assert "/v0/responses" in route_paths + assert "/v0/memories/admit" in route_paths + assert "/v0/consents" in route_paths + assert "/v0/policies" in route_paths + assert "/v0/policies/{policy_id}" in route_paths + assert "/v0/policies/evaluate" in route_paths + assert "/v0/memories/extract-explicit-preferences" in route_paths + assert "/v0/memories" in route_paths + assert "/v0/memories/review-queue" in route_paths + assert "/v0/memories/evaluation-summary" in route_paths + assert "/v0/memories/semantic-retrieval" in route_paths + assert "/v0/memories/{memory_id}" in route_paths + assert "/v0/memories/{memory_id}/revisions" in route_paths + assert "/v0/memories/{memory_id}/labels" in route_paths + assert "/v0/embedding-configs" in route_paths + assert "/v0/memory-embeddings" in route_paths + assert "/v0/memories/{memory_id}/embeddings" in route_paths + assert "/v0/memory-embeddings/{memory_embedding_id}" in route_paths + assert "/v0/entities" in route_paths + assert "/v0/entity-edges" in route_paths + assert "/v0/tools/route" in route_paths + assert "/v0/execution-budgets" in route_paths + assert "/v0/execution-budgets/{execution_budget_id}" in route_paths + assert "/v0/execution-budgets/{execution_budget_id}/deactivate" in route_paths + assert "/v0/execution-budgets/{execution_budget_id}/supersede" in route_paths + assert "/v0/tool-executions" in route_paths + assert "/v0/tool-executions/{execution_id}" in route_paths + assert "/v0/tasks" in route_paths + assert "/v0/tasks/{task_id}" in route_paths + assert "/v0/tasks/{task_id}/workspace" in route_paths + assert "/v0/tasks/{task_id}/steps" in route_paths + assert "/v0/task-workspaces" in route_paths + assert "/v0/task-workspaces/{task_workspace_id}" in route_paths + assert "/v0/task-steps/{task_step_id}" in route_paths + assert "/v0/task-steps/{task_step_id}/transition" in route_paths + assert "/v0/entities/{entity_id}" in route_paths + assert "/v0/entities/{entity_id}/edges" in route_paths + + +def test_redact_url_credentials_strips_embedded_secrets() -> None: + assert main_module.redact_url_credentials("redis://alicebot:supersecret@cache:6379/0") == ( + "redis://cache:6379/0" + ) + assert main_module.redact_url_credentials("redis://cache:6379/0") == "redis://cache:6379/0" + + +def test_build_healthcheck_payload_keeps_boundary_statuses_consistent() -> None: + settings = Settings( + app_env="test", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + ) + + assert main_module.build_healthcheck_payload(settings, database_ok=True) == { + "status": "ok", + "environment": "test", + "services": { + "database": {"status": "ok"}, + "redis": {"status": "not_checked", "url": "redis://cache:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://object-store", + }, + }, + } + assert main_module.build_healthcheck_payload(settings, database_ok=False)["services"][ + "database" + ] == {"status": "unreachable"} + + +def test_compile_context_returns_trace_and_context_pack(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semantic_retrieval): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["thread_id"] = thread_id + captured["limits"] = limits + captured["semantic_retrieval"] = semantic_retrieval + return CompiledTraceRun( + trace_id="trace-123", + trace_event_count=5, + context_pack={ + "compiler_version": "continuity_v0", + "scope": {"user_id": str(user_id), "thread_id": str(thread_id)}, + "limits": { + "max_sessions": 2, + "max_events": 4, + "max_memories": 3, + "max_entities": 2, + "max_entity_edges": 6, + }, + "user": { + "id": str(user_id), + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-11T09:00:00+00:00", + }, + "thread": { + "id": str(thread_id), + "title": "Thread", + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:01:00+00:00", + }, + "sessions": [], + "events": [], + "memories": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ], + "memory_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:03:00+00:00", + } + ], + "entity_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + "entity_edges": [ + { + "id": "edge-123", + "from_entity_id": "entity-123", + "to_entity_id": "entity-999", + "relationship_type": "depends_on", + "valid_from": "2026-03-11T09:04:00+00:00", + "valid_to": None, + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:04:00+00:00", + } + ], + "entity_edge_summary": { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + }, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "compile_and_persist_trace", fake_compile_and_persist_trace) + + response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + max_sessions=2, + max_events=4, + max_memories=3, + max_entities=2, + max_entity_edges=6, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "trace_id": "trace-123", + "trace_event_count": 5, + "context_pack": { + "compiler_version": "continuity_v0", + "scope": {"user_id": str(user_id), "thread_id": str(thread_id)}, + "limits": { + "max_sessions": 2, + "max_events": 4, + "max_memories": 3, + "max_entities": 2, + "max_entity_edges": 6, + }, + "user": { + "id": str(user_id), + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-11T09:00:00+00:00", + }, + "thread": { + "id": str(thread_id), + "title": "Thread", + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:01:00+00:00", + }, + "sessions": [], + "events": [], + "memories": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ], + "memory_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:03:00+00:00", + } + ], + "entity_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + "entity_edges": [ + { + "id": "edge-123", + "from_entity_id": "entity-123", + "to_entity_id": "entity-999", + "relationship_type": "depends_on", + "valid_from": "2026-03-11T09:04:00+00:00", + "valid_to": None, + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:04:00+00:00", + } + ], + "entity_edge_summary": { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["thread_id"] == thread_id + assert captured["limits"].max_sessions == 2 + assert captured["limits"].max_events == 4 + assert captured["limits"].max_memories == 3 + assert captured["limits"].max_entities == 2 + assert captured["limits"].max_entity_edges == 6 + assert captured["semantic_retrieval"] is None + + +def test_compile_context_returns_not_found_when_scope_row_is_missing(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "compile_and_persist_trace", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + ContinuityStoreInvariantError("get_thread did not return a row from the database") + ), + ) + + response = main_module.compile_context( + main_module.CompileContextRequest(user_id=uuid4(), thread_id=uuid4()) + ) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": "get_thread did not return a row from the database", + } + + +def test_compile_context_routes_semantic_inputs_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semantic_retrieval): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["thread_id"] = thread_id + captured["limits"] = limits + captured["semantic_retrieval"] = semantic_retrieval + return CompiledTraceRun( + trace_id="trace-semantic", + trace_event_count=7, + context_pack={ + "compiler_version": "continuity_v0", + "scope": {"user_id": str(user_id), "thread_id": str(thread_id)}, + "limits": { + "max_sessions": 3, + "max_events": 8, + "max_memories": 5, + "max_entities": 5, + "max_entity_edges": 10, + }, + "user": { + "id": str(user_id), + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "thread": { + "id": str(thread_id), + "title": "Thread", + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:01:00+00:00", + }, + "sessions": [], + "events": [], + "memories": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-123"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "source_provenance": { + "sources": ["symbolic", "semantic"], + "semantic_score": 0.99, + }, + } + ], + "memory_summary": { + "candidate_count": 1, + "included_count": 1, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "semantic_limit": 2, + "symbolic_selected_count": 1, + "semantic_selected_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 1, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [], + "entity_summary": { + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + "entity_edges": [], + "entity_edge_summary": { + "anchor_entity_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + }, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "compile_and_persist_trace", fake_compile_and_persist_trace) + + response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + semantic=main_module.CompileContextSemanticRequest( + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=2, + ), + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["context_pack"]["memory_summary"]["hybrid_retrieval"] == { + "requested": True, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "semantic_limit": 2, + "symbolic_selected_count": 1, + "semantic_selected_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 1, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["semantic_retrieval"].embedding_config_id == config_id + assert captured["semantic_retrieval"].query_vector == (0.1, 0.2, 0.3) + assert captured["semantic_retrieval"].limit == 2 + + monkeypatch.setattr( + main_module, + "compile_and_persist_trace", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + SemanticMemoryRetrievalValidationError( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + ), + ) + + error_response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + semantic=main_module.CompileContextSemanticRequest( + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=2, + ), + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "embedding_config_id must reference an existing embedding config owned by the user" + } + + +def test_generate_assistant_response_returns_assistant_and_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings( + database_url="postgresql://app", + model_provider="openai_responses", + model_name="gpt-5-mini", + ) + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_generate_response(store, *, settings, user_id, thread_id, message_text, limits): + captured["store_type"] = type(store).__name__ + captured["settings"] = settings + captured["user_id"] = user_id + captured["thread_id"] = thread_id + captured["message_text"] = message_text + captured["limits"] = limits + return { + "assistant": { + "event_id": "assistant-event-123", + "sequence_no": 5, + "text": "Hello back.", + "model_provider": "openai_responses", + "model": "gpt-5-mini", + }, + "trace": { + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 11, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "generate_response", fake_generate_response) + + response = main_module.generate_assistant_response( + main_module.GenerateResponseRequest( + user_id=user_id, + thread_id=thread_id, + message="Hello?", + max_sessions=2, + max_events=4, + max_memories=3, + max_entities=2, + max_entity_edges=6, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "assistant": { + "event_id": "assistant-event-123", + "sequence_no": 5, + "text": "Hello back.", + "model_provider": "openai_responses", + "model": "gpt-5-mini", + }, + "trace": { + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 11, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["thread_id"] == thread_id + assert captured["message_text"] == "Hello?" + assert captured["limits"].max_sessions == 2 + assert captured["limits"].max_events == 4 + assert captured["limits"].max_memories == 3 + assert captured["limits"].max_entities == 2 + assert captured["limits"].max_entity_edges == 6 + + +def test_generate_assistant_response_returns_502_with_trace_when_model_invocation_fails( + monkeypatch, +) -> None: + user_id = uuid4() + thread_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "generate_response", + lambda *_args, **_kwargs: ResponseFailure( + detail="upstream timeout", + trace={ + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 9, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + ), + ) + + response = main_module.generate_assistant_response( + main_module.GenerateResponseRequest( + user_id=user_id, + thread_id=thread_id, + message="Hello?", + ) + ) + + assert response.status_code == 502 + assert json.loads(response.body) == { + "detail": "upstream timeout", + "trace": { + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 9, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + } + + +def test_admit_memory_returns_decision_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_admit_memory_candidate(store, *, user_id, candidate): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["candidate"] = candidate + return AdmissionDecisionOutput( + action="ADD", + reason="source_backed_add", + memory={ + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:00:00+00:00", + "deleted_at": None, + }, + revision={ + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "candidate": { + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "delete_requested": False, + }, + "created_at": "2026-03-11T09:00:00+00:00", + }, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "admit_memory_candidate", fake_admit_memory_candidate) + + response = main_module.admit_memory( + main_module.AdmitMemoryRequest( + user_id=user_id, + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=[uuid4()], + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "candidate": { + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "delete_requested": False, + }, + "created_at": "2026-03-11T09:00:00+00:00", + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["candidate"].memory_key == "user.preference.coffee" + + +def test_admit_memory_returns_bad_request_when_source_validation_fails(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "admit_memory_candidate", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + MemoryAdmissionValidationError("source_event_ids must all reference existing events owned by the user") + ), + ) + + response = main_module.admit_memory( + main_module.AdmitMemoryRequest( + user_id=uuid4(), + memory_key="user.preference.coffee", + value={"likes": "black"}, + source_event_ids=[uuid4()], + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "source_event_ids must all reference existing events owned by the user", + } + + +def test_extract_explicit_preferences_returns_payload(monkeypatch) -> None: + user_id = uuid4() + source_event_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_extract_and_admit_explicit_preferences(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "candidates": [ + { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ], + "admissions": [ + { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "status": "active", + "source_event_ids": [str(source_event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.black_coffee", + "previous_value": None, + "new_value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "candidate": { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + } + ], + "summary": { + "source_event_id": str(source_event_id), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "extract_and_admit_explicit_preferences", + fake_extract_and_admit_explicit_preferences, + ) + + response = main_module.extract_explicit_preferences( + main_module.ExtractExplicitPreferencesRequest( + user_id=user_id, + source_event_id=source_event_id, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "candidates": [ + { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ], + "admissions": [ + { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "status": "active", + "source_event_ids": [str(source_event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.black_coffee", + "previous_value": None, + "new_value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "candidate": { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + } + ], + "summary": { + "source_event_id": str(source_event_id), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].source_event_id == source_event_id + + +def test_extract_explicit_preferences_returns_bad_request_when_source_event_is_invalid( + monkeypatch, +) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "extract_and_admit_explicit_preferences", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + main_module.ExplicitPreferenceExtractionValidationError( + "source_event_id must reference an existing message.user event owned by the user" + ) + ), + ) + + response = main_module.extract_explicit_preferences( + main_module.ExtractExplicitPreferencesRequest( + user_id=uuid4(), + source_event_id=uuid4(), + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "source_event_id must reference an existing message.user event owned by the user", + } + + +def test_list_memories_returns_review_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_review_records(store, *, user_id, status, limit): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["status"] = status + captured["limit"] = limit + return { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "deleted_at": None, + } + ], + "summary": { + "status": "active", + "limit": 10, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_memory_review_records", fake_list_memory_review_records) + + response = main_module.list_memories(user_id=user_id, status="active", limit=10) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "deleted_at": None, + } + ], + "summary": { + "status": "active", + "limit": 10, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["status"] == "active" + assert captured["limit"] == 10 + + +def test_get_memory_returns_not_found_when_memory_is_inaccessible(monkeypatch) -> None: + memory_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "get_memory_review_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + main_module.MemoryReviewNotFoundError(f"memory {memory_id} was not found") + ), + ) + + response = main_module.get_memory(memory_id=memory_id, user_id=uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"memory {memory_id} was not found", + } + + +def test_list_memory_review_queue_returns_unlabeled_active_queue_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_review_queue_records(store, *, user_id, limit): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["limit"] = limit + return { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:02:00+00:00", + } + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 7, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_memory_review_queue_records", fake_list_memory_review_queue_records) + + response = main_module.list_memory_review_queue(user_id=user_id, limit=7) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:02:00+00:00", + } + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 7, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["limit"] == 7 + + +def test_get_memories_evaluation_summary_returns_aggregate_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_memory_evaluation_summary(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "summary": { + "total_memory_count": 4, + "active_memory_count": 3, + "deleted_memory_count": 1, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_memory_evaluation_summary", fake_get_memory_evaluation_summary) + + response = main_module.get_memories_evaluation_summary(user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "summary": { + "total_memory_count": 4, + "active_memory_count": 3, + "deleted_memory_count": 1, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + + +def test_list_memory_revisions_returns_review_payload(monkeypatch) -> None: + user_id = uuid4() + memory_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_revision_review_records(store, *, user_id, memory_id, limit): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["memory_id"] = memory_id + captured["limit"] = limit + return { + "items": [ + { + "id": "revision-123", + "memory_id": str(memory_id), + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + } + ], + "summary": { + "memory_id": str(memory_id), + "limit": 5, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["sequence_no_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_revision_review_records", + fake_list_memory_revision_review_records, + ) + + response = main_module.list_memory_revisions(memory_id=memory_id, user_id=user_id, limit=5) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "revision-123", + "memory_id": str(memory_id), + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + } + ], + "summary": { + "memory_id": str(memory_id), + "limit": 5, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["sequence_no_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["memory_id"] == memory_id + assert captured["limit"] == 5 + + +def test_create_memory_review_label_returns_created_payload(monkeypatch) -> None: + memory_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_memory_review_label_record(store, *, user_id, memory_id, label, note): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["memory_id"] = memory_id + captured["label"] = label + captured["note"] = note + return { + "label": { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "correct", + "note": "Backed by the latest source.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_memory_review_label_record", + fake_create_memory_review_label_record, + ) + + response = main_module.create_memory_review_label( + memory_id, + main_module.CreateMemoryReviewLabelRequest( + user_id=user_id, + label="correct", + note="Backed by the latest source.", + ), + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "label": { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "correct", + "note": "Backed by the latest source.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["memory_id"] == memory_id + assert captured["label"] == "correct" + assert captured["note"] == "Backed by the latest source." + + +def test_create_memory_review_label_returns_not_found_for_inaccessible_memory(monkeypatch) -> None: + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_memory_review_label_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw(MemoryReviewNotFoundError("memory missing")), + ) + + response = main_module.create_memory_review_label( + uuid4(), + main_module.CreateMemoryReviewLabelRequest(user_id=uuid4(), label="incorrect"), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": "memory missing"} + + +def test_list_memory_review_labels_returns_deterministic_items_and_summary(monkeypatch) -> None: + memory_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_review_label_records(store, *, user_id, memory_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["memory_id"] = memory_id + return { + "items": [ + { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "incorrect", + "note": "Conflicts with the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + { + "id": "label-124", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "outdated", + "note": None, + "created_at": "2026-03-12T09:01:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_review_label_records", + fake_list_memory_review_label_records, + ) + + response = main_module.list_memory_review_labels(memory_id=memory_id, user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "incorrect", + "note": "Conflicts with the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + { + "id": "label-124", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "outdated", + "note": None, + "created_at": "2026-03-12T09:01:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["memory_id"] == memory_id + + +def test_list_memory_review_labels_returns_not_found_for_inaccessible_memory(monkeypatch) -> None: + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_review_label_records", + lambda *_args, **_kwargs: (_ for _ in ()).throw(MemoryReviewNotFoundError("memory hidden")), + ) + + response = main_module.list_memory_review_labels(uuid4(), uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": "memory hidden"} + + +def test_create_embedding_config_returns_created_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_embedding_config_record(store, *, user_id, config): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["config"] = config + return { + "embedding_config": { + "id": "config-123", + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + "created_at": "2026-03-12T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_embedding_config_record", fake_create_embedding_config_record) + + response = main_module.create_embedding_config( + main_module.CreateEmbeddingConfigRequest( + user_id=user_id, + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + status="active", + metadata={"task": "memory_retrieval"}, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "embedding_config": { + "id": "config-123", + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + "created_at": "2026-03-12T10:00:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["config"].provider == "openai" + + +def test_create_embedding_config_returns_bad_request_for_validation_failure(monkeypatch) -> None: + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_embedding_config_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EmbeddingConfigValidationError( + "embedding config already exists for provider/model/version under the user scope: " + "openai/text-embedding-3-large/2026-03-12" + ) + ), + ) + + response = main_module.create_embedding_config( + main_module.CreateEmbeddingConfigRequest( + user_id=uuid4(), + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + status="active", + metadata={"task": "memory_retrieval"}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": ( + "embedding config already exists for provider/model/version under the user scope: " + "openai/text-embedding-3-large/2026-03-12" + ) + } + + +def test_upsert_memory_embedding_routes_success_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + memory_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_upsert_memory_embedding_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "embedding": { + "id": "embedding-123", + "memory_id": str(memory_id), + "embedding_config_id": str(config_id), + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + }, + "write_mode": "created", + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "upsert_memory_embedding_record", fake_upsert_memory_embedding_record) + + response = main_module.upsert_memory_embedding( + main_module.UpsertMemoryEmbeddingRequest( + user_id=user_id, + memory_id=memory_id, + embedding_config_id=config_id, + vector=[0.1, 0.2, 0.3], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["write_mode"] == "created" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].memory_id == memory_id + + monkeypatch.setattr( + main_module, + "upsert_memory_embedding_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + MemoryEmbeddingValidationError( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + ), + ) + + error_response = main_module.upsert_memory_embedding( + main_module.UpsertMemoryEmbeddingRequest( + user_id=user_id, + memory_id=memory_id, + embedding_config_id=config_id, + vector=[0.1, 0.2, 0.3], + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "embedding_config_id must reference an existing embedding config owned by the user" + } + + +def test_retrieve_semantic_memories_routes_success_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_retrieve_semantic_memory_records(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "items": [ + { + "memory_id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": ["event-123"], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + "score": 0.99, + } + ], + "summary": { + "embedding_config_id": str(config_id), + "limit": 5, + "returned_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_semantic_memory_records", + fake_retrieve_semantic_memory_records, + ) + + response = main_module.retrieve_semantic_memories( + main_module.RetrieveSemanticMemoriesRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=5, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["summary"] == { + "embedding_config_id": str(config_id), + "limit": 5, + "returned_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].embedding_config_id == config_id + assert captured["request"].query_vector == (0.1, 0.2, 0.3) + + monkeypatch.setattr( + main_module, + "retrieve_semantic_memory_records", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + SemanticMemoryRetrievalValidationError( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + ), + ) + + error_response = main_module.retrieve_semantic_memories( + main_module.RetrieveSemanticMemoriesRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=5, + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "embedding_config_id must reference an existing embedding config owned by the user" + } + + +def test_memory_embedding_read_routes_return_payload_and_not_found(monkeypatch) -> None: + user_id = uuid4() + memory_id = uuid4() + embedding_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_embedding_records", + lambda *_args, **_kwargs: { + "items": [ + { + "id": str(embedding_id), + "memory_id": str(memory_id), + "embedding_config_id": "config-123", + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + }, + ) + monkeypatch.setattr( + main_module, + "get_memory_embedding_record", + lambda *_args, **_kwargs: { + "embedding": { + "id": str(embedding_id), + "memory_id": str(memory_id), + "embedding_config_id": "config-123", + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + } + }, + ) + + list_response = main_module.list_memory_embeddings(memory_id=memory_id, user_id=user_id) + detail_response = main_module.get_memory_embedding(memory_embedding_id=embedding_id, user_id=user_id) + + assert list_response.status_code == 200 + assert json.loads(list_response.body)["summary"]["memory_id"] == str(memory_id) + assert detail_response.status_code == 200 + assert json.loads(detail_response.body)["embedding"]["id"] == str(embedding_id) + + monkeypatch.setattr( + main_module, + "get_memory_embedding_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + MemoryEmbeddingNotFoundError(f"memory embedding {embedding_id} was not found") + ), + ) + + not_found_response = main_module.get_memory_embedding( + memory_embedding_id=embedding_id, + user_id=user_id, + ) + + assert not_found_response.status_code == 404 + assert json.loads(not_found_response.body) == { + "detail": f"memory embedding {embedding_id} was not found" + } + + +def test_create_entity_returns_created_payload(monkeypatch) -> None: + user_id = uuid4() + first_memory_id = uuid4() + second_memory_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_entity_record(store, *, user_id, entity): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["entity"] = entity + return { + "entity": { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_entity_record", fake_create_entity_record) + + response = main_module.create_entity( + main_module.CreateEntityRequest( + user_id=user_id, + entity_type="project", + name="AliceBot", + source_memory_ids=[first_memory_id, second_memory_id], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "entity": { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["entity"].entity_type == "project" + assert captured["entity"].name == "AliceBot" + + +def test_create_entity_returns_bad_request_when_source_memory_validation_fails(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_entity_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityValidationError("source_memory_ids must all reference existing memories owned by the user") + ), + ) + + response = main_module.create_entity( + main_module.CreateEntityRequest( + user_id=uuid4(), + entity_type="person", + name="Samir", + source_memory_ids=[uuid4()], + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "source_memory_ids must all reference existing memories owned by the user", + } + + +def test_create_entity_edge_returns_created_payload(monkeypatch) -> None: + user_id = uuid4() + from_entity_id = uuid4() + to_entity_id = uuid4() + source_memory_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_entity_edge_record(store, *, user_id, edge): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["edge"] = edge + return { + "edge": { + "id": "edge-123", + "from_entity_id": str(from_entity_id), + "to_entity_id": str(to_entity_id), + "relationship_type": "works_on", + "valid_from": "2026-03-12T10:00:00+00:00", + "valid_to": None, + "source_memory_ids": [str(source_memory_id)], + "created_at": "2026-03-12T10:01:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_entity_edge_record", fake_create_entity_edge_record) + + response = main_module.create_entity_edge( + main_module.CreateEntityEdgeRequest( + user_id=user_id, + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from="2026-03-12T10:00:00+00:00", + source_memory_ids=[source_memory_id], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "edge": { + "id": "edge-123", + "from_entity_id": str(from_entity_id), + "to_entity_id": str(to_entity_id), + "relationship_type": "works_on", + "valid_from": "2026-03-12T10:00:00+00:00", + "valid_to": None, + "source_memory_ids": [str(source_memory_id)], + "created_at": "2026-03-12T10:01:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["edge"].from_entity_id == from_entity_id + assert captured["edge"].to_entity_id == to_entity_id + + +def test_create_entity_edge_returns_bad_request_for_validation_failure(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_entity_edge_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityEdgeValidationError("valid_to must be greater than or equal to valid_from") + ), + ) + + response = main_module.create_entity_edge( + main_module.CreateEntityEdgeRequest( + user_id=uuid4(), + from_entity_id=uuid4(), + to_entity_id=uuid4(), + relationship_type="works_on", + valid_from="2026-03-12T11:00:00+00:00", + valid_to="2026-03-12T10:00:00+00:00", + source_memory_ids=[uuid4()], + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "valid_to must be greater than or equal to valid_from", + } + + +def test_list_entities_returns_deterministic_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_entity_records(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "items": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_entity_records", fake_list_entity_records) + + response = main_module.list_entities(user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + + +def test_list_entity_edges_returns_deterministic_payload(monkeypatch) -> None: + user_id = uuid4() + entity_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_entity_edge_records(store, *, user_id, entity_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["entity_id"] = entity_id + return { + "items": [ + { + "id": "edge-123", + "from_entity_id": str(entity_id), + "to_entity_id": "entity-456", + "relationship_type": "works_on", + "valid_from": None, + "valid_to": None, + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "entity_id": str(entity_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_entity_edge_records", fake_list_entity_edge_records) + + response = main_module.list_entity_edges(entity_id=entity_id, user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "edge-123", + "from_entity_id": str(entity_id), + "to_entity_id": "entity-456", + "relationship_type": "works_on", + "valid_from": None, + "valid_to": None, + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "entity_id": str(entity_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["entity_id"] == entity_id + + +def test_list_entity_edges_returns_not_found_for_inaccessible_entity(monkeypatch) -> None: + entity_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_entity_edge_records", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityNotFoundError(f"entity {entity_id} was not found") + ), + ) + + response = main_module.list_entity_edges(entity_id=entity_id, user_id=uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"entity {entity_id} was not found", + } + + +def test_get_entity_returns_detail_payload(monkeypatch) -> None: + user_id = uuid4() + entity_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_entity_record(store, *, user_id, entity_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["entity_id"] = entity_id + return { + "entity": { + "id": str(entity_id), + "entity_type": "person", + "name": "Samir", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_entity_record", fake_get_entity_record) + + response = main_module.get_entity(entity_id=entity_id, user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "entity": { + "id": str(entity_id), + "entity_type": "person", + "name": "Samir", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["entity_id"] == entity_id + + +def test_get_entity_returns_not_found_for_inaccessible_entity(monkeypatch) -> None: + entity_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "get_entity_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityNotFoundError(f"entity {entity_id} was not found") + ), + ) + + response = main_module.get_entity(entity_id=entity_id, user_id=uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"entity {entity_id} was not found", + } diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py new file mode 100644 index 0000000..1ce8211 --- /dev/null +++ b/tests/unit/test_memory.py @@ -0,0 +1,897 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.memory import ( + MemoryAdmissionValidationError, + MemoryReviewNotFoundError, + admit_memory_candidate, + create_memory_review_label_record, + get_memory_evaluation_summary, + get_memory_review_record, + list_memory_review_queue_records, + list_memory_review_label_records, + list_memory_review_records, + list_memory_revision_review_records, +) + + +class MemoryStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + self.events: dict[UUID, dict[str, object]] = {} + self.memory: dict[str, object] | None = None + self.revisions: list[dict[str, object]] = [] + + def list_events_by_ids(self, event_ids: list[UUID]) -> list[dict[str, object]]: + return [self.events[event_id] for event_id in event_ids if event_id in self.events] + + def get_memory_by_key(self, memory_key: str) -> dict[str, object] | None: + if self.memory is None or self.memory["memory_key"] != memory_key: + return None + return self.memory + + def create_memory( + self, + *, + memory_key: str, + value, + status: str, + source_event_ids: list[str], + ) -> dict[str, object]: + self.memory = { + "id": uuid4(), + "user_id": uuid4(), + "memory_key": memory_key, + "value": value, + "status": status, + "source_event_ids": source_event_ids, + "created_at": self.base_time, + "updated_at": self.base_time, + "deleted_at": None, + } + return self.memory + + def update_memory( + self, + *, + memory_id: UUID, + value, + status: str, + source_event_ids: list[str], + ) -> dict[str, object]: + assert self.memory is not None + assert self.memory["id"] == memory_id + updated_at = self.base_time + timedelta(minutes=len(self.revisions) + 1) + self.memory = { + **self.memory, + "value": value, + "status": status, + "source_event_ids": source_event_ids, + "updated_at": updated_at, + "deleted_at": updated_at if status == "deleted" else None, + } + return self.memory + + def append_memory_revision( + self, + *, + memory_id: UUID, + action: str, + memory_key: str, + previous_value, + new_value, + source_event_ids: list[str], + candidate: dict[str, object], + ) -> dict[str, object]: + revision = { + "id": uuid4(), + "user_id": self.memory["user_id"] if self.memory is not None else uuid4(), + "memory_id": memory_id, + "sequence_no": len(self.revisions) + 1, + "action": action, + "memory_key": memory_key, + "previous_value": previous_value, + "new_value": new_value, + "source_event_ids": source_event_ids, + "candidate": candidate, + "created_at": self.base_time + timedelta(minutes=len(self.revisions) + 1), + } + self.revisions.append(revision) + return revision + + +def seed_event(store: MemoryStoreStub) -> UUID: + event_id = uuid4() + store.events[event_id] = { + "id": event_id, + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "evidence"}, + "created_at": store.base_time, + } + return event_id + + +def test_admit_memory_candidate_defaults_to_noop_when_value_is_missing() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value=None, + source_event_ids=(event_id,), + ), + ) + + assert decision.action == "NOOP" + assert decision.reason == "candidate_value_missing" + assert decision.memory is None + assert decision.revision is None + + +def test_admit_memory_candidate_rejects_missing_source_events() -> None: + store = MemoryStoreStub() + + with pytest.raises( + MemoryAdmissionValidationError, + match="source_event_ids must all reference existing events owned by the user", + ): + admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.tea", + value={"likes": True}, + source_event_ids=(uuid4(),), + ), + ) + + +def test_admit_memory_candidate_rejects_empty_source_event_ids() -> None: + store = MemoryStoreStub() + + with pytest.raises( + MemoryAdmissionValidationError, + match="source_event_ids must include at least one existing event owned by the user", + ): + admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.tea", + value={"likes": True}, + source_event_ids=(), + ), + ) + + +def test_admit_memory_candidate_adds_new_memory_with_first_revision() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_id,), + ), + ) + + assert decision.action == "ADD" + assert decision.reason == "source_backed_add" + assert decision.memory is not None + assert decision.memory["memory_key"] == "user.preference.coffee" + assert decision.memory["status"] == "active" + assert decision.revision is not None + assert decision.revision["sequence_no"] == 1 + assert decision.revision["action"] == "ADD" + assert decision.revision["new_value"] == {"likes": "oat milk"} + + +def test_admit_memory_candidate_updates_existing_memory_and_appends_revision() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + created = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_id)], + ) + store.append_memory_revision( + memory_id=created["id"], + action="ADD", + memory_key="user.preference.coffee", + previous_value=None, + new_value={"likes": "black"}, + source_event_ids=[str(event_id)], + candidate={ + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(event_id)], + "delete_requested": False, + }, + ) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_id,), + ), + ) + + assert decision.action == "UPDATE" + assert decision.reason == "source_backed_update" + assert decision.memory is not None + assert decision.memory["value"] == {"likes": "oat milk"} + assert decision.revision is not None + assert decision.revision["sequence_no"] == 2 + assert decision.revision["previous_value"] == {"likes": "black"} + assert decision.revision["new_value"] == {"likes": "oat milk"} + + +def test_admit_memory_candidate_marks_memory_deleted_and_appends_revision() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + created = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_id)], + ) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value=None, + source_event_ids=(event_id,), + delete_requested=True, + ), + ) + + assert decision.action == "DELETE" + assert decision.reason == "source_backed_delete" + assert decision.memory is not None + assert UUID(decision.memory["id"]) == created["id"] + assert decision.memory["status"] == "deleted" + assert decision.revision is not None + assert decision.revision["sequence_no"] == 1 + assert decision.revision["action"] == "DELETE" + assert decision.revision["new_value"] is None + + +class MemoryReviewStoreStub: + def __init__(self) -> None: + self.memories: list[dict[str, object]] = [] + self.revisions: dict[UUID, list[dict[str, object]]] = {} + self.labels: dict[UUID, list[dict[str, object]]] = {} + + def count_memories(self, *, status: str | None = None) -> int: + return len(self._filtered_memories(status)) + + def list_review_memories(self, *, status: str | None = None, limit: int) -> list[dict[str, object]]: + return self._review_sorted_memories(self._filtered_memories(status))[:limit] + + def count_unlabeled_review_memories(self) -> int: + return len( + [memory for memory in self.memories if memory["status"] == "active" and not self.labels.get(memory["id"])] + ) + + def list_unlabeled_review_memories(self, *, limit: int) -> list[dict[str, object]]: + return self._review_sorted_memories( + [ + memory + for memory in self.memories + if memory["status"] == "active" and not self.labels.get(memory["id"]) + ] + )[:limit] + + def get_memory_optional(self, memory_id: UUID) -> dict[str, object] | None: + for memory in self.memories: + if memory["id"] == memory_id: + return memory + return None + + def count_memory_revisions(self, memory_id: UUID) -> int: + return len(self.revisions.get(memory_id, [])) + + def list_memory_revisions(self, memory_id: UUID, *, limit: int | None = None) -> list[dict[str, object]]: + revisions = self.revisions.get(memory_id, []) + if limit is None: + return revisions + return revisions[:limit] + + def create_memory_review_label( + self, + *, + memory_id: UUID, + label: str, + note: str | None, + ) -> dict[str, object]: + memory = self.get_memory_optional(memory_id) + created = { + "id": uuid4(), + "user_id": uuid4() if memory is None else memory["user_id"], + "memory_id": memory_id, + "label": label, + "note": note, + "created_at": datetime(2026, 3, 11, 13, len(self.labels.get(memory_id, [])), tzinfo=UTC), + } + self.labels.setdefault(memory_id, []).append(created) + return created + + def list_memory_review_labels(self, memory_id: UUID) -> list[dict[str, object]]: + return list(self.labels.get(memory_id, [])) + + def list_memory_review_label_counts(self, memory_id: UUID) -> list[dict[str, object]]: + counts: dict[str, int] = {} + for label in self.labels.get(memory_id, []): + label_name = label["label"] + counts[label_name] = counts.get(label_name, 0) + 1 + return [{"label": label, "count": count} for label, count in sorted(counts.items())] + + def count_labeled_memories(self) -> int: + return len([memory for memory in self.memories if self.labels.get(memory["id"])]) + + def count_unlabeled_memories(self) -> int: + return len([memory for memory in self.memories if not self.labels.get(memory["id"])]) + + def list_all_memory_review_label_counts(self) -> list[dict[str, object]]: + counts: dict[str, int] = {} + for labels in self.labels.values(): + for label in labels: + label_name = label["label"] + counts[label_name] = counts.get(label_name, 0) + 1 + return [{"label": label, "count": count} for label, count in sorted(counts.items())] + + def _filtered_memories(self, status: str | None) -> list[dict[str, object]]: + if status is None: + return list(self.memories) + return [memory for memory in self.memories if memory["status"] == status] + + def _review_sorted_memories(self, memories: list[dict[str, object]]) -> list[dict[str, object]]: + return sorted( + memories, + key=lambda memory: (memory["updated_at"], memory["created_at"], memory["id"]), + reverse=True, + ) + + +def test_list_memory_review_records_returns_summary_and_stable_shape() -> None: + store = MemoryReviewStoreStub() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + deleted_time = base_time + timedelta(minutes=1) + active_time = base_time + timedelta(minutes=2) + deleted_id = uuid4() + active_id = uuid4() + store.memories = [ + { + "id": active_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": active_time, + "deleted_at": None, + }, + { + "id": deleted_id, + "user_id": uuid4(), + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "deleted", + "source_event_ids": ["event-1"], + "created_at": base_time, + "updated_at": deleted_time, + "deleted_at": deleted_time, + }, + ] + + payload = list_memory_review_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + status="all", + limit=1, + ) + + assert payload == { + "items": [ + { + "id": str(active_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": "2026-03-11T12:00:00+00:00", + "updated_at": "2026-03-11T12:02:00+00:00", + "deleted_at": None, + } + ], + "summary": { + "status": "all", + "limit": 1, + "returned_count": 1, + "total_count": 2, + "has_more": True, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + +def test_get_memory_review_record_raises_not_found_for_inaccessible_memory() -> None: + store = MemoryReviewStoreStub() + + with pytest.raises(MemoryReviewNotFoundError, match="was not found"): + get_memory_review_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=uuid4(), + ) + + +def test_list_memory_review_queue_records_returns_only_active_unlabeled_memories_in_stable_order() -> None: + store = MemoryReviewStoreStub() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + deleted_id = uuid4() + labeled_id = uuid4() + newest_unlabeled_id = uuid4() + older_unlabeled_id = uuid4() + store.memories = [ + { + "id": newest_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-4"], + "created_at": base_time + timedelta(minutes=3), + "updated_at": base_time + timedelta(minutes=6), + "deleted_at": None, + }, + { + "id": labeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.snack", + "value": {"likes": "chips"}, + "status": "active", + "source_event_ids": ["event-3"], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=5), + "deleted_at": None, + }, + { + "id": older_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=4), + "deleted_at": None, + }, + { + "id": deleted_id, + "user_id": uuid4(), + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "deleted", + "source_event_ids": ["event-1"], + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=7), + "deleted_at": base_time + timedelta(minutes=7), + }, + ] + store.labels[labeled_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": labeled_id, + "label": "correct", + "note": "Already reviewed.", + "created_at": base_time + timedelta(minutes=8), + } + ] + + payload = list_memory_review_queue_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + limit=2, + ) + + assert payload == { + "items": [ + { + "id": str(newest_unlabeled_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-4"], + "created_at": "2026-03-11T12:03:00+00:00", + "updated_at": "2026-03-11T12:06:00+00:00", + }, + { + "id": str(older_unlabeled_id), + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": "2026-03-11T12:01:00+00:00", + "updated_at": "2026-03-11T12:04:00+00:00", + }, + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 2, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + +def test_list_memory_revision_review_records_returns_deterministic_revision_order() -> None: + store = MemoryReviewStoreStub() + memory_id = uuid4() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + store.memories = [ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=2), + "deleted_at": None, + } + ] + store.revisions[memory_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": memory_id, + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "candidate": {"memory_key": "user.preference.coffee"}, + "created_at": base_time, + }, + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": memory_id, + "sequence_no": 2, + "action": "UPDATE", + "memory_key": "user.preference.coffee", + "previous_value": {"likes": "black"}, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-2"], + "candidate": {"memory_key": "user.preference.coffee"}, + "created_at": base_time + timedelta(minutes=1), + }, + ] + + payload = list_memory_revision_review_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=memory_id, + limit=10, + ) + + assert payload == { + "items": [ + { + "id": str(store.revisions[memory_id][0]["id"]), + "memory_id": str(memory_id), + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T12:00:00+00:00", + }, + { + "id": str(store.revisions[memory_id][1]["id"]), + "memory_id": str(memory_id), + "sequence_no": 2, + "action": "UPDATE", + "memory_key": "user.preference.coffee", + "previous_value": {"likes": "black"}, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-2"], + "created_at": "2026-03-11T12:01:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "limit": 10, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["sequence_no_asc"], + }, + } + + +def test_create_memory_review_label_record_returns_created_label_and_summary_counts() -> None: + store = MemoryReviewStoreStub() + memory_id = uuid4() + reviewer_user_id = uuid4() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + store.memories = [ + { + "id": memory_id, + "user_id": reviewer_user_id, + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + } + ] + store.labels[memory_id] = [ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "correct", + "note": "Matches the latest cited event.", + "created_at": datetime(2026, 3, 11, 12, 30, tzinfo=UTC), + } + ] + + payload = create_memory_review_label_record( + store, # type: ignore[arg-type] + user_id=reviewer_user_id, + memory_id=memory_id, + label="outdated", + note="Superseded by the newer milk preference.", + ) + + assert payload == { + "label": { + "id": payload["label"]["id"], + "memory_id": str(memory_id), + "reviewer_user_id": payload["label"]["reviewer_user_id"], + "label": "outdated", + "note": "Superseded by the newer milk preference.", + "created_at": "2026-03-11T13:01:00+00:00", + }, + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_create_memory_review_label_record_raises_not_found_for_inaccessible_memory() -> None: + store = MemoryReviewStoreStub() + + with pytest.raises(MemoryReviewNotFoundError, match="was not found"): + create_memory_review_label_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=uuid4(), + label="correct", + note=None, + ) + + +def test_list_memory_review_label_records_returns_deterministic_order_and_zero_filled_counts() -> None: + store = MemoryReviewStoreStub() + memory_id = uuid4() + reviewer_user_id = uuid4() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + store.memories = [ + { + "id": memory_id, + "user_id": reviewer_user_id, + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + } + ] + store.labels[memory_id] = [ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "incorrect", + "note": "The source event only mentions tea.", + "created_at": datetime(2026, 3, 11, 12, 15, tzinfo=UTC), + }, + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "insufficient_evidence", + "note": None, + "created_at": datetime(2026, 3, 11, 12, 16, tzinfo=UTC), + }, + ] + + payload = list_memory_review_label_records( + store, # type: ignore[arg-type] + user_id=reviewer_user_id, + memory_id=memory_id, + ) + + assert payload == { + "items": [ + { + "id": str(store.labels[memory_id][0]["id"]), + "memory_id": str(memory_id), + "reviewer_user_id": str(reviewer_user_id), + "label": "incorrect", + "note": "The source event only mentions tea.", + "created_at": "2026-03-11T12:15:00+00:00", + }, + { + "id": str(store.labels[memory_id][1]["id"]), + "memory_id": str(memory_id), + "reviewer_user_id": str(reviewer_user_id), + "label": "insufficient_evidence", + "note": None, + "created_at": "2026-03-11T12:16:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 0, + "insufficient_evidence": 1, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_get_memory_evaluation_summary_returns_explicit_consistent_counts() -> None: + store = MemoryReviewStoreStub() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + active_labeled_id = uuid4() + active_unlabeled_id = uuid4() + deleted_labeled_id = uuid4() + deleted_unlabeled_id = uuid4() + store.memories = [ + { + "id": active_labeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + }, + { + "id": active_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=1), + "deleted_at": None, + }, + { + "id": deleted_labeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.snack", + "value": {"likes": "chips"}, + "status": "deleted", + "source_event_ids": ["event-3"], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=2), + "deleted_at": base_time + timedelta(minutes=2), + }, + { + "id": deleted_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "deleted", + "source_event_ids": ["event-4"], + "created_at": base_time + timedelta(minutes=3), + "updated_at": base_time + timedelta(minutes=3), + "deleted_at": base_time + timedelta(minutes=3), + }, + ] + store.labels[active_labeled_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": active_labeled_id, + "label": "correct", + "note": "Looks right.", + "created_at": base_time + timedelta(minutes=4), + }, + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": active_labeled_id, + "label": "insufficient_evidence", + "note": "Needs another source.", + "created_at": base_time + timedelta(minutes=5), + }, + ] + store.labels[deleted_labeled_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": deleted_labeled_id, + "label": "outdated", + "note": None, + "created_at": base_time + timedelta(minutes=6), + } + ] + + payload = get_memory_evaluation_summary( + store, # type: ignore[arg-type] + user_id=uuid4(), + ) + + assert payload == { + "summary": { + "total_memory_count": 4, + "active_memory_count": 2, + "deleted_memory_count": 2, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } diff --git a/tests/unit/test_memory_store.py b/tests/unit/test_memory_store.py new file mode 100644 index 0000000..9b09755 --- /dev/null +++ b/tests/unit/test_memory_store.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb +import pytest + +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_memory_methods_use_expected_queries_and_payload_serialization() -> None: + memory_id = uuid4() + event_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "status": "active", + "source_event_ids": [str(event_id)], + }, + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_id)], + }, + { + "id": uuid4(), + "memory_id": memory_id, + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": [str(event_id)], + "candidate": {"memory_key": "user.preference.coffee"}, + }, + ], + fetchall_results=[ + [{"id": event_id, "sequence_no": 1}], + [{"sequence_no": 1, "action": "ADD"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_id)], + ) + updated = store.update_memory( + memory_id=memory_id, + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(event_id)], + ) + revision = store.append_memory_revision( + memory_id=memory_id, + action="ADD", + memory_key="user.preference.coffee", + previous_value=None, + new_value={"likes": "black"}, + source_event_ids=[str(event_id)], + candidate={"memory_key": "user.preference.coffee"}, + ) + listed_events = store.list_events_by_ids([event_id]) + listed_revisions = store.list_memory_revisions(memory_id) + listed_context_memories = store.list_context_memories() + + assert created["id"] == memory_id + assert updated["value"] == {"likes": "oat milk"} + assert revision["sequence_no"] == 1 + assert listed_events == [{"id": event_id, "sequence_no": 1}] + assert listed_revisions == [{"sequence_no": 1, "action": "ADD"}] + assert listed_context_memories == [] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO memories" in create_query + assert "clock_timestamp()" in create_query + assert create_params is not None + assert create_params[0] == "user.preference.coffee" + assert isinstance(create_params[1], Jsonb) + assert create_params[1].obj == {"likes": "black"} + assert create_params[2] == "active" + assert isinstance(create_params[3], Jsonb) + assert create_params[3].obj == [str(event_id)] + + update_query, update_params = cursor.executed[1] + assert "UPDATE memories" in update_query + assert "updated_at = clock_timestamp()" in update_query + assert "THEN clock_timestamp()" in update_query + assert update_params is not None + assert isinstance(update_params[0], Jsonb) + assert update_params[0].obj == {"likes": "oat milk"} + assert update_params[1] == "active" + assert isinstance(update_params[2], Jsonb) + assert update_params[2].obj == [str(event_id)] + assert update_params[3] == "active" + assert update_params[4] == memory_id + + assert cursor.executed[2] == ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 1))", + (str(memory_id),), + ) + append_revision_query, append_revision_params = cursor.executed[3] + assert "INSERT INTO memory_revisions" in append_revision_query + assert append_revision_params is not None + assert append_revision_params[:4] == ( + memory_id, + memory_id, + "ADD", + "user.preference.coffee", + ) + assert isinstance(append_revision_params[4], Jsonb) + assert append_revision_params[4].obj is None + assert isinstance(append_revision_params[5], Jsonb) + assert append_revision_params[5].obj == {"likes": "black"} + assert isinstance(append_revision_params[6], Jsonb) + assert append_revision_params[6].obj == [str(event_id)] + assert isinstance(append_revision_params[7], Jsonb) + assert append_revision_params[7].obj == {"memory_key": "user.preference.coffee"} + assert cursor.executed[6] == ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY updated_at ASC, created_at ASC, id ASC + """, + None, + ) + + +def test_get_memory_by_key_returns_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_memory_by_key("user.preference.coffee") is None + + +def test_append_memory_revision_raises_clear_error_when_returning_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + with pytest.raises( + ContinuityStoreInvariantError, + match="append_memory_revision did not return a row", + ): + store.append_memory_revision( + memory_id=uuid4(), + action="ADD", + memory_key="user.preference.coffee", + previous_value=None, + new_value={"likes": "black"}, + source_event_ids=["event-1"], + candidate={"memory_key": "user.preference.coffee"}, + ) + + +def test_memory_review_read_methods_use_explicit_order_filter_and_limit() -> None: + memory_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "status": "active", + "source_event_ids": ["event-1"], + }, + {"count": 2}, + {"count": 3}, + ], + fetchall_results=[ + [{"id": memory_id, "memory_key": "user.preference.coffee"}], + [{"sequence_no": 1, "action": "ADD"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + memory = store.get_memory_optional(memory_id) + memory_count = store.count_memories(status="active") + listed_memories = store.list_review_memories(status="active", limit=5) + revision_count = store.count_memory_revisions(memory_id) + listed_revisions = store.list_memory_revisions(memory_id, limit=2) + + assert memory is not None + assert memory["id"] == memory_id + assert memory_count == 2 + assert listed_memories == [{"id": memory_id, "memory_key": "user.preference.coffee"}] + assert revision_count == 3 + assert listed_revisions == [{"sequence_no": 1, "action": "ADD"}] + assert cursor.executed == [ + ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = %s + """, + (memory_id,), + ), + ( + """ + SELECT COUNT(*) AS count + FROM memories + WHERE status = %s + """, + ("active",), + ), + ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE status = %s + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """, + ("active", 5), + ), + ( + """ + SELECT COUNT(*) AS count + FROM memory_revisions + WHERE memory_id = %s + """, + (memory_id,), + ), + ( + """ + SELECT id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + FROM memory_revisions + WHERE memory_id = %s + ORDER BY sequence_no ASC + LIMIT %s + """, + (memory_id, 2), + ), + ] + + +def test_memory_review_label_methods_use_append_only_queries_and_deterministic_order() -> None: + memory_id = uuid4() + reviewer_user_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "correct", + "note": "Supported by the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + } + ], + fetchall_results=[ + [ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "correct", + "note": "Supported by the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + } + ], + [ + {"label": "correct", "count": 1}, + {"label": "outdated", "count": 2}, + ], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_memory_review_label( + memory_id=memory_id, + label="correct", + note="Supported by the latest event.", + ) + listed = store.list_memory_review_labels(memory_id) + counts = store.list_memory_review_label_counts(memory_id) + + assert created["memory_id"] == memory_id + assert listed[0]["label"] == "correct" + assert counts == [{"label": "correct", "count": 1}, {"label": "outdated", "count": 2}] + assert cursor.executed == [ + ( + """ + INSERT INTO memory_review_labels (user_id, memory_id, label, note) + VALUES (app.current_user_id(), %s, %s, %s) + RETURNING id, user_id, memory_id, label, note, created_at + """, + (memory_id, "correct", "Supported by the latest event."), + ), + ( + """ + SELECT id, user_id, memory_id, label, note, created_at + FROM memory_review_labels + WHERE memory_id = %s + ORDER BY created_at ASC, id ASC + """, + (memory_id,), + ), + ( + """ + SELECT label, COUNT(*) AS count + FROM memory_review_labels + WHERE memory_id = %s + GROUP BY label + ORDER BY label ASC + """, + (memory_id,), + ), + ] diff --git a/tests/unit/test_ops_assets.py b/tests/unit/test_ops_assets.py new file mode 100644 index 0000000..ddabcbb --- /dev/null +++ b/tests/unit/test_ops_assets.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def test_dev_up_waits_for_postgres_and_role_bootstrap() -> None: + script = (REPO_ROOT / "scripts" / "dev_up.sh").read_text() + + assert "Timed out waiting for Postgres readiness and alicebot_app bootstrap" in script + assert "SELECT 1 FROM pg_roles WHERE rolname = %s" in script + + +def test_runtime_role_init_only_grants_connect_on_alicebot_database() -> None: + init_sql = (REPO_ROOT / "infra" / "postgres" / "init" / "001_roles.sql").read_text() + + assert "GRANT CONNECT ON DATABASE alicebot TO alicebot_app;" in init_sql + assert "GRANT CONNECT ON DATABASE postgres TO alicebot_app;" not in init_sql diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py new file mode 100644 index 0000000..9f5c40e --- /dev/null +++ b/tests/unit/test_policy.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import ConsentUpsertInput, PolicyCreateInput, PolicyEvaluationRequestInput +from alicebot_api.policy import ( + PolicyEvaluationValidationError, + PolicyNotFoundError, + create_policy_record, + evaluate_policy_request, + get_policy_record, + list_consent_records, + list_policy_records, + upsert_consent_record, +) + + +class PolicyStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.consents: dict[str, dict[str, object]] = {} + self.policies: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def create_consent(self, *, consent_key: str, status: str, metadata: dict[str, object]) -> dict[str, object]: + consent = { + "id": uuid4(), + "user_id": self.user_id, + "consent_key": consent_key, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.consents)), + "updated_at": self.base_time + timedelta(minutes=len(self.consents)), + } + self.consents[consent_key] = consent + return consent + + def get_consent_by_key_optional(self, consent_key: str) -> dict[str, object] | None: + return self.consents.get(consent_key) + + def list_consents(self) -> list[dict[str, object]]: + return sorted( + self.consents.values(), + key=lambda consent: (consent["consent_key"], consent["created_at"], consent["id"]), + ) + + def update_consent(self, *, consent_id: UUID, status: str, metadata: dict[str, object]) -> dict[str, object]: + for consent in self.consents.values(): + if consent["id"] != consent_id: + continue + consent["status"] = status + consent["metadata"] = metadata + consent["updated_at"] = consent["updated_at"] + timedelta(minutes=5) + return consent + raise AssertionError("missing consent") + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: dict[str, object], + required_consents: list[str], + ) -> dict[str, object]: + policy = { + "id": uuid4(), + "user_id": self.user_id, + "name": name, + "action": action, + "scope": scope, + "effect": effect, + "priority": priority, + "active": active, + "conditions": conditions, + "required_consents": required_consents, + "created_at": self.base_time + timedelta(minutes=len(self.policies)), + "updated_at": self.base_time + timedelta(minutes=len(self.policies)), + } + self.policies.append(policy) + return policy + + def list_policies(self) -> list[dict[str, object]]: + return sorted( + self.policies, + key=lambda policy: (policy["priority"], policy["created_at"], policy["id"]), + ) + + def get_policy_optional(self, policy_id: UUID) -> dict[str, object] | None: + return next((policy for policy in self.policies if policy["id"] == policy_id), None) + + def list_active_policies(self) -> list[dict[str, object]]: + return [policy for policy in self.list_policies() if policy["active"] is True] + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Policy thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time, + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time, + } + self.trace_events.append(event) + return event + + +def test_upsert_consent_record_creates_and_updates_in_place() -> None: + store = PolicyStoreStub() + + created = upsert_consent_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + consent=ConsentUpsertInput( + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ), + ) + updated = upsert_consent_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + consent=ConsentUpsertInput( + consent_key="email_marketing", + status="revoked", + metadata={"source": "banner"}, + ), + ) + + assert created["write_mode"] == "created" + assert updated["write_mode"] == "updated" + assert updated["consent"]["id"] == created["consent"]["id"] + assert updated["consent"]["status"] == "revoked" + assert updated["consent"]["metadata"] == {"source": "banner"} + + +def test_list_consent_records_returns_deterministic_shape() -> None: + store = PolicyStoreStub() + zeta = store.create_consent(consent_key="zeta", status="granted", metadata={}) + alpha = store.create_consent(consent_key="alpha", status="revoked", metadata={"reason": "user"}) + + payload = list_consent_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + + assert payload == { + "items": [ + { + "id": str(alpha["id"]), + "consent_key": "alpha", + "status": "revoked", + "metadata": {"reason": "user"}, + "created_at": alpha["created_at"].isoformat(), + "updated_at": alpha["updated_at"].isoformat(), + }, + { + "id": str(zeta["id"]), + "consent_key": "zeta", + "status": "granted", + "metadata": {}, + "created_at": zeta["created_at"].isoformat(), + "updated_at": zeta["updated_at"].isoformat(), + }, + ], + "summary": { + "total_count": 2, + "order": ["consent_key_asc", "created_at_asc", "id_asc"], + }, + } + + +def test_create_and_list_policy_records_preserve_priority_order_and_shape() -> None: + store = PolicyStoreStub() + first = create_policy_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + policy=PolicyCreateInput( + name="Require approval for exports", + action="memory.export", + scope="profile", + effect="require_approval", + priority=20, + active=True, + conditions={"channel": "email"}, + required_consents=("email_marketing", "email_marketing"), + ), + ) + second_policy = store.create_policy( + name="Allow low risk read", + action="memory.read", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=[], + ) + + list_payload = list_policy_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail_payload = get_policy_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + policy_id=UUID(first["policy"]["id"]), + ) + + assert first["policy"]["required_consents"] == ["email_marketing"] + assert [item["id"] for item in list_payload["items"]] == [ + str(second_policy["id"]), + first["policy"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert detail_payload == {"policy": first["policy"]} + + +def test_get_policy_record_raises_not_found_for_inaccessible_policy() -> None: + with pytest.raises(PolicyNotFoundError, match="policy .* was not found"): + get_policy_record( + PolicyStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + policy_id=uuid4(), + ) + + +def test_evaluate_policy_request_uses_first_matching_policy_and_emits_trace() -> None: + store = PolicyStoreStub() + store.create_consent(consent_key="email_marketing", status="granted", metadata={"source": "settings"}) + higher_priority_match = store.create_policy( + name="Allow email export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={"channel": "email"}, + required_consents=["email_marketing"], + ) + store.create_policy( + name="Deny fallback export", + action="memory.export", + scope="profile", + effect="deny", + priority=20, + active=True, + conditions={"channel": "email"}, + required_consents=[], + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={"channel": "email"}, + ), + ) + + assert payload["decision"] == "allow" + assert payload["matched_policy"]["id"] == str(higher_priority_match["id"]) + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "policy_effect_allow", + ] + assert payload["evaluation"] == { + "action": "memory.export", + "scope": "profile", + "evaluated_policy_count": 2, + "matched_policy_id": str(higher_priority_match["id"]), + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 3 + assert [event["kind"] for event in store.trace_events] == [ + "policy.evaluate.request", + "policy.evaluate.order", + "policy.evaluate.decision", + ] + + +def test_evaluate_policy_request_denies_when_required_consent_is_missing() -> None: + store = PolicyStoreStub() + matched_policy = store.create_policy( + name="Allow export with consent", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={}, + ), + ) + + assert payload["decision"] == "deny" + assert payload["matched_policy"]["id"] == str(matched_policy["id"]) + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "consent_missing", + ] + + +def test_evaluate_policy_request_denies_when_required_consent_is_revoked() -> None: + store = PolicyStoreStub() + matched_policy = store.create_policy( + name="Allow export with consent", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + store.create_consent( + consent_key="email_marketing", + status="revoked", + metadata={"source": "settings"}, + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={}, + ), + ) + + assert payload["decision"] == "deny" + assert payload["matched_policy"]["id"] == str(matched_policy["id"]) + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "consent_revoked", + ] + + +def test_evaluate_policy_request_returns_require_approval_and_validates_thread_scope() -> None: + store = PolicyStoreStub() + matched_policy = store.create_policy( + name="Escalate export", + action="memory.export", + scope="profile", + effect="require_approval", + priority=10, + active=True, + conditions={}, + required_consents=[], + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={}, + ), + ) + + assert payload["decision"] == "require_approval" + assert payload["matched_policy"]["id"] == str(matched_policy["id"]) + assert payload["reasons"][-1]["code"] == "policy_effect_require_approval" + + with pytest.raises( + PolicyEvaluationValidationError, + match="thread_id must reference an existing thread owned by the user", + ): + evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=uuid4(), + action="memory.export", + scope="profile", + attributes={}, + ), + ) diff --git a/tests/unit/test_policy_main.py b/tests/unit/test_policy_main.py new file mode 100644 index 0000000..fa3e4e5 --- /dev/null +++ b/tests/unit/test_policy_main.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.policy import PolicyEvaluationValidationError, PolicyNotFoundError + + +def test_upsert_consent_endpoint_translates_request_and_returns_created_status(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_upsert_consent_record(store, *, user_id, consent): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["consent"] = consent + return { + "consent": { + "id": "consent-123", + "consent_key": "email_marketing", + "status": "granted", + "metadata": {"source": "settings"}, + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + }, + "write_mode": "created", + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "upsert_consent_record", fake_upsert_consent_record) + + response = main_module.upsert_consent( + main_module.UpsertConsentRequest( + user_id=user_id, + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "consent": { + "id": "consent-123", + "consent_key": "email_marketing", + "status": "granted", + "metadata": {"source": "settings"}, + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + }, + "write_mode": "created", + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["consent"].consent_key == "email_marketing" + assert captured["consent"].status == "granted" + assert captured["consent"].metadata == {"source": "settings"} + + +def test_get_policy_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + policy_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_policy_record(*_args, **_kwargs): + raise PolicyNotFoundError(f"policy {policy_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_policy_record", fake_get_policy_record) + + response = main_module.get_policy(policy_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"policy {policy_id} was not found"} + + +def test_evaluate_policy_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_evaluate_policy_request(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "decision": "allow", + "matched_policy": { + "id": "policy-123", + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + }, + "reasons": [ + { + "code": "matched_policy", + "source": "policy", + "message": "Matched policy 'Allow export' at priority 10.", + "policy_id": "policy-123", + "consent_key": None, + }, + { + "code": "policy_effect_allow", + "source": "policy", + "message": "Policy effect resolved the decision to 'allow'.", + "policy_id": "policy-123", + "consent_key": None, + }, + ], + "evaluation": { + "action": "memory.export", + "scope": "profile", + "evaluated_policy_count": 1, + "matched_policy_id": "policy-123", + "order": ["priority_asc", "created_at_asc", "id_asc"], + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_policy_request", fake_evaluate_policy_request) + + response = main_module.evaluate_policy( + main_module.EvaluatePolicyRequest( + user_id=user_id, + thread_id=thread_id, + action="memory.export", + scope="profile", + attributes={"channel": "email"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == {"trace_id": "trace-123", "trace_event_count": 3} + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].action == "memory.export" + assert captured["request"].scope == "profile" + assert captured["request"].attributes == {"channel": "email"} + + +def test_evaluate_policy_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_evaluate_policy_request(*_args, **_kwargs): + raise PolicyEvaluationValidationError("thread_id must reference an existing thread owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_policy_request", fake_evaluate_policy_request) + + response = main_module.evaluate_policy( + main_module.EvaluatePolicyRequest( + user_id=user_id, + thread_id=uuid4(), + action="memory.export", + scope="profile", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "thread_id must reference an existing thread owned by the user" + } diff --git a/tests/unit/test_policy_store.py b/tests/unit/test_policy_store.py new file mode 100644 index 0000000..6d33734 --- /dev/null +++ b/tests/unit/test_policy_store.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_consent_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + consent_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + {"id": consent_id, "consent_key": "email_marketing", "status": "granted", "metadata": {}}, + {"id": consent_id, "consent_key": "email_marketing", "status": "revoked", "metadata": {"source": "banner"}}, + ], + fetchall_result=[{"id": consent_id, "consent_key": "email_marketing", "status": "revoked", "metadata": {}}], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_consent( + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ) + updated = store.update_consent( + consent_id=consent_id, + status="revoked", + metadata={"source": "banner"}, + ) + listed = store.list_consents() + + assert created["id"] == consent_id + assert updated["status"] == "revoked" + assert listed == [{"id": consent_id, "consent_key": "email_marketing", "status": "revoked", "metadata": {}}] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO consents" in create_query + assert create_params is not None + assert create_params[:2] == ("email_marketing", "granted") + assert isinstance(create_params[2], Jsonb) + assert create_params[2].obj == {"source": "settings"} + + update_query, update_params = cursor.executed[1] + assert "UPDATE consents" in update_query + assert update_params is not None + assert update_params[0] == "revoked" + assert isinstance(update_params[1], Jsonb) + assert update_params[1].obj == {"source": "banner"} + assert update_params[2] == consent_id + + assert cursor.executed[2] == ( + """ + SELECT id, user_id, consent_key, status, metadata, created_at, updated_at + FROM consents + ORDER BY consent_key ASC, created_at ASC, id ASC + """, + None, + ) + + +def test_policy_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + policy_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": policy_id, + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + }, + { + "id": policy_id, + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + }, + ], + fetchall_result=[ + { + "id": policy_id, + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_policy( + name="Allow export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={"channel": "email"}, + required_consents=["email_marketing"], + ) + fetched = store.get_policy_optional(policy_id) + listed = store.list_active_policies() + + assert created["id"] == policy_id + assert fetched is not None + assert listed[0]["id"] == policy_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO policies" in create_query + assert create_params is not None + assert create_params[:6] == ("Allow export", "memory.export", "profile", "allow", 10, True) + assert isinstance(create_params[6], Jsonb) + assert create_params[6].obj == {"channel": "email"} + assert isinstance(create_params[7], Jsonb) + assert create_params[7].obj == ["email_marketing"] + + assert cursor.executed[1] == ( + """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + WHERE id = %s + """, + (policy_id,), + ) + assert "WHERE active = TRUE" in cursor.executed[2][0] diff --git a/tests/unit/test_proxy_execution.py b/tests/unit/test_proxy_execution.py new file mode 100644 index 0000000..d1f4a61 --- /dev/null +++ b/tests/unit/test_proxy_execution.py @@ -0,0 +1,783 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.approvals import ApprovalNotFoundError +from alicebot_api.contracts import ProxyExecutionRequestInput +from alicebot_api.proxy_execution import ( + PROXY_EXECUTION_REQUEST_EVENT_KIND, + PROXY_EXECUTION_RESULT_EVENT_KIND, + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, + execute_approved_proxy_request, + registered_proxy_handler_keys, +) + + +class ProxyExecutionStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.locked_task_ids: list[UUID] = [] + self.approvals: dict[UUID, dict[str, object]] = {} + self.tasks: list[dict[str, object]] = [] + self.task_steps: list[dict[str, object]] = [] + self.events: list[dict[str, object]] = [] + self.tool_executions: list[dict[str, object]] = [] + self.execution_budgets: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def current_time(self) -> datetime: + return self.base_time + timedelta(minutes=len(self.tool_executions)) + + def seed_approval(self, *, status: str, tool_key: str) -> dict[str, object]: + approval_id = uuid4() + tool_id = uuid4() + created_at = self.base_time + timedelta(minutes=len(self.approvals)) + approval = { + "id": approval_id, + "user_id": self.user_id, + "thread_id": self.thread_id, + "tool_id": tool_id, + "task_step_id": None, + "status": status, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + "name": "Proxy Echo" if tool_key == "proxy.echo" else "Unregistered Proxy", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": created_at.isoformat(), + }, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": str(uuid4()), "trace_event_count": 3}, + }, + "routing_trace_id": uuid4(), + "created_at": created_at, + "resolved_at": None if status == "pending" else created_at + timedelta(minutes=30), + "resolved_by_user_id": None if status == "pending" else self.user_id, + } + self.approvals[approval_id] = approval + task = self.create_task( + thread_id=self.thread_id, + tool_id=tool_id, + status={ + "pending": "pending_approval", + "approved": "approved", + "rejected": "denied", + }[status], + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval_id, + latest_execution_id=None, + ) + task_step = self.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status={ + "pending": "created", + "approved": "approved", + "rejected": "denied", + }[status], + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval_id), + "approval_status": status, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request" if status == "pending" else "approval.resolve", + ) + approval["task_step_id"] = task_step["id"] + return approval + + def seed_execution_budget( + self, + *, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, + supersedes_budget_id: UUID | None = None, + ) -> dict[str, object]: + row = { + "id": uuid4(), + "user_id": self.user_id, + "tool_key": tool_key, + "domain_hint": domain_hint, + "max_completed_executions": max_completed_executions, + "rolling_window_seconds": rolling_window_seconds, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": supersedes_budget_id, + "created_at": self.base_time + timedelta(minutes=len(self.execution_budgets)), + } + self.execution_budgets.append(row) + self.execution_budgets.sort(key=lambda item: (item["created_at"], item["id"])) + return row + + def get_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return self.approvals.get(approval_id) + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + def create_task( + self, + *, + thread_id: UUID, + tool_id: UUID, + status: str, + request: dict[str, object], + tool: dict[str, object], + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object]: + task = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "tool_id": tool_id, + "status": status, + "request": request, + "tool": tool, + "latest_approval_id": latest_approval_id, + "latest_execution_id": latest_execution_id, + "created_at": self.base_time + timedelta(minutes=len(self.tasks)), + "updated_at": self.base_time + timedelta(minutes=len(self.tasks)), + } + self.tasks.append(task) + return task + + def get_task_by_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["latest_approval_id"] == approval_id), None) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def get_task_step_optional(self, task_step_id: UUID) -> dict[str, object] | None: + return next((task_step for task_step in self.task_steps if task_step["id"] == task_step_id), None) + + def lock_task_steps(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def update_task_execution_by_approval_optional( + self, + *, + approval_id: UUID, + latest_execution_id: UUID, + status: str, + ) -> dict[str, object] | None: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + return None + task["status"] = status + task["latest_execution_id"] = latest_execution_id + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def append_event( + self, + thread_id: UUID, + session_id: UUID | None, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "session_id": session_id, + "sequence_no": len(self.events) + 1, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.events)), + } + self.events.append(event) + return event + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: dict[str, object], + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object]: + task_step = { + "id": uuid4(), + "user_id": self.user_id, + "task_id": task_id, + "sequence_no": sequence_no, + "parent_step_id": parent_step_id, + "source_approval_id": source_approval_id, + "source_execution_id": source_execution_id, + "kind": kind, + "status": status, + "request": request, + "outcome": outcome, + "trace_id": trace_id, + "trace_kind": trace_kind, + "created_at": self.base_time + timedelta(minutes=len(self.task_steps)), + "updated_at": self.base_time + timedelta(minutes=len(self.task_steps)), + } + self.task_steps.append(task_step) + return task_step + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> dict[str, object] | None: + return next( + ( + task_step + for task_step in self.task_steps + if task_step["task_id"] == task_id and task_step["sequence_no"] == sequence_no + ), + None, + ) + + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + [task_step for task_step in self.task_steps if task_step["task_id"] == task_id], + key=lambda task_step: (task_step["sequence_no"], task_step["created_at"], task_step["id"]), + ) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_for_task_sequence_optional(task_id=task_id, sequence_no=sequence_no) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_optional(task_step_id) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def create_tool_execution( + self, + *, + approval_id: UUID, + task_step_id: UUID, + thread_id: UUID, + tool_id: UUID, + trace_id: UUID, + request_event_id: UUID | None, + result_event_id: UUID | None, + status: str, + handler_key: str | None, + request: dict[str, object], + tool: dict[str, object], + result: dict[str, object], + ) -> dict[str, object]: + execution = { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": approval_id, + "task_step_id": task_step_id, + "thread_id": thread_id, + "tool_id": tool_id, + "trace_id": trace_id, + "request_event_id": request_event_id, + "result_event_id": result_event_id, + "status": status, + "handler_key": handler_key, + "request": request, + "tool": tool, + "result": result, + "executed_at": self.base_time + timedelta(minutes=len(self.tool_executions)), + } + self.tool_executions.append(execution) + return execution + + def list_execution_budgets(self) -> list[dict[str, object]]: + return list(self.execution_budgets) + + def list_tool_executions(self) -> list[dict[str, object]]: + return list(self.tool_executions) + + +def test_execute_approved_proxy_request_returns_result_and_persists_events() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert list(payload) == ["request", "approval", "tool", "result", "events", "trace"] + assert payload["request"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(approval["task_step_id"]), + } + assert payload["approval"]["status"] == "approved" + assert payload["tool"]["tool_key"] == "proxy.echo" + assert payload["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": { + "mode": "no_side_effect", + "tool_key": "proxy.echo", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + } + assert payload["events"]["request_sequence_no"] == 1 + assert payload["events"]["result_sequence_no"] == 2 + assert payload["trace"]["trace_event_count"] == 9 + assert len(store.tool_executions) == 1 + assert store.tool_executions[0]["approval_id"] == approval["id"] + assert store.tool_executions[0]["task_step_id"] == approval["task_step_id"] + assert store.tool_executions[0]["trace_id"] == UUID(payload["trace"]["trace_id"]) + assert store.tool_executions[0]["handler_key"] == "proxy.echo" + assert store.tasks[0]["status"] == "executed" + assert store.task_steps[0]["status"] == "executed" + assert store.tasks[0]["latest_execution_id"] == store.tool_executions[0]["id"] + assert store.tool_executions[0]["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": payload["result"]["output"], + "reason": None, + } + assert [event["kind"] for event in store.events] == [ + PROXY_EXECUTION_REQUEST_EVENT_KIND, + PROXY_EXECUTION_RESULT_EVENT_KIND, + ] + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + + +def test_execute_approved_proxy_request_locks_task_steps_before_persisting_execution_state() -> None: + class LockingProxyExecutionStoreStub(ProxyExecutionStoreStub): + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + if task_id not in self.locked_task_ids: + raise AssertionError("task-step boundary was checked before the task-step lock was taken") + return super().list_task_steps_for_task(task_id) + + def create_tool_execution( + self, + *, + approval_id: UUID, + task_step_id: UUID, + thread_id: UUID, + tool_id: UUID, + trace_id: UUID, + request_event_id: UUID | None, + result_event_id: UUID | None, + status: str, + handler_key: str | None, + request: dict[str, object], + tool: dict[str, object], + result: dict[str, object], + ) -> dict[str, object]: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + raise AssertionError("expected task for approval before execution persistence") + if task["id"] not in self.locked_task_ids: + raise AssertionError("tool execution persisted before the task-step lock was taken") + return super().create_tool_execution( + approval_id=approval_id, + task_step_id=task_step_id, + thread_id=thread_id, + tool_id=tool_id, + trace_id=trace_id, + request_event_id=request_event_id, + result_event_id=result_event_id, + status=status, + handler_key=handler_key, + request=request, + tool=tool, + result=result, + ) + + store = LockingProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert payload["result"]["status"] == "completed" + assert store.tasks[0]["id"] in store.locked_task_ids + + +def test_execute_approved_proxy_request_updates_the_linked_later_step_without_mutating_the_original_step() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + task = store.tasks[0] + first_step = store.task_steps[0] + initial_execution_id = uuid4() + task["status"] = "pending_approval" + task["latest_execution_id"] = None + first_step["status"] = "executed" + first_step["outcome"] = { + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "approved", + "execution_id": str(initial_execution_id), + "execution_status": "completed", + "blocked_reason": None, + } + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval["id"], + source_execution_id=initial_execution_id, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + + original_first_trace_id = first_step["trace_id"] + original_first_outcome = dict(first_step["outcome"]) + original_later_trace_id = later_step["trace_id"] + approval["task_step_id"] = later_step["id"] + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert payload["result"]["status"] == "completed" + assert task["status"] == "executed" + assert task["latest_execution_id"] == store.tool_executions[0]["id"] + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "executed" + assert later_step["trace_id"] == UUID(payload["trace"]["trace_id"]) + assert later_step["trace_id"] != original_later_trace_id + assert later_step["outcome"]["execution_id"] == str(store.tool_executions[0]["id"]) + assert later_step["outcome"]["execution_status"] == "completed" + assert store.tool_executions[0]["task_step_id"] == later_step["id"] + assert store.events[0]["payload"]["task_step_id"] == str(later_step["id"]) + assert store.events[1]["payload"]["task_step_id"] == str(later_step["id"]) + assert store.trace_events[0]["payload"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(later_step["id"]), + } + assert store.trace_events[3]["payload"]["task_step_id"] == str(later_step["id"]) + assert store.trace_events[4]["payload"]["task_step_id"] == str(later_step["id"]) + + +@pytest.mark.parametrize("status", ["pending", "rejected"]) +def test_execute_approved_proxy_request_rejects_non_approved_statuses(status: str) -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status=status, tool_key="proxy.echo") + + with pytest.raises( + ProxyExecutionApprovalStateError, + match=rf"approval {approval['id']} is {status} and cannot be executed", + ): + execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert store.events == [] + assert store.tool_executions == [] + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + ] + assert store.trace_events[2]["payload"]["dispatch_status"] == "blocked" + assert store.trace_events[3]["payload"]["execution_status"] == "blocked" + + +def test_execute_approved_proxy_request_rejects_missing_handlers() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.missing") + + with pytest.raises( + ProxyExecutionHandlerNotFoundError, + match="tool 'proxy.missing' has no registered proxy handler", + ): + execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert store.events == [] + assert len(store.tool_executions) == 1 + assert store.tool_executions[0]["status"] == "blocked" + assert store.tool_executions[0]["task_step_id"] == approval["task_step_id"] + assert store.tool_executions[0]["handler_key"] is None + assert store.tool_executions[0]["request_event_id"] is None + assert store.tool_executions[0]["result_event_id"] is None + assert store.tasks[0]["status"] == "blocked" + assert store.task_steps[0]["status"] == "blocked" + assert store.tasks[0]["latest_execution_id"] == store.tool_executions[0]["id"] + assert store.tool_executions[0]["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + } + assert store.trace_events[2]["payload"]["decision"] == "allow" + assert store.trace_events[3]["payload"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(approval["task_step_id"]), + "tool_id": approval["tool"]["id"], + "tool_key": "proxy.missing", + "handler_key": None, + "dispatch_status": "blocked", + "reason": "tool 'proxy.missing' has no registered proxy handler", + "result_status": "blocked", + "output": None, + } + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + + +def test_execute_approved_proxy_request_returns_blocked_budget_response_and_persists_review_record() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + budget = store.seed_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + store.create_tool_execution( + approval_id=uuid4(), + task_step_id=uuid4(), + thread_id=store.thread_id, + tool_id=UUID(approval["tool"]["id"]), + trace_id=uuid4(), + request_event_id=uuid4(), + result_event_id=uuid4(), + status="completed", + handler_key="proxy.echo", + request={ + "thread_id": str(store.thread_id), + "tool_id": approval["tool"]["id"], + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "seed"}, + }, + tool=approval["tool"], + result={ + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + ) + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert payload["events"] is None + assert payload["trace"]["trace_event_count"] == 9 + assert payload["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {budget['id']} blocks execution: projected completed executions " + "2 would exceed limit 1" + ), + "budget_decision": { + "matched_budget_id": str(budget["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + } + assert len(store.events) == 0 + assert len(store.tool_executions) == 2 + assert store.tool_executions[-1]["status"] == "blocked" + assert store.tool_executions[-1]["request_event_id"] is None + assert store.tool_executions[-1]["result_event_id"] is None + assert store.tasks[0]["status"] == "blocked" + assert store.task_steps[0]["status"] == "blocked" + assert store.tasks[0]["latest_execution_id"] == store.tool_executions[-1]["id"] + assert store.tool_executions[-1]["result"] == payload["result"] + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[2]["payload"] == payload["result"]["budget_decision"] + + +def test_execute_approved_proxy_request_rejects_missing_visible_approval() -> None: + store = ProxyExecutionStoreStub() + + with pytest.raises(ApprovalNotFoundError, match="was not found"): + execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=uuid4()), + ) + + +def test_registered_proxy_handler_keys_are_sorted_and_explicit() -> None: + assert registered_proxy_handler_keys() == ("proxy.echo",) diff --git a/tests/unit/test_proxy_execution_main.py b/tests/unit/test_proxy_execution_main.py new file mode 100644 index 0000000..b98b7bb --- /dev/null +++ b/tests/unit/test_proxy_execution_main.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.approvals import ApprovalNotFoundError +from alicebot_api.proxy_execution import ( + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, +) +from alicebot_api.tasks import TaskStepApprovalLinkageError + + +def test_execute_approved_proxy_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_execute_approved_proxy_request(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "request": {"approval_id": str(approval_id), "task_step_id": "task-step-123"}, + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-123", + "status": "approved", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-13T09:00:00+00:00", + "resolution": { + "resolved_at": "2026-03-13T09:30:00+00:00", + "resolved_by_user_id": str(user_id), + }, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + }, + "events": { + "request_event_id": "event-request-123", + "request_sequence_no": 1, + "result_event_id": "event-result-123", + "result_sequence_no": 2, + }, + "trace": {"trace_id": "proxy-trace-123", "trace_event_count": 5}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["request"] == { + "approval_id": str(approval_id), + "task_step_id": "task-step-123", + } + assert json.loads(response.body)["trace"] == { + "trace_id": "proxy-trace-123", + "trace_event_count": 5, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].approval_id == approval_id + + +def test_execute_approved_proxy_endpoint_maps_missing_approval_to_404(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was not found"} + + +def test_execute_approved_proxy_endpoint_maps_blocked_approval_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise ProxyExecutionApprovalStateError( + f"approval {approval_id} is pending and cannot be executed" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"approval {approval_id} is pending and cannot be executed" + } + + +def test_execute_approved_proxy_endpoint_maps_missing_handler_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise ProxyExecutionHandlerNotFoundError( + "tool 'proxy.missing' has no registered proxy handler" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": "tool 'proxy.missing' has no registered proxy handler" + } + + +def test_execute_approved_proxy_endpoint_maps_linkage_error_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise TaskStepApprovalLinkageError( + f"approval {approval_id} is missing linked task_step_id" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"approval {approval_id} is missing linked task_step_id" + } + + +def test_execute_approved_proxy_endpoint_returns_budget_blocked_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + return { + "request": {"approval_id": str(approval_id), "task_step_id": "task-step-123"}, + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-123", + "status": "approved", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-13T09:00:00+00:00", + "resolution": { + "resolved_at": "2026-03-13T09:30:00+00:00", + "resolved_by_user_id": str(user_id), + }, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "execution budget budget-123 blocks execution: projected completed executions 2 would exceed limit 1", + "budget_decision": { + "matched_budget_id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + }, + "events": None, + "trace": {"trace_id": "proxy-trace-456", "trace_event_count": 5}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["events"] is None diff --git a/tests/unit/test_response_generation.py b/tests/unit/test_response_generation.py new file mode 100644 index 0000000..f91c051 --- /dev/null +++ b/tests/unit/test_response_generation.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import json + +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import ( + ModelInvocationRequest, + ModelInvocationResponse, + PROMPT_ASSEMBLY_VERSION_V0, + PromptAssemblyInput, +) +from alicebot_api.response_generation import ( + assemble_prompt, + build_assistant_response_payload, + invoke_model, +) + + +def make_context_pack() -> dict[str, object]: + return { + "compiler_version": "continuity_v0", + "scope": { + "user_id": "11111111-1111-1111-8111-111111111111", + "thread_id": "22222222-2222-2222-8222-222222222222", + }, + "limits": { + "max_sessions": 3, + "max_events": 8, + "max_memories": 5, + "max_entities": 5, + "max_entity_edges": 10, + }, + "user": { + "id": "11111111-1111-1111-8111-111111111111", + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "thread": { + "id": "22222222-2222-2222-8222-222222222222", + "title": "Thread", + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:05:00+00:00", + }, + "sessions": [], + "events": [ + { + "id": "33333333-3333-3333-8333-333333333333", + "session_id": None, + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "Hello"}, + "created_at": "2026-03-12T09:06:00+00:00", + } + ], + "memories": [ + { + "id": "44444444-4444-4444-8444-444444444444", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["33333333-3333-3333-8333-333333333333"], + "created_at": "2026-03-12T09:04:00+00:00", + "updated_at": "2026-03-12T09:05:00+00:00", + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ], + "memory_summary": { + "candidate_count": 1, + "included_count": 1, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [], + "entity_summary": { + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + "entity_edges": [], + "entity_edge_summary": { + "anchor_entity_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + } + + +def test_assemble_prompt_is_deterministic_and_explicit() -> None: + first = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + second = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + + assert first.prompt_text == second.prompt_text + assert first.prompt_sha256 == second.prompt_sha256 + assert first.trace_payload == second.trace_payload + assert [section.name for section in first.sections] == [ + "system", + "developer", + "context", + "conversation", + ] + assert "[SYSTEM]\nSystem instruction" in first.prompt_text + assert "[DEVELOPER]\nDeveloper instruction" in first.prompt_text + assert '"memory_key":"user.preference.coffee"' in first.prompt_text + assert first.trace_payload["version"] == PROMPT_ASSEMBLY_VERSION_V0 + assert first.trace_payload["compile_trace_id"] == "compile-trace-123" + assert first.trace_payload["included_event_count"] == 1 + assert first.trace_payload["included_memory_count"] == 1 + + +class FakeHTTPResponse: + def __init__(self, body: bytes) -> None: + self.body = body + + def __enter__(self) -> "FakeHTTPResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def read(self) -> bytes: + return self.body + + +def test_invoke_model_sends_tools_disabled_request_and_parses_response(monkeypatch) -> None: + captured: dict[str, object] = {} + + def fake_urlopen(request, timeout): + captured["url"] = request.full_url + captured["timeout"] = timeout + captured["headers"] = dict(request.header_items()) + captured["body"] = json.loads(request.data.decode("utf-8")) + return FakeHTTPResponse( + json.dumps( + { + "id": "resp_123", + "status": "completed", + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "Assistant reply"}], + } + ], + "usage": { + "input_tokens": 12, + "output_tokens": 4, + "total_tokens": 16, + }, + } + ).encode("utf-8") + ) + + monkeypatch.setattr("alicebot_api.response_generation.urlopen", fake_urlopen) + + prompt = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + response = invoke_model( + settings=Settings( + model_provider="openai_responses", + model_base_url="https://example.test/v1", + model_name="gpt-5-mini", + model_api_key="secret-key", + model_timeout_seconds=17, + ), + request=ModelInvocationRequest( + provider="openai_responses", + model="gpt-5-mini", + prompt=prompt, + ), + ) + + assert captured["url"] == "https://example.test/v1/responses" + assert captured["timeout"] == 17 + assert captured["headers"]["Authorization"] == "Bearer secret-key" + assert captured["body"]["tool_choice"] == "none" + assert captured["body"]["tools"] == [] + assert captured["body"]["store"] is False + assert [item["role"] for item in captured["body"]["input"]] == [ + "system", + "developer", + "user", + "user", + ] + assert response == ModelInvocationResponse( + provider="openai_responses", + model="gpt-5-mini", + response_id="resp_123", + finish_reason="completed", + output_text="Assistant reply", + usage={"input_tokens": 12, "output_tokens": 4, "total_tokens": 16}, + ) + + +def test_build_assistant_response_payload_captures_model_and_prompt_metadata() -> None: + prompt = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + payload = build_assistant_response_payload( + prompt=prompt, + model_response=ModelInvocationResponse( + provider="openai_responses", + model="gpt-5-mini", + response_id="resp_123", + finish_reason="completed", + output_text="Assistant reply", + usage={"input_tokens": 12, "output_tokens": 4, "total_tokens": 16}, + ), + ) + + assert payload == { + "text": "Assistant reply", + "model": { + "provider": "openai_responses", + "model": "gpt-5-mini", + "response_id": "resp_123", + "finish_reason": "completed", + "usage": {"input_tokens": 12, "output_tokens": 4, "total_tokens": 16}, + }, + "prompt": { + "assembly_version": "prompt_assembly_v0", + "prompt_sha256": prompt.prompt_sha256, + "section_order": ["system", "developer", "context", "conversation"], + }, + } diff --git a/tests/unit/test_semantic_retrieval.py b/tests/unit/test_semantic_retrieval.py new file mode 100644 index 0000000..780b4e4 --- /dev/null +++ b/tests/unit/test_semantic_retrieval.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import SemanticMemoryRetrievalRequestInput +from alicebot_api.semantic_retrieval import ( + SemanticMemoryRetrievalValidationError, + retrieve_semantic_memory_records, +) + + +class SemanticRetrievalStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.config_by_id: dict[UUID, dict[str, object]] = {} + self.retrieval_rows: list[dict[str, object]] = [] + self.last_query: dict[str, object] | None = None + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> dict[str, object] | None: + return self.config_by_id.get(embedding_config_id) + + def retrieve_semantic_memory_matches( + self, + *, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[dict[str, object]]: + self.last_query = { + "embedding_config_id": embedding_config_id, + "query_vector": query_vector, + "limit": limit, + } + return list(self.retrieval_rows[:limit]) + + +def seed_config(store: SemanticRetrievalStoreStub, *, dimensions: int = 3) -> UUID: + config_id = uuid4() + store.config_by_id[config_id] = { + "id": config_id, + "dimensions": dimensions, + } + return config_id + + +def active_row( + store: SemanticRetrievalStoreStub, + *, + memory_key: str, + score: float, + minute_offset: int, +) -> dict[str, object]: + return { + "id": uuid4(), + "user_id": uuid4(), + "memory_key": memory_key, + "value": {"memory_key": memory_key}, + "status": "active", + "source_event_ids": [str(uuid4())], + "created_at": store.base_time + timedelta(minutes=minute_offset), + "updated_at": store.base_time + timedelta(minutes=minute_offset + 1), + "deleted_at": None, + "score": score, + } + + +def test_retrieve_semantic_memory_records_returns_stable_shape_and_summary() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + first_row = active_row(store, memory_key="user.preference.coffee", score=1.0, minute_offset=0) + second_row = active_row(store, memory_key="user.preference.tea", score=0.75, minute_offset=1) + store.retrieval_rows = [first_row, second_row] + + payload = retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=config_id, + query_vector=(0.1, 0.2, 0.3), + limit=2, + ), + ) + + assert payload == { + "items": [ + { + "memory_id": str(first_row["id"]), + "memory_key": "user.preference.coffee", + "value": {"memory_key": "user.preference.coffee"}, + "source_event_ids": first_row["source_event_ids"], + "created_at": first_row["created_at"].isoformat(), + "updated_at": first_row["updated_at"].isoformat(), + "score": 1.0, + }, + { + "memory_id": str(second_row["id"]), + "memory_key": "user.preference.tea", + "value": {"memory_key": "user.preference.tea"}, + "source_event_ids": second_row["source_event_ids"], + "created_at": second_row["created_at"].isoformat(), + "updated_at": second_row["updated_at"].isoformat(), + "score": 0.75, + }, + ], + "summary": { + "embedding_config_id": str(config_id), + "limit": 2, + "returned_count": 2, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert store.last_query == { + "embedding_config_id": config_id, + "query_vector": [0.1, 0.2, 0.3], + "limit": 2, + } + + +def test_retrieve_semantic_memory_records_rejects_missing_config() -> None: + store = SemanticRetrievalStoreStub() + + with pytest.raises( + SemanticMemoryRetrievalValidationError, + match="embedding_config_id must reference an existing embedding config owned by the user", + ): + retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=uuid4(), + query_vector=(0.1, 0.2, 0.3), + ), + ) + + +def test_retrieve_semantic_memory_records_rejects_dimension_mismatch() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + + with pytest.raises( + SemanticMemoryRetrievalValidationError, + match="query_vector length must match embedding config dimensions \\(3\\): 2", + ): + retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=config_id, + query_vector=(0.1, 0.2), + ), + ) + + +def test_retrieve_semantic_memory_records_rejects_non_active_memory_rows() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + invalid_row = active_row(store, memory_key="user.preference.music", score=0.5, minute_offset=0) + invalid_row["status"] = "deleted" + store.retrieval_rows = [invalid_row] + + with pytest.raises( + SemanticMemoryRetrievalValidationError, + match="semantic retrieval only supports active memories", + ): + retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=config_id, + query_vector=(0.1, 0.2, 0.3), + ), + ) diff --git a/tests/unit/test_store.py b/tests/unit/test_store.py new file mode 100644 index 0000000..15c1f1e --- /dev/null +++ b/tests/unit/test_store.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb +import pytest + +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_create_methods_return_cursor_rows_and_use_expected_parameters() -> None: + user_id = uuid4() + thread_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + {"id": user_id, "email": "owner@example.com", "display_name": "Owner"}, + {"id": thread_id, "title": "Starter thread"}, + {"id": uuid4(), "thread_id": thread_id, "status": "active"}, + ] + ) + store = ContinuityStore(RecordingConnection(cursor)) + + user = store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Starter thread") + session = store.create_session(thread_id) + + assert user["id"] == user_id + assert thread["id"] == thread_id + assert session["thread_id"] == thread_id + assert cursor.executed == [ + ( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, %s, %s) + RETURNING id, email, display_name, created_at + """, + (user_id, "owner@example.com", "Owner"), + ), + ( + """ + INSERT INTO threads (user_id, title) + VALUES (app.current_user_id(), %s) + RETURNING id, user_id, title, created_at, updated_at + """, + ("Starter thread",), + ), + ( + """ + INSERT INTO sessions (user_id, thread_id, status) + VALUES (app.current_user_id(), %s, %s) + RETURNING id, user_id, thread_id, status, started_at, ended_at, created_at + """, + (thread_id, "active"), + ), + ] + + +def test_append_event_locks_thread_and_serializes_payload() -> None: + thread_id = uuid4() + session_id = uuid4() + payload = {"text": "hello"} + cursor = RecordingCursor( + fetchone_results=[ + { + "id": uuid4(), + "thread_id": thread_id, + "session_id": session_id, + "sequence_no": 1, + "kind": "message.user", + "payload": payload, + } + ] + ) + store = ContinuityStore(RecordingConnection(cursor)) + + event = store.append_event(thread_id, session_id, "message.user", payload) + + assert event["sequence_no"] == 1 + assert cursor.executed[0] == ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 0))", + (str(thread_id),), + ) + insert_query, insert_params = cursor.executed[1] + assert "WITH next_sequence AS" in insert_query + assert insert_params is not None + assert insert_params[:4] == (thread_id, thread_id, session_id, "message.user") + assert isinstance(insert_params[4], Jsonb) + assert insert_params[4].obj == payload + + +def test_list_thread_events_returns_all_rows_in_order() -> None: + thread_id = uuid4() + events = [ + {"sequence_no": 1, "kind": "message.user"}, + {"sequence_no": 2, "kind": "message.assistant"}, + ] + cursor = RecordingCursor(fetchone_results=[], fetchall_result=events) + store = ContinuityStore(RecordingConnection(cursor)) + + result = store.list_thread_events(thread_id) + + assert result == events + assert cursor.executed == [ + ( + """ + SELECT id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + FROM events + WHERE thread_id = %s + ORDER BY sequence_no ASC + """, + (thread_id,), + ), + ] + + +def test_create_user_raises_clear_error_when_returning_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + with pytest.raises( + ContinuityStoreInvariantError, + match="create_user did not return a row", + ): + store.create_user(uuid4(), "owner@example.com") diff --git a/tests/unit/test_task_step_store.py b/tests/unit/test_task_step_store.py new file mode 100644 index 0000000..af764b4 --- /dev/null +++ b/tests/unit/test_task_step_store.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_task_step_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + task_step_id = uuid4() + task_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + trace_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "created", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.request", + }, + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "created", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.request", + }, + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "approved", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": str(uuid4()), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.resolve", + }, + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "approved", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": str(uuid4()), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.resolve", + }, + { + "id": task_id, + "user_id": uuid4(), + "thread_id": thread_id, + "tool_id": tool_id, + "status": "approved", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "tool": {"id": str(tool_id), "tool_key": "proxy.echo"}, + "latest_approval_id": None, + "latest_execution_id": None, + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:05:00+00:00", + }, + ], + fetchall_result=[ + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "created", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.request", + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_task_step( + task_id=task_id, + sequence_no=1, + kind="governed_request", + status="created", + request={"thread_id": "thread-123", "tool_id": "tool-123"}, + outcome={ + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=trace_id, + trace_kind="approval.request", + ) + fetched = store.get_task_step_optional(task_step_id) + listed = store.list_task_steps_for_task(task_id) + updated = store.update_task_step_for_task_sequence_optional( + task_id=task_id, + sequence_no=1, + status="approved", + outcome={ + "routing_decision": "approval_required", + "approval_id": "approval-123", + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=trace_id, + trace_kind="approval.resolve", + ) + updated_by_id = store.update_task_step_optional( + task_step_id=task_step_id, + status="approved", + outcome={ + "routing_decision": "approval_required", + "approval_id": "approval-123", + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=trace_id, + trace_kind="approval.resolve", + ) + updated_task = store.update_task_status_optional( + task_id=task_id, + status="approved", + latest_approval_id=None, + latest_execution_id=None, + ) + + assert created["id"] == task_step_id + assert fetched is not None + assert listed[0]["id"] == task_step_id + assert updated is not None + assert updated["status"] == "approved" + assert updated_by_id is not None + assert updated_by_id["status"] == "approved" + assert updated_task is not None + assert updated_task["status"] == "approved" + + lock_query, lock_params = cursor.executed[0] + assert "pg_advisory_xact_lock" in lock_query + assert lock_params == (str(task_id),) + + create_query, create_params = cursor.executed[1] + assert "INSERT INTO task_steps" in create_query + assert create_params is not None + assert create_params[:7] == (task_id, 1, None, None, None, "governed_request", "created") + assert isinstance(create_params[7], Jsonb) + assert create_params[7].obj == {"thread_id": "thread-123", "tool_id": "tool-123"} + assert isinstance(create_params[8], Jsonb) + assert create_params[8].obj == { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + } + assert create_params[9] == trace_id + assert create_params[10] == "approval.request" + assert "FROM task_steps" in cursor.executed[2][0] + assert "ORDER BY sequence_no ASC, created_at ASC, id ASC" in cursor.executed[3][0] + + update_query, update_params = cursor.executed[4] + assert "UPDATE task_steps" in update_query + assert "WHERE task_id = %s" in update_query + assert update_params is not None + assert update_params[0] == "approved" + assert isinstance(update_params[1], Jsonb) + assert update_params[1].obj["approval_status"] == "approved" + assert update_params[2] == trace_id + assert update_params[3] == "approval.resolve" + assert update_params[4:] == (task_id, 1) + + update_by_id_query, update_by_id_params = cursor.executed[5] + assert "UPDATE task_steps" in update_by_id_query + assert "WHERE id = %s" in update_by_id_query + assert update_by_id_params is not None + assert update_by_id_params[0] == "approved" + assert isinstance(update_by_id_params[1], Jsonb) + assert update_by_id_params[1].obj["approval_status"] == "approved" + assert update_by_id_params[2] == trace_id + assert update_by_id_params[3] == "approval.resolve" + assert update_by_id_params[4] == task_step_id + + task_update_query, task_update_params = cursor.executed[6] + assert "UPDATE tasks" in task_update_query + assert task_update_params == ("approved", None, None, task_id) diff --git a/tests/unit/test_task_workspace_store.py b/tests/unit/test_task_workspace_store.py new file mode 100644 index 0000000..16bd0ae --- /dev/null +++ b/tests/unit/test_task_workspace_store.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_task_workspace_store_methods_use_expected_queries() -> None: + task_workspace_id = uuid4() + task_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + ], + fetchall_result=[ + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_task_workspace( + task_id=task_id, + status="active", + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + fetched = store.get_task_workspace_optional(task_workspace_id) + active = store.get_active_task_workspace_for_task_optional(task_id) + listed = store.list_task_workspaces() + store.lock_task_workspaces(task_id) + + assert created["id"] == task_workspace_id + assert fetched is not None + assert active is not None + assert listed[0]["id"] == task_workspace_id + assert cursor.executed == [ + ( + """ + INSERT INTO task_workspaces ( + user_id, + task_id, + status, + local_path, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + """, + (task_id, "active", "/tmp/alicebot/task-workspaces/user/task"), + ), + ( + """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE id = %s + """, + (task_workspace_id,), + ), + ( + """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE task_id = %s + AND status = 'active' + ORDER BY created_at ASC, id ASC + LIMIT 1 + """, + (task_id,), + ), + ( + """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + ORDER BY created_at ASC, id ASC + """, + None, + ), + ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 3))", + (str(task_id),), + ), + ] diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py new file mode 100644 index 0000000..142f048 --- /dev/null +++ b/tests/unit/test_tasks.py @@ -0,0 +1,1663 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +from alicebot_api.tasks import ( + TaskNotFoundError, + TaskStepApprovalLinkageError, + TaskStepExecutionLinkageError, + TaskStepNotFoundError, + TaskStepSequenceError, + TaskStepTransitionError, + allowed_task_step_transitions, + create_next_task_step_record, + create_task_step_for_governed_request, + get_task_step_record, + get_task_record, + list_task_records, + list_task_step_records, + sync_task_with_task_step_status, + sync_task_step_with_approval, + sync_task_step_with_execution, + task_status_for_step_status, + next_task_status_for_approval, + task_lifecycle_trace_events, + task_step_lifecycle_trace_events, + task_step_outcome_snapshot, + task_step_status_for_approval_status, + task_step_status_for_execution_status, + task_step_status_for_routing_decision, + task_status_for_approval_status, + task_status_for_execution_status, + task_status_for_routing_decision, + transition_task_step_record, +) +from alicebot_api.contracts import ( + TaskStepCreateInput, + TaskStepLineageInput, + TaskStepNextCreateInput, + TaskStepTransitionInput, +) + + +class TaskStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.user_id = uuid4() + self.tasks: list[dict[str, object]] = [] + self.task_steps: list[dict[str, object]] = [] + self.approvals: list[dict[str, object]] = [] + self.tool_executions: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + self.locked_task_ids: list[UUID] = [] + + def create_task( + self, + *, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object]: + task = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "status": status, + "request": { + "thread_id": str(uuid4()), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + "tool": { + "id": str(uuid4()), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": self.base_time.isoformat(), + }, + "latest_approval_id": latest_approval_id, + "latest_execution_id": latest_execution_id, + "created_at": self.base_time + timedelta(minutes=len(self.tasks)), + "updated_at": self.base_time + timedelta(minutes=len(self.tasks)), + } + self.tasks.append(task) + return task + + def list_tasks(self) -> list[dict[str, object]]: + return sorted(self.tasks, key=lambda task: (task["created_at"], task["id"])) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def get_task_by_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["latest_approval_id"] == approval_id), None) + + def update_task_status_optional( + self, + *, + task_id: UUID, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object] | None: + task = self.get_task_optional(task_id) + if task is None: + return None + task["status"] = status + task["latest_approval_id"] = latest_approval_id + task["latest_execution_id"] = latest_execution_id + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def lock_task_steps(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: dict[str, object], + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object]: + task_step = { + "id": uuid4(), + "user_id": self.user_id, + "task_id": task_id, + "sequence_no": sequence_no, + "parent_step_id": parent_step_id, + "source_approval_id": source_approval_id, + "source_execution_id": source_execution_id, + "kind": kind, + "status": status, + "request": request, + "outcome": outcome, + "trace_id": trace_id, + "trace_kind": trace_kind, + "created_at": self.base_time + timedelta(minutes=len(self.task_steps)), + "updated_at": self.base_time + timedelta(minutes=len(self.task_steps)), + } + self.task_steps.append(task_step) + return task_step + + def get_task_step_optional(self, task_step_id: UUID) -> dict[str, object] | None: + return next((task_step for task_step in self.task_steps if task_step["id"] == task_step_id), None) + + def get_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((approval for approval in self.approvals if approval["id"] == approval_id), None) + + def get_tool_execution_optional(self, execution_id: UUID) -> dict[str, object] | None: + return next((execution for execution in self.tool_executions if execution["id"] == execution_id), None) + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> dict[str, object] | None: + return next( + ( + task_step + for task_step in self.task_steps + if task_step["task_id"] == task_id and task_step["sequence_no"] == sequence_no + ), + None, + ) + + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + [task_step for task_step in self.task_steps if task_step["task_id"] == task_id], + key=lambda task_step: (task_step["sequence_no"], task_step["created_at"], task_step["id"]), + ) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_for_task_sequence_optional(task_id=task_id, sequence_no=sequence_no) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.task_steps)) + return task_step + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_optional(task_step_id) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.task_steps)) + return task_step + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + +def test_list_and_get_task_records_are_deterministic() -> None: + store = TaskStoreStub() + first = store.create_task( + status="approved", + latest_approval_id=None, + latest_execution_id=None, + ) + second = store.create_task( + status="blocked", + latest_approval_id=uuid4(), + latest_execution_id=uuid4(), + ) + + listed = list_task_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_task_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_id=second["id"], + ) + + assert [item["id"] for item in listed["items"]] == [str(first["id"]), str(second["id"])] + assert [item["status"] for item in listed["items"]] == ["approved", "blocked"] + assert listed["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail["task"]["id"] == str(second["id"]) + assert detail["task"]["status"] == "blocked" + assert detail["task"]["latest_approval_id"] == str(second["latest_approval_id"]) + assert detail["task"]["latest_execution_id"] == str(second["latest_execution_id"]) + + +def test_task_lifecycle_helpers_return_deterministic_statuses_and_trace_payloads() -> None: + assert task_status_for_routing_decision("approval_required") == "pending_approval" + assert task_status_for_routing_decision("ready") == "approved" + assert task_status_for_routing_decision("denied") == "denied" + assert task_status_for_approval_status("approved") == "approved" + assert task_status_for_approval_status("rejected") == "denied" + assert next_task_status_for_approval(current_status="pending_approval", approval_status="approved") == "approved" + assert next_task_status_for_approval(current_status="executed", approval_status="approved") == "executed" + assert task_status_for_execution_status("completed") == "executed" + assert task_status_for_execution_status("blocked") == "blocked" + assert task_step_status_for_routing_decision("approval_required") == "created" + assert task_step_status_for_routing_decision("ready") == "approved" + assert task_step_status_for_routing_decision("denied") == "denied" + assert task_step_status_for_approval_status("approved") == "approved" + assert task_step_status_for_approval_status("rejected") == "denied" + assert task_step_status_for_execution_status("completed") == "executed" + assert task_step_status_for_execution_status("blocked") == "blocked" + assert task_status_for_step_status("created") == "pending_approval" + assert task_status_for_step_status("approved") == "approved" + assert task_status_for_step_status("executed") == "executed" + assert allowed_task_step_transitions("created") == ["approved", "denied"] + assert allowed_task_step_transitions("approved") == ["executed", "blocked"] + assert allowed_task_step_transitions("executed") == [] + + task = { + "id": str(uuid4()), + "thread_id": str(uuid4()), + "tool_id": str(uuid4()), + "status": "executed", + "request": { + "thread_id": str(uuid4()), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + "tool": { + "id": str(uuid4()), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": "2026-03-13T10:00:00+00:00", + }, + "latest_approval_id": str(uuid4()), + "latest_execution_id": str(uuid4()), + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:05:00+00:00", + } + + events = task_lifecycle_trace_events( + task=task, + previous_status="approved", + source="proxy_execution", + ) + + assert events == [ + ( + "task.lifecycle.state", + { + "task_id": task["id"], + "source": "proxy_execution", + "previous_status": "approved", + "current_status": "executed", + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + }, + ), + ( + "task.lifecycle.summary", + { + "task_id": task["id"], + "source": "proxy_execution", + "final_status": "executed", + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + }, + ), + ] + + task_step = { + "id": str(uuid4()), + "task_id": task["id"], + "sequence_no": 1, + "lineage": { + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + }, + "kind": "governed_request", + "status": "executed", + "request": task["request"], + "outcome": task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=task["latest_approval_id"], + approval_status="approved", + execution_id=task["latest_execution_id"], + execution_status="completed", + blocked_reason=None, + ), + "trace": { + "trace_id": str(uuid4()), + "trace_kind": "tool.proxy.execute", + }, + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:05:00+00:00", + } + + task_step_events = task_step_lifecycle_trace_events( + task_step=task_step, + previous_status="approved", + source="proxy_execution", + ) + + assert task_step_events == [ + ( + "task.step.lifecycle.state", + { + "task_id": task["id"], + "task_step_id": task_step["id"], + "source": "proxy_execution", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": "approved", + "current_status": "executed", + "trace": task_step["trace"], + }, + ), + ( + "task.step.lifecycle.summary", + { + "task_id": task["id"], + "task_step_id": task_step["id"], + "source": "proxy_execution", + "sequence_no": 1, + "kind": "governed_request", + "final_status": "executed", + "trace": task_step["trace"], + }, + ), + ] + + +def test_get_task_record_raises_not_found_when_missing() -> None: + store = TaskStoreStub() + + try: + get_task_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_id=uuid4(), + ) + except TaskNotFoundError as exc: + assert "task" in str(exc) + else: + raise AssertionError("expected TaskNotFoundError") + + +def test_task_step_list_get_and_lifecycle_updates_are_deterministic() -> None: + store = TaskStoreStub() + task = store.create_task( + status="pending_approval", + latest_approval_id=uuid4(), + latest_execution_id=None, + ) + first_trace_id = uuid4() + create_payload = create_task_step_for_governed_request( + store, # type: ignore[arg-type] + request=TaskStepCreateInput( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=first_trace_id, + trace_kind="approval.request", + ), + ) + second_trace_id = uuid4() + approval_transition = sync_task_step_with_approval( + store, # type: ignore[arg-type] + approval_id=UUID(str(task["latest_approval_id"])), + task_step_id=UUID(create_payload["task_step"]["id"]), + approval_status="approved", + trace_id=second_trace_id, + trace_kind="approval.resolve", + ) + execution = { + "id": uuid4(), + "approval_id": task["latest_approval_id"], + "task_step_id": UUID(create_payload["task_step"]["id"]), + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + } + third_trace_id = uuid4() + execution_transition = sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution=execution, # type: ignore[arg-type] + trace_id=third_trace_id, + trace_kind="tool.proxy.execute", + ) + store.create_task_step( + task_id=task["id"], + sequence_no=2, + kind="governed_request", + status="denied", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="denied", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.request", + ) + + listed = list_task_step_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_id=task["id"], + ) + detail = get_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_step_id=UUID(create_payload["task_step"]["id"]), + ) + + assert [item["sequence_no"] for item in listed["items"]] == [1, 2] + assert listed["summary"] == { + "task_id": str(task["id"]), + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "denied", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert detail["task_step"]["id"] == create_payload["task_step"]["id"] + assert detail["task_step"]["status"] == "executed" + assert detail["task_step"]["trace"] == { + "trace_id": str(third_trace_id), + "trace_kind": "tool.proxy.execute", + } + assert detail["task_step"]["outcome"] == { + "routing_decision": "approval_required", + "approval_id": str(task["latest_approval_id"]), + "approval_status": "approved", + "execution_id": str(execution["id"]), + "execution_status": "completed", + "blocked_reason": None, + } + + +def test_sync_task_step_with_approval_updates_explicitly_linked_later_step_only() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="pending_approval", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(initial_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + + original_first_trace_id = first_step["trace_id"] + original_first_trace_kind = first_step["trace_kind"] + original_first_outcome = dict(first_step["outcome"]) + later_trace_id = uuid4() + + transition = sync_task_step_with_approval( + store, # type: ignore[arg-type] + approval_id=approval_id, + task_step_id=later_step["id"], + approval_status="approved", + trace_id=later_trace_id, + trace_kind="approval.resolve", + ) + + assert transition.previous_status == "created" + assert transition.task_step["id"] == str(later_step["id"]) + assert transition.task_step["status"] == "approved" + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["trace_kind"] == original_first_trace_kind + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "approved" + assert later_step["trace_id"] == later_trace_id + assert later_step["trace_kind"] == "approval.resolve" + assert later_step["outcome"] == { + "routing_decision": "approval_required", + "approval_id": str(approval_id), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + } + assert task["status"] == "pending_approval" + assert task["latest_execution_id"] is None + + +def test_sync_task_step_with_approval_rejects_inconsistent_linkage_without_mutating_steps() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="pending_approval", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(initial_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + + original_first_outcome = dict(first_step["outcome"]) + original_later_trace_id = later_step["trace_id"] + + try: + sync_task_step_with_approval( + store, # type: ignore[arg-type] + approval_id=approval_id, + task_step_id=later_step["id"], + approval_status="approved", + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + except TaskStepApprovalLinkageError as exc: + assert str(exc) == ( + f"approval {approval_id} is inconsistent with linked task step {later_step['id']}" + ) + else: + raise AssertionError("expected TaskStepApprovalLinkageError") + + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "created" + assert later_step["trace_id"] == original_later_trace_id + assert later_step["trace_kind"] == "task.step.continuation" + + +def test_sync_task_step_with_execution_updates_the_linked_later_step_without_mutating_initial_step() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(initial_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.transition", + ) + + original_first_trace_id = first_step["trace_id"] + original_first_trace_kind = first_step["trace_kind"] + original_first_outcome = dict(first_step["outcome"]) + execution = { + "id": uuid4(), + "approval_id": approval_id, + "task_step_id": later_step["id"], + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + } + + transition = sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution=execution, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + assert transition.previous_status == "approved" + assert transition.task_step["id"] == str(later_step["id"]) + assert transition.task_step["status"] == "executed" + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["trace_kind"] == original_first_trace_kind + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "executed" + assert later_step["trace_kind"] == "tool.proxy.execute" + assert later_step["outcome"] == { + "routing_decision": "approval_required", + "approval_id": str(approval_id), + "approval_status": "approved", + "execution_id": str(execution["id"]), + "execution_status": "completed", + "blocked_reason": None, + } + assert task["status"] == "approved" + assert task["latest_execution_id"] is None + + +def test_sync_task_step_with_execution_rejects_missing_linkage_without_mutating_steps() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + execution_id = uuid4() + + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": execution_id, + "approval_id": approval_id, + "task_step_id": None, + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == f"tool execution {execution_id} is missing linked task_step_id" + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + assert first_step["status"] == "approved" + assert first_step["outcome"]["execution_id"] is None + + +def test_sync_task_step_with_execution_rejects_unknown_or_out_of_task_linkage() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + other_task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + other_step = store.create_task_step( + task_id=other_task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=other_task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + + missing_execution_id = uuid4() + missing_task_step_id = uuid4() + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": missing_execution_id, + "approval_id": approval_id, + "task_step_id": missing_task_step_id, + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == ( + f"tool execution {missing_execution_id} references linked task step " + f"{missing_task_step_id} that was not found" + ) + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + outside_execution_id = uuid4() + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": outside_execution_id, + "approval_id": approval_id, + "task_step_id": other_step["id"], + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == ( + f"tool execution {outside_execution_id} links task step {other_step['id']} " + f"outside task {task['id']}" + ) + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + +def test_sync_task_step_with_execution_rejects_inconsistent_linkage_without_mutating_steps() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + inconsistent_execution_id = uuid4() + + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": inconsistent_execution_id, + "approval_id": uuid4(), + "task_step_id": step["id"], + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == ( + f"tool execution {inconsistent_execution_id} is inconsistent with linked task step {step['id']}" + ) + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + assert step["status"] == "approved" + assert step["outcome"]["execution_id"] is None + + +def test_sync_task_with_task_step_status_updates_parent_through_task_seam() -> None: + store = TaskStoreStub() + task = store.create_task( + status="executed", + latest_approval_id=uuid4(), + latest_execution_id=uuid4(), + ) + + transition = sync_task_with_task_step_status( + store, # type: ignore[arg-type] + task_id=task["id"], + task_step_status="created", + linked_approval_id=task["latest_approval_id"], + linked_execution_id=None, + ) + + assert transition.previous_status == "executed" + assert transition.task["status"] == "pending_approval" + assert transition.task["latest_execution_id"] is None + assert store.tasks[0]["status"] == "pending_approval" + assert store.tasks[0]["latest_execution_id"] is None + + +def test_create_next_task_step_assigns_deterministic_sequence_updates_parent_and_records_trace() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="executed", + latest_approval_id=approval_id, + latest_execution_id=initial_execution_id, + ) + store.approvals.append({"id": approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}) + store.tool_executions.append( + { + "id": task["latest_execution_id"], + "thread_id": task["thread_id"], + "tool_id": task["tool_id"], + "approval_id": approval_id, + } + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(task["latest_execution_id"]), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + payload = create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput( + parent_step_id=store.task_steps[0]["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + ), + ), + ) + + assert payload["task"]["status"] == "pending_approval" + assert payload["task"]["latest_approval_id"] == str(approval_id) + assert payload["task"]["latest_execution_id"] is None + assert payload["task_step"]["sequence_no"] == 2 + assert payload["task_step"]["status"] == "created" + assert payload["task_step"]["lineage"] == { + "parent_step_id": str(store.task_steps[0]["id"]), + "source_approval_id": str(approval_id), + "source_execution_id": str(initial_execution_id), + } + assert payload["task_step"]["trace"]["trace_kind"] == "task.step.continuation" + assert payload["sequencing"] == { + "task_id": str(task["id"]), + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "created", + "next_sequence_no": 3, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 7 + assert [event["kind"] for event in store.trace_events] == [ + "task.step.continuation.request", + "task.step.continuation.lineage", + "task.step.continuation.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[1]["payload"] == { + "task_id": str(task["id"]), + "parent_task_step_id": str(store.task_steps[0]["id"]), + "parent_sequence_no": 1, + "parent_status": "executed", + "source_approval_id": str(approval_id), + "source_execution_id": str(initial_execution_id), + } + + +def test_create_next_task_step_rejects_when_latest_step_is_not_terminal() -> None: + store = TaskStoreStub() + task = store.create_task( + status="pending_approval", + latest_approval_id=uuid4(), + latest_execution_id=None, + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.request", + ) + + try: + create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput(parent_step_id=store.task_steps[0]["id"]), + ), + ) + except TaskStepSequenceError as exc: + assert str(exc) == ( + f"task {task['id']} latest step {store.task_steps[0]['id']} is created and cannot append a next step" + ) + else: + raise AssertionError("expected TaskStepSequenceError") + + +def test_transition_task_step_updates_latest_step_parent_and_trace() -> None: + store = TaskStoreStub() + first_approval_id = uuid4() + first_execution_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=first_approval_id, + latest_execution_id=first_execution_id, + ) + store.approvals.extend( + [ + {"id": first_approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}, + ] + ) + store.tool_executions.extend( + [ + { + "id": first_execution_id, + "thread_id": task["thread_id"], + "tool_id": task["tool_id"], + "approval_id": first_approval_id, + }, + ] + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(first_approval_id), + approval_status="approved", + execution_id=str(first_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + second_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.sequence", + ) + + payload = transition_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepTransitionInput( + task_step_id=second_step["id"], + status="executed", + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=str(first_approval_id), + approval_status="approved", + execution_id=str(first_execution_id), + execution_status="completed", + blocked_reason=None, + ), + ), + ) + + assert first_step["status"] == "executed" + assert payload["task"]["status"] == "executed" + assert payload["task"]["latest_approval_id"] == str(first_approval_id) + assert payload["task"]["latest_execution_id"] == str(first_execution_id) + assert payload["task_step"]["id"] == str(second_step["id"]) + assert payload["task_step"]["status"] == "executed" + assert payload["task_step"]["trace"]["trace_kind"] == "task.step.transition" + assert payload["sequencing"] == { + "task_id": str(task["id"]), + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "executed", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert [event["kind"] for event in store.trace_events] == [ + "task.step.transition.request", + "task.step.transition.state", + "task.step.transition.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[1]["payload"]["allowed_next_statuses"] == ["executed", "blocked"] + + +def test_create_next_task_step_locks_before_listing_existing_steps() -> None: + class LockingTaskStoreStub(TaskStoreStub): + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + if task_id not in self.locked_task_ids: + raise AssertionError("task steps were listed before the advisory lock was taken") + return super().list_task_steps_for_task(task_id) + + store = LockingTaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="executed", + latest_approval_id=approval_id, + latest_execution_id=initial_execution_id, + ) + store.approvals.append({"id": approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}) + store.tool_executions.append( + { + "id": task["latest_execution_id"], + "thread_id": task["thread_id"], + "tool_id": task["tool_id"], + "approval_id": approval_id, + } + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(task["latest_execution_id"]), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + payload = create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput( + parent_step_id=store.task_steps[0]["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + ), + ), + ) + + assert payload["task_step"]["sequence_no"] == 2 + + +def test_create_next_task_step_rejects_visible_approval_from_unrelated_task_lineage() -> None: + store = TaskStoreStub() + task = store.create_task( + status="executed", + latest_approval_id=uuid4(), + latest_execution_id=uuid4(), + ) + unrelated_approval_id = uuid4() + store.approvals.append( + { + "id": unrelated_approval_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + } + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="approved", + execution_id=str(task["latest_execution_id"]), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + try: + create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput( + parent_step_id=store.task_steps[0]["id"], + source_approval_id=unrelated_approval_id, + ), + ), + ) + except TaskStepSequenceError as exc: + assert str(exc) == f"approval {unrelated_approval_id} does not belong to task {task['id']}" + else: + raise AssertionError("expected TaskStepSequenceError") + + +def test_create_next_task_step_rejects_parent_step_from_unrelated_task() -> None: + store = TaskStoreStub() + task = store.create_task( + status="executed", + latest_approval_id=None, + latest_execution_id=None, + ) + unrelated_task = store.create_task( + status="executed", + latest_approval_id=None, + latest_execution_id=None, + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + unrelated_step = store.create_task_step( + task_id=unrelated_task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=unrelated_task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + try: + create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput(parent_step_id=unrelated_step["id"]), + ), + ) + except TaskStepSequenceError as exc: + assert str(exc) == f"task step {unrelated_step['id']} does not belong to task {task['id']}" + else: + raise AssertionError("expected TaskStepSequenceError") + + +def test_transition_task_step_rejects_invalid_status_graph_edge() -> None: + store = TaskStoreStub() + task = store.create_task( + status="pending_approval", + latest_approval_id=uuid4(), + latest_execution_id=None, + ) + step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.request", + ) + + try: + transition_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepTransitionInput( + task_step_id=step["id"], + status="executed", + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="approved", + execution_id=str(uuid4()), + execution_status="completed", + blocked_reason=None, + ), + ), + ) + except TaskStepTransitionError as exc: + assert str(exc) == ( + f"task step {step['id']} is created and cannot transition to executed; allowed: approved, denied" + ) + else: + raise AssertionError("expected TaskStepTransitionError") + + +def test_transition_task_step_rejects_visible_execution_from_unrelated_task_lineage() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.sequence", + ) + store.approvals.append({"id": approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}) + unrelated_execution_id = uuid4() + store.tool_executions.append( + { + "id": unrelated_execution_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "approval_id": approval_id, + } + ) + + try: + transition_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepTransitionInput( + task_step_id=step["id"], + status="executed", + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(unrelated_execution_id), + execution_status="completed", + blocked_reason=None, + ), + ), + ) + except TaskStepTransitionError as exc: + assert str(exc) == f"tool execution {unrelated_execution_id} does not belong to task {task['id']}" + else: + raise AssertionError("expected TaskStepTransitionError") + + +def test_get_task_step_record_raises_not_found_when_missing() -> None: + store = TaskStoreStub() + + try: + get_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_step_id=uuid4(), + ) + except TaskStepNotFoundError as exc: + assert "task step" in str(exc) + else: + raise AssertionError("expected TaskStepNotFoundError") diff --git a/tests/unit/test_tasks_main.py b/tests/unit/test_tasks_main.py new file mode 100644 index 0000000..0be9960 --- /dev/null +++ b/tests/unit/test_tasks_main.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.tasks import ( + TaskNotFoundError, + TaskStepNotFoundError, + TaskStepSequenceError, + TaskStepTransitionError, +) + + +def test_list_task_steps_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_step_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "task_id": str(task_id), + "total_count": 0, + "latest_sequence_no": None, + "latest_status": None, + "next_sequence_no": 1, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + }, + }, + ) + + response = main_module.list_task_steps(task_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": { + "task_id": str(task_id), + "total_count": 0, + "latest_sequence_no": None, + "latest_status": None, + "next_sequence_no": 1, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + }, + } + + +def test_list_task_steps_endpoint_maps_task_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_list_task_step_records(*_args, **_kwargs): + raise TaskNotFoundError(f"task {task_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_task_step_records", fake_list_task_step_records) + + response = main_module.list_task_steps(task_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task {task_id} was not found"} + + +def test_get_task_step_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_step_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_task_step_record(*_args, **_kwargs): + raise TaskStepNotFoundError(f"task step {task_step_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_task_step_record", fake_get_task_step_record) + + response = main_module.get_task_step(task_step_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task step {task_step_id} was not found"} + + +def test_create_next_task_step_endpoint_maps_sequence_conflict_to_409(monkeypatch) -> None: + task_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_next_task_step_record(*_args, **_kwargs): + raise TaskStepSequenceError(f"task {task_id} latest step blocked append") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_next_task_step_record", fake_create_next_task_step_record) + + response = main_module.create_next_task_step( + task_id, + main_module.CreateNextTaskStepRequest( + user_id=user_id, + kind="governed_request", + status="created", + request=main_module.TaskStepRequestSnapshot( + thread_id=uuid4(), + tool_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ), + outcome=main_module.TaskStepOutcomeRequest( + routing_decision="approval_required", + approval_status="pending", + ), + lineage=main_module.TaskStepLineageRequest(parent_step_id=uuid4()), + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == {"detail": f"task {task_id} latest step blocked append"} + + +def test_transition_task_step_endpoint_maps_transition_conflict_to_409(monkeypatch) -> None: + task_step_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_transition_task_step_record(*_args, **_kwargs): + raise TaskStepTransitionError(f"task step {task_step_id} is created and cannot transition") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "transition_task_step_record", fake_transition_task_step_record) + + response = main_module.transition_task_step( + task_step_id, + main_module.TransitionTaskStepRequest( + user_id=user_id, + status="approved", + outcome=main_module.TaskStepOutcomeRequest( + routing_decision="approval_required", + approval_status="approved", + ), + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"task step {task_step_id} is created and cannot transition" + } diff --git a/tests/unit/test_tool_execution_store.py b/tests/unit/test_tool_execution_store.py new file mode 100644 index 0000000..f0715fc --- /dev/null +++ b/tests/unit/test_tool_execution_store.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_tool_execution_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + execution_id = uuid4() + approval_id = uuid4() + task_step_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + trace_id = uuid4() + request_event_id = uuid4() + result_event_id = uuid4() + row = { + "id": execution_id, + "approval_id": approval_id, + "task_step_id": task_step_id, + "thread_id": thread_id, + "tool_id": tool_id, + "trace_id": trace_id, + "request_event_id": request_event_id, + "result_event_id": result_event_id, + "status": "completed", + "handler_key": "proxy.echo", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "proxy.echo"}, + "result": {"handler_key": "proxy.echo", "status": "completed", "output": {"mode": "no_side_effect"}, "reason": None}, + "executed_at": "2026-03-13T10:00:00+00:00", + } + cursor = RecordingCursor( + fetchone_results=[row, row], + fetchall_result=[row], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_tool_execution( + approval_id=approval_id, + task_step_id=task_step_id, + thread_id=thread_id, + tool_id=tool_id, + trace_id=trace_id, + request_event_id=request_event_id, + result_event_id=result_event_id, + status="completed", + handler_key="proxy.echo", + request={"thread_id": str(thread_id), "tool_id": str(tool_id)}, + tool={"id": str(tool_id), "tool_key": "proxy.echo"}, + result={"handler_key": "proxy.echo", "status": "completed", "output": {"mode": "no_side_effect"}, "reason": None}, + ) + fetched = store.get_tool_execution_optional(execution_id) + listed = store.list_tool_executions() + + assert created["id"] == execution_id + assert fetched is not None + assert fetched["id"] == execution_id + assert listed[0]["id"] == execution_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO tool_executions" in create_query + assert create_params is not None + assert create_params[:9] == ( + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + "completed", + "proxy.echo", + ) + assert isinstance(create_params[9], Jsonb) + assert create_params[9].obj == {"thread_id": str(thread_id), "tool_id": str(tool_id)} + assert isinstance(create_params[10], Jsonb) + assert create_params[10].obj == {"id": str(tool_id), "tool_key": "proxy.echo"} + assert isinstance(create_params[11], Jsonb) + assert create_params[11].obj == { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + } + assert "FROM tool_executions" in cursor.executed[1][0] + assert "ORDER BY executed_at ASC, id ASC" in cursor.executed[2][0] + + +def test_create_tool_execution_accepts_blocked_attempt_without_event_ids() -> None: + execution_id = uuid4() + approval_id = uuid4() + task_step_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + trace_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": execution_id, + "approval_id": approval_id, + "task_step_id": task_step_id, + "thread_id": thread_id, + "tool_id": tool_id, + "trace_id": trace_id, + "request_event_id": None, + "result_event_id": None, + "status": "blocked", + "handler_key": None, + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "proxy.missing"}, + "result": { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + }, + "executed_at": "2026-03-13T10:05:00+00:00", + } + ] + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_tool_execution( + approval_id=approval_id, + task_step_id=task_step_id, + thread_id=thread_id, + tool_id=tool_id, + trace_id=trace_id, + request_event_id=None, + result_event_id=None, + status="blocked", + handler_key=None, + request={"thread_id": str(thread_id), "tool_id": str(tool_id)}, + tool={"id": str(tool_id), "tool_key": "proxy.missing"}, + result={ + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + }, + ) + + assert created["status"] == "blocked" + create_query, create_params = cursor.executed[0] + assert "INSERT INTO tool_executions" in create_query + assert create_params is not None + assert create_params[5] is None + assert create_params[6] is None + assert create_params[8] is None diff --git a/tests/unit/test_tool_store.py b/tests/unit/test_tool_store.py new file mode 100644 index 0000000..6f9fc09 --- /dev/null +++ b/tests/unit/test_tool_store.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_tool_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + tool_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": tool_id, + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + { + "id": tool_id, + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + ], + fetchall_result=[ + { + "id": tool_id, + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + fetched = store.get_tool_optional(tool_id) + listed = store.list_active_tools() + + assert created["id"] == tool_id + assert fetched is not None + assert listed[0]["id"] == tool_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO tools" in create_query + assert create_params is not None + assert create_params[:6] == ( + "browser.open", + "Browser Open", + "Open documentation pages.", + "1.0.0", + "tool_metadata_v0", + True, + ) + for index, expected in ( + (6, ["browser"]), + (7, ["tool.run"]), + (8, ["workspace"]), + (9, ["docs"]), + (10, []), + ): + assert isinstance(create_params[index], Jsonb) + assert create_params[index].obj == expected + assert isinstance(create_params[11], Jsonb) + assert create_params[11].obj == {"transport": "proxy"} + + assert cursor.executed[1] == ( + """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + WHERE id = %s + """, + (tool_id,), + ) + assert "WHERE active = TRUE" in cursor.executed[2][0] diff --git a/tests/unit/test_tools.py b/tests/unit/test_tools.py new file mode 100644 index 0000000..169e24e --- /dev/null +++ b/tests/unit/test_tools.py @@ -0,0 +1,688 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import ( + ToolAllowlistEvaluationRequestInput, + ToolCreateInput, + ToolRoutingRequestInput, +) +from alicebot_api.tools import ( + ToolAllowlistValidationError, + create_tool_record, + evaluate_tool_allowlist, + get_tool_record, + list_tool_records, + route_tool_invocation, + ToolRoutingValidationError, +) + + +class ToolStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.consents: dict[str, dict[str, object]] = {} + self.policies: list[dict[str, object]] = [] + self.tools: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def create_consent(self, *, consent_key: str, status: str, metadata: dict[str, object]) -> dict[str, object]: + consent = { + "id": uuid4(), + "user_id": self.user_id, + "consent_key": consent_key, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.consents)), + "updated_at": self.base_time + timedelta(minutes=len(self.consents)), + } + self.consents[consent_key] = consent + return consent + + def list_consents(self) -> list[dict[str, object]]: + return sorted( + self.consents.values(), + key=lambda consent: (consent["consent_key"], consent["created_at"], consent["id"]), + ) + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: dict[str, object], + required_consents: list[str], + ) -> dict[str, object]: + policy = { + "id": uuid4(), + "user_id": self.user_id, + "name": name, + "action": action, + "scope": scope, + "effect": effect, + "priority": priority, + "active": active, + "conditions": conditions, + "required_consents": required_consents, + "created_at": self.base_time + timedelta(minutes=len(self.policies)), + "updated_at": self.base_time + timedelta(minutes=len(self.policies)), + } + self.policies.append(policy) + return policy + + def list_active_policies(self) -> list[dict[str, object]]: + return sorted( + [policy for policy in self.policies if policy["active"] is True], + key=lambda policy: (policy["priority"], policy["created_at"], policy["id"]), + ) + + def create_tool( + self, + *, + tool_key: str, + name: str, + description: str, + version: str, + metadata_version: str, + active: bool, + tags: list[str], + action_hints: list[str], + scope_hints: list[str], + domain_hints: list[str], + risk_hints: list[str], + metadata: dict[str, object], + ) -> dict[str, object]: + tool = { + "id": uuid4(), + "user_id": self.user_id, + "tool_key": tool_key, + "name": name, + "description": description, + "version": version, + "metadata_version": metadata_version, + "active": active, + "tags": tags, + "action_hints": action_hints, + "scope_hints": scope_hints, + "domain_hints": domain_hints, + "risk_hints": risk_hints, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.tools)), + } + self.tools.append(tool) + return tool + + def get_tool_optional(self, tool_id: UUID) -> dict[str, object] | None: + return next((tool for tool in self.tools if tool["id"] == tool_id), None) + + def list_tools(self) -> list[dict[str, object]]: + return sorted( + self.tools, + key=lambda tool: (tool["tool_key"], tool["version"], tool["created_at"], tool["id"]), + ) + + def list_active_tools(self) -> list[dict[str, object]]: + return [tool for tool in self.list_tools() if tool["active"] is True] + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Tool thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time, + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time, + } + self.trace_events.append(event) + return event + + +def test_create_list_and_get_tool_records_preserve_deterministic_order() -> None: + store = ToolStoreStub() + later = create_tool_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + tool=ToolCreateInput( + tool_key="zeta.fetch", + name="Zeta Fetch", + description="Fetch zeta records.", + version="2.0.0", + action_hints=("tool.run",), + scope_hints=("workspace",), + ), + ) + earlier = store.create_tool( + tool_key="alpha.open", + name="Alpha Open", + description="Open alpha pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + listed = list_tool_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_tool_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + tool_id=UUID(later["tool"]["id"]), + ) + + assert [item["tool_key"] for item in listed["items"]] == ["alpha.open", "zeta.fetch"] + assert listed["summary"] == { + "total_count": 2, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert detail == {"tool": later["tool"]} + assert listed["items"][0]["id"] == str(earlier["id"]) + + +def test_evaluate_tool_allowlist_splits_allowed_denied_and_approval_required() -> None: + store = ToolStoreStub() + store.create_consent(consent_key="web_access", status="granted", metadata={"source": "settings"}) + allowed_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read a calendar.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=20, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + payload = evaluate_tool_allowlist( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolAllowlistEvaluationRequestInput( + thread_id=store.thread_id, + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={}, + ), + ) + + assert payload["allowed"] == [ + { + "decision": "allowed", + "tool": { + "id": str(allowed_tool["id"]), + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": allowed_tool["created_at"].isoformat(), + }, + "reasons": [ + { + "code": "tool_metadata_matched", + "source": "tool", + "message": "Tool metadata matched the requested action, scope, and optional hints.", + "tool_id": str(allowed_tool["id"]), + "policy_id": None, + "consent_key": None, + }, + { + "code": "matched_policy", + "source": "policy", + "message": "Matched policy 'Allow docs browser' at priority 10.", + "tool_id": str(allowed_tool["id"]), + "policy_id": str(store.policies[0]["id"]), + "consent_key": None, + }, + { + "code": "policy_effect_allow", + "source": "policy", + "message": "Policy effect resolved the decision to 'allow'.", + "tool_id": str(allowed_tool["id"]), + "policy_id": str(store.policies[0]["id"]), + "consent_key": None, + }, + ], + } + ] + assert [item["tool"]["id"] for item in payload["approval_required"]] == [str(approval_tool["id"])] + assert payload["approval_required"][0]["reasons"][-1]["code"] == "policy_effect_require_approval" + assert [item["tool"]["id"] for item in payload["denied"]] == [str(denied_tool["id"])] + assert [reason["code"] for reason in payload["denied"][0]["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert payload["summary"] == { + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "evaluated_tool_count": 3, + "allowed_count": 1, + "denied_count": 1, + "approval_required_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 6 + assert [event["kind"] for event in store.trace_events] == [ + "tool.allowlist.request", + "tool.allowlist.order", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.summary", + ] + + +def test_evaluate_tool_allowlist_validates_thread_scope() -> None: + with pytest.raises( + ToolAllowlistValidationError, + match="thread_id must reference an existing thread owned by the user", + ): + evaluate_tool_allowlist( + ToolStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + request=ToolAllowlistEvaluationRequestInput( + thread_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + +def test_route_tool_invocation_returns_ready_with_trace() -> None: + store = ToolStoreStub() + store.create_consent(consent_key="web_access", status="granted", metadata={"source": "settings"}) + tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + policy = store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + payload = route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={"channel": "chat"}, + ), + ) + + assert payload == { + "request": { + "thread_id": str(store.thread_id), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {"channel": "chat"}, + }, + "decision": "ready", + "tool": { + "id": str(tool["id"]), + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + "reasons": [ + { + "code": "tool_metadata_matched", + "source": "tool", + "message": "Tool metadata matched the requested action, scope, and optional hints.", + "tool_id": str(tool["id"]), + "policy_id": None, + "consent_key": None, + }, + { + "code": "matched_policy", + "source": "policy", + "message": "Matched policy 'Allow docs browser' at priority 10.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + }, + { + "code": "policy_effect_allow", + "source": "policy", + "message": "Policy effect resolved the decision to 'allow'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + }, + ], + "summary": { + "thread_id": str(store.thread_id), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 1, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + "trace": { + "trace_id": str(store.traces[0]["id"]), + "trace_event_count": 3, + }, + } + assert store.traces[0]["kind"] == "tool.route" + assert store.traces[0]["compiler_version"] == "tool_routing_v0" + assert [event["kind"] for event in store.trace_events] == [ + "tool.route.request", + "tool.route.decision", + "tool.route.summary", + ] + assert store.trace_events[1]["payload"]["allowlist_decision"] == "allowed" + assert store.trace_events[1]["payload"]["routing_decision"] == "ready" + + +def test_route_tool_invocation_returns_denied_for_metadata_or_policy_denial() -> None: + store = ToolStoreStub() + tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + + payload = route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + assert payload["decision"] == "denied" + assert [reason["code"] for reason in payload["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert payload["summary"]["decision"] == "denied" + assert payload["trace"]["trace_event_count"] == 3 + + +def test_route_tool_invocation_returns_approval_required() -> None: + store = ToolStoreStub() + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + payload = route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + assert payload["decision"] == "approval_required" + assert payload["summary"]["decision"] == "approval_required" + assert payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + } + + +def test_route_tool_invocation_validates_thread_scope() -> None: + store = ToolStoreStub() + tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + + with pytest.raises( + ToolRoutingValidationError, + match="thread_id must reference an existing thread owned by the user", + ): + route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=uuid4(), + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + +def test_route_tool_invocation_validates_active_tool_scope() -> None: + store = ToolStoreStub() + inactive_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=False, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + + with pytest.raises( + ToolRoutingValidationError, + match="tool_id must reference an existing active tool owned by the user", + ): + route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=inactive_tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) diff --git a/tests/unit/test_tools_main.py b/tests/unit/test_tools_main.py new file mode 100644 index 0000000..fc86a9c --- /dev/null +++ b/tests/unit/test_tools_main.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.tools import ( + ToolAllowlistValidationError, + ToolNotFoundError, + ToolRoutingValidationError, +) + + +def test_create_tool_endpoint_translates_request_and_returns_created_status(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_tool_record(store, *, user_id, tool): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["tool"] = tool + return { + "tool": { + "id": "tool-123", + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": "2026-03-12T09:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_tool_record", fake_create_tool_record) + + response = main_module.create_tool( + main_module.CreateToolRequest( + user_id=user_id, + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["tool"]["tool_key"] == "browser.open" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["tool"].tool_key == "browser.open" + assert captured["tool"].action_hints == ("tool.run",) + assert captured["tool"].scope_hints == ("workspace",) + + +def test_get_tool_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + tool_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_tool_record(*_args, **_kwargs): + raise ToolNotFoundError(f"tool {tool_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_tool_record", fake_get_tool_record) + + response = main_module.get_tool(tool_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"tool {tool_id} was not found"} + + +def test_evaluate_tool_allowlist_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_evaluate_tool_allowlist(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "allowed": [], + "denied": [], + "approval_required": [], + "summary": { + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "evaluated_tool_count": 0, + "allowed_count": 0, + "denied_count": 0, + "approval_required_count": 0, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_tool_allowlist", fake_evaluate_tool_allowlist) + + response = main_module.evaluate_tools_allowlist( + main_module.EvaluateToolAllowlistRequest( + user_id=user_id, + thread_id=thread_id, + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={"channel": "chat"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == {"trace_id": "trace-123", "trace_event_count": 3} + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].action == "tool.run" + assert captured["request"].scope == "workspace" + assert captured["request"].domain_hint == "docs" + assert captured["request"].attributes == {"channel": "chat"} + + +def test_evaluate_tool_allowlist_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_evaluate_tool_allowlist(*_args, **_kwargs): + raise ToolAllowlistValidationError("thread_id must reference an existing thread owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_tool_allowlist", fake_evaluate_tool_allowlist) + + response = main_module.evaluate_tools_allowlist( + main_module.EvaluateToolAllowlistRequest( + user_id=user_id, + thread_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "thread_id must reference an existing thread owned by the user" + } + + +def test_route_tool_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_route_tool_invocation(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "request": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {"channel": "chat"}, + }, + "decision": "ready", + "tool": { + "id": str(tool_id), + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "reasons": [], + "summary": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 1, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "route_tool_invocation", fake_route_tool_invocation) + + response = main_module.route_tool( + main_module.RouteToolRequest( + user_id=user_id, + thread_id=thread_id, + tool_id=tool_id, + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={"channel": "chat"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == {"trace_id": "trace-123", "trace_event_count": 3} + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].tool_id == tool_id + assert captured["request"].action == "tool.run" + assert captured["request"].scope == "workspace" + assert captured["request"].domain_hint == "docs" + assert captured["request"].attributes == {"channel": "chat"} + + +def test_route_tool_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_route_tool_invocation(*_args, **_kwargs): + raise ToolRoutingValidationError("tool_id must reference an existing active tool owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "route_tool_invocation", fake_route_tool_invocation) + + response = main_module.route_tool( + main_module.RouteToolRequest( + user_id=user_id, + thread_id=uuid4(), + tool_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "tool_id must reference an existing active tool owned by the user" + } diff --git a/tests/unit/test_trace_store.py b/tests/unit/test_trace_store.py new file mode 100644 index 0000000..28bbac2 --- /dev/null +++ b/tests/unit/test_trace_store.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb +import pytest + +from alicebot_api.store import AppendOnlyViolation, ContinuityStore, ContinuityStoreInvariantError + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_trace_methods_use_expected_queries_and_payload_serialization() -> None: + user_id = uuid4() + thread_id = uuid4() + trace_id = uuid4() + payload = {"reason": "within_event_limit"} + cursor = RecordingCursor( + fetchone_results=[ + {"id": user_id, "email": "owner@example.com", "display_name": "Owner"}, + {"id": thread_id, "user_id": user_id, "title": "Thread"}, + {"id": trace_id, "user_id": user_id, "thread_id": thread_id, "kind": "context.compile"}, + { + "id": uuid4(), + "user_id": user_id, + "trace_id": trace_id, + "sequence_no": 1, + "kind": "context.include", + "payload": payload, + }, + ], + fetchall_result=[ + {"sequence_no": 1, "kind": "context.include", "payload": payload}, + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + user = store.get_user(user_id) + thread = store.get_thread(thread_id) + trace = store.create_trace( + user_id=user_id, + thread_id=thread_id, + kind="context.compile", + compiler_version="continuity_v0", + status="completed", + limits={"max_sessions": 3, "max_events": 8}, + ) + trace_event = store.append_trace_event( + trace_id=trace_id, + sequence_no=1, + kind="context.include", + payload=payload, + ) + listed_trace_events = store.list_trace_events(trace_id) + + assert user["id"] == user_id + assert thread["id"] == thread_id + assert trace["id"] == trace_id + assert trace_event["sequence_no"] == 1 + assert listed_trace_events == [{"sequence_no": 1, "kind": "context.include", "payload": payload}] + + assert cursor.executed[0] == ( + """ + SELECT id, email, display_name, created_at + FROM users + WHERE id = %s + """, + (user_id,), + ) + assert cursor.executed[1] == ( + """ + SELECT id, user_id, title, created_at, updated_at + FROM threads + WHERE id = %s + """, + (thread_id,), + ) + create_trace_query, create_trace_params = cursor.executed[2] + assert "INSERT INTO traces" in create_trace_query + assert create_trace_params is not None + assert create_trace_params[:5] == ( + user_id, + thread_id, + "context.compile", + "continuity_v0", + "completed", + ) + assert isinstance(create_trace_params[5], Jsonb) + assert create_trace_params[5].obj == {"max_sessions": 3, "max_events": 8} + + append_trace_query, append_trace_params = cursor.executed[3] + assert "INSERT INTO trace_events" in append_trace_query + assert append_trace_params is not None + assert append_trace_params[:3] == (trace_id, 1, "context.include") + assert isinstance(append_trace_params[3], Jsonb) + assert append_trace_params[3].obj == payload + + +def test_trace_event_updates_and_deletes_are_rejected_by_contract() -> None: + store = ContinuityStore(conn=None) # type: ignore[arg-type] + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.update_trace_event("trace-event-id", {"text": "mutated"}) + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.delete_trace_event("trace-event-id") + + +def test_get_trace_raises_clear_error_when_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + with pytest.raises( + ContinuityStoreInvariantError, + match="get_trace did not return a row", + ): + store.get_trace(uuid4()) diff --git a/tests/unit/test_worker_main.py b/tests/unit/test_worker_main.py new file mode 100644 index 0000000..b6391d2 --- /dev/null +++ b/tests/unit/test_worker_main.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import logging +import os +from pathlib import Path +import subprocess +import sys + +from workers.alicebot_worker.main import run + + +def test_run_logs_scaffold_message(caplog) -> None: + with caplog.at_level(logging.INFO, logger="alicebot.worker"): + run() + + assert caplog.messages == [ + "Worker scaffold initialized; no background jobs are in scope for this sprint." + ] + + +def test_module_entrypoint_logs_scaffold_message() -> None: + repo_root = Path(__file__).resolve().parents[2] + env = os.environ.copy() + pythonpath_entries = [str(repo_root / "apps" / "api" / "src"), str(repo_root / "workers")] + existing_pythonpath = env.get("PYTHONPATH") + if existing_pythonpath: + pythonpath_entries.append(existing_pythonpath) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries) + + result = subprocess.run( + [sys.executable, "-m", "alicebot_worker.main"], + cwd=repo_root, + env=env, + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0 + assert "Worker scaffold initialized; no background jobs are in scope for this sprint." in result.stderr diff --git a/tests/unit/test_workspaces.py b/tests/unit/test_workspaces.py new file mode 100644 index 0000000..67cb1bc --- /dev/null +++ b/tests/unit/test_workspaces.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from pathlib import Path +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.config import Settings +from alicebot_api.contracts import TaskWorkspaceCreateInput +from alicebot_api.tasks import TaskNotFoundError +from alicebot_api.workspaces import ( + TaskWorkspaceAlreadyExistsError, + TaskWorkspaceNotFoundError, + TaskWorkspaceProvisioningError, + build_task_workspace_path, + create_task_workspace_record, + ensure_workspace_path_is_rooted, + get_task_workspace_record, + list_task_workspace_records, + serialize_task_workspace_row, +) + + +class WorkspaceStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.tasks: list[dict[str, object]] = [] + self.workspaces: list[dict[str, object]] = [] + self.locked_task_ids: list[UUID] = [] + + def create_task(self, *, task_id: UUID, user_id: UUID) -> None: + self.tasks.append( + { + "id": task_id, + "user_id": user_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "status": "approved", + "request": {}, + "tool": {}, + "latest_approval_id": None, + "latest_execution_id": None, + "created_at": self.base_time, + "updated_at": self.base_time, + } + ) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def lock_task_workspaces(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def get_active_task_workspace_for_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next( + ( + workspace + for workspace in self.workspaces + if workspace["task_id"] == task_id and workspace["status"] == "active" + ), + None, + ) + + def create_task_workspace( + self, + *, + task_id: UUID, + status: str, + local_path: str, + ) -> dict[str, object]: + workspace = { + "id": uuid4(), + "user_id": self.tasks[0]["user_id"], + "task_id": task_id, + "status": status, + "local_path": local_path, + "created_at": self.base_time + timedelta(minutes=len(self.workspaces)), + "updated_at": self.base_time + timedelta(minutes=len(self.workspaces)), + } + self.workspaces.append(workspace) + return workspace + + def list_task_workspaces(self) -> list[dict[str, object]]: + return sorted(self.workspaces, key=lambda workspace: (workspace["created_at"], workspace["id"])) + + def get_task_workspace_optional(self, task_workspace_id: UUID) -> dict[str, object] | None: + return next((workspace for workspace in self.workspaces if workspace["id"] == task_workspace_id), None) + + +def test_build_task_workspace_path_is_deterministic() -> None: + user_id = UUID("00000000-0000-0000-0000-000000000111") + task_id = UUID("00000000-0000-0000-0000-000000000222") + root = Path("/tmp/alicebot/task-workspaces") + + path = build_task_workspace_path( + workspace_root=root, + user_id=user_id, + task_id=task_id, + ) + + assert path == Path("/tmp/alicebot/task-workspaces") / str(user_id) / str(task_id) + + +def test_ensure_workspace_path_is_rooted_rejects_escape() -> None: + with pytest.raises(TaskWorkspaceProvisioningError, match="escapes configured root"): + ensure_workspace_path_is_rooted( + workspace_root=Path("/tmp/alicebot/task-workspaces"), + workspace_path=Path("/tmp/alicebot/task-workspaces/../escape"), + ) + + +def test_create_task_workspace_record_provisions_directory_and_returns_record(tmp_path) -> None: + store = WorkspaceStoreStub() + user_id = uuid4() + task_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + settings = Settings(task_workspace_root=str(tmp_path / "workspaces")) + + response = create_task_workspace_record( + store, + settings=settings, + user_id=user_id, + request=TaskWorkspaceCreateInput(task_id=task_id, status="active"), + ) + + expected_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + assert response == { + "workspace": { + "id": response["workspace"]["id"], + "task_id": str(task_id), + "status": "active", + "local_path": str(expected_path.resolve()), + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + } + } + assert expected_path.is_dir() + assert store.locked_task_ids == [task_id] + + +def test_create_task_workspace_record_rejects_duplicate_active_workspace(tmp_path) -> None: + store = WorkspaceStoreStub() + user_id = uuid4() + task_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + settings = Settings(task_workspace_root=str(tmp_path / "workspaces")) + create_task_workspace_record( + store, + settings=settings, + user_id=user_id, + request=TaskWorkspaceCreateInput(task_id=task_id, status="active"), + ) + + with pytest.raises(TaskWorkspaceAlreadyExistsError, match=f"task {task_id} already has active workspace"): + create_task_workspace_record( + store, + settings=settings, + user_id=user_id, + request=TaskWorkspaceCreateInput(task_id=task_id, status="active"), + ) + + +def test_create_task_workspace_record_requires_visible_task(tmp_path) -> None: + store = WorkspaceStoreStub() + + with pytest.raises(TaskNotFoundError, match="was not found"): + create_task_workspace_record( + store, + settings=Settings(task_workspace_root=str(tmp_path / "workspaces")), + user_id=uuid4(), + request=TaskWorkspaceCreateInput(task_id=uuid4(), status="active"), + ) + + +def test_list_and_get_task_workspace_records_are_deterministic() -> None: + store = WorkspaceStoreStub() + user_id = uuid4() + task_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + workspace = store.create_task_workspace( + task_id=task_id, + status="active", + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + + assert list_task_workspace_records(store, user_id=user_id) == { + "items": [serialize_task_workspace_row(workspace)], + "summary": { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert get_task_workspace_record( + store, + user_id=user_id, + task_workspace_id=workspace["id"], + ) == {"workspace": serialize_task_workspace_row(workspace)} + + +def test_get_task_workspace_record_raises_when_workspace_is_missing() -> None: + with pytest.raises(TaskWorkspaceNotFoundError, match="was not found"): + get_task_workspace_record( + WorkspaceStoreStub(), + user_id=uuid4(), + task_workspace_id=uuid4(), + ) diff --git a/tests/unit/test_workspaces_main.py b/tests/unit/test_workspaces_main.py new file mode 100644 index 0000000..b5f19ff --- /dev/null +++ b/tests/unit/test_workspaces_main.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.tasks import TaskNotFoundError +from alicebot_api.workspaces import TaskWorkspaceAlreadyExistsError, TaskWorkspaceNotFoundError + + +def test_list_task_workspaces_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_workspace_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + }, + ) + + response = main_module.list_task_workspaces(user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + +def test_get_task_workspace_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_task_workspace_record(*_args, **_kwargs): + raise TaskWorkspaceNotFoundError(f"task workspace {task_workspace_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_task_workspace_record", fake_get_task_workspace_record) + + response = main_module.get_task_workspace(task_workspace_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task workspace {task_workspace_id} was not found"} + + +def test_create_task_workspace_endpoint_maps_task_not_found_to_404(monkeypatch) -> None: + task_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_task_workspace_record(*_args, **_kwargs): + raise TaskNotFoundError(f"task {task_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_task_workspace_record", fake_create_task_workspace_record) + + response = main_module.create_task_workspace( + task_id, + main_module.CreateTaskWorkspaceRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task {task_id} was not found"} + + +def test_create_task_workspace_endpoint_maps_duplicate_to_409(monkeypatch) -> None: + task_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_task_workspace_record(*_args, **_kwargs): + raise TaskWorkspaceAlreadyExistsError(f"task {task_id} already has active workspace workspace-123") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_task_workspace_record", fake_create_task_workspace_record) + + response = main_module.create_task_workspace( + task_id, + main_module.CreateTaskWorkspaceRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"task {task_id} already has active workspace workspace-123" + } diff --git a/workers/.gitkeep b/workers/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/workers/.gitkeep @@ -0,0 +1 @@ + diff --git a/workers/alicebot_worker/__init__.py b/workers/alicebot_worker/__init__.py new file mode 100644 index 0000000..462d476 --- /dev/null +++ b/workers/alicebot_worker/__init__.py @@ -0,0 +1,2 @@ +"""Worker scaffold for future asynchronous jobs.""" + diff --git a/workers/alicebot_worker/main.py b/workers/alicebot_worker/main.py new file mode 100644 index 0000000..21d01ff --- /dev/null +++ b/workers/alicebot_worker/main.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import logging + + +LOGGER = logging.getLogger("alicebot.worker") + + +def run() -> None: + LOGGER.info("Worker scaffold initialized; no background jobs are in scope for this sprint.") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + run() From 8103d5e731bd14bd3660723443e770ba4000693d Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 13 Mar 2026 23:47:59 +0100 Subject: [PATCH 002/135] Sprint 5B: Project Truth Compaction 01 (#2) --- README.md | 84 +++++++++++++----------------------------- ROADMAP.md | 106 ++++++++++++----------------------------------------- RULES.md | 63 ++++++++++++------------------- 3 files changed, 74 insertions(+), 179 deletions(-) diff --git a/README.md b/README.md index 97fd320..7e25b7a 100644 --- a/README.md +++ b/README.md @@ -1,73 +1,41 @@ # AliceBot -AliceBot is a private, permissioned personal AI operating system. The repository now includes the runnable foundation slice plus the first tracing/context-compilation seam, the first governed memory/admissions-and-embeddings slice, the first deterministic response-generation seam, the first governance routing seam for non-executing tool requests, the first durable approval-request persistence seam for `approval_required` routing outcomes, the explicit approval-resolution seam, the first minimal approved-only proxy-execution seam, the first durable execution-review seam over that proxy path, the narrow execution-budget lifecycle seam over approved proxy execution, and the first deterministic task-workspace provisioning seam: local infrastructure, an API scaffold, migration tooling, continuity primitives, persisted traces, a deterministic continuity-only compiler, explicit memory admission, a narrow deterministic explicit-preference extraction path, explicit embedding-config and memory-embedding storage paths, a direct semantic memory retrieval primitive, deterministic hybrid compile-path memory merge, a no-tools model invocation path over deterministically assembled prompts, deterministic policy and tool-governance seams, a narrow no-side-effect proxy handler path, durable `tool_executions` records, durable `execution_budgets` records, durable `task_workspaces` records, execution-budget create/list/detail reads, budget deactivate/supersede lifecycle operations, active-only budget enforcement, budget-blocked execution persistence, task-workspace create/list/detail reads, and backend verification coverage. +AliceBot is a private, permissioned personal AI operating system. The current repo contains the accepted backend slice through Sprint 5A plus local developer tooling. -## Status +## Current Implemented Slice -- Local Docker Compose infrastructure is defined for Postgres with `pgvector`, Redis, and MinIO. -- `apps/api` contains FastAPI health, compile, response-generation, memory-admission, explicit-preference extraction, semantic-memory-retrieval, policy, tool-registry, tool-allowlist, tool-routing, approval-request, approval-resolution, proxy-execution, execution-budget, execution-review, task, and task-workspace endpoints, configuration loading, Alembic migrations, continuity storage primitives, the Sprint 2A trace/compiler path, the Sprint 3A memory-admission path, the Sprint 3I deterministic extraction path, the Sprint 3K embedding substrate, the Sprint 3L semantic retrieval primitive, the Sprint 3M compile-path semantic retrieval adoption, the Sprint 3N deterministic hybrid memory merge, the Sprint 4A deterministic prompt-assembly and no-tools response path, the Sprint 4D deterministic non-executing tool-routing seam, the Sprint 4E durable approval-request persistence seam, the Sprint 4F approval-resolution seam, the Sprint 4G minimal approved-only proxy-execution seam, the Sprint 4H durable execution-review seam, the Sprint 4I execution-budget guard seam, the Sprint 4J execution-budget lifecycle seam, the Sprint 4K time-windowed execution-budget seam, the Sprint 4S explicit execution-to-task-step linkage seam, and the Sprint 5A task-workspace provisioning seam. -- `apps/web` and `workers` contain minimal starter scaffolds for later milestone work. -- The active sprint is documented in [.ai/active/SPRINT_PACKET.md](.ai/active/SPRINT_PACKET.md). +- `apps/api` is the shipped surface. It includes continuity storage, tracing, deterministic context compilation, governed memory admission and review, embeddings, semantic retrieval, entities, policy and tool governance, approval persistence and resolution, approved-only `proxy.echo` execution, execution budgets, tasks, task steps, explicit manual continuation lineage, step-linked approval/execution synchronization, and deterministic rooted local task-workspace provisioning. +- `apps/web` and `workers` are starter scaffolds only. +- Task workspaces are currently local rooted directories plus durable records. Artifact indexing, document ingestion, connectors, and runner orchestration are not shipped. ## Quick Start 1. Create a local env file: `cp .env.example .env` -2. Start required infrastructure with one command: `docker compose up -d` -3. Create a project virtualenv and install Python dependencies: `python3 -m venv .venv && ./.venv/bin/python -m pip install -e '.[dev]'` -4. Run database migrations: `./scripts/migrate.sh` -5. Start the API locally: `./scripts/api_dev.sh` +2. Start infrastructure: `docker compose up -d` +3. Create a virtualenv and install dependencies: `python3 -m venv .venv && ./.venv/bin/python -m pip install -e '.[dev]'` +4. Apply migrations: `./scripts/migrate.sh` +5. Start the API: `./scripts/api_dev.sh` -The health endpoint is exposed at [http://127.0.0.1:8000/healthz](http://127.0.0.1:8000/healthz). -The minimal context-compilation API path is `POST /v0/context/compile`. -The minimal response-generation API path is `POST /v0/responses`. -The minimal memory-admission API path is `POST /v0/memories/admit`. -The explicit-preference extraction API path is `POST /v0/memories/extract-explicit-preferences`. -The minimal non-executing tool-routing API path is `POST /v0/tools/route`. -The minimal approval API paths are `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, and `POST /v0/approvals/{approval_id}/execute`. -The execution-budget API paths are `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, and `POST /v0/execution-budgets/{execution_budget_id}/supersede`. -The execution-review API paths are `GET /v0/tool-executions` and `GET /v0/tool-executions/{execution_id}`. -The task-workspace API paths are `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, and `GET /v0/task-workspaces/{task_workspace_id}`. -The helper scripts load the repo-root `.env` automatically and prefer `.venv/bin/python` when that virtualenv exists, falling back to `python3` otherwise. The default migration/admin URL targets the same local `alicebot` database as the app runtime. -`/healthz` currently performs a live Postgres check only. Redis and MinIO are reported as configured endpoints with `not_checked` status. -`TASK_WORKSPACE_ROOT` controls the single rooted base directory used for deterministic local task-workspace provisioning. By default it is `/tmp/alicebot/task-workspaces`, and each workspace path is created as `//`. -The current backend path has been verified in a local developer environment with `docker compose up -d`, `./scripts/migrate.sh`, `./.venv/bin/python -m pytest tests/unit tests/integration`, a live `GET /healthz`, and the Postgres-backed `POST /v0/context/compile`, `POST /v0/responses`, `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `POST /v0/memories/semantic-retrieval`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `POST /v0/approvals/{approval_id}/execute`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, and `GET /v0/tool-executions/{execution_id}` integration paths, including compile requests that explicitly enable the hybrid memory merge, response requests that persist assistant events and response traces, deterministic non-executing tool-routing requests that persist `tool.route.*` traces, approval-request persistence requests that persist `approval.request.*` traces plus durable approval rows only for `approval_required` outcomes, approved proxy execution that persists `tool.proxy.execute.*` traces plus durable `tool_executions` rows for approved execution attempts, deterministic budget-management requests over durable `execution_budgets` rows, lifecycle requests that persist `execution_budget.lifecycle.*` traces and change budget status deterministically, budget-prechecked proxy execution that emits `tool.proxy.execute.budget` trace events against active budgets only, and execution-review reads over those durable records including budget-blocked attempts. +Useful checks: -## Repo Structure +- API health: [http://127.0.0.1:8000/healthz](http://127.0.0.1:8000/healthz) +- Full backend tests: `./.venv/bin/python -m pytest tests/unit tests/integration` +- Web shell: `pnpm --dir apps/web dev` -- [PRODUCT_BRIEF.md](PRODUCT_BRIEF.md): permanent product truth. -- [ARCHITECTURE.md](ARCHITECTURE.md): permanent technical truth. -- [ROADMAP.md](ROADMAP.md): milestone sequence and delivery risks. -- [RULES.md](RULES.md): durable engineering and scope rules. -- [.ai/handoff/CURRENT_STATE.md](.ai/handoff/CURRENT_STATE.md): fresh-thread recovery snapshot. -- [.ai/active/SPRINT_PACKET.md](.ai/active/SPRINT_PACKET.md): current builder sprint. -- `docker-compose.yml`: local Postgres, Redis, and MinIO stack. -- `infra/postgres/init/`: Postgres bootstrap SQL, including the non-superuser app role. -- `apps/api/`: FastAPI app, config, continuity store, and Alembic migrations. -- `apps/web/`: minimal Next.js shell for later dashboard work. -- `workers/`: placeholder Python worker package for future background jobs. -- `tests/`: unit and Postgres-backed integration tests for the foundation slice. -- `scripts/`: local development and migration entrypoints. - -## Essential Commands +## Repo Map -- `docker compose up -d`: start Postgres, Redis, and MinIO on `127.0.0.1`. -- `./scripts/dev_up.sh`: start local infrastructure, wait for Postgres and role bootstrap readiness, and apply Alembic migrations. -- `./scripts/migrate.sh`: apply Alembic migrations with the admin database URL from `.env` or the built-in defaults. -- `./scripts/api_dev.sh`: run the FastAPI service with auto-reload. -- `./.venv/bin/python -m pytest tests/unit tests/integration`: run backend tests from the project virtualenv. -- `pnpm --dir apps/web dev`: start the web shell after frontend dependencies are installed. +- [PRODUCT_BRIEF.md](PRODUCT_BRIEF.md): stable product scope and ship gates. +- [ARCHITECTURE.md](ARCHITECTURE.md): implemented technical boundaries and planned-later boundaries. +- [ROADMAP.md](ROADMAP.md): forward-looking milestone direction from the current repo position. +- [RULES.md](RULES.md): durable engineering and scope rules. +- [.ai/handoff/CURRENT_STATE.md](.ai/handoff/CURRENT_STATE.md): compact current-state recovery snapshot. +- [.ai/active/SPRINT_PACKET.md](.ai/active/SPRINT_PACKET.md): active builder scope. +- [docs/archive/sprints](docs/archive/sprints): archived sprint build and review history. ## Environment Notes -- Postgres is the system of record and the live schema now includes continuity tables, trace tables, policy-governance tables including `approvals`, `tool_executions`, and `execution_budgets`, task lifecycle tables including `tasks`, `task_steps`, and `task_workspaces`, memory tables, entity tables, and the embedding substrate tables `embedding_configs` and `memory_embeddings`. -- Sprint 2A adds persisted `traces` and `trace_events` plus a deterministic continuity-only context compiler over existing durable continuity records. -- Sprint 3A adds governed `memories` and append-only `memory_revisions` plus an explicit `NOOP`-first admission path over cited source events. -- The app and migration defaults both target the local `alicebot` database to keep quick-start behavior deterministic. -- `TASK_WORKSPACE_ROOT` defaults to `/tmp/alicebot/task-workspaces` and defines the only allowed root for deterministic local task-workspace provisioning. -- Local service ports are bound to `127.0.0.1` by default to avoid exposing fixed development credentials on non-loopback interfaces. -- Redis is reserved for future queue, lock, and cache work; no retrieval or orchestration features are enabled in this sprint. -- MinIO provides the local S3-compatible endpoint for future document and artifact storage. -- Continuity tables enforce row-level security from the start and `events` are append-only by application contract plus database trigger, with concurrent appends serialized per thread. -- Trace tables follow the same per-user isolation model, with append-only `trace_events` for compiler explainability. -- Memory admission remains explicit and evidence-backed, automatic extraction is currently limited to a narrow deterministic explicit-preference path over stored user messages, and the repo now includes explicit versioned embedding-config storage, direct memory-embedding persistence, a direct semantic retrieval API over active durable memories, compile-path hybrid memory merge into one `context_pack["memories"]` section with `memory_summary.hybrid_retrieval` metadata, one deterministic no-tools response path that assembles prompts from durable compiled context and persists assistant replies plus response traces, one deterministic approval-request persistence path over `approval_required` tool-routing outcomes, explicit approval resolution, one minimal approved-only proxy execution path through the no-side-effect `proxy.echo` handler, durable execution-review records plus list/detail reads for approved execution attempts, one narrow deterministic execution-budget seam that can activate, deactivate, supersede, and enforce both lifetime and rolling-window limits using durable `tool_executions` history while keeping blocked attempts reviewable, and one narrow deterministic task-workspace seam that provisions rooted local workspace directories and persists durable `task_workspaces` rows. Broader extraction, reranking, external-connector tool execution, artifact indexing, document ingestion, orchestration, and review UI remain deferred. -- The runtime database role is limited to `SELECT`/`INSERT` on continuity and trace tables, `SELECT`/`INSERT` on `memory_revisions`, `memory_review_labels`, `embedding_configs`, `entities`, and `entity_edges`, plus `SELECT`/`INSERT`/`UPDATE` on `consents`, `memories`, `memory_embeddings`, and `execution_budgets`. +- Postgres is the system of record. +- Local Docker Compose includes Postgres with `pgvector`, Redis, and MinIO. +- The helper scripts source the repo-root `.env` and prefer `.venv/bin/python` when present. +- `TASK_WORKSPACE_ROOT` defaults to `/tmp/alicebot/task-workspaces` and is the only allowed root for task-workspace provisioning. +- `/healthz` performs a live Postgres check; Redis and MinIO are reported as configured but not live-checked. diff --git a/ROADMAP.md b/ROADMAP.md index ed7c2ba..a963227 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,97 +1,39 @@ # Roadmap -## Current State +## Current Position -- The repo has shipped the implementation slices originally planned as Milestones 1 through 4. -- Sprint 4O added the latest accepted backend seam: durable `tasks` and `task_steps` with explicit manual continuation lineage and deterministic task-step transitions. -- The project is no longer at Foundation. The current repo state is a post-Milestone-4 checkpoint, and this sprint is synchronizing project-truth docs before Milestone 5 work begins. -- No task runner, workspace/artifact layer, document ingestion, read-only connector, or broader side-effect surface has landed yet. +- The accepted repo state is current through Sprint 5A. +- The backend foundation through governance, execution review, task/task-step lifecycle, explicit manual continuation, step-linked approval/execution synchronization, and deterministic rooted task-workspace provisioning is already shipped. +- This roadmap is future-facing from that position; milestone history lives in archived sprint reports, not here. -## Completed Milestones +## Next Delivery Focus -### Milestone 1: Foundation +### Finish Milestone 5 On Top Of The Shipped Workspace Boundary -- Repo scaffold, local Docker Compose infra, FastAPI app shell, config loading, migration tooling, and backend test harness. -- Postgres continuity primitives: `users`, `threads`, `sessions`, and append-only `events`. -- Row-level-security foundation and concurrent event sequencing hardening. +- Add artifact records and artifact-handling rules that reuse `task_workspaces` instead of inventing a parallel storage seam. +- Add document ingestion and retrieval only after the artifact/workspace boundary is explicit and reviewable. +- Add read-only Gmail and Calendar connectors only after document and workspace boundaries remain deterministic under the current governance model. -Status on March 13, 2026: -- Complete. +### Preserve Current Governance And Task Guarantees -### Milestone 2: Context Compiler and Tracing +- Keep approvals, execution budgets, task/task-step state, and trace visibility deterministic as new Milestone 5 work lands. +- Do not widen the current no-external-I/O proxy surface or introduce new consequential side effects without an explicit sprint opening that scope. -- Deterministic context compilation over durable continuity records. -- Persisted `traces` and append-only `trace_events`. -- Trace-visible inclusion and exclusion reasoning for compiled context. +## After Milestone 5 -Status on March 13, 2026: -- Complete. - -### Milestone 3: Memory and Retrieval - -- Governed memory admission with append-only revisions. -- Narrow deterministic explicit-preference extraction from stored user events. -- Memory review labels, review queue reads, and evaluation summary reads. -- Explicit entities and temporal entity edges backed by cited memories. -- Versioned embedding configs, durable memory embeddings, direct semantic retrieval, and deterministic hybrid compile-path memory merge. - -Status on March 13, 2026: -- Complete. - -### Milestone 4: Governance and Safe Action - -- Deterministic response generation over compiled context. -- User-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, and tool routing. -- Durable approval requests and explicit approval resolution. -- Approved-only proxy execution through the in-process `proxy.echo` handler. -- Durable execution review, execution-budget enforcement, lifecycle mutations, and optional rolling-window limits. -- Durable `tasks` and `task_steps`, deterministic task-step reads, explicit task-step transitions, and explicit manual continuation with lineage. - -Status on March 13, 2026: -- Complete through Sprint 4O. - -## Current Milestone Position - -- The repo is at the boundary after Milestone 4. -- Milestone 5 has not started in shipped code yet. -- The immediate work is documentation synchronization and narrow lifecycle-boundary hardening so Milestone 5 planning and review start from truthful artifacts. - -## Next Milestones - -### Immediate Next Narrow Boundary - -- Preserve the current manual-continuation seam as the only shipped multi-step task path. -- Remove or explicitly constrain the remaining approval/execution helpers that still synchronize against `task_steps.sequence_no = 1` before starting runner-style orchestration or workspace-heavy task flows. - -### Milestone 5: Documents, Workspaces, and Read-Only Connectors - -- Add document ingestion and chunk retrieval. -- Add scoped task workspaces and artifact handling. -- Add read-only Gmail and Calendar sync. -- Keep connector scope read-only and approval-aware. - -### Sequencing After Milestone 5 - -- Generalize task lifecycle handling beyond the current manual continuation seam. -- Introduce runner-style orchestration only after the first-step lifecycle assumption is removed. -- Expand tool execution breadth only after the governance and task seams stay deterministic under multi-step flows. +- Revisit broader task orchestration only after the current explicit task-step seams remain stable under workspace, artifact, and document flows. +- Expand tool execution breadth only after governance, review, and budget controls still hold under the wider task surface. +- Address production-facing auth and deployment hardening as the product approaches broader real-world use. ## Dependencies -- Truth artifacts must stay synchronized before milestone planning and review work can be trusted. -- The current first-step lifecycle assumption must be resolved before broader runner or workspace work can safely depend on `tasks` / `task_steps`. -- Scoped workspace and artifact boundaries should land before document-heavy or connector-heavy flows rely on them. -- Connector scope should remain deferred until the core memory, governance, and task seams stay stable under the shipped workload. - -## Blockers and Risks - -- Memory extraction and retrieval quality remain the biggest product risk. -- Auth beyond DB user context is still unimplemented. -- The remaining first-step approval/execution synchronization helpers are a forward-compatibility risk for broader multi-step orchestration. -- Workspace or connector work could create hidden scope drift if it starts before the current task-lifecycle boundary is hardened. +- Live truth docs must stay synchronized with accepted repo state so sprint planning does not start from stale assumptions. +- Artifact and document work should build on the existing rooted local workspace contract. +- Connector work should remain read-only and approval-aware. +- Runner-style orchestration should stay deferred until the repo no longer depends on narrow current-step assumptions for safety and explainability. -## Recently Completed +## Ongoing Risks -- Durable approval, execution review, and execution-budget seams over the approved proxy path. -- Durable `tasks` and `task_steps` with deterministic reads and status transitions. -- Explicit task-step lineage and manual continuation, including adversarial validation for cross-task, cross-user, and parent-step mismatch cases. +- Memory extraction and retrieval quality remain the largest product risk. +- Auth beyond database user context is still missing. +- Milestone 5 can drift if artifact, document, connector, and orchestration work are mixed into one sprint instead of landing as narrow seams. diff --git a/RULES.md b/RULES.md index f6ac44b..6c02e05 100644 --- a/RULES.md +++ b/RULES.md @@ -1,52 +1,37 @@ # Rules -## Product / Scope Rules +## Truth And Scope + +- The active sprint packet is the top scope boundary for implementation work. +- Never describe planned behavior as already implemented. +- Keep canonical truth files concise, current, and durable. +- Archive stale planning or history material instead of deleting it when traceability still matters. +- Do not widen product scope without an explicit roadmap or sprint change. + +## Product And Safety -- The active sprint packet is the top priority scope boundary for implementation work and overrides broader roadmap intent when they conflict. -- Never represent planned architecture as implemented behavior in docs, handoffs, or build reports. - Never execute a consequential external action without explicit user approval. -- Always treat explainability as a product feature, not an internal debugging aid. +- Treat explainability as a product feature, not an internal debugging aid. - Treat the repeat magnesium reorder as the v1 ship-gate scenario. -- Never expand v1 scope with proactive automation, write-capable connectors, voice, or browser automation without an explicit roadmap change. -- Do not start runner, workspace/artifact, document-ingestion, or connector work unless the active sprint explicitly opens that boundary. +- Do not add proactive automation, write-capable connectors, voice, or browser automation without an explicit roadmap change. -## Architecture Rules +## Architecture And Data -- Treat the immutable event store as ground truth; memories, tasks, and summaries are derived or governed views over durable records. +- Treat the immutable event store as ground truth; downstream memories, tasks, and summaries are derived or governed views. - Always compile context per invocation from durable sources. -- Keep prompt prefixes, tool schemas, and serialized context ordering deterministic. -- Treat Postgres as the v1 system of record unless measured constraints justify a platform split. -- Appended task steps must carry explicit lineage to a prior visible task step. Do not relink approvals or executions heuristically from broader task history. -- Manual continuation is the current multi-step boundary. Until the older first-step lifecycle helpers are removed or constrained, do not describe broader automatic multi-step orchestration as implemented. - -## Coding Rules - -- Always build against typed contracts and migration-backed schemas first. -- Never mutate tool schemas mid-session; enforce access through policy and proxy layers. -- Keep changes small, module-scoped, and test-backed. -- Stop long-running tasks with a clear progress summary when budgets or circuit breakers trip. -- Sprint-scoped docs must clearly separate what exists now from what is only planned later. - -## Data / Schema Rules - -- Enforce row-level security on every user-owned table from the start. -- Default memory admission to `NOOP`; promote only evidence-backed changes. -- Always keep memory revision history for non-`NOOP` changes. -- Task-step lineage references must stay inside the current user scope and must validate against the intended parent step and its recorded outcome. +- Keep prompt assembly, tool schemas, and serialized context ordering deterministic. +- Treat Postgres as the v1 system of record unless measured constraints justify a change. +- Task-step lineage and execution linkage must stay explicit; do not reconstruct them heuristically from broader task history. +- Enforce row-level security on every user-owned table. +- Default memory admission to `NOOP`; promote only evidence-backed changes and preserve revision history for non-`NOOP` updates. - Apply domain and sensitivity filters before semantic retrieval. -## Deployment / Ops Rules - -- Keep v1 operations simple: one modular monolith, one primary database, one cache, one object store. -- Never store secrets in source control, committed config, or logs. -- Any repo-advertised bootstrap script that starts dependencies and then runs dependent commands must wait for service readiness before proceeding. -- When external side effects are introduced, route them through approval-aware tool execution paths. -- Backups and object versioning are required before production use. - -## Testing Rules +## Delivery And Testing +- Build against typed contracts and migration-backed schemas first. +- Keep changes small, module-scoped, and test-backed. +- Never bypass policy, approval, or proxy boundaries to introduce side effects. - Schema changes are not complete without forward and rollback coverage. - Every module needs unit tests and at least one integration boundary test. -- Approval boundaries, RLS isolation, and audit logging require adversarial tests. -- Lineage changes require adversarial tests for cross-task, cross-user, and parent-step mismatch cases. -- Memory quality and retrieval quality need labeled evaluations before release claims. +- Approval boundaries, row-level security, audit logging, and lineage changes require adversarial tests. +- Do not make memory-quality or retrieval-quality release claims without labeled evaluation evidence. From 389e22bc3b1f28e6c22c29d37796e3e949ee17af Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 14 Mar 2026 09:53:29 +0100 Subject: [PATCH 003/135] Sprint 5C: task artifact records and registration (#3) Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 39 ++- .../versions/20260313_0023_task_artifacts.py | 87 ++++++ apps/api/src/alicebot_api/artifacts.py | 173 ++++++++++ apps/api/src/alicebot_api/contracts.py | 42 +++ apps/api/src/alicebot_api/main.py | 82 +++++ apps/api/src/alicebot_api/store.py | 145 +++++++++ tests/integration/test_migrations.py | 10 + tests/integration/test_task_artifacts_api.py | 293 +++++++++++++++++ .../unit/test_20260313_0023_task_artifacts.py | 46 +++ tests/unit/test_artifacts.py | 295 ++++++++++++++++++ tests/unit/test_artifacts_main.py | 154 +++++++++ tests/unit/test_main.py | 3 + tests/unit/test_task_artifact_store.py | 228 ++++++++++++++ 13 files changed, 1586 insertions(+), 11 deletions(-) create mode 100644 apps/api/alembic/versions/20260313_0023_task_artifacts.py create mode 100644 apps/api/src/alicebot_api/artifacts.py create mode 100644 tests/integration/test_task_artifacts_api.py create mode 100644 tests/unit/test_20260313_0023_task_artifacts.py create mode 100644 tests/unit/test_artifacts.py create mode 100644 tests/unit/test_artifacts_main.py create mode 100644 tests/unit/test_task_artifact_store.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 9cefafb..f097b55 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5A. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5C. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, and `task_workspaces`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, and deterministic rooted local task-workspace provisioning +- durable `tasks`, `task_steps`, `task_workspaces`, and `task_artifacts`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, and explicit rooted local artifact registration plus deterministic artifact reads -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries. Broader runner-style orchestration, automatic multi-step progression, artifact indexing, document ingestion, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations under those workspaces. Broader runner-style orchestration, automatic multi-step progression, artifact indexing, document ingestion, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -24,7 +24,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` - - task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` + - task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` - `apps/web` and `workers` remain starter shells only. ### Data Foundation @@ -37,7 +37,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` - graph tables: `entities`, `entity_edges` - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` - - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces` + - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts` - `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. - `memory_review_labels` are append-only by database enforcement. - `tasks` are explicit user-scoped lifecycle records keyed to one thread and one tool, with durable request/tool snapshots, status in `pending_approval | approved | executed | denied | blocked`, and latest approval/execution pointers for the current narrow lifecycle seam. @@ -49,17 +49,18 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - Lineage fields are guarded by composite user-scoped foreign keys and a self-reference check so a step cannot cite itself as its parent. - `tool_executions` now persist an explicit `task_step_id` linked by a composite foreign key to `task_steps(id, user_id)`. - `task_workspaces` persist one active workspace record per visible task and user, store a deterministic `local_path`, and enforce that active uniqueness through a partial unique index on `(user_id, task_id)`. +- `task_artifacts` persist explicit user-scoped artifact rows linked to both `tasks` and `task_workspaces`, store `status = registered`, `ingestion_status = pending`, store only a workspace-relative `relative_path` plus optional `media_type_hint`, and enforce deterministic duplicate rejection through a unique index on `(user_id, task_workspace_id, relative_path)`. - `execution_budgets` enforce at most one active budget per `(user_id, tool_key, domain_hint)` selector scope through a partial unique index. - Per-request user context is set in the database through `app.current_user_id()`. - `TASK_WORKSPACE_ROOT` defines the only allowed base directory for workspace provisioning, and the live path rule is `resolved_root / user_id / task_id`. ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, and task workspaces. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, and task artifacts. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, and Sprint 5A task-workspace provisioning. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, and Sprint 5C task-artifact registration. ## Core Flows Implemented Now @@ -162,12 +163,24 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 8. `GET /v0/task-workspaces` lists visible workspaces in deterministic `created_at ASC, id ASC` order. 9. `GET /v0/task-workspaces/{task_workspace_id}` returns one user-visible workspace detail record. +### Task Artifact Registration And Reads + +1. Accept a user-scoped `POST /v0/task-workspaces/{task_workspace_id}/artifacts` request for one visible workspace. +2. Resolve the provided local file path and require it to exist as a regular file. +3. Resolve the persisted workspace `local_path` and reject registration if the file path escapes that rooted workspace boundary. +4. Persist only the workspace-relative POSIX path plus optional `media_type_hint`; no absolute artifact path is stored. +5. Lock registration for the target workspace before checking for an existing artifact with the same relative path. +6. Reject duplicate registration for the same visible `(task_workspace_id, relative_path)` deterministically. +7. Persist one `task_artifacts` row with `status = registered` and `ingestion_status = pending`. +8. `GET /v0/task-artifacts` lists visible artifact rows in deterministic `created_at ASC, id ASC` order. +9. `GET /v0/task-artifacts/{task_artifact_id}` returns one user-visible artifact detail record. + ## Security Model Implemented Now -- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, and task-workspace tables enforce row-level security. +- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, task-workspace, and task-artifact tables enforce row-level security. - The runtime role is limited to the narrow `SELECT` / `INSERT` / `UPDATE` permissions required by the shipped seams; there is no broad DDL or unrestricted table access at runtime. - Cross-user references are constrained through composite foreign keys on `(id, user_id)` where the schema needs ownership-linked joins. -- Approval, execution, memory, entity, task/task-step, and task-workspace reads all operate only inside the current user scope. +- Approval, execution, memory, entity, task/task-step, task-workspace, and task-artifact reads all operate only inside the current user scope. - Task-step manual continuation adds both schema-level and service-level lineage protection: - schema-level: user-scoped foreign keys and parent-not-self check - service-level: same-task, latest-step, visible-approval, visible-execution, and parent-outcome-match validation @@ -176,7 +189,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ## Testing Coverage Implemented Now - Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. -- Sprint 4O, Sprint 4S, and Sprint 5A added explicit task lifecycle coverage: +- Sprint 4O, Sprint 4S, Sprint 5A, and Sprint 5C added explicit task lifecycle coverage: - migrations for `tasks`, `task_steps`, and task-step lineage - staged/backfilled migration coverage for `tool_executions.task_step_id` - task and task-step store contracts @@ -189,6 +202,10 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - workspace create/list/detail response shape - duplicate active workspace rejection - task-workspace per-user isolation + - artifact register/list/detail response shape + - rooted artifact-path enforcement beneath the persisted workspace path + - duplicate artifact registration rejection for the same workspace-relative path + - task-artifact per-user isolation - trace visibility for continuation and transition events - user isolation for task and task-step reads and mutations - adversarial lineage validation for cross-task, cross-user, and parent-step mismatch cases @@ -198,7 +215,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i The following areas remain planned later and must not be described as implemented: - runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam -- artifact storage, artifact indexing, and document ingestion beyond the current rooted local workspace boundary +- artifact indexing, artifact content processing, and document ingestion beyond the current explicit rooted local registration boundary - read-only Gmail and Calendar connectors - broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler - model-driven extraction, reranking, and broader memory review automation diff --git a/apps/api/alembic/versions/20260313_0023_task_artifacts.py b/apps/api/alembic/versions/20260313_0023_task_artifacts.py new file mode 100644 index 0000000..a3d32c1 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0023_task_artifacts.py @@ -0,0 +1,87 @@ +"""Add user-scoped task artifact records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0023" +down_revision = "20260313_0022" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_artifacts",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_artifacts ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_id uuid NOT NULL, + task_workspace_id uuid NOT NULL, + status text NOT NULL, + ingestion_status text NOT NULL, + relative_path text NOT NULL, + media_type_hint text, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT task_artifacts_task_user_fk + FOREIGN KEY (task_id, user_id) + REFERENCES tasks(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_artifacts_workspace_user_fk + FOREIGN KEY (task_workspace_id, user_id) + REFERENCES task_workspaces(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_artifacts_status_check + CHECK (status IN ('registered')), + CONSTRAINT task_artifacts_ingestion_status_check + CHECK (ingestion_status IN ('pending')), + CONSTRAINT task_artifacts_relative_path_nonempty_check + CHECK (length(relative_path) > 0), + CONSTRAINT task_artifacts_media_type_hint_nonempty_check + CHECK (media_type_hint IS NULL OR length(media_type_hint) > 0) + ); + + CREATE INDEX task_artifacts_user_created_idx + ON task_artifacts (user_id, created_at, id); + + CREATE UNIQUE INDEX task_artifacts_workspace_relative_path_idx + ON task_artifacts (user_id, task_workspace_id, relative_path); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON task_artifacts TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_artifacts_is_owner ON task_artifacts + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS task_artifacts", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py new file mode 100644 index 0000000..c91f77c --- /dev/null +++ b/apps/api/src/alicebot_api/artifacts.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from pathlib import Path +from typing import cast +from uuid import UUID + +import psycopg + +from alicebot_api.contracts import ( + TASK_ARTIFACT_LIST_ORDER, + TaskArtifactCreateResponse, + TaskArtifactDetailResponse, + TaskArtifactListResponse, + TaskArtifactRecord, + TaskArtifactRegisterInput, + TaskArtifactStatus, + TaskArtifactIngestionStatus, +) +from alicebot_api.store import ContinuityStore, TaskArtifactRow +from alicebot_api.workspaces import TaskWorkspaceNotFoundError + + +class TaskArtifactNotFoundError(LookupError): + """Raised when a task artifact is not visible inside the current user scope.""" + + +class TaskArtifactAlreadyExistsError(RuntimeError): + """Raised when the same workspace-relative artifact path is registered twice.""" + + +class TaskArtifactValidationError(ValueError): + """Raised when a local artifact path cannot satisfy registration constraints.""" + + +def resolve_artifact_path(local_path: str) -> Path: + return Path(local_path).expanduser().resolve() + + +def ensure_artifact_path_is_rooted(*, workspace_path: Path, artifact_path: Path) -> None: + resolved_workspace_path = workspace_path.resolve() + resolved_artifact_path = artifact_path.resolve() + try: + resolved_artifact_path.relative_to(resolved_workspace_path) + except ValueError as exc: + raise TaskArtifactValidationError( + f"artifact path {resolved_artifact_path} escapes workspace root {resolved_workspace_path}" + ) from exc + + +def build_workspace_relative_artifact_path(*, workspace_path: Path, artifact_path: Path) -> str: + relative_path = artifact_path.relative_to(workspace_path).as_posix() + if relative_path in ("", "."): + raise TaskArtifactValidationError( + f"artifact path {artifact_path} must point to a file beneath workspace root {workspace_path}" + ) + return relative_path + + +def _require_existing_file(artifact_path: Path) -> None: + if not artifact_path.exists(): + raise TaskArtifactValidationError(f"artifact path {artifact_path} was not found") + if not artifact_path.is_file(): + raise TaskArtifactValidationError(f"artifact path {artifact_path} is not a regular file") + + +def _duplicate_registration_message(*, task_workspace_id: UUID, relative_path: str) -> str: + return ( + f"artifact {relative_path} is already registered for task workspace {task_workspace_id}" + ) + + +def serialize_task_artifact_row(row: TaskArtifactRow) -> TaskArtifactRecord: + return { + "id": str(row["id"]), + "task_id": str(row["task_id"]), + "task_workspace_id": str(row["task_workspace_id"]), + "status": cast(TaskArtifactStatus, row["status"]), + "ingestion_status": cast(TaskArtifactIngestionStatus, row["ingestion_status"]), + "relative_path": row["relative_path"], + "media_type_hint": row["media_type_hint"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def register_task_artifact_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskArtifactRegisterInput, +) -> TaskArtifactCreateResponse: + del user_id + + workspace = store.get_task_workspace_optional(request.task_workspace_id) + if workspace is None: + raise TaskWorkspaceNotFoundError( + f"task workspace {request.task_workspace_id} was not found" + ) + + workspace_path = Path(workspace["local_path"]).expanduser().resolve() + artifact_path = resolve_artifact_path(request.local_path) + _require_existing_file(artifact_path) + ensure_artifact_path_is_rooted( + workspace_path=workspace_path, + artifact_path=artifact_path, + ) + relative_path = build_workspace_relative_artifact_path( + workspace_path=workspace_path, + artifact_path=artifact_path, + ) + + store.lock_task_artifacts(workspace["id"]) + existing = store.get_task_artifact_by_workspace_relative_path_optional( + task_workspace_id=workspace["id"], + relative_path=relative_path, + ) + if existing is not None: + raise TaskArtifactAlreadyExistsError( + _duplicate_registration_message( + task_workspace_id=workspace["id"], + relative_path=relative_path, + ) + ) + + try: + row = store.create_task_artifact( + task_id=workspace["task_id"], + task_workspace_id=workspace["id"], + status="registered", + ingestion_status="pending", + relative_path=relative_path, + media_type_hint=request.media_type_hint, + ) + except psycopg.errors.UniqueViolation as exc: + raise TaskArtifactAlreadyExistsError( + _duplicate_registration_message( + task_workspace_id=workspace["id"], + relative_path=relative_path, + ) + ) from exc + + return {"artifact": serialize_task_artifact_row(row)} + + +def list_task_artifact_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> TaskArtifactListResponse: + del user_id + + items = [serialize_task_artifact_row(row) for row in store.list_task_artifacts()] + return { + "items": items, + "summary": { + "total_count": len(items), + "order": list(TASK_ARTIFACT_LIST_ORDER), + }, + } + + +def get_task_artifact_record( + store: ContinuityStore, + *, + user_id: UUID, + task_artifact_id: UUID, +) -> TaskArtifactDetailResponse: + del user_id + + row = store.get_task_artifact_optional(task_artifact_id) + if row is None: + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + return {"artifact": serialize_task_artifact_row(row)} diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index fc794c2..06113c4 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -20,6 +20,8 @@ ApprovalResolutionOutcome = Literal["resolved", "duplicate_rejected", "conflict_rejected"] TaskStatus = Literal["pending_approval", "approved", "executed", "denied", "blocked"] TaskWorkspaceStatus = Literal["active"] +TaskArtifactStatus = Literal["registered"] +TaskArtifactIngestionStatus = Literal["pending"] TaskLifecycleSource = Literal[ "approval_request", "approval_resolution", @@ -129,6 +131,7 @@ APPROVAL_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_ARTIFACT_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_STEP_LIST_ORDER = ["sequence_no_asc", "created_at_asc", "id_asc"] TOOL_EXECUTION_LIST_ORDER = ["executed_at_asc", "id_asc"] EXECUTION_BUDGET_LIST_ORDER = ["created_at_asc", "id_asc"] @@ -136,6 +139,8 @@ EXECUTION_BUDGET_STATUSES = ["active", "inactive", "superseded"] TASK_STATUSES = ["pending_approval", "approved", "executed", "denied", "blocked"] TASK_WORKSPACE_STATUSES = ["active"] +TASK_ARTIFACT_STATUSES = ["registered"] +TASK_ARTIFACT_INGESTION_STATUSES = ["pending"] TASK_STEP_KINDS = ["governed_request"] TASK_STEP_STATUSES = ["created", "approved", "executed", "blocked", "denied"] APPROVAL_REQUEST_VERSION_V0 = "approval_request_v0" @@ -1594,6 +1599,43 @@ class TaskWorkspaceDetailResponse(TypedDict): workspace: TaskWorkspaceRecord +@dataclass(frozen=True, slots=True) +class TaskArtifactRegisterInput: + task_workspace_id: UUID + local_path: str + media_type_hint: str | None = None + + +class TaskArtifactRecord(TypedDict): + id: str + task_id: str + task_workspace_id: str + status: TaskArtifactStatus + ingestion_status: TaskArtifactIngestionStatus + relative_path: str + media_type_hint: str | None + created_at: str + updated_at: str + + +class TaskArtifactCreateResponse(TypedDict): + artifact: TaskArtifactRecord + + +class TaskArtifactListSummary(TypedDict): + total_count: int + order: list[str] + + +class TaskArtifactListResponse(TypedDict): + items: list[TaskArtifactRecord] + summary: TaskArtifactListSummary + + +class TaskArtifactDetailResponse(TypedDict): + artifact: TaskArtifactRecord + + class TaskStepTraceLink(TypedDict): trace_id: str trace_kind: str diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index 764812d..5bce9bf 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -50,6 +50,7 @@ ProxyExecutionStatus, ToolAllowlistEvaluationRequestInput, ProxyExecutionRequestInput, + TaskArtifactRegisterInput, TaskStepKind, TaskStepLineageInput, TaskStepNextCreateInput, @@ -60,6 +61,14 @@ ToolRoutingRequestInput, ToolCreateInput, ) +from alicebot_api.artifacts import ( + TaskArtifactAlreadyExistsError, + TaskArtifactNotFoundError, + TaskArtifactValidationError, + get_task_artifact_record, + list_task_artifact_records, + register_task_artifact_record, +) from alicebot_api.approvals import ( ApprovalNotFoundError, ApprovalResolutionConflictError, @@ -398,6 +407,12 @@ class CreateTaskWorkspaceRequest(BaseModel): user_id: UUID +class RegisterTaskArtifactRequest(BaseModel): + user_id: UUID + local_path: str = Field(min_length=1, max_length=4000) + media_type_hint: str | None = Field(default=None, min_length=1, max_length=200) + + class TaskStepRequestSnapshot(BaseModel): thread_id: UUID tool_id: UUID @@ -1172,6 +1187,73 @@ def get_task_step(task_step_id: UUID, user_id: UUID) -> JSONResponse: ) +@app.post("/v0/task-workspaces/{task_workspace_id}/artifacts") +def register_task_artifact( + task_workspace_id: UUID, + request: RegisterTaskArtifactRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = register_task_artifact_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskArtifactRegisterInput( + task_workspace_id=task_workspace_id, + local_path=request.local_path, + media_type_hint=request.media_type_hint, + ), + ) + except TaskWorkspaceNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskArtifactValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except TaskArtifactAlreadyExistsError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-artifacts") +def list_task_artifacts(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_artifact_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-artifacts/{task_artifact_id}") +def get_task_artifact(task_artifact_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_artifact_record( + ContinuityStore(conn), + user_id=user_id, + task_artifact_id=task_artifact_id, + ) + except TaskArtifactNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.post("/v0/tasks/{task_id}/steps") def create_next_task_step(task_id: UUID, request: CreateNextTaskStepRequest) -> JSONResponse: settings = get_settings() diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index 8c3b551..c849a35 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -245,6 +245,19 @@ class TaskWorkspaceRow(TypedDict): updated_at: datetime +class TaskArtifactRow(TypedDict): + id: UUID + user_id: UUID + task_id: UUID + task_workspace_id: UUID + status: str + ingestion_status: str + relative_path: str + media_type_hint: str | None + created_at: datetime + updated_at: datetime + + class TaskStepRow(TypedDict): id: UUID user_id: UUID @@ -344,6 +357,7 @@ class LabelCountRow(TypedDict): LOCK_THREAD_EVENTS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 0))" LOCK_TASK_STEPS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 2))" LOCK_TASK_WORKSPACES_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 3))" +LOCK_TASK_ARTIFACTS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 4))" INSERT_EVENT_SQL = """ WITH next_sequence AS ( @@ -1416,6 +1430,93 @@ class LabelCountRow(TypedDict): ORDER BY created_at ASC, id ASC """ +INSERT_TASK_ARTIFACT_SQL = """ + INSERT INTO task_artifacts ( + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + """ + +GET_TASK_ARTIFACT_SQL = """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + WHERE id = %s + """ + +GET_TASK_ARTIFACT_BY_WORKSPACE_RELATIVE_PATH_SQL = """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + WHERE task_workspace_id = %s + AND relative_path = %s + ORDER BY created_at ASC, id ASC + LIMIT 1 + """ + +LIST_TASK_ARTIFACTS_SQL = """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + ORDER BY created_at ASC, id ASC + """ + INSERT_TASK_STEP_SQL = """ INSERT INTO task_steps ( user_id, @@ -2505,6 +2606,50 @@ def get_active_task_workspace_for_task_optional(self, task_id: UUID) -> TaskWork def list_task_workspaces(self) -> list[TaskWorkspaceRow]: return self._fetch_all(LIST_TASK_WORKSPACES_SQL) + def lock_task_artifacts(self, task_workspace_id: UUID) -> None: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_ARTIFACTS_SQL, (str(task_workspace_id),)) + + def create_task_artifact( + self, + *, + task_id: UUID, + task_workspace_id: UUID, + status: str, + ingestion_status: str, + relative_path: str, + media_type_hint: str | None, + ) -> TaskArtifactRow: + return self._fetch_one( + "create_task_artifact", + INSERT_TASK_ARTIFACT_SQL, + ( + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + ), + ) + + def get_task_artifact_optional(self, task_artifact_id: UUID) -> TaskArtifactRow | None: + return self._fetch_optional_one(GET_TASK_ARTIFACT_SQL, (task_artifact_id,)) + + def get_task_artifact_by_workspace_relative_path_optional( + self, + *, + task_workspace_id: UUID, + relative_path: str, + ) -> TaskArtifactRow | None: + return self._fetch_optional_one( + GET_TASK_ARTIFACT_BY_WORKSPACE_RELATIVE_PATH_SQL, + (task_workspace_id, relative_path), + ) + + def list_task_artifacts(self) -> list[TaskArtifactRow]: + return self._fetch_all(LIST_TASK_ARTIFACTS_SQL) + def lock_task_steps(self, task_id: UUID) -> None: with self.conn.cursor() as cur: cur.execute(LOCK_TASK_STEPS_SQL, (str(task_id),)) diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 434645e..250c270 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -299,6 +299,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): assert cur.fetchone()[0] == "tasks" cur.execute("SELECT to_regclass('public.task_workspaces')") assert cur.fetchone()[0] == "task_workspaces" + cur.execute("SELECT to_regclass('public.task_artifacts')") + assert cur.fetchone()[0] == "task_artifacts" cur.execute("SELECT to_regclass('public.task_steps')") assert cur.fetchone()[0] == "task_steps" cur.execute( @@ -380,6 +382,7 @@ def test_migrations_upgrade_and_downgrade(database_urls): 'approvals', 'tasks', 'task_workspaces', + 'task_artifacts', 'task_steps', 'execution_budgets', 'tool_executions' @@ -401,6 +404,7 @@ def test_migrations_upgrade_and_downgrade(database_urls): ("memory_revisions", True, True), ("policies", True, True), ("sessions", True, True), + ("task_artifacts", True, True), ("task_steps", True, True), ("task_workspaces", True, True), ("tasks", True, True), @@ -467,6 +471,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): has_table_privilege('alicebot_app', 'tasks', 'DELETE'), has_table_privilege('alicebot_app', 'task_workspaces', 'UPDATE'), has_table_privilege('alicebot_app', 'task_workspaces', 'DELETE'), + has_table_privilege('alicebot_app', 'task_artifacts', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_artifacts', 'DELETE'), has_table_privilege('alicebot_app', 'task_steps', 'UPDATE'), has_table_privilege('alicebot_app', 'task_steps', 'DELETE'), has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE'), @@ -504,6 +510,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): False, False, False, + False, + False, True, False, True, @@ -516,6 +524,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): with psycopg.connect(database_urls["admin"]) as conn: with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_artifacts')") + assert cur.fetchone()[0] is None cur.execute("SELECT to_regclass('public.task_workspaces')") assert cur.fetchone()[0] is None cur.execute( diff --git a/tests/integration/test_task_artifacts_api.py b/tests/integration/test_task_artifacts_api.py new file mode 100644 index 0000000..ff78f47 --- /dev/null +++ b/tests/integration/test_task_artifacts_api.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_task(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Artifact thread") + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + task = store.create_task( + thread_id=thread["id"], + tool_id=tool["id"], + status="approved", + request={ + "thread_id": str(thread["id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + tool={ + "id": str(tool["id"]), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + latest_approval_id=None, + latest_execution_id=None, + ) + + return { + "user_id": user_id, + "task_id": task["id"], + } + + +def test_task_artifact_endpoints_register_list_detail_isolate_and_reject_duplicates( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + first_file = workspace_path / "docs" / "spec.txt" + first_file.parent.mkdir(parents=True) + first_file.write_text("spec") + second_file = workspace_path / "notes" / "plan.md" + second_file.parent.mkdir(parents=True) + second_file.write_text("plan") + outside_file = tmp_path / "escape.txt" + outside_file.write_text("escape") + + first_create_status, first_create_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(first_file), + "media_type_hint": "text/plain", + }, + ) + second_create_status, second_create_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(second_file), + "media_type_hint": "text/markdown", + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/task-artifacts", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{first_create_payload['artifact']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + duplicate_status, duplicate_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(first_file), + "media_type_hint": "text/plain", + }, + ) + escaped_status, escaped_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(outside_file), + }, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/task-artifacts", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{first_create_payload['artifact']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_create_status, isolated_create_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(intruder["user_id"]), + "local_path": str(first_file), + }, + ) + + assert first_create_status == 201 + assert first_create_payload == { + "artifact": { + "id": first_create_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "pending", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": first_create_payload["artifact"]["created_at"], + "updated_at": first_create_payload["artifact"]["updated_at"], + } + } + + assert second_create_status == 201 + assert second_create_payload == { + "artifact": { + "id": second_create_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "pending", + "relative_path": "notes/plan.md", + "media_type_hint": "text/markdown", + "created_at": second_create_payload["artifact"]["created_at"], + "updated_at": second_create_payload["artifact"]["updated_at"], + } + } + + assert list_status == 200 + assert list_payload == { + "items": [ + first_create_payload["artifact"], + second_create_payload["artifact"], + ], + "summary": {"total_count": 2, "order": ["created_at_asc", "id_asc"]}, + } + + assert detail_status == 200 + assert detail_payload == {"artifact": first_create_payload["artifact"]} + + assert duplicate_status == 409 + assert duplicate_payload == { + "detail": ( + "artifact docs/spec.txt is already registered for task workspace " + f"{workspace_payload['workspace']['id']}" + ) + } + + assert escaped_status == 400 + assert escaped_payload == { + "detail": f"artifact path {outside_file.resolve()} escapes workspace root {workspace_path.resolve()}" + } + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"task artifact {first_create_payload['artifact']['id']} was not found" + } + + assert isolated_create_status == 404 + assert isolated_create_payload == { + "detail": f"task workspace {workspace_payload['workspace']['id']} was not found" + } diff --git a/tests/unit/test_20260313_0023_task_artifacts.py b/tests/unit/test_20260313_0023_task_artifacts.py new file mode 100644 index 0000000..f6fd17d --- /dev/null +++ b/tests/unit/test_20260313_0023_task_artifacts.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0023_task_artifacts" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_artifacts ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_artifacts FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_artifact_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON task_artifacts TO alicebot_app", + ) diff --git a/tests/unit/test_artifacts.py b/tests/unit/test_artifacts.py new file mode 100644 index 0000000..33d6fb7 --- /dev/null +++ b/tests/unit/test_artifacts.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from pathlib import Path +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.artifacts import ( + TaskArtifactAlreadyExistsError, + TaskArtifactNotFoundError, + TaskArtifactValidationError, + build_workspace_relative_artifact_path, + ensure_artifact_path_is_rooted, + get_task_artifact_record, + list_task_artifact_records, + register_task_artifact_record, + serialize_task_artifact_row, +) +from alicebot_api.contracts import TaskArtifactRegisterInput +from alicebot_api.workspaces import TaskWorkspaceNotFoundError + + +class ArtifactStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.workspaces: list[dict[str, object]] = [] + self.artifacts: list[dict[str, object]] = [] + self.locked_workspace_ids: list[UUID] = [] + + def create_task_workspace(self, *, task_workspace_id: UUID, task_id: UUID, user_id: UUID, local_path: str) -> dict[str, object]: + workspace = { + "id": task_workspace_id, + "user_id": user_id, + "task_id": task_id, + "status": "active", + "local_path": local_path, + "created_at": self.base_time, + "updated_at": self.base_time, + } + self.workspaces.append(workspace) + return workspace + + def get_task_workspace_optional(self, task_workspace_id: UUID) -> dict[str, object] | None: + return next((workspace for workspace in self.workspaces if workspace["id"] == task_workspace_id), None) + + def lock_task_artifacts(self, task_workspace_id: UUID) -> None: + self.locked_workspace_ids.append(task_workspace_id) + + def get_task_artifact_by_workspace_relative_path_optional( + self, + *, + task_workspace_id: UUID, + relative_path: str, + ) -> dict[str, object] | None: + return next( + ( + artifact + for artifact in self.artifacts + if artifact["task_workspace_id"] == task_workspace_id + and artifact["relative_path"] == relative_path + ), + None, + ) + + def create_task_artifact( + self, + *, + task_id: UUID, + task_workspace_id: UUID, + status: str, + ingestion_status: str, + relative_path: str, + media_type_hint: str | None, + ) -> dict[str, object]: + artifact = { + "id": uuid4(), + "user_id": self.workspaces[0]["user_id"], + "task_id": task_id, + "task_workspace_id": task_workspace_id, + "status": status, + "ingestion_status": ingestion_status, + "relative_path": relative_path, + "media_type_hint": media_type_hint, + "created_at": self.base_time + timedelta(minutes=len(self.artifacts)), + "updated_at": self.base_time + timedelta(minutes=len(self.artifacts)), + } + self.artifacts.append(artifact) + return artifact + + def list_task_artifacts(self) -> list[dict[str, object]]: + return sorted(self.artifacts, key=lambda artifact: (artifact["created_at"], artifact["id"])) + + def get_task_artifact_optional(self, task_artifact_id: UUID) -> dict[str, object] | None: + return next((artifact for artifact in self.artifacts if artifact["id"] == task_artifact_id), None) + + +def test_ensure_artifact_path_is_rooted_rejects_escape() -> None: + with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): + ensure_artifact_path_is_rooted( + workspace_path=Path("/tmp/alicebot/task-workspaces/user/task"), + artifact_path=Path("/tmp/alicebot/task-workspaces/user/task/../escape.txt"), + ) + + +def test_build_workspace_relative_artifact_path_returns_posix_path() -> None: + relative_path = build_workspace_relative_artifact_path( + workspace_path=Path("/tmp/alicebot/task-workspaces/user/task"), + artifact_path=Path("/tmp/alicebot/task-workspaces/user/task/docs/spec.txt"), + ) + + assert relative_path == "docs/spec.txt" + + +def test_register_task_artifact_record_persists_relative_path_and_returns_record(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "spec.txt" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_text("spec") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + + response = register_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactRegisterInput( + task_workspace_id=task_workspace_id, + local_path=str(artifact_path), + media_type_hint="text/plain", + ), + ) + + assert response == { + "artifact": { + "id": response["artifact"]["id"], + "task_id": str(task_id), + "task_workspace_id": str(task_workspace_id), + "status": "registered", + "ingestion_status": "pending", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + } + } + assert store.locked_workspace_ids == [task_workspace_id] + + +def test_register_task_artifact_record_rejects_duplicate_relative_path(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "spec.txt" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_text("spec") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + + register_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactRegisterInput( + task_workspace_id=task_workspace_id, + local_path=str(artifact_path), + media_type_hint="text/plain", + ), + ) + + with pytest.raises( + TaskArtifactAlreadyExistsError, + match=f"artifact docs/spec.txt is already registered for task workspace {task_workspace_id}", + ): + register_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactRegisterInput( + task_workspace_id=task_workspace_id, + local_path=str(artifact_path), + media_type_hint="text/plain", + ), + ) + + +def test_register_task_artifact_record_requires_visible_workspace(tmp_path) -> None: + artifact_path = tmp_path / "spec.txt" + artifact_path.write_text("spec") + + with pytest.raises(TaskWorkspaceNotFoundError, match="was not found"): + register_task_artifact_record( + ArtifactStoreStub(), + user_id=uuid4(), + request=TaskArtifactRegisterInput( + task_workspace_id=uuid4(), + local_path=str(artifact_path), + media_type_hint=None, + ), + ) + + +def test_register_task_artifact_record_rejects_paths_outside_workspace(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + outside_path = tmp_path / "escape.txt" + outside_path.write_text("escape") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + + with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): + register_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactRegisterInput( + task_workspace_id=task_workspace_id, + local_path=str(outside_path), + media_type_hint=None, + ), + ) + + +def test_list_and_get_task_artifact_records_are_deterministic() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + first = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/a.txt", + media_type_hint="text/plain", + ) + second = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/b.txt", + media_type_hint=None, + ) + + assert list_task_artifact_records(store, user_id=user_id) == { + "items": [ + serialize_task_artifact_row(first), + serialize_task_artifact_row(second), + ], + "summary": { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + assert get_task_artifact_record( + store, + user_id=user_id, + task_artifact_id=first["id"], + ) == {"artifact": serialize_task_artifact_row(first)} + + +def test_get_task_artifact_record_raises_when_artifact_is_missing() -> None: + with pytest.raises(TaskArtifactNotFoundError, match="was not found"): + get_task_artifact_record( + ArtifactStoreStub(), + user_id=uuid4(), + task_artifact_id=uuid4(), + ) diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py new file mode 100644 index 0000000..a791655 --- /dev/null +++ b/tests/unit/test_artifacts_main.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.artifacts import ( + TaskArtifactAlreadyExistsError, + TaskArtifactNotFoundError, + TaskArtifactValidationError, +) +from alicebot_api.workspaces import TaskWorkspaceNotFoundError + + +def test_list_task_artifacts_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_artifact_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + }, + ) + + response = main_module.list_task_artifacts(user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + +def test_get_task_artifact_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_artifact_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_task_artifact_record(*_args, **_kwargs): + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_task_artifact_record", fake_get_task_artifact_record) + + response = main_module.get_task_artifact(task_artifact_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task artifact {task_artifact_id} was not found"} + + +def test_register_task_artifact_endpoint_maps_workspace_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_register_task_artifact_record(*_args, **_kwargs): + raise TaskWorkspaceNotFoundError(f"task workspace {task_workspace_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "register_task_artifact_record", fake_register_task_artifact_record) + + response = main_module.register_task_artifact( + task_workspace_id, + main_module.RegisterTaskArtifactRequest( + user_id=user_id, + local_path="/tmp/example.txt", + ), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task workspace {task_workspace_id} was not found"} + + +def test_register_task_artifact_endpoint_maps_validation_to_400(monkeypatch) -> None: + user_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_register_task_artifact_record(*_args, **_kwargs): + raise TaskArtifactValidationError("artifact path /tmp/escape.txt escapes workspace root /tmp/workspace") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "register_task_artifact_record", fake_register_task_artifact_record) + + response = main_module.register_task_artifact( + task_workspace_id, + main_module.RegisterTaskArtifactRequest( + user_id=user_id, + local_path="/tmp/escape.txt", + ), + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "artifact path /tmp/escape.txt escapes workspace root /tmp/workspace" + } + + +def test_register_task_artifact_endpoint_maps_duplicate_to_409(monkeypatch) -> None: + user_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_register_task_artifact_record(*_args, **_kwargs): + raise TaskArtifactAlreadyExistsError( + f"artifact docs/spec.txt is already registered for task workspace {task_workspace_id}" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "register_task_artifact_record", fake_register_task_artifact_record) + + response = main_module.register_task_artifact( + task_workspace_id, + main_module.RegisterTaskArtifactRequest( + user_id=user_id, + local_path="/tmp/docs/spec.txt", + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"artifact docs/spec.txt is already registered for task workspace {task_workspace_id}" + } diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index dc1e5ca..20b7a00 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -126,6 +126,9 @@ def test_healthcheck_route_is_registered() -> None: assert "/v0/tasks/{task_id}/steps" in route_paths assert "/v0/task-workspaces" in route_paths assert "/v0/task-workspaces/{task_workspace_id}" in route_paths + assert "/v0/task-workspaces/{task_workspace_id}/artifacts" in route_paths + assert "/v0/task-artifacts" in route_paths + assert "/v0/task-artifacts/{task_artifact_id}" in route_paths assert "/v0/task-steps/{task_step_id}" in route_paths assert "/v0/task-steps/{task_step_id}/transition" in route_paths assert "/v0/entities/{entity_id}" in route_paths diff --git a/tests/unit/test_task_artifact_store.py b/tests/unit/test_task_artifact_store.py new file mode 100644 index 0000000..c6f6277 --- /dev/null +++ b/tests/unit/test_task_artifact_store.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_task_artifact_store_methods_use_expected_queries() -> None: + task_artifact_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": task_artifact_id, + "user_id": uuid4(), + "task_id": task_id, + "task_workspace_id": task_workspace_id, + "status": "registered", + "ingestion_status": "pending", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": task_artifact_id, + "user_id": uuid4(), + "task_id": task_id, + "task_workspace_id": task_workspace_id, + "status": "registered", + "ingestion_status": "pending", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": task_artifact_id, + "user_id": uuid4(), + "task_id": task_id, + "task_workspace_id": task_workspace_id, + "status": "registered", + "ingestion_status": "pending", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + ], + fetchall_result=[ + { + "id": task_artifact_id, + "user_id": uuid4(), + "task_id": task_id, + "task_workspace_id": task_workspace_id, + "status": "registered", + "ingestion_status": "pending", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + fetched = store.get_task_artifact_optional(task_artifact_id) + duplicate = store.get_task_artifact_by_workspace_relative_path_optional( + task_workspace_id=task_workspace_id, + relative_path="docs/spec.txt", + ) + listed = store.list_task_artifacts() + store.lock_task_artifacts(task_workspace_id) + + assert created["id"] == task_artifact_id + assert fetched is not None + assert duplicate is not None + assert listed[0]["id"] == task_artifact_id + assert cursor.executed == [ + ( + """ + INSERT INTO task_artifacts ( + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + """, + ( + task_id, + task_workspace_id, + "registered", + "pending", + "docs/spec.txt", + "text/plain", + ), + ), + ( + """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + WHERE id = %s + """, + (task_artifact_id,), + ), + ( + """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + WHERE task_workspace_id = %s + AND relative_path = %s + ORDER BY created_at ASC, id ASC + LIMIT 1 + """, + (task_workspace_id, "docs/spec.txt"), + ), + ( + """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + ORDER BY created_at ASC, id ASC + """, + None, + ), + ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 4))", + (str(task_workspace_id),), + ), + ] From ec8055ed6a4e1c0e817107195fd1b83a5ed95ddc Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 14 Mar 2026 13:55:31 +0100 Subject: [PATCH 004/135] Sprint 5D: Local Artifact Ingestion V0 (#4) --- ARCHITECTURE.md | 44 +- .../20260314_0024_task_artifact_chunks.py | 97 ++++ apps/api/src/alicebot_api/artifacts.py | 180 +++++++- apps/api/src/alicebot_api/contracts.py | 39 +- apps/api/src/alicebot_api/main.py | 54 +++ apps/api/src/alicebot_api/store.py | 115 +++++ tests/integration/test_migrations.py | 10 + tests/integration/test_task_artifacts_api.py | 372 +++++++++++++++ ...test_20260314_0024_task_artifact_chunks.py | 48 ++ tests/unit/test_artifacts.py | 431 +++++++++++++++++- tests/unit/test_artifacts_main.py | 99 ++++ tests/unit/test_main.py | 2 + tests/unit/test_task_artifact_store.py | 142 ++++++ 13 files changed, 1617 insertions(+), 16 deletions(-) create mode 100644 apps/api/alembic/versions/20260314_0024_task_artifact_chunks.py create mode 100644 tests/unit/test_20260314_0024_task_artifact_chunks.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index f097b55..9899d45 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5C. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5D. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, and `task_artifacts`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, and explicit rooted local artifact registration plus deterministic artifact reads +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, and `task_artifact_chunks`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, and deterministic artifact plus chunk reads -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations under those workspaces. Broader runner-style orchestration, automatic multi-step progression, artifact indexing, document ingestion, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations plus narrow deterministic text ingestion under those workspaces. Broader runner-style orchestration, automatic multi-step progression, retrieval over artifact chunks, embeddings, rich-document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -24,7 +24,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` - - task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` +- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` - `apps/web` and `workers` remain starter shells only. ### Data Foundation @@ -37,7 +37,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` - graph tables: `entities`, `entity_edges` - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` - - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts` + - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks` - `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. - `memory_review_labels` are append-only by database enforcement. - `tasks` are explicit user-scoped lifecycle records keyed to one thread and one tool, with durable request/tool snapshots, status in `pending_approval | approved | executed | denied | blocked`, and latest approval/execution pointers for the current narrow lifecycle seam. @@ -49,18 +49,19 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - Lineage fields are guarded by composite user-scoped foreign keys and a self-reference check so a step cannot cite itself as its parent. - `tool_executions` now persist an explicit `task_step_id` linked by a composite foreign key to `task_steps(id, user_id)`. - `task_workspaces` persist one active workspace record per visible task and user, store a deterministic `local_path`, and enforce that active uniqueness through a partial unique index on `(user_id, task_id)`. -- `task_artifacts` persist explicit user-scoped artifact rows linked to both `tasks` and `task_workspaces`, store `status = registered`, `ingestion_status = pending`, store only a workspace-relative `relative_path` plus optional `media_type_hint`, and enforce deterministic duplicate rejection through a unique index on `(user_id, task_workspace_id, relative_path)`. +- `task_artifacts` persist explicit user-scoped artifact rows linked to both `tasks` and `task_workspaces`, store `status = registered`, `ingestion_status in ('pending', 'ingested')`, store only a workspace-relative `relative_path` plus optional `media_type_hint`, and enforce deterministic duplicate rejection through a unique index on `(user_id, task_workspace_id, relative_path)`. +- `task_artifact_chunks` persist explicit user-scoped durable chunk rows linked to one artifact, store ordered `sequence_no`, zero-based `char_start`, exclusive `char_end_exclusive`, and chunk `text`, and enforce deterministic uniqueness through a unique index on `(user_id, task_artifact_id, sequence_no)`. - `execution_budgets` enforce at most one active budget per `(user_id, tool_key, domain_hint)` selector scope through a partial unique index. - Per-request user context is set in the database through `app.current_user_id()`. - `TASK_WORKSPACE_ROOT` defines the only allowed base directory for workspace provisioning, and the live path rule is `resolved_root / user_id / task_id`. ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, and task artifacts. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, and narrow local artifact chunk ingestion. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, and Sprint 5C task-artifact registration. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, and Sprint 5D local artifact ingestion plus chunk reads. ## Core Flows Implemented Now @@ -175,12 +176,26 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 8. `GET /v0/task-artifacts` lists visible artifact rows in deterministic `created_at ASC, id ASC` order. 9. `GET /v0/task-artifacts/{task_artifact_id}` returns one user-visible artifact detail record. +### Task Artifact Ingestion And Chunk Reads + +1. Accept a user-scoped `POST /v0/task-artifacts/{task_artifact_id}/ingest` request for one visible registered artifact. +2. Lock ingestion for that artifact before deciding whether work is needed. +3. Resolve the persisted workspace `local_path` plus persisted artifact `relative_path`, and reject any rooted-path escape deterministically. +4. Support only the narrow explicit text set: `text/plain` and `text/markdown`. +5. Read file bytes deterministically and require valid UTF-8 text. +6. Normalize line endings by rewriting `\r\n` and `\r` to `\n`. +7. Chunk normalized text deterministically with rule `normalized_utf8_text_fixed_window_1000_chars_v1`. +8. Persist ordered `task_artifact_chunks` rows with `sequence_no`, `char_start`, `char_end_exclusive`, and `text`. +9. Update the parent artifact to `ingestion_status = ingested`. +10. If the artifact is already ingested, return the existing artifact and chunk summary without reinserting chunks. +11. `GET /v0/task-artifacts/{task_artifact_id}/chunks` returns visible chunk rows in deterministic `sequence_no ASC, id ASC` order plus stable summary metadata. + ## Security Model Implemented Now -- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, task-workspace, and task-artifact tables enforce row-level security. +- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, task-workspace, task-artifact, and task-artifact-chunk tables enforce row-level security. - The runtime role is limited to the narrow `SELECT` / `INSERT` / `UPDATE` permissions required by the shipped seams; there is no broad DDL or unrestricted table access at runtime. - Cross-user references are constrained through composite foreign keys on `(id, user_id)` where the schema needs ownership-linked joins. -- Approval, execution, memory, entity, task/task-step, task-workspace, and task-artifact reads all operate only inside the current user scope. +- Approval, execution, memory, entity, task/task-step, task-workspace, task-artifact, and task-artifact-chunk reads all operate only inside the current user scope. - Task-step manual continuation adds both schema-level and service-level lineage protection: - schema-level: user-scoped foreign keys and parent-not-self check - service-level: same-task, latest-step, visible-approval, visible-execution, and parent-outcome-match validation @@ -205,7 +220,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - artifact register/list/detail response shape - rooted artifact-path enforcement beneath the persisted workspace path - duplicate artifact registration rejection for the same workspace-relative path - - task-artifact per-user isolation + - supported `text/plain` and `text/markdown` ingestion + - deterministic line-ending normalization and fixed-window chunk boundaries + - invalid UTF-8 rejection + - idempotent re-ingestion of already ingested artifacts + - task-artifact and task-artifact-chunk per-user isolation - trace visibility for continuation and transition events - user isolation for task and task-step reads and mutations - adversarial lineage validation for cross-task, cross-user, and parent-step mismatch cases @@ -215,7 +234,8 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i The following areas remain planned later and must not be described as implemented: - runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam -- artifact indexing, artifact content processing, and document ingestion beyond the current explicit rooted local registration boundary +- retrieval over artifact chunks, chunk ranking, and embeddings beyond the current explicit rooted local ingestion boundary +- rich document parsing beyond the current narrow UTF-8 text and markdown ingestion boundary - read-only Gmail and Calendar connectors - broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler - model-driven extraction, reranking, and broader memory review automation diff --git a/apps/api/alembic/versions/20260314_0024_task_artifact_chunks.py b/apps/api/alembic/versions/20260314_0024_task_artifact_chunks.py new file mode 100644 index 0000000..ee51410 --- /dev/null +++ b/apps/api/alembic/versions/20260314_0024_task_artifact_chunks.py @@ -0,0 +1,97 @@ +"""Add user-scoped task artifact chunk records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260314_0024" +down_revision = "20260313_0023" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_artifact_chunks",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_artifact_chunks ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_artifact_id uuid NOT NULL, + sequence_no integer NOT NULL, + char_start integer NOT NULL, + char_end_exclusive integer NOT NULL, + text text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT task_artifact_chunks_artifact_user_fk + FOREIGN KEY (task_artifact_id, user_id) + REFERENCES task_artifacts(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_artifact_chunks_sequence_no_check + CHECK (sequence_no >= 1), + CONSTRAINT task_artifact_chunks_char_start_check + CHECK (char_start >= 0), + CONSTRAINT task_artifact_chunks_char_end_exclusive_check + CHECK (char_end_exclusive > char_start), + CONSTRAINT task_artifact_chunks_text_nonempty_check + CHECK (length(text) > 0) + ); + + CREATE UNIQUE INDEX task_artifact_chunks_artifact_sequence_idx + ON task_artifact_chunks (user_id, task_artifact_id, sequence_no); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT UPDATE ON task_artifacts TO alicebot_app", + "GRANT SELECT, INSERT ON task_artifact_chunks TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_artifact_chunks_is_owner ON task_artifact_chunks + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_UPGRADE_TASK_ARTIFACTS_STATEMENTS = ( + "ALTER TABLE task_artifacts DROP CONSTRAINT task_artifacts_ingestion_status_check", + """ + ALTER TABLE task_artifacts + ADD CONSTRAINT task_artifacts_ingestion_status_check + CHECK (ingestion_status IN ('pending', 'ingested')) + """, +) + +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON task_artifacts FROM alicebot_app", + "DROP TABLE IF EXISTS task_artifact_chunks", + "ALTER TABLE task_artifacts DROP CONSTRAINT task_artifacts_ingestion_status_check", + """ + ALTER TABLE task_artifacts + ADD CONSTRAINT task_artifacts_ingestion_status_check + CHECK (ingestion_status IN ('pending')) + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_TASK_ARTIFACTS_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py index c91f77c..611dbe1 100644 --- a/apps/api/src/alicebot_api/artifacts.py +++ b/apps/api/src/alicebot_api/artifacts.py @@ -8,17 +8,33 @@ from alicebot_api.contracts import ( TASK_ARTIFACT_LIST_ORDER, + TASK_ARTIFACT_CHUNK_LIST_ORDER, TaskArtifactCreateResponse, TaskArtifactDetailResponse, + TaskArtifactChunkListResponse, + TaskArtifactChunkListSummary, + TaskArtifactChunkRecord, TaskArtifactListResponse, TaskArtifactRecord, + TaskArtifactIngestInput, + TaskArtifactIngestionResponse, TaskArtifactRegisterInput, TaskArtifactStatus, TaskArtifactIngestionStatus, ) -from alicebot_api.store import ContinuityStore, TaskArtifactRow +from alicebot_api.store import ContinuityStore, TaskArtifactChunkRow, TaskArtifactRow from alicebot_api.workspaces import TaskWorkspaceNotFoundError +SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES = ("text/plain", "text/markdown") +SUPPORTED_TEXT_ARTIFACT_EXTENSIONS = { + ".txt": "text/plain", + ".text": "text/plain", + ".md": "text/markdown", + ".markdown": "text/markdown", +} +TASK_ARTIFACT_CHUNK_MAX_CHARS = 1000 +TASK_ARTIFACT_CHUNKING_RULE = "normalized_utf8_text_fixed_window_1000_chars_v1" + class TaskArtifactNotFoundError(LookupError): """Raised when a task artifact is not visible inside the current user scope.""" @@ -83,6 +99,79 @@ def serialize_task_artifact_row(row: TaskArtifactRow) -> TaskArtifactRecord: } +def serialize_task_artifact_chunk_row(row: TaskArtifactChunkRow) -> TaskArtifactChunkRecord: + return { + "id": str(row["id"]), + "task_artifact_id": str(row["task_artifact_id"]), + "sequence_no": row["sequence_no"], + "char_start": row["char_start"], + "char_end_exclusive": row["char_end_exclusive"], + "text": row["text"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def infer_task_artifact_media_type(row: TaskArtifactRow) -> str | None: + if row["media_type_hint"] is not None: + return row["media_type_hint"] + + artifact_path = Path(row["relative_path"]) + return SUPPORTED_TEXT_ARTIFACT_EXTENSIONS.get(artifact_path.suffix.lower()) + + +def resolve_supported_task_artifact_media_type(row: TaskArtifactRow) -> str: + media_type = infer_task_artifact_media_type(row) + if media_type in SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES: + return cast(str, media_type) + + supported_types = ", ".join(SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES) + raise TaskArtifactValidationError( + f"artifact {row['relative_path']} has unsupported media type " + f"{media_type or 'unknown'}; supported types: {supported_types}" + ) + + +def normalize_artifact_text(text: str) -> str: + return text.replace("\r\n", "\n").replace("\r", "\n") + + +def chunk_normalized_artifact_text( + text: str, + *, + chunk_size: int = TASK_ARTIFACT_CHUNK_MAX_CHARS, +) -> list[tuple[int, int, str]]: + chunks: list[tuple[int, int, str]] = [] + for char_start in range(0, len(text), chunk_size): + char_end_exclusive = min(char_start + chunk_size, len(text)) + chunks.append((char_start, char_end_exclusive, text[char_start:char_end_exclusive])) + return chunks + + +def resolve_registered_artifact_path(*, workspace_path: Path, relative_path: str) -> Path: + artifact_path = (workspace_path / relative_path).resolve() + ensure_artifact_path_is_rooted( + workspace_path=workspace_path, + artifact_path=artifact_path, + ) + return artifact_path + + +def build_task_artifact_chunk_list_summary( + chunk_rows: list[TaskArtifactChunkRow], + *, + media_type: str, +) -> TaskArtifactChunkListSummary: + total_characters = sum(len(row["text"]) for row in chunk_rows) + return { + "total_count": len(chunk_rows), + "total_characters": total_characters, + "media_type": media_type, + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": list(TASK_ARTIFACT_CHUNK_LIST_ORDER), + } + + def register_task_artifact_record( store: ContinuityStore, *, @@ -171,3 +260,92 @@ def get_task_artifact_record( if row is None: raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") return {"artifact": serialize_task_artifact_row(row)} + + +def ingest_task_artifact_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskArtifactIngestInput, +) -> TaskArtifactIngestionResponse: + del user_id + + row = store.get_task_artifact_optional(request.task_artifact_id) + if row is None: + raise TaskArtifactNotFoundError(f"task artifact {request.task_artifact_id} was not found") + + store.lock_task_artifact_ingestion(row["id"]) + row = store.get_task_artifact_optional(request.task_artifact_id) + if row is None: + raise TaskArtifactNotFoundError(f"task artifact {request.task_artifact_id} was not found") + + media_type = resolve_supported_task_artifact_media_type(row) + chunk_rows = store.list_task_artifact_chunks(row["id"]) + if row["ingestion_status"] == "ingested": + return { + "artifact": serialize_task_artifact_row(row), + "summary": build_task_artifact_chunk_list_summary(chunk_rows, media_type=media_type), + } + + workspace = store.get_task_workspace_optional(row["task_workspace_id"]) + if workspace is None: + raise TaskWorkspaceNotFoundError( + f"task workspace {row['task_workspace_id']} was not found" + ) + + workspace_path = Path(workspace["local_path"]).expanduser().resolve() + artifact_path = resolve_registered_artifact_path( + workspace_path=workspace_path, + relative_path=row["relative_path"], + ) + _require_existing_file(artifact_path) + + try: + text = artifact_path.read_bytes().decode("utf-8") + except UnicodeDecodeError as exc: + raise TaskArtifactValidationError( + f"artifact {row['relative_path']} is not valid UTF-8 text" + ) from exc + + normalized_text = normalize_artifact_text(text) + for index, (char_start, char_end_exclusive, chunk_text) in enumerate( + chunk_normalized_artifact_text(normalized_text), + start=1, + ): + store.create_task_artifact_chunk( + task_artifact_id=row["id"], + sequence_no=index, + char_start=char_start, + char_end_exclusive=char_end_exclusive, + text=chunk_text, + ) + + artifact_row = store.update_task_artifact_ingestion_status( + task_artifact_id=row["id"], + ingestion_status="ingested", + ) + chunk_rows = store.list_task_artifact_chunks(row["id"]) + return { + "artifact": serialize_task_artifact_row(artifact_row), + "summary": build_task_artifact_chunk_list_summary(chunk_rows, media_type=media_type), + } + + +def list_task_artifact_chunk_records( + store: ContinuityStore, + *, + user_id: UUID, + task_artifact_id: UUID, +) -> TaskArtifactChunkListResponse: + del user_id + + row = store.get_task_artifact_optional(task_artifact_id) + if row is None: + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + + chunk_rows = store.list_task_artifact_chunks(task_artifact_id) + media_type = infer_task_artifact_media_type(row) or "unknown" + return { + "items": [serialize_task_artifact_chunk_row(chunk_row) for chunk_row in chunk_rows], + "summary": build_task_artifact_chunk_list_summary(chunk_rows, media_type=media_type), + } diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index 06113c4..aa68f2e 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -21,7 +21,7 @@ TaskStatus = Literal["pending_approval", "approved", "executed", "denied", "blocked"] TaskWorkspaceStatus = Literal["active"] TaskArtifactStatus = Literal["registered"] -TaskArtifactIngestionStatus = Literal["pending"] +TaskArtifactIngestionStatus = Literal["pending", "ingested"] TaskLifecycleSource = Literal[ "approval_request", "approval_resolution", @@ -132,6 +132,7 @@ TASK_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_ARTIFACT_CHUNK_LIST_ORDER = ["sequence_no_asc", "id_asc"] TASK_STEP_LIST_ORDER = ["sequence_no_asc", "created_at_asc", "id_asc"] TOOL_EXECUTION_LIST_ORDER = ["executed_at_asc", "id_asc"] EXECUTION_BUDGET_LIST_ORDER = ["created_at_asc", "id_asc"] @@ -140,7 +141,7 @@ TASK_STATUSES = ["pending_approval", "approved", "executed", "denied", "blocked"] TASK_WORKSPACE_STATUSES = ["active"] TASK_ARTIFACT_STATUSES = ["registered"] -TASK_ARTIFACT_INGESTION_STATUSES = ["pending"] +TASK_ARTIFACT_INGESTION_STATUSES = ["pending", "ingested"] TASK_STEP_KINDS = ["governed_request"] TASK_STEP_STATUSES = ["created", "approved", "executed", "blocked", "denied"] APPROVAL_REQUEST_VERSION_V0 = "approval_request_v0" @@ -1606,6 +1607,11 @@ class TaskArtifactRegisterInput: media_type_hint: str | None = None +@dataclass(frozen=True, slots=True) +class TaskArtifactIngestInput: + task_artifact_id: UUID + + class TaskArtifactRecord(TypedDict): id: str task_id: str @@ -1636,6 +1642,35 @@ class TaskArtifactDetailResponse(TypedDict): artifact: TaskArtifactRecord +class TaskArtifactChunkRecord(TypedDict): + id: str + task_artifact_id: str + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + created_at: str + updated_at: str + + +class TaskArtifactChunkListSummary(TypedDict): + total_count: int + total_characters: int + media_type: str + chunking_rule: str + order: list[str] + + +class TaskArtifactChunkListResponse(TypedDict): + items: list[TaskArtifactChunkRecord] + summary: TaskArtifactChunkListSummary + + +class TaskArtifactIngestionResponse(TypedDict): + artifact: TaskArtifactRecord + summary: TaskArtifactChunkListSummary + + class TaskStepTraceLink(TypedDict): trace_id: str trace_kind: str diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index 5bce9bf..ebdf2d3 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -50,6 +50,7 @@ ProxyExecutionStatus, ToolAllowlistEvaluationRequestInput, ProxyExecutionRequestInput, + TaskArtifactIngestInput, TaskArtifactRegisterInput, TaskStepKind, TaskStepLineageInput, @@ -66,6 +67,8 @@ TaskArtifactNotFoundError, TaskArtifactValidationError, get_task_artifact_record, + ingest_task_artifact_record, + list_task_artifact_chunk_records, list_task_artifact_records, register_task_artifact_record, ) @@ -413,6 +416,10 @@ class RegisterTaskArtifactRequest(BaseModel): media_type_hint: str | None = Field(default=None, min_length=1, max_length=200) +class IngestTaskArtifactRequest(BaseModel): + user_id: UUID + + class TaskStepRequestSnapshot(BaseModel): thread_id: UUID tool_id: UUID @@ -1254,6 +1261,53 @@ def get_task_artifact(task_artifact_id: UUID, user_id: UUID) -> JSONResponse: ) +@app.post("/v0/task-artifacts/{task_artifact_id}/ingest") +def ingest_task_artifact( + task_artifact_id: UUID, + request: IngestTaskArtifactRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = ingest_task_artifact_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskArtifactIngestInput(task_artifact_id=task_artifact_id), + ) + except TaskArtifactNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskWorkspaceNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskArtifactValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-artifacts/{task_artifact_id}/chunks") +def list_task_artifact_chunks(task_artifact_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_artifact_chunk_records( + ContinuityStore(conn), + user_id=user_id, + task_artifact_id=task_artifact_id, + ) + except TaskArtifactNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.post("/v0/tasks/{task_id}/steps") def create_next_task_step(task_id: UUID, request: CreateNextTaskStepRequest) -> JSONResponse: settings = get_settings() diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index c849a35..13c7cc3 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -258,6 +258,18 @@ class TaskArtifactRow(TypedDict): updated_at: datetime +class TaskArtifactChunkRow(TypedDict): + id: UUID + user_id: UUID + task_artifact_id: UUID + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + created_at: datetime + updated_at: datetime + + class TaskStepRow(TypedDict): id: UUID user_id: UUID @@ -1517,6 +1529,75 @@ class LabelCountRow(TypedDict): ORDER BY created_at ASC, id ASC """ +LOCK_TASK_ARTIFACT_INGESTION_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 5))" + +INSERT_TASK_ARTIFACT_CHUNK_SQL = """ + INSERT INTO task_artifact_chunks ( + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + """ + +LIST_TASK_ARTIFACT_CHUNKS_SQL = """ + SELECT + id, + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + FROM task_artifact_chunks + WHERE task_artifact_id = %s + ORDER BY sequence_no ASC, id ASC + """ + +UPDATE_TASK_ARTIFACT_INGESTION_STATUS_SQL = """ + UPDATE task_artifacts + SET ingestion_status = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + """ + INSERT_TASK_STEP_SQL = """ INSERT INTO task_steps ( user_id, @@ -2650,6 +2731,40 @@ def get_task_artifact_by_workspace_relative_path_optional( def list_task_artifacts(self) -> list[TaskArtifactRow]: return self._fetch_all(LIST_TASK_ARTIFACTS_SQL) + def lock_task_artifact_ingestion(self, task_artifact_id: UUID) -> None: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_ARTIFACT_INGESTION_SQL, (str(task_artifact_id),)) + + def create_task_artifact_chunk( + self, + *, + task_artifact_id: UUID, + sequence_no: int, + char_start: int, + char_end_exclusive: int, + text: str, + ) -> TaskArtifactChunkRow: + return self._fetch_one( + "create_task_artifact_chunk", + INSERT_TASK_ARTIFACT_CHUNK_SQL, + (task_artifact_id, sequence_no, char_start, char_end_exclusive, text), + ) + + def list_task_artifact_chunks(self, task_artifact_id: UUID) -> list[TaskArtifactChunkRow]: + return self._fetch_all(LIST_TASK_ARTIFACT_CHUNKS_SQL, (task_artifact_id,)) + + def update_task_artifact_ingestion_status( + self, + *, + task_artifact_id: UUID, + ingestion_status: str, + ) -> TaskArtifactRow: + return self._fetch_one( + "update_task_artifact_ingestion_status", + UPDATE_TASK_ARTIFACT_INGESTION_STATUS_SQL, + (ingestion_status, task_artifact_id), + ) + def lock_task_steps(self, task_id: UUID) -> None: with self.conn.cursor() as cur: cur.execute(LOCK_TASK_STEPS_SQL, (str(task_id),)) diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 250c270..001f1a0 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -301,6 +301,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): assert cur.fetchone()[0] == "task_workspaces" cur.execute("SELECT to_regclass('public.task_artifacts')") assert cur.fetchone()[0] == "task_artifacts" + cur.execute("SELECT to_regclass('public.task_artifact_chunks')") + assert cur.fetchone()[0] == "task_artifact_chunks" cur.execute("SELECT to_regclass('public.task_steps')") assert cur.fetchone()[0] == "task_steps" cur.execute( @@ -383,6 +385,7 @@ def test_migrations_upgrade_and_downgrade(database_urls): 'tasks', 'task_workspaces', 'task_artifacts', + 'task_artifact_chunks', 'task_steps', 'execution_budgets', 'tool_executions' @@ -404,6 +407,7 @@ def test_migrations_upgrade_and_downgrade(database_urls): ("memory_revisions", True, True), ("policies", True, True), ("sessions", True, True), + ("task_artifact_chunks", True, True), ("task_artifacts", True, True), ("task_steps", True, True), ("task_workspaces", True, True), @@ -473,6 +477,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): has_table_privilege('alicebot_app', 'task_workspaces', 'DELETE'), has_table_privilege('alicebot_app', 'task_artifacts', 'UPDATE'), has_table_privilege('alicebot_app', 'task_artifacts', 'DELETE'), + has_table_privilege('alicebot_app', 'task_artifact_chunks', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_artifact_chunks', 'DELETE'), has_table_privilege('alicebot_app', 'task_steps', 'UPDATE'), has_table_privilege('alicebot_app', 'task_steps', 'DELETE'), has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE'), @@ -510,6 +516,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): False, False, False, + True, + False, False, False, True, @@ -524,6 +532,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): with psycopg.connect(database_urls["admin"]) as conn: with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_artifact_chunks')") + assert cur.fetchone()[0] is None cur.execute("SELECT to_regclass('public.task_artifacts')") assert cur.fetchone()[0] is None cur.execute("SELECT to_regclass('public.task_workspaces')") diff --git a/tests/integration/test_task_artifacts_api.py b/tests/integration/test_task_artifacts_api.py index ff78f47..cf9626b 100644 --- a/tests/integration/test_task_artifacts_api.py +++ b/tests/integration/test_task_artifacts_api.py @@ -7,6 +7,7 @@ from uuid import UUID, uuid4 import anyio +import psycopg import apps.api.src.alicebot_api.main as main_module from apps.api.src.alicebot_api.config import Settings @@ -291,3 +292,374 @@ def test_task_artifact_endpoints_register_list_detail_isolate_and_reject_duplica assert isolated_create_payload == { "detail": f"task workspace {workspace_payload['workspace']['id']} was not found" } + + +def test_task_artifact_ingestion_and_chunk_endpoints_are_deterministic_and_isolated( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + supported_file = workspace_path / "docs" / "spec.txt" + supported_file.parent.mkdir(parents=True) + supported_file.write_text(("A" * 998) + "\r\n" + ("B" * 5) + "\rC") + unsupported_file = workspace_path / "docs" / "manual.pdf" + unsupported_file.write_text("not really a pdf") + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(supported_file), + "media_type_hint": "text/plain", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_chunk_list_status, isolated_chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_ingest_status, isolated_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(intruder["user_id"])}, + ) + + unsupported_register_status, unsupported_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(unsupported_file), + "media_type_hint": "application/pdf", + }, + ) + assert unsupported_register_status == 201 + unsupported_ingest_status, unsupported_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{unsupported_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert ingest_status == 200 + assert ingest_payload == { + "artifact": { + "id": register_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": register_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "text/plain", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert chunk_list_status == 200 + assert chunk_list_payload == { + "items": [ + { + "id": chunk_list_payload["items"][0]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": ("A" * 998) + "\n" + "B", + "created_at": chunk_list_payload["items"][0]["created_at"], + "updated_at": chunk_list_payload["items"][0]["updated_at"], + }, + { + "id": chunk_list_payload["items"][1]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": chunk_list_payload["items"][1]["created_at"], + "updated_at": chunk_list_payload["items"][1]["updated_at"], + }, + ], + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "text/plain", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert isolated_chunk_list_status == 404 + assert isolated_chunk_list_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + assert isolated_ingest_status == 404 + assert isolated_ingest_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + assert unsupported_ingest_status == 400 + assert unsupported_ingest_payload == { + "detail": ( + "artifact docs/manual.pdf has unsupported media type application/pdf; " + "supported types: text/plain, text/markdown" + ) + } + + +def test_task_artifact_ingestion_supports_markdown_and_reingest_is_idempotent( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + markdown_file = workspace_path / "notes" / "plan.md" + markdown_file.parent.mkdir(parents=True) + markdown_file.write_text("# Plan\r\n\r\n- Ship ingestion\n- Keep scope narrow\r") + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(markdown_file), + "media_type_hint": "text/markdown", + }, + ) + assert register_status == 201 + + first_ingest_status, first_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + second_ingest_status, second_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(owner["user_id"])}, + ) + + assert first_ingest_status == 200 + assert first_ingest_payload == { + "artifact": { + "id": register_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "notes/plan.md", + "media_type_hint": "text/markdown", + "created_at": register_payload["artifact"]["created_at"], + "updated_at": first_ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": 1, + "total_characters": 45, + "media_type": "text/markdown", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert second_ingest_status == 200 + assert second_ingest_payload == first_ingest_payload + assert chunk_list_status == 200 + assert chunk_list_payload == { + "items": [ + { + "id": chunk_list_payload["items"][0]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 45, + "text": "# Plan\n\n- Ship ingestion\n- Keep scope narrow\n", + "created_at": chunk_list_payload["items"][0]["created_at"], + "updated_at": chunk_list_payload["items"][0]["updated_at"], + } + ], + "summary": { + "total_count": 1, + "total_characters": 45, + "media_type": "text/markdown", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + +def test_task_artifact_ingestion_rejects_invalid_utf8_content( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + broken_file = workspace_path / "docs" / "broken.txt" + broken_file.parent.mkdir(parents=True) + broken_file.write_bytes(b"\xff\xfe\xfd") + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(broken_file), + "media_type_hint": "text/plain", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert ingest_status == 400 + assert ingest_payload == { + "detail": "artifact docs/broken.txt is not valid UTF-8 text" + } + + +def test_task_artifact_ingestion_enforces_rooted_workspace_paths( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + safe_file = workspace_path / "docs" / "spec.txt" + safe_file.parent.mkdir(parents=True) + safe_file.write_text("spec") + outside_file = tmp_path / "escape.txt" + outside_file.write_text("escape") + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(safe_file), + "media_type_hint": "text/plain", + }, + ) + assert register_status == 201 + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE task_artifacts + SET relative_path = '../../../escape.txt' + WHERE id = %s + """, + (register_payload["artifact"]["id"],), + ) + conn.commit() + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert ingest_status == 400 + assert ingest_payload == { + "detail": f"artifact path {outside_file.resolve()} escapes workspace root {workspace_path.resolve()}" + } diff --git a/tests/unit/test_20260314_0024_task_artifact_chunks.py b/tests/unit/test_20260314_0024_task_artifact_chunks.py new file mode 100644 index 0000000..5e9e6da --- /dev/null +++ b/tests/unit/test_20260314_0024_task_artifact_chunks.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260314_0024_task_artifact_chunks" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_TASK_ARTIFACTS_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_artifact_chunks ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_artifact_chunks FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_artifact_chunk_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT UPDATE ON task_artifacts TO alicebot_app", + "GRANT SELECT, INSERT ON task_artifact_chunks TO alicebot_app", + ) diff --git a/tests/unit/test_artifacts.py b/tests/unit/test_artifacts.py index 33d6fb7..e6ed44c 100644 --- a/tests/unit/test_artifacts.py +++ b/tests/unit/test_artifacts.py @@ -7,17 +7,22 @@ import pytest from alicebot_api.artifacts import ( + TASK_ARTIFACT_CHUNKING_RULE, TaskArtifactAlreadyExistsError, TaskArtifactNotFoundError, TaskArtifactValidationError, build_workspace_relative_artifact_path, + chunk_normalized_artifact_text, ensure_artifact_path_is_rooted, get_task_artifact_record, + ingest_task_artifact_record, + list_task_artifact_chunk_records, list_task_artifact_records, + normalize_artifact_text, register_task_artifact_record, serialize_task_artifact_row, ) -from alicebot_api.contracts import TaskArtifactRegisterInput +from alicebot_api.contracts import TaskArtifactIngestInput, TaskArtifactRegisterInput from alicebot_api.workspaces import TaskWorkspaceNotFoundError @@ -26,7 +31,9 @@ def __init__(self) -> None: self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) self.workspaces: list[dict[str, object]] = [] self.artifacts: list[dict[str, object]] = [] + self.artifact_chunks: list[dict[str, object]] = [] self.locked_workspace_ids: list[UUID] = [] + self.locked_artifact_ids: list[UUID] = [] def create_task_workspace(self, *, task_workspace_id: UUID, task_id: UUID, user_id: UUID, local_path: str) -> dict[str, object]: workspace = { @@ -94,6 +101,54 @@ def list_task_artifacts(self) -> list[dict[str, object]]: def get_task_artifact_optional(self, task_artifact_id: UUID) -> dict[str, object] | None: return next((artifact for artifact in self.artifacts if artifact["id"] == task_artifact_id), None) + def lock_task_artifact_ingestion(self, task_artifact_id: UUID) -> None: + self.locked_artifact_ids.append(task_artifact_id) + + def create_task_artifact_chunk( + self, + *, + task_artifact_id: UUID, + sequence_no: int, + char_start: int, + char_end_exclusive: int, + text: str, + ) -> dict[str, object]: + chunk = { + "id": uuid4(), + "user_id": self.workspaces[0]["user_id"], + "task_artifact_id": task_artifact_id, + "sequence_no": sequence_no, + "char_start": char_start, + "char_end_exclusive": char_end_exclusive, + "text": text, + "created_at": self.base_time + timedelta(seconds=len(self.artifact_chunks)), + "updated_at": self.base_time + timedelta(seconds=len(self.artifact_chunks)), + } + self.artifact_chunks.append(chunk) + return chunk + + def list_task_artifact_chunks(self, task_artifact_id: UUID) -> list[dict[str, object]]: + return sorted( + ( + chunk + for chunk in self.artifact_chunks + if chunk["task_artifact_id"] == task_artifact_id + ), + key=lambda chunk: (chunk["sequence_no"], chunk["id"]), + ) + + def update_task_artifact_ingestion_status( + self, + *, + task_artifact_id: UUID, + ingestion_status: str, + ) -> dict[str, object]: + artifact = self.get_task_artifact_optional(task_artifact_id) + assert artifact is not None + artifact["ingestion_status"] = ingestion_status + artifact["updated_at"] = self.base_time + timedelta(minutes=30) + return artifact + def test_ensure_artifact_path_is_rooted_rejects_escape() -> None: with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): @@ -241,6 +296,380 @@ def test_register_task_artifact_record_rejects_paths_outside_workspace(tmp_path) ) +def test_normalize_and_chunk_artifact_text_are_deterministic() -> None: + normalized = normalize_artifact_text("ab\r\ncd\ref") + + assert normalized == "ab\ncd\nef" + assert chunk_normalized_artifact_text(normalized, chunk_size=4) == [ + (0, 4, "ab\nc"), + (4, 8, "d\nef"), + ] + + +def test_ingest_task_artifact_record_persists_deterministic_chunks(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "spec.txt" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_text(("A" * 998) + "\r\n" + ("B" * 5) + "\rC") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + assert response == { + "artifact": { + "id": str(artifact["id"]), + "task_id": str(task_id), + "task_workspace_id": str(task_workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:30:00+00:00", + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "text/plain", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert store.locked_artifact_ids == [artifact["id"]] + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": ("A" * 998) + "\n" + "B", + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + }, + { + "id": store.artifact_chunks[1]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + }, + ] + + +def test_ingest_task_artifact_record_supports_markdown(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "notes" / "plan.md" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_text("# Plan\r\n\r\n- Ship ingestion\n- Keep scope narrow\r") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="notes/plan.md", + media_type_hint="text/markdown", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + assert response["artifact"]["ingestion_status"] == "ingested" + assert response["summary"] == { + "total_count": 1, + "total_characters": 45, + "media_type": "text/markdown", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + } + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 45, + "text": "# Plan\n\n- Ship ingestion\n- Keep scope narrow\n", + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + } + ] + + +def test_ingest_task_artifact_record_is_idempotent_for_already_ingested_artifact() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=4, + text="spec", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + assert response == { + "artifact": { + "id": str(artifact["id"]), + "task_id": str(task_id), + "task_workspace_id": str(task_workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + "summary": { + "total_count": 1, + "total_characters": 4, + "media_type": "text/plain", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert store.locked_artifact_ids == [artifact["id"]] + assert len(store.artifact_chunks) == 1 + + +def test_ingest_task_artifact_record_rejects_unsupported_media_type(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "spec.pdf" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_text("not really a pdf") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/spec.pdf", + media_type_hint="application/pdf", + ) + + with pytest.raises( + TaskArtifactValidationError, + match="artifact docs/spec.pdf has unsupported media type application/pdf", + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + +def test_ingest_task_artifact_record_rejects_invalid_utf8_content(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "broken.txt" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(b"\xff\xfe\xfd") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/broken.txt", + media_type_hint="text/plain", + ) + + with pytest.raises( + TaskArtifactValidationError, + match="artifact docs/broken.txt is not valid UTF-8 text", + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + +def test_ingest_task_artifact_record_rejects_paths_outside_workspace(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + outside_path = tmp_path / "escape.txt" + outside_path.write_text("escape") + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="../escape.txt", + media_type_hint="text/plain", + ) + + with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + +def test_list_task_artifact_chunk_records_are_deterministic() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=4, + text="spec", + ) + store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=2, + char_start=4, + char_end_exclusive=8, + text="plan", + ) + + assert list_task_artifact_chunk_records( + store, + user_id=user_id, + task_artifact_id=artifact["id"], + ) == { + "items": [ + { + "id": str(store.artifact_chunks[0]["id"]), + "task_artifact_id": str(artifact["id"]), + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 4, + "text": "spec", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": str(store.artifact_chunks[1]["id"]), + "task_artifact_id": str(artifact["id"]), + "sequence_no": 2, + "char_start": 4, + "char_end_exclusive": 8, + "text": "plan", + "created_at": "2026-03-13T10:00:01+00:00", + "updated_at": "2026-03-13T10:00:01+00:00", + }, + ], + "summary": { + "total_count": 2, + "total_characters": 8, + "media_type": "text/plain", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + }, + } + + def test_list_and_get_task_artifact_records_are_deterministic() -> None: store = ArtifactStoreStub() user_id = uuid4() diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py index a791655..9e6e1b7 100644 --- a/tests/unit/test_artifacts_main.py +++ b/tests/unit/test_artifacts_main.py @@ -64,6 +64,47 @@ def fake_get_task_artifact_record(*_args, **_kwargs): assert json.loads(response.body) == {"detail": f"task artifact {task_artifact_id} was not found"} +def test_list_task_artifact_chunks_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + task_artifact_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_artifact_chunk_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "total_count": 0, + "total_characters": 0, + "media_type": "text/plain", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + }, + ) + + response = main_module.list_task_artifact_chunks(task_artifact_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": { + "total_count": 0, + "total_characters": 0, + "media_type": "text/plain", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + def test_register_task_artifact_endpoint_maps_workspace_not_found_to_404(monkeypatch) -> None: user_id = uuid4() task_workspace_id = uuid4() @@ -152,3 +193,61 @@ def fake_register_task_artifact_record(*_args, **_kwargs): assert json.loads(response.body) == { "detail": f"artifact docs/spec.txt is already registered for task workspace {task_workspace_id}" } + + +def test_ingest_task_artifact_endpoint_maps_validation_to_400(monkeypatch) -> None: + user_id = uuid4() + task_artifact_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_ingest_task_artifact_record(*_args, **_kwargs): + raise TaskArtifactValidationError( + "artifact docs/spec.txt has unsupported media type application/pdf; " + "supported types: text/plain, text/markdown" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "ingest_task_artifact_record", fake_ingest_task_artifact_record) + + response = main_module.ingest_task_artifact( + task_artifact_id, + main_module.IngestTaskArtifactRequest(user_id=user_id), + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": ( + "artifact docs/spec.txt has unsupported media type application/pdf; " + "supported types: text/plain, text/markdown" + ) + } + + +def test_ingest_task_artifact_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_artifact_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_ingest_task_artifact_record(*_args, **_kwargs): + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "ingest_task_artifact_record", fake_ingest_task_artifact_record) + + response = main_module.ingest_task_artifact( + task_artifact_id, + main_module.IngestTaskArtifactRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task artifact {task_artifact_id} was not found"} diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 20b7a00..446f108 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -129,6 +129,8 @@ def test_healthcheck_route_is_registered() -> None: assert "/v0/task-workspaces/{task_workspace_id}/artifacts" in route_paths assert "/v0/task-artifacts" in route_paths assert "/v0/task-artifacts/{task_artifact_id}" in route_paths + assert "/v0/task-artifacts/{task_artifact_id}/ingest" in route_paths + assert "/v0/task-artifacts/{task_artifact_id}/chunks" in route_paths assert "/v0/task-steps/{task_step_id}" in route_paths assert "/v0/task-steps/{task_step_id}/transition" in route_paths assert "/v0/entities/{entity_id}" in route_paths diff --git a/tests/unit/test_task_artifact_store.py b/tests/unit/test_task_artifact_store.py index c6f6277..df841c0 100644 --- a/tests/unit/test_task_artifact_store.py +++ b/tests/unit/test_task_artifact_store.py @@ -226,3 +226,145 @@ def test_task_artifact_store_methods_use_expected_queries() -> None: (str(task_workspace_id),), ), ] + + +def test_task_artifact_chunk_store_methods_use_expected_queries() -> None: + task_artifact_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": uuid4(), + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 4, + "text": "spec", + "created_at": "2026-03-14T10:00:00+00:00", + "updated_at": "2026-03-14T10:00:00+00:00", + }, + { + "id": task_artifact_id, + "user_id": uuid4(), + "task_id": uuid4(), + "task_workspace_id": uuid4(), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": "2026-03-14T10:00:00+00:00", + "updated_at": "2026-03-14T10:01:00+00:00", + }, + ], + fetchall_result=[ + { + "id": uuid4(), + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 4, + "text": "spec", + "created_at": "2026-03-14T10:00:00+00:00", + "updated_at": "2026-03-14T10:00:00+00:00", + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_task_artifact_chunk( + task_artifact_id=task_artifact_id, + sequence_no=1, + char_start=0, + char_end_exclusive=4, + text="spec", + ) + updated = store.update_task_artifact_ingestion_status( + task_artifact_id=task_artifact_id, + ingestion_status="ingested", + ) + listed = store.list_task_artifact_chunks(task_artifact_id) + store.lock_task_artifact_ingestion(task_artifact_id) + + assert created["task_artifact_id"] == task_artifact_id + assert updated["ingestion_status"] == "ingested" + assert listed[0]["task_artifact_id"] == task_artifact_id + assert cursor.executed == [ + ( + """ + INSERT INTO task_artifact_chunks ( + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + """, + (task_artifact_id, 1, 0, 4, "spec"), + ), + ( + """ + UPDATE task_artifacts + SET ingestion_status = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + """, + ("ingested", task_artifact_id), + ), + ( + """ + SELECT + id, + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + FROM task_artifact_chunks + WHERE task_artifact_id = %s + ORDER BY sequence_no ASC, id ASC + """, + (task_artifact_id,), + ), + ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 5))", + (str(task_artifact_id),), + ), + ] From ec1320244de9d6fde4c19835ef6625aee04d48b4 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 14 Mar 2026 21:41:33 +0100 Subject: [PATCH 005/135] Sprint 5E: add artifact chunk retrieval (#5) Co-authored-by: Sami Rusani --- apps/api/src/alicebot_api/artifacts.py | 233 ++++++++++++ apps/api/src/alicebot_api/contracts.py | 60 +++ apps/api/src/alicebot_api/main.py | 66 ++++ apps/api/src/alicebot_api/store.py | 20 + tests/integration/test_task_artifacts_api.py | 308 +++++++++++++++ tests/unit/test_artifacts.py | 378 ++++++++++++++++++- tests/unit/test_artifacts_main.py | 155 ++++++++ tests/unit/test_task_artifact_store.py | 21 ++ 8 files changed, 1240 insertions(+), 1 deletion(-) diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py index 611dbe1..d3b794f 100644 --- a/apps/api/src/alicebot_api/artifacts.py +++ b/apps/api/src/alicebot_api/artifacts.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from pathlib import Path from typing import cast from uuid import UUID @@ -9,6 +10,14 @@ from alicebot_api.contracts import ( TASK_ARTIFACT_LIST_ORDER, TASK_ARTIFACT_CHUNK_LIST_ORDER, + TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER, + ArtifactScopedArtifactChunkRetrievalInput, + TaskArtifactChunkRetrievalItem, + TaskArtifactChunkRetrievalMatch, + TaskArtifactChunkRetrievalResponse, + TaskArtifactChunkRetrievalScope, + TaskArtifactChunkRetrievalScopeKind, + TaskArtifactChunkRetrievalSummary, TaskArtifactCreateResponse, TaskArtifactDetailResponse, TaskArtifactChunkListResponse, @@ -21,8 +30,10 @@ TaskArtifactRegisterInput, TaskArtifactStatus, TaskArtifactIngestionStatus, + TaskScopedArtifactChunkRetrievalInput, ) from alicebot_api.store import ContinuityStore, TaskArtifactChunkRow, TaskArtifactRow +from alicebot_api.tasks import TaskNotFoundError from alicebot_api.workspaces import TaskWorkspaceNotFoundError SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES = ("text/plain", "text/markdown") @@ -34,6 +45,10 @@ } TASK_ARTIFACT_CHUNK_MAX_CHARS = 1000 TASK_ARTIFACT_CHUNKING_RULE = "normalized_utf8_text_fixed_window_1000_chars_v1" +TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE = ( + "casefolded_unicode_word_overlap_unique_query_terms_v1" +) +_LEXICAL_TERM_PATTERN = re.compile(r"\w+") class TaskArtifactNotFoundError(LookupError): @@ -48,6 +63,10 @@ class TaskArtifactValidationError(ValueError): """Raised when a local artifact path cannot satisfy registration constraints.""" +class TaskArtifactChunkRetrievalValidationError(ValueError): + """Raised when an artifact chunk retrieval request cannot be evaluated safely.""" + + def resolve_artifact_path(local_path: str) -> Path: return Path(local_path).expanduser().resolve() @@ -172,6 +191,150 @@ def build_task_artifact_chunk_list_summary( } +def extract_unique_lexical_terms(text: str) -> list[str]: + terms: list[str] = [] + seen: set[str] = set() + for match in _LEXICAL_TERM_PATTERN.finditer(text.casefold()): + term = match.group(0) + if term in seen: + continue + seen.add(term) + terms.append(term) + return terms + + +def resolve_artifact_chunk_retrieval_query_terms(query: str) -> list[str]: + terms = extract_unique_lexical_terms(query) + if not terms: + raise TaskArtifactChunkRetrievalValidationError( + "artifact chunk retrieval query must include at least one word" + ) + return terms + + +def build_task_artifact_chunk_retrieval_scope( + *, + kind: str, + task_id: UUID, + task_artifact_id: UUID | None = None, +) -> TaskArtifactChunkRetrievalScope: + scope: TaskArtifactChunkRetrievalScope = { + "kind": cast(TaskArtifactChunkRetrievalScopeKind, kind), + "task_id": str(task_id), + } + if task_artifact_id is not None: + scope["task_artifact_id"] = str(task_artifact_id) + return scope + + +def build_task_artifact_chunk_retrieval_summary( + *, + total_count: int, + searched_artifact_count: int, + query: str, + query_terms: list[str], + scope: TaskArtifactChunkRetrievalScope, +) -> TaskArtifactChunkRetrievalSummary: + return { + "total_count": total_count, + "searched_artifact_count": searched_artifact_count, + "query": query, + "query_terms": list(query_terms), + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), + "scope": scope, + } + + +def match_artifact_chunk_text( + *, + query_terms: list[str], + chunk_text: str, +) -> TaskArtifactChunkRetrievalMatch | None: + first_positions: dict[str, int] = {} + for match in _LEXICAL_TERM_PATTERN.finditer(chunk_text.casefold()): + term = match.group(0) + if term not in first_positions: + first_positions[term] = match.start() + + matched_terms = [term for term in query_terms if term in first_positions] + if not matched_terms: + return None + + return { + "matched_query_terms": matched_terms, + "matched_query_term_count": len(matched_terms), + "first_match_char_start": min(first_positions[term] for term in matched_terms), + } + + +def serialize_task_artifact_chunk_retrieval_item( + *, + artifact_row: TaskArtifactRow, + chunk_row: TaskArtifactChunkRow, + match: TaskArtifactChunkRetrievalMatch, +) -> TaskArtifactChunkRetrievalItem: + return { + "id": str(chunk_row["id"]), + "task_id": str(artifact_row["task_id"]), + "task_artifact_id": str(chunk_row["task_artifact_id"]), + "relative_path": artifact_row["relative_path"], + "media_type": infer_task_artifact_media_type(artifact_row) or "unknown", + "sequence_no": chunk_row["sequence_no"], + "char_start": chunk_row["char_start"], + "char_end_exclusive": chunk_row["char_end_exclusive"], + "text": chunk_row["text"], + "match": match, + } + + +def retrieve_matching_task_artifact_chunks( + store: ContinuityStore, + *, + artifact_rows: list[TaskArtifactRow], + query_terms: list[str], +) -> tuple[list[TaskArtifactChunkRetrievalItem], int]: + matched_items_with_keys: list[ + tuple[tuple[int, int, str, int, str], TaskArtifactChunkRetrievalItem] + ] = [] + searched_artifact_count = 0 + + for artifact_row in artifact_rows: + if artifact_row["ingestion_status"] != "ingested": + continue + + searched_artifact_count += 1 + chunk_rows = store.list_task_artifact_chunks(artifact_row["id"]) + for chunk_row in chunk_rows: + match = match_artifact_chunk_text( + query_terms=query_terms, + chunk_text=chunk_row["text"], + ) + if match is None: + continue + + item = serialize_task_artifact_chunk_retrieval_item( + artifact_row=artifact_row, + chunk_row=chunk_row, + match=match, + ) + matched_items_with_keys.append( + ( + ( + -match["matched_query_term_count"], + match["first_match_char_start"], + artifact_row["relative_path"], + chunk_row["sequence_no"], + str(chunk_row["id"]), + ), + item, + ) + ) + + matched_items_with_keys.sort(key=lambda entry: entry[0]) + return [item for _, item in matched_items_with_keys], searched_artifact_count + + def register_task_artifact_record( store: ContinuityStore, *, @@ -349,3 +512,73 @@ def list_task_artifact_chunk_records( "items": [serialize_task_artifact_chunk_row(chunk_row) for chunk_row in chunk_rows], "summary": build_task_artifact_chunk_list_summary(chunk_rows, media_type=media_type), } + + +def retrieve_task_scoped_artifact_chunk_records( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskScopedArtifactChunkRetrievalInput, +) -> TaskArtifactChunkRetrievalResponse: + del user_id + + task = store.get_task_optional(request.task_id) + if task is None: + raise TaskNotFoundError(f"task {request.task_id} was not found") + + query_terms = resolve_artifact_chunk_retrieval_query_terms(request.query) + artifact_rows = store.list_task_artifacts_for_task(request.task_id) + items, searched_artifact_count = retrieve_matching_task_artifact_chunks( + store, + artifact_rows=artifact_rows, + query_terms=query_terms, + ) + scope = build_task_artifact_chunk_retrieval_scope( + kind="task", + task_id=request.task_id, + ) + return { + "items": items, + "summary": build_task_artifact_chunk_retrieval_summary( + total_count=len(items), + searched_artifact_count=searched_artifact_count, + query=request.query, + query_terms=query_terms, + scope=scope, + ), + } + + +def retrieve_artifact_scoped_artifact_chunk_records( + store: ContinuityStore, + *, + user_id: UUID, + request: ArtifactScopedArtifactChunkRetrievalInput, +) -> TaskArtifactChunkRetrievalResponse: + del user_id + + artifact_row = store.get_task_artifact_optional(request.task_artifact_id) + if artifact_row is None: + raise TaskArtifactNotFoundError(f"task artifact {request.task_artifact_id} was not found") + + query_terms = resolve_artifact_chunk_retrieval_query_terms(request.query) + items, searched_artifact_count = retrieve_matching_task_artifact_chunks( + store, + artifact_rows=[artifact_row], + query_terms=query_terms, + ) + scope = build_task_artifact_chunk_retrieval_scope( + kind="artifact", + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + ) + return { + "items": items, + "summary": build_task_artifact_chunk_retrieval_summary( + total_count=len(items), + searched_artifact_count=searched_artifact_count, + query=request.query, + query_terms=query_terms, + scope=scope, + ), + } diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index aa68f2e..c86549c 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -22,6 +22,7 @@ TaskWorkspaceStatus = Literal["active"] TaskArtifactStatus = Literal["registered"] TaskArtifactIngestionStatus = Literal["pending", "ingested"] +TaskArtifactChunkRetrievalScopeKind = Literal["task", "artifact"] TaskLifecycleSource = Literal[ "approval_request", "approval_resolution", @@ -133,6 +134,13 @@ TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_CHUNK_LIST_ORDER = ["sequence_no_asc", "id_asc"] +TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER = [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", +] TASK_STEP_LIST_ORDER = ["sequence_no_asc", "created_at_asc", "id_asc"] TOOL_EXECUTION_LIST_ORDER = ["executed_at_asc", "id_asc"] EXECUTION_BUDGET_LIST_ORDER = ["created_at_asc", "id_asc"] @@ -1612,6 +1620,18 @@ class TaskArtifactIngestInput: task_artifact_id: UUID +@dataclass(frozen=True, slots=True) +class TaskScopedArtifactChunkRetrievalInput: + task_id: UUID + query: str + + +@dataclass(frozen=True, slots=True) +class ArtifactScopedArtifactChunkRetrievalInput: + task_artifact_id: UUID + query: str + + class TaskArtifactRecord(TypedDict): id: str task_id: str @@ -1671,6 +1691,46 @@ class TaskArtifactIngestionResponse(TypedDict): summary: TaskArtifactChunkListSummary +class TaskArtifactChunkRetrievalMatch(TypedDict): + matched_query_terms: list[str] + matched_query_term_count: int + first_match_char_start: int + + +class TaskArtifactChunkRetrievalItem(TypedDict): + id: str + task_id: str + task_artifact_id: str + relative_path: str + media_type: str + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + match: TaskArtifactChunkRetrievalMatch + + +class TaskArtifactChunkRetrievalScope(TypedDict): + kind: TaskArtifactChunkRetrievalScopeKind + task_id: str + task_artifact_id: NotRequired[str] + + +class TaskArtifactChunkRetrievalSummary(TypedDict): + total_count: int + searched_artifact_count: int + query: str + query_terms: list[str] + matching_rule: str + order: list[str] + scope: TaskArtifactChunkRetrievalScope + + +class TaskArtifactChunkRetrievalResponse(TypedDict): + items: list[TaskArtifactChunkRetrievalItem] + summary: TaskArtifactChunkRetrievalSummary + + class TaskStepTraceLink(TypedDict): trace_id: str trace_kind: str diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index ebdf2d3..982becb 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -47,11 +47,13 @@ SemanticMemoryRetrievalRequestInput, TOOL_METADATA_VERSION_V0, ApprovalStatus, + ArtifactScopedArtifactChunkRetrievalInput, ProxyExecutionStatus, ToolAllowlistEvaluationRequestInput, ProxyExecutionRequestInput, TaskArtifactIngestInput, TaskArtifactRegisterInput, + TaskScopedArtifactChunkRetrievalInput, TaskStepKind, TaskStepLineageInput, TaskStepNextCreateInput, @@ -64,6 +66,7 @@ ) from alicebot_api.artifacts import ( TaskArtifactAlreadyExistsError, + TaskArtifactChunkRetrievalValidationError, TaskArtifactNotFoundError, TaskArtifactValidationError, get_task_artifact_record, @@ -71,6 +74,8 @@ list_task_artifact_chunk_records, list_task_artifact_records, register_task_artifact_record, + retrieve_artifact_scoped_artifact_chunk_records, + retrieve_task_scoped_artifact_chunk_records, ) from alicebot_api.approvals import ( ApprovalNotFoundError, @@ -420,6 +425,11 @@ class IngestTaskArtifactRequest(BaseModel): user_id: UUID +class RetrieveArtifactChunksRequest(BaseModel): + user_id: UUID + query: str = Field(min_length=1, max_length=1000) + + class TaskStepRequestSnapshot(BaseModel): thread_id: UUID tool_id: UUID @@ -1308,6 +1318,62 @@ def list_task_artifact_chunks(task_artifact_id: UUID, user_id: UUID) -> JSONResp ) +@app.post("/v0/tasks/{task_id}/artifact-chunks/retrieve") +def retrieve_task_artifact_chunks( + task_id: UUID, + request: RetrieveArtifactChunksRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = retrieve_task_scoped_artifact_chunk_records( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskScopedArtifactChunkRetrievalInput( + task_id=task_id, + query=request.query, + ), + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskArtifactChunkRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/task-artifacts/{task_artifact_id}/chunks/retrieve") +def retrieve_task_artifact_chunks_for_artifact( + task_artifact_id: UUID, + request: RetrieveArtifactChunksRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = retrieve_artifact_scoped_artifact_chunk_records( + ContinuityStore(conn), + user_id=request.user_id, + request=ArtifactScopedArtifactChunkRetrievalInput( + task_artifact_id=task_artifact_id, + query=request.query, + ), + ) + except TaskArtifactNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskArtifactChunkRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.post("/v0/tasks/{task_id}/steps") def create_next_task_step(task_id: UUID, request: CreateNextTaskStepRequest) -> JSONResponse: settings = get_settings() diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index 13c7cc3..d18ced9 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -1529,6 +1529,23 @@ class LabelCountRow(TypedDict): ORDER BY created_at ASC, id ASC """ +LIST_TASK_ARTIFACTS_FOR_TASK_SQL = """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + WHERE task_id = %s + ORDER BY created_at ASC, id ASC + """ + LOCK_TASK_ARTIFACT_INGESTION_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 5))" INSERT_TASK_ARTIFACT_CHUNK_SQL = """ @@ -2731,6 +2748,9 @@ def get_task_artifact_by_workspace_relative_path_optional( def list_task_artifacts(self) -> list[TaskArtifactRow]: return self._fetch_all(LIST_TASK_ARTIFACTS_SQL) + def list_task_artifacts_for_task(self, task_id: UUID) -> list[TaskArtifactRow]: + return self._fetch_all(LIST_TASK_ARTIFACTS_FOR_TASK_SQL, (task_id,)) + def lock_task_artifact_ingestion(self, task_artifact_id: UUID) -> None: with self.conn.cursor() as cur: cur.execute(LOCK_TASK_ARTIFACT_INGESTION_SQL, (str(task_artifact_id),)) diff --git a/tests/integration/test_task_artifacts_api.py b/tests/integration/test_task_artifacts_api.py index cf9626b..1aab4e1 100644 --- a/tests/integration/test_task_artifacts_api.py +++ b/tests/integration/test_task_artifacts_api.py @@ -11,6 +11,7 @@ import apps.api.src.alicebot_api.main as main_module from apps.api.src.alicebot_api.config import Settings +from alicebot_api.artifacts import TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE from alicebot_api.db import user_connection from alicebot_api.store import ContinuityStore @@ -663,3 +664,310 @@ def test_task_artifact_ingestion_enforces_rooted_workspace_paths( assert ingest_payload == { "detail": f"artifact path {outside_file.resolve()} escapes workspace root {workspace_path.resolve()}" } + + +def test_task_artifact_chunk_retrieval_endpoints_are_scoped_deterministic_and_isolated( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + owner_workspace_status, owner_workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert owner_workspace_status == 201 + owner_workspace_path = Path(owner_workspace_payload["workspace"]["local_path"]) + + docs_file = owner_workspace_path / "docs" / "a.txt" + docs_file.parent.mkdir(parents=True) + docs_file.write_text("beta alpha doc") + notes_file = owner_workspace_path / "notes" / "b.md" + notes_file.parent.mkdir(parents=True) + notes_file.write_text("alpha beta note") + weak_file = owner_workspace_path / "notes" / "c.txt" + weak_file.write_text("beta only") + pending_file = owner_workspace_path / "notes" / "hidden.txt" + pending_file.write_text("alpha beta hidden") + + docs_register_status, docs_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{owner_workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(docs_file), + "media_type_hint": "text/plain", + }, + ) + notes_register_status, notes_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{owner_workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(notes_file), + "media_type_hint": "text/markdown", + }, + ) + weak_register_status, weak_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{owner_workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(weak_file), + "media_type_hint": "text/plain", + }, + ) + pending_register_status, pending_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{owner_workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(pending_file), + "media_type_hint": "text/plain", + }, + ) + assert docs_register_status == 201 + assert notes_register_status == 201 + assert weak_register_status == 201 + assert pending_register_status == 201 + + docs_ingest_status, _ = invoke_request( + "POST", + f"/v0/task-artifacts/{docs_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + notes_ingest_status, _ = invoke_request( + "POST", + f"/v0/task-artifacts/{notes_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + weak_ingest_status, _ = invoke_request( + "POST", + f"/v0/task-artifacts/{weak_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + assert docs_ingest_status == 200 + assert notes_ingest_status == 200 + assert weak_ingest_status == 200 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_task_artifact_chunk( + task_artifact_id=UUID(pending_register_payload["artifact"]["id"]), + sequence_no=1, + char_start=0, + char_end_exclusive=17, + text="alpha beta hidden", + ) + + intruder_workspace_status, intruder_workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{intruder['task_id']}/workspace", + payload={"user_id": str(intruder["user_id"])}, + ) + assert intruder_workspace_status == 201 + intruder_workspace_path = Path(intruder_workspace_payload["workspace"]["local_path"]) + intruder_file = intruder_workspace_path / "docs" / "secret.txt" + intruder_file.parent.mkdir(parents=True) + intruder_file.write_text("alpha beta intruder") + + intruder_register_status, intruder_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{intruder_workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(intruder["user_id"]), + "local_path": str(intruder_file), + "media_type_hint": "text/plain", + }, + ) + assert intruder_register_status == 201 + intruder_ingest_status, _ = invoke_request( + "POST", + f"/v0/task-artifacts/{intruder_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(intruder["user_id"])}, + ) + assert intruder_ingest_status == 200 + + task_retrieve_status, task_retrieve_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/retrieve", + payload={"user_id": str(owner["user_id"]), "query": "Alpha beta"}, + ) + artifact_retrieve_status, artifact_retrieve_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{notes_register_payload['artifact']['id']}/chunks/retrieve", + payload={"user_id": str(owner["user_id"]), "query": "Alpha beta"}, + ) + empty_retrieve_status, empty_retrieve_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/retrieve", + payload={"user_id": str(owner["user_id"]), "query": "missing"}, + ) + isolated_task_retrieve_status, isolated_task_retrieve_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/retrieve", + payload={"user_id": str(intruder["user_id"]), "query": "Alpha beta"}, + ) + isolated_artifact_retrieve_status, isolated_artifact_retrieve_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{notes_register_payload['artifact']['id']}/chunks/retrieve", + payload={"user_id": str(intruder["user_id"]), "query": "Alpha beta"}, + ) + + assert task_retrieve_status == 200 + assert task_retrieve_payload == { + "items": [ + { + "id": task_retrieve_payload["items"][0]["id"], + "task_id": str(owner["task_id"]), + "task_artifact_id": docs_register_payload["artifact"]["id"], + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + { + "id": task_retrieve_payload["items"][1]["id"], + "task_id": str(owner["task_id"]), + "task_artifact_id": notes_register_payload["artifact"]["id"], + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + { + "id": task_retrieve_payload["items"][2]["id"], + "task_id": str(owner["task_id"]), + "task_artifact_id": weak_register_payload["artifact"]["id"], + "relative_path": "notes/c.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 9, + "text": "beta only", + "match": { + "matched_query_terms": ["beta"], + "matched_query_term_count": 1, + "first_match_char_start": 0, + }, + }, + ], + "summary": { + "total_count": 3, + "searched_artifact_count": 3, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": { + "kind": "task", + "task_id": str(owner["task_id"]), + }, + }, + } + + assert artifact_retrieve_status == 200 + assert artifact_retrieve_payload == { + "items": [ + { + "id": artifact_retrieve_payload["items"][0]["id"], + "task_id": str(owner["task_id"]), + "task_artifact_id": notes_register_payload["artifact"]["id"], + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + } + ], + "summary": { + "total_count": 1, + "searched_artifact_count": 1, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": { + "kind": "artifact", + "task_id": str(owner["task_id"]), + "task_artifact_id": notes_register_payload["artifact"]["id"], + }, + }, + } + + assert empty_retrieve_status == 200 + assert empty_retrieve_payload == { + "items": [], + "summary": { + "total_count": 0, + "searched_artifact_count": 3, + "query": "missing", + "query_terms": ["missing"], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": { + "kind": "task", + "task_id": str(owner["task_id"]), + }, + }, + } + + assert isolated_task_retrieve_status == 404 + assert isolated_task_retrieve_payload == { + "detail": f"task {owner['task_id']} was not found" + } + + assert isolated_artifact_retrieve_status == 404 + assert isolated_artifact_retrieve_payload == { + "detail": f"task artifact {notes_register_payload['artifact']['id']} was not found" + } diff --git a/tests/unit/test_artifacts.py b/tests/unit/test_artifacts.py index e6ed44c..07dc3de 100644 --- a/tests/unit/test_artifacts.py +++ b/tests/unit/test_artifacts.py @@ -8,33 +8,66 @@ from alicebot_api.artifacts import ( TASK_ARTIFACT_CHUNKING_RULE, + TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, TaskArtifactAlreadyExistsError, + TaskArtifactChunkRetrievalValidationError, TaskArtifactNotFoundError, TaskArtifactValidationError, build_workspace_relative_artifact_path, chunk_normalized_artifact_text, ensure_artifact_path_is_rooted, + extract_unique_lexical_terms, get_task_artifact_record, ingest_task_artifact_record, list_task_artifact_chunk_records, list_task_artifact_records, + match_artifact_chunk_text, normalize_artifact_text, register_task_artifact_record, + retrieve_artifact_scoped_artifact_chunk_records, + retrieve_task_scoped_artifact_chunk_records, serialize_task_artifact_row, ) -from alicebot_api.contracts import TaskArtifactIngestInput, TaskArtifactRegisterInput +from alicebot_api.contracts import ( + ArtifactScopedArtifactChunkRetrievalInput, + TaskArtifactIngestInput, + TaskArtifactRegisterInput, + TaskScopedArtifactChunkRetrievalInput, +) +from alicebot_api.tasks import TaskNotFoundError from alicebot_api.workspaces import TaskWorkspaceNotFoundError class ArtifactStoreStub: def __init__(self) -> None: self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.tasks: list[dict[str, object]] = [] self.workspaces: list[dict[str, object]] = [] self.artifacts: list[dict[str, object]] = [] self.artifact_chunks: list[dict[str, object]] = [] self.locked_workspace_ids: list[UUID] = [] self.locked_artifact_ids: list[UUID] = [] + def create_task(self, *, task_id: UUID, user_id: UUID) -> dict[str, object]: + task = { + "id": task_id, + "user_id": user_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "status": "approved", + "request": {}, + "tool": {}, + "latest_approval_id": None, + "latest_execution_id": None, + "created_at": self.base_time, + "updated_at": self.base_time, + } + self.tasks.append(task) + return task + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + def create_task_workspace(self, *, task_workspace_id: UUID, task_id: UUID, user_id: UUID, local_path: str) -> dict[str, object]: workspace = { "id": task_workspace_id, @@ -98,6 +131,12 @@ def create_task_artifact( def list_task_artifacts(self) -> list[dict[str, object]]: return sorted(self.artifacts, key=lambda artifact: (artifact["created_at"], artifact["id"])) + def list_task_artifacts_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + (artifact for artifact in self.artifacts if artifact["task_id"] == task_id), + key=lambda artifact: (artifact["created_at"], artifact["id"]), + ) + def get_task_artifact_optional(self, task_artifact_id: UUID) -> dict[str, object] | None: return next((artifact for artifact in self.artifacts if artifact["id"] == task_artifact_id), None) @@ -670,6 +709,343 @@ def test_list_task_artifact_chunk_records_are_deterministic() -> None: } +def test_extract_unique_lexical_terms_preserves_first_occurrence_order() -> None: + assert extract_unique_lexical_terms("Alpha beta, alpha\nbeta gamma") == [ + "alpha", + "beta", + "gamma", + ] + + +def test_match_artifact_chunk_text_returns_explicit_metadata() -> None: + assert match_artifact_chunk_text( + query_terms=["alpha", "beta", "delta"], + chunk_text="beta alpha release", + ) == { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + } + + +def test_task_scoped_chunk_retrieval_orders_matches_deterministically_and_skips_pending() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + docs_artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="docs/a.txt", + media_type_hint="text/plain", + ) + notes_artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="notes/b.md", + media_type_hint="text/markdown", + ) + pending_artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="notes/hidden.txt", + media_type_hint="text/plain", + ) + weak_match_artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="notes/c.txt", + media_type_hint="text/plain", + ) + store.create_task_artifact_chunk( + task_artifact_id=docs_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=14, + text="beta alpha doc", + ) + store.create_task_artifact_chunk( + task_artifact_id=notes_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=15, + text="alpha beta note", + ) + store.create_task_artifact_chunk( + task_artifact_id=pending_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=17, + text="alpha beta hidden", + ) + store.create_task_artifact_chunk( + task_artifact_id=weak_match_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=9, + text="beta only", + ) + + assert retrieve_task_scoped_artifact_chunk_records( + store, + user_id=user_id, + request=TaskScopedArtifactChunkRetrievalInput( + task_id=task_id, + query="Alpha beta", + ), + ) == { + "items": [ + { + "id": str(store.artifact_chunks[0]["id"]), + "task_id": str(task_id), + "task_artifact_id": str(docs_artifact["id"]), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + { + "id": str(store.artifact_chunks[1]["id"]), + "task_id": str(task_id), + "task_artifact_id": str(notes_artifact["id"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + { + "id": str(store.artifact_chunks[3]["id"]), + "task_id": str(task_id), + "task_artifact_id": str(weak_match_artifact["id"]), + "relative_path": "notes/c.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 9, + "text": "beta only", + "match": { + "matched_query_terms": ["beta"], + "matched_query_term_count": 1, + "first_match_char_start": 0, + }, + }, + ], + "summary": { + "total_count": 3, + "searched_artifact_count": 3, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": { + "kind": "task", + "task_id": str(task_id), + }, + }, + } + + +def test_artifact_scoped_chunk_retrieval_returns_empty_for_non_ingested_artifact() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=10, + text="alpha beta", + ) + + assert retrieve_artifact_scoped_artifact_chunk_records( + store, + user_id=user_id, + request=ArtifactScopedArtifactChunkRetrievalInput( + task_artifact_id=artifact["id"], + query="alpha", + ), + ) == { + "items": [], + "summary": { + "total_count": 0, + "searched_artifact_count": 0, + "query": "alpha", + "query_terms": ["alpha"], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": { + "kind": "artifact", + "task_id": str(task_id), + "task_artifact_id": str(artifact["id"]), + }, + }, + } + + +def test_task_scoped_chunk_retrieval_returns_empty_when_no_chunks_match() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=11, + text="release plan", + ) + + response = retrieve_task_scoped_artifact_chunk_records( + store, + user_id=user_id, + request=TaskScopedArtifactChunkRetrievalInput( + task_id=task_id, + query="alpha", + ), + ) + + assert response == { + "items": [], + "summary": { + "total_count": 0, + "searched_artifact_count": 1, + "query": "alpha", + "query_terms": ["alpha"], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": { + "kind": "task", + "task_id": str(task_id), + }, + }, + } + + +def test_task_scoped_chunk_retrieval_raises_when_task_is_missing() -> None: + with pytest.raises(TaskNotFoundError, match="was not found"): + retrieve_task_scoped_artifact_chunk_records( + ArtifactStoreStub(), + user_id=uuid4(), + request=TaskScopedArtifactChunkRetrievalInput( + task_id=uuid4(), + query="alpha", + ), + ) + + +def test_artifact_chunk_retrieval_rejects_query_without_words() -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="ingested", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + + with pytest.raises( + TaskArtifactChunkRetrievalValidationError, + match="must include at least one word", + ): + retrieve_artifact_scoped_artifact_chunk_records( + store, + user_id=user_id, + request=ArtifactScopedArtifactChunkRetrievalInput( + task_artifact_id=artifact["id"], + query=" ... ", + ), + ) + + def test_list_and_get_task_artifact_records_are_deterministic() -> None: store = ArtifactStoreStub() user_id = uuid4() diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py index 9e6e1b7..a009b60 100644 --- a/tests/unit/test_artifacts_main.py +++ b/tests/unit/test_artifacts_main.py @@ -8,9 +8,11 @@ from apps.api.src.alicebot_api.config import Settings from alicebot_api.artifacts import ( TaskArtifactAlreadyExistsError, + TaskArtifactChunkRetrievalValidationError, TaskArtifactNotFoundError, TaskArtifactValidationError, ) +from alicebot_api.tasks import TaskNotFoundError from alicebot_api.workspaces import TaskWorkspaceNotFoundError @@ -105,6 +107,159 @@ def fake_user_connection(*_args, **_kwargs): } +def test_retrieve_task_artifact_chunks_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_task_scoped_artifact_chunk_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "total_count": 0, + "searched_artifact_count": 1, + "query": "alpha", + "query_terms": ["alpha"], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": {"kind": "task", "task_id": str(task_id)}, + }, + }, + ) + + response = main_module.retrieve_task_artifact_chunks( + task_id, + main_module.RetrieveArtifactChunksRequest(user_id=user_id, query="alpha"), + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": { + "total_count": 0, + "searched_artifact_count": 1, + "query": "alpha", + "query_terms": ["alpha"], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "scope": {"kind": "task", "task_id": str(task_id)}, + }, + } + + +def test_retrieve_task_artifact_chunks_endpoint_maps_task_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_retrieve_task_scoped_artifact_chunk_records(*_args, **_kwargs): + raise TaskNotFoundError(f"task {task_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_task_scoped_artifact_chunk_records", + fake_retrieve_task_scoped_artifact_chunk_records, + ) + + response = main_module.retrieve_task_artifact_chunks( + task_id, + main_module.RetrieveArtifactChunksRequest(user_id=user_id, query="alpha"), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task {task_id} was not found"} + + +def test_retrieve_task_artifact_chunks_endpoint_maps_validation_to_400(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_retrieve_task_scoped_artifact_chunk_records(*_args, **_kwargs): + raise TaskArtifactChunkRetrievalValidationError( + "artifact chunk retrieval query must include at least one word" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_task_scoped_artifact_chunk_records", + fake_retrieve_task_scoped_artifact_chunk_records, + ) + + response = main_module.retrieve_task_artifact_chunks( + task_id, + main_module.RetrieveArtifactChunksRequest(user_id=user_id, query="alpha"), + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "artifact chunk retrieval query must include at least one word" + } + + +def test_retrieve_artifact_chunk_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_artifact_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_retrieve_artifact_scoped_artifact_chunk_records(*_args, **_kwargs): + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_artifact_scoped_artifact_chunk_records", + fake_retrieve_artifact_scoped_artifact_chunk_records, + ) + + response = main_module.retrieve_task_artifact_chunks_for_artifact( + task_artifact_id, + main_module.RetrieveArtifactChunksRequest(user_id=user_id, query="alpha"), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"task artifact {task_artifact_id} was not found" + } + + def test_register_task_artifact_endpoint_maps_workspace_not_found_to_404(monkeypatch) -> None: user_id = uuid4() task_workspace_id = uuid4() diff --git a/tests/unit/test_task_artifact_store.py b/tests/unit/test_task_artifact_store.py index df841c0..938c680 100644 --- a/tests/unit/test_task_artifact_store.py +++ b/tests/unit/test_task_artifact_store.py @@ -112,12 +112,14 @@ def test_task_artifact_store_methods_use_expected_queries() -> None: relative_path="docs/spec.txt", ) listed = store.list_task_artifacts() + listed_for_task = store.list_task_artifacts_for_task(task_id) store.lock_task_artifacts(task_workspace_id) assert created["id"] == task_artifact_id assert fetched is not None assert duplicate is not None assert listed[0]["id"] == task_artifact_id + assert listed_for_task[0]["id"] == task_artifact_id assert cursor.executed == [ ( """ @@ -221,6 +223,25 @@ def test_task_artifact_store_methods_use_expected_queries() -> None: """, None, ), + ( + """ + SELECT + id, + user_id, + task_id, + task_workspace_id, + status, + ingestion_status, + relative_path, + media_type_hint, + created_at, + updated_at + FROM task_artifacts + WHERE task_id = %s + ORDER BY created_at ASC, id ASC + """, + (task_id,), + ), ( "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 4))", (str(task_workspace_id),), From a740d52d990ca85e4e525e7cfe19460b76b60d60 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 14 Mar 2026 22:29:24 +0100 Subject: [PATCH 006/135] Sprint 5F: artifact chunk compile integration (#6) Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 32 +- apps/api/src/alicebot_api/compiler.py | 228 ++++++++++++ apps/api/src/alicebot_api/contracts.py | 86 ++++- apps/api/src/alicebot_api/main.py | 63 +++- tests/integration/test_context_compile.py | 424 ++++++++++++++++++++++ tests/unit/test_compiler.py | 234 +++++++++++- tests/unit/test_main.py | 132 ++++++- 7 files changed, 1183 insertions(+), 16 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 9899d45..b362714 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5D. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5F. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, and `task_artifact_chunks`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, and deterministic artifact plus chunk reads +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, and `task_artifact_chunks`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, and optional compile-path artifact chunk inclusion as a separate context section -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations plus narrow deterministic text ingestion under those workspaces. Broader runner-style orchestration, automatic multi-step progression, retrieval over artifact chunks, embeddings, rich-document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, and optional compile-path inclusion of retrieved artifact chunks in a separate response section. Broader runner-style orchestration, automatic multi-step progression, artifact chunk embeddings and semantic retrieval, rich-document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -24,7 +24,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` -- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` +- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` - `apps/web` and `workers` remain starter shells only. ### Data Foundation @@ -57,11 +57,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, and narrow local artifact chunk ingestion. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, deterministic lexical artifact chunk retrieval, and narrow compile-path artifact chunk inclusion. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, and Sprint 5D local artifact ingestion plus chunk reads. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, and Sprint 5F compile-path artifact chunk integration. ## Core Flows Implemented Now @@ -70,8 +70,10 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 1. Accept a user-scoped `POST /v0/context/compile` request. 2. Read durable continuity records in deterministic order. 3. Merge in active memories, entities, and entity edges through the currently shipped symbolic and optional semantic retrieval paths. -4. Persist a `context.compile` trace plus explicit inclusion and exclusion events. -5. Return one deterministic `context_pack` describing scope, limits, selected context, and trace metadata. +4. Optionally retrieve artifact chunks through the existing lexical artifact-chunk retrieval seam, scoped to exactly one visible task or one visible artifact per request. +5. Keep retrieved artifact chunks separate from memory and entity sections, with deterministic per-section limits and ordering. +6. Persist a `context.compile` trace plus explicit inclusion and exclusion events, including artifact chunk include/exclude decisions. +7. Return one deterministic `context_pack` describing scope, limits, selected context, artifact chunk results, and trace metadata. ### Governed Memory And Retrieval @@ -190,6 +192,16 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 10. If the artifact is already ingested, return the existing artifact and chunk summary without reinserting chunks. 11. `GET /v0/task-artifacts/{task_artifact_id}/chunks` returns visible chunk rows in deterministic `sequence_no ASC, id ASC` order plus stable summary metadata. +### Artifact Chunk Retrieval + +1. Accept a user-scoped retrieval request scoped to exactly one visible task or one visible artifact. +2. Normalize the query deterministically by casefolding and extracting unique lexical `\w+` terms in first-occurrence order. +3. Read only persisted `task_artifact_chunks` rows for visible artifacts; compile and retrieval paths do not read raw files. +4. Exclude artifacts whose `ingestion_status != 'ingested'`. +5. Match chunks by lexical query-term overlap and record match metadata including matched query terms and first match offset. +6. Order matches deterministically by matched query term count desc, first match offset asc, relative path asc, sequence no asc, and id asc. +7. Return stable summary metadata describing query terms, scope, searched artifact count, and ordering. + ## Security Model Implemented Now - User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, task-workspace, task-artifact, and task-artifact-chunk tables enforce row-level security. @@ -224,6 +236,8 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - deterministic line-ending normalization and fixed-window chunk boundaries - invalid UTF-8 rejection - idempotent re-ingestion of already ingested artifacts + - deterministic lexical artifact-chunk retrieval by task and by artifact + - compile-path artifact chunk inclusion, exclusion, ordering, and per-user isolation - task-artifact and task-artifact-chunk per-user isolation - trace visibility for continuation and transition events - user isolation for task and task-step reads and mutations @@ -234,7 +248,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i The following areas remain planned later and must not be described as implemented: - runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam -- retrieval over artifact chunks, chunk ranking, and embeddings beyond the current explicit rooted local ingestion boundary +- artifact chunk ranking beyond the current lexical match ordering, plus embeddings and semantic retrieval for artifact chunks - rich document parsing beyond the current narrow UTF-8 text and markdown ingestion boundary - read-only Gmail and Calendar connectors - broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler diff --git a/apps/api/src/alicebot_api/compiler.py b/apps/api/src/alicebot_api/compiler.py index 3626eed..770319c 100644 --- a/apps/api/src/alicebot_api/compiler.py +++ b/apps/api/src/alicebot_api/compiler.py @@ -5,10 +5,16 @@ from alicebot_api.contracts import ( COMPILER_VERSION_V0, + ArtifactRetrievalDecisionTracePayload, CompilerDecision, + CompileContextArtifactRetrievalInput, + CompileContextArtifactScopedArtifactRetrievalInput, CompileContextSemanticRetrievalInput, + CompileContextTaskScopedArtifactRetrievalInput, CompilerRunResult, CompiledContextPack, + ContextPackArtifactChunk, + ContextPackArtifactChunkSummary, ContextCompilerLimits, ContextPackHybridMemorySummary, ContextPackMemory, @@ -16,11 +22,20 @@ HybridMemoryDecisionTracePayload, MemorySelectionSource, SEMANTIC_MEMORY_RETRIEVAL_ORDER, + TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER, SemanticMemoryRetrievalRequestInput, TRACE_KIND_CONTEXT_COMPILE, TraceEventRecord, isoformat_or_none, ) +from alicebot_api.artifacts import ( + TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + TaskArtifactNotFoundError, + build_task_artifact_chunk_retrieval_scope, + infer_task_artifact_media_type, + resolve_artifact_chunk_retrieval_query_terms, + retrieve_matching_task_artifact_chunks, +) from alicebot_api.semantic_retrieval import validate_semantic_memory_retrieval_request from alicebot_api.store import ( ContinuityStore, @@ -33,6 +48,7 @@ ThreadRow, UserRow, ) +from alicebot_api.tasks import TaskNotFoundError SUMMARY_TRACE_EVENT_KIND = "context.summary" _UNBOUNDED_SEMANTIC_RETRIEVAL_LIMIT = 2_147_483_647 @@ -54,6 +70,13 @@ class CompiledMemorySection: decisions: list[CompilerDecision] +@dataclass(frozen=True, slots=True) +class CompiledArtifactChunkSection: + items: list[ContextPackArtifactChunk] + summary: ContextPackArtifactChunkSummary + decisions: list[CompilerDecision] + + @dataclass(slots=True) class HybridMemoryCandidate: memory: MemoryRow @@ -197,6 +220,59 @@ def _empty_hybrid_memory_summary() -> ContextPackHybridMemorySummary: } +def _empty_artifact_chunk_summary() -> ContextPackArtifactChunkSummary: + return { + "requested": False, + "scope": None, + "query": None, + "query_terms": [], + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), + } + + +def _artifact_retrieval_decision_metadata( + *, + scope_kind: str, + task_id: UUID, + task_artifact_id: UUID, + relative_path: str, + media_type: str | None, + ingestion_status: str, + limit: int, + match: dict[str, object] | None = None, + sequence_no: int | None = None, + char_start: int | None = None, + char_end_exclusive: int | None = None, +) -> ArtifactRetrievalDecisionTracePayload: + payload: ArtifactRetrievalDecisionTracePayload = { + "scope_kind": scope_kind, # type: ignore[typeddict-item] + "task_id": str(task_id), + "task_artifact_id": str(task_artifact_id), + "relative_path": relative_path, + "media_type": media_type, + "ingestion_status": ingestion_status, # type: ignore[typeddict-item] + "limit": limit, + } + if match is not None: + payload["matched_query_terms"] = list(match["matched_query_terms"]) # type: ignore[index] + payload["matched_query_term_count"] = int(match["matched_query_term_count"]) # type: ignore[index] + payload["first_match_char_start"] = int(match["first_match_char_start"]) # type: ignore[index] + if sequence_no is not None: + payload["sequence_no"] = sequence_no + if char_start is not None: + payload["char_start"] = char_start + if char_end_exclusive is not None: + payload["char_end_exclusive"] = char_end_exclusive + return payload + + def _hybrid_memory_decision_metadata( *, embedding_config_id: UUID | None, @@ -492,6 +568,125 @@ def _compile_memory_section( ) +def _compile_artifact_chunk_section( + store: ContinuityStore, + *, + artifact_retrieval: CompileContextArtifactRetrievalInput | None, +) -> CompiledArtifactChunkSection: + if artifact_retrieval is None: + return CompiledArtifactChunkSection( + items=[], + summary=_empty_artifact_chunk_summary(), + decisions=[], + ) + + if isinstance(artifact_retrieval, CompileContextTaskScopedArtifactRetrievalInput): + task = store.get_task_optional(artifact_retrieval.task_id) + if task is None: + raise TaskNotFoundError(f"task {artifact_retrieval.task_id} was not found") + artifact_rows = store.list_task_artifacts_for_task(artifact_retrieval.task_id) + scope = build_task_artifact_chunk_retrieval_scope( + kind="task", + task_id=artifact_retrieval.task_id, + ) + scope_kind = "task" + else: + artifact_row = store.get_task_artifact_optional(artifact_retrieval.task_artifact_id) + if artifact_row is None: + raise TaskArtifactNotFoundError( + f"task artifact {artifact_retrieval.task_artifact_id} was not found" + ) + artifact_rows = [artifact_row] + scope = build_task_artifact_chunk_retrieval_scope( + kind="artifact", + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + ) + scope_kind = "artifact" + + query_terms = resolve_artifact_chunk_retrieval_query_terms(artifact_retrieval.query) + matched_items, searched_artifact_count = retrieve_matching_task_artifact_chunks( + store, + artifact_rows=artifact_rows, + query_terms=query_terms, + ) + included_items = matched_items[: artifact_retrieval.limit] + excluded_uningested_artifact_count = 0 + decisions: list[CompilerDecision] = [] + + for position, artifact_row in enumerate(artifact_rows, start=1): + if artifact_row["ingestion_status"] == "ingested": + continue + excluded_uningested_artifact_count += 1 + decisions.append( + CompilerDecision( + "excluded", + "task_artifact", + artifact_row["id"], + "artifact_not_ingested", + position, + metadata=_artifact_retrieval_decision_metadata( + scope_kind=scope_kind, + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + relative_path=artifact_row["relative_path"], + media_type=infer_task_artifact_media_type(artifact_row), + ingestion_status=artifact_row["ingestion_status"], + limit=artifact_retrieval.limit, + ), + ) + ) + + for position, item in enumerate(matched_items, start=1): + decision_kind = "included" if position <= artifact_retrieval.limit else "excluded" + decision_reason = ( + "within_artifact_chunk_limit" + if position <= artifact_retrieval.limit + else "artifact_chunk_limit_exceeded" + ) + decisions.append( + CompilerDecision( + decision_kind, + "artifact_chunk", + UUID(item["id"]), + decision_reason, + position, + metadata=_artifact_retrieval_decision_metadata( + scope_kind=scope_kind, + task_id=UUID(item["task_id"]), + task_artifact_id=UUID(item["task_artifact_id"]), + relative_path=item["relative_path"], + media_type=item["media_type"], + ingestion_status="ingested", + limit=artifact_retrieval.limit, + match=item["match"], + sequence_no=item["sequence_no"], + char_start=item["char_start"], + char_end_exclusive=item["char_end_exclusive"], + ), + ) + ) + + return CompiledArtifactChunkSection( + items=list(included_items), + summary={ + "requested": True, + "scope": scope, + "query": artifact_retrieval.query, + "query_terms": list(query_terms), + "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, + "limit": artifact_retrieval.limit, + "searched_artifact_count": searched_artifact_count, + "candidate_count": len(matched_items), + "included_count": len(included_items), + "excluded_uningested_artifact_count": excluded_uningested_artifact_count, + "excluded_limit_count": max(len(matched_items) - len(included_items), 0), + "order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), + }, + decisions=decisions, + ) + + def compile_continuity_context( *, user: UserRow, @@ -503,6 +698,7 @@ def compile_continuity_context( entity_edges: list[EntityEdgeRow], limits: ContextCompilerLimits, memory_section: CompiledMemorySection | None = None, + artifact_chunk_section: CompiledArtifactChunkSection | None = None, ) -> CompilerRunResult: latest_session_sequence: dict[UUID, int] = {} for event in events: @@ -595,6 +791,12 @@ def compile_continuity_context( limits=limits, ) decisions.extend(resolved_memory_section.decisions) + resolved_artifact_chunk_section = artifact_chunk_section or CompiledArtifactChunkSection( + items=[], + summary=_empty_artifact_chunk_summary(), + decisions=[], + ) + decisions.extend(resolved_artifact_chunk_section.decisions) ordered_entities = sorted(entities, key=_entity_sort_key) included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] included_entity_ids = {entity["id"] for entity in included_entities} @@ -725,6 +927,24 @@ def compile_continuity_context( "included_dual_source_memory_count": resolved_memory_section.summary[ "hybrid_retrieval" ]["included_dual_source_count"], + "artifact_retrieval_requested": resolved_artifact_chunk_section.summary["requested"], + "artifact_retrieval_scope_kind": ( + None + if resolved_artifact_chunk_section.summary["scope"] is None + else resolved_artifact_chunk_section.summary["scope"]["kind"] + ), + "artifact_chunk_candidate_count": resolved_artifact_chunk_section.summary[ + "candidate_count" + ], + "included_artifact_chunk_count": resolved_artifact_chunk_section.summary[ + "included_count" + ], + "excluded_artifact_chunk_limit_count": resolved_artifact_chunk_section.summary[ + "excluded_limit_count" + ], + "excluded_uningested_artifact_count": resolved_artifact_chunk_section.summary[ + "excluded_uningested_artifact_count" + ], "included_entity_count": len(included_entities), "excluded_entity_count": excluded_entity_limit_count, "excluded_entity_limit_count": excluded_entity_limit_count, @@ -756,6 +976,8 @@ def compile_continuity_context( "events": [_serialize_event(event) for event in included_events], "memories": list(resolved_memory_section.items), "memory_summary": resolved_memory_section.summary, + "artifact_chunks": list(resolved_artifact_chunk_section.items), + "artifact_chunk_summary": resolved_artifact_chunk_section.summary, "entities": [_serialize_entity(entity) for entity in included_entities], "entity_summary": { "candidate_count": len(ordered_entities), @@ -781,6 +1003,7 @@ def compile_and_persist_trace( thread_id: UUID, limits: ContextCompilerLimits, semantic_retrieval: CompileContextSemanticRetrievalInput | None = None, + artifact_retrieval: CompileContextArtifactRetrievalInput | None = None, ) -> CompiledTraceRun: user = store.get_user(user_id) thread = store.get_thread(thread_id) @@ -793,6 +1016,10 @@ def compile_and_persist_trace( limits=limits, semantic_retrieval=semantic_retrieval, ) + artifact_chunk_section = _compile_artifact_chunk_section( + store, + artifact_retrieval=artifact_retrieval, + ) entities = store.list_entities() ordered_entities = sorted(entities, key=_entity_sort_key) included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] @@ -807,6 +1034,7 @@ def compile_and_persist_trace( entity_edges=entity_edges, limits=limits, memory_section=memory_section, + artifact_chunk_section=artifact_chunk_section, ) trace = store.create_trace( user_id=user_id, diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index c86549c..86e7934 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Literal, NotRequired, TypedDict +from typing import Literal, NotRequired, TypeAlias, TypedDict from uuid import UUID from alicebot_api.store import JsonObject, JsonValue @@ -89,6 +89,8 @@ MAX_MEMORY_REVIEW_LIMIT = 100 DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT = 5 MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT = 50 +DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT = 5 +MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT = 50 COMPILER_VERSION_V0 = "continuity_v0" PROMPT_ASSEMBLY_VERSION_V0 = "prompt_assembly_v0" RESPONSE_GENERATION_VERSION_V0 = "response_generation_v0" @@ -201,6 +203,42 @@ def as_payload(self) -> JsonObject: } +@dataclass(frozen=True, slots=True) +class CompileContextTaskScopedArtifactRetrievalInput: + task_id: UUID + query: str + limit: int = DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "kind": "task", + "task_id": str(self.task_id), + "query": self.query, + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class CompileContextArtifactScopedArtifactRetrievalInput: + task_artifact_id: UUID + query: str + limit: int = DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "kind": "artifact", + "task_artifact_id": str(self.task_artifact_id), + "query": self.query, + "limit": self.limit, + } + + +CompileContextArtifactRetrievalInput: TypeAlias = ( + CompileContextTaskScopedArtifactRetrievalInput + | CompileContextArtifactScopedArtifactRetrievalInput +) + + @dataclass(frozen=True, slots=True) class TraceCreate: user_id: UUID @@ -316,6 +354,50 @@ class ContextPackHybridMemorySummary(TypedDict): semantic_order: list[str] +class ContextPackArtifactChunk(TypedDict): + id: str + task_id: str + task_artifact_id: str + relative_path: str + media_type: str + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + match: "TaskArtifactChunkRetrievalMatch" + + +class ContextPackArtifactChunkSummary(TypedDict): + requested: bool + scope: TaskArtifactChunkRetrievalScope | None + query: str | None + query_terms: list[str] + matching_rule: str + limit: int + searched_artifact_count: int + candidate_count: int + included_count: int + excluded_uningested_artifact_count: int + excluded_limit_count: int + order: list[str] + + +class ArtifactRetrievalDecisionTracePayload(TypedDict): + scope_kind: TaskArtifactChunkRetrievalScopeKind + task_id: str + task_artifact_id: str + relative_path: str + media_type: str | None + ingestion_status: TaskArtifactIngestionStatus + limit: int + matched_query_terms: NotRequired[list[str]] + matched_query_term_count: NotRequired[int] + first_match_char_start: NotRequired[int] + sequence_no: NotRequired[int] + char_start: NotRequired[int] + char_end_exclusive: NotRequired[int] + + class ContextPackMemorySummary(TypedDict): candidate_count: int included_count: int @@ -399,6 +481,8 @@ class CompiledContextPack(TypedDict): events: list[ContextPackEvent] memories: list[ContextPackMemory] memory_summary: ContextPackMemorySummary + artifact_chunks: list[ContextPackArtifactChunk] + artifact_chunk_summary: ContextPackArtifactChunkSummary entities: list[ContextPackEntity] entity_summary: ContextPackEntitySummary entity_edges: list[ContextPackEntityEdge] diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index 982becb..bab17c2 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -1,11 +1,11 @@ from __future__ import annotations from datetime import datetime -from typing import Literal, TypedDict +from typing import Annotated, Literal, TypedDict from uuid import UUID from fastapi import FastAPI, Query from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from fastapi.responses import JSONResponse from urllib.parse import urlsplit, urlunsplit @@ -15,9 +15,12 @@ ApprovalApproveInput, ApprovalRejectInput, ApprovalRequestCreateInput, + CompileContextArtifactScopedArtifactRetrievalInput, + CompileContextTaskScopedArtifactRetrievalInput, ConsentStatus, ConsentUpsertInput, CompileContextSemanticRetrievalInput, + DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, DEFAULT_MAX_EVENTS, DEFAULT_MAX_ENTITY_EDGES, DEFAULT_MAX_ENTITIES, @@ -26,6 +29,7 @@ DEFAULT_MAX_SESSIONS, DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, MAX_MEMORY_REVIEW_LIMIT, + MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, ContextCompilerLimits, EmbeddingConfigStatus, @@ -242,6 +246,39 @@ class CompileContextSemanticRequest(BaseModel): ) +class CompileContextTaskScopedArtifactRetrievalRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + kind: Literal["task"] + task_id: UUID + query: str = Field(min_length=1, max_length=4000) + limit: int = Field( + default=DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ge=1, + le=MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ) + + +class CompileContextArtifactScopedArtifactRetrievalRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + kind: Literal["artifact"] + task_artifact_id: UUID + query: str = Field(min_length=1, max_length=4000) + limit: int = Field( + default=DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ge=1, + le=MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ) + + +CompileContextArtifactRetrievalRequest = Annotated[ + CompileContextTaskScopedArtifactRetrievalRequest + | CompileContextArtifactScopedArtifactRetrievalRequest, + Field(discriminator="kind"), +] + + class CompileContextRequest(BaseModel): user_id: UUID thread_id: UUID @@ -251,6 +288,7 @@ class CompileContextRequest(BaseModel): max_entities: int = Field(default=DEFAULT_MAX_ENTITIES, ge=0, le=50) max_entity_edges: int = Field(default=DEFAULT_MAX_ENTITY_EDGES, ge=0, le=100) semantic: CompileContextSemanticRequest | None = None + artifact_retrieval: CompileContextArtifactRetrievalRequest | None = None class GenerateResponseRequest(BaseModel): @@ -547,6 +585,22 @@ def healthcheck() -> JSONResponse: @app.post("/v0/context/compile") def compile_context(request: CompileContextRequest) -> JSONResponse: settings = get_settings() + artifact_retrieval = None + if isinstance(request.artifact_retrieval, CompileContextTaskScopedArtifactRetrievalRequest): + artifact_retrieval = CompileContextTaskScopedArtifactRetrievalInput( + task_id=request.artifact_retrieval.task_id, + query=request.artifact_retrieval.query, + limit=request.artifact_retrieval.limit, + ) + elif isinstance( + request.artifact_retrieval, + CompileContextArtifactScopedArtifactRetrievalRequest, + ): + artifact_retrieval = CompileContextArtifactScopedArtifactRetrievalInput( + task_artifact_id=request.artifact_retrieval.task_artifact_id, + query=request.artifact_retrieval.query, + limit=request.artifact_retrieval.limit, + ) try: with user_connection(settings.database_url, request.user_id) as conn: @@ -570,9 +624,14 @@ def compile_context(request: CompileContextRequest) -> JSONResponse: limit=request.semantic.limit, ) ), + artifact_retrieval=artifact_retrieval, ) + except TaskArtifactChunkRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) except SemanticMemoryRetrievalValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) + except (TaskNotFoundError, TaskArtifactNotFoundError) as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) except ContinuityStoreInvariantError as exc: return JSONResponse(status_code=404, content={"detail": str(exc)}) diff --git a/tests/integration/test_context_compile.py b/tests/integration/test_context_compile.py index f86bfe7..4b6913b 100644 --- a/tests/integration/test_context_compile.py +++ b/tests/integration/test_context_compile.py @@ -278,6 +278,120 @@ def seed_memory_embedding_for_user( ) +def seed_compile_artifact_scope( + database_url: str, + *, + user_id: UUID, + thread_id: UUID, +) -> dict[str, object]: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="artifact.search", + name="Artifact Search", + description="Compile artifact retrieval fixture", + version="2026-03-14", + metadata_version="tool_metadata_v0", + active=True, + tags=[], + action_hints=["retrieve"], + scope_hints=["task"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + task = store.create_task( + thread_id=thread_id, + tool_id=tool["id"], + status="approved", + request={"action": "retrieve"}, + tool={"tool_key": "artifact.search"}, + latest_approval_id=None, + latest_execution_id=None, + ) + workspace = store.create_task_workspace( + task_id=task["id"], + status="active", + local_path=f"/tmp/alicebot/{task['id']}", + ) + docs_artifact = store.create_task_artifact( + task_id=task["id"], + task_workspace_id=workspace["id"], + status="registered", + ingestion_status="ingested", + relative_path="docs/a.txt", + media_type_hint="text/plain", + ) + notes_artifact = store.create_task_artifact( + task_id=task["id"], + task_workspace_id=workspace["id"], + status="registered", + ingestion_status="ingested", + relative_path="notes/b.md", + media_type_hint="text/markdown", + ) + pending_artifact = store.create_task_artifact( + task_id=task["id"], + task_workspace_id=workspace["id"], + status="registered", + ingestion_status="pending", + relative_path="notes/hidden.txt", + media_type_hint="text/plain", + ) + weak_artifact = store.create_task_artifact( + task_id=task["id"], + task_workspace_id=workspace["id"], + status="registered", + ingestion_status="ingested", + relative_path="notes/c.txt", + media_type_hint="text/plain", + ) + docs_chunk = store.create_task_artifact_chunk( + task_artifact_id=docs_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=14, + text="beta alpha doc", + ) + notes_chunk = store.create_task_artifact_chunk( + task_artifact_id=notes_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=15, + text="alpha beta note", + ) + pending_chunk = store.create_task_artifact_chunk( + task_artifact_id=pending_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=17, + text="alpha beta hidden", + ) + weak_chunk = store.create_task_artifact_chunk( + task_artifact_id=weak_artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=9, + text="beta only", + ) + + return { + "task_id": task["id"], + "artifact_ids": { + "docs": docs_artifact["id"], + "notes": notes_artifact["id"], + "pending": pending_artifact["id"], + "weak": weak_artifact["id"], + }, + "chunk_ids": { + "docs": docs_chunk["id"], + "notes": notes_chunk["id"], + "pending": pending_chunk["id"], + "weak": weak_chunk["id"], + }, + } + + def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_database_urls, monkeypatch) -> None: seeded = seed_traceable_thread(migrated_database_urls["app"]) user_id = seeded["user_id"] @@ -798,6 +912,316 @@ def test_compile_context_semantic_validation_rejects_missing_config_dimension_mi ) +def test_compile_context_artifact_retrieval_integrates_chunks_traces_and_exclusion_rules( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=seeded["user_id"], + thread_id=seeded["thread_id"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "artifact_retrieval": { + "kind": "task", + "task_id": str(artifact_scope["task_id"]), + "query": "Alpha beta", + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["context_pack"]["artifact_chunks"] == [ + { + "id": str(artifact_scope["chunk_ids"]["docs"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["docs"]), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + { + "id": str(artifact_scope["chunk_ids"]["notes"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + ] + assert payload["context_pack"]["artifact_chunk_summary"] == { + "requested": True, + "scope": {"kind": "task", "task_id": str(artifact_scope["task_id"])}, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 2, + "searched_artifact_count": 3, + "candidate_count": 3, + "included_count": 2, + "excluded_uningested_artifact_count": 1, + "excluded_limit_count": 1, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + } + assert payload["context_pack"]["memories"] + assert payload["context_pack"]["entities"] + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["docs"]) + and event["payload"]["relative_path"] == "docs/a.txt" + and event["payload"]["matched_query_terms"] == ["alpha", "beta"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "within_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) + and event["payload"]["relative_path"] == "notes/b.md" + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "artifact_chunk_limit_exceeded" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["weak"]) + and event["payload"]["relative_path"] == "notes/c.txt" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "artifact_not_ingested" + and event["payload"]["entity_id"] == str(artifact_scope["artifact_ids"]["pending"]) + and event["payload"]["relative_path"] == "notes/hidden.txt" + and event["payload"]["ingestion_status"] == "pending" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "task" + assert trace_events[-1]["payload"]["artifact_chunk_candidate_count"] == 3 + assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 2 + assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 1 + assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 1 + + +def test_compile_context_artifact_scoped_retrieval_returns_only_visible_artifact_chunks( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=seeded["user_id"], + thread_id=seeded["thread_id"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "artifact_retrieval": { + "kind": "artifact", + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "query": "Alpha beta", + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["context_pack"]["artifact_chunks"] == [ + { + "id": str(artifact_scope["chunk_ids"]["notes"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + } + ] + assert payload["context_pack"]["artifact_chunk_summary"] == { + "requested": True, + "scope": { + "kind": "artifact", + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + }, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 2, + "searched_artifact_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) + and event["payload"]["scope_kind"] == "artifact" + and event["payload"]["task_artifact_id"] == str(artifact_scope["artifact_ids"]["notes"]) + for event in trace_events + if event["kind"] == "context.included" + ) + assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "artifact" + assert trace_events[-1]["payload"]["artifact_chunk_candidate_count"] == 1 + assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 1 + assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 0 + assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 0 + + +def test_compile_context_artifact_retrieval_validation_and_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_traceable_thread(migrated_database_urls["app"]) + intruder = seed_traceable_thread( + migrated_database_urls["app"], + email="intruder@example.com", + display_name="Intruder", + ) + owner_artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=owner["user_id"], + thread_id=owner["thread_id"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + blank_query_status, blank_query_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "artifact_retrieval": { + "kind": "task", + "task_id": str(owner_artifact_scope["task_id"]), + "query": " ", + "limit": 2, + }, + } + ) + invalid_shape_status, invalid_shape_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "artifact_retrieval": { + "kind": "task", + "task_artifact_id": str(owner_artifact_scope["artifact_ids"]["docs"]), + "query": "alpha beta", + }, + } + ) + isolated_task_status, isolated_task_payload = invoke_compile_context( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "artifact_retrieval": { + "kind": "task", + "task_id": str(owner_artifact_scope["task_id"]), + "query": "alpha beta", + "limit": 2, + }, + } + ) + isolated_artifact_status, isolated_artifact_payload = invoke_compile_context( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "artifact_retrieval": { + "kind": "artifact", + "task_artifact_id": str(owner_artifact_scope["artifact_ids"]["docs"]), + "query": "alpha beta", + "limit": 2, + }, + } + ) + + assert blank_query_status == 400 + assert blank_query_payload == { + "detail": "artifact chunk retrieval query must include at least one word" + } + assert invalid_shape_status == 422 + assert "task_id" in json.dumps(invalid_shape_payload) + assert isolated_task_status == 404 + assert isolated_task_payload == { + "detail": f"task {owner_artifact_scope['task_id']} was not found" + } + assert isolated_artifact_status == 404 + assert isolated_artifact_payload == { + "detail": ( + "task artifact " + f"{owner_artifact_scope['artifact_ids']['docs']} was not found" + ) + } + + def test_traces_and_trace_events_respect_per_user_isolation(migrated_database_urls, monkeypatch) -> None: seeded = seed_traceable_thread(migrated_database_urls["app"]) owner_id = seeded["user_id"] diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index c221707..e0c2cae 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -5,10 +5,15 @@ from alicebot_api.compiler import ( SUMMARY_TRACE_EVENT_KIND, + _compile_artifact_chunk_section, _compile_memory_section, compile_continuity_context, ) -from alicebot_api.contracts import CompileContextSemanticRetrievalInput, ContextCompilerLimits +from alicebot_api.contracts import ( + CompileContextSemanticRetrievalInput, + CompileContextTaskScopedArtifactRetrievalInput, + ContextCompilerLimits, +) def test_compile_continuity_context_is_deterministic_and_stably_ordered() -> None: @@ -287,6 +292,27 @@ def test_compile_continuity_context_is_deterministic_and_stably_ordered() -> Non "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, } + assert first_run.context_pack["artifact_chunks"] == [] + assert first_run.context_pack["artifact_chunk_summary"] == { + "requested": False, + "scope": None, + "query": None, + "query_terms": [], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + } assert first_run.context_pack["entity_summary"] == { "candidate_count": 3, "included_count": 2, @@ -596,6 +622,27 @@ def test_compile_continuity_context_records_included_and_excluded_reasons() -> N "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, } + assert compiler_run.context_pack["artifact_chunks"] == [] + assert compiler_run.context_pack["artifact_chunk_summary"] == { + "requested": False, + "scope": None, + "query": None, + "query_terms": [], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + } assert compiler_run.context_pack["entities"] == [ { "id": str(kept_entity_id), @@ -629,6 +676,11 @@ def test_compile_continuity_context_records_included_and_excluded_reasons() -> N assert compiler_run.trace_events[-1].payload["hybrid_memory_candidate_count"] == 2 assert compiler_run.trace_events[-1].payload["hybrid_memory_merged_candidate_count"] == 1 assert compiler_run.trace_events[-1].payload["hybrid_memory_deduplicated_count"] == 0 + assert compiler_run.trace_events[-1].payload["artifact_retrieval_requested"] is False + assert compiler_run.trace_events[-1].payload["artifact_chunk_candidate_count"] == 0 + assert compiler_run.trace_events[-1].payload["included_artifact_chunk_count"] == 0 + assert compiler_run.trace_events[-1].payload["excluded_artifact_chunk_limit_count"] == 0 + assert compiler_run.trace_events[-1].payload["excluded_uningested_artifact_count"] == 0 class SemanticCompileStoreStub: @@ -690,6 +742,186 @@ def list_memory_embeddings_for_config(self, embedding_config_id): ] +class ArtifactCompileStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 14, 12, 0, tzinfo=UTC) + self.task_id = uuid4() + self.artifact_ids = [uuid4(), uuid4(), uuid4(), uuid4()] + self.chunk_ids = [uuid4(), uuid4(), uuid4()] + + def get_task_optional(self, task_id): + if task_id != self.task_id: + return None + return {"id": self.task_id} + + def list_task_artifacts_for_task(self, task_id): + assert task_id == self.task_id + return [ + { + "id": self.artifact_ids[0], + "task_id": self.task_id, + "task_workspace_id": uuid4(), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/a.txt", + "media_type_hint": "text/plain", + "created_at": self.base_time, + "updated_at": self.base_time, + }, + { + "id": self.artifact_ids[1], + "task_id": self.task_id, + "task_workspace_id": uuid4(), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "notes/b.md", + "media_type_hint": "text/markdown", + "created_at": self.base_time + timedelta(minutes=1), + "updated_at": self.base_time + timedelta(minutes=1), + }, + { + "id": self.artifact_ids[2], + "task_id": self.task_id, + "task_workspace_id": uuid4(), + "status": "registered", + "ingestion_status": "pending", + "relative_path": "notes/hidden.txt", + "media_type_hint": "text/plain", + "created_at": self.base_time + timedelta(minutes=2), + "updated_at": self.base_time + timedelta(minutes=2), + }, + { + "id": self.artifact_ids[3], + "task_id": self.task_id, + "task_workspace_id": uuid4(), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "notes/c.txt", + "media_type_hint": "text/plain", + "created_at": self.base_time + timedelta(minutes=3), + "updated_at": self.base_time + timedelta(minutes=3), + }, + ] + + def list_task_artifact_chunks(self, task_artifact_id): + if task_artifact_id == self.artifact_ids[0]: + return [ + { + "id": self.chunk_ids[0], + "task_artifact_id": task_artifact_id, + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "created_at": self.base_time, + "updated_at": self.base_time, + } + ] + if task_artifact_id == self.artifact_ids[1]: + return [ + { + "id": self.chunk_ids[1], + "task_artifact_id": task_artifact_id, + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "created_at": self.base_time, + "updated_at": self.base_time, + } + ] + if task_artifact_id == self.artifact_ids[3]: + return [ + { + "id": self.chunk_ids[2], + "task_artifact_id": task_artifact_id, + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 9, + "text": "beta only", + "created_at": self.base_time, + "updated_at": self.base_time, + } + ] + return [] + + +def test_compile_artifact_chunk_section_orders_limits_and_excludes_non_ingested() -> None: + store = ArtifactCompileStoreStub() + + artifact_section = _compile_artifact_chunk_section( + store, # type: ignore[arg-type] + artifact_retrieval=CompileContextTaskScopedArtifactRetrievalInput( + task_id=store.task_id, + query="Alpha beta", + limit=2, + ), + ) + + assert artifact_section.items == [ + { + "id": str(store.chunk_ids[0]), + "task_id": str(store.task_id), + "task_artifact_id": str(store.artifact_ids[0]), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + { + "id": str(store.chunk_ids[1]), + "task_id": str(store.task_id), + "task_artifact_id": str(store.artifact_ids[1]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + }, + ] + assert artifact_section.summary == { + "requested": True, + "scope": {"kind": "task", "task_id": str(store.task_id)}, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 2, + "searched_artifact_count": 3, + "candidate_count": 3, + "included_count": 2, + "excluded_uningested_artifact_count": 1, + "excluded_limit_count": 1, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + } + assert [decision.reason for decision in artifact_section.decisions] == [ + "artifact_not_ingested", + "within_artifact_chunk_limit", + "within_artifact_chunk_limit", + "artifact_chunk_limit_exceeded", + ] + assert artifact_section.decisions[0].metadata["relative_path"] == "notes/hidden.txt" + assert artifact_section.decisions[-1].metadata["relative_path"] == "notes/c.txt" + + def test_compile_memory_section_orders_limits_and_excludes_deleted() -> None: store = SemanticCompileStoreStub() deleted_memory = { diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 446f108..7afeb2c 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from uuid import uuid4 +import pytest import apps.api.src.alicebot_api.main as main_module from apps.api.src.alicebot_api.config import Settings from alicebot_api.compiler import CompiledTraceRun @@ -180,12 +181,21 @@ def fake_user_connection(database_url: str, current_user_id): captured["current_user_id"] = current_user_id yield object() - def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semantic_retrieval): + def fake_compile_and_persist_trace( + store, + *, + user_id, + thread_id, + limits, + semantic_retrieval, + artifact_retrieval, + ): captured["store_type"] = type(store).__name__ captured["user_id"] = user_id captured["thread_id"] = thread_id captured["limits"] = limits captured["semantic_retrieval"] = semantic_retrieval + captured["artifact_retrieval"] = artifact_retrieval return CompiledTraceRun( trace_id="trace-123", trace_event_count=5, @@ -248,6 +258,27 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, }, + "artifact_chunks": [], + "artifact_chunk_summary": { + "requested": False, + "scope": None, + "query": None, + "query_terms": [], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + }, "entities": [ { "id": "entity-123", @@ -362,6 +393,27 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, }, + "artifact_chunks": [], + "artifact_chunk_summary": { + "requested": False, + "scope": None, + "query": None, + "query_terms": [], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + }, "entities": [ { "id": "entity-123", @@ -406,6 +458,7 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti assert captured["limits"].max_entities == 2 assert captured["limits"].max_entity_edges == 6 assert captured["semantic_retrieval"] is None + assert captured["artifact_retrieval"] is None def test_compile_context_returns_not_found_when_scope_row_is_missing(monkeypatch) -> None: @@ -433,7 +486,9 @@ def fake_user_connection(_database_url: str, _current_user_id): } -def test_compile_context_routes_semantic_inputs_and_validation_errors(monkeypatch) -> None: +def test_compile_context_routes_semantic_and_artifact_inputs_and_validation_errors( + monkeypatch, +) -> None: user_id = uuid4() thread_id = uuid4() config_id = uuid4() @@ -446,12 +501,21 @@ def fake_user_connection(database_url: str, current_user_id): captured["current_user_id"] = current_user_id yield object() - def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semantic_retrieval): + def fake_compile_and_persist_trace( + store, + *, + user_id, + thread_id, + limits, + semantic_retrieval, + artifact_retrieval, + ): captured["store_type"] = type(store).__name__ captured["user_id"] = user_id captured["thread_id"] = thread_id captured["limits"] = limits captured["semantic_retrieval"] = semantic_retrieval + captured["artifact_retrieval"] = artifact_retrieval return CompiledTraceRun( trace_id="trace-semantic", trace_event_count=7, @@ -517,6 +581,44 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, }, + "artifact_chunks": [ + { + "id": "chunk-123", + "task_id": "task-123", + "task_artifact_id": "artifact-123", + "relative_path": "docs/spec.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 16, + "text": "alpha beta spec", + "match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + } + ], + "artifact_chunk_summary": { + "requested": True, + "scope": {"kind": "task", "task_id": "task-123"}, + "query": "alpha beta", + "query_terms": ["alpha", "beta"], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 2, + "searched_artifact_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + }, "entities": [], "entity_summary": { "candidate_count": 0, @@ -546,6 +648,12 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti query_vector=[0.1, 0.2, 0.3], limit=2, ), + artifact_retrieval=main_module.CompileContextTaskScopedArtifactRetrievalRequest( + kind="task", + task_id=uuid4(), + query="alpha beta", + limit=2, + ), ) ) @@ -572,6 +680,9 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti assert captured["semantic_retrieval"].embedding_config_id == config_id assert captured["semantic_retrieval"].query_vector == (0.1, 0.2, 0.3) assert captured["semantic_retrieval"].limit == 2 + assert captured["artifact_retrieval"].task_id is not None + assert captured["artifact_retrieval"].query == "alpha beta" + assert captured["artifact_retrieval"].limit == 2 monkeypatch.setattr( main_module, @@ -601,6 +712,21 @@ def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semanti } +def test_compile_context_request_rejects_invalid_artifact_scope_shape() -> None: + with pytest.raises(Exception) as exc_info: + main_module.CompileContextRequest( + user_id=uuid4(), + thread_id=uuid4(), + artifact_retrieval={ + "kind": "task", + "task_artifact_id": str(uuid4()), + "query": "alpha beta", + }, + ) + + assert "task_id" in str(exc_info.value) + + def test_generate_assistant_response_returns_assistant_and_trace_payload(monkeypatch) -> None: user_id = uuid4() thread_id = uuid4() From 50012a93c812f143eeea575499ca92a79a10c5e2 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 15 Mar 2026 00:26:19 +0100 Subject: [PATCH 007/135] Sprint 5G: artifact chunk embedding substrate (#7) Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 30 +- RULES.md | 1 + ...314_0025_task_artifact_chunk_embeddings.py | 81 +++ apps/api/src/alicebot_api/contracts.py | 58 +++ apps/api/src/alicebot_api/embedding.py | 208 +++++++- apps/api/src/alicebot_api/main.py | 109 ++++ apps/api/src/alicebot_api/store.py | 256 ++++++++++ tests/integration/test_migrations.py | 23 + ...test_task_artifact_chunk_embeddings_api.py | 474 ++++++++++++++++++ ...314_0025_task_artifact_chunk_embeddings.py | 46 ++ tests/unit/test_main.py | 227 +++++++++ .../test_task_artifact_chunk_embedding.py | 344 +++++++++++++ ...est_task_artifact_chunk_embedding_store.py | 236 +++++++++ 13 files changed, 2079 insertions(+), 14 deletions(-) create mode 100644 apps/api/alembic/versions/20260314_0025_task_artifact_chunk_embeddings.py create mode 100644 tests/integration/test_task_artifact_chunk_embeddings_api.py create mode 100644 tests/unit/test_20260314_0025_task_artifact_chunk_embeddings.py create mode 100644 tests/unit/test_task_artifact_chunk_embedding.py create mode 100644 tests/unit/test_task_artifact_chunk_embedding_store.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index b362714..0e9f72d 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5F. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5G. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, and `task_artifact_chunks`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, and optional compile-path artifact chunk inclusion as a separate context section +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, optional compile-path artifact chunk inclusion as a separate context section, and explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, and optional compile-path inclusion of retrieved artifact chunks in a separate response section. Broader runner-style orchestration, automatic multi-step progression, artifact chunk embeddings and semantic retrieval, rich-document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, optional compile-path inclusion of retrieved artifact chunks in a separate response section, and explicit artifact-chunk embedding storage tied to existing embedding configs. Broader runner-style orchestration, automatic multi-step progression, artifact-chunk semantic retrieval, rich-document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -24,7 +24,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` -- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` +- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `POST /v0/task-artifact-chunk-embeddings`, `GET /v0/task-artifacts/{task_artifact_id}/chunk-embeddings`, `GET /v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings`, `GET /v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` - `apps/web` and `workers` remain starter shells only. ### Data Foundation @@ -37,7 +37,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` - graph tables: `entities`, `entity_edges` - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` - - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks` + - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, `task_artifact_chunk_embeddings` - `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. - `memory_review_labels` are append-only by database enforcement. - `tasks` are explicit user-scoped lifecycle records keyed to one thread and one tool, with durable request/tool snapshots, status in `pending_approval | approved | executed | denied | blocked`, and latest approval/execution pointers for the current narrow lifecycle seam. @@ -51,17 +51,18 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - `task_workspaces` persist one active workspace record per visible task and user, store a deterministic `local_path`, and enforce that active uniqueness through a partial unique index on `(user_id, task_id)`. - `task_artifacts` persist explicit user-scoped artifact rows linked to both `tasks` and `task_workspaces`, store `status = registered`, `ingestion_status in ('pending', 'ingested')`, store only a workspace-relative `relative_path` plus optional `media_type_hint`, and enforce deterministic duplicate rejection through a unique index on `(user_id, task_workspace_id, relative_path)`. - `task_artifact_chunks` persist explicit user-scoped durable chunk rows linked to one artifact, store ordered `sequence_no`, zero-based `char_start`, exclusive `char_end_exclusive`, and chunk `text`, and enforce deterministic uniqueness through a unique index on `(user_id, task_artifact_id, sequence_no)`. +- `task_artifact_chunk_embeddings` persist explicit user-scoped durable embedding rows linked to one visible chunk and one visible embedding config, store validated `dimensions` and `vector`, and enforce deterministic uniqueness through a unique index on `(user_id, task_artifact_chunk_id, embedding_config_id)`. - `execution_budgets` enforce at most one active budget per `(user_id, tool_key, domain_hint)` selector scope through a partial unique index. - Per-request user context is set in the database through `app.current_user_id()`. - `TASK_WORKSPACE_ROOT` defines the only allowed base directory for workspace provisioning, and the live path rule is `resolved_root / user_id / task_id`. ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, deterministic lexical artifact chunk retrieval, and narrow compile-path artifact chunk inclusion. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, and narrow compile-path artifact chunk inclusion. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, and Sprint 5F compile-path artifact chunk integration. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, and Sprint 5G artifact-chunk embedding persistence and reads. ## Core Flows Implemented Now @@ -202,12 +203,23 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 6. Order matches deterministically by matched query term count desc, first match offset asc, relative path asc, sequence no asc, and id asc. 7. Return stable summary metadata describing query terms, scope, searched artifact count, and ordering. +### Artifact Chunk Embedding Storage + +1. Accept a user-scoped `POST /v0/task-artifact-chunk-embeddings` request. +2. Require `task_artifact_chunk_id` to reference one visible persisted chunk. +3. Require `embedding_config_id` to reference one visible persisted embedding config. +4. Normalize the submitted vector as finite numeric values only. +5. Reject writes unless the vector length matches `embedding_config.dimensions`. +6. Persist or update exactly one embedding per visible `(task_artifact_chunk_id, embedding_config_id)` pair. +7. Expose deterministic reads by artifact scope, chunk scope, and embedding id. +8. Order list reads by chunk sequence first, then `created_at ASC`, then `id ASC`. + ## Security Model Implemented Now -- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, task-workspace, task-artifact, and task-artifact-chunk tables enforce row-level security. +- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, task-workspace, task-artifact, task-artifact-chunk, and task-artifact-chunk-embedding tables enforce row-level security. - The runtime role is limited to the narrow `SELECT` / `INSERT` / `UPDATE` permissions required by the shipped seams; there is no broad DDL or unrestricted table access at runtime. - Cross-user references are constrained through composite foreign keys on `(id, user_id)` where the schema needs ownership-linked joins. -- Approval, execution, memory, entity, task/task-step, task-workspace, task-artifact, and task-artifact-chunk reads all operate only inside the current user scope. +- Approval, execution, memory, entity, task/task-step, task-workspace, task-artifact, task-artifact-chunk, and task-artifact-chunk-embedding reads all operate only inside the current user scope. - Task-step manual continuation adds both schema-level and service-level lineage protection: - schema-level: user-scoped foreign keys and parent-not-self check - service-level: same-task, latest-step, visible-approval, visible-execution, and parent-outcome-match validation diff --git a/RULES.md b/RULES.md index 6c02e05..ad493d8 100644 --- a/RULES.md +++ b/RULES.md @@ -3,6 +3,7 @@ ## Truth And Scope - The active sprint packet is the top scope boundary for implementation work. +- Treat `.ai/active/SPRINT_PACKET.md` as an input/control artifact: do not edit it during implementation unless Control Tower explicitly changes the sprint. - Never describe planned behavior as already implemented. - Keep canonical truth files concise, current, and durable. - Archive stale planning or history material instead of deleting it when traceability still matters. diff --git a/apps/api/alembic/versions/20260314_0025_task_artifact_chunk_embeddings.py b/apps/api/alembic/versions/20260314_0025_task_artifact_chunk_embeddings.py new file mode 100644 index 0000000..23cba0e --- /dev/null +++ b/apps/api/alembic/versions/20260314_0025_task_artifact_chunk_embeddings.py @@ -0,0 +1,81 @@ +"""Add user-scoped task artifact chunk embedding records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260314_0025" +down_revision = "20260314_0024" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_artifact_chunk_embeddings",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_artifact_chunk_embeddings ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_artifact_chunk_id uuid NOT NULL, + embedding_config_id uuid NOT NULL, + dimensions integer NOT NULL, + vector jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, task_artifact_chunk_id, embedding_config_id), + CONSTRAINT task_artifact_chunk_embeddings_chunk_fkey + FOREIGN KEY (task_artifact_chunk_id, user_id) + REFERENCES task_artifact_chunks(id, user_id) ON DELETE CASCADE, + CONSTRAINT task_artifact_chunk_embeddings_embedding_config_fkey + FOREIGN KEY (embedding_config_id, user_id) + REFERENCES embedding_configs(id, user_id) ON DELETE CASCADE, + CONSTRAINT task_artifact_chunk_embeddings_dimensions_check + CHECK (dimensions > 0), + CONSTRAINT task_artifact_chunk_embeddings_vector_array_check + CHECK (jsonb_typeof(vector) = 'array'), + CONSTRAINT task_artifact_chunk_embeddings_vector_nonempty_check + CHECK (jsonb_array_length(vector) > 0), + CONSTRAINT task_artifact_chunk_embeddings_vector_dimensions_match_check + CHECK (jsonb_array_length(vector) = dimensions) + ); + + CREATE INDEX task_artifact_chunk_embeddings_user_chunk_created_idx + ON task_artifact_chunk_embeddings (user_id, task_artifact_chunk_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON task_artifact_chunk_embeddings TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_artifact_chunk_embeddings_is_owner ON task_artifact_chunk_embeddings + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS task_artifact_chunk_embeddings", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index 86e7934..c624578 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -23,6 +23,7 @@ TaskArtifactStatus = Literal["registered"] TaskArtifactIngestionStatus = Literal["pending", "ingested"] TaskArtifactChunkRetrievalScopeKind = Literal["task", "artifact"] +TaskArtifactChunkEmbeddingListScopeKind = Literal["artifact", "chunk"] TaskLifecycleSource = Literal[ "approval_request", "approval_resolution", @@ -136,6 +137,11 @@ TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_CHUNK_LIST_ORDER = ["sequence_no_asc", "id_asc"] +TASK_ARTIFACT_CHUNK_EMBEDDING_LIST_ORDER = [ + "task_artifact_chunk_sequence_no_asc", + "created_at_asc", + "id_asc", +] TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER = [ "matched_query_term_count_desc", "first_match_char_start_asc", @@ -747,6 +753,20 @@ def as_payload(self) -> JsonObject: } +@dataclass(frozen=True, slots=True) +class TaskArtifactChunkEmbeddingUpsertInput: + task_artifact_chunk_id: UUID + embedding_config_id: UUID + vector: tuple[float, ...] + + def as_payload(self) -> JsonObject: + return { + "task_artifact_chunk_id": str(self.task_artifact_chunk_id), + "embedding_config_id": str(self.embedding_config_id), + "vector": [float(value) for value in self.vector], + } + + @dataclass(frozen=True, slots=True) class SemanticMemoryRetrievalRequestInput: embedding_config_id: UUID @@ -1770,6 +1790,44 @@ class TaskArtifactChunkListResponse(TypedDict): summary: TaskArtifactChunkListSummary +class TaskArtifactChunkEmbeddingRecord(TypedDict): + id: str + task_artifact_id: str + task_artifact_chunk_id: str + task_artifact_chunk_sequence_no: int + embedding_config_id: str + dimensions: int + vector: list[float] + created_at: str + updated_at: str + + +class TaskArtifactChunkEmbeddingWriteResponse(TypedDict): + embedding: TaskArtifactChunkEmbeddingRecord + write_mode: Literal["created", "updated"] + + +class TaskArtifactChunkEmbeddingDetailResponse(TypedDict): + embedding: TaskArtifactChunkEmbeddingRecord + + +class TaskArtifactChunkEmbeddingListScope(TypedDict): + kind: TaskArtifactChunkEmbeddingListScopeKind + task_artifact_id: str + task_artifact_chunk_id: NotRequired[str] + + +class TaskArtifactChunkEmbeddingListSummary(TypedDict): + total_count: int + order: list[str] + scope: TaskArtifactChunkEmbeddingListScope + + +class TaskArtifactChunkEmbeddingListResponse(TypedDict): + items: list[TaskArtifactChunkEmbeddingRecord] + summary: TaskArtifactChunkEmbeddingListSummary + + class TaskArtifactIngestionResponse(TypedDict): artifact: TaskArtifactRecord summary: TaskArtifactChunkListSummary diff --git a/apps/api/src/alicebot_api/embedding.py b/apps/api/src/alicebot_api/embedding.py index 5248197..320d5fb 100644 --- a/apps/api/src/alicebot_api/embedding.py +++ b/apps/api/src/alicebot_api/embedding.py @@ -5,9 +5,11 @@ import psycopg +from alicebot_api.artifacts import TaskArtifactNotFoundError from alicebot_api.contracts import ( EMBEDDING_CONFIG_LIST_ORDER, MEMORY_EMBEDDING_LIST_ORDER, + TASK_ARTIFACT_CHUNK_EMBEDDING_LIST_ORDER, EmbeddingConfigCreateInput, EmbeddingConfigCreateResponse, EmbeddingConfigListResponse, @@ -19,8 +21,21 @@ MemoryEmbeddingRecord, MemoryEmbeddingUpsertInput, MemoryEmbeddingUpsertResponse, + TaskArtifactChunkEmbeddingDetailResponse, + TaskArtifactChunkEmbeddingListResponse, + TaskArtifactChunkEmbeddingListScope, + TaskArtifactChunkEmbeddingListScopeKind, + TaskArtifactChunkEmbeddingListSummary, + TaskArtifactChunkEmbeddingRecord, + TaskArtifactChunkEmbeddingUpsertInput, + TaskArtifactChunkEmbeddingWriteResponse, +) +from alicebot_api.store import ( + ContinuityStore, + EmbeddingConfigRow, + MemoryEmbeddingRow, + TaskArtifactChunkEmbeddingRow, ) -from alicebot_api.store import ContinuityStore, EmbeddingConfigRow, MemoryEmbeddingRow class EmbeddingConfigValidationError(ValueError): @@ -35,6 +50,14 @@ class MemoryEmbeddingNotFoundError(LookupError): """Raised when a requested memory embedding is not visible inside the current user scope.""" +class TaskArtifactChunkEmbeddingValidationError(ValueError): + """Raised when an artifact-chunk embedding request fails explicit validation.""" + + +class TaskArtifactChunkEmbeddingNotFoundError(LookupError): + """Raised when an artifact-chunk embedding read target is not visible inside the current user scope.""" + + def _duplicate_embedding_config_message( *, provider: str, @@ -72,20 +95,67 @@ def _serialize_memory_embedding(embedding: MemoryEmbeddingRow) -> MemoryEmbeddin } -def _validate_vector(vector: tuple[float, ...]) -> list[float]: +def _serialize_task_artifact_chunk_embedding( + embedding: TaskArtifactChunkEmbeddingRow, +) -> TaskArtifactChunkEmbeddingRecord: + return { + "id": str(embedding["id"]), + "task_artifact_id": str(embedding["task_artifact_id"]), + "task_artifact_chunk_id": str(embedding["task_artifact_chunk_id"]), + "task_artifact_chunk_sequence_no": embedding["task_artifact_chunk_sequence_no"], + "embedding_config_id": str(embedding["embedding_config_id"]), + "dimensions": embedding["dimensions"], + "vector": [float(value) for value in embedding["vector"]], + "created_at": embedding["created_at"].isoformat(), + "updated_at": embedding["updated_at"].isoformat(), + } + + +def _validate_vector( + vector: tuple[float, ...], + *, + error_type: type[ValueError], +) -> list[float]: if not vector: - raise MemoryEmbeddingValidationError("vector must include at least one numeric value") + raise error_type("vector must include at least one numeric value") normalized: list[float] = [] for value in vector: normalized_value = float(value) if not math.isfinite(normalized_value): - raise MemoryEmbeddingValidationError("vector must contain only finite numeric values") + raise error_type("vector must contain only finite numeric values") normalized.append(normalized_value) return normalized +def _build_task_artifact_chunk_embedding_scope( + *, + kind: TaskArtifactChunkEmbeddingListScopeKind, + task_artifact_id: UUID, + task_artifact_chunk_id: UUID | None = None, +) -> TaskArtifactChunkEmbeddingListScope: + scope: TaskArtifactChunkEmbeddingListScope = { + "kind": kind, + "task_artifact_id": str(task_artifact_id), + } + if task_artifact_chunk_id is not None: + scope["task_artifact_chunk_id"] = str(task_artifact_chunk_id) + return scope + + +def _build_task_artifact_chunk_embedding_summary( + *, + items: list[TaskArtifactChunkEmbeddingRecord], + scope: TaskArtifactChunkEmbeddingListScope, +) -> TaskArtifactChunkEmbeddingListSummary: + return { + "total_count": len(items), + "order": list(TASK_ARTIFACT_CHUNK_EMBEDDING_LIST_ORDER), + "scope": scope, + } + + def create_embedding_config_record( store: ContinuityStore, *, @@ -168,7 +238,7 @@ def upsert_memory_embedding_record( f"{request.embedding_config_id}" ) - vector = _validate_vector(request.vector) + vector = _validate_vector(request.vector, error_type=MemoryEmbeddingValidationError) if len(vector) != config["dimensions"]: raise MemoryEmbeddingValidationError( "vector length must match embedding config dimensions " @@ -240,3 +310,131 @@ def list_memory_embedding_records( "items": items, "summary": summary, } + + +def upsert_task_artifact_chunk_embedding_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskArtifactChunkEmbeddingUpsertInput, +) -> TaskArtifactChunkEmbeddingWriteResponse: + del user_id + + chunk = store.get_task_artifact_chunk_optional(request.task_artifact_chunk_id) + if chunk is None: + raise TaskArtifactChunkEmbeddingValidationError( + "task_artifact_chunk_id must reference an existing task artifact chunk owned by the " + f"user: {request.task_artifact_chunk_id}" + ) + + config = store.get_embedding_config_optional(request.embedding_config_id) + if config is None: + raise TaskArtifactChunkEmbeddingValidationError( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{request.embedding_config_id}" + ) + + vector = _validate_vector(request.vector, error_type=TaskArtifactChunkEmbeddingValidationError) + if len(vector) != config["dimensions"]: + raise TaskArtifactChunkEmbeddingValidationError( + "vector length must match embedding config dimensions " + f"({config['dimensions']}): {len(vector)}" + ) + + existing = store.get_task_artifact_chunk_embedding_by_chunk_and_config_optional( + task_artifact_chunk_id=request.task_artifact_chunk_id, + embedding_config_id=request.embedding_config_id, + ) + if existing is None: + created = store.create_task_artifact_chunk_embedding( + task_artifact_chunk_id=request.task_artifact_chunk_id, + embedding_config_id=request.embedding_config_id, + dimensions=config["dimensions"], + vector=vector, + ) + return { + "embedding": _serialize_task_artifact_chunk_embedding(created), + "write_mode": "created", + } + + updated = store.update_task_artifact_chunk_embedding( + task_artifact_chunk_embedding_id=existing["id"], + dimensions=config["dimensions"], + vector=vector, + ) + return { + "embedding": _serialize_task_artifact_chunk_embedding(updated), + "write_mode": "updated", + } + + +def get_task_artifact_chunk_embedding_record( + store: ContinuityStore, + *, + user_id: UUID, + task_artifact_chunk_embedding_id: UUID, +) -> TaskArtifactChunkEmbeddingDetailResponse: + del user_id + + embedding = store.get_task_artifact_chunk_embedding_optional(task_artifact_chunk_embedding_id) + if embedding is None: + raise TaskArtifactChunkEmbeddingNotFoundError( + f"task artifact chunk embedding {task_artifact_chunk_embedding_id} was not found" + ) + + return {"embedding": _serialize_task_artifact_chunk_embedding(embedding)} + + +def list_task_artifact_chunk_embedding_records_for_artifact( + store: ContinuityStore, + *, + user_id: UUID, + task_artifact_id: UUID, +) -> TaskArtifactChunkEmbeddingListResponse: + del user_id + + artifact = store.get_task_artifact_optional(task_artifact_id) + if artifact is None: + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + + items = [ + _serialize_task_artifact_chunk_embedding(embedding) + for embedding in store.list_task_artifact_chunk_embeddings_for_artifact(task_artifact_id) + ] + scope = _build_task_artifact_chunk_embedding_scope( + kind="artifact", + task_artifact_id=task_artifact_id, + ) + return { + "items": items, + "summary": _build_task_artifact_chunk_embedding_summary(items=items, scope=scope), + } + + +def list_task_artifact_chunk_embedding_records_for_chunk( + store: ContinuityStore, + *, + user_id: UUID, + task_artifact_chunk_id: UUID, +) -> TaskArtifactChunkEmbeddingListResponse: + del user_id + + chunk = store.get_task_artifact_chunk_optional(task_artifact_chunk_id) + if chunk is None: + raise TaskArtifactChunkEmbeddingNotFoundError( + f"task artifact chunk {task_artifact_chunk_id} was not found" + ) + + items = [ + _serialize_task_artifact_chunk_embedding(embedding) + for embedding in store.list_task_artifact_chunk_embeddings_for_chunk(task_artifact_chunk_id) + ] + scope = _build_task_artifact_chunk_embedding_scope( + kind="chunk", + task_artifact_id=chunk["task_artifact_id"], + task_artifact_chunk_id=task_artifact_chunk_id, + ) + return { + "items": items, + "summary": _build_task_artifact_chunk_embedding_summary(items=items, scope=scope), + } diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index bab17c2..0106917 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -49,6 +49,7 @@ PolicyEffect, PolicyEvaluationRequestInput, SemanticMemoryRetrievalRequestInput, + TaskArtifactChunkEmbeddingUpsertInput, TOOL_METADATA_VERSION_V0, ApprovalStatus, ArtifactScopedArtifactChunkRetrievalInput, @@ -133,10 +134,16 @@ EmbeddingConfigValidationError, MemoryEmbeddingNotFoundError, MemoryEmbeddingValidationError, + TaskArtifactChunkEmbeddingNotFoundError, + TaskArtifactChunkEmbeddingValidationError, create_embedding_config_record, get_memory_embedding_record, + get_task_artifact_chunk_embedding_record, list_embedding_config_records, list_memory_embedding_records, + list_task_artifact_chunk_embedding_records_for_artifact, + list_task_artifact_chunk_embedding_records_for_chunk, + upsert_task_artifact_chunk_embedding_record, upsert_memory_embedding_record, ) from alicebot_api.entity import ( @@ -355,6 +362,13 @@ class UpsertMemoryEmbeddingRequest(BaseModel): vector: list[float] = Field(min_length=1, max_length=20000) +class UpsertTaskArtifactChunkEmbeddingRequest(BaseModel): + user_id: UUID + task_artifact_chunk_id: UUID + embedding_config_id: UUID + vector: list[float] = Field(min_length=1, max_length=20000) + + class RetrieveSemanticMemoriesRequest(BaseModel): user_id: UUID embedding_config_id: UUID @@ -1951,6 +1965,32 @@ def upsert_memory_embedding(request: UpsertMemoryEmbeddingRequest) -> JSONRespon ) +@app.post("/v0/task-artifact-chunk-embeddings") +def upsert_task_artifact_chunk_embedding( + request: UpsertTaskArtifactChunkEmbeddingRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = upsert_task_artifact_chunk_embedding_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=request.task_artifact_chunk_id, + embedding_config_id=request.embedding_config_id, + vector=tuple(request.vector), + ), + ) + except TaskArtifactChunkEmbeddingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + @app.get("/v0/memories/{memory_id}/embeddings") def list_memory_embeddings(memory_id: UUID, user_id: UUID) -> JSONResponse: settings = get_settings() @@ -1971,6 +2011,52 @@ def list_memory_embeddings(memory_id: UUID, user_id: UUID) -> JSONResponse: ) +@app.get("/v0/task-artifacts/{task_artifact_id}/chunk-embeddings") +def list_task_artifact_chunk_embeddings_for_artifact( + task_artifact_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_artifact_chunk_embedding_records_for_artifact( + ContinuityStore(conn), + user_id=user_id, + task_artifact_id=task_artifact_id, + ) + except TaskArtifactNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings") +def list_task_artifact_chunk_embeddings( + task_artifact_chunk_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_artifact_chunk_embedding_records_for_chunk( + ContinuityStore(conn), + user_id=user_id, + task_artifact_chunk_id=task_artifact_chunk_id, + ) + except TaskArtifactChunkEmbeddingNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.get("/v0/memory-embeddings/{memory_embedding_id}") def get_memory_embedding(memory_embedding_id: UUID, user_id: UUID) -> JSONResponse: settings = get_settings() @@ -1991,6 +2077,29 @@ def get_memory_embedding(memory_embedding_id: UUID, user_id: UUID) -> JSONRespon ) +@app.get("/v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}") +def get_task_artifact_chunk_embedding( + task_artifact_chunk_embedding_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_artifact_chunk_embedding_record( + ContinuityStore(conn), + user_id=user_id, + task_artifact_chunk_embedding_id=task_artifact_chunk_embedding_id, + ) + except TaskArtifactChunkEmbeddingNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.post("/v0/entities") def create_entity(request: CreateEntityRequest) -> JSONResponse: settings = get_settings() diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index d18ced9..206d168 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -270,6 +270,19 @@ class TaskArtifactChunkRow(TypedDict): updated_at: datetime +class TaskArtifactChunkEmbeddingRow(TypedDict): + id: UUID + user_id: UUID + task_artifact_id: UUID + task_artifact_chunk_id: UUID + task_artifact_chunk_sequence_no: int + embedding_config_id: UUID + dimensions: int + vector: list[float] + created_at: datetime + updated_at: datetime + + class TaskStepRow(TypedDict): id: UUID user_id: UUID @@ -1597,6 +1610,181 @@ class LabelCountRow(TypedDict): ORDER BY sequence_no ASC, id ASC """ +GET_TASK_ARTIFACT_CHUNK_SQL = """ + SELECT + id, + user_id, + task_artifact_id, + sequence_no, + char_start, + char_end_exclusive, + text, + created_at, + updated_at + FROM task_artifact_chunks + WHERE id = %s + """ + +INSERT_TASK_ARTIFACT_CHUNK_EMBEDDING_SQL = """ + WITH inserted AS ( + INSERT INTO task_artifact_chunk_embeddings ( + user_id, + task_artifact_chunk_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_artifact_chunk_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + ) + SELECT + inserted.id, + inserted.user_id, + chunks.task_artifact_id, + inserted.task_artifact_chunk_id, + chunks.sequence_no AS task_artifact_chunk_sequence_no, + inserted.embedding_config_id, + inserted.dimensions, + inserted.vector, + inserted.created_at, + inserted.updated_at + FROM inserted + JOIN task_artifact_chunks AS chunks + ON chunks.id = inserted.task_artifact_chunk_id + AND chunks.user_id = inserted.user_id + """ + +GET_TASK_ARTIFACT_CHUNK_EMBEDDING_SQL = """ + SELECT + embeddings.id, + embeddings.user_id, + chunks.task_artifact_id, + embeddings.task_artifact_chunk_id, + chunks.sequence_no AS task_artifact_chunk_sequence_no, + embeddings.embedding_config_id, + embeddings.dimensions, + embeddings.vector, + embeddings.created_at, + embeddings.updated_at + FROM task_artifact_chunk_embeddings AS embeddings + JOIN task_artifact_chunks AS chunks + ON chunks.id = embeddings.task_artifact_chunk_id + AND chunks.user_id = embeddings.user_id + WHERE embeddings.id = %s + """ + +GET_TASK_ARTIFACT_CHUNK_EMBEDDING_BY_CHUNK_AND_CONFIG_SQL = """ + SELECT + embeddings.id, + embeddings.user_id, + chunks.task_artifact_id, + embeddings.task_artifact_chunk_id, + chunks.sequence_no AS task_artifact_chunk_sequence_no, + embeddings.embedding_config_id, + embeddings.dimensions, + embeddings.vector, + embeddings.created_at, + embeddings.updated_at + FROM task_artifact_chunk_embeddings AS embeddings + JOIN task_artifact_chunks AS chunks + ON chunks.id = embeddings.task_artifact_chunk_id + AND chunks.user_id = embeddings.user_id + WHERE embeddings.task_artifact_chunk_id = %s + AND embeddings.embedding_config_id = %s + """ + +LIST_TASK_ARTIFACT_CHUNK_EMBEDDINGS_FOR_CHUNK_SQL = """ + SELECT + embeddings.id, + embeddings.user_id, + chunks.task_artifact_id, + embeddings.task_artifact_chunk_id, + chunks.sequence_no AS task_artifact_chunk_sequence_no, + embeddings.embedding_config_id, + embeddings.dimensions, + embeddings.vector, + embeddings.created_at, + embeddings.updated_at + FROM task_artifact_chunk_embeddings AS embeddings + JOIN task_artifact_chunks AS chunks + ON chunks.id = embeddings.task_artifact_chunk_id + AND chunks.user_id = embeddings.user_id + WHERE embeddings.task_artifact_chunk_id = %s + ORDER BY chunks.sequence_no ASC, embeddings.created_at ASC, embeddings.id ASC + """ + +LIST_TASK_ARTIFACT_CHUNK_EMBEDDINGS_FOR_ARTIFACT_SQL = """ + SELECT + embeddings.id, + embeddings.user_id, + chunks.task_artifact_id, + embeddings.task_artifact_chunk_id, + chunks.sequence_no AS task_artifact_chunk_sequence_no, + embeddings.embedding_config_id, + embeddings.dimensions, + embeddings.vector, + embeddings.created_at, + embeddings.updated_at + FROM task_artifact_chunk_embeddings AS embeddings + JOIN task_artifact_chunks AS chunks + ON chunks.id = embeddings.task_artifact_chunk_id + AND chunks.user_id = embeddings.user_id + WHERE chunks.task_artifact_id = %s + ORDER BY chunks.sequence_no ASC, embeddings.created_at ASC, embeddings.id ASC + """ + +UPDATE_TASK_ARTIFACT_CHUNK_EMBEDDING_SQL = """ + WITH updated AS ( + UPDATE task_artifact_chunk_embeddings + SET dimensions = %s, + vector = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + task_artifact_chunk_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + ) + SELECT + updated.id, + updated.user_id, + chunks.task_artifact_id, + updated.task_artifact_chunk_id, + chunks.sequence_no AS task_artifact_chunk_sequence_no, + updated.embedding_config_id, + updated.dimensions, + updated.vector, + updated.created_at, + updated.updated_at + FROM updated + JOIN task_artifact_chunks AS chunks + ON chunks.id = updated.task_artifact_chunk_id + AND chunks.user_id = updated.user_id + """ + UPDATE_TASK_ARTIFACT_INGESTION_STATUS_SQL = """ UPDATE task_artifacts SET ingestion_status = %s, @@ -2770,9 +2958,77 @@ def create_task_artifact_chunk( (task_artifact_id, sequence_no, char_start, char_end_exclusive, text), ) + def get_task_artifact_chunk_optional(self, task_artifact_chunk_id: UUID) -> TaskArtifactChunkRow | None: + return self._fetch_optional_one(GET_TASK_ARTIFACT_CHUNK_SQL, (task_artifact_chunk_id,)) + def list_task_artifact_chunks(self, task_artifact_id: UUID) -> list[TaskArtifactChunkRow]: return self._fetch_all(LIST_TASK_ARTIFACT_CHUNKS_SQL, (task_artifact_id,)) + def create_task_artifact_chunk_embedding( + self, + *, + task_artifact_chunk_id: UUID, + embedding_config_id: UUID, + dimensions: int, + vector: list[float], + ) -> TaskArtifactChunkEmbeddingRow: + return self._fetch_one( + "create_task_artifact_chunk_embedding", + INSERT_TASK_ARTIFACT_CHUNK_EMBEDDING_SQL, + (task_artifact_chunk_id, embedding_config_id, dimensions, Jsonb(vector)), + ) + + def get_task_artifact_chunk_embedding_optional( + self, + task_artifact_chunk_embedding_id: UUID, + ) -> TaskArtifactChunkEmbeddingRow | None: + return self._fetch_optional_one( + GET_TASK_ARTIFACT_CHUNK_EMBEDDING_SQL, + (task_artifact_chunk_embedding_id,), + ) + + def get_task_artifact_chunk_embedding_by_chunk_and_config_optional( + self, + *, + task_artifact_chunk_id: UUID, + embedding_config_id: UUID, + ) -> TaskArtifactChunkEmbeddingRow | None: + return self._fetch_optional_one( + GET_TASK_ARTIFACT_CHUNK_EMBEDDING_BY_CHUNK_AND_CONFIG_SQL, + (task_artifact_chunk_id, embedding_config_id), + ) + + def list_task_artifact_chunk_embeddings_for_chunk( + self, + task_artifact_chunk_id: UUID, + ) -> list[TaskArtifactChunkEmbeddingRow]: + return self._fetch_all( + LIST_TASK_ARTIFACT_CHUNK_EMBEDDINGS_FOR_CHUNK_SQL, + (task_artifact_chunk_id,), + ) + + def list_task_artifact_chunk_embeddings_for_artifact( + self, + task_artifact_id: UUID, + ) -> list[TaskArtifactChunkEmbeddingRow]: + return self._fetch_all( + LIST_TASK_ARTIFACT_CHUNK_EMBEDDINGS_FOR_ARTIFACT_SQL, + (task_artifact_id,), + ) + + def update_task_artifact_chunk_embedding( + self, + *, + task_artifact_chunk_embedding_id: UUID, + dimensions: int, + vector: list[float], + ) -> TaskArtifactChunkEmbeddingRow: + return self._fetch_one( + "update_task_artifact_chunk_embedding", + UPDATE_TASK_ARTIFACT_CHUNK_EMBEDDING_SQL, + (dimensions, Jsonb(vector), task_artifact_chunk_embedding_id), + ) + def update_task_artifact_ingestion_status( self, *, diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 001f1a0..e8e7011 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -303,6 +303,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): assert cur.fetchone()[0] == "task_artifacts" cur.execute("SELECT to_regclass('public.task_artifact_chunks')") assert cur.fetchone()[0] == "task_artifact_chunks" + cur.execute("SELECT to_regclass('public.task_artifact_chunk_embeddings')") + assert cur.fetchone()[0] == "task_artifact_chunk_embeddings" cur.execute("SELECT to_regclass('public.task_steps')") assert cur.fetchone()[0] == "task_steps" cur.execute( @@ -386,6 +388,7 @@ def test_migrations_upgrade_and_downgrade(database_urls): 'task_workspaces', 'task_artifacts', 'task_artifact_chunks', + 'task_artifact_chunk_embeddings', 'task_steps', 'execution_budgets', 'tool_executions' @@ -407,6 +410,7 @@ def test_migrations_upgrade_and_downgrade(database_urls): ("memory_revisions", True, True), ("policies", True, True), ("sessions", True, True), + ("task_artifact_chunk_embeddings", True, True), ("task_artifact_chunks", True, True), ("task_artifacts", True, True), ("task_steps", True, True), @@ -479,6 +483,8 @@ def test_migrations_upgrade_and_downgrade(database_urls): has_table_privilege('alicebot_app', 'task_artifacts', 'DELETE'), has_table_privilege('alicebot_app', 'task_artifact_chunks', 'UPDATE'), has_table_privilege('alicebot_app', 'task_artifact_chunks', 'DELETE'), + has_table_privilege('alicebot_app', 'task_artifact_chunk_embeddings', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_artifact_chunk_embeddings', 'DELETE'), has_table_privilege('alicebot_app', 'task_steps', 'UPDATE'), has_table_privilege('alicebot_app', 'task_steps', 'DELETE'), has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE'), @@ -524,16 +530,33 @@ def test_migrations_upgrade_and_downgrade(database_urls): False, True, False, + True, + False, False, False, ) + command.downgrade(config, "20260314_0024") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_artifact_chunk_embeddings')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.task_artifact_chunks')") + assert cur.fetchone()[0] == "task_artifact_chunks" + cur.execute("SELECT to_regclass('public.task_artifacts')") + assert cur.fetchone()[0] == "task_artifacts" + cur.execute("SELECT to_regclass('public.task_workspaces')") + assert cur.fetchone()[0] == "task_workspaces" + command.downgrade(config, "20260313_0021") with psycopg.connect(database_urls["admin"]) as conn: with conn.cursor() as cur: cur.execute("SELECT to_regclass('public.task_artifact_chunks')") assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.task_artifact_chunk_embeddings')") + assert cur.fetchone()[0] is None cur.execute("SELECT to_regclass('public.task_artifacts')") assert cur.fetchone()[0] is None cur.execute("SELECT to_regclass('public.task_workspaces')") diff --git a/tests/integration/test_task_artifact_chunk_embeddings_api.py b/tests/integration/test_task_artifact_chunk_embeddings_api.py new file mode 100644 index 0000000..97559cf --- /dev/null +++ b/tests/integration/test_task_artifact_chunk_embeddings_api.py @@ -0,0 +1,474 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_task_artifact_with_chunks(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Artifact chunk embedding thread") + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + task = store.create_task( + thread_id=thread["id"], + tool_id=tool["id"], + status="approved", + request={ + "thread_id": str(thread["id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + tool={ + "id": str(tool["id"]), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + latest_approval_id=None, + latest_execution_id=None, + ) + workspace = store.create_task_workspace( + task_id=task["id"], + status="active", + local_path=f"/tmp/task-workspaces/{user_id}/{task['id']}", + ) + artifact = store.create_task_artifact( + task_id=task["id"], + task_workspace_id=workspace["id"], + status="registered", + ingestion_status="ingested", + relative_path="docs/spec.txt", + media_type_hint="text/plain", + ) + first_chunk = store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=1, + char_start=0, + char_end_exclusive=12, + text="alpha chunk", + ) + second_chunk = store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=2, + char_start=12, + char_end_exclusive=24, + text="beta chunk", + ) + + return { + "user_id": user_id, + "task_id": task["id"], + "task_artifact_id": artifact["id"], + "first_chunk_id": first_chunk["id"], + "second_chunk_id": second_chunk["id"], + } + + +def seed_embedding_config( + database_url: str, + *, + user_id: UUID, + provider: str, + model: str, + version: str, + dimensions: int, +) -> UUID: + with user_connection(database_url, user_id) as conn: + created = ContinuityStore(conn).create_embedding_config( + provider=provider, + model=model, + version=version, + dimensions=dimensions, + status="active", + metadata={"task": "artifact_chunk_retrieval"}, + ) + return created["id"] + + +def test_task_artifact_chunk_embedding_endpoints_persist_and_read_embeddings( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_task_artifact_with_chunks( + migrated_database_urls["app"], + email="owner@example.com", + ) + first_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-small", + version="2026-03-14", + dimensions=3, + ) + second_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-15", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + second_write_status, second_write_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "task_artifact_chunk_id": str(seeded["second_chunk_id"]), + "embedding_config_id": str(first_config_id), + "vector": [0.4, 0.5, 0.6], + }, + ) + first_write_status, first_write_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "task_artifact_chunk_id": str(seeded["first_chunk_id"]), + "embedding_config_id": str(second_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + update_status, update_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "task_artifact_chunk_id": str(seeded["second_chunk_id"]), + "embedding_config_id": str(first_config_id), + "vector": [0.9, 0.8, 0.7], + }, + ) + artifact_list_status, artifact_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{seeded['task_artifact_id']}/chunk-embeddings", + query_params={"user_id": str(seeded["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifact-chunks/{seeded['second_chunk_id']}/embeddings", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/task-artifact-chunk-embeddings/{update_payload['embedding']['id']}", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert second_write_status == 201 + assert second_write_payload["write_mode"] == "created" + assert first_write_status == 201 + assert first_write_payload["write_mode"] == "created" + assert update_status == 201 + assert update_payload["write_mode"] == "updated" + assert update_payload["embedding"]["vector"] == [0.9, 0.8, 0.7] + assert artifact_list_status == 200 + assert artifact_list_payload["summary"] == { + "total_count": 2, + "order": ["task_artifact_chunk_sequence_no_asc", "created_at_asc", "id_asc"], + "scope": { + "kind": "artifact", + "task_artifact_id": str(seeded["task_artifact_id"]), + }, + } + assert chunk_list_status == 200 + assert chunk_list_payload["summary"] == { + "total_count": 1, + "order": ["task_artifact_chunk_sequence_no_asc", "created_at_asc", "id_asc"], + "scope": { + "kind": "chunk", + "task_artifact_id": str(seeded["task_artifact_id"]), + "task_artifact_chunk_id": str(seeded["second_chunk_id"]), + }, + } + assert detail_status == 200 + assert detail_payload["embedding"]["id"] == update_payload["embedding"]["id"] + assert detail_payload["embedding"]["task_artifact_chunk_sequence_no"] == 2 + assert set(detail_payload["embedding"]) == { + "id", + "task_artifact_id", + "task_artifact_chunk_id", + "task_artifact_chunk_sequence_no", + "embedding_config_id", + "dimensions", + "vector", + "created_at", + "updated_at", + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored = ContinuityStore(conn).list_task_artifact_chunk_embeddings_for_artifact( + seeded["task_artifact_id"] + ) + + assert [item["id"] for item in artifact_list_payload["items"]] == [ + str(embedding["id"]) for embedding in stored + ] + assert [item["task_artifact_chunk_id"] for item in artifact_list_payload["items"]] == [ + str(seeded["first_chunk_id"]), + str(seeded["second_chunk_id"]), + ] + + +def test_task_artifact_chunk_embedding_writes_reject_invalid_refs_dimension_mismatches_and_cross_user_refs( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_task_artifact_with_chunks(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task_artifact_with_chunks( + migrated_database_urls["app"], + email="intruder@example.com", + ) + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-14", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-14", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_config_status, missing_config_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "task_artifact_chunk_id": str(owner["first_chunk_id"]), + "embedding_config_id": str(uuid4()), + "vector": [0.1, 0.2, 0.3], + }, + ) + missing_chunk_status, missing_chunk_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "task_artifact_chunk_id": str(uuid4()), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + mismatch_status, mismatch_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "task_artifact_chunk_id": str(owner["first_chunk_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2], + }, + ) + cross_user_chunk_status, cross_user_chunk_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(intruder["user_id"]), + "task_artifact_chunk_id": str(owner["first_chunk_id"]), + "embedding_config_id": str(intruder_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + cross_user_config_status, cross_user_config_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(intruder["user_id"]), + "task_artifact_chunk_id": str(intruder["first_chunk_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + + assert missing_config_status == 400 + assert missing_config_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert missing_chunk_status == 400 + assert missing_chunk_payload["detail"].startswith( + "task_artifact_chunk_id must reference an existing task artifact chunk owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "vector length must match embedding config dimensions (3): 2" + assert cross_user_chunk_status == 400 + assert cross_user_chunk_payload["detail"] == ( + "task_artifact_chunk_id must reference an existing task artifact chunk owned by the " + f"user: {owner['first_chunk_id']}" + ) + assert cross_user_config_status == 400 + assert cross_user_config_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_task_artifact_chunk_embedding_reads_respect_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_task_artifact_with_chunks(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task_artifact_with_chunks( + migrated_database_urls["app"], + email="intruder@example.com", + ) + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-14", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + write_status, write_payload = invoke_request( + "POST", + "/v0/task-artifact-chunk-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "task_artifact_chunk_id": str(owner["first_chunk_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + artifact_list_status, artifact_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{owner['task_artifact_id']}/chunk-embeddings", + query_params={"user_id": str(intruder["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifact-chunks/{owner['first_chunk_id']}/embeddings", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/task-artifact-chunk-embeddings/{write_payload['embedding']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert write_status == 201 + assert artifact_list_status == 404 + assert artifact_list_payload == { + "detail": f"task artifact {owner['task_artifact_id']} was not found" + } + assert chunk_list_status == 404 + assert chunk_list_payload == { + "detail": f"task artifact chunk {owner['first_chunk_id']} was not found" + } + assert detail_status == 404 + assert detail_payload == { + "detail": ( + f"task artifact chunk embedding {write_payload['embedding']['id']} was not found" + ) + } diff --git a/tests/unit/test_20260314_0025_task_artifact_chunk_embeddings.py b/tests/unit/test_20260314_0025_task_artifact_chunk_embeddings.py new file mode 100644 index 0000000..0844b8e --- /dev/null +++ b/tests/unit/test_20260314_0025_task_artifact_chunk_embeddings.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260314_0025_task_artifact_chunk_embeddings" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_artifact_chunk_embeddings ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_artifact_chunk_embeddings FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_artifact_chunk_embedding_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON task_artifact_chunk_embeddings TO alicebot_app", + ) diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 7afeb2c..9c7a926 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -7,12 +7,15 @@ import pytest import apps.api.src.alicebot_api.main as main_module from apps.api.src.alicebot_api.config import Settings +from alicebot_api.artifacts import TaskArtifactNotFoundError from alicebot_api.compiler import CompiledTraceRun from alicebot_api.contracts import AdmissionDecisionOutput from alicebot_api.embedding import ( EmbeddingConfigValidationError, MemoryEmbeddingNotFoundError, MemoryEmbeddingValidationError, + TaskArtifactChunkEmbeddingNotFoundError, + TaskArtifactChunkEmbeddingValidationError, ) from alicebot_api.entity import EntityNotFoundError, EntityValidationError from alicebot_api.entity_edge import EntityEdgeValidationError @@ -112,6 +115,10 @@ def test_healthcheck_route_is_registered() -> None: assert "/v0/memory-embeddings" in route_paths assert "/v0/memories/{memory_id}/embeddings" in route_paths assert "/v0/memory-embeddings/{memory_embedding_id}" in route_paths + assert "/v0/task-artifact-chunk-embeddings" in route_paths + assert "/v0/task-artifacts/{task_artifact_id}/chunk-embeddings" in route_paths + assert "/v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings" in route_paths + assert "/v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}" in route_paths assert "/v0/entities" in route_paths assert "/v0/entity-edges" in route_paths assert "/v0/tools/route" in route_paths @@ -2100,6 +2107,226 @@ def fake_user_connection(_database_url: str, _current_user_id): } +def test_task_artifact_chunk_embedding_routes_success_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + chunk_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_upsert_task_artifact_chunk_embedding_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "embedding": { + "id": "artifact-embedding-123", + "task_artifact_id": "artifact-123", + "task_artifact_chunk_id": str(chunk_id), + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": str(config_id), + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-14T12:00:00+00:00", + "updated_at": "2026-03-14T12:00:00+00:00", + }, + "write_mode": "created", + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "upsert_task_artifact_chunk_embedding_record", + fake_upsert_task_artifact_chunk_embedding_record, + ) + + response = main_module.upsert_task_artifact_chunk_embedding( + main_module.UpsertTaskArtifactChunkEmbeddingRequest( + user_id=user_id, + task_artifact_chunk_id=chunk_id, + embedding_config_id=config_id, + vector=[0.1, 0.2, 0.3], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["write_mode"] == "created" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].task_artifact_chunk_id == chunk_id + + monkeypatch.setattr( + main_module, + "upsert_task_artifact_chunk_embedding_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + TaskArtifactChunkEmbeddingValidationError( + "task_artifact_chunk_id must reference an existing task artifact chunk owned by the user" + ) + ), + ) + + error_response = main_module.upsert_task_artifact_chunk_embedding( + main_module.UpsertTaskArtifactChunkEmbeddingRequest( + user_id=user_id, + task_artifact_chunk_id=chunk_id, + embedding_config_id=config_id, + vector=[0.1, 0.2, 0.3], + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "task_artifact_chunk_id must reference an existing task artifact chunk owned by the user" + } + + +def test_task_artifact_chunk_embedding_read_routes_return_payload_and_not_found(monkeypatch) -> None: + user_id = uuid4() + artifact_id = uuid4() + chunk_id = uuid4() + embedding_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_artifact_chunk_embedding_records_for_artifact", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "total_count": 0, + "order": ["task_artifact_chunk_sequence_no_asc", "created_at_asc", "id_asc"], + "scope": { + "kind": "artifact", + "task_artifact_id": str(artifact_id), + }, + }, + }, + ) + monkeypatch.setattr( + main_module, + "list_task_artifact_chunk_embedding_records_for_chunk", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "total_count": 0, + "order": ["task_artifact_chunk_sequence_no_asc", "created_at_asc", "id_asc"], + "scope": { + "kind": "chunk", + "task_artifact_id": str(artifact_id), + "task_artifact_chunk_id": str(chunk_id), + }, + }, + }, + ) + monkeypatch.setattr( + main_module, + "get_task_artifact_chunk_embedding_record", + lambda *_args, **_kwargs: { + "embedding": { + "id": str(embedding_id), + "task_artifact_id": str(artifact_id), + "task_artifact_chunk_id": str(chunk_id), + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": "config-123", + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-14T12:00:00+00:00", + "updated_at": "2026-03-14T12:00:00+00:00", + } + }, + ) + + artifact_response = main_module.list_task_artifact_chunk_embeddings_for_artifact( + task_artifact_id=artifact_id, + user_id=user_id, + ) + chunk_response = main_module.list_task_artifact_chunk_embeddings( + task_artifact_chunk_id=chunk_id, + user_id=user_id, + ) + detail_response = main_module.get_task_artifact_chunk_embedding( + task_artifact_chunk_embedding_id=embedding_id, + user_id=user_id, + ) + + assert artifact_response.status_code == 200 + assert json.loads(artifact_response.body)["summary"]["scope"]["task_artifact_id"] == str( + artifact_id + ) + assert chunk_response.status_code == 200 + assert json.loads(chunk_response.body)["summary"]["scope"]["task_artifact_chunk_id"] == str( + chunk_id + ) + assert detail_response.status_code == 200 + assert json.loads(detail_response.body)["embedding"]["id"] == str(embedding_id) + + monkeypatch.setattr( + main_module, + "list_task_artifact_chunk_embedding_records_for_artifact", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + TaskArtifactNotFoundError(f"task artifact {artifact_id} was not found") + ), + ) + monkeypatch.setattr( + main_module, + "list_task_artifact_chunk_embedding_records_for_chunk", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + TaskArtifactChunkEmbeddingNotFoundError( + f"task artifact chunk {chunk_id} was not found" + ) + ), + ) + monkeypatch.setattr( + main_module, + "get_task_artifact_chunk_embedding_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + TaskArtifactChunkEmbeddingNotFoundError( + f"task artifact chunk embedding {embedding_id} was not found" + ) + ), + ) + + missing_artifact_response = main_module.list_task_artifact_chunk_embeddings_for_artifact( + task_artifact_id=artifact_id, + user_id=user_id, + ) + missing_chunk_response = main_module.list_task_artifact_chunk_embeddings( + task_artifact_chunk_id=chunk_id, + user_id=user_id, + ) + missing_detail_response = main_module.get_task_artifact_chunk_embedding( + task_artifact_chunk_embedding_id=embedding_id, + user_id=user_id, + ) + + assert missing_artifact_response.status_code == 404 + assert json.loads(missing_artifact_response.body) == { + "detail": f"task artifact {artifact_id} was not found" + } + assert missing_chunk_response.status_code == 404 + assert json.loads(missing_chunk_response.body) == { + "detail": f"task artifact chunk {chunk_id} was not found" + } + assert missing_detail_response.status_code == 404 + assert json.loads(missing_detail_response.body) == { + "detail": f"task artifact chunk embedding {embedding_id} was not found" + } + + def test_create_entity_returns_created_payload(monkeypatch) -> None: user_id = uuid4() first_memory_id = uuid4() diff --git a/tests/unit/test_task_artifact_chunk_embedding.py b/tests/unit/test_task_artifact_chunk_embedding.py new file mode 100644 index 0000000..d70366a --- /dev/null +++ b/tests/unit/test_task_artifact_chunk_embedding.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.artifacts import TaskArtifactNotFoundError +from alicebot_api.contracts import TaskArtifactChunkEmbeddingUpsertInput +from alicebot_api.embedding import ( + TaskArtifactChunkEmbeddingNotFoundError, + TaskArtifactChunkEmbeddingValidationError, + get_task_artifact_chunk_embedding_record, + list_task_artifact_chunk_embedding_records_for_artifact, + list_task_artifact_chunk_embedding_records_for_chunk, + upsert_task_artifact_chunk_embedding_record, +) + + +class TaskArtifactChunkEmbeddingStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 14, 12, 0, tzinfo=UTC) + self.artifacts: dict[UUID, dict[str, object]] = {} + self.chunks: dict[UUID, dict[str, object]] = {} + self.configs: dict[UUID, dict[str, object]] = {} + self.embeddings: list[dict[str, object]] = [] + self.embedding_by_id: dict[UUID, dict[str, object]] = {} + + def create_artifact(self) -> UUID: + artifact_id = uuid4() + self.artifacts[artifact_id] = { + "id": artifact_id, + "task_id": uuid4(), + "task_workspace_id": uuid4(), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "created_at": self.base_time, + "updated_at": self.base_time, + } + return artifact_id + + def create_chunk(self, *, task_artifact_id: UUID, sequence_no: int) -> UUID: + chunk_id = uuid4() + self.chunks[chunk_id] = { + "id": chunk_id, + "task_artifact_id": task_artifact_id, + "sequence_no": sequence_no, + "char_start": (sequence_no - 1) * 10, + "char_end_exclusive": sequence_no * 10, + "text": f"chunk-{sequence_no}", + "created_at": self.base_time + timedelta(minutes=sequence_no), + "updated_at": self.base_time + timedelta(minutes=sequence_no), + } + return chunk_id + + def create_config(self, *, dimensions: int = 3) -> UUID: + config_id = uuid4() + self.configs[config_id] = { + "id": config_id, + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-14", + "dimensions": dimensions, + "status": "active", + "metadata": {"task": "artifact_chunk_retrieval"}, + "created_at": self.base_time, + } + return config_id + + def get_task_artifact_optional(self, task_artifact_id: UUID) -> dict[str, object] | None: + return self.artifacts.get(task_artifact_id) + + def get_task_artifact_chunk_optional( + self, + task_artifact_chunk_id: UUID, + ) -> dict[str, object] | None: + return self.chunks.get(task_artifact_chunk_id) + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> dict[str, object] | None: + return self.configs.get(embedding_config_id) + + def get_task_artifact_chunk_embedding_by_chunk_and_config_optional( + self, + *, + task_artifact_chunk_id: UUID, + embedding_config_id: UUID, + ) -> dict[str, object] | None: + return next( + ( + embedding + for embedding in self.embeddings + if embedding["task_artifact_chunk_id"] == task_artifact_chunk_id + and embedding["embedding_config_id"] == embedding_config_id + ), + None, + ) + + def create_task_artifact_chunk_embedding( + self, + *, + task_artifact_chunk_id: UUID, + embedding_config_id: UUID, + dimensions: int, + vector: list[float], + ) -> dict[str, object]: + chunk = self.chunks[task_artifact_chunk_id] + embedding_id = uuid4() + record = { + "id": embedding_id, + "user_id": uuid4(), + "task_artifact_id": chunk["task_artifact_id"], + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": chunk["sequence_no"], + "embedding_config_id": embedding_config_id, + "dimensions": dimensions, + "vector": vector, + "created_at": self.base_time + timedelta(seconds=len(self.embeddings)), + "updated_at": self.base_time + timedelta(seconds=len(self.embeddings)), + } + self.embeddings.append(record) + self.embedding_by_id[embedding_id] = record + return record + + def update_task_artifact_chunk_embedding( + self, + *, + task_artifact_chunk_embedding_id: UUID, + dimensions: int, + vector: list[float], + ) -> dict[str, object]: + record = self.embedding_by_id[task_artifact_chunk_embedding_id] + updated = { + **record, + "dimensions": dimensions, + "vector": vector, + "updated_at": self.base_time + timedelta(minutes=10), + } + self.embedding_by_id[task_artifact_chunk_embedding_id] = updated + for index, existing in enumerate(self.embeddings): + if existing["id"] == task_artifact_chunk_embedding_id: + self.embeddings[index] = updated + return updated + + def get_task_artifact_chunk_embedding_optional( + self, + task_artifact_chunk_embedding_id: UUID, + ) -> dict[str, object] | None: + return self.embedding_by_id.get(task_artifact_chunk_embedding_id) + + def list_task_artifact_chunk_embeddings_for_artifact( + self, + task_artifact_id: UUID, + ) -> list[dict[str, object]]: + return sorted( + ( + embedding + for embedding in self.embeddings + if embedding["task_artifact_id"] == task_artifact_id + ), + key=lambda embedding: ( + embedding["task_artifact_chunk_sequence_no"], + embedding["created_at"], + embedding["id"], + ), + ) + + def list_task_artifact_chunk_embeddings_for_chunk( + self, + task_artifact_chunk_id: UUID, + ) -> list[dict[str, object]]: + return sorted( + ( + embedding + for embedding in self.embeddings + if embedding["task_artifact_chunk_id"] == task_artifact_chunk_id + ), + key=lambda embedding: ( + embedding["task_artifact_chunk_sequence_no"], + embedding["created_at"], + embedding["id"], + ), + ) + + +def test_task_artifact_chunk_embedding_writes_and_reads_are_deterministic() -> None: + store = TaskArtifactChunkEmbeddingStoreStub() + artifact_id = store.create_artifact() + first_chunk_id = store.create_chunk(task_artifact_id=artifact_id, sequence_no=1) + second_chunk_id = store.create_chunk(task_artifact_id=artifact_id, sequence_no=2) + config_id = store.create_config() + + second_write = upsert_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=second_chunk_id, + embedding_config_id=config_id, + vector=(0.4, 0.5, 0.6), + ), + ) + first_write = upsert_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=first_chunk_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + updated = upsert_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=second_chunk_id, + embedding_config_id=config_id, + vector=(0.9, 0.8, 0.7), + ), + ) + + artifact_payload = list_task_artifact_chunk_embedding_records_for_artifact( + store, # type: ignore[arg-type] + user_id=uuid4(), + task_artifact_id=artifact_id, + ) + chunk_payload = list_task_artifact_chunk_embedding_records_for_chunk( + store, # type: ignore[arg-type] + user_id=uuid4(), + task_artifact_chunk_id=second_chunk_id, + ) + detail_payload = get_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + task_artifact_chunk_embedding_id=UUID(updated["embedding"]["id"]), + ) + + assert second_write["write_mode"] == "created" + assert first_write["write_mode"] == "created" + assert updated["write_mode"] == "updated" + assert updated["embedding"]["vector"] == [0.9, 0.8, 0.7] + assert [item["task_artifact_chunk_id"] for item in artifact_payload["items"]] == [ + str(first_chunk_id), + str(second_chunk_id), + ] + assert artifact_payload["summary"] == { + "total_count": 2, + "order": ["task_artifact_chunk_sequence_no_asc", "created_at_asc", "id_asc"], + "scope": { + "kind": "artifact", + "task_artifact_id": str(artifact_id), + }, + } + assert chunk_payload["summary"] == { + "total_count": 1, + "order": ["task_artifact_chunk_sequence_no_asc", "created_at_asc", "id_asc"], + "scope": { + "kind": "chunk", + "task_artifact_id": str(artifact_id), + "task_artifact_chunk_id": str(second_chunk_id), + }, + } + assert detail_payload["embedding"]["id"] == updated["embedding"]["id"] + assert detail_payload["embedding"]["task_artifact_chunk_sequence_no"] == 2 + + +def test_task_artifact_chunk_embedding_writes_reject_missing_refs_and_dimension_mismatch() -> None: + store = TaskArtifactChunkEmbeddingStoreStub() + artifact_id = store.create_artifact() + chunk_id = store.create_chunk(task_artifact_id=artifact_id, sequence_no=1) + config_id = store.create_config(dimensions=3) + + with pytest.raises( + TaskArtifactChunkEmbeddingValidationError, + match="task_artifact_chunk_id must reference an existing task artifact chunk owned by the user", + ): + upsert_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=uuid4(), + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + with pytest.raises( + TaskArtifactChunkEmbeddingValidationError, + match="embedding_config_id must reference an existing embedding config owned by the user", + ): + upsert_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=chunk_id, + embedding_config_id=uuid4(), + vector=(0.1, 0.2, 0.3), + ), + ) + + with pytest.raises( + TaskArtifactChunkEmbeddingValidationError, + match=r"vector length must match embedding config dimensions \(3\): 2", + ): + upsert_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskArtifactChunkEmbeddingUpsertInput( + task_artifact_chunk_id=chunk_id, + embedding_config_id=config_id, + vector=(0.1, 0.2), + ), + ) + + +def test_task_artifact_chunk_embedding_reads_raise_not_found_when_scope_is_missing() -> None: + store = TaskArtifactChunkEmbeddingStoreStub() + + with pytest.raises(TaskArtifactNotFoundError, match="task artifact .* was not found"): + list_task_artifact_chunk_embedding_records_for_artifact( + store, # type: ignore[arg-type] + user_id=uuid4(), + task_artifact_id=uuid4(), + ) + + with pytest.raises( + TaskArtifactChunkEmbeddingNotFoundError, + match="task artifact chunk .* was not found", + ): + list_task_artifact_chunk_embedding_records_for_chunk( + store, # type: ignore[arg-type] + user_id=uuid4(), + task_artifact_chunk_id=uuid4(), + ) + + with pytest.raises( + TaskArtifactChunkEmbeddingNotFoundError, + match="task artifact chunk embedding .* was not found", + ): + get_task_artifact_chunk_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + task_artifact_chunk_embedding_id=uuid4(), + ) diff --git a/tests/unit/test_task_artifact_chunk_embedding_store.py b/tests/unit/test_task_artifact_chunk_embedding_store.py new file mode 100644 index 0000000..227a191 --- /dev/null +++ b/tests/unit/test_task_artifact_chunk_embedding_store.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_task_artifact_chunk_embedding_store_methods_use_expected_queries() -> None: + task_artifact_id = uuid4() + task_artifact_chunk_id = uuid4() + embedding_config_id = uuid4() + embedding_id = uuid4() + created_at = datetime(2026, 3, 14, 12, 0, tzinfo=UTC) + updated_at = datetime(2026, 3, 14, 12, 5, tzinfo=UTC) + cursor = RecordingCursor( + fetchone_results=[ + { + "id": task_artifact_chunk_id, + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "sequence_no": 2, + "char_start": 10, + "char_end_exclusive": 20, + "text": "chunk-2", + "created_at": created_at, + "updated_at": created_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": created_at, + "updated_at": created_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + "dimensions": 3, + "vector": [0.3, 0.2, 0.1], + "created_at": created_at, + "updated_at": updated_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + "dimensions": 3, + "vector": [0.3, 0.2, 0.1], + "created_at": created_at, + "updated_at": updated_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + "dimensions": 3, + "vector": [0.3, 0.2, 0.1], + "created_at": created_at, + "updated_at": updated_at, + }, + ], + fetchall_results=[ + [ + { + "id": embedding_id, + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + } + ], + [ + { + "id": embedding_id, + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + } + ], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + fetched_chunk = store.get_task_artifact_chunk_optional(task_artifact_chunk_id) + created = store.create_task_artifact_chunk_embedding( + task_artifact_chunk_id=task_artifact_chunk_id, + embedding_config_id=embedding_config_id, + dimensions=3, + vector=[0.1, 0.2, 0.3], + ) + updated = store.update_task_artifact_chunk_embedding( + task_artifact_chunk_embedding_id=embedding_id, + dimensions=3, + vector=[0.3, 0.2, 0.1], + ) + fetched_embedding = store.get_task_artifact_chunk_embedding_optional(embedding_id) + existing = store.get_task_artifact_chunk_embedding_by_chunk_and_config_optional( + task_artifact_chunk_id=task_artifact_chunk_id, + embedding_config_id=embedding_config_id, + ) + listed_for_chunk = store.list_task_artifact_chunk_embeddings_for_chunk(task_artifact_chunk_id) + listed_for_artifact = store.list_task_artifact_chunk_embeddings_for_artifact(task_artifact_id) + + assert fetched_chunk is not None + assert fetched_chunk["id"] == task_artifact_chunk_id + assert created["id"] == embedding_id + assert updated["updated_at"] == updated_at + assert fetched_embedding is not None + assert existing is not None + assert listed_for_chunk == [ + { + "id": embedding_id, + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + } + ] + assert listed_for_artifact == [ + { + "id": embedding_id, + "task_artifact_id": task_artifact_id, + "task_artifact_chunk_id": task_artifact_chunk_id, + "task_artifact_chunk_sequence_no": 2, + "embedding_config_id": embedding_config_id, + } + ] + + get_chunk_query, get_chunk_params = cursor.executed[0] + assert "FROM task_artifact_chunks" in get_chunk_query + assert get_chunk_params == (task_artifact_chunk_id,) + + create_query, create_params = cursor.executed[1] + assert "INSERT INTO task_artifact_chunk_embeddings" in create_query + assert "JOIN task_artifact_chunks AS chunks" in create_query + assert create_params is not None + assert create_params[:3] == (task_artifact_chunk_id, embedding_config_id, 3) + assert isinstance(create_params[3], Jsonb) + assert create_params[3].obj == [0.1, 0.2, 0.3] + + update_query, update_params = cursor.executed[2] + assert "UPDATE task_artifact_chunk_embeddings" in update_query + assert update_params is not None + assert update_params[0] == 3 + assert isinstance(update_params[1], Jsonb) + assert update_params[1].obj == [0.3, 0.2, 0.1] + assert update_params[2] == embedding_id + + get_embedding_query, get_embedding_params = cursor.executed[3] + assert "FROM task_artifact_chunk_embeddings AS embeddings" in get_embedding_query + assert get_embedding_params == (embedding_id,) + + get_existing_query, get_existing_params = cursor.executed[4] + assert "WHERE embeddings.task_artifact_chunk_id = %s" in get_existing_query + assert "AND embeddings.embedding_config_id = %s" in get_existing_query + assert get_existing_params == (task_artifact_chunk_id, embedding_config_id) + + list_chunk_query, list_chunk_params = cursor.executed[5] + assert "WHERE embeddings.task_artifact_chunk_id = %s" in list_chunk_query + assert "ORDER BY chunks.sequence_no ASC, embeddings.created_at ASC, embeddings.id ASC" in list_chunk_query + assert list_chunk_params == (task_artifact_chunk_id,) + + list_artifact_query, list_artifact_params = cursor.executed[6] + assert "WHERE chunks.task_artifact_id = %s" in list_artifact_query + assert "ORDER BY chunks.sequence_no ASC, embeddings.created_at ASC, embeddings.id ASC" in list_artifact_query + assert list_artifact_params == (task_artifact_id,) + + +def test_task_artifact_chunk_embedding_store_optional_reads_return_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_task_artifact_chunk_optional(uuid4()) is None + assert store.get_task_artifact_chunk_embedding_optional(uuid4()) is None + assert store.get_task_artifact_chunk_embedding_by_chunk_and_config_optional( + task_artifact_chunk_id=uuid4(), + embedding_config_id=uuid4(), + ) is None From ba6b98243634f1f92c843bc1dfee877359f43957 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 15 Mar 2026 21:08:13 +0100 Subject: [PATCH 008/135] Sprint 5H: semantic artifact chunk retrieval primitive (#8) * Sprint 5H: semantic artifact chunk retrieval packet * Sprint 5H: semantic artifact chunk retrieval primitive --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 25 +- apps/api/src/alicebot_api/contracts.py | 67 +++ apps/api/src/alicebot_api/main.py | 76 +++ .../src/alicebot_api/semantic_retrieval.py | 243 +++++++- apps/api/src/alicebot_api/store.py | 121 ++++ ...t_semantic_artifact_chunk_retrieval_api.py | 569 ++++++++++++++++++ tests/unit/test_artifacts_main.py | 136 +++++ tests/unit/test_main.py | 2 + tests/unit/test_semantic_retrieval.py | 289 ++++++++- ...est_task_artifact_chunk_embedding_store.py | 82 +++ 10 files changed, 1587 insertions(+), 23 deletions(-) create mode 100644 tests/integration/test_semantic_artifact_chunk_retrieval_api.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 0e9f72d..0d939f7 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5G. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5H. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, optional compile-path artifact chunk inclusion as a separate context section, and explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, optional compile-path artifact chunk inclusion as a separate context section, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, and explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, optional compile-path inclusion of retrieved artifact chunks in a separate response section, and explicit artifact-chunk embedding storage tied to existing embedding configs. Broader runner-style orchestration, automatic multi-step progression, artifact-chunk semantic retrieval, rich-document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, optional compile-path inclusion of retrieved artifact chunks in a separate response section, explicit artifact-chunk embedding storage tied to existing embedding configs, and direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time. Broader runner-style orchestration, automatic multi-step progression, compile-path semantic artifact use, hybrid artifact retrieval, richer document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -24,7 +24,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` -- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `POST /v0/task-artifact-chunk-embeddings`, `GET /v0/task-artifacts/{task_artifact_id}/chunk-embeddings`, `GET /v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings`, `GET /v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` +- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `POST /v0/tasks/{task_id}/artifact-chunks/semantic-retrieval`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/semantic-retrieval`, `POST /v0/task-artifact-chunk-embeddings`, `GET /v0/task-artifacts/{task_artifact_id}/chunk-embeddings`, `GET /v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings`, `GET /v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` - `apps/web` and `workers` remain starter shells only. ### Data Foundation @@ -58,11 +58,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, and narrow compile-path artifact chunk inclusion. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, and narrow compile-path artifact chunk inclusion. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, and Sprint 5G artifact-chunk embedding persistence and reads. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, and Sprint 5H semantic artifact-chunk retrieval. ## Core Flows Implemented Now @@ -76,6 +76,15 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 6. Persist a `context.compile` trace plus explicit inclusion and exclusion events, including artifact chunk include/exclude decisions. 7. Return one deterministic `context_pack` describing scope, limits, selected context, artifact chunk results, and trace metadata. +### Artifact Chunk Retrieval + +1. Register and ingest visible local artifacts into durable `task_artifacts` and `task_artifact_chunks`. +2. Persist explicit artifact-chunk embeddings in `task_artifact_chunk_embeddings`, keyed to an existing visible embedding config. +3. Support deterministic lexical artifact-chunk retrieval for one visible task or one visible artifact. +4. Support deterministic semantic artifact-chunk retrieval for one visible task or one visible artifact, using a caller-supplied query vector plus explicit `embedding_config_id`. +5. Exclude artifacts whose `ingestion_status` is not `ingested`. +6. Keep compile-path artifact retrieval separate and lexical-only for now; semantic artifact retrieval remains a direct read seam outside compile. + ### Governed Memory And Retrieval 1. Accept explicit memory candidates through `POST /v0/memories/admit`. @@ -227,7 +236,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ## Testing Coverage Implemented Now -- Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. +- Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, artifact semantic retrieval, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. - Sprint 4O, Sprint 4S, Sprint 5A, and Sprint 5C added explicit task lifecycle coverage: - migrations for `tasks`, `task_steps`, and task-step lineage - staged/backfilled migration coverage for `tool_executions.task_step_id` @@ -260,7 +269,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i The following areas remain planned later and must not be described as implemented: - runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam -- artifact chunk ranking beyond the current lexical match ordering, plus embeddings and semantic retrieval for artifact chunks +- hybrid lexical plus semantic artifact retrieval, compile-path semantic artifact use, and reranking beyond the current direct lexical and direct semantic ordering seams - rich document parsing beyond the current narrow UTF-8 text and markdown ingestion boundary - read-only Gmail and Calendar connectors - broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index c624578..07fa214 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -149,6 +149,12 @@ "sequence_no_asc", "id_asc", ] +TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER = [ + "score_desc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", +] TASK_STEP_LIST_ORDER = ["sequence_no_asc", "created_at_asc", "id_asc"] TOOL_EXECUTION_LIST_ORDER = ["executed_at_asc", "id_asc"] EXECUTION_BUDGET_LIST_ORDER = ["created_at_asc", "id_asc"] @@ -1736,6 +1742,38 @@ class ArtifactScopedArtifactChunkRetrievalInput: query: str +@dataclass(frozen=True, slots=True) +class TaskScopedSemanticArtifactChunkRetrievalInput: + task_id: UUID + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "task_id": str(self.task_id), + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class ArtifactScopedSemanticArtifactChunkRetrievalInput: + task_artifact_id: UUID + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "task_artifact_id": str(self.task_artifact_id), + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + class TaskArtifactRecord(TypedDict): id: str task_id: str @@ -1873,6 +1911,35 @@ class TaskArtifactChunkRetrievalResponse(TypedDict): summary: TaskArtifactChunkRetrievalSummary +class TaskArtifactChunkSemanticRetrievalItem(TypedDict): + id: str + task_id: str + task_artifact_id: str + relative_path: str + media_type: str + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + score: float + + +class TaskArtifactChunkSemanticRetrievalSummary(TypedDict): + embedding_config_id: str + query_vector_dimensions: int + limit: int + returned_count: int + searched_artifact_count: int + similarity_metric: Literal["cosine_similarity"] + order: list[str] + scope: TaskArtifactChunkRetrievalScope + + +class TaskArtifactChunkSemanticRetrievalResponse(TypedDict): + items: list[TaskArtifactChunkSemanticRetrievalItem] + summary: TaskArtifactChunkSemanticRetrievalSummary + + class TaskStepTraceLink(TypedDict): trace_id: str trace_kind: str diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index 0106917..fb487e6 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -15,6 +15,7 @@ ApprovalApproveInput, ApprovalRejectInput, ApprovalRequestCreateInput, + ArtifactScopedSemanticArtifactChunkRetrievalInput, CompileContextArtifactScopedArtifactRetrievalInput, CompileContextTaskScopedArtifactRetrievalInput, ConsentStatus, @@ -58,6 +59,7 @@ ProxyExecutionRequestInput, TaskArtifactIngestInput, TaskArtifactRegisterInput, + TaskScopedSemanticArtifactChunkRetrievalInput, TaskScopedArtifactChunkRetrievalInput, TaskStepKind, TaskStepLineageInput, @@ -197,8 +199,11 @@ route_tool_invocation, ) from alicebot_api.semantic_retrieval import ( + SemanticArtifactChunkRetrievalValidationError, SemanticMemoryRetrievalValidationError, + retrieve_artifact_scoped_semantic_artifact_chunk_records, retrieve_semantic_memory_records, + retrieve_task_scoped_semantic_artifact_chunk_records, ) from alicebot_api.response_generation import ( ResponseFailure, @@ -380,6 +385,17 @@ class RetrieveSemanticMemoriesRequest(BaseModel): ) +class RetrieveSemanticArtifactChunksRequest(BaseModel): + user_id: UUID + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ge=1, + le=MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ) + + class UpsertConsentRequest(BaseModel): user_id: UUID consent_key: str = Field(min_length=1, max_length=200) @@ -1447,6 +1463,66 @@ def retrieve_task_artifact_chunks_for_artifact( ) +@app.post("/v0/tasks/{task_id}/artifact-chunks/semantic-retrieval") +def retrieve_semantic_task_artifact_chunks( + task_id: UUID, + request: RetrieveSemanticArtifactChunksRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = retrieve_task_scoped_semantic_artifact_chunk_records( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskScopedSemanticArtifactChunkRetrievalInput( + task_id=task_id, + embedding_config_id=request.embedding_config_id, + query_vector=tuple(request.query_vector), + limit=request.limit, + ), + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except SemanticArtifactChunkRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/task-artifacts/{task_artifact_id}/chunks/semantic-retrieval") +def retrieve_semantic_artifact_chunks_for_artifact( + task_artifact_id: UUID, + request: RetrieveSemanticArtifactChunksRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = retrieve_artifact_scoped_semantic_artifact_chunk_records( + ContinuityStore(conn), + user_id=request.user_id, + request=ArtifactScopedSemanticArtifactChunkRetrievalInput( + task_artifact_id=task_artifact_id, + embedding_config_id=request.embedding_config_id, + query_vector=tuple(request.query_vector), + limit=request.limit, + ), + ) + except TaskArtifactNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except SemanticArtifactChunkRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.post("/v0/tasks/{task_id}/steps") def create_next_task_step(task_id: UUID, request: CreateNextTaskStepRequest) -> JSONResponse: settings = get_settings() diff --git a/apps/api/src/alicebot_api/semantic_retrieval.py b/apps/api/src/alicebot_api/semantic_retrieval.py index 5384e3d..4fe066d 100644 --- a/apps/api/src/alicebot_api/semantic_retrieval.py +++ b/apps/api/src/alicebot_api/semantic_retrieval.py @@ -1,25 +1,56 @@ from __future__ import annotations import math +from pathlib import Path +from typing import cast from uuid import UUID +from alicebot_api.artifacts import TaskArtifactNotFoundError from alicebot_api.contracts import ( SEMANTIC_MEMORY_RETRIEVAL_ORDER, + TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER, + ArtifactScopedSemanticArtifactChunkRetrievalInput, SemanticMemoryRetrievalRequestInput, SemanticMemoryRetrievalResponse, SemanticMemoryRetrievalResultItem, SemanticMemoryRetrievalSummary, + TaskArtifactChunkRetrievalScope, + TaskArtifactChunkRetrievalScopeKind, + TaskArtifactChunkSemanticRetrievalItem, + TaskArtifactChunkSemanticRetrievalResponse, + TaskArtifactChunkSemanticRetrievalSummary, + TaskScopedSemanticArtifactChunkRetrievalInput, ) -from alicebot_api.store import ContinuityStore, SemanticMemoryRetrievalRow +from alicebot_api.store import ( + ContinuityStore, + SemanticMemoryRetrievalRow, + TaskArtifactChunkSemanticRetrievalRow, +) +from alicebot_api.tasks import TaskNotFoundError + +SUPPORTED_TEXT_ARTIFACT_EXTENSIONS = { + ".txt": "text/plain", + ".text": "text/plain", + ".md": "text/markdown", + ".markdown": "text/markdown", +} class SemanticMemoryRetrievalValidationError(ValueError): """Raised when semantic memory retrieval fails explicit validation.""" -def _validate_query_vector(query_vector: tuple[float, ...]) -> list[float]: +class SemanticArtifactChunkRetrievalValidationError(ValueError): + """Raised when semantic artifact chunk retrieval fails explicit validation.""" + + +def _validate_query_vector( + query_vector: tuple[float, ...], + *, + error_type: type[ValueError], +) -> list[float]: if not query_vector: - raise SemanticMemoryRetrievalValidationError( + raise error_type( "query_vector must include at least one numeric value" ) @@ -27,7 +58,7 @@ def _validate_query_vector(query_vector: tuple[float, ...]) -> list[float]: for value in query_vector: normalized_value = float(value) if not math.isfinite(normalized_value): - raise SemanticMemoryRetrievalValidationError( + raise error_type( "query_vector must contain only finite numeric values" ) normalized.append(normalized_value) @@ -35,26 +66,41 @@ def _validate_query_vector(query_vector: tuple[float, ...]) -> list[float]: return normalized -def validate_semantic_memory_retrieval_request( +def _validate_embedding_config_and_query_vector( store: ContinuityStore, *, - request: SemanticMemoryRetrievalRequestInput, + embedding_config_id: UUID, + query_vector: tuple[float, ...], + error_type: type[ValueError], ) -> tuple[dict[str, object], list[float]]: - config = store.get_embedding_config_optional(request.embedding_config_id) + config = store.get_embedding_config_optional(embedding_config_id) if config is None: - raise SemanticMemoryRetrievalValidationError( + raise error_type( "embedding_config_id must reference an existing embedding config owned by the user: " - f"{request.embedding_config_id}" + f"{embedding_config_id}" ) - query_vector = _validate_query_vector(request.query_vector) - if len(query_vector) != config["dimensions"]: - raise SemanticMemoryRetrievalValidationError( + normalized_query_vector = _validate_query_vector(query_vector, error_type=error_type) + if len(normalized_query_vector) != config["dimensions"]: + raise error_type( "query_vector length must match embedding config dimensions " - f"({config['dimensions']}): {len(query_vector)}" + f"({config['dimensions']}): {len(normalized_query_vector)}" ) - return config, query_vector + return config, normalized_query_vector + + +def validate_semantic_memory_retrieval_request( + store: ContinuityStore, + *, + request: SemanticMemoryRetrievalRequestInput, +) -> tuple[dict[str, object], list[float]]: + return _validate_embedding_config_and_query_vector( + store, + embedding_config_id=request.embedding_config_id, + query_vector=request.query_vector, + error_type=SemanticMemoryRetrievalValidationError, + ) def serialize_semantic_memory_result_item( @@ -76,6 +122,175 @@ def serialize_semantic_memory_result_item( } +def _infer_media_type(*, relative_path: str, media_type_hint: str | None) -> str: + if media_type_hint is not None: + return media_type_hint + return SUPPORTED_TEXT_ARTIFACT_EXTENSIONS.get(Path(relative_path).suffix.lower(), "unknown") + + +def _build_task_artifact_chunk_retrieval_scope( + *, + kind: str, + task_id: UUID, + task_artifact_id: UUID | None = None, +) -> TaskArtifactChunkRetrievalScope: + scope: TaskArtifactChunkRetrievalScope = { + "kind": cast(TaskArtifactChunkRetrievalScopeKind, kind), + "task_id": str(task_id), + } + if task_artifact_id is not None: + scope["task_artifact_id"] = str(task_artifact_id) + return scope + + +def _serialize_semantic_artifact_chunk_result_item( + row: TaskArtifactChunkSemanticRetrievalRow, +) -> TaskArtifactChunkSemanticRetrievalItem: + return { + "id": str(row["id"]), + "task_id": str(row["task_id"]), + "task_artifact_id": str(row["task_artifact_id"]), + "relative_path": row["relative_path"], + "media_type": _infer_media_type( + relative_path=row["relative_path"], + media_type_hint=row["media_type_hint"], + ), + "sequence_no": row["sequence_no"], + "char_start": row["char_start"], + "char_end_exclusive": row["char_end_exclusive"], + "text": row["text"], + "score": float(row["score"]), + } + + +def validate_semantic_artifact_chunk_retrieval_request( + store: ContinuityStore, + *, + embedding_config_id: UUID, + query_vector: tuple[float, ...], +) -> tuple[dict[str, object], list[float]]: + return _validate_embedding_config_and_query_vector( + store, + embedding_config_id=embedding_config_id, + query_vector=query_vector, + error_type=SemanticArtifactChunkRetrievalValidationError, + ) + + +def _count_ingested_artifacts(artifact_rows: list[dict[str, object]]) -> int: + return sum(1 for artifact_row in artifact_rows if artifact_row["ingestion_status"] == "ingested") + + +def _build_semantic_artifact_chunk_summary( + *, + embedding_config_id: UUID, + query_vector_dimensions: int, + limit: int, + searched_artifact_count: int, + scope: TaskArtifactChunkRetrievalScope, + items: list[TaskArtifactChunkSemanticRetrievalItem], +) -> TaskArtifactChunkSemanticRetrievalSummary: + return { + "embedding_config_id": str(embedding_config_id), + "query_vector_dimensions": query_vector_dimensions, + "limit": limit, + "returned_count": len(items), + "searched_artifact_count": searched_artifact_count, + "similarity_metric": "cosine_similarity", + "order": list(TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER), + "scope": scope, + } + + +def retrieve_task_scoped_semantic_artifact_chunk_records( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskScopedSemanticArtifactChunkRetrievalInput, +) -> TaskArtifactChunkSemanticRetrievalResponse: + del user_id + + task = store.get_task_optional(request.task_id) + if task is None: + raise TaskNotFoundError(f"task {request.task_id} was not found") + + _config, query_vector = validate_semantic_artifact_chunk_retrieval_request( + store, + embedding_config_id=request.embedding_config_id, + query_vector=request.query_vector, + ) + items = [ + _serialize_semantic_artifact_chunk_result_item(row) + for row in store.retrieve_task_scoped_semantic_artifact_chunk_matches( + task_id=request.task_id, + embedding_config_id=request.embedding_config_id, + query_vector=query_vector, + limit=request.limit, + ) + ] + artifact_rows = store.list_task_artifacts_for_task(request.task_id) + scope = _build_task_artifact_chunk_retrieval_scope( + kind="task", + task_id=request.task_id, + ) + return { + "items": items, + "summary": _build_semantic_artifact_chunk_summary( + embedding_config_id=request.embedding_config_id, + query_vector_dimensions=len(query_vector), + limit=request.limit, + searched_artifact_count=_count_ingested_artifacts(artifact_rows), + scope=scope, + items=items, + ), + } + + +def retrieve_artifact_scoped_semantic_artifact_chunk_records( + store: ContinuityStore, + *, + user_id: UUID, + request: ArtifactScopedSemanticArtifactChunkRetrievalInput, +) -> TaskArtifactChunkSemanticRetrievalResponse: + del user_id + + artifact_row = store.get_task_artifact_optional(request.task_artifact_id) + if artifact_row is None: + raise TaskArtifactNotFoundError(f"task artifact {request.task_artifact_id} was not found") + + _config, query_vector = validate_semantic_artifact_chunk_retrieval_request( + store, + embedding_config_id=request.embedding_config_id, + query_vector=request.query_vector, + ) + items = [ + _serialize_semantic_artifact_chunk_result_item(row) + for row in store.retrieve_artifact_scoped_semantic_artifact_chunk_matches( + task_artifact_id=request.task_artifact_id, + embedding_config_id=request.embedding_config_id, + query_vector=query_vector, + limit=request.limit, + ) + ] + scope = _build_task_artifact_chunk_retrieval_scope( + kind="artifact", + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + ) + searched_artifact_count = 1 if artifact_row["ingestion_status"] == "ingested" else 0 + return { + "items": items, + "summary": _build_semantic_artifact_chunk_summary( + embedding_config_id=request.embedding_config_id, + query_vector_dimensions=len(query_vector), + limit=request.limit, + searched_artifact_count=searched_artifact_count, + scope=scope, + items=items, + ), + } + + def retrieve_semantic_memory_records( store: ContinuityStore, *, diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index 206d168..c2f0771 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -283,6 +283,23 @@ class TaskArtifactChunkEmbeddingRow(TypedDict): updated_at: datetime +class TaskArtifactChunkSemanticRetrievalRow(TypedDict): + id: UUID + user_id: UUID + task_id: UUID + task_artifact_id: UUID + relative_path: str + media_type_hint: str | None + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + created_at: datetime + updated_at: datetime + embedding_config_id: UUID + score: float + + class TaskStepRow(TypedDict): id: UUID user_id: UUID @@ -806,6 +823,72 @@ class LabelCountRow(TypedDict): LIMIT %s """ +RETRIEVE_TASK_SCOPED_SEMANTIC_ARTIFACT_CHUNK_MATCHES_SQL = """ + SELECT + chunks.id, + chunks.user_id, + artifacts.task_id, + artifacts.id AS task_artifact_id, + artifacts.relative_path, + artifacts.media_type_hint, + chunks.sequence_no, + chunks.char_start, + chunks.char_end_exclusive, + chunks.text, + chunks.created_at, + chunks.updated_at, + embeddings.embedding_config_id, + 1 - ( + replace(embeddings.vector::text, ' ', '')::vector <=> %s::vector + ) AS score + FROM task_artifact_chunk_embeddings AS embeddings + JOIN task_artifact_chunks AS chunks + ON chunks.id = embeddings.task_artifact_chunk_id + AND chunks.user_id = embeddings.user_id + JOIN task_artifacts AS artifacts + ON artifacts.id = chunks.task_artifact_id + AND artifacts.user_id = chunks.user_id + WHERE embeddings.embedding_config_id = %s + AND embeddings.dimensions = %s + AND artifacts.task_id = %s + AND artifacts.ingestion_status = 'ingested' + ORDER BY score DESC, artifacts.relative_path ASC, chunks.sequence_no ASC, chunks.id ASC + LIMIT %s + """ + +RETRIEVE_ARTIFACT_SCOPED_SEMANTIC_ARTIFACT_CHUNK_MATCHES_SQL = """ + SELECT + chunks.id, + chunks.user_id, + artifacts.task_id, + artifacts.id AS task_artifact_id, + artifacts.relative_path, + artifacts.media_type_hint, + chunks.sequence_no, + chunks.char_start, + chunks.char_end_exclusive, + chunks.text, + chunks.created_at, + chunks.updated_at, + embeddings.embedding_config_id, + 1 - ( + replace(embeddings.vector::text, ' ', '')::vector <=> %s::vector + ) AS score + FROM task_artifact_chunk_embeddings AS embeddings + JOIN task_artifact_chunks AS chunks + ON chunks.id = embeddings.task_artifact_chunk_id + AND chunks.user_id = embeddings.user_id + JOIN task_artifacts AS artifacts + ON artifacts.id = chunks.task_artifact_id + AND artifacts.user_id = chunks.user_id + WHERE embeddings.embedding_config_id = %s + AND embeddings.dimensions = %s + AND artifacts.id = %s + AND artifacts.ingestion_status = 'ingested' + ORDER BY score DESC, artifacts.relative_path ASC, chunks.sequence_no ASC, chunks.id ASC + LIMIT %s + """ + INSERT_ENTITY_SQL = """ INSERT INTO entities (user_id, entity_type, name, source_memory_ids, created_at) VALUES (app.current_user_id(), %s, %s, %s, clock_timestamp()) @@ -2579,6 +2662,44 @@ def retrieve_semantic_memory_matches( ), ) + def retrieve_task_scoped_semantic_artifact_chunk_matches( + self, + *, + task_id: UUID, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[TaskArtifactChunkSemanticRetrievalRow]: + return self._fetch_all( + RETRIEVE_TASK_SCOPED_SEMANTIC_ARTIFACT_CHUNK_MATCHES_SQL, + ( + self._vector_literal(query_vector), + embedding_config_id, + len(query_vector), + task_id, + limit, + ), + ) + + def retrieve_artifact_scoped_semantic_artifact_chunk_matches( + self, + *, + task_artifact_id: UUID, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[TaskArtifactChunkSemanticRetrievalRow]: + return self._fetch_all( + RETRIEVE_ARTIFACT_SCOPED_SEMANTIC_ARTIFACT_CHUNK_MATCHES_SQL, + ( + self._vector_literal(query_vector), + embedding_config_id, + len(query_vector), + task_artifact_id, + limit, + ), + ) + def create_entity( self, *, diff --git a/tests/integration/test_semantic_artifact_chunk_retrieval_api.py b/tests/integration/test_semantic_artifact_chunk_retrieval_api.py new file mode 100644 index 0000000..7ee8cbd --- /dev/null +++ b/tests/integration/test_semantic_artifact_chunk_retrieval_api.py @@ -0,0 +1,569 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_task_with_workspace(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Semantic artifact retrieval thread") + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + task = store.create_task( + thread_id=thread["id"], + tool_id=tool["id"], + status="approved", + request={ + "thread_id": str(thread["id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + tool={ + "id": str(tool["id"]), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + latest_approval_id=None, + latest_execution_id=None, + ) + workspace = store.create_task_workspace( + task_id=task["id"], + status="active", + local_path=f"/tmp/task-workspaces/{user_id}/{task['id']}", + ) + + return { + "user_id": user_id, + "task_id": task["id"], + "task_workspace_id": workspace["id"], + } + + +def seed_embedding_config( + database_url: str, + *, + user_id: UUID, + provider: str, + model: str, + version: str, + dimensions: int, +) -> UUID: + with user_connection(database_url, user_id) as conn: + created = ContinuityStore(conn).create_embedding_config( + provider=provider, + model=model, + version=version, + dimensions=dimensions, + status="active", + metadata={"task": "semantic_artifact_chunk_retrieval"}, + ) + return created["id"] + + +def create_artifact_with_chunk_embeddings( + database_url: str, + *, + user_id: UUID, + task_id: UUID, + task_workspace_id: UUID, + embedding_config_id: UUID | None, + relative_path: str, + chunks: list[tuple[str, list[float] | None]], + ingestion_status: str = "ingested", + media_type_hint: str | None = "text/plain", +) -> dict[str, object]: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status=ingestion_status, + relative_path=relative_path, + media_type_hint=media_type_hint, + ) + created_chunks: list[dict[str, object]] = [] + char_start = 0 + for sequence_no, (text, vector) in enumerate(chunks, start=1): + chunk = store.create_task_artifact_chunk( + task_artifact_id=artifact["id"], + sequence_no=sequence_no, + char_start=char_start, + char_end_exclusive=char_start + len(text), + text=text, + ) + char_start += len(text) + created_chunks.append(chunk) + if embedding_config_id is not None and vector is not None: + store.create_task_artifact_chunk_embedding( + task_artifact_chunk_id=chunk["id"], + embedding_config_id=embedding_config_id, + dimensions=len(vector), + vector=vector, + ) + + return { + "artifact_id": artifact["id"], + "chunk_ids": [chunk["id"] for chunk in created_chunks], + } + + +def test_semantic_artifact_chunk_retrieval_endpoints_return_deterministic_task_and_artifact_results( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_task_with_workspace(migrated_database_urls["app"], email="owner@example.com") + config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-15", + dimensions=3, + ) + docs = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=owner["user_id"], + task_id=owner["task_id"], + task_workspace_id=owner["task_workspace_id"], + embedding_config_id=config_id, + relative_path="docs/a.txt", + chunks=[("alpha doc", [1.0, 0.0, 0.0])], + media_type_hint="text/plain", + ) + notes = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=owner["user_id"], + task_id=owner["task_id"], + task_workspace_id=owner["task_workspace_id"], + embedding_config_id=config_id, + relative_path="notes/b.md", + chunks=[("alpha note", [1.0, 0.0, 0.0])], + media_type_hint="text/markdown", + ) + weak = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=owner["user_id"], + task_id=owner["task_id"], + task_workspace_id=owner["task_workspace_id"], + embedding_config_id=config_id, + relative_path="notes/c.txt", + chunks=[("beta weak", [0.0, 1.0, 0.0])], + media_type_hint="text/plain", + ) + pending = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=owner["user_id"], + task_id=owner["task_id"], + task_workspace_id=owner["task_workspace_id"], + embedding_config_id=config_id, + relative_path="notes/pending.txt", + chunks=[("hidden pending", [1.0, 0.0, 0.0])], + ingestion_status="pending", + media_type_hint="text/plain", + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + task_status, task_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 10, + }, + ) + artifact_status, artifact_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{notes['artifact_id']}/chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 10, + }, + ) + + assert task_status == 200 + assert task_payload["summary"] == { + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 10, + "returned_count": 3, + "searched_artifact_count": 3, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": {"kind": "task", "task_id": str(owner["task_id"])}, + } + assert [item["id"] for item in task_payload["items"]] == [ + str(docs["chunk_ids"][0]), + str(notes["chunk_ids"][0]), + str(weak["chunk_ids"][0]), + ] + assert str(pending["chunk_ids"][0]) not in {item["id"] for item in task_payload["items"]} + assert task_payload["items"][0]["relative_path"] == "docs/a.txt" + assert task_payload["items"][1]["relative_path"] == "notes/b.md" + assert task_payload["items"][0]["score"] == pytest.approx(1.0) + assert task_payload["items"][1]["score"] == pytest.approx(1.0) + assert task_payload["items"][2]["score"] == pytest.approx(0.0) + assert set(task_payload["items"][0]) == { + "id", + "task_id", + "task_artifact_id", + "relative_path", + "media_type", + "sequence_no", + "char_start", + "char_end_exclusive", + "text", + "score", + } + + assert artifact_status == 200 + assert artifact_payload["summary"] == { + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 10, + "returned_count": 1, + "searched_artifact_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": { + "kind": "artifact", + "task_id": str(owner["task_id"]), + "task_artifact_id": str(notes["artifact_id"]), + }, + } + assert artifact_payload["items"] == [ + { + "id": str(notes["chunk_ids"][0]), + "task_id": str(owner["task_id"]), + "task_artifact_id": str(notes["artifact_id"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": len("alpha note"), + "text": "alpha note", + "score": artifact_payload["items"][0]["score"], + } + ] + assert artifact_payload["items"][0]["score"] == pytest.approx(1.0) + + +def test_semantic_artifact_chunk_retrieval_rejects_invalid_config_dimension_mismatch_and_cross_user_scope( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_task_with_workspace(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task_with_workspace(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-15", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-15", + dimensions=3, + ) + owner_artifact = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=owner["user_id"], + task_id=owner["task_id"], + task_workspace_id=owner["task_workspace_id"], + embedding_config_id=owner_config_id, + relative_path="docs/spec.txt", + chunks=[("owner chunk", [1.0, 0.0, 0.0])], + ) + create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=intruder["user_id"], + task_id=intruder["task_id"], + task_workspace_id=intruder["task_workspace_id"], + embedding_config_id=intruder_config_id, + relative_path="docs/intruder.txt", + chunks=[("intruder chunk", [1.0, 0.0, 0.0])], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_status, missing_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(uuid4()), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + mismatch_status, mismatch_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0], + "limit": 5, + }, + ) + cross_user_task_status, cross_user_task_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(intruder_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + cross_user_artifact_status, cross_user_artifact_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{owner_artifact['artifact_id']}/chunks/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(intruder_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + cross_user_config_status, cross_user_config_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(intruder_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + + assert missing_status == 400 + assert missing_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "query_vector length must match embedding config dimensions (3): 2" + assert cross_user_task_status == 404 + assert cross_user_task_payload == {"detail": f"task {owner['task_id']} was not found"} + assert cross_user_artifact_status == 404 + assert cross_user_artifact_payload == { + "detail": f"task artifact {owner_artifact['artifact_id']} was not found" + } + assert cross_user_config_status == 400 + assert cross_user_config_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{intruder_config_id}" + ) + + +def test_semantic_artifact_chunk_retrieval_supports_empty_results_and_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_task_with_workspace(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task_with_workspace(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-15", + dimensions=3, + ) + owner_empty_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-small", + version="2026-03-15", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-15", + dimensions=3, + ) + owner_artifact = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=owner["user_id"], + task_id=owner["task_id"], + task_workspace_id=owner["task_workspace_id"], + embedding_config_id=owner_config_id, + relative_path="docs/owner.txt", + chunks=[("owner semantic", [1.0, 0.0, 0.0])], + ) + intruder_artifact = create_artifact_with_chunk_embeddings( + migrated_database_urls["app"], + user_id=intruder["user_id"], + task_id=intruder["task_id"], + task_workspace_id=intruder["task_workspace_id"], + embedding_config_id=intruder_config_id, + relative_path="docs/intruder.txt", + chunks=[("intruder semantic", [1.0, 0.0, 0.0])], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + owner_status, owner_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + intruder_status, intruder_payload = invoke_request( + "POST", + f"/v0/tasks/{intruder['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(intruder_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + empty_status, empty_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/artifact-chunks/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_empty_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + + assert owner_status == 200 + assert [item["id"] for item in owner_payload["items"]] == [str(owner_artifact["chunk_ids"][0])] + assert intruder_status == 200 + assert [item["id"] for item in intruder_payload["items"]] == [ + str(intruder_artifact["chunk_ids"][0]) + ] + assert empty_status == 200 + assert empty_payload == { + "items": [], + "summary": { + "embedding_config_id": str(owner_empty_config_id), + "query_vector_dimensions": 3, + "limit": 5, + "returned_count": 0, + "searched_artifact_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": {"kind": "task", "task_id": str(owner["task_id"])}, + }, + } diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py index a009b60..634d9b6 100644 --- a/tests/unit/test_artifacts_main.py +++ b/tests/unit/test_artifacts_main.py @@ -12,6 +12,7 @@ TaskArtifactNotFoundError, TaskArtifactValidationError, ) +from alicebot_api.semantic_retrieval import SemanticArtifactChunkRetrievalValidationError from alicebot_api.tasks import TaskNotFoundError from alicebot_api.workspaces import TaskWorkspaceNotFoundError @@ -229,6 +230,141 @@ def fake_retrieve_task_scoped_artifact_chunk_records(*_args, **_kwargs): } +def test_retrieve_semantic_task_artifact_chunks_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_task_scoped_semantic_artifact_chunk_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 5, + "returned_count": 0, + "searched_artifact_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": {"kind": "task", "task_id": str(task_id)}, + }, + }, + ) + + response = main_module.retrieve_semantic_task_artifact_chunks( + task_id, + main_module.RetrieveSemanticArtifactChunksRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[1.0, 0.0, 0.0], + limit=5, + ), + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": { + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 5, + "returned_count": 0, + "searched_artifact_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": {"kind": "task", "task_id": str(task_id)}, + }, + } + + +def test_retrieve_semantic_task_artifact_chunks_endpoint_maps_validation_to_400(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_retrieve_task_scoped_semantic_artifact_chunk_records(*_args, **_kwargs): + raise SemanticArtifactChunkRetrievalValidationError( + f"embedding_config_id must reference an existing embedding config owned by the user: {config_id}" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_task_scoped_semantic_artifact_chunk_records", + fake_retrieve_task_scoped_semantic_artifact_chunk_records, + ) + + response = main_module.retrieve_semantic_task_artifact_chunks( + task_id, + main_module.RetrieveSemanticArtifactChunksRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[1.0, 0.0, 0.0], + limit=5, + ), + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{config_id}" + ) + } + + +def test_retrieve_semantic_artifact_chunk_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_artifact_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_retrieve_artifact_scoped_semantic_artifact_chunk_records(*_args, **_kwargs): + raise TaskArtifactNotFoundError(f"task artifact {task_artifact_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_artifact_scoped_semantic_artifact_chunk_records", + fake_retrieve_artifact_scoped_semantic_artifact_chunk_records, + ) + + response = main_module.retrieve_semantic_artifact_chunks_for_artifact( + task_artifact_id, + main_module.RetrieveSemanticArtifactChunksRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[1.0, 0.0, 0.0], + limit=5, + ), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"task artifact {task_artifact_id} was not found" + } + + def test_retrieve_artifact_chunk_endpoint_maps_not_found_to_404(monkeypatch) -> None: user_id = uuid4() task_artifact_id = uuid4() diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 9c7a926..0b3441d 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -139,6 +139,8 @@ def test_healthcheck_route_is_registered() -> None: assert "/v0/task-artifacts/{task_artifact_id}" in route_paths assert "/v0/task-artifacts/{task_artifact_id}/ingest" in route_paths assert "/v0/task-artifacts/{task_artifact_id}/chunks" in route_paths + assert "/v0/tasks/{task_id}/artifact-chunks/semantic-retrieval" in route_paths + assert "/v0/task-artifacts/{task_artifact_id}/chunks/semantic-retrieval" in route_paths assert "/v0/task-steps/{task_step_id}" in route_paths assert "/v0/task-steps/{task_step_id}/transition" in route_paths assert "/v0/entities/{entity_id}" in route_paths diff --git a/tests/unit/test_semantic_retrieval.py b/tests/unit/test_semantic_retrieval.py index 780b4e4..7f3b26f 100644 --- a/tests/unit/test_semantic_retrieval.py +++ b/tests/unit/test_semantic_retrieval.py @@ -5,11 +5,19 @@ import pytest -from alicebot_api.contracts import SemanticMemoryRetrievalRequestInput +from alicebot_api.contracts import ( + ArtifactScopedSemanticArtifactChunkRetrievalInput, + SemanticMemoryRetrievalRequestInput, + TaskScopedSemanticArtifactChunkRetrievalInput, +) from alicebot_api.semantic_retrieval import ( + SemanticArtifactChunkRetrievalValidationError, SemanticMemoryRetrievalValidationError, + retrieve_artifact_scoped_semantic_artifact_chunk_records, retrieve_semantic_memory_records, + retrieve_task_scoped_semantic_artifact_chunk_records, ) +from alicebot_api.tasks import TaskNotFoundError class SemanticRetrievalStoreStub: @@ -17,6 +25,10 @@ def __init__(self) -> None: self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) self.config_by_id: dict[UUID, dict[str, object]] = {} self.retrieval_rows: list[dict[str, object]] = [] + self.task_artifact_retrieval_rows: list[dict[str, object]] = [] + self.tasks: dict[UUID, dict[str, object]] = {} + self.artifacts_by_id: dict[UUID, dict[str, object]] = {} + self.artifacts_by_task_id: dict[UUID, list[dict[str, object]]] = {} self.last_query: dict[str, object] | None = None def get_embedding_config_optional(self, embedding_config_id: UUID) -> dict[str, object] | None: @@ -36,6 +48,49 @@ def retrieve_semantic_memory_matches( } return list(self.retrieval_rows[:limit]) + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return self.tasks.get(task_id) + + def get_task_artifact_optional(self, task_artifact_id: UUID) -> dict[str, object] | None: + return self.artifacts_by_id.get(task_artifact_id) + + def list_task_artifacts_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return list(self.artifacts_by_task_id.get(task_id, [])) + + def retrieve_task_scoped_semantic_artifact_chunk_matches( + self, + *, + task_id: UUID, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[dict[str, object]]: + self.last_query = { + "scope": "task", + "task_id": task_id, + "embedding_config_id": embedding_config_id, + "query_vector": query_vector, + "limit": limit, + } + return list(self.task_artifact_retrieval_rows[:limit]) + + def retrieve_artifact_scoped_semantic_artifact_chunk_matches( + self, + *, + task_artifact_id: UUID, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[dict[str, object]]: + self.last_query = { + "scope": "artifact", + "task_artifact_id": task_artifact_id, + "embedding_config_id": embedding_config_id, + "query_vector": query_vector, + "limit": limit, + } + return list(self.task_artifact_retrieval_rows[:limit]) + def seed_config(store: SemanticRetrievalStoreStub, *, dimensions: int = 3) -> UUID: config_id = uuid4() @@ -67,6 +122,60 @@ def active_row( } +def seed_task(store: SemanticRetrievalStoreStub) -> UUID: + task_id = uuid4() + store.tasks[task_id] = {"id": task_id} + return task_id + + +def seed_artifact( + store: SemanticRetrievalStoreStub, + *, + task_id: UUID, + ingestion_status: str = "ingested", + relative_path: str = "docs/spec.txt", + media_type_hint: str | None = "text/plain", +) -> UUID: + task_artifact_id = uuid4() + artifact = { + "id": task_artifact_id, + "task_id": task_id, + "ingestion_status": ingestion_status, + "relative_path": relative_path, + "media_type_hint": media_type_hint, + } + store.artifacts_by_id[task_artifact_id] = artifact + store.artifacts_by_task_id.setdefault(task_id, []).append(artifact) + return task_artifact_id + + +def semantic_artifact_row( + store: SemanticRetrievalStoreStub, + *, + task_id: UUID, + task_artifact_id: UUID, + relative_path: str, + score: float, + sequence_no: int, +) -> dict[str, object]: + return { + "id": uuid4(), + "user_id": uuid4(), + "task_id": task_id, + "task_artifact_id": task_artifact_id, + "relative_path": relative_path, + "media_type_hint": "text/plain", + "sequence_no": sequence_no, + "char_start": 0, + "char_end_exclusive": 11, + "text": f"{relative_path}-chunk", + "created_at": store.base_time + timedelta(minutes=sequence_no), + "updated_at": store.base_time + timedelta(minutes=sequence_no + 1), + "embedding_config_id": uuid4(), + "score": score, + } + + def test_retrieve_semantic_memory_records_returns_stable_shape_and_summary() -> None: store = SemanticRetrievalStoreStub() config_id = seed_config(store, dimensions=3) @@ -174,3 +283,181 @@ def test_retrieve_semantic_memory_records_rejects_non_active_memory_rows() -> No query_vector=(0.1, 0.2, 0.3), ), ) + + +def test_retrieve_task_scoped_semantic_artifact_chunk_records_returns_stable_shape_and_summary() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + task_id = seed_task(store) + first_artifact_id = seed_artifact( + store, + task_id=task_id, + relative_path="docs/a.txt", + ) + second_artifact_id = seed_artifact( + store, + task_id=task_id, + relative_path="notes/b.txt", + ) + pending_artifact_id = seed_artifact( + store, + task_id=task_id, + ingestion_status="pending", + relative_path="notes/pending.txt", + ) + first_row = semantic_artifact_row( + store, + task_id=task_id, + task_artifact_id=first_artifact_id, + relative_path="docs/a.txt", + score=1.0, + sequence_no=1, + ) + second_row = semantic_artifact_row( + store, + task_id=task_id, + task_artifact_id=second_artifact_id, + relative_path="notes/b.txt", + score=0.25, + sequence_no=1, + ) + store.task_artifact_retrieval_rows = [first_row, second_row] + + payload = retrieve_task_scoped_semantic_artifact_chunk_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskScopedSemanticArtifactChunkRetrievalInput( + task_id=task_id, + embedding_config_id=config_id, + query_vector=(1.0, 0.0, 0.0), + limit=2, + ), + ) + + assert payload == { + "items": [ + { + "id": str(first_row["id"]), + "task_id": str(task_id), + "task_artifact_id": str(first_artifact_id), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 11, + "text": "docs/a.txt-chunk", + "score": 1.0, + }, + { + "id": str(second_row["id"]), + "task_id": str(task_id), + "task_artifact_id": str(second_artifact_id), + "relative_path": "notes/b.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 11, + "text": "notes/b.txt-chunk", + "score": 0.25, + }, + ], + "summary": { + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 2, + "returned_count": 2, + "searched_artifact_count": 2, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": {"kind": "task", "task_id": str(task_id)}, + }, + } + assert pending_artifact_id in store.artifacts_by_id + assert store.last_query == { + "scope": "task", + "task_id": task_id, + "embedding_config_id": config_id, + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + } + + +def test_retrieve_task_scoped_semantic_artifact_chunk_records_rejects_missing_task_and_dimension_mismatch() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + + with pytest.raises(TaskNotFoundError, match="task .* was not found"): + retrieve_task_scoped_semantic_artifact_chunk_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskScopedSemanticArtifactChunkRetrievalInput( + task_id=uuid4(), + embedding_config_id=config_id, + query_vector=(1.0, 0.0, 0.0), + ), + ) + + task_id = seed_task(store) + seed_artifact(store, task_id=task_id) + with pytest.raises( + SemanticArtifactChunkRetrievalValidationError, + match="query_vector length must match embedding config dimensions \\(3\\): 2", + ): + retrieve_task_scoped_semantic_artifact_chunk_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskScopedSemanticArtifactChunkRetrievalInput( + task_id=task_id, + embedding_config_id=config_id, + query_vector=(1.0, 0.0), + ), + ) + + +def test_retrieve_artifact_scoped_semantic_artifact_chunk_records_returns_empty_for_pending_artifact() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + task_id = seed_task(store) + artifact_id = seed_artifact( + store, + task_id=task_id, + ingestion_status="pending", + relative_path="notes/pending.txt", + media_type_hint="text/markdown", + ) + + payload = retrieve_artifact_scoped_semantic_artifact_chunk_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=ArtifactScopedSemanticArtifactChunkRetrievalInput( + task_artifact_id=artifact_id, + embedding_config_id=config_id, + query_vector=(0.0, 1.0, 0.0), + limit=5, + ), + ) + + assert payload == { + "items": [], + "summary": { + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 5, + "returned_count": 0, + "searched_artifact_count": 0, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "scope": { + "kind": "artifact", + "task_id": str(task_id), + "task_artifact_id": str(artifact_id), + }, + }, + } + assert store.last_query == { + "scope": "artifact", + "task_artifact_id": artifact_id, + "embedding_config_id": config_id, + "query_vector": [0.0, 1.0, 0.0], + "limit": 5, + } diff --git a/tests/unit/test_task_artifact_chunk_embedding_store.py b/tests/unit/test_task_artifact_chunk_embedding_store.py index 227a191..08704fb 100644 --- a/tests/unit/test_task_artifact_chunk_embedding_store.py +++ b/tests/unit/test_task_artifact_chunk_embedding_store.py @@ -234,3 +234,85 @@ def test_task_artifact_chunk_embedding_store_optional_reads_return_none_when_row task_artifact_chunk_id=uuid4(), embedding_config_id=uuid4(), ) is None + + +def test_semantic_artifact_chunk_retrieval_store_methods_use_expected_queries() -> None: + task_id = uuid4() + task_artifact_id = uuid4() + task_artifact_chunk_id = uuid4() + embedding_config_id = uuid4() + created_at = datetime(2026, 3, 15, 9, 0, tzinfo=UTC) + cursor = RecordingCursor( + fetchone_results=[], + fetchall_results=[ + [ + { + "id": task_artifact_chunk_id, + "user_id": uuid4(), + "task_id": task_id, + "task_artifact_id": task_artifact_id, + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 11, + "text": "alpha chunk", + "created_at": created_at, + "updated_at": created_at, + "embedding_config_id": embedding_config_id, + "score": 1.0, + } + ], + [ + { + "id": task_artifact_chunk_id, + "user_id": uuid4(), + "task_id": task_id, + "task_artifact_id": task_artifact_id, + "relative_path": "docs/spec.txt", + "media_type_hint": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 11, + "text": "alpha chunk", + "created_at": created_at, + "updated_at": created_at, + "embedding_config_id": embedding_config_id, + "score": 1.0, + } + ], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + task_rows = store.retrieve_task_scoped_semantic_artifact_chunk_matches( + task_id=task_id, + embedding_config_id=embedding_config_id, + query_vector=[1.0, 0.0, 0.0], + limit=5, + ) + artifact_rows = store.retrieve_artifact_scoped_semantic_artifact_chunk_matches( + task_artifact_id=task_artifact_id, + embedding_config_id=embedding_config_id, + query_vector=[1.0, 0.0, 0.0], + limit=3, + ) + + assert task_rows[0]["task_id"] == task_id + assert artifact_rows[0]["task_artifact_id"] == task_artifact_id + + task_query, task_params = cursor.executed[0] + assert "FROM task_artifact_chunk_embeddings AS embeddings" in task_query + assert "JOIN task_artifacts AS artifacts" in task_query + assert "artifacts.task_id = %s" in task_query + assert "artifacts.ingestion_status = 'ingested'" in task_query + assert "ORDER BY score DESC, artifacts.relative_path ASC, chunks.sequence_no ASC, chunks.id ASC" in task_query + assert task_params == ("[1.0,0.0,0.0]", embedding_config_id, 3, task_id, 5) + + artifact_query, artifact_params = cursor.executed[1] + assert "FROM task_artifact_chunk_embeddings AS embeddings" in artifact_query + assert "JOIN task_artifacts AS artifacts" in artifact_query + assert "artifacts.id = %s" in artifact_query + assert "artifacts.ingestion_status = 'ingested'" in artifact_query + assert "ORDER BY score DESC, artifacts.relative_path ASC, chunks.sequence_no ASC, chunks.id ASC" in artifact_query + assert artifact_params == ("[1.0,0.0,0.0]", embedding_config_id, 3, task_artifact_id, 3) From d716653343d9b677cebd4bfc4329ed29df8e091c Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 15 Mar 2026 22:18:30 +0100 Subject: [PATCH 009/135] Sprint 5I: compile semantic artifact retrieval adoption (#9) * Sprint 5I: compile semantic artifact retrieval packet * Sprint 5I: compile semantic artifact retrieval adoption --------- Co-authored-by: Sami Rusani --- apps/api/src/alicebot_api/compiler.py | 269 +++++++++++- apps/api/src/alicebot_api/contracts.py | 87 ++++ apps/api/src/alicebot_api/main.py | 65 +++ .../src/alicebot_api/response_generation.py | 4 + .../src/alicebot_api/semantic_retrieval.py | 6 +- tests/integration/test_context_compile.py | 399 ++++++++++++++++++ tests/unit/test_compiler.py | 251 ++++++++++- tests/unit/test_main.py | 128 +++++- tests/unit/test_response_generation.py | 36 ++ 9 files changed, 1239 insertions(+), 6 deletions(-) diff --git a/apps/api/src/alicebot_api/compiler.py b/apps/api/src/alicebot_api/compiler.py index 770319c..46a89b1 100644 --- a/apps/api/src/alicebot_api/compiler.py +++ b/apps/api/src/alicebot_api/compiler.py @@ -6,10 +6,13 @@ from alicebot_api.contracts import ( COMPILER_VERSION_V0, ArtifactRetrievalDecisionTracePayload, + CompileContextArtifactScopedSemanticArtifactRetrievalInput, CompilerDecision, CompileContextArtifactRetrievalInput, CompileContextArtifactScopedArtifactRetrievalInput, CompileContextSemanticRetrievalInput, + CompileContextSemanticArtifactRetrievalInput, + CompileContextTaskScopedSemanticArtifactRetrievalInput, CompileContextTaskScopedArtifactRetrievalInput, CompilerRunResult, CompiledContextPack, @@ -19,10 +22,14 @@ ContextPackHybridMemorySummary, ContextPackMemory, ContextPackMemorySummary, + ContextPackSemanticArtifactChunk, + ContextPackSemanticArtifactChunkSummary, HybridMemoryDecisionTracePayload, MemorySelectionSource, SEMANTIC_MEMORY_RETRIEVAL_ORDER, TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER, + TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER, + SemanticArtifactRetrievalDecisionTracePayload, SemanticMemoryRetrievalRequestInput, TRACE_KIND_CONTEXT_COMPILE, TraceEventRecord, @@ -36,7 +43,13 @@ resolve_artifact_chunk_retrieval_query_terms, retrieve_matching_task_artifact_chunks, ) -from alicebot_api.semantic_retrieval import validate_semantic_memory_retrieval_request +from alicebot_api.semantic_retrieval import ( + retrieve_artifact_scoped_semantic_artifact_chunk_records, + retrieve_task_scoped_semantic_artifact_chunk_records, + serialize_semantic_artifact_chunk_result_item, + validate_semantic_artifact_chunk_retrieval_request, + validate_semantic_memory_retrieval_request, +) from alicebot_api.store import ( ContinuityStore, EntityEdgeRow, @@ -52,6 +65,7 @@ SUMMARY_TRACE_EVENT_KIND = "context.summary" _UNBOUNDED_SEMANTIC_RETRIEVAL_LIMIT = 2_147_483_647 +_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT = 2_147_483_647 HYBRID_MEMORY_SOURCE_PRECEDENCE: list[MemorySelectionSource] = ["symbolic", "semantic"] HYBRID_SYMBOLIC_ORDER = ["updated_at_asc", "created_at_asc", "id_asc"] @@ -77,6 +91,13 @@ class CompiledArtifactChunkSection: decisions: list[CompilerDecision] +@dataclass(frozen=True, slots=True) +class CompiledSemanticArtifactChunkSection: + items: list[ContextPackSemanticArtifactChunk] + summary: ContextPackSemanticArtifactChunkSummary + decisions: list[CompilerDecision] + + @dataclass(slots=True) class HybridMemoryCandidate: memory: MemoryRow @@ -237,6 +258,23 @@ def _empty_artifact_chunk_summary() -> ContextPackArtifactChunkSummary: } +def _empty_semantic_artifact_chunk_summary() -> ContextPackSemanticArtifactChunkSummary: + return { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": list(TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER), + } + + def _artifact_retrieval_decision_metadata( *, scope_kind: str, @@ -273,6 +311,45 @@ def _artifact_retrieval_decision_metadata( return payload +def _semantic_artifact_retrieval_decision_metadata( + *, + scope_kind: str, + task_id: UUID, + task_artifact_id: UUID, + relative_path: str, + media_type: str | None, + ingestion_status: str, + embedding_config_id: UUID, + query_vector_dimensions: int, + limit: int, + score: float | None = None, + sequence_no: int | None = None, + char_start: int | None = None, + char_end_exclusive: int | None = None, +) -> SemanticArtifactRetrievalDecisionTracePayload: + payload: SemanticArtifactRetrievalDecisionTracePayload = { + "scope_kind": scope_kind, # type: ignore[typeddict-item] + "task_id": str(task_id), + "task_artifact_id": str(task_artifact_id), + "relative_path": relative_path, + "media_type": media_type, + "ingestion_status": ingestion_status, # type: ignore[typeddict-item] + "embedding_config_id": str(embedding_config_id), + "query_vector_dimensions": query_vector_dimensions, + "limit": limit, + "similarity_metric": "cosine_similarity", + } + if score is not None: + payload["score"] = score + if sequence_no is not None: + payload["sequence_no"] = sequence_no + if char_start is not None: + payload["char_start"] = char_start + if char_end_exclusive is not None: + payload["char_end_exclusive"] = char_end_exclusive + return payload + + def _hybrid_memory_decision_metadata( *, embedding_config_id: UUID | None, @@ -687,6 +764,158 @@ def _compile_artifact_chunk_section( ) +def _compile_semantic_artifact_chunk_section( + store: ContinuityStore, + *, + semantic_artifact_retrieval: CompileContextSemanticArtifactRetrievalInput | None, +) -> CompiledSemanticArtifactChunkSection: + if semantic_artifact_retrieval is None: + return CompiledSemanticArtifactChunkSection( + items=[], + summary=_empty_semantic_artifact_chunk_summary(), + decisions=[], + ) + + if isinstance( + semantic_artifact_retrieval, + CompileContextTaskScopedSemanticArtifactRetrievalInput, + ): + task = store.get_task_optional(semantic_artifact_retrieval.task_id) + if task is None: + raise TaskNotFoundError(f"task {semantic_artifact_retrieval.task_id} was not found") + artifact_rows = store.list_task_artifacts_for_task(semantic_artifact_retrieval.task_id) + scope_kind = "task" + section_payload = retrieve_task_scoped_semantic_artifact_chunk_records( + store, + user_id=task["id"], + request=semantic_artifact_retrieval, + ) + _config, query_vector = validate_semantic_artifact_chunk_retrieval_request( + store, + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector=semantic_artifact_retrieval.query_vector, + ) + matched_items = [ + serialize_semantic_artifact_chunk_result_item(row) + for row in store.retrieve_task_scoped_semantic_artifact_chunk_matches( + task_id=semantic_artifact_retrieval.task_id, + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector=query_vector, + limit=_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT, + ) + ] + else: + artifact_row = store.get_task_artifact_optional( + semantic_artifact_retrieval.task_artifact_id + ) + if artifact_row is None: + raise TaskArtifactNotFoundError( + f"task artifact {semantic_artifact_retrieval.task_artifact_id} was not found" + ) + artifact_rows = [artifact_row] + scope_kind = "artifact" + section_payload = retrieve_artifact_scoped_semantic_artifact_chunk_records( + store, + user_id=artifact_row["task_id"], + request=semantic_artifact_retrieval, + ) + _config, query_vector = validate_semantic_artifact_chunk_retrieval_request( + store, + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector=semantic_artifact_retrieval.query_vector, + ) + matched_items = [ + serialize_semantic_artifact_chunk_result_item(row) + for row in store.retrieve_artifact_scoped_semantic_artifact_chunk_matches( + task_artifact_id=semantic_artifact_retrieval.task_artifact_id, + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector=query_vector, + limit=_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT, + ) + ] + + included_items = list(section_payload["items"]) + excluded_uningested_artifact_count = 0 + decisions: list[CompilerDecision] = [] + + for position, artifact_row in enumerate(artifact_rows, start=1): + if artifact_row["ingestion_status"] == "ingested": + continue + excluded_uningested_artifact_count += 1 + decisions.append( + CompilerDecision( + "excluded", + "task_artifact", + artifact_row["id"], + "semantic_artifact_not_ingested", + position, + metadata=_semantic_artifact_retrieval_decision_metadata( + scope_kind=scope_kind, + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + relative_path=artifact_row["relative_path"], + media_type=infer_task_artifact_media_type(artifact_row), + ingestion_status=artifact_row["ingestion_status"], + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector_dimensions=len(query_vector), + limit=semantic_artifact_retrieval.limit, + ), + ) + ) + + for position, item in enumerate(matched_items, start=1): + decision_kind = "included" if position <= semantic_artifact_retrieval.limit else "excluded" + decision_reason = ( + "within_semantic_artifact_chunk_limit" + if position <= semantic_artifact_retrieval.limit + else "semantic_artifact_chunk_limit_exceeded" + ) + decisions.append( + CompilerDecision( + decision_kind, + "semantic_artifact_chunk", + UUID(item["id"]), + decision_reason, + position, + metadata=_semantic_artifact_retrieval_decision_metadata( + scope_kind=scope_kind, + task_id=UUID(item["task_id"]), + task_artifact_id=UUID(item["task_artifact_id"]), + relative_path=item["relative_path"], + media_type=item["media_type"], + ingestion_status="ingested", + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector_dimensions=len(query_vector), + limit=semantic_artifact_retrieval.limit, + score=item["score"], + sequence_no=item["sequence_no"], + char_start=item["char_start"], + char_end_exclusive=item["char_end_exclusive"], + ), + ) + ) + + section_summary = section_payload["summary"] + return CompiledSemanticArtifactChunkSection( + items=included_items, + summary={ + "requested": True, + "scope": section_summary["scope"], + "embedding_config_id": section_summary["embedding_config_id"], + "query_vector_dimensions": section_summary["query_vector_dimensions"], + "limit": section_summary["limit"], + "searched_artifact_count": section_summary["searched_artifact_count"], + "candidate_count": len(matched_items), + "included_count": len(included_items), + "excluded_uningested_artifact_count": excluded_uningested_artifact_count, + "excluded_limit_count": max(len(matched_items) - len(included_items), 0), + "similarity_metric": section_summary["similarity_metric"], + "order": list(section_summary["order"]), + }, + decisions=decisions, + ) + + def compile_continuity_context( *, user: UserRow, @@ -699,6 +928,7 @@ def compile_continuity_context( limits: ContextCompilerLimits, memory_section: CompiledMemorySection | None = None, artifact_chunk_section: CompiledArtifactChunkSection | None = None, + semantic_artifact_chunk_section: CompiledSemanticArtifactChunkSection | None = None, ) -> CompilerRunResult: latest_session_sequence: dict[UUID, int] = {} for event in events: @@ -797,6 +1027,15 @@ def compile_continuity_context( decisions=[], ) decisions.extend(resolved_artifact_chunk_section.decisions) + resolved_semantic_artifact_chunk_section = ( + semantic_artifact_chunk_section + or CompiledSemanticArtifactChunkSection( + items=[], + summary=_empty_semantic_artifact_chunk_summary(), + decisions=[], + ) + ) + decisions.extend(resolved_semantic_artifact_chunk_section.decisions) ordered_entities = sorted(entities, key=_entity_sort_key) included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] included_entity_ids = {entity["id"] for entity in included_entities} @@ -945,6 +1184,26 @@ def compile_continuity_context( "excluded_uningested_artifact_count": resolved_artifact_chunk_section.summary[ "excluded_uningested_artifact_count" ], + "semantic_artifact_retrieval_requested": resolved_semantic_artifact_chunk_section.summary[ + "requested" + ], + "semantic_artifact_retrieval_scope_kind": ( + None + if resolved_semantic_artifact_chunk_section.summary["scope"] is None + else resolved_semantic_artifact_chunk_section.summary["scope"]["kind"] + ), + "semantic_artifact_chunk_candidate_count": resolved_semantic_artifact_chunk_section.summary[ + "candidate_count" + ], + "included_semantic_artifact_chunk_count": resolved_semantic_artifact_chunk_section.summary[ + "included_count" + ], + "excluded_semantic_artifact_chunk_limit_count": resolved_semantic_artifact_chunk_section.summary[ + "excluded_limit_count" + ], + "excluded_semantic_uningested_artifact_count": resolved_semantic_artifact_chunk_section.summary[ + "excluded_uningested_artifact_count" + ], "included_entity_count": len(included_entities), "excluded_entity_count": excluded_entity_limit_count, "excluded_entity_limit_count": excluded_entity_limit_count, @@ -978,6 +1237,8 @@ def compile_continuity_context( "memory_summary": resolved_memory_section.summary, "artifact_chunks": list(resolved_artifact_chunk_section.items), "artifact_chunk_summary": resolved_artifact_chunk_section.summary, + "semantic_artifact_chunks": list(resolved_semantic_artifact_chunk_section.items), + "semantic_artifact_chunk_summary": resolved_semantic_artifact_chunk_section.summary, "entities": [_serialize_entity(entity) for entity in included_entities], "entity_summary": { "candidate_count": len(ordered_entities), @@ -1004,6 +1265,7 @@ def compile_and_persist_trace( limits: ContextCompilerLimits, semantic_retrieval: CompileContextSemanticRetrievalInput | None = None, artifact_retrieval: CompileContextArtifactRetrievalInput | None = None, + semantic_artifact_retrieval: CompileContextSemanticArtifactRetrievalInput | None = None, ) -> CompiledTraceRun: user = store.get_user(user_id) thread = store.get_thread(thread_id) @@ -1020,6 +1282,10 @@ def compile_and_persist_trace( store, artifact_retrieval=artifact_retrieval, ) + semantic_artifact_chunk_section = _compile_semantic_artifact_chunk_section( + store, + semantic_artifact_retrieval=semantic_artifact_retrieval, + ) entities = store.list_entities() ordered_entities = sorted(entities, key=_entity_sort_key) included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] @@ -1035,6 +1301,7 @@ def compile_and_persist_trace( limits=limits, memory_section=memory_section, artifact_chunk_section=artifact_chunk_section, + semantic_artifact_chunk_section=semantic_artifact_chunk_section, ) trace = store.create_trace( user_id=user_id, diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index 07fa214..8d4882b 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -251,6 +251,46 @@ def as_payload(self) -> JsonObject: ) +@dataclass(frozen=True, slots=True) +class CompileContextTaskScopedSemanticArtifactRetrievalInput: + task_id: UUID + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "kind": "task", + "task_id": str(self.task_id), + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class CompileContextArtifactScopedSemanticArtifactRetrievalInput: + task_artifact_id: UUID + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "kind": "artifact", + "task_artifact_id": str(self.task_artifact_id), + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +CompileContextSemanticArtifactRetrievalInput: TypeAlias = ( + CompileContextTaskScopedSemanticArtifactRetrievalInput + | CompileContextArtifactScopedSemanticArtifactRetrievalInput +) + + @dataclass(frozen=True, slots=True) class TraceCreate: user_id: UUID @@ -394,6 +434,34 @@ class ContextPackArtifactChunkSummary(TypedDict): order: list[str] +class ContextPackSemanticArtifactChunk(TypedDict): + id: str + task_id: str + task_artifact_id: str + relative_path: str + media_type: str + sequence_no: int + char_start: int + char_end_exclusive: int + text: str + score: float + + +class ContextPackSemanticArtifactChunkSummary(TypedDict): + requested: bool + scope: TaskArtifactChunkRetrievalScope | None + embedding_config_id: str | None + query_vector_dimensions: int + limit: int + searched_artifact_count: int + candidate_count: int + included_count: int + excluded_uningested_artifact_count: int + excluded_limit_count: int + similarity_metric: Literal["cosine_similarity"] | None + order: list[str] + + class ArtifactRetrievalDecisionTracePayload(TypedDict): scope_kind: TaskArtifactChunkRetrievalScopeKind task_id: str @@ -410,6 +478,23 @@ class ArtifactRetrievalDecisionTracePayload(TypedDict): char_end_exclusive: NotRequired[int] +class SemanticArtifactRetrievalDecisionTracePayload(TypedDict): + scope_kind: TaskArtifactChunkRetrievalScopeKind + task_id: str + task_artifact_id: str + relative_path: str + media_type: str | None + ingestion_status: TaskArtifactIngestionStatus + embedding_config_id: str + query_vector_dimensions: int + limit: int + similarity_metric: Literal["cosine_similarity"] + score: NotRequired[float] + sequence_no: NotRequired[int] + char_start: NotRequired[int] + char_end_exclusive: NotRequired[int] + + class ContextPackMemorySummary(TypedDict): candidate_count: int included_count: int @@ -495,6 +580,8 @@ class CompiledContextPack(TypedDict): memory_summary: ContextPackMemorySummary artifact_chunks: list[ContextPackArtifactChunk] artifact_chunk_summary: ContextPackArtifactChunkSummary + semantic_artifact_chunks: list[ContextPackSemanticArtifactChunk] + semantic_artifact_chunk_summary: ContextPackSemanticArtifactChunkSummary entities: list[ContextPackEntity] entity_summary: ContextPackEntitySummary entity_edges: list[ContextPackEntityEdge] diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index fb487e6..e530778 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -16,8 +16,11 @@ ApprovalRejectInput, ApprovalRequestCreateInput, ArtifactScopedSemanticArtifactChunkRetrievalInput, + CompileContextArtifactScopedSemanticArtifactRetrievalInput, CompileContextArtifactScopedArtifactRetrievalInput, + CompileContextSemanticArtifactRetrievalInput, CompileContextTaskScopedArtifactRetrievalInput, + CompileContextTaskScopedSemanticArtifactRetrievalInput, ConsentStatus, ConsentUpsertInput, CompileContextSemanticRetrievalInput, @@ -291,6 +294,41 @@ class CompileContextArtifactScopedArtifactRetrievalRequest(BaseModel): ] +class CompileContextTaskScopedSemanticArtifactRetrievalRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + kind: Literal["task"] + task_id: UUID + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ge=1, + le=MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ) + + +class CompileContextArtifactScopedSemanticArtifactRetrievalRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + kind: Literal["artifact"] + task_artifact_id: UUID + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ge=1, + le=MAX_ARTIFACT_CHUNK_RETRIEVAL_LIMIT, + ) + + +CompileContextSemanticArtifactRetrievalRequest = Annotated[ + CompileContextTaskScopedSemanticArtifactRetrievalRequest + | CompileContextArtifactScopedSemanticArtifactRetrievalRequest, + Field(discriminator="kind"), +] + + class CompileContextRequest(BaseModel): user_id: UUID thread_id: UUID @@ -301,6 +339,7 @@ class CompileContextRequest(BaseModel): max_entity_edges: int = Field(default=DEFAULT_MAX_ENTITY_EDGES, ge=0, le=100) semantic: CompileContextSemanticRequest | None = None artifact_retrieval: CompileContextArtifactRetrievalRequest | None = None + semantic_artifact_retrieval: CompileContextSemanticArtifactRetrievalRequest | None = None class GenerateResponseRequest(BaseModel): @@ -616,6 +655,7 @@ def healthcheck() -> JSONResponse: def compile_context(request: CompileContextRequest) -> JSONResponse: settings = get_settings() artifact_retrieval = None + semantic_artifact_retrieval = None if isinstance(request.artifact_retrieval, CompileContextTaskScopedArtifactRetrievalRequest): artifact_retrieval = CompileContextTaskScopedArtifactRetrievalInput( task_id=request.artifact_retrieval.task_id, @@ -631,6 +671,28 @@ def compile_context(request: CompileContextRequest) -> JSONResponse: query=request.artifact_retrieval.query, limit=request.artifact_retrieval.limit, ) + if isinstance( + request.semantic_artifact_retrieval, + CompileContextTaskScopedSemanticArtifactRetrievalRequest, + ): + semantic_artifact_retrieval = CompileContextTaskScopedSemanticArtifactRetrievalInput( + task_id=request.semantic_artifact_retrieval.task_id, + embedding_config_id=request.semantic_artifact_retrieval.embedding_config_id, + query_vector=tuple(request.semantic_artifact_retrieval.query_vector), + limit=request.semantic_artifact_retrieval.limit, + ) + elif isinstance( + request.semantic_artifact_retrieval, + CompileContextArtifactScopedSemanticArtifactRetrievalRequest, + ): + semantic_artifact_retrieval = ( + CompileContextArtifactScopedSemanticArtifactRetrievalInput( + task_artifact_id=request.semantic_artifact_retrieval.task_artifact_id, + embedding_config_id=request.semantic_artifact_retrieval.embedding_config_id, + query_vector=tuple(request.semantic_artifact_retrieval.query_vector), + limit=request.semantic_artifact_retrieval.limit, + ) + ) try: with user_connection(settings.database_url, request.user_id) as conn: @@ -655,9 +717,12 @@ def compile_context(request: CompileContextRequest) -> JSONResponse: ) ), artifact_retrieval=artifact_retrieval, + semantic_artifact_retrieval=semantic_artifact_retrieval, ) except TaskArtifactChunkRetrievalValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) + except SemanticArtifactChunkRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) except SemanticMemoryRetrievalValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) except (TaskNotFoundError, TaskArtifactNotFoundError) as exc: diff --git a/apps/api/src/alicebot_api/response_generation.py b/apps/api/src/alicebot_api/response_generation.py index 7652a5d..78f2ee6 100644 --- a/apps/api/src/alicebot_api/response_generation.py +++ b/apps/api/src/alicebot_api/response_generation.py @@ -90,6 +90,10 @@ def _context_section_payload(context_pack: CompiledContextPack) -> JsonObject: "sessions": context_pack["sessions"], "memories": context_pack["memories"], "memory_summary": context_pack["memory_summary"], + "artifact_chunks": context_pack["artifact_chunks"], + "artifact_chunk_summary": context_pack["artifact_chunk_summary"], + "semantic_artifact_chunks": context_pack["semantic_artifact_chunks"], + "semantic_artifact_chunk_summary": context_pack["semantic_artifact_chunk_summary"], "entities": context_pack["entities"], "entity_summary": context_pack["entity_summary"], "entity_edges": context_pack["entity_edges"], diff --git a/apps/api/src/alicebot_api/semantic_retrieval.py b/apps/api/src/alicebot_api/semantic_retrieval.py index 4fe066d..50059a1 100644 --- a/apps/api/src/alicebot_api/semantic_retrieval.py +++ b/apps/api/src/alicebot_api/semantic_retrieval.py @@ -143,7 +143,7 @@ def _build_task_artifact_chunk_retrieval_scope( return scope -def _serialize_semantic_artifact_chunk_result_item( +def serialize_semantic_artifact_chunk_result_item( row: TaskArtifactChunkSemanticRetrievalRow, ) -> TaskArtifactChunkSemanticRetrievalItem: return { @@ -220,7 +220,7 @@ def retrieve_task_scoped_semantic_artifact_chunk_records( query_vector=request.query_vector, ) items = [ - _serialize_semantic_artifact_chunk_result_item(row) + serialize_semantic_artifact_chunk_result_item(row) for row in store.retrieve_task_scoped_semantic_artifact_chunk_matches( task_id=request.task_id, embedding_config_id=request.embedding_config_id, @@ -264,7 +264,7 @@ def retrieve_artifact_scoped_semantic_artifact_chunk_records( query_vector=request.query_vector, ) items = [ - _serialize_semantic_artifact_chunk_result_item(row) + serialize_semantic_artifact_chunk_result_item(row) for row in store.retrieve_artifact_scoped_semantic_artifact_chunk_matches( task_artifact_id=request.task_artifact_id, embedding_config_id=request.embedding_config_id, diff --git a/tests/integration/test_context_compile.py b/tests/integration/test_context_compile.py index 4b6913b..cc43cd5 100644 --- a/tests/integration/test_context_compile.py +++ b/tests/integration/test_context_compile.py @@ -278,6 +278,23 @@ def seed_memory_embedding_for_user( ) +def seed_task_artifact_chunk_embedding_for_user( + database_url: str, + *, + user_id: UUID, + task_artifact_chunk_id: UUID, + embedding_config_id: UUID, + vector: list[float], +) -> None: + with user_connection(database_url, user_id) as conn: + ContinuityStore(conn).create_task_artifact_chunk_embedding( + task_artifact_chunk_id=task_artifact_chunk_id, + embedding_config_id=embedding_config_id, + dimensions=len(vector), + vector=vector, + ) + + def seed_compile_artifact_scope( database_url: str, *, @@ -468,6 +485,21 @@ def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_datab "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, } + assert payload["context_pack"]["semantic_artifact_chunks"] == [] + assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } assert payload["context_pack"]["entities"] == [ { "id": str(included_entity["id"]), @@ -576,6 +608,11 @@ def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_datab assert trace_events[-1]["payload"]["hybrid_memory_candidate_count"] == 2 assert trace_events[-1]["payload"]["hybrid_memory_merged_candidate_count"] == 1 assert trace_events[-1]["payload"]["hybrid_memory_deduplicated_count"] == 0 + assert trace_events[-1]["payload"]["semantic_artifact_retrieval_requested"] is False + assert trace_events[-1]["payload"]["semantic_artifact_chunk_candidate_count"] == 0 + assert trace_events[-1]["payload"]["included_semantic_artifact_chunk_count"] == 0 + assert trace_events[-1]["payload"]["excluded_semantic_artifact_chunk_limit_count"] == 0 + assert trace_events[-1]["payload"]["excluded_semantic_uningested_artifact_count"] == 0 assert trace_events[-1]["payload"]["included_entity_count"] == 1 assert trace_events[-1]["payload"]["excluded_entity_limit_count"] == 2 assert trace_events[-1]["payload"]["included_entity_edge_count"] == 1 @@ -654,6 +691,21 @@ def test_compile_context_prefers_updated_active_memory_within_same_transaction( "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, } + assert payload["context_pack"]["semantic_artifact_chunks"] == [] + assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } assert payload["context_pack"]["entity_summary"] == { "candidate_count": 2, "included_count": 1, @@ -1134,6 +1186,353 @@ def test_compile_context_artifact_scoped_retrieval_returns_only_visible_artifact assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 0 +def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_and_exclusion_rules( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=seeded["user_id"], + thread_id=seeded["thread_id"], + ) + config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["docs"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["notes"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["weak"], + embedding_config_id=config_id, + vector=[0.0, 1.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "task", + "task_id": str(artifact_scope["task_id"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["context_pack"]["semantic_artifact_chunks"] == [ + { + "id": str(artifact_scope["chunk_ids"]["docs"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["docs"]), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "score": 1.0, + }, + { + "id": str(artifact_scope["chunk_ids"]["notes"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "score": 1.0, + }, + ] + assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + "requested": True, + "scope": {"kind": "task", "task_id": str(artifact_scope["task_id"])}, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 2, + "searched_artifact_count": 3, + "candidate_count": 3, + "included_count": 2, + "excluded_uningested_artifact_count": 1, + "excluded_limit_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } + assert payload["context_pack"]["artifact_chunks"] == [] + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_semantic_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["docs"]) + and event["payload"]["relative_path"] == "docs/a.txt" + and event["payload"]["score"] == 1.0 + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "within_semantic_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) + and event["payload"]["relative_path"] == "notes/b.md" + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "semantic_artifact_chunk_limit_exceeded" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["weak"]) + and event["payload"]["relative_path"] == "notes/c.txt" + and event["payload"]["score"] == 0.0 + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "semantic_artifact_not_ingested" + and event["payload"]["entity_id"] == str(artifact_scope["artifact_ids"]["pending"]) + and event["payload"]["relative_path"] == "notes/hidden.txt" + and event["payload"]["ingestion_status"] == "pending" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert trace_events[-1]["payload"]["semantic_artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["semantic_artifact_retrieval_scope_kind"] == "task" + assert trace_events[-1]["payload"]["semantic_artifact_chunk_candidate_count"] == 3 + assert trace_events[-1]["payload"]["included_semantic_artifact_chunk_count"] == 2 + assert trace_events[-1]["payload"]["excluded_semantic_artifact_chunk_limit_count"] == 1 + assert trace_events[-1]["payload"]["excluded_semantic_uningested_artifact_count"] == 1 + + +def test_compile_context_semantic_artifact_scoped_retrieval_returns_only_visible_artifact_chunks( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=seeded["user_id"], + thread_id=seeded["thread_id"], + ) + config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["notes"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "artifact", + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["context_pack"]["semantic_artifact_chunks"] == [ + { + "id": str(artifact_scope["chunk_ids"]["notes"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "score": 1.0, + } + ] + assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + "requested": True, + "scope": { + "kind": "artifact", + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + }, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 2, + "searched_artifact_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_semantic_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) + and event["payload"]["scope_kind"] == "artifact" + and event["payload"]["task_artifact_id"] == str(artifact_scope["artifact_ids"]["notes"]) + for event in trace_events + if event["kind"] == "context.included" + ) + assert trace_events[-1]["payload"]["semantic_artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["semantic_artifact_retrieval_scope_kind"] == "artifact" + assert trace_events[-1]["payload"]["semantic_artifact_chunk_candidate_count"] == 1 + assert trace_events[-1]["payload"]["included_semantic_artifact_chunk_count"] == 1 + assert trace_events[-1]["payload"]["excluded_semantic_artifact_chunk_limit_count"] == 0 + assert trace_events[-1]["payload"]["excluded_semantic_uningested_artifact_count"] == 0 + + +def test_compile_context_semantic_artifact_retrieval_validation_and_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_traceable_thread(migrated_database_urls["app"]) + intruder = seed_traceable_thread( + migrated_database_urls["app"], + email="intruder@example.com", + display_name="Intruder", + ) + owner_artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=owner["user_id"], + thread_id=owner["thread_id"], + ) + owner_config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=owner["user_id"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + invalid_shape_status, invalid_shape_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "task", + "task_artifact_id": str(owner_artifact_scope["artifact_ids"]["docs"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + }, + } + ) + missing_status, missing_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "task", + "task_id": str(owner_artifact_scope["task_id"]), + "embedding_config_id": str(uuid4()), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + mismatch_status, mismatch_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "task", + "task_id": str(owner_artifact_scope["task_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0], + "limit": 2, + }, + } + ) + isolated_task_status, isolated_task_payload = invoke_compile_context( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "task", + "task_id": str(owner_artifact_scope["task_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + isolated_artifact_status, isolated_artifact_payload = invoke_compile_context( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "semantic_artifact_retrieval": { + "kind": "artifact", + "task_artifact_id": str(owner_artifact_scope["artifact_ids"]["docs"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + + assert invalid_shape_status == 422 + assert "task_id" in json.dumps(invalid_shape_payload) + assert missing_status == 400 + assert missing_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "query_vector length must match embedding config dimensions (3): 2" + assert isolated_task_status == 404 + assert isolated_task_payload == { + "detail": f"task {owner_artifact_scope['task_id']} was not found" + } + assert isolated_artifact_status == 404 + assert isolated_artifact_payload == { + "detail": ( + "task artifact " + f"{owner_artifact_scope['artifact_ids']['docs']} was not found" + ) + } + + def test_compile_context_artifact_retrieval_validation_and_isolation( migrated_database_urls, monkeypatch, diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index e0c2cae..7ff19c9 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -7,10 +7,13 @@ SUMMARY_TRACE_EVENT_KIND, _compile_artifact_chunk_section, _compile_memory_section, + _compile_semantic_artifact_chunk_section, compile_continuity_context, ) from alicebot_api.contracts import ( + CompileContextArtifactScopedSemanticArtifactRetrievalInput, CompileContextSemanticRetrievalInput, + CompileContextTaskScopedSemanticArtifactRetrievalInput, CompileContextTaskScopedArtifactRetrievalInput, ContextCompilerLimits, ) @@ -313,6 +316,21 @@ def test_compile_continuity_context_is_deterministic_and_stably_ordered() -> Non "id_asc", ], } + assert first_run.context_pack["semantic_artifact_chunks"] == [] + assert first_run.context_pack["semantic_artifact_chunk_summary"] == { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } assert first_run.context_pack["entity_summary"] == { "candidate_count": 3, "included_count": 2, @@ -643,6 +661,21 @@ def test_compile_continuity_context_records_included_and_excluded_reasons() -> N "id_asc", ], } + assert compiler_run.context_pack["semantic_artifact_chunks"] == [] + assert compiler_run.context_pack["semantic_artifact_chunk_summary"] == { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } assert compiler_run.context_pack["entities"] == [ { "id": str(kept_entity_id), @@ -681,6 +714,14 @@ def test_compile_continuity_context_records_included_and_excluded_reasons() -> N assert compiler_run.trace_events[-1].payload["included_artifact_chunk_count"] == 0 assert compiler_run.trace_events[-1].payload["excluded_artifact_chunk_limit_count"] == 0 assert compiler_run.trace_events[-1].payload["excluded_uningested_artifact_count"] == 0 + assert compiler_run.trace_events[-1].payload["semantic_artifact_retrieval_requested"] is False + assert compiler_run.trace_events[-1].payload["semantic_artifact_chunk_candidate_count"] == 0 + assert compiler_run.trace_events[-1].payload["included_semantic_artifact_chunk_count"] == 0 + assert ( + compiler_run.trace_events[-1].payload["excluded_semantic_artifact_chunk_limit_count"] + == 0 + ) + assert compiler_run.trace_events[-1].payload["excluded_semantic_uningested_artifact_count"] == 0 class SemanticCompileStoreStub: @@ -745,9 +786,15 @@ def list_memory_embeddings_for_config(self, embedding_config_id): class ArtifactCompileStoreStub: def __init__(self) -> None: self.base_time = datetime(2026, 3, 14, 12, 0, tzinfo=UTC) + self.config_id = uuid4() self.task_id = uuid4() self.artifact_ids = [uuid4(), uuid4(), uuid4(), uuid4()] - self.chunk_ids = [uuid4(), uuid4(), uuid4()] + self.chunk_ids = [uuid4(), uuid4(), uuid4(), uuid4()] + + def get_embedding_config_optional(self, embedding_config_id): + if embedding_config_id != self.config_id: + return None + return {"id": self.config_id, "dimensions": 3} def get_task_optional(self, task_id): if task_id != self.task_id: @@ -845,6 +892,97 @@ def list_task_artifact_chunks(self, task_artifact_id): ] return [] + def get_task_artifact_optional(self, task_artifact_id): + for artifact_row in self.list_task_artifacts_for_task(self.task_id): + if artifact_row["id"] == task_artifact_id: + return artifact_row + return None + + def retrieve_task_scoped_semantic_artifact_chunk_matches( + self, + *, + task_id, + embedding_config_id, + query_vector, + limit, + ): + assert task_id == self.task_id + assert embedding_config_id == self.config_id + assert query_vector == [1.0, 0.0, 0.0] + rows = [ + { + "id": self.chunk_ids[0], + "user_id": uuid4(), + "task_id": self.task_id, + "task_artifact_id": self.artifact_ids[0], + "relative_path": "docs/a.txt", + "media_type_hint": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "created_at": self.base_time, + "updated_at": self.base_time, + "embedding_config_id": self.config_id, + "score": 1.0, + }, + { + "id": self.chunk_ids[1], + "user_id": uuid4(), + "task_id": self.task_id, + "task_artifact_id": self.artifact_ids[1], + "relative_path": "notes/b.md", + "media_type_hint": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "created_at": self.base_time + timedelta(minutes=1), + "updated_at": self.base_time + timedelta(minutes=1), + "embedding_config_id": self.config_id, + "score": 1.0, + }, + { + "id": self.chunk_ids[3], + "user_id": uuid4(), + "task_id": self.task_id, + "task_artifact_id": self.artifact_ids[3], + "relative_path": "notes/c.txt", + "media_type_hint": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 9, + "text": "beta only", + "created_at": self.base_time + timedelta(minutes=3), + "updated_at": self.base_time + timedelta(minutes=3), + "embedding_config_id": self.config_id, + "score": 0.25, + }, + ] + return list(rows[:limit]) + + def retrieve_artifact_scoped_semantic_artifact_chunk_matches( + self, + *, + task_artifact_id, + embedding_config_id, + query_vector, + limit, + ): + assert embedding_config_id == self.config_id + assert query_vector == [1.0, 0.0, 0.0] + rows = [ + row + for row in self.retrieve_task_scoped_semantic_artifact_chunk_matches( + task_id=self.task_id, + embedding_config_id=embedding_config_id, + query_vector=query_vector, + limit=10, + ) + if row["task_artifact_id"] == task_artifact_id + ] + return list(rows[:limit]) + def test_compile_artifact_chunk_section_orders_limits_and_excludes_non_ingested() -> None: store = ArtifactCompileStoreStub() @@ -922,6 +1060,117 @@ def test_compile_artifact_chunk_section_orders_limits_and_excludes_non_ingested( assert artifact_section.decisions[-1].metadata["relative_path"] == "notes/c.txt" +def test_compile_semantic_artifact_chunk_section_orders_limits_and_excludes_non_ingested() -> None: + store = ArtifactCompileStoreStub() + + semantic_artifact_section = _compile_semantic_artifact_chunk_section( + store, # type: ignore[arg-type] + semantic_artifact_retrieval=CompileContextTaskScopedSemanticArtifactRetrievalInput( + task_id=store.task_id, + embedding_config_id=store.config_id, + query_vector=(1.0, 0.0, 0.0), + limit=2, + ), + ) + + assert semantic_artifact_section.items == [ + { + "id": str(store.chunk_ids[0]), + "task_id": str(store.task_id), + "task_artifact_id": str(store.artifact_ids[0]), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "score": 1.0, + }, + { + "id": str(store.chunk_ids[1]), + "task_id": str(store.task_id), + "task_artifact_id": str(store.artifact_ids[1]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "score": 1.0, + }, + ] + assert semantic_artifact_section.summary == { + "requested": True, + "scope": {"kind": "task", "task_id": str(store.task_id)}, + "embedding_config_id": str(store.config_id), + "query_vector_dimensions": 3, + "limit": 2, + "searched_artifact_count": 3, + "candidate_count": 3, + "included_count": 2, + "excluded_uningested_artifact_count": 1, + "excluded_limit_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } + assert [decision.reason for decision in semantic_artifact_section.decisions] == [ + "semantic_artifact_not_ingested", + "within_semantic_artifact_chunk_limit", + "within_semantic_artifact_chunk_limit", + "semantic_artifact_chunk_limit_exceeded", + ] + assert semantic_artifact_section.decisions[0].metadata["relative_path"] == "notes/hidden.txt" + assert semantic_artifact_section.decisions[-1].metadata["relative_path"] == "notes/c.txt" + + +def test_compile_semantic_artifact_chunk_section_supports_artifact_scope() -> None: + store = ArtifactCompileStoreStub() + + semantic_artifact_section = _compile_semantic_artifact_chunk_section( + store, # type: ignore[arg-type] + semantic_artifact_retrieval=CompileContextArtifactScopedSemanticArtifactRetrievalInput( + task_artifact_id=store.artifact_ids[1], + embedding_config_id=store.config_id, + query_vector=(1.0, 0.0, 0.0), + limit=2, + ), + ) + + assert semantic_artifact_section.items == [ + { + "id": str(store.chunk_ids[1]), + "task_id": str(store.task_id), + "task_artifact_id": str(store.artifact_ids[1]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "score": 1.0, + } + ] + assert semantic_artifact_section.summary == { + "requested": True, + "scope": { + "kind": "artifact", + "task_id": str(store.task_id), + "task_artifact_id": str(store.artifact_ids[1]), + }, + "embedding_config_id": str(store.config_id), + "query_vector_dimensions": 3, + "limit": 2, + "searched_artifact_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + } + assert semantic_artifact_section.decisions[0].metadata["scope_kind"] == "artifact" + + def test_compile_memory_section_orders_limits_and_excludes_deleted() -> None: store = SemanticCompileStoreStub() deleted_memory = { diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 0b3441d..c8e63b8 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -21,7 +21,10 @@ from alicebot_api.entity_edge import EntityEdgeValidationError from alicebot_api.memory import MemoryAdmissionValidationError, MemoryReviewNotFoundError from alicebot_api.response_generation import ResponseFailure -from alicebot_api.semantic_retrieval import SemanticMemoryRetrievalValidationError +from alicebot_api.semantic_retrieval import ( + SemanticArtifactChunkRetrievalValidationError, + SemanticMemoryRetrievalValidationError, +) from alicebot_api.store import ContinuityStoreInvariantError @@ -198,6 +201,7 @@ def fake_compile_and_persist_trace( limits, semantic_retrieval, artifact_retrieval, + semantic_artifact_retrieval, ): captured["store_type"] = type(store).__name__ captured["user_id"] = user_id @@ -205,6 +209,7 @@ def fake_compile_and_persist_trace( captured["limits"] = limits captured["semantic_retrieval"] = semantic_retrieval captured["artifact_retrieval"] = artifact_retrieval + captured["semantic_artifact_retrieval"] = semantic_artifact_retrieval return CompiledTraceRun( trace_id="trace-123", trace_event_count=5, @@ -288,6 +293,21 @@ def fake_compile_and_persist_trace( "id_asc", ], }, + "semantic_artifact_chunks": [], + "semantic_artifact_chunk_summary": { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + }, "entities": [ { "id": "entity-123", @@ -423,6 +443,21 @@ def fake_compile_and_persist_trace( "id_asc", ], }, + "semantic_artifact_chunks": [], + "semantic_artifact_chunk_summary": { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + }, "entities": [ { "id": "entity-123", @@ -468,6 +503,7 @@ def fake_compile_and_persist_trace( assert captured["limits"].max_entity_edges == 6 assert captured["semantic_retrieval"] is None assert captured["artifact_retrieval"] is None + assert captured["semantic_artifact_retrieval"] is None def test_compile_context_returns_not_found_when_scope_row_is_missing(monkeypatch) -> None: @@ -518,6 +554,7 @@ def fake_compile_and_persist_trace( limits, semantic_retrieval, artifact_retrieval, + semantic_artifact_retrieval, ): captured["store_type"] = type(store).__name__ captured["user_id"] = user_id @@ -525,6 +562,7 @@ def fake_compile_and_persist_trace( captured["limits"] = limits captured["semantic_retrieval"] = semantic_retrieval captured["artifact_retrieval"] = artifact_retrieval + captured["semantic_artifact_retrieval"] = semantic_artifact_retrieval return CompiledTraceRun( trace_id="trace-semantic", trace_event_count=7, @@ -628,6 +666,34 @@ def fake_compile_and_persist_trace( "id_asc", ], }, + "semantic_artifact_chunks": [ + { + "id": "semantic-chunk-123", + "task_id": "task-123", + "task_artifact_id": "artifact-123", + "relative_path": "docs/spec.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 16, + "text": "alpha beta spec", + "score": 0.99, + } + ], + "semantic_artifact_chunk_summary": { + "requested": True, + "scope": {"kind": "task", "task_id": "task-123"}, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 2, + "searched_artifact_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + }, "entities": [], "entity_summary": { "candidate_count": 0, @@ -663,6 +729,15 @@ def fake_compile_and_persist_trace( query="alpha beta", limit=2, ), + semantic_artifact_retrieval=( + main_module.CompileContextTaskScopedSemanticArtifactRetrievalRequest( + kind="task", + task_id=uuid4(), + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=2, + ) + ), ) ) @@ -692,6 +767,10 @@ def fake_compile_and_persist_trace( assert captured["artifact_retrieval"].task_id is not None assert captured["artifact_retrieval"].query == "alpha beta" assert captured["artifact_retrieval"].limit == 2 + assert captured["semantic_artifact_retrieval"].task_id is not None + assert captured["semantic_artifact_retrieval"].embedding_config_id == config_id + assert captured["semantic_artifact_retrieval"].query_vector == (0.1, 0.2, 0.3) + assert captured["semantic_artifact_retrieval"].limit == 2 monkeypatch.setattr( main_module, @@ -720,6 +799,37 @@ def fake_compile_and_persist_trace( "detail": "embedding_config_id must reference an existing embedding config owned by the user" } + monkeypatch.setattr( + main_module, + "compile_and_persist_trace", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + SemanticArtifactChunkRetrievalValidationError( + "query_vector length must match embedding config dimensions (3): 2" + ) + ), + ) + + semantic_artifact_error_response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + semantic_artifact_retrieval=( + main_module.CompileContextTaskScopedSemanticArtifactRetrievalRequest( + kind="task", + task_id=uuid4(), + embedding_config_id=config_id, + query_vector=[0.1, 0.2], + limit=2, + ) + ), + ) + ) + + assert semantic_artifact_error_response.status_code == 400 + assert json.loads(semantic_artifact_error_response.body) == { + "detail": "query_vector length must match embedding config dimensions (3): 2" + } + def test_compile_context_request_rejects_invalid_artifact_scope_shape() -> None: with pytest.raises(Exception) as exc_info: @@ -736,6 +846,22 @@ def test_compile_context_request_rejects_invalid_artifact_scope_shape() -> None: assert "task_id" in str(exc_info.value) +def test_compile_context_request_rejects_invalid_semantic_artifact_scope_shape() -> None: + with pytest.raises(Exception) as exc_info: + main_module.CompileContextRequest( + user_id=uuid4(), + thread_id=uuid4(), + semantic_artifact_retrieval={ + "kind": "task", + "task_artifact_id": str(uuid4()), + "embedding_config_id": str(uuid4()), + "query_vector": [0.1, 0.2, 0.3], + }, + ) + + assert "task_id" in str(exc_info.value) + + def test_generate_assistant_response_returns_assistant_and_trace_payload(monkeypatch) -> None: user_id = uuid4() thread_id = uuid4() diff --git a/tests/unit/test_response_generation.py b/tests/unit/test_response_generation.py index f91c051..59cbd40 100644 --- a/tests/unit/test_response_generation.py +++ b/tests/unit/test_response_generation.py @@ -88,6 +88,42 @@ def make_context_pack() -> dict[str, object]: "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, }, + "artifact_chunks": [], + "artifact_chunk_summary": { + "requested": False, + "scope": None, + "query": None, + "query_terms": [], + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + }, + "semantic_artifact_chunks": [], + "semantic_artifact_chunk_summary": { + "requested": False, + "scope": None, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "limit": 0, + "searched_artifact_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_uningested_artifact_count": 0, + "excluded_limit_count": 0, + "similarity_metric": None, + "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + }, "entities": [], "entity_summary": { "candidate_count": 0, From 9d7451e743e4be67b912222cad5662ffbefda700 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 08:54:45 +0100 Subject: [PATCH 010/135] Sprint 5J: deterministic hybrid artifact merge (#10) * Sprint 5J: hybrid artifact merge packet * Sprint 5J: deterministic hybrid artifact merge --------- Co-authored-by: Sami Rusani --- apps/api/src/alicebot_api/compiler.py | 763 +++++++++++------- apps/api/src/alicebot_api/contracts.py | 67 +- .../src/alicebot_api/response_generation.py | 2 - tests/integration/test_context_compile.py | 534 ++++++++++-- tests/unit/test_compiler.py | 298 +++++-- tests/unit/test_main.py | 157 ++-- tests/unit/test_response_generation.py | 44 +- 7 files changed, 1286 insertions(+), 579 deletions(-) diff --git a/apps/api/src/alicebot_api/compiler.py b/apps/api/src/alicebot_api/compiler.py index 46a89b1..e8eac76 100644 --- a/apps/api/src/alicebot_api/compiler.py +++ b/apps/api/src/alicebot_api/compiler.py @@ -5,7 +5,7 @@ from alicebot_api.contracts import ( COMPILER_VERSION_V0, - ArtifactRetrievalDecisionTracePayload, + ArtifactSelectionSource, CompileContextArtifactScopedSemanticArtifactRetrievalInput, CompilerDecision, CompileContextArtifactRetrievalInput, @@ -22,14 +22,12 @@ ContextPackHybridMemorySummary, ContextPackMemory, ContextPackMemorySummary, - ContextPackSemanticArtifactChunk, - ContextPackSemanticArtifactChunkSummary, HybridMemoryDecisionTracePayload, + HybridArtifactRetrievalDecisionTracePayload, MemorySelectionSource, SEMANTIC_MEMORY_RETRIEVAL_ORDER, TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER, TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER, - SemanticArtifactRetrievalDecisionTracePayload, SemanticMemoryRetrievalRequestInput, TRACE_KIND_CONTEXT_COMPILE, TraceEventRecord, @@ -44,8 +42,6 @@ retrieve_matching_task_artifact_chunks, ) from alicebot_api.semantic_retrieval import ( - retrieve_artifact_scoped_semantic_artifact_chunk_records, - retrieve_task_scoped_semantic_artifact_chunk_records, serialize_semantic_artifact_chunk_result_item, validate_semantic_artifact_chunk_retrieval_request, validate_semantic_memory_retrieval_request, @@ -68,6 +64,15 @@ _UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT = 2_147_483_647 HYBRID_MEMORY_SOURCE_PRECEDENCE: list[MemorySelectionSource] = ["symbolic", "semantic"] HYBRID_SYMBOLIC_ORDER = ["updated_at_asc", "created_at_asc", "id_asc"] +HYBRID_ARTIFACT_SOURCE_PRECEDENCE: list[ArtifactSelectionSource] = ["lexical", "semantic"] +HYBRID_ARTIFACT_MERGED_ORDER = [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", +] @dataclass(frozen=True, slots=True) @@ -91,13 +96,6 @@ class CompiledArtifactChunkSection: decisions: list[CompilerDecision] -@dataclass(frozen=True, slots=True) -class CompiledSemanticArtifactChunkSection: - items: list[ContextPackSemanticArtifactChunk] - summary: ContextPackSemanticArtifactChunkSummary - decisions: list[CompilerDecision] - - @dataclass(slots=True) class HybridMemoryCandidate: memory: MemoryRow @@ -105,6 +103,14 @@ class HybridMemoryCandidate: semantic_score: float | None = None +@dataclass(slots=True) +class HybridArtifactChunkCandidate: + item: ContextPackArtifactChunk + sources: list[ArtifactSelectionSource] + lexical_rank: int | None = None + semantic_rank: int | None = None + + def _session_sort_key( session: SessionRow, latest_session_sequence: dict[UUID, int], @@ -244,38 +250,37 @@ def _empty_hybrid_memory_summary() -> ContextPackHybridMemorySummary: def _empty_artifact_chunk_summary() -> ContextPackArtifactChunkSummary: return { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, "query": None, "query_terms": [], - "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, - "limit": 0, - "searched_artifact_count": 0, - "candidate_count": 0, - "included_count": 0, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), - } - - -def _empty_semantic_artifact_chunk_summary() -> ContextPackSemanticArtifactChunkSummary: - return { - "requested": False, - "scope": None, "embedding_config_id": None, "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, + "matching_rule": None, "similarity_metric": None, - "order": list(TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER), + "source_precedence": list(HYBRID_ARTIFACT_SOURCE_PRECEDENCE), + "lexical_order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), + "semantic_order": list(TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER), + "merged_order": list(HYBRID_ARTIFACT_MERGED_ORDER), } -def _artifact_retrieval_decision_metadata( +def _hybrid_artifact_retrieval_decision_metadata( *, scope_kind: str, task_id: UUID, @@ -284,63 +289,34 @@ def _artifact_retrieval_decision_metadata( media_type: str | None, ingestion_status: str, limit: int, + selected_sources: list[ArtifactSelectionSource], + embedding_config_id: UUID | None = None, + query_vector_dimensions: int = 0, match: dict[str, object] | None = None, + score: float | None = None, sequence_no: int | None = None, char_start: int | None = None, char_end_exclusive: int | None = None, -) -> ArtifactRetrievalDecisionTracePayload: - payload: ArtifactRetrievalDecisionTracePayload = { +) -> HybridArtifactRetrievalDecisionTracePayload: + payload: HybridArtifactRetrievalDecisionTracePayload = { "scope_kind": scope_kind, # type: ignore[typeddict-item] "task_id": str(task_id), "task_artifact_id": str(task_artifact_id), "relative_path": relative_path, "media_type": media_type, "ingestion_status": ingestion_status, # type: ignore[typeddict-item] + "selected_sources": list(selected_sources), + "embedding_config_id": None if embedding_config_id is None else str(embedding_config_id), + "query_vector_dimensions": query_vector_dimensions, "limit": limit, } if match is not None: payload["matched_query_terms"] = list(match["matched_query_terms"]) # type: ignore[index] payload["matched_query_term_count"] = int(match["matched_query_term_count"]) # type: ignore[index] payload["first_match_char_start"] = int(match["first_match_char_start"]) # type: ignore[index] - if sequence_no is not None: - payload["sequence_no"] = sequence_no - if char_start is not None: - payload["char_start"] = char_start - if char_end_exclusive is not None: - payload["char_end_exclusive"] = char_end_exclusive - return payload - - -def _semantic_artifact_retrieval_decision_metadata( - *, - scope_kind: str, - task_id: UUID, - task_artifact_id: UUID, - relative_path: str, - media_type: str | None, - ingestion_status: str, - embedding_config_id: UUID, - query_vector_dimensions: int, - limit: int, - score: float | None = None, - sequence_no: int | None = None, - char_start: int | None = None, - char_end_exclusive: int | None = None, -) -> SemanticArtifactRetrievalDecisionTracePayload: - payload: SemanticArtifactRetrievalDecisionTracePayload = { - "scope_kind": scope_kind, # type: ignore[typeddict-item] - "task_id": str(task_id), - "task_artifact_id": str(task_artifact_id), - "relative_path": relative_path, - "media_type": media_type, - "ingestion_status": ingestion_status, # type: ignore[typeddict-item] - "embedding_config_id": str(embedding_config_id), - "query_vector_dimensions": query_vector_dimensions, - "limit": limit, - "similarity_metric": "cosine_similarity", - } if score is not None: payload["score"] = score + payload["similarity_metric"] = "cosine_similarity" if sequence_no is not None: payload["sequence_no"] = sequence_no if char_start is not None: @@ -386,6 +362,110 @@ def _serialize_hybrid_memory(candidate: HybridMemoryCandidate) -> ContextPackMem } +def _serialize_hybrid_artifact_chunk(candidate: HybridArtifactChunkCandidate) -> ContextPackArtifactChunk: + item = candidate.item + return { + "id": item["id"], + "task_id": item["task_id"], + "task_artifact_id": item["task_artifact_id"], + "relative_path": item["relative_path"], + "media_type": item["media_type"], + "sequence_no": item["sequence_no"], + "char_start": item["char_start"], + "char_end_exclusive": item["char_end_exclusive"], + "text": item["text"], + "source_provenance": { + "sources": list(candidate.sources), + "lexical_match": item["source_provenance"]["lexical_match"], + "semantic_score": item["source_provenance"]["semantic_score"], + }, + } + + +def _resolve_artifact_scope( + store: ContinuityStore, + *, + artifact_retrieval: CompileContextArtifactRetrievalInput | None, + semantic_artifact_retrieval: CompileContextSemanticArtifactRetrievalInput | None, +) -> tuple[list[dict[str, object]], dict[str, str] | None, str | None]: + lexical_scope: tuple[list[dict[str, object]], dict[str, str], str] | None = None + semantic_scope: tuple[list[dict[str, object]], dict[str, str], str] | None = None + + if isinstance(artifact_retrieval, CompileContextTaskScopedArtifactRetrievalInput): + task = store.get_task_optional(artifact_retrieval.task_id) + if task is None: + raise TaskNotFoundError(f"task {artifact_retrieval.task_id} was not found") + lexical_scope = ( + store.list_task_artifacts_for_task(artifact_retrieval.task_id), + build_task_artifact_chunk_retrieval_scope( + kind="task", + task_id=artifact_retrieval.task_id, + ), + "task", + ) + elif isinstance(artifact_retrieval, CompileContextArtifactScopedArtifactRetrievalInput): + artifact_row = store.get_task_artifact_optional(artifact_retrieval.task_artifact_id) + if artifact_row is None: + raise TaskArtifactNotFoundError( + f"task artifact {artifact_retrieval.task_artifact_id} was not found" + ) + lexical_scope = ( + [artifact_row], + build_task_artifact_chunk_retrieval_scope( + kind="artifact", + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + ), + "artifact", + ) + + if isinstance( + semantic_artifact_retrieval, + CompileContextTaskScopedSemanticArtifactRetrievalInput, + ): + task = store.get_task_optional(semantic_artifact_retrieval.task_id) + if task is None: + raise TaskNotFoundError(f"task {semantic_artifact_retrieval.task_id} was not found") + semantic_scope = ( + store.list_task_artifacts_for_task(semantic_artifact_retrieval.task_id), + build_task_artifact_chunk_retrieval_scope( + kind="task", + task_id=semantic_artifact_retrieval.task_id, + ), + "task", + ) + elif isinstance( + semantic_artifact_retrieval, + CompileContextArtifactScopedSemanticArtifactRetrievalInput, + ): + artifact_row = store.get_task_artifact_optional( + semantic_artifact_retrieval.task_artifact_id + ) + if artifact_row is None: + raise TaskArtifactNotFoundError( + f"task artifact {semantic_artifact_retrieval.task_artifact_id} was not found" + ) + semantic_scope = ( + [artifact_row], + build_task_artifact_chunk_retrieval_scope( + kind="artifact", + task_id=artifact_row["task_id"], + task_artifact_id=artifact_row["id"], + ), + "artifact", + ) + + if lexical_scope is not None and semantic_scope is not None and lexical_scope[1] != semantic_scope[1]: + raise TaskArtifactChunkRetrievalValidationError( + "artifact_retrieval and semantic_artifact_retrieval must target the same scope" + ) + + resolved_scope = lexical_scope or semantic_scope + if resolved_scope is None: + return [], None, None + return resolved_scope + + def _build_symbolic_memory_section( *, memories: list[MemoryRow], @@ -649,194 +729,145 @@ def _compile_artifact_chunk_section( store: ContinuityStore, *, artifact_retrieval: CompileContextArtifactRetrievalInput | None, + semantic_artifact_retrieval: CompileContextSemanticArtifactRetrievalInput | None, ) -> CompiledArtifactChunkSection: - if artifact_retrieval is None: + if artifact_retrieval is None and semantic_artifact_retrieval is None: return CompiledArtifactChunkSection( items=[], summary=_empty_artifact_chunk_summary(), decisions=[], ) - if isinstance(artifact_retrieval, CompileContextTaskScopedArtifactRetrievalInput): - task = store.get_task_optional(artifact_retrieval.task_id) - if task is None: - raise TaskNotFoundError(f"task {artifact_retrieval.task_id} was not found") - artifact_rows = store.list_task_artifacts_for_task(artifact_retrieval.task_id) - scope = build_task_artifact_chunk_retrieval_scope( - kind="task", - task_id=artifact_retrieval.task_id, - ) - scope_kind = "task" - else: - artifact_row = store.get_task_artifact_optional(artifact_retrieval.task_artifact_id) - if artifact_row is None: - raise TaskArtifactNotFoundError( - f"task artifact {artifact_retrieval.task_artifact_id} was not found" - ) - artifact_rows = [artifact_row] - scope = build_task_artifact_chunk_retrieval_scope( - kind="artifact", - task_id=artifact_row["task_id"], - task_artifact_id=artifact_row["id"], - ) - scope_kind = "artifact" - - query_terms = resolve_artifact_chunk_retrieval_query_terms(artifact_retrieval.query) - matched_items, searched_artifact_count = retrieve_matching_task_artifact_chunks( + artifact_rows, scope, scope_kind = _resolve_artifact_scope( store, - artifact_rows=artifact_rows, - query_terms=query_terms, + artifact_retrieval=artifact_retrieval, + semantic_artifact_retrieval=semantic_artifact_retrieval, ) - included_items = matched_items[: artifact_retrieval.limit] - excluded_uningested_artifact_count = 0 - decisions: list[CompilerDecision] = [] - - for position, artifact_row in enumerate(artifact_rows, start=1): - if artifact_row["ingestion_status"] == "ingested": - continue - excluded_uningested_artifact_count += 1 - decisions.append( - CompilerDecision( - "excluded", - "task_artifact", - artifact_row["id"], - "artifact_not_ingested", - position, - metadata=_artifact_retrieval_decision_metadata( - scope_kind=scope_kind, - task_id=artifact_row["task_id"], - task_artifact_id=artifact_row["id"], - relative_path=artifact_row["relative_path"], - media_type=infer_task_artifact_media_type(artifact_row), - ingestion_status=artifact_row["ingestion_status"], - limit=artifact_retrieval.limit, - ), - ) - ) - - for position, item in enumerate(matched_items, start=1): - decision_kind = "included" if position <= artifact_retrieval.limit else "excluded" - decision_reason = ( - "within_artifact_chunk_limit" - if position <= artifact_retrieval.limit - else "artifact_chunk_limit_exceeded" - ) - decisions.append( - CompilerDecision( - decision_kind, - "artifact_chunk", - UUID(item["id"]), - decision_reason, - position, - metadata=_artifact_retrieval_decision_metadata( - scope_kind=scope_kind, - task_id=UUID(item["task_id"]), - task_artifact_id=UUID(item["task_artifact_id"]), - relative_path=item["relative_path"], - media_type=item["media_type"], - ingestion_status="ingested", - limit=artifact_retrieval.limit, - match=item["match"], - sequence_no=item["sequence_no"], - char_start=item["char_start"], - char_end_exclusive=item["char_end_exclusive"], - ), - ) - ) - - return CompiledArtifactChunkSection( - items=list(included_items), - summary={ - "requested": True, - "scope": scope, - "query": artifact_retrieval.query, - "query_terms": list(query_terms), - "matching_rule": TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE, - "limit": artifact_retrieval.limit, - "searched_artifact_count": searched_artifact_count, - "candidate_count": len(matched_items), - "included_count": len(included_items), - "excluded_uningested_artifact_count": excluded_uningested_artifact_count, - "excluded_limit_count": max(len(matched_items) - len(included_items), 0), - "order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), - }, - decisions=decisions, + assert scope is not None + assert scope_kind is not None + + query = None if artifact_retrieval is None else artifact_retrieval.query + query_terms: list[str] = [] + lexical_items: list[ContextPackArtifactChunk] = [] + searched_artifact_count = sum( + 1 for artifact_row in artifact_rows if artifact_row["ingestion_status"] == "ingested" ) - - -def _compile_semantic_artifact_chunk_section( - store: ContinuityStore, - *, - semantic_artifact_retrieval: CompileContextSemanticArtifactRetrievalInput | None, -) -> CompiledSemanticArtifactChunkSection: - if semantic_artifact_retrieval is None: - return CompiledSemanticArtifactChunkSection( - items=[], - summary=_empty_semantic_artifact_chunk_summary(), - decisions=[], + if artifact_retrieval is not None: + query_terms = resolve_artifact_chunk_retrieval_query_terms(artifact_retrieval.query) + lexical_matches, searched_artifact_count = retrieve_matching_task_artifact_chunks( + store, + artifact_rows=artifact_rows, + query_terms=query_terms, ) + lexical_items = [ + { + "id": item["id"], + "task_id": item["task_id"], + "task_artifact_id": item["task_artifact_id"], + "relative_path": item["relative_path"], + "media_type": item["media_type"], + "sequence_no": item["sequence_no"], + "char_start": item["char_start"], + "char_end_exclusive": item["char_end_exclusive"], + "text": item["text"], + "source_provenance": { + "sources": ["lexical"], + "lexical_match": item["match"], + "semantic_score": None, + }, + } + for item in lexical_matches + ] + semantic_items: list[ContextPackArtifactChunk] = [] + query_vector_dimensions = 0 if isinstance( semantic_artifact_retrieval, CompileContextTaskScopedSemanticArtifactRetrievalInput, ): - task = store.get_task_optional(semantic_artifact_retrieval.task_id) - if task is None: - raise TaskNotFoundError(f"task {semantic_artifact_retrieval.task_id} was not found") - artifact_rows = store.list_task_artifacts_for_task(semantic_artifact_retrieval.task_id) - scope_kind = "task" - section_payload = retrieve_task_scoped_semantic_artifact_chunk_records( - store, - user_id=task["id"], - request=semantic_artifact_retrieval, - ) _config, query_vector = validate_semantic_artifact_chunk_retrieval_request( store, embedding_config_id=semantic_artifact_retrieval.embedding_config_id, query_vector=semantic_artifact_retrieval.query_vector, ) - matched_items = [ - serialize_semantic_artifact_chunk_result_item(row) - for row in store.retrieve_task_scoped_semantic_artifact_chunk_matches( - task_id=semantic_artifact_retrieval.task_id, - embedding_config_id=semantic_artifact_retrieval.embedding_config_id, - query_vector=query_vector, - limit=_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT, - ) + query_vector_dimensions = len(query_vector) + semantic_items = [ + { + "id": item["id"], + "task_id": item["task_id"], + "task_artifact_id": item["task_artifact_id"], + "relative_path": item["relative_path"], + "media_type": item["media_type"], + "sequence_no": item["sequence_no"], + "char_start": item["char_start"], + "char_end_exclusive": item["char_end_exclusive"], + "text": item["text"], + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": item["score"], + }, + } + for item in [ + serialize_semantic_artifact_chunk_result_item(row) + for row in store.retrieve_task_scoped_semantic_artifact_chunk_matches( + task_id=semantic_artifact_retrieval.task_id, + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector=query_vector, + limit=_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT, + ) + ] ] - else: - artifact_row = store.get_task_artifact_optional( - semantic_artifact_retrieval.task_artifact_id - ) - if artifact_row is None: - raise TaskArtifactNotFoundError( - f"task artifact {semantic_artifact_retrieval.task_artifact_id} was not found" - ) - artifact_rows = [artifact_row] - scope_kind = "artifact" - section_payload = retrieve_artifact_scoped_semantic_artifact_chunk_records( - store, - user_id=artifact_row["task_id"], - request=semantic_artifact_retrieval, - ) + elif isinstance( + semantic_artifact_retrieval, + CompileContextArtifactScopedSemanticArtifactRetrievalInput, + ): _config, query_vector = validate_semantic_artifact_chunk_retrieval_request( store, embedding_config_id=semantic_artifact_retrieval.embedding_config_id, query_vector=semantic_artifact_retrieval.query_vector, ) - matched_items = [ - serialize_semantic_artifact_chunk_result_item(row) - for row in store.retrieve_artifact_scoped_semantic_artifact_chunk_matches( - task_artifact_id=semantic_artifact_retrieval.task_artifact_id, - embedding_config_id=semantic_artifact_retrieval.embedding_config_id, - query_vector=query_vector, - limit=_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT, - ) + query_vector_dimensions = len(query_vector) + semantic_items = [ + { + "id": item["id"], + "task_id": item["task_id"], + "task_artifact_id": item["task_artifact_id"], + "relative_path": item["relative_path"], + "media_type": item["media_type"], + "sequence_no": item["sequence_no"], + "char_start": item["char_start"], + "char_end_exclusive": item["char_end_exclusive"], + "text": item["text"], + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": item["score"], + }, + } + for item in [ + serialize_semantic_artifact_chunk_result_item(row) + for row in store.retrieve_artifact_scoped_semantic_artifact_chunk_matches( + task_artifact_id=semantic_artifact_retrieval.task_artifact_id, + embedding_config_id=semantic_artifact_retrieval.embedding_config_id, + query_vector=query_vector, + limit=_UNBOUNDED_SEMANTIC_ARTIFACT_RETRIEVAL_LIMIT, + ) + ] ] - included_items = list(section_payload["items"]) + merged_candidates: list[HybridArtifactChunkCandidate] = [] + merged_candidates_by_id: dict[str, HybridArtifactChunkCandidate] = {} + deduplicated_count = 0 excluded_uningested_artifact_count = 0 decisions: list[CompilerDecision] = [] + final_limit = ( + artifact_retrieval.limit + if artifact_retrieval is not None + else semantic_artifact_retrieval.limit + if semantic_artifact_retrieval is not None + else 0 + ) for position, artifact_row in enumerate(artifact_rows, start=1): if artifact_row["ingestion_status"] == "ingested": @@ -847,70 +878,216 @@ def _compile_semantic_artifact_chunk_section( "excluded", "task_artifact", artifact_row["id"], - "semantic_artifact_not_ingested", + "hybrid_artifact_not_ingested", position, - metadata=_semantic_artifact_retrieval_decision_metadata( + metadata=_hybrid_artifact_retrieval_decision_metadata( scope_kind=scope_kind, task_id=artifact_row["task_id"], task_artifact_id=artifact_row["id"], relative_path=artifact_row["relative_path"], media_type=infer_task_artifact_media_type(artifact_row), ingestion_status=artifact_row["ingestion_status"], + limit=final_limit, + selected_sources=[], + embedding_config_id=( + None + if semantic_artifact_retrieval is None + else semantic_artifact_retrieval.embedding_config_id + ), + query_vector_dimensions=query_vector_dimensions, + ), + ) + ) + + for lexical_rank, item in enumerate(lexical_items, start=1): + candidate = HybridArtifactChunkCandidate( + item=item, + sources=["lexical"], + lexical_rank=lexical_rank, + ) + merged_candidates.append(candidate) + merged_candidates_by_id[item["id"]] = candidate + + for semantic_rank, item in enumerate(semantic_items, start=1): + existing_candidate = merged_candidates_by_id.get(item["id"]) + if existing_candidate is None: + candidate = HybridArtifactChunkCandidate( + item=item, + sources=["semantic"], + semantic_rank=semantic_rank, + ) + merged_candidates.append(candidate) + merged_candidates_by_id[item["id"]] = candidate + continue + + deduplicated_count += 1 + if "semantic" not in existing_candidate.sources: + existing_candidate.sources.append("semantic") + existing_candidate.semantic_rank = semantic_rank + existing_candidate.item["source_provenance"]["semantic_score"] = item["source_provenance"][ + "semantic_score" + ] + decisions.append( + CompilerDecision( + "included", + "artifact_chunk", + UUID(existing_candidate.item["id"]), + "hybrid_artifact_chunk_deduplicated", + semantic_rank, + metadata=_hybrid_artifact_retrieval_decision_metadata( + scope_kind=scope_kind, + task_id=UUID(existing_candidate.item["task_id"]), + task_artifact_id=UUID(existing_candidate.item["task_artifact_id"]), + relative_path=existing_candidate.item["relative_path"], + media_type=existing_candidate.item["media_type"], + ingestion_status="ingested", + limit=final_limit, + selected_sources=existing_candidate.sources, embedding_config_id=semantic_artifact_retrieval.embedding_config_id, - query_vector_dimensions=len(query_vector), - limit=semantic_artifact_retrieval.limit, + query_vector_dimensions=query_vector_dimensions, + match=existing_candidate.item["source_provenance"]["lexical_match"], + score=existing_candidate.item["source_provenance"]["semantic_score"], + sequence_no=existing_candidate.item["sequence_no"], + char_start=existing_candidate.item["char_start"], + char_end_exclusive=existing_candidate.item["char_end_exclusive"], ), ) ) - for position, item in enumerate(matched_items, start=1): - decision_kind = "included" if position <= semantic_artifact_retrieval.limit else "excluded" - decision_reason = ( - "within_semantic_artifact_chunk_limit" - if position <= semantic_artifact_retrieval.limit - else "semantic_artifact_chunk_limit_exceeded" + merged_candidates.sort( + key=lambda candidate: ( + min( + HYBRID_ARTIFACT_SOURCE_PRECEDENCE.index(source) + for source in candidate.sources + ), + candidate.lexical_rank if candidate.lexical_rank is not None else 2_147_483_647, + candidate.semantic_rank if candidate.semantic_rank is not None else 2_147_483_647, + candidate.item["relative_path"], + candidate.item["sequence_no"], + candidate.item["id"], ) + ) + + included_candidates = merged_candidates[:final_limit] if final_limit > 0 else [] + excluded_candidates = merged_candidates[final_limit:] if final_limit > 0 else merged_candidates + included_lexical_only_count = 0 + included_semantic_only_count = 0 + included_dual_source_count = 0 + + for position, candidate in enumerate(merged_candidates, start=1): + if position <= final_limit and final_limit > 0: + if candidate.sources == ["lexical"]: + included_lexical_only_count += 1 + elif candidate.sources == ["semantic"]: + included_semantic_only_count += 1 + else: + included_dual_source_count += 1 + decisions.append( + CompilerDecision( + "included", + "artifact_chunk", + UUID(candidate.item["id"]), + "within_hybrid_artifact_chunk_limit", + position, + metadata=_hybrid_artifact_retrieval_decision_metadata( + scope_kind=scope_kind, + task_id=UUID(candidate.item["task_id"]), + task_artifact_id=UUID(candidate.item["task_artifact_id"]), + relative_path=candidate.item["relative_path"], + media_type=candidate.item["media_type"], + ingestion_status="ingested", + limit=final_limit, + selected_sources=candidate.sources, + embedding_config_id=( + None + if semantic_artifact_retrieval is None + else semantic_artifact_retrieval.embedding_config_id + ), + query_vector_dimensions=query_vector_dimensions, + match=candidate.item["source_provenance"]["lexical_match"], + score=candidate.item["source_provenance"]["semantic_score"], + sequence_no=candidate.item["sequence_no"], + char_start=candidate.item["char_start"], + char_end_exclusive=candidate.item["char_end_exclusive"], + ), + ) + ) + continue + decisions.append( CompilerDecision( - decision_kind, - "semantic_artifact_chunk", - UUID(item["id"]), - decision_reason, + "excluded", + "artifact_chunk", + UUID(candidate.item["id"]), + "hybrid_artifact_chunk_limit_exceeded", position, - metadata=_semantic_artifact_retrieval_decision_metadata( + metadata=_hybrid_artifact_retrieval_decision_metadata( scope_kind=scope_kind, - task_id=UUID(item["task_id"]), - task_artifact_id=UUID(item["task_artifact_id"]), - relative_path=item["relative_path"], - media_type=item["media_type"], + task_id=UUID(candidate.item["task_id"]), + task_artifact_id=UUID(candidate.item["task_artifact_id"]), + relative_path=candidate.item["relative_path"], + media_type=candidate.item["media_type"], ingestion_status="ingested", - embedding_config_id=semantic_artifact_retrieval.embedding_config_id, - query_vector_dimensions=len(query_vector), - limit=semantic_artifact_retrieval.limit, - score=item["score"], - sequence_no=item["sequence_no"], - char_start=item["char_start"], - char_end_exclusive=item["char_end_exclusive"], + limit=final_limit, + selected_sources=candidate.sources, + embedding_config_id=( + None + if semantic_artifact_retrieval is None + else semantic_artifact_retrieval.embedding_config_id + ), + query_vector_dimensions=query_vector_dimensions, + match=candidate.item["source_provenance"]["lexical_match"], + score=candidate.item["source_provenance"]["semantic_score"], + sequence_no=candidate.item["sequence_no"], + char_start=candidate.item["char_start"], + char_end_exclusive=candidate.item["char_end_exclusive"], ), ) ) - section_summary = section_payload["summary"] - return CompiledSemanticArtifactChunkSection( - items=included_items, + return CompiledArtifactChunkSection( + items=[_serialize_hybrid_artifact_chunk(candidate) for candidate in included_candidates], summary={ "requested": True, - "scope": section_summary["scope"], - "embedding_config_id": section_summary["embedding_config_id"], - "query_vector_dimensions": section_summary["query_vector_dimensions"], - "limit": section_summary["limit"], - "searched_artifact_count": section_summary["searched_artifact_count"], - "candidate_count": len(matched_items), - "included_count": len(included_items), + "lexical_requested": artifact_retrieval is not None, + "semantic_requested": semantic_artifact_retrieval is not None, + "scope": scope, + "query": query, + "query_terms": list(query_terms), + "embedding_config_id": ( + None + if semantic_artifact_retrieval is None + else str(semantic_artifact_retrieval.embedding_config_id) + ), + "query_vector_dimensions": query_vector_dimensions, + "limit": final_limit, + "lexical_limit": 0 if artifact_retrieval is None else artifact_retrieval.limit, + "semantic_limit": ( + 0 if semantic_artifact_retrieval is None else semantic_artifact_retrieval.limit + ), + "searched_artifact_count": searched_artifact_count, + "lexical_candidate_count": len(lexical_items), + "semantic_candidate_count": len(semantic_items), + "merged_candidate_count": len(merged_candidates), + "deduplicated_count": deduplicated_count, + "included_count": len(included_candidates), + "included_lexical_only_count": included_lexical_only_count, + "included_semantic_only_count": included_semantic_only_count, + "included_dual_source_count": included_dual_source_count, "excluded_uningested_artifact_count": excluded_uningested_artifact_count, - "excluded_limit_count": max(len(matched_items) - len(included_items), 0), - "similarity_metric": section_summary["similarity_metric"], - "order": list(section_summary["order"]), + "excluded_limit_count": len(excluded_candidates), + "matching_rule": ( + None + if artifact_retrieval is None + else TASK_ARTIFACT_CHUNK_RETRIEVAL_MATCHING_RULE + ), + "similarity_metric": ( + None if semantic_artifact_retrieval is None else "cosine_similarity" + ), + "source_precedence": list(HYBRID_ARTIFACT_SOURCE_PRECEDENCE), + "lexical_order": list(TASK_ARTIFACT_CHUNK_RETRIEVAL_ORDER), + "semantic_order": list(TASK_ARTIFACT_CHUNK_SEMANTIC_RETRIEVAL_ORDER), + "merged_order": list(HYBRID_ARTIFACT_MERGED_ORDER), }, decisions=decisions, ) @@ -928,7 +1105,6 @@ def compile_continuity_context( limits: ContextCompilerLimits, memory_section: CompiledMemorySection | None = None, artifact_chunk_section: CompiledArtifactChunkSection | None = None, - semantic_artifact_chunk_section: CompiledSemanticArtifactChunkSection | None = None, ) -> CompilerRunResult: latest_session_sequence: dict[UUID, int] = {} for event in events: @@ -1027,15 +1203,6 @@ def compile_continuity_context( decisions=[], ) decisions.extend(resolved_artifact_chunk_section.decisions) - resolved_semantic_artifact_chunk_section = ( - semantic_artifact_chunk_section - or CompiledSemanticArtifactChunkSection( - items=[], - summary=_empty_semantic_artifact_chunk_summary(), - decisions=[], - ) - ) - decisions.extend(resolved_semantic_artifact_chunk_section.decisions) ordered_entities = sorted(entities, key=_entity_sort_key) included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] included_entity_ids = {entity["id"] for entity in included_entities} @@ -1172,36 +1339,34 @@ def compile_continuity_context( if resolved_artifact_chunk_section.summary["scope"] is None else resolved_artifact_chunk_section.summary["scope"]["kind"] ), - "artifact_chunk_candidate_count": resolved_artifact_chunk_section.summary[ - "candidate_count" + "artifact_lexical_retrieval_requested": resolved_artifact_chunk_section.summary[ + "lexical_requested" ], - "included_artifact_chunk_count": resolved_artifact_chunk_section.summary[ - "included_count" + "artifact_semantic_retrieval_requested": resolved_artifact_chunk_section.summary[ + "semantic_requested" ], - "excluded_artifact_chunk_limit_count": resolved_artifact_chunk_section.summary[ - "excluded_limit_count" + "artifact_lexical_candidate_count": resolved_artifact_chunk_section.summary[ + "lexical_candidate_count" ], - "excluded_uningested_artifact_count": resolved_artifact_chunk_section.summary[ - "excluded_uningested_artifact_count" + "artifact_semantic_candidate_count": resolved_artifact_chunk_section.summary[ + "semantic_candidate_count" ], - "semantic_artifact_retrieval_requested": resolved_semantic_artifact_chunk_section.summary[ - "requested" + "artifact_merged_candidate_count": resolved_artifact_chunk_section.summary[ + "merged_candidate_count" ], - "semantic_artifact_retrieval_scope_kind": ( - None - if resolved_semantic_artifact_chunk_section.summary["scope"] is None - else resolved_semantic_artifact_chunk_section.summary["scope"]["kind"] - ), - "semantic_artifact_chunk_candidate_count": resolved_semantic_artifact_chunk_section.summary[ - "candidate_count" + "artifact_deduplicated_count": resolved_artifact_chunk_section.summary[ + "deduplicated_count" ], - "included_semantic_artifact_chunk_count": resolved_semantic_artifact_chunk_section.summary[ + "included_artifact_chunk_count": resolved_artifact_chunk_section.summary[ "included_count" ], - "excluded_semantic_artifact_chunk_limit_count": resolved_semantic_artifact_chunk_section.summary[ + "included_dual_source_artifact_chunk_count": resolved_artifact_chunk_section.summary[ + "included_dual_source_count" + ], + "excluded_artifact_chunk_limit_count": resolved_artifact_chunk_section.summary[ "excluded_limit_count" ], - "excluded_semantic_uningested_artifact_count": resolved_semantic_artifact_chunk_section.summary[ + "excluded_uningested_artifact_count": resolved_artifact_chunk_section.summary[ "excluded_uningested_artifact_count" ], "included_entity_count": len(included_entities), @@ -1237,8 +1402,6 @@ def compile_continuity_context( "memory_summary": resolved_memory_section.summary, "artifact_chunks": list(resolved_artifact_chunk_section.items), "artifact_chunk_summary": resolved_artifact_chunk_section.summary, - "semantic_artifact_chunks": list(resolved_semantic_artifact_chunk_section.items), - "semantic_artifact_chunk_summary": resolved_semantic_artifact_chunk_section.summary, "entities": [_serialize_entity(entity) for entity in included_entities], "entity_summary": { "candidate_count": len(ordered_entities), @@ -1281,9 +1444,6 @@ def compile_and_persist_trace( artifact_chunk_section = _compile_artifact_chunk_section( store, artifact_retrieval=artifact_retrieval, - ) - semantic_artifact_chunk_section = _compile_semantic_artifact_chunk_section( - store, semantic_artifact_retrieval=semantic_artifact_retrieval, ) entities = store.list_entities() @@ -1301,7 +1461,6 @@ def compile_and_persist_trace( limits=limits, memory_section=memory_section, artifact_chunk_section=artifact_chunk_section, - semantic_artifact_chunk_section=semantic_artifact_chunk_section, ) trace = store.create_trace( user_id=user_id, diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index 8d4882b..7362392 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -80,6 +80,7 @@ "remember_that_i_prefer", ] MemorySelectionSource = Literal["symbolic", "semantic"] +ArtifactSelectionSource = Literal["lexical", "semantic"] DEFAULT_MAX_SESSIONS = 3 DEFAULT_MAX_EVENTS = 8 @@ -416,50 +417,44 @@ class ContextPackArtifactChunk(TypedDict): char_start: int char_end_exclusive: int text: str - match: "TaskArtifactChunkRetrievalMatch" + source_provenance: "ContextPackArtifactChunkSourceProvenance" + + +class ContextPackArtifactChunkSourceProvenance(TypedDict): + sources: list[ArtifactSelectionSource] + lexical_match: "TaskArtifactChunkRetrievalMatch | None" + semantic_score: float | None class ContextPackArtifactChunkSummary(TypedDict): requested: bool + lexical_requested: bool + semantic_requested: bool scope: TaskArtifactChunkRetrievalScope | None query: str | None query_terms: list[str] - matching_rule: str - limit: int - searched_artifact_count: int - candidate_count: int - included_count: int - excluded_uningested_artifact_count: int - excluded_limit_count: int - order: list[str] - - -class ContextPackSemanticArtifactChunk(TypedDict): - id: str - task_id: str - task_artifact_id: str - relative_path: str - media_type: str - sequence_no: int - char_start: int - char_end_exclusive: int - text: str - score: float - - -class ContextPackSemanticArtifactChunkSummary(TypedDict): - requested: bool - scope: TaskArtifactChunkRetrievalScope | None embedding_config_id: str | None query_vector_dimensions: int limit: int + lexical_limit: int + semantic_limit: int searched_artifact_count: int - candidate_count: int + lexical_candidate_count: int + semantic_candidate_count: int + merged_candidate_count: int + deduplicated_count: int included_count: int + included_lexical_only_count: int + included_semantic_only_count: int + included_dual_source_count: int excluded_uningested_artifact_count: int excluded_limit_count: int + matching_rule: str | None similarity_metric: Literal["cosine_similarity"] | None - order: list[str] + source_precedence: list[ArtifactSelectionSource] + lexical_order: list[str] + semantic_order: list[str] + merged_order: list[str] class ArtifactRetrievalDecisionTracePayload(TypedDict): @@ -478,18 +473,22 @@ class ArtifactRetrievalDecisionTracePayload(TypedDict): char_end_exclusive: NotRequired[int] -class SemanticArtifactRetrievalDecisionTracePayload(TypedDict): +class HybridArtifactRetrievalDecisionTracePayload(TypedDict): scope_kind: TaskArtifactChunkRetrievalScopeKind task_id: str task_artifact_id: str relative_path: str media_type: str | None ingestion_status: TaskArtifactIngestionStatus - embedding_config_id: str - query_vector_dimensions: int limit: int - similarity_metric: Literal["cosine_similarity"] + selected_sources: list[ArtifactSelectionSource] + embedding_config_id: str | None + query_vector_dimensions: int + matched_query_terms: NotRequired[list[str]] + matched_query_term_count: NotRequired[int] + first_match_char_start: NotRequired[int] score: NotRequired[float] + similarity_metric: NotRequired[Literal["cosine_similarity"]] sequence_no: NotRequired[int] char_start: NotRequired[int] char_end_exclusive: NotRequired[int] @@ -580,8 +579,6 @@ class CompiledContextPack(TypedDict): memory_summary: ContextPackMemorySummary artifact_chunks: list[ContextPackArtifactChunk] artifact_chunk_summary: ContextPackArtifactChunkSummary - semantic_artifact_chunks: list[ContextPackSemanticArtifactChunk] - semantic_artifact_chunk_summary: ContextPackSemanticArtifactChunkSummary entities: list[ContextPackEntity] entity_summary: ContextPackEntitySummary entity_edges: list[ContextPackEntityEdge] diff --git a/apps/api/src/alicebot_api/response_generation.py b/apps/api/src/alicebot_api/response_generation.py index 78f2ee6..7f6c055 100644 --- a/apps/api/src/alicebot_api/response_generation.py +++ b/apps/api/src/alicebot_api/response_generation.py @@ -92,8 +92,6 @@ def _context_section_payload(context_pack: CompiledContextPack) -> JsonObject: "memory_summary": context_pack["memory_summary"], "artifact_chunks": context_pack["artifact_chunks"], "artifact_chunk_summary": context_pack["artifact_chunk_summary"], - "semantic_artifact_chunks": context_pack["semantic_artifact_chunks"], - "semantic_artifact_chunk_summary": context_pack["semantic_artifact_chunk_summary"], "entities": context_pack["entities"], "entity_summary": context_pack["entity_summary"], "entity_edges": context_pack["entity_edges"], diff --git a/tests/integration/test_context_compile.py b/tests/integration/test_context_compile.py index cc43cd5..77979c5 100644 --- a/tests/integration/test_context_compile.py +++ b/tests/integration/test_context_compile.py @@ -485,20 +485,49 @@ def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_datab "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, } - assert payload["context_pack"]["semantic_artifact_chunks"] == [] - assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + assert payload["context_pack"]["artifact_chunks"] == [] + assert payload["context_pack"]["artifact_chunk_summary"] == { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, + "query": None, + "query_terms": [], "embedding_config_id": None, "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, + "matching_rule": None, "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } assert payload["context_pack"]["entities"] == [ { @@ -608,11 +637,13 @@ def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_datab assert trace_events[-1]["payload"]["hybrid_memory_candidate_count"] == 2 assert trace_events[-1]["payload"]["hybrid_memory_merged_candidate_count"] == 1 assert trace_events[-1]["payload"]["hybrid_memory_deduplicated_count"] == 0 - assert trace_events[-1]["payload"]["semantic_artifact_retrieval_requested"] is False - assert trace_events[-1]["payload"]["semantic_artifact_chunk_candidate_count"] == 0 - assert trace_events[-1]["payload"]["included_semantic_artifact_chunk_count"] == 0 - assert trace_events[-1]["payload"]["excluded_semantic_artifact_chunk_limit_count"] == 0 - assert trace_events[-1]["payload"]["excluded_semantic_uningested_artifact_count"] == 0 + assert trace_events[-1]["payload"]["artifact_lexical_retrieval_requested"] is False + assert trace_events[-1]["payload"]["artifact_semantic_retrieval_requested"] is False + assert trace_events[-1]["payload"]["artifact_lexical_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_semantic_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_merged_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_deduplicated_count"] == 0 + assert trace_events[-1]["payload"]["included_dual_source_artifact_chunk_count"] == 0 assert trace_events[-1]["payload"]["included_entity_count"] == 1 assert trace_events[-1]["payload"]["excluded_entity_limit_count"] == 2 assert trace_events[-1]["payload"]["included_entity_edge_count"] == 1 @@ -691,20 +722,49 @@ def test_compile_context_prefers_updated_active_memory_within_same_transaction( "semantic_order": ["score_desc", "created_at_asc", "id_asc"], }, } - assert payload["context_pack"]["semantic_artifact_chunks"] == [] - assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + assert payload["context_pack"]["artifact_chunks"] == [] + assert payload["context_pack"]["artifact_chunk_summary"] == { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, + "query": None, + "query_terms": [], "embedding_config_id": None, "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, + "matching_rule": None, "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } assert payload["context_pack"]["entity_summary"] == { "candidate_count": 2, @@ -1005,10 +1065,14 @@ def test_compile_context_artifact_retrieval_integrates_chunks_traces_and_exclusi "char_start": 0, "char_end_exclusive": 14, "text": "beta alpha doc", - "match": { - "matched_query_terms": ["alpha", "beta"], - "matched_query_term_count": 2, - "first_match_char_start": 0, + "source_provenance": { + "sources": ["lexical"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": None, }, }, { @@ -1021,32 +1085,59 @@ def test_compile_context_artifact_retrieval_integrates_chunks_traces_and_exclusi "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "match": { - "matched_query_terms": ["alpha", "beta"], - "matched_query_term_count": 2, - "first_match_char_start": 0, + "source_provenance": { + "sources": ["lexical"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": None, }, }, ] assert payload["context_pack"]["artifact_chunk_summary"] == { "requested": True, + "lexical_requested": True, + "semantic_requested": False, "scope": {"kind": "task", "task_id": str(artifact_scope["task_id"])}, "query": "Alpha beta", "query_terms": ["alpha", "beta"], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 2, + "lexical_limit": 2, + "semantic_limit": 0, "searched_artifact_count": 3, - "candidate_count": 3, + "lexical_candidate_count": 3, + "semantic_candidate_count": 0, + "merged_candidate_count": 3, + "deduplicated_count": 0, "included_count": 2, + "included_lexical_only_count": 2, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 1, "excluded_limit_count": 1, - "order": [ + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } assert payload["context_pack"]["memories"] assert payload["context_pack"]["entities"] @@ -1056,7 +1147,7 @@ def test_compile_context_artifact_retrieval_integrates_chunks_traces_and_exclusi trace_events = ContinuityStore(conn).list_trace_events(trace_id) assert any( - event["payload"]["reason"] == "within_artifact_chunk_limit" + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["docs"]) and event["payload"]["relative_path"] == "docs/a.txt" and event["payload"]["matched_query_terms"] == ["alpha", "beta"] @@ -1064,21 +1155,21 @@ def test_compile_context_artifact_retrieval_integrates_chunks_traces_and_exclusi if event["kind"] == "context.included" ) assert any( - event["payload"]["reason"] == "within_artifact_chunk_limit" + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) and event["payload"]["relative_path"] == "notes/b.md" for event in trace_events if event["kind"] == "context.included" ) assert any( - event["payload"]["reason"] == "artifact_chunk_limit_exceeded" + event["payload"]["reason"] == "hybrid_artifact_chunk_limit_exceeded" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["weak"]) and event["payload"]["relative_path"] == "notes/c.txt" for event in trace_events if event["kind"] == "context.excluded" ) assert any( - event["payload"]["reason"] == "artifact_not_ingested" + event["payload"]["reason"] == "hybrid_artifact_not_ingested" and event["payload"]["entity_id"] == str(artifact_scope["artifact_ids"]["pending"]) and event["payload"]["relative_path"] == "notes/hidden.txt" and event["payload"]["ingestion_status"] == "pending" @@ -1087,8 +1178,14 @@ def test_compile_context_artifact_retrieval_integrates_chunks_traces_and_exclusi ) assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "task" - assert trace_events[-1]["payload"]["artifact_chunk_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_lexical_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_semantic_retrieval_requested"] is False + assert trace_events[-1]["payload"]["artifact_lexical_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_semantic_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_merged_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_deduplicated_count"] == 0 assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 2 + assert trace_events[-1]["payload"]["included_dual_source_artifact_chunk_count"] == 0 assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 1 assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 1 @@ -1134,15 +1231,21 @@ def test_compile_context_artifact_scoped_retrieval_returns_only_visible_artifact "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "match": { - "matched_query_terms": ["alpha", "beta"], - "matched_query_term_count": 2, - "first_match_char_start": 0, + "source_provenance": { + "sources": ["lexical"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": None, }, } ] assert payload["context_pack"]["artifact_chunk_summary"] == { "requested": True, + "lexical_requested": True, + "semantic_requested": False, "scope": { "kind": "artifact", "task_id": str(artifact_scope["task_id"]), @@ -1150,20 +1253,41 @@ def test_compile_context_artifact_scoped_retrieval_returns_only_visible_artifact }, "query": "Alpha beta", "query_terms": ["alpha", "beta"], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 2, + "lexical_limit": 2, + "semantic_limit": 0, "searched_artifact_count": 1, - "candidate_count": 1, + "lexical_candidate_count": 1, + "semantic_candidate_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, "included_count": 1, + "included_lexical_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } trace_id = UUID(payload["trace_id"]) @@ -1171,7 +1295,7 @@ def test_compile_context_artifact_scoped_retrieval_returns_only_visible_artifact trace_events = ContinuityStore(conn).list_trace_events(trace_id) assert any( - event["payload"]["reason"] == "within_artifact_chunk_limit" + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) and event["payload"]["scope_kind"] == "artifact" and event["payload"]["task_artifact_id"] == str(artifact_scope["artifact_ids"]["notes"]) @@ -1180,12 +1304,207 @@ def test_compile_context_artifact_scoped_retrieval_returns_only_visible_artifact ) assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "artifact" - assert trace_events[-1]["payload"]["artifact_chunk_candidate_count"] == 1 + assert trace_events[-1]["payload"]["artifact_lexical_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_semantic_retrieval_requested"] is False + assert trace_events[-1]["payload"]["artifact_lexical_candidate_count"] == 1 + assert trace_events[-1]["payload"]["artifact_semantic_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_merged_candidate_count"] == 1 + assert trace_events[-1]["payload"]["artifact_deduplicated_count"] == 0 assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 1 + assert trace_events[-1]["payload"]["included_dual_source_artifact_chunk_count"] == 0 assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 0 assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 0 +def test_compile_context_hybrid_artifact_merge_preserves_dual_source_provenance_and_limits( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + artifact_scope = seed_compile_artifact_scope( + migrated_database_urls["app"], + user_id=seeded["user_id"], + thread_id=seeded["thread_id"], + ) + config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["docs"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["notes"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_task_artifact_chunk_embedding_for_user( + migrated_database_urls["app"], + user_id=seeded["user_id"], + task_artifact_chunk_id=artifact_scope["chunk_ids"]["weak"], + embedding_config_id=config_id, + vector=[0.0, 1.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "artifact_retrieval": { + "kind": "task", + "task_id": str(artifact_scope["task_id"]), + "query": "Alpha beta", + "limit": 2, + }, + "semantic_artifact_retrieval": { + "kind": "task", + "task_id": str(artifact_scope["task_id"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["context_pack"]["artifact_chunks"] == [ + { + "id": str(artifact_scope["chunk_ids"]["docs"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["docs"]), + "relative_path": "docs/a.txt", + "media_type": "text/plain", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 14, + "text": "beta alpha doc", + "source_provenance": { + "sources": ["lexical", "semantic"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": 1.0, + }, + }, + { + "id": str(artifact_scope["chunk_ids"]["notes"]), + "task_id": str(artifact_scope["task_id"]), + "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), + "relative_path": "notes/b.md", + "media_type": "text/markdown", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 15, + "text": "alpha beta note", + "source_provenance": { + "sources": ["lexical", "semantic"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": 1.0, + }, + }, + ] + assert payload["context_pack"]["artifact_chunk_summary"] == { + "requested": True, + "lexical_requested": True, + "semantic_requested": True, + "scope": {"kind": "task", "task_id": str(artifact_scope["task_id"])}, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "limit": 2, + "lexical_limit": 2, + "semantic_limit": 2, + "searched_artifact_count": 3, + "lexical_candidate_count": 3, + "semantic_candidate_count": 3, + "merged_candidate_count": 3, + "deduplicated_count": 3, + "included_count": 2, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 2, + "excluded_uningested_artifact_count": 1, + "excluded_limit_count": 1, + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "similarity_metric": "cosine_similarity", + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "hybrid_artifact_chunk_deduplicated" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["docs"]) + and event["payload"]["selected_sources"] == ["lexical", "semantic"] + and event["payload"]["score"] == 1.0 + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) + and event["payload"]["selected_sources"] == ["lexical", "semantic"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "hybrid_artifact_chunk_limit_exceeded" + and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["weak"]) + and event["payload"]["selected_sources"] == ["lexical", "semantic"] + and event["payload"]["score"] == 0.0 + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "task" + assert trace_events[-1]["payload"]["artifact_lexical_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_semantic_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_lexical_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_semantic_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_merged_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_deduplicated_count"] == 3 + assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 2 + assert trace_events[-1]["payload"]["included_dual_source_artifact_chunk_count"] == 2 + assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 1 + assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 1 + + def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_and_exclusion_rules( migrated_database_urls, monkeypatch, @@ -1242,7 +1561,7 @@ def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_an ) assert status_code == 200 - assert payload["context_pack"]["semantic_artifact_chunks"] == [ + assert payload["context_pack"]["artifact_chunks"] == [ { "id": str(artifact_scope["chunk_ids"]["docs"]), "task_id": str(artifact_scope["task_id"]), @@ -1253,7 +1572,11 @@ def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_an "char_start": 0, "char_end_exclusive": 14, "text": "beta alpha doc", - "score": 1.0, + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": 1.0, + }, }, { "id": str(artifact_scope["chunk_ids"]["notes"]), @@ -1265,31 +1588,63 @@ def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_an "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "score": 1.0, + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": 1.0, + }, }, ] - assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + assert payload["context_pack"]["artifact_chunk_summary"] == { "requested": True, + "lexical_requested": False, + "semantic_requested": True, "scope": {"kind": "task", "task_id": str(artifact_scope["task_id"])}, + "query": None, + "query_terms": [], "embedding_config_id": str(config_id), "query_vector_dimensions": 3, "limit": 2, + "lexical_limit": 0, + "semantic_limit": 2, "searched_artifact_count": 3, - "candidate_count": 3, + "lexical_candidate_count": 0, + "semantic_candidate_count": 3, + "merged_candidate_count": 3, + "deduplicated_count": 0, "included_count": 2, + "included_lexical_only_count": 0, + "included_semantic_only_count": 2, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 1, "excluded_limit_count": 1, + "matching_rule": None, "similarity_metric": "cosine_similarity", - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } - assert payload["context_pack"]["artifact_chunks"] == [] trace_id = UUID(payload["trace_id"]) with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: trace_events = ContinuityStore(conn).list_trace_events(trace_id) assert any( - event["payload"]["reason"] == "within_semantic_artifact_chunk_limit" + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["docs"]) and event["payload"]["relative_path"] == "docs/a.txt" and event["payload"]["score"] == 1.0 @@ -1297,14 +1652,14 @@ def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_an if event["kind"] == "context.included" ) assert any( - event["payload"]["reason"] == "within_semantic_artifact_chunk_limit" + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) and event["payload"]["relative_path"] == "notes/b.md" for event in trace_events if event["kind"] == "context.included" ) assert any( - event["payload"]["reason"] == "semantic_artifact_chunk_limit_exceeded" + event["payload"]["reason"] == "hybrid_artifact_chunk_limit_exceeded" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["weak"]) and event["payload"]["relative_path"] == "notes/c.txt" and event["payload"]["score"] == 0.0 @@ -1312,19 +1667,25 @@ def test_compile_context_semantic_artifact_retrieval_integrates_chunks_traces_an if event["kind"] == "context.excluded" ) assert any( - event["payload"]["reason"] == "semantic_artifact_not_ingested" + event["payload"]["reason"] == "hybrid_artifact_not_ingested" and event["payload"]["entity_id"] == str(artifact_scope["artifact_ids"]["pending"]) and event["payload"]["relative_path"] == "notes/hidden.txt" and event["payload"]["ingestion_status"] == "pending" for event in trace_events if event["kind"] == "context.excluded" ) - assert trace_events[-1]["payload"]["semantic_artifact_retrieval_requested"] is True - assert trace_events[-1]["payload"]["semantic_artifact_retrieval_scope_kind"] == "task" - assert trace_events[-1]["payload"]["semantic_artifact_chunk_candidate_count"] == 3 - assert trace_events[-1]["payload"]["included_semantic_artifact_chunk_count"] == 2 - assert trace_events[-1]["payload"]["excluded_semantic_artifact_chunk_limit_count"] == 1 - assert trace_events[-1]["payload"]["excluded_semantic_uningested_artifact_count"] == 1 + assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "task" + assert trace_events[-1]["payload"]["artifact_lexical_retrieval_requested"] is False + assert trace_events[-1]["payload"]["artifact_semantic_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_lexical_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_semantic_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_merged_candidate_count"] == 3 + assert trace_events[-1]["payload"]["artifact_deduplicated_count"] == 0 + assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 2 + assert trace_events[-1]["payload"]["included_dual_source_artifact_chunk_count"] == 0 + assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 1 + assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 1 def test_compile_context_semantic_artifact_scoped_retrieval_returns_only_visible_artifact_chunks( @@ -1369,7 +1730,7 @@ def test_compile_context_semantic_artifact_scoped_retrieval_returns_only_visible ) assert status_code == 200 - assert payload["context_pack"]["semantic_artifact_chunks"] == [ + assert payload["context_pack"]["artifact_chunks"] == [ { "id": str(artifact_scope["chunk_ids"]["notes"]), "task_id": str(artifact_scope["task_id"]), @@ -1380,26 +1741,59 @@ def test_compile_context_semantic_artifact_scoped_retrieval_returns_only_visible "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "score": 1.0, + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": 1.0, + }, } ] - assert payload["context_pack"]["semantic_artifact_chunk_summary"] == { + assert payload["context_pack"]["artifact_chunk_summary"] == { "requested": True, + "lexical_requested": False, + "semantic_requested": True, "scope": { "kind": "artifact", "task_id": str(artifact_scope["task_id"]), "task_artifact_id": str(artifact_scope["artifact_ids"]["notes"]), }, + "query": None, + "query_terms": [], "embedding_config_id": str(config_id), "query_vector_dimensions": 3, "limit": 2, + "lexical_limit": 0, + "semantic_limit": 2, "searched_artifact_count": 1, - "candidate_count": 1, + "lexical_candidate_count": 0, + "semantic_candidate_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 0, "included_count": 1, + "included_lexical_only_count": 0, + "included_semantic_only_count": 1, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, + "matching_rule": None, "similarity_metric": "cosine_similarity", - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } trace_id = UUID(payload["trace_id"]) @@ -1407,19 +1801,25 @@ def test_compile_context_semantic_artifact_scoped_retrieval_returns_only_visible trace_events = ContinuityStore(conn).list_trace_events(trace_id) assert any( - event["payload"]["reason"] == "within_semantic_artifact_chunk_limit" + event["payload"]["reason"] == "within_hybrid_artifact_chunk_limit" and event["payload"]["entity_id"] == str(artifact_scope["chunk_ids"]["notes"]) and event["payload"]["scope_kind"] == "artifact" and event["payload"]["task_artifact_id"] == str(artifact_scope["artifact_ids"]["notes"]) for event in trace_events if event["kind"] == "context.included" ) - assert trace_events[-1]["payload"]["semantic_artifact_retrieval_requested"] is True - assert trace_events[-1]["payload"]["semantic_artifact_retrieval_scope_kind"] == "artifact" - assert trace_events[-1]["payload"]["semantic_artifact_chunk_candidate_count"] == 1 - assert trace_events[-1]["payload"]["included_semantic_artifact_chunk_count"] == 1 - assert trace_events[-1]["payload"]["excluded_semantic_artifact_chunk_limit_count"] == 0 - assert trace_events[-1]["payload"]["excluded_semantic_uningested_artifact_count"] == 0 + assert trace_events[-1]["payload"]["artifact_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_retrieval_scope_kind"] == "artifact" + assert trace_events[-1]["payload"]["artifact_lexical_retrieval_requested"] is False + assert trace_events[-1]["payload"]["artifact_semantic_retrieval_requested"] is True + assert trace_events[-1]["payload"]["artifact_lexical_candidate_count"] == 0 + assert trace_events[-1]["payload"]["artifact_semantic_candidate_count"] == 1 + assert trace_events[-1]["payload"]["artifact_merged_candidate_count"] == 1 + assert trace_events[-1]["payload"]["artifact_deduplicated_count"] == 0 + assert trace_events[-1]["payload"]["included_artifact_chunk_count"] == 1 + assert trace_events[-1]["payload"]["included_dual_source_artifact_chunk_count"] == 0 + assert trace_events[-1]["payload"]["excluded_artifact_chunk_limit_count"] == 0 + assert trace_events[-1]["payload"]["excluded_uningested_artifact_count"] == 0 def test_compile_context_semantic_artifact_retrieval_validation_and_isolation( diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 7ff19c9..8267391 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -7,10 +7,10 @@ SUMMARY_TRACE_EVENT_KIND, _compile_artifact_chunk_section, _compile_memory_section, - _compile_semantic_artifact_chunk_section, compile_continuity_context, ) from alicebot_api.contracts import ( + CompileContextArtifactScopedArtifactRetrievalInput, CompileContextArtifactScopedSemanticArtifactRetrievalInput, CompileContextSemanticRetrievalInput, CompileContextTaskScopedSemanticArtifactRetrievalInput, @@ -298,38 +298,46 @@ def test_compile_continuity_context_is_deterministic_and_stably_ordered() -> Non assert first_run.context_pack["artifact_chunks"] == [] assert first_run.context_pack["artifact_chunk_summary"] == { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, "query": None, "query_terms": [], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": None, + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], - } - assert first_run.context_pack["semantic_artifact_chunks"] == [] - assert first_run.context_pack["semantic_artifact_chunk_summary"] == { - "requested": False, - "scope": None, - "embedding_config_id": None, - "query_vector_dimensions": 0, - "limit": 0, - "searched_artifact_count": 0, - "candidate_count": 0, - "included_count": 0, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } assert first_run.context_pack["entity_summary"] == { "candidate_count": 3, @@ -643,38 +651,46 @@ def test_compile_continuity_context_records_included_and_excluded_reasons() -> N assert compiler_run.context_pack["artifact_chunks"] == [] assert compiler_run.context_pack["artifact_chunk_summary"] == { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, "query": None, "query_terms": [], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": None, + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], - } - assert compiler_run.context_pack["semantic_artifact_chunks"] == [] - assert compiler_run.context_pack["semantic_artifact_chunk_summary"] == { - "requested": False, - "scope": None, - "embedding_config_id": None, - "query_vector_dimensions": 0, - "limit": 0, - "searched_artifact_count": 0, - "candidate_count": 0, - "included_count": 0, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } assert compiler_run.context_pack["entities"] == [ { @@ -710,18 +726,16 @@ def test_compile_continuity_context_records_included_and_excluded_reasons() -> N assert compiler_run.trace_events[-1].payload["hybrid_memory_merged_candidate_count"] == 1 assert compiler_run.trace_events[-1].payload["hybrid_memory_deduplicated_count"] == 0 assert compiler_run.trace_events[-1].payload["artifact_retrieval_requested"] is False - assert compiler_run.trace_events[-1].payload["artifact_chunk_candidate_count"] == 0 + assert compiler_run.trace_events[-1].payload["artifact_lexical_retrieval_requested"] is False + assert compiler_run.trace_events[-1].payload["artifact_semantic_retrieval_requested"] is False + assert compiler_run.trace_events[-1].payload["artifact_lexical_candidate_count"] == 0 + assert compiler_run.trace_events[-1].payload["artifact_semantic_candidate_count"] == 0 + assert compiler_run.trace_events[-1].payload["artifact_merged_candidate_count"] == 0 + assert compiler_run.trace_events[-1].payload["artifact_deduplicated_count"] == 0 assert compiler_run.trace_events[-1].payload["included_artifact_chunk_count"] == 0 + assert compiler_run.trace_events[-1].payload["included_dual_source_artifact_chunk_count"] == 0 assert compiler_run.trace_events[-1].payload["excluded_artifact_chunk_limit_count"] == 0 assert compiler_run.trace_events[-1].payload["excluded_uningested_artifact_count"] == 0 - assert compiler_run.trace_events[-1].payload["semantic_artifact_retrieval_requested"] is False - assert compiler_run.trace_events[-1].payload["semantic_artifact_chunk_candidate_count"] == 0 - assert compiler_run.trace_events[-1].payload["included_semantic_artifact_chunk_count"] == 0 - assert ( - compiler_run.trace_events[-1].payload["excluded_semantic_artifact_chunk_limit_count"] - == 0 - ) - assert compiler_run.trace_events[-1].payload["excluded_semantic_uningested_artifact_count"] == 0 class SemanticCompileStoreStub: @@ -994,6 +1008,7 @@ def test_compile_artifact_chunk_section_orders_limits_and_excludes_non_ingested( query="Alpha beta", limit=2, ), + semantic_artifact_retrieval=None, ) assert artifact_section.items == [ @@ -1007,10 +1022,14 @@ def test_compile_artifact_chunk_section_orders_limits_and_excludes_non_ingested( "char_start": 0, "char_end_exclusive": 14, "text": "beta alpha doc", - "match": { - "matched_query_terms": ["alpha", "beta"], - "matched_query_term_count": 2, - "first_match_char_start": 0, + "source_provenance": { + "sources": ["lexical"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": None, }, }, { @@ -1023,48 +1042,76 @@ def test_compile_artifact_chunk_section_orders_limits_and_excludes_non_ingested( "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "match": { - "matched_query_terms": ["alpha", "beta"], - "matched_query_term_count": 2, - "first_match_char_start": 0, + "source_provenance": { + "sources": ["lexical"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": None, }, }, ] assert artifact_section.summary == { "requested": True, + "lexical_requested": True, + "semantic_requested": False, "scope": {"kind": "task", "task_id": str(store.task_id)}, "query": "Alpha beta", "query_terms": ["alpha", "beta"], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 2, + "lexical_limit": 2, + "semantic_limit": 0, "searched_artifact_count": 3, - "candidate_count": 3, + "lexical_candidate_count": 3, + "semantic_candidate_count": 0, + "merged_candidate_count": 3, + "deduplicated_count": 0, "included_count": 2, + "included_lexical_only_count": 2, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 1, "excluded_limit_count": 1, - "order": [ + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } assert [decision.reason for decision in artifact_section.decisions] == [ - "artifact_not_ingested", - "within_artifact_chunk_limit", - "within_artifact_chunk_limit", - "artifact_chunk_limit_exceeded", + "hybrid_artifact_not_ingested", + "within_hybrid_artifact_chunk_limit", + "within_hybrid_artifact_chunk_limit", + "hybrid_artifact_chunk_limit_exceeded", ] assert artifact_section.decisions[0].metadata["relative_path"] == "notes/hidden.txt" assert artifact_section.decisions[-1].metadata["relative_path"] == "notes/c.txt" -def test_compile_semantic_artifact_chunk_section_orders_limits_and_excludes_non_ingested() -> None: +def test_compile_artifact_chunk_section_supports_semantic_only_scope() -> None: store = ArtifactCompileStoreStub() - semantic_artifact_section = _compile_semantic_artifact_chunk_section( + artifact_section = _compile_artifact_chunk_section( store, # type: ignore[arg-type] + artifact_retrieval=None, semantic_artifact_retrieval=CompileContextTaskScopedSemanticArtifactRetrievalInput( task_id=store.task_id, embedding_config_id=store.config_id, @@ -1073,7 +1120,7 @@ def test_compile_semantic_artifact_chunk_section_orders_limits_and_excludes_non_ ), ) - assert semantic_artifact_section.items == [ + assert artifact_section.items == [ { "id": str(store.chunk_ids[0]), "task_id": str(store.task_id), @@ -1084,7 +1131,11 @@ def test_compile_semantic_artifact_chunk_section_orders_limits_and_excludes_non_ "char_start": 0, "char_end_exclusive": 14, "text": "beta alpha doc", - "score": 1.0, + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": 1.0, + }, }, { "id": str(store.chunk_ids[1]), @@ -1096,38 +1147,76 @@ def test_compile_semantic_artifact_chunk_section_orders_limits_and_excludes_non_ "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "score": 1.0, + "source_provenance": { + "sources": ["semantic"], + "lexical_match": None, + "semantic_score": 1.0, + }, }, ] - assert semantic_artifact_section.summary == { + assert artifact_section.summary == { "requested": True, + "lexical_requested": False, + "semantic_requested": True, "scope": {"kind": "task", "task_id": str(store.task_id)}, + "query": None, + "query_terms": [], "embedding_config_id": str(store.config_id), "query_vector_dimensions": 3, "limit": 2, + "lexical_limit": 0, + "semantic_limit": 2, "searched_artifact_count": 3, - "candidate_count": 3, + "lexical_candidate_count": 0, + "semantic_candidate_count": 3, + "merged_candidate_count": 3, + "deduplicated_count": 0, "included_count": 2, + "included_lexical_only_count": 0, + "included_semantic_only_count": 2, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 1, "excluded_limit_count": 1, + "matching_rule": None, "similarity_metric": "cosine_similarity", - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } - assert [decision.reason for decision in semantic_artifact_section.decisions] == [ - "semantic_artifact_not_ingested", - "within_semantic_artifact_chunk_limit", - "within_semantic_artifact_chunk_limit", - "semantic_artifact_chunk_limit_exceeded", + assert [decision.reason for decision in artifact_section.decisions] == [ + "hybrid_artifact_not_ingested", + "within_hybrid_artifact_chunk_limit", + "within_hybrid_artifact_chunk_limit", + "hybrid_artifact_chunk_limit_exceeded", ] - assert semantic_artifact_section.decisions[0].metadata["relative_path"] == "notes/hidden.txt" - assert semantic_artifact_section.decisions[-1].metadata["relative_path"] == "notes/c.txt" + assert artifact_section.decisions[0].metadata["relative_path"] == "notes/hidden.txt" + assert artifact_section.decisions[-1].metadata["relative_path"] == "notes/c.txt" -def test_compile_semantic_artifact_chunk_section_supports_artifact_scope() -> None: +def test_compile_artifact_chunk_section_merges_dual_source_provenance_for_artifact_scope() -> None: store = ArtifactCompileStoreStub() - semantic_artifact_section = _compile_semantic_artifact_chunk_section( + artifact_section = _compile_artifact_chunk_section( store, # type: ignore[arg-type] + artifact_retrieval=CompileContextArtifactScopedArtifactRetrievalInput( + task_artifact_id=store.artifact_ids[1], + query="Alpha beta", + limit=2, + ), semantic_artifact_retrieval=CompileContextArtifactScopedSemanticArtifactRetrievalInput( task_artifact_id=store.artifact_ids[1], embedding_config_id=store.config_id, @@ -1136,7 +1225,7 @@ def test_compile_semantic_artifact_chunk_section_supports_artifact_scope() -> No ), ) - assert semantic_artifact_section.items == [ + assert artifact_section.items == [ { "id": str(store.chunk_ids[1]), "task_id": str(store.task_id), @@ -1147,28 +1236,69 @@ def test_compile_semantic_artifact_chunk_section_supports_artifact_scope() -> No "char_start": 0, "char_end_exclusive": 15, "text": "alpha beta note", - "score": 1.0, + "source_provenance": { + "sources": ["lexical", "semantic"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": 1.0, + }, } ] - assert semantic_artifact_section.summary == { + assert artifact_section.summary == { "requested": True, + "lexical_requested": True, + "semantic_requested": True, "scope": { "kind": "artifact", "task_id": str(store.task_id), "task_artifact_id": str(store.artifact_ids[1]), }, + "query": "Alpha beta", + "query_terms": ["alpha", "beta"], "embedding_config_id": str(store.config_id), "query_vector_dimensions": 3, "limit": 2, + "lexical_limit": 2, + "semantic_limit": 2, "searched_artifact_count": 1, - "candidate_count": 1, + "lexical_candidate_count": 1, + "semantic_candidate_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 1, "included_count": 1, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", "similarity_metric": "cosine_similarity", - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ + "matched_query_term_count_desc", + "first_match_char_start_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], } - assert semantic_artifact_section.decisions[0].metadata["scope_kind"] == "artifact" + assert [decision.reason for decision in artifact_section.decisions] == [ + "hybrid_artifact_chunk_deduplicated", + "within_hybrid_artifact_chunk_limit", + ] + assert artifact_section.decisions[1].metadata["selected_sources"] == ["lexical", "semantic"] def test_compile_memory_section_orders_limits_and_excludes_deleted() -> None: diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index c8e63b8..67ad9a8 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -275,38 +275,46 @@ def fake_compile_and_persist_trace( "artifact_chunks": [], "artifact_chunk_summary": { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, "query": None, "query_terms": [], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": None, + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], - }, - "semantic_artifact_chunks": [], - "semantic_artifact_chunk_summary": { - "requested": False, - "scope": None, - "embedding_config_id": None, - "query_vector_dimensions": 0, - "limit": 0, - "searched_artifact_count": 0, - "candidate_count": 0, - "included_count": 0, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], }, "entities": [ { @@ -425,38 +433,46 @@ def fake_compile_and_persist_trace( "artifact_chunks": [], "artifact_chunk_summary": { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, "query": None, "query_terms": [], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": None, + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], - }, - "semantic_artifact_chunks": [], - "semantic_artifact_chunk_summary": { - "requested": False, - "scope": None, - "embedding_config_id": None, - "query_vector_dimensions": 0, - "limit": 0, - "searched_artifact_count": 0, - "candidate_count": 0, - "included_count": 0, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], }, "entities": [ { @@ -639,60 +655,59 @@ def fake_compile_and_persist_trace( "char_start": 0, "char_end_exclusive": 16, "text": "alpha beta spec", - "match": { - "matched_query_terms": ["alpha", "beta"], - "matched_query_term_count": 2, - "first_match_char_start": 0, + "source_provenance": { + "sources": ["lexical", "semantic"], + "lexical_match": { + "matched_query_terms": ["alpha", "beta"], + "matched_query_term_count": 2, + "first_match_char_start": 0, + }, + "semantic_score": 0.99, }, } ], "artifact_chunk_summary": { "requested": True, + "lexical_requested": True, + "semantic_requested": True, "scope": {"kind": "task", "task_id": "task-123"}, "query": "alpha beta", "query_terms": ["alpha", "beta"], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, "limit": 2, + "lexical_limit": 2, + "semantic_limit": 2, "searched_artifact_count": 1, - "candidate_count": 1, + "lexical_candidate_count": 1, + "semantic_candidate_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 1, "included_count": 1, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "similarity_metric": "cosine_similarity", + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], - }, - "semantic_artifact_chunks": [ - { - "id": "semantic-chunk-123", - "task_id": "task-123", - "task_artifact_id": "artifact-123", - "relative_path": "docs/spec.txt", - "media_type": "text/plain", - "sequence_no": 1, - "char_start": 0, - "char_end_exclusive": 16, - "text": "alpha beta spec", - "score": 0.99, - } - ], - "semantic_artifact_chunk_summary": { - "requested": True, - "scope": {"kind": "task", "task_id": "task-123"}, - "embedding_config_id": str(config_id), - "query_vector_dimensions": 3, - "limit": 2, - "searched_artifact_count": 1, - "candidate_count": 1, - "included_count": 1, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "similarity_metric": "cosine_similarity", - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], }, "entities": [], "entity_summary": { diff --git a/tests/unit/test_response_generation.py b/tests/unit/test_response_generation.py index 59cbd40..9c6c22c 100644 --- a/tests/unit/test_response_generation.py +++ b/tests/unit/test_response_generation.py @@ -91,38 +91,46 @@ def make_context_pack() -> dict[str, object]: "artifact_chunks": [], "artifact_chunk_summary": { "requested": False, + "lexical_requested": False, + "semantic_requested": False, "scope": None, "query": None, "query_terms": [], - "matching_rule": "casefolded_unicode_word_overlap_unique_query_terms_v1", + "embedding_config_id": None, + "query_vector_dimensions": 0, "limit": 0, + "lexical_limit": 0, + "semantic_limit": 0, "searched_artifact_count": 0, - "candidate_count": 0, + "lexical_candidate_count": 0, + "semantic_candidate_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, "included_count": 0, + "included_lexical_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, "excluded_uningested_artifact_count": 0, "excluded_limit_count": 0, - "order": [ + "matching_rule": None, + "similarity_metric": None, + "source_precedence": ["lexical", "semantic"], + "lexical_order": [ "matched_query_term_count_desc", "first_match_char_start_asc", "relative_path_asc", "sequence_no_asc", "id_asc", ], - }, - "semantic_artifact_chunks": [], - "semantic_artifact_chunk_summary": { - "requested": False, - "scope": None, - "embedding_config_id": None, - "query_vector_dimensions": 0, - "limit": 0, - "searched_artifact_count": 0, - "candidate_count": 0, - "included_count": 0, - "excluded_uningested_artifact_count": 0, - "excluded_limit_count": 0, - "similarity_metric": None, - "order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "semantic_order": ["score_desc", "relative_path_asc", "sequence_no_asc", "id_asc"], + "merged_order": [ + "source_precedence_asc", + "lexical_rank_asc", + "semantic_rank_asc", + "relative_path_asc", + "sequence_no_asc", + "id_asc", + ], }, "entities": [], "entity_summary": { From dfd36943c40204d87c0b7c29a608a6bb3bd760eb Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 09:26:19 +0100 Subject: [PATCH 011/135] Sprint 5K: project truth synchronization (#11) * Sprint 5K: project truth sync packet * Sprint 5K: project truth synchronization --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 35 ++++++++++++++++++++++------------- ROADMAP.md | 36 +++++++++++++++++++----------------- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 0d939f7..168429a 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5H. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5J. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, optional compile-path artifact chunk inclusion as a separate context section, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, and explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, optional compile-path inclusion of retrieved artifact chunks in a separate response section, explicit artifact-chunk embedding storage tied to existing embedding configs, and direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time. Broader runner-style orchestration, automatic multi-step progression, compile-path semantic artifact use, hybrid artifact retrieval, richer document parsing, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. Broader runner-style orchestration, automatic multi-step progression, richer document parsing, connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -58,11 +58,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, and narrow compile-path artifact chunk inclusion. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, and Sprint 5H semantic artifact-chunk retrieval. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, and Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile. ## Core Flows Implemented Now @@ -71,10 +71,12 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 1. Accept a user-scoped `POST /v0/context/compile` request. 2. Read durable continuity records in deterministic order. 3. Merge in active memories, entities, and entity edges through the currently shipped symbolic and optional semantic retrieval paths. -4. Optionally retrieve artifact chunks through the existing lexical artifact-chunk retrieval seam, scoped to exactly one visible task or one visible artifact per request. -5. Keep retrieved artifact chunks separate from memory and entity sections, with deterministic per-section limits and ordering. -6. Persist a `context.compile` trace plus explicit inclusion and exclusion events, including artifact chunk include/exclude decisions. -7. Return one deterministic `context_pack` describing scope, limits, selected context, artifact chunk results, and trace metadata. +4. Optionally retrieve artifact chunks through lexical retrieval, semantic retrieval, or both, scoped to exactly one visible task or one visible artifact per request. +5. Reuse only persisted `task_artifact_chunks` rows and persisted artifact-chunk embeddings during compile; compile does not read raw files. +6. When both artifact retrieval modes are present for the same scope, merge candidates by durable chunk id into one `artifact_chunks` section, preserve lexical match and semantic score provenance, and apply deterministic lexical-first source precedence. +7. Keep retrieved artifact chunks separate from memory and entity sections, with deterministic per-section limits, ordering, and summary metadata. +8. Persist a `context.compile` trace plus explicit inclusion and exclusion events, including artifact chunk deduplication, inclusion, and exclusion decisions. +9. Return one deterministic `context_pack` describing scope, limits, selected context, artifact chunk results, and trace metadata. ### Artifact Chunk Retrieval @@ -83,7 +85,10 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 3. Support deterministic lexical artifact-chunk retrieval for one visible task or one visible artifact. 4. Support deterministic semantic artifact-chunk retrieval for one visible task or one visible artifact, using a caller-supplied query vector plus explicit `embedding_config_id`. 5. Exclude artifacts whose `ingestion_status` is not `ingested`. -6. Keep compile-path artifact retrieval separate and lexical-only for now; semantic artifact retrieval remains a direct read seam outside compile. +6. Reuse those same persisted lexical and semantic retrieval seams inside compile for one visible task or one visible artifact. +7. When compile receives both lexical and semantic artifact retrieval for the same scope, deduplicate by durable chunk id, preserve per-chunk source provenance, and count dual-source inclusions explicitly. +8. Order hybrid compile candidates deterministically by source precedence, lexical rank, semantic rank, `relative_path`, `sequence_no`, and `id`. +9. Return stable summary metadata covering scope, query terms, embedding config, query-vector dimensions, candidate counts, deduplication counts, inclusion counts, exclusion counts, and ordering rules. ### Governed Memory And Retrieval @@ -236,8 +241,8 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ## Testing Coverage Implemented Now -- Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, artifact semantic retrieval, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. -- Sprint 4O, Sprint 4S, Sprint 5A, and Sprint 5C added explicit task lifecycle coverage: +- Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, compile-path hybrid memory retrieval, artifact lexical retrieval, artifact semantic retrieval, compile-path semantic artifact retrieval, hybrid artifact compile merge, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. +- Sprints 4O through 5J added explicit task lifecycle and artifact retrieval coverage: - migrations for `tasks`, `task_steps`, and task-step lineage - staged/backfilled migration coverage for `tool_executions.task_step_id` - task and task-step store contracts @@ -259,6 +264,10 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - idempotent re-ingestion of already ingested artifacts - deterministic lexical artifact-chunk retrieval by task and by artifact - compile-path artifact chunk inclusion, exclusion, ordering, and per-user isolation + - artifact-chunk embedding write and read coverage + - direct semantic artifact-chunk retrieval by task and by artifact + - compile-path semantic artifact retrieval including trace visibility, exclusion rules, and scope isolation + - deterministic hybrid artifact compile merge with dual-source provenance, deduplication, lexical-first precedence, and shared limit enforcement - task-artifact and task-artifact-chunk per-user isolation - trace visibility for continuation and transition events - user isolation for task and task-step reads and mutations @@ -269,7 +278,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i The following areas remain planned later and must not be described as implemented: - runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam -- hybrid lexical plus semantic artifact retrieval, compile-path semantic artifact use, and reranking beyond the current direct lexical and direct semantic ordering seams +- artifact reranking, weighted fusion, or precedence changes beyond the current lexical-first hybrid compile merge and direct lexical/direct semantic ordering seams - rich document parsing beyond the current narrow UTF-8 text and markdown ingestion boundary - read-only Gmail and Calendar connectors - broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler diff --git a/ROADMAP.md b/ROADMAP.md index a963227..a3d9e21 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -2,38 +2,40 @@ ## Current Position -- The accepted repo state is current through Sprint 5A. -- The backend foundation through governance, execution review, task/task-step lifecycle, explicit manual continuation, step-linked approval/execution synchronization, and deterministic rooted task-workspace provisioning is already shipped. -- This roadmap is future-facing from that position; milestone history lives in archived sprint reports, not here. +- The accepted repo state is current through Sprint 5J. +- Milestone 5 now ships the rooted local workspace and artifact baseline end to end: workspace provisioning, artifact registration, narrow text ingestion, durable chunk storage, lexical artifact retrieval, compile-path artifact inclusion, artifact-chunk embeddings, direct semantic artifact retrieval, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. +- This roadmap is future-facing from that shipped baseline; historical sprint-by-sprint detail lives in accepted build and review artifacts, not here. ## Next Delivery Focus -### Finish Milestone 5 On Top Of The Shipped Workspace Boundary +### Open Richer Document Parsing On Top Of The Shipped Artifact Retrieval Baseline -- Add artifact records and artifact-handling rules that reuse `task_workspaces` instead of inventing a parallel storage seam. -- Add document ingestion and retrieval only after the artifact/workspace boundary is explicit and reviewable. -- Add read-only Gmail and Calendar connectors only after document and workspace boundaries remain deterministic under the current governance model. +- Extend ingestion beyond the current `text/plain` and `text/markdown` seam without changing the rooted `task_workspaces` and durable `task_artifact_chunks` contracts. +- Keep retrieval building on persisted chunk rows and persisted embeddings; new parsing work should feed the existing compile-path lexical/semantic/hybrid artifact retrieval seam rather than inventing a parallel context path. +- Keep the next sprint narrow: richer document parsing first, then reassess connectors only after the parsing seam is accepted. -### Preserve Current Governance And Task Guarantees +### Preserve Current Compile, Governance, And Task Guarantees -- Keep approvals, execution budgets, task/task-step state, and trace visibility deterministic as new Milestone 5 work lands. -- Do not widen the current no-external-I/O proxy surface or introduce new consequential side effects without an explicit sprint opening that scope. +- Keep approvals, execution budgets, task/task-step state, and trace visibility deterministic as Milestone 5 continues. +- Preserve the shipped compile contract of one merged artifact section with explicit source provenance, deterministic lexical-first precedence, and trace-visible inclusion and exclusion decisions. +- Do not widen the current no-external-I/O proxy surface or introduce runner, connector, or UI scope until those areas are explicitly opened. -## After Milestone 5 +## After The Next Narrow Sprint -- Revisit broader task orchestration only after the current explicit task-step seams remain stable under workspace, artifact, and document flows. -- Expand tool execution breadth only after governance, review, and budget controls still hold under the wider task surface. -- Address production-facing auth and deployment hardening as the product approaches broader real-world use. +- Open read-only connector work only after richer document parsing remains deterministic under the current artifact and governance seams. +- Revisit workflow UI only after backend document and connector seams are accepted and the truth artifacts stay current. +- Revisit broader task orchestration only after the current explicit task-step seams remain stable under workspace, artifact, document, and connector flows. +- Continue to defer broader tool execution breadth and production auth/deployment hardening until the current governed surface remains stable. ## Dependencies - Live truth docs must stay synchronized with accepted repo state so sprint planning does not start from stale assumptions. -- Artifact and document work should build on the existing rooted local workspace contract. -- Connector work should remain read-only and approval-aware. +- Rich document parsing should build on the shipped rooted local workspace, durable artifact chunk, and hybrid compile retrieval contracts. +- Connector work should remain read-only, approval-aware, and downstream of the document parsing seam. - Runner-style orchestration should stay deferred until the repo no longer depends on narrow current-step assumptions for safety and explainability. ## Ongoing Risks - Memory extraction and retrieval quality remain the largest product risk. - Auth beyond database user context is still missing. -- Milestone 5 can drift if artifact, document, connector, and orchestration work are mixed into one sprint instead of landing as narrow seams. +- Milestone 5 can drift if richer document parsing, connectors, UI, and orchestration work are mixed into one sprint instead of landing as narrow seams on top of the shipped artifact retrieval baseline. From 9635d08c6f234c35bc89dd9b3550d13ea532fce7 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 09:56:10 +0100 Subject: [PATCH 012/135] Sprint 5L: PDF artifact parsing v0 (#12) * Sprint 5L: PDF artifact parsing packet * Sprint 5L: PDF artifact parsing v0 --------- Co-authored-by: Sami Rusani --- apps/api/src/alicebot_api/artifacts.py | 601 ++++++++++++++++++- tests/integration/test_task_artifacts_api.py | 283 ++++++++- tests/unit/test_artifacts.py | 216 ++++++- tests/unit/test_artifacts_main.py | 8 +- 4 files changed, 1073 insertions(+), 35 deletions(-) diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py index d3b794f..3e12638 100644 --- a/apps/api/src/alicebot_api/artifacts.py +++ b/apps/api/src/alicebot_api/artifacts.py @@ -1,6 +1,8 @@ from __future__ import annotations import re +import zlib +from dataclasses import dataclass from pathlib import Path from typing import cast from uuid import UUID @@ -37,11 +39,17 @@ from alicebot_api.workspaces import TaskWorkspaceNotFoundError SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES = ("text/plain", "text/markdown") -SUPPORTED_TEXT_ARTIFACT_EXTENSIONS = { +SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE = "application/pdf" +SUPPORTED_ARTIFACT_MEDIA_TYPES = ( + *SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES, + SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE, +) +SUPPORTED_ARTIFACT_EXTENSIONS = { ".txt": "text/plain", ".text": "text/plain", ".md": "text/markdown", ".markdown": "text/markdown", + ".pdf": SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE, } TASK_ARTIFACT_CHUNK_MAX_CHARS = 1000 TASK_ARTIFACT_CHUNKING_RULE = "normalized_utf8_text_fixed_window_1000_chars_v1" @@ -49,6 +57,90 @@ "casefolded_unicode_word_overlap_unique_query_terms_v1" ) _LEXICAL_TERM_PATTERN = re.compile(r"\w+") +_PDF_INDIRECT_OBJECT_PATTERN = re.compile(rb"(?s)(\d+)\s+(\d+)\s+obj\b(.*?)\bendobj\b") +_PDF_REFERENCE_PATTERN = re.compile(rb"(\d+)\s+(\d+)\s+R") +_PDF_NUMERIC_TOKEN_PATTERN = re.compile(rb"[+-]?(?:\d+(?:\.\d+)?|\.\d+)") +_PDF_CONTENT_OPERATORS = { + b'"', + b"'", + b"*", + b"B", + b"BT", + b"BX", + b"B*", + b"BI", + b"BMC", + b"BDC", + b"b", + b"b*", + b"cm", + b"CS", + b"cs", + b"Do", + b"DP", + b"EI", + b"EMC", + b"ET", + b"EX", + b"f", + b"F", + b"f*", + b"G", + b"g", + b"gs", + b"h", + b"i", + b"ID", + b"j", + b"J", + b"K", + b"k", + b"l", + b"M", + b"m", + b"MP", + b"n", + b"q", + b"Q", + b"re", + b"RG", + b"rg", + b"ri", + b"s", + b"S", + b"SC", + b"sc", + b"SCN", + b"scn", + b"sh", + b"T*", + b"Tc", + b"Td", + b"TD", + b"Tf", + b"TJ", + b"Tj", + b"TL", + b"Tm", + b"Tr", + b"Ts", + b"Tw", + b"Tz", + b"v", + b"w", + b"W", + b"W*", + b"y", +} + + +@dataclass(frozen=True, slots=True) +class _PdfObject: + object_id: int + generation: int + dictionary: bytes + stream: bytes | None + raw_content: bytes class TaskArtifactNotFoundError(LookupError): @@ -136,15 +228,15 @@ def infer_task_artifact_media_type(row: TaskArtifactRow) -> str | None: return row["media_type_hint"] artifact_path = Path(row["relative_path"]) - return SUPPORTED_TEXT_ARTIFACT_EXTENSIONS.get(artifact_path.suffix.lower()) + return SUPPORTED_ARTIFACT_EXTENSIONS.get(artifact_path.suffix.lower()) def resolve_supported_task_artifact_media_type(row: TaskArtifactRow) -> str: media_type = infer_task_artifact_media_type(row) - if media_type in SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES: + if media_type in SUPPORTED_ARTIFACT_MEDIA_TYPES: return cast(str, media_type) - supported_types = ", ".join(SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES) + supported_types = ", ".join(SUPPORTED_ARTIFACT_MEDIA_TYPES) raise TaskArtifactValidationError( f"artifact {row['relative_path']} has unsupported media type " f"{media_type or 'unknown'}; supported types: {supported_types}" @@ -167,6 +259,494 @@ def chunk_normalized_artifact_text( return chunks +def _extract_text_from_utf8_artifact_bytes(*, relative_path: str, payload: bytes) -> str: + try: + return payload.decode("utf-8") + except UnicodeDecodeError as exc: + raise TaskArtifactValidationError( + f"artifact {relative_path} is not valid UTF-8 text" + ) from exc + + +def _extract_pdf_name(dictionary: bytes, key: bytes) -> bytes | None: + match = re.search(rb"/" + re.escape(key) + rb"\s*/([A-Za-z0-9_.#-]+)", dictionary) + if match is None: + return None + return match.group(1) + + +def _extract_pdf_reference(dictionary: bytes, key: bytes) -> tuple[int, int] | None: + match = re.search(rb"/" + re.escape(key) + rb"\s+(\d+)\s+(\d+)\s+R", dictionary) + if match is None: + return None + return int(match.group(1)), int(match.group(2)) + + +def _extract_pdf_reference_array(dictionary: bytes, key: bytes) -> list[tuple[int, int]]: + match = re.search(rb"/" + re.escape(key) + rb"\s*\[(.*?)\]", dictionary, re.DOTALL) + if match is None: + return [] + return [ + (int(ref_match.group(1)), int(ref_match.group(2))) + for ref_match in _PDF_REFERENCE_PATTERN.finditer(match.group(1)) + ] + + +def _extract_pdf_filter_names(dictionary: bytes) -> list[bytes]: + array_match = re.search(rb"/Filter\s*\[(.*?)\]", dictionary, re.DOTALL) + if array_match is not None: + return re.findall(rb"/([A-Za-z0-9_.#-]+)", array_match.group(1)) + + filter_name = _extract_pdf_name(dictionary, b"Filter") + if filter_name is None: + return [] + return [filter_name] + + +def _extract_pdf_stream_payload( + *, + relative_path: str, + dictionary: bytes, + body: bytes, + stream_start: int, +) -> bytes: + length_match = re.search(rb"/Length\s+(\d+)", dictionary) + if length_match is not None: + stream_length = int(length_match.group(1)) + stream_end = stream_start + stream_length + if stream_end <= len(body): + return body[stream_start:stream_end] + + stream_end = body.rfind(b"endstream") + if stream_end == -1 or stream_end < stream_start: + raise TaskArtifactValidationError( + f"artifact {relative_path} contains an unreadable PDF stream" + ) + + payload = body[stream_start:stream_end] + if payload.endswith(b"\r\n"): + return payload[:-2] + if payload.endswith((b"\n", b"\r")): + return payload[:-1] + return payload + + +def _parse_pdf_objects(*, relative_path: str, payload: bytes) -> dict[tuple[int, int], _PdfObject]: + objects: dict[tuple[int, int], _PdfObject] = {} + for match in _PDF_INDIRECT_OBJECT_PATTERN.finditer(payload): + object_id = int(match.group(1)) + generation = int(match.group(2)) + body = match.group(3).strip() + dictionary = body + stream: bytes | None = None + stream_index = body.find(b"stream") + if stream_index != -1: + dictionary = body[:stream_index].rstrip() + stream_start = stream_index + len(b"stream") + if body[stream_start : stream_start + 2] == b"\r\n": + stream_start += 2 + elif body[stream_start : stream_start + 1] in (b"\r", b"\n"): + stream_start += 1 + stream = _extract_pdf_stream_payload( + relative_path=relative_path, + dictionary=dictionary, + body=body, + stream_start=stream_start, + ) + objects[(object_id, generation)] = _PdfObject( + object_id=object_id, + generation=generation, + dictionary=dictionary, + stream=stream, + raw_content=body, + ) + + if not objects: + raise TaskArtifactValidationError(f"artifact {relative_path} is not a valid PDF") + return objects + + +def _read_pdf_literal_string(payload: bytes, start: int) -> tuple[bytes, int]: + cursor = start + 1 + depth = 1 + result = bytearray() + while cursor < len(payload): + current = payload[cursor] + if current == ord("\\"): + cursor += 1 + if cursor >= len(payload): + break + escaped = payload[cursor] + if escaped in b"nrtbf()\\": + result.extend( + { + ord("n"): b"\n", + ord("r"): b"\r", + ord("t"): b"\t", + ord("b"): b"\b", + ord("f"): b"\f", + ord("("): b"(", + ord(")"): b")", + ord("\\"): b"\\", + }[escaped] + ) + cursor += 1 + continue + if escaped in b"\r\n": + if escaped == ord("\r") and payload[cursor : cursor + 2] == b"\r\n": + cursor += 2 + else: + cursor += 1 + continue + if chr(escaped).isdigit(): + octal_digits = bytes([escaped]) + cursor += 1 + while cursor < len(payload) and len(octal_digits) < 3 and chr(payload[cursor]).isdigit(): + octal_digits += bytes([payload[cursor]]) + cursor += 1 + result.append(int(octal_digits, 8)) + continue + result.append(escaped) + cursor += 1 + continue + if current == ord("("): + depth += 1 + result.append(current) + cursor += 1 + continue + if current == ord(")"): + depth -= 1 + cursor += 1 + if depth == 0: + return bytes(result), cursor + result.append(current) + continue + result.append(current) + cursor += 1 + + raise TaskArtifactValidationError("PDF literal string terminated unexpectedly") + + +def _read_pdf_hex_string(payload: bytes, start: int) -> tuple[bytes, int]: + cursor = start + 1 + hex_digits = bytearray() + while cursor < len(payload): + current = payload[cursor] + if current == ord(">"): + cursor += 1 + break + if chr(current).isspace(): + cursor += 1 + continue + hex_digits.append(current) + cursor += 1 + + if len(hex_digits) % 2 == 1: + hex_digits.append(ord("0")) + return bytes.fromhex(hex_digits.decode("ascii")), cursor + + +def _skip_pdf_whitespace_and_comments(payload: bytes, start: int) -> int: + cursor = start + while cursor < len(payload): + current = payload[cursor] + if chr(current).isspace(): + cursor += 1 + continue + if current == ord("%"): + while cursor < len(payload) and payload[cursor] not in b"\r\n": + cursor += 1 + continue + break + return cursor + + +def _read_pdf_content_token(payload: bytes, start: int) -> tuple[object | None, int]: + cursor = _skip_pdf_whitespace_and_comments(payload, start) + if cursor >= len(payload): + return None, cursor + + current = payload[cursor] + if current == ord("("): + return _read_pdf_literal_string(payload, cursor) + if current == ord("<") and payload[cursor : cursor + 2] != b"<<": + return _read_pdf_hex_string(payload, cursor) + if current == ord("["): + items: list[object] = [] + cursor += 1 + while True: + cursor = _skip_pdf_whitespace_and_comments(payload, cursor) + if cursor >= len(payload): + raise TaskArtifactValidationError("PDF array terminated unexpectedly") + if payload[cursor] == ord("]"): + return items, cursor + 1 + item, cursor = _read_pdf_content_token(payload, cursor) + if item is None: + raise TaskArtifactValidationError("PDF array terminated unexpectedly") + items.append(item) + if current == ord("/"): + cursor += 1 + token_start = cursor + while cursor < len(payload) and not chr(payload[cursor]).isspace() and payload[cursor] not in b"()<>[]{}/%": + cursor += 1 + return payload[token_start - 1 : cursor], cursor + + token_start = cursor + while cursor < len(payload) and not chr(payload[cursor]).isspace() and payload[cursor] not in b"()<>[]{}/%": + cursor += 1 + return payload[token_start:cursor], cursor + + +def _decode_pdf_text_bytes(raw: bytes) -> str: + if raw.startswith(b"\xfe\xff"): + return raw[2:].decode("utf-16-be", errors="ignore") + if raw.startswith(b"\xff\xfe"): + return raw[2:].decode("utf-16-le", errors="ignore") + return raw.decode("latin-1", errors="ignore") + + +def _decode_pdf_text_operand(value: object | None) -> str: + if isinstance(value, bytes): + return _decode_pdf_text_bytes(value) + if isinstance(value, list): + return "".join( + _decode_pdf_text_bytes(item) for item in value if isinstance(item, bytes) + ) + return "" + + +def _pop_last_pdf_text_operand(operands: list[object]) -> object | None: + for index in range(len(operands) - 1, -1, -1): + candidate = operands[index] + if isinstance(candidate, (bytes, list)): + return operands.pop(index) + return None + + +def _extract_text_from_pdf_content_stream(stream: bytes) -> str: + operands: list[object] = [] + fragments: list[str] = [] + inside_text_block = False + pending_newline = False + cursor = 0 + + def request_newline() -> None: + nonlocal pending_newline + if fragments: + pending_newline = True + + def append_text(text: str) -> None: + nonlocal pending_newline + if text == "": + return + if pending_newline and fragments and fragments[-1] != "\n": + fragments.append("\n") + pending_newline = False + fragments.append(text) + + while True: + token, cursor = _read_pdf_content_token(stream, cursor) + if token is None: + break + if isinstance(token, list) or ( + isinstance(token, bytes) + and ( + token.startswith(b"/") + or _PDF_NUMERIC_TOKEN_PATTERN.fullmatch(token) is not None + or token in {b"true", b"false", b"null"} + ) + ): + operands.append(token) + continue + if not isinstance(token, bytes): + operands.append(token) + continue + + operator = token + if operator == b"BT": + inside_text_block = True + operands.clear() + continue + if operator == b"ET": + inside_text_block = False + operands.clear() + continue + if operator not in _PDF_CONTENT_OPERATORS: + operands.append(token) + continue + if not inside_text_block: + operands.clear() + continue + if operator in {b"T*", b"Td", b"TD", b"Tm"}: + request_newline() + operands.clear() + continue + if operator in {b"Tj", b"TJ"}: + append_text(_decode_pdf_text_operand(_pop_last_pdf_text_operand(operands))) + operands.clear() + continue + if operator in {b"'", b'"'}: + request_newline() + append_text(_decode_pdf_text_operand(_pop_last_pdf_text_operand(operands))) + operands.clear() + continue + operands.clear() + + return "".join(fragments).strip() + + +def _decode_pdf_stream(*, relative_path: str, pdf_object: _PdfObject) -> bytes: + if pdf_object.stream is None: + raise TaskArtifactValidationError( + f"artifact {relative_path} contains a PDF content reference without a stream" + ) + + filters = _extract_pdf_filter_names(pdf_object.dictionary) + if not filters: + return pdf_object.stream + if filters == [b"FlateDecode"]: + try: + return zlib.decompress(pdf_object.stream) + except zlib.error as exc: + raise TaskArtifactValidationError( + f"artifact {relative_path} contains an unreadable FlateDecode PDF stream" + ) from exc + + filter_names = ", ".join(f"/{name.decode('ascii', errors='ignore')}" for name in filters) + raise TaskArtifactValidationError( + f"artifact {relative_path} uses unsupported PDF stream filters {filter_names}" + ) + + +def _collect_pdf_page_refs( + *, + relative_path: str, + objects: dict[tuple[int, int], _PdfObject], + current_ref: tuple[int, int], + collected_refs: list[tuple[int, int]], + visited_refs: set[tuple[int, int]], +) -> None: + if current_ref in visited_refs: + return + visited_refs.add(current_ref) + current_object = objects.get(current_ref) + if current_object is None: + raise TaskArtifactValidationError( + f"artifact {relative_path} references a missing PDF object {current_ref[0]} {current_ref[1]} R" + ) + + object_type = _extract_pdf_name(current_object.dictionary, b"Type") + if object_type == b"Page": + collected_refs.append(current_ref) + return + if object_type != b"Pages": + raise TaskArtifactValidationError( + f"artifact {relative_path} uses unsupported PDF page tree structure" + ) + + child_refs = _extract_pdf_reference_array(current_object.dictionary, b"Kids") + if not child_refs: + raise TaskArtifactValidationError( + f"artifact {relative_path} uses unsupported PDF page tree structure" + ) + for child_ref in child_refs: + _collect_pdf_page_refs( + relative_path=relative_path, + objects=objects, + current_ref=child_ref, + collected_refs=collected_refs, + visited_refs=visited_refs, + ) + + +def _resolve_pdf_page_refs( + *, + relative_path: str, + objects: dict[tuple[int, int], _PdfObject], +) -> list[tuple[int, int]]: + catalog_ref = next( + ( + object_ref + for object_ref, pdf_object in objects.items() + if _extract_pdf_name(pdf_object.dictionary, b"Type") == b"Catalog" + ), + None, + ) + if catalog_ref is None: + raise TaskArtifactValidationError(f"artifact {relative_path} is not a valid PDF") + + pages_ref = _extract_pdf_reference(objects[catalog_ref].dictionary, b"Pages") + if pages_ref is None: + raise TaskArtifactValidationError( + f"artifact {relative_path} uses unsupported PDF page tree structure" + ) + + page_refs: list[tuple[int, int]] = [] + _collect_pdf_page_refs( + relative_path=relative_path, + objects=objects, + current_ref=pages_ref, + collected_refs=page_refs, + visited_refs=set(), + ) + return page_refs + + +def _extract_text_from_pdf_artifact_bytes(*, relative_path: str, payload: bytes) -> str: + if not payload.startswith(b"%PDF-"): + raise TaskArtifactValidationError(f"artifact {relative_path} is not a valid PDF") + + objects = _parse_pdf_objects(relative_path=relative_path, payload=payload) + page_refs = _resolve_pdf_page_refs(relative_path=relative_path, objects=objects) + page_fragments: list[str] = [] + for page_ref in page_refs: + page_object = objects[page_ref] + content_refs = _extract_pdf_reference_array(page_object.dictionary, b"Contents") + if not content_refs: + single_content_ref = _extract_pdf_reference(page_object.dictionary, b"Contents") + if single_content_ref is not None: + content_refs = [single_content_ref] + + stream_fragments: list[str] = [] + for content_ref in content_refs: + content_object = objects.get(content_ref) + if content_object is None: + raise TaskArtifactValidationError( + f"artifact {relative_path} references a missing PDF object {content_ref[0]} {content_ref[1]} R" + ) + extracted = _extract_text_from_pdf_content_stream( + _decode_pdf_stream(relative_path=relative_path, pdf_object=content_object) + ) + if extracted != "": + stream_fragments.append(extracted) + if stream_fragments: + page_fragments.append("\n".join(stream_fragments)) + + extracted_text = "\n".join(page_fragments).strip() + if extracted_text == "": + raise TaskArtifactValidationError( + f"artifact {relative_path} does not contain extractable PDF text" + ) + return extracted_text + + +def extract_artifact_text(*, row: TaskArtifactRow, artifact_path: Path, media_type: str) -> str: + payload = artifact_path.read_bytes() + if media_type in SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES: + return _extract_text_from_utf8_artifact_bytes( + relative_path=row["relative_path"], + payload=payload, + ) + if media_type == SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE: + return _extract_text_from_pdf_artifact_bytes( + relative_path=row["relative_path"], + payload=payload, + ) + raise TaskArtifactValidationError( + f"artifact {row['relative_path']} has unsupported media type {media_type}" + ) + + def resolve_registered_artifact_path(*, workspace_path: Path, relative_path: str) -> Path: artifact_path = (workspace_path / relative_path).resolve() ensure_artifact_path_is_rooted( @@ -462,14 +1042,11 @@ def ingest_task_artifact_record( relative_path=row["relative_path"], ) _require_existing_file(artifact_path) - - try: - text = artifact_path.read_bytes().decode("utf-8") - except UnicodeDecodeError as exc: - raise TaskArtifactValidationError( - f"artifact {row['relative_path']} is not valid UTF-8 text" - ) from exc - + text = extract_artifact_text( + row=row, + artifact_path=artifact_path, + media_type=media_type, + ) normalized_text = normalize_artifact_text(text) for index, (char_start, char_end_exclusive, chunk_text) in enumerate( chunk_normalized_artifact_text(normalized_text), diff --git a/tests/integration/test_task_artifacts_api.py b/tests/integration/test_task_artifacts_api.py index 1aab4e1..d23c5c7 100644 --- a/tests/integration/test_task_artifacts_api.py +++ b/tests/integration/test_task_artifacts_api.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import zlib from pathlib import Path from typing import Any from urllib.parse import urlencode @@ -16,6 +17,88 @@ from alicebot_api.store import ContinuityStore +def _escape_pdf_literal_string(value: str) -> str: + return value.replace("\\", "\\\\").replace("(", "\\(").replace(")", "\\)") + + +def _build_pdf_bytes( + pages: list[list[str]], + *, + compress_streams: bool = True, + textless: bool = False, +) -> bytes: + objects: dict[int, bytes] = { + 1: b"<< /Type /Catalog /Pages 2 0 R >>", + 3: b"<< /Type /Font /Subtype /Type1 /BaseFont /Helvetica >>", + } + page_refs: list[str] = [] + next_object_id = 4 + for page_lines in pages: + page_object_id = next_object_id + content_object_id = next_object_id + 1 + next_object_id += 2 + page_refs.append(f"{page_object_id} 0 R") + + if textless: + content_stream = b"q 10 10 100 100 re S Q\n" + else: + commands = [b"BT", b"/F1 12 Tf", b"72 720 Td"] + for index, line in enumerate(page_lines): + if index > 0: + commands.append(b"T*") + commands.append(f"({_escape_pdf_literal_string(line)}) Tj".encode("latin-1")) + commands.append(b"ET") + content_stream = b"\n".join(commands) + b"\n" + + if compress_streams: + encoded_stream = zlib.compress(content_stream) + content_body = ( + f"<< /Length {len(encoded_stream)} /Filter /FlateDecode >>\n".encode("ascii") + + b"stream\n" + + encoded_stream + + b"\nendstream" + ) + else: + content_body = ( + f"<< /Length {len(content_stream)} >>\n".encode("ascii") + + b"stream\n" + + content_stream + + b"endstream" + ) + + objects[page_object_id] = ( + f"<< /Type /Page /Parent 2 0 R /Resources << /Font << /F1 3 0 R >> >> " + f"/MediaBox [0 0 612 792] /Contents {content_object_id} 0 R >>" + ).encode("ascii") + objects[content_object_id] = content_body + + objects[2] = ( + f"<< /Type /Pages /Count {len(page_refs)} /Kids [{' '.join(page_refs)}] >>" + ).encode("ascii") + + document = bytearray(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n") + max_object_id = max(objects) + offsets = [0] * (max_object_id + 1) + for object_id in range(1, max_object_id + 1): + offsets[object_id] = len(document) + document.extend(f"{object_id} 0 obj\n".encode("ascii")) + document.extend(objects[object_id]) + document.extend(b"\nendobj\n") + + xref_offset = len(document) + document.extend(f"xref\n0 {max_object_id + 1}\n".encode("ascii")) + document.extend(b"0000000000 65535 f \n") + for object_id in range(1, max_object_id + 1): + document.extend(f"{offsets[object_id]:010d} 00000 n \n".encode("ascii")) + document.extend( + ( + f"trailer\n<< /Size {max_object_id + 1} /Root 1 0 R >>\n" + f"startxref\n{xref_offset}\n%%EOF\n" + ).encode("ascii") + ) + return bytes(document) + + def invoke_request( method: str, path: str, @@ -323,8 +406,8 @@ def test_task_artifact_ingestion_and_chunk_endpoints_are_deterministic_and_isola supported_file = workspace_path / "docs" / "spec.txt" supported_file.parent.mkdir(parents=True) supported_file.write_text(("A" * 998) + "\r\n" + ("B" * 5) + "\rC") - unsupported_file = workspace_path / "docs" / "manual.pdf" - unsupported_file.write_text("not really a pdf") + unsupported_file = workspace_path / "docs" / "manual.bin" + unsupported_file.write_bytes(b"\x00\x01\x02") register_status, register_payload = invoke_request( "POST", @@ -364,7 +447,7 @@ def test_task_artifact_ingestion_and_chunk_endpoints_are_deterministic_and_isola payload={ "user_id": str(owner["user_id"]), "local_path": str(unsupported_file), - "media_type_hint": "application/pdf", + "media_type_hint": "application/octet-stream", }, ) assert unsupported_register_status == 201 @@ -442,12 +525,139 @@ def test_task_artifact_ingestion_and_chunk_endpoints_are_deterministic_and_isola assert unsupported_ingest_status == 400 assert unsupported_ingest_payload == { "detail": ( - "artifact docs/manual.pdf has unsupported media type application/pdf; " - "supported types: text/plain, text/markdown" + "artifact docs/manual.bin has unsupported media type application/octet-stream; " + "supported types: text/plain, text/markdown, application/pdf" ) } +def test_task_artifact_pdf_ingestion_and_chunk_endpoints_are_deterministic_and_isolated( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + pdf_file = workspace_path / "docs" / "spec.pdf" + pdf_file.parent.mkdir(parents=True) + pdf_file.write_bytes(_build_pdf_bytes([["A" * 998, "B" * 5, "C"]])) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(pdf_file), + "media_type_hint": "application/pdf", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_chunk_list_status, isolated_chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_ingest_status, isolated_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(intruder["user_id"])}, + ) + + assert ingest_status == 200 + assert ingest_payload == { + "artifact": { + "id": register_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.pdf", + "media_type_hint": "application/pdf", + "created_at": register_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "application/pdf", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert chunk_list_status == 200 + assert chunk_list_payload == { + "items": [ + { + "id": chunk_list_payload["items"][0]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": ("A" * 998) + "\n" + "B", + "created_at": chunk_list_payload["items"][0]["created_at"], + "updated_at": chunk_list_payload["items"][0]["updated_at"], + }, + { + "id": chunk_list_payload["items"][1]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": chunk_list_payload["items"][1]["created_at"], + "updated_at": chunk_list_payload["items"][1]["updated_at"], + }, + ], + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "application/pdf", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert isolated_chunk_list_status == 404 + assert isolated_chunk_list_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + assert isolated_ingest_status == 404 + assert isolated_ingest_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + def test_task_artifact_ingestion_supports_markdown_and_reingest_is_idempotent( migrated_database_urls, monkeypatch, @@ -601,6 +811,57 @@ def test_task_artifact_ingestion_rejects_invalid_utf8_content( } +def test_task_artifact_ingestion_rejects_textless_pdf_content( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + textless_pdf = workspace_path / "docs" / "scanned.pdf" + textless_pdf.parent.mkdir(parents=True) + textless_pdf.write_bytes(_build_pdf_bytes([[]], textless=True)) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(textless_pdf), + "media_type_hint": "application/pdf", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert ingest_status == 400 + assert ingest_payload == { + "detail": "artifact docs/scanned.pdf does not contain extractable PDF text" + } + + def test_task_artifact_ingestion_enforces_rooted_workspace_paths( migrated_database_urls, monkeypatch, @@ -625,11 +886,11 @@ def test_task_artifact_ingestion_enforces_rooted_workspace_paths( assert workspace_status == 201 workspace_path = Path(workspace_payload["workspace"]["local_path"]) - safe_file = workspace_path / "docs" / "spec.txt" + safe_file = workspace_path / "docs" / "spec.pdf" safe_file.parent.mkdir(parents=True) - safe_file.write_text("spec") - outside_file = tmp_path / "escape.txt" - outside_file.write_text("escape") + safe_file.write_bytes(_build_pdf_bytes([["spec"]])) + outside_file = tmp_path / "escape.pdf" + outside_file.write_bytes(_build_pdf_bytes([["escape"]])) register_status, register_payload = invoke_request( "POST", @@ -637,7 +898,7 @@ def test_task_artifact_ingestion_enforces_rooted_workspace_paths( payload={ "user_id": str(owner["user_id"]), "local_path": str(safe_file), - "media_type_hint": "text/plain", + "media_type_hint": "application/pdf", }, ) assert register_status == 201 @@ -647,7 +908,7 @@ def test_task_artifact_ingestion_enforces_rooted_workspace_paths( cur.execute( """ UPDATE task_artifacts - SET relative_path = '../../../escape.txt' + SET relative_path = '../../../escape.pdf' WHERE id = %s """, (register_payload["artifact"]["id"],), diff --git a/tests/unit/test_artifacts.py b/tests/unit/test_artifacts.py index 07dc3de..35da2b0 100644 --- a/tests/unit/test_artifacts.py +++ b/tests/unit/test_artifacts.py @@ -1,5 +1,6 @@ from __future__ import annotations +import zlib from datetime import UTC, datetime, timedelta from pathlib import Path from uuid import UUID, uuid4 @@ -38,6 +39,88 @@ from alicebot_api.workspaces import TaskWorkspaceNotFoundError +def _escape_pdf_literal_string(value: str) -> str: + return value.replace("\\", "\\\\").replace("(", "\\(").replace(")", "\\)") + + +def _build_pdf_bytes( + pages: list[list[str]], + *, + compress_streams: bool = True, + textless: bool = False, +) -> bytes: + objects: dict[int, bytes] = { + 1: b"<< /Type /Catalog /Pages 2 0 R >>", + 3: b"<< /Type /Font /Subtype /Type1 /BaseFont /Helvetica >>", + } + page_refs: list[str] = [] + next_object_id = 4 + for page_lines in pages: + page_object_id = next_object_id + content_object_id = next_object_id + 1 + next_object_id += 2 + page_refs.append(f"{page_object_id} 0 R") + + if textless: + content_stream = b"q 10 10 100 100 re S Q\n" + else: + commands = [b"BT", b"/F1 12 Tf", b"72 720 Td"] + for index, line in enumerate(page_lines): + if index > 0: + commands.append(b"T*") + commands.append(f"({_escape_pdf_literal_string(line)}) Tj".encode("latin-1")) + commands.append(b"ET") + content_stream = b"\n".join(commands) + b"\n" + + if compress_streams: + encoded_stream = zlib.compress(content_stream) + content_body = ( + f"<< /Length {len(encoded_stream)} /Filter /FlateDecode >>\n".encode("ascii") + + b"stream\n" + + encoded_stream + + b"\nendstream" + ) + else: + content_body = ( + f"<< /Length {len(content_stream)} >>\n".encode("ascii") + + b"stream\n" + + content_stream + + b"endstream" + ) + + objects[page_object_id] = ( + f"<< /Type /Page /Parent 2 0 R /Resources << /Font << /F1 3 0 R >> >> " + f"/MediaBox [0 0 612 792] /Contents {content_object_id} 0 R >>" + ).encode("ascii") + objects[content_object_id] = content_body + + objects[2] = ( + f"<< /Type /Pages /Count {len(page_refs)} /Kids [{' '.join(page_refs)}] >>" + ).encode("ascii") + + document = bytearray(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n") + max_object_id = max(objects) + offsets = [0] * (max_object_id + 1) + for object_id in range(1, max_object_id + 1): + offsets[object_id] = len(document) + document.extend(f"{object_id} 0 obj\n".encode("ascii")) + document.extend(objects[object_id]) + document.extend(b"\nendobj\n") + + xref_offset = len(document) + document.extend(f"xref\n0 {max_object_id + 1}\n".encode("ascii")) + document.extend(b"0000000000 65535 f \n") + for object_id in range(1, max_object_id + 1): + document.extend(f"{offsets[object_id]:010d} 00000 n \n".encode("ascii")) + document.extend( + ( + f"trailer\n<< /Size {max_object_id + 1} /Root 1 0 R >>\n" + f"startxref\n{xref_offset}\n%%EOF\n" + ).encode("ascii") + ) + return bytes(document) + + class ArtifactStoreStub: def __init__(self) -> None: self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) @@ -477,6 +560,84 @@ def test_ingest_task_artifact_record_supports_markdown(tmp_path) -> None: ] +def test_ingest_task_artifact_record_persists_deterministic_pdf_chunks(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "spec.pdf" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_pdf_bytes([["A" * 998, "B" * 5, "C"]])) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/spec.pdf", + media_type_hint="application/pdf", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + assert response == { + "artifact": { + "id": str(artifact["id"]), + "task_id": str(task_id), + "task_workspace_id": str(task_workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.pdf", + "media_type_hint": "application/pdf", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:30:00+00:00", + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "application/pdf", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert store.locked_artifact_ids == [artifact["id"]] + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": ("A" * 998) + "\n" + "B", + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + }, + { + "id": store.artifact_chunks[1]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + }, + ] + + def test_ingest_task_artifact_record_is_idempotent_for_already_ingested_artifact() -> None: store = ArtifactStoreStub() user_id = uuid4() @@ -541,9 +702,9 @@ def test_ingest_task_artifact_record_rejects_unsupported_media_type(tmp_path) -> task_workspace_id = uuid4() workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) workspace_path.mkdir(parents=True) - artifact_path = workspace_path / "docs" / "spec.pdf" + artifact_path = workspace_path / "docs" / "spec.bin" artifact_path.parent.mkdir(parents=True) - artifact_path.write_text("not really a pdf") + artifact_path.write_bytes(b"\x00\x01\x02") store.create_task_workspace( task_workspace_id=task_workspace_id, task_id=task_id, @@ -555,13 +716,52 @@ def test_ingest_task_artifact_record_rejects_unsupported_media_type(tmp_path) -> task_workspace_id=task_workspace_id, status="registered", ingestion_status="pending", - relative_path="docs/spec.pdf", + relative_path="docs/spec.bin", + media_type_hint="application/octet-stream", + ) + + with pytest.raises( + TaskArtifactValidationError, + match=( + "artifact docs/spec.bin has unsupported media type application/octet-stream; " + "supported types: text/plain, text/markdown, application/pdf" + ), + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + +def test_ingest_task_artifact_record_rejects_textless_pdf(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "scanned.pdf" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_pdf_bytes([[]], textless=True)) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/scanned.pdf", media_type_hint="application/pdf", ) with pytest.raises( TaskArtifactValidationError, - match="artifact docs/spec.pdf has unsupported media type application/pdf", + match="artifact docs/scanned.pdf does not contain extractable PDF text", ): ingest_task_artifact_record( store, @@ -613,8 +813,8 @@ def test_ingest_task_artifact_record_rejects_paths_outside_workspace(tmp_path) - task_workspace_id = uuid4() workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) workspace_path.mkdir(parents=True) - outside_path = tmp_path / "escape.txt" - outside_path.write_text("escape") + outside_path = tmp_path / "escape.pdf" + outside_path.write_bytes(_build_pdf_bytes([["escape"]])) store.create_task_workspace( task_workspace_id=task_workspace_id, task_id=task_id, @@ -626,8 +826,8 @@ def test_ingest_task_artifact_record_rejects_paths_outside_workspace(tmp_path) - task_workspace_id=task_workspace_id, status="registered", ingestion_status="pending", - relative_path="../escape.txt", - media_type_hint="text/plain", + relative_path="../escape.pdf", + media_type_hint="application/pdf", ) with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py index 634d9b6..f43b78c 100644 --- a/tests/unit/test_artifacts_main.py +++ b/tests/unit/test_artifacts_main.py @@ -497,8 +497,8 @@ def fake_user_connection(*_args, **_kwargs): def fake_ingest_task_artifact_record(*_args, **_kwargs): raise TaskArtifactValidationError( - "artifact docs/spec.txt has unsupported media type application/pdf; " - "supported types: text/plain, text/markdown" + "artifact docs/spec.bin has unsupported media type application/octet-stream; " + "supported types: text/plain, text/markdown, application/pdf" ) monkeypatch.setattr(main_module, "get_settings", lambda: settings) @@ -513,8 +513,8 @@ def fake_ingest_task_artifact_record(*_args, **_kwargs): assert response.status_code == 400 assert json.loads(response.body) == { "detail": ( - "artifact docs/spec.txt has unsupported media type application/pdf; " - "supported types: text/plain, text/markdown" + "artifact docs/spec.bin has unsupported media type application/octet-stream; " + "supported types: text/plain, text/markdown, application/pdf" ) } From 914014d2fac0dfc5de7249afa48370a69d10f08f Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 12:10:18 +0100 Subject: [PATCH 013/135] Sprint 5M: DOCX artifact parsing v0 (#13) * Sprint 5M: DOCX artifact parsing packet * Sprint 5M: DOCX artifact parsing v0 --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 27 +- apps/api/src/alicebot_api/artifacts.py | 64 ++++ .../src/alicebot_api/semantic_retrieval.py | 1 + tests/integration/test_task_artifacts_api.py | 336 +++++++++++++++++- tests/unit/test_artifacts.py | 254 ++++++++++++- tests/unit/test_artifacts_main.py | 6 +- tests/unit/test_semantic_retrieval.py | 48 +++ 7 files changed, 720 insertions(+), 16 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 168429a..faf17a5 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5J. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5M. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local text-artifact ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, and narrow DOCX text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic text ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. Broader runner-style orchestration, automatic multi-step progression, richer document parsing, connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, and DOCX support is limited to narrow local text extraction from `word/document.xml`; OCR, image extraction, layout reconstruction, connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -62,7 +62,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, and Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, and Sprint 5M narrow DOCX artifact ingestion. ## Core Flows Implemented Now @@ -198,14 +198,17 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 1. Accept a user-scoped `POST /v0/task-artifacts/{task_artifact_id}/ingest` request for one visible registered artifact. 2. Lock ingestion for that artifact before deciding whether work is needed. 3. Resolve the persisted workspace `local_path` plus persisted artifact `relative_path`, and reject any rooted-path escape deterministically. -4. Support only the narrow explicit text set: `text/plain` and `text/markdown`. -5. Read file bytes deterministically and require valid UTF-8 text. -6. Normalize line endings by rewriting `\r\n` and `\r` to `\n`. -7. Chunk normalized text deterministically with rule `normalized_utf8_text_fixed_window_1000_chars_v1`. -8. Persist ordered `task_artifact_chunks` rows with `sequence_no`, `char_start`, `char_end_exclusive`, and `text`. -9. Update the parent artifact to `ingestion_status = ingested`. -10. If the artifact is already ingested, return the existing artifact and chunk summary without reinserting chunks. -11. `GET /v0/task-artifacts/{task_artifact_id}/chunks` returns visible chunk rows in deterministic `sequence_no ASC, id ASC` order plus stable summary metadata. +4. Support only the current narrow explicit set: `text/plain`, `text/markdown`, narrow local `application/pdf` text extraction, and narrow local `application/vnd.openxmlformats-officedocument.wordprocessingml.document` text extraction from `word/document.xml`. +5. For plain text and markdown, read file bytes deterministically and require valid UTF-8 text. +6. For PDFs, extract only narrow local text content; OCR, image extraction, and broader PDF compatibility remain out of scope. +7. For DOCX, extract only narrow local text from `word/document.xml`; OCR, image extraction, headers/footers/comments expansion, and layout reconstruction remain out of scope. +8. Reject malformed or textless richer-document inputs deterministically instead of producing misleading chunks. +9. Normalize line endings by rewriting `\r\n` and `\r` to `\n`. +10. Chunk normalized text deterministically with rule `normalized_utf8_text_fixed_window_1000_chars_v1`. +11. Persist ordered `task_artifact_chunks` rows with `sequence_no`, `char_start`, `char_end_exclusive`, and `text`. +12. Update the parent artifact to `ingestion_status = ingested`. +13. If the artifact is already ingested, return the existing artifact and chunk summary without reinserting chunks. +14. `GET /v0/task-artifacts/{task_artifact_id}/chunks` returns visible chunk rows in deterministic `sequence_no ASC, id ASC` order plus stable summary metadata. ### Artifact Chunk Retrieval diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py index 3e12638..734f251 100644 --- a/apps/api/src/alicebot_api/artifacts.py +++ b/apps/api/src/alicebot_api/artifacts.py @@ -1,11 +1,14 @@ from __future__ import annotations +import io import re import zlib from dataclasses import dataclass from pathlib import Path from typing import cast from uuid import UUID +import xml.etree.ElementTree as ET +import zipfile import psycopg @@ -40,9 +43,13 @@ SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES = ("text/plain", "text/markdown") SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE = "application/pdf" +SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE = ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" +) SUPPORTED_ARTIFACT_MEDIA_TYPES = ( *SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES, SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE, + SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE, ) SUPPORTED_ARTIFACT_EXTENSIONS = { ".txt": "text/plain", @@ -50,6 +57,7 @@ ".md": "text/markdown", ".markdown": "text/markdown", ".pdf": SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE, + ".docx": SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE, } TASK_ARTIFACT_CHUNK_MAX_CHARS = 1000 TASK_ARTIFACT_CHUNKING_RULE = "normalized_utf8_text_fixed_window_1000_chars_v1" @@ -132,6 +140,14 @@ b"W*", b"y", } +_DOCX_DOCUMENT_XML_PATH = "word/document.xml" +_DOCX_WORDPROCESSING_NAMESPACE = "http://schemas.openxmlformats.org/wordprocessingml/2006/main" +_DOCX_PARAGRAPH_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}p" +_DOCX_TEXT_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}t" +_DOCX_TAB_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}tab" +_DOCX_BREAK_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}br" +_DOCX_CARRIAGE_RETURN_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}cr" +_DOCX_BODY_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}body" @dataclass(frozen=True, slots=True) @@ -268,6 +284,49 @@ def _extract_text_from_utf8_artifact_bytes(*, relative_path: str, payload: bytes ) from exc +def _extract_text_from_docx_paragraph(paragraph: ET.Element) -> str: + fragments: list[str] = [] + for element in paragraph.iter(): + if element.tag == _DOCX_TEXT_TAG: + fragments.append(element.text or "") + continue + if element.tag == _DOCX_TAB_TAG: + fragments.append("\t") + continue + if element.tag in {_DOCX_BREAK_TAG, _DOCX_CARRIAGE_RETURN_TAG}: + fragments.append("\n") + return "".join(fragments) + + +def _extract_text_from_docx_artifact_bytes(*, relative_path: str, payload: bytes) -> str: + try: + with zipfile.ZipFile(io.BytesIO(payload)) as archive: + document_xml = archive.read(_DOCX_DOCUMENT_XML_PATH) + except (KeyError, zipfile.BadZipFile) as exc: + raise TaskArtifactValidationError(f"artifact {relative_path} is not a valid DOCX") from exc + + try: + document_root = ET.fromstring(document_xml) + except ET.ParseError as exc: + raise TaskArtifactValidationError(f"artifact {relative_path} is not a valid DOCX") from exc + + document_body = document_root.find(_DOCX_BODY_TAG) + if document_body is None: + raise TaskArtifactValidationError(f"artifact {relative_path} is not a valid DOCX") + + paragraphs = [ + paragraph_text + for paragraph in document_body.iter(_DOCX_PARAGRAPH_TAG) + if (paragraph_text := _extract_text_from_docx_paragraph(paragraph)) != "" + ] + extracted_text = "\n".join(paragraphs).strip() + if extracted_text == "": + raise TaskArtifactValidationError( + f"artifact {relative_path} does not contain extractable DOCX text" + ) + return extracted_text + + def _extract_pdf_name(dictionary: bytes, key: bytes) -> bytes | None: match = re.search(rb"/" + re.escape(key) + rb"\s*/([A-Za-z0-9_.#-]+)", dictionary) if match is None: @@ -742,6 +801,11 @@ def extract_artifact_text(*, row: TaskArtifactRow, artifact_path: Path, media_ty relative_path=row["relative_path"], payload=payload, ) + if media_type == SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE: + return _extract_text_from_docx_artifact_bytes( + relative_path=row["relative_path"], + payload=payload, + ) raise TaskArtifactValidationError( f"artifact {row['relative_path']} has unsupported media type {media_type}" ) diff --git a/apps/api/src/alicebot_api/semantic_retrieval.py b/apps/api/src/alicebot_api/semantic_retrieval.py index 50059a1..5145062 100644 --- a/apps/api/src/alicebot_api/semantic_retrieval.py +++ b/apps/api/src/alicebot_api/semantic_retrieval.py @@ -33,6 +33,7 @@ ".text": "text/plain", ".md": "text/markdown", ".markdown": "text/markdown", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", } diff --git a/tests/integration/test_task_artifacts_api.py b/tests/integration/test_task_artifacts_api.py index d23c5c7..618bc24 100644 --- a/tests/integration/test_task_artifacts_api.py +++ b/tests/integration/test_task_artifacts_api.py @@ -1,11 +1,14 @@ from __future__ import annotations +import io import json import zlib from pathlib import Path from typing import Any from urllib.parse import urlencode from uuid import UUID, uuid4 +from xml.sax.saxutils import escape +import zipfile import anyio import psycopg @@ -99,6 +102,72 @@ def _build_pdf_bytes( return bytes(document) +def _build_docx_bytes( + paragraphs: list[str], + *, + include_document_xml: bool = True, + malformed_document_xml: bool = False, +) -> bytes: + document_xml = ( + b"' + '' + "" + + "".join( + ( + "" + f"{escape(paragraph)}" + "" + ) + for paragraph in paragraphs + ) + + ( + "" + "" + "" + "" + "" + "" + ) + ) + ) + + archive_buffer = io.BytesIO() + with zipfile.ZipFile(archive_buffer, "w", compression=zipfile.ZIP_STORED) as archive: + entries = { + "[Content_Types].xml": ( + '' + '' + '' + '' + '' + "" + ).encode("utf-8"), + "_rels/.rels": ( + '' + '' + '' + "" + ).encode("utf-8"), + } + if include_document_xml: + entries["word/document.xml"] = document_xml + + for name, payload in entries.items(): + info = zipfile.ZipInfo(filename=name) + info.date_time = (2026, 3, 13, 10, 0, 0) + info.compress_type = zipfile.ZIP_STORED + archive.writestr(info, payload) + + return archive_buffer.getvalue() + + def invoke_request( method: str, path: str, @@ -526,7 +595,8 @@ def test_task_artifact_ingestion_and_chunk_endpoints_are_deterministic_and_isola assert unsupported_ingest_payload == { "detail": ( "artifact docs/manual.bin has unsupported media type application/octet-stream; " - "supported types: text/plain, text/markdown, application/pdf" + "supported types: text/plain, text/markdown, application/pdf, " + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) } @@ -658,6 +728,133 @@ def test_task_artifact_pdf_ingestion_and_chunk_endpoints_are_deterministic_and_i } +def test_task_artifact_docx_ingestion_and_chunk_endpoints_are_deterministic_and_isolated( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + docx_file = workspace_path / "docs" / "spec.docx" + docx_file.parent.mkdir(parents=True) + docx_file.write_bytes(_build_docx_bytes(["A" * 998, "B" * 5, "C"])) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(docx_file), + "media_type_hint": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_chunk_list_status, isolated_chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_ingest_status, isolated_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(intruder["user_id"])}, + ) + + assert ingest_status == 200 + assert ingest_payload == { + "artifact": { + "id": register_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.docx", + "media_type_hint": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "created_at": register_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert chunk_list_status == 200 + assert chunk_list_payload == { + "items": [ + { + "id": chunk_list_payload["items"][0]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": ("A" * 998) + "\n" + "B", + "created_at": chunk_list_payload["items"][0]["created_at"], + "updated_at": chunk_list_payload["items"][0]["updated_at"], + }, + { + "id": chunk_list_payload["items"][1]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": chunk_list_payload["items"][1]["created_at"], + "updated_at": chunk_list_payload["items"][1]["updated_at"], + }, + ], + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert isolated_chunk_list_status == 404 + assert isolated_chunk_list_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + assert isolated_ingest_status == 404 + assert isolated_ingest_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + def test_task_artifact_ingestion_supports_markdown_and_reingest_is_idempotent( migrated_database_urls, monkeypatch, @@ -862,6 +1059,78 @@ def test_task_artifact_ingestion_rejects_textless_pdf_content( } +def test_task_artifact_ingestion_rejects_textless_or_malformed_docx( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + textless_docx = workspace_path / "docs" / "empty.docx" + textless_docx.parent.mkdir(parents=True) + textless_docx.write_bytes(_build_docx_bytes(["", ""])) + malformed_docx = workspace_path / "docs" / "broken.docx" + malformed_docx.write_bytes(_build_docx_bytes(["broken"], malformed_document_xml=True)) + + textless_register_status, textless_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(textless_docx), + "media_type_hint": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + }, + ) + malformed_register_status, malformed_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(malformed_docx), + "media_type_hint": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + }, + ) + assert textless_register_status == 201 + assert malformed_register_status == 201 + + textless_ingest_status, textless_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{textless_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + malformed_ingest_status, malformed_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{malformed_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert textless_ingest_status == 400 + assert textless_ingest_payload == { + "detail": "artifact docs/empty.docx does not contain extractable DOCX text" + } + assert malformed_ingest_status == 400 + assert malformed_ingest_payload == { + "detail": "artifact docs/broken.docx is not a valid DOCX" + } + + def test_task_artifact_ingestion_enforces_rooted_workspace_paths( migrated_database_urls, monkeypatch, @@ -927,6 +1196,71 @@ def test_task_artifact_ingestion_enforces_rooted_workspace_paths( } +def test_task_artifact_docx_ingestion_enforces_rooted_workspace_paths( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + safe_file = workspace_path / "docs" / "spec.docx" + safe_file.parent.mkdir(parents=True) + safe_file.write_bytes(_build_docx_bytes(["spec"])) + outside_file = tmp_path / "escape.docx" + outside_file.write_bytes(_build_docx_bytes(["escape"])) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(safe_file), + "media_type_hint": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + }, + ) + assert register_status == 201 + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE task_artifacts + SET relative_path = '../../../escape.docx' + WHERE id = %s + """, + (register_payload["artifact"]["id"],), + ) + conn.commit() + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert ingest_status == 400 + assert ingest_payload == { + "detail": f"artifact path {outside_file.resolve()} escapes workspace root {workspace_path.resolve()}" + } + + def test_task_artifact_chunk_retrieval_endpoints_are_scoped_deterministic_and_isolated( migrated_database_urls, monkeypatch, diff --git a/tests/unit/test_artifacts.py b/tests/unit/test_artifacts.py index 35da2b0..442a08d 100644 --- a/tests/unit/test_artifacts.py +++ b/tests/unit/test_artifacts.py @@ -1,9 +1,12 @@ from __future__ import annotations +import io import zlib from datetime import UTC, datetime, timedelta from pathlib import Path from uuid import UUID, uuid4 +from xml.sax.saxutils import escape +import zipfile import pytest @@ -121,6 +124,72 @@ def _build_pdf_bytes( return bytes(document) +def _build_docx_bytes( + paragraphs: list[str], + *, + include_document_xml: bool = True, + malformed_document_xml: bool = False, +) -> bytes: + document_xml = ( + b"' + '' + "" + + "".join( + ( + "" + f"{escape(paragraph)}" + "" + ) + for paragraph in paragraphs + ) + + ( + "" + "" + "" + "" + "" + "" + ) + ) + ) + + archive_buffer = io.BytesIO() + with zipfile.ZipFile(archive_buffer, "w", compression=zipfile.ZIP_STORED) as archive: + entries = { + "[Content_Types].xml": ( + '' + '' + '' + '' + '' + "" + ).encode("utf-8"), + "_rels/.rels": ( + '' + '' + '' + "" + ).encode("utf-8"), + } + if include_document_xml: + entries["word/document.xml"] = document_xml + + for name, payload in entries.items(): + info = zipfile.ZipInfo(filename=name) + info.date_time = (2026, 3, 13, 10, 0, 0) + info.compress_type = zipfile.ZIP_STORED + archive.writestr(info, payload) + + return archive_buffer.getvalue() + + class ArtifactStoreStub: def __init__(self) -> None: self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) @@ -638,6 +707,84 @@ def test_ingest_task_artifact_record_persists_deterministic_pdf_chunks(tmp_path) ] +def test_ingest_task_artifact_record_persists_deterministic_docx_chunks(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "spec.docx" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_docx_bytes(["A" * 998, "B" * 5, "C"])) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/spec.docx", + media_type_hint="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + assert response == { + "artifact": { + "id": str(artifact["id"]), + "task_id": str(task_id), + "task_workspace_id": str(task_workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "docs/spec.docx", + "media_type_hint": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:30:00+00:00", + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert store.locked_artifact_ids == [artifact["id"]] + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": ("A" * 998) + "\n" + "B", + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + }, + { + "id": store.artifact_chunks[1]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + }, + ] + + def test_ingest_task_artifact_record_is_idempotent_for_already_ingested_artifact() -> None: store = ArtifactStoreStub() user_id = uuid4() @@ -724,7 +871,8 @@ def test_ingest_task_artifact_record_rejects_unsupported_media_type(tmp_path) -> TaskArtifactValidationError, match=( "artifact docs/spec.bin has unsupported media type application/octet-stream; " - "supported types: text/plain, text/markdown, application/pdf" + "supported types: text/plain, text/markdown, application/pdf, " + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ), ): ingest_task_artifact_record( @@ -770,6 +918,78 @@ def test_ingest_task_artifact_record_rejects_textless_pdf(tmp_path) -> None: ) +def test_ingest_task_artifact_record_rejects_textless_docx(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "empty.docx" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_docx_bytes(["", ""])) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/empty.docx", + media_type_hint="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + + with pytest.raises( + TaskArtifactValidationError, + match="artifact docs/empty.docx does not contain extractable DOCX text", + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + +def test_ingest_task_artifact_record_rejects_malformed_docx(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "docs" / "broken.docx" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_docx_bytes(["broken"], malformed_document_xml=True)) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="docs/broken.docx", + media_type_hint="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + + with pytest.raises( + TaskArtifactValidationError, + match="artifact docs/broken.docx is not a valid DOCX", + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + def test_ingest_task_artifact_record_rejects_invalid_utf8_content(tmp_path) -> None: store = ArtifactStoreStub() user_id = uuid4() @@ -838,6 +1058,38 @@ def test_ingest_task_artifact_record_rejects_paths_outside_workspace(tmp_path) - ) +def test_ingest_task_artifact_record_rejects_docx_paths_outside_workspace(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + outside_path = tmp_path / "escape.docx" + outside_path.write_bytes(_build_docx_bytes(["escape"])) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="../escape.docx", + media_type_hint="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + + with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + def test_list_task_artifact_chunk_records_are_deterministic() -> None: store = ArtifactStoreStub() user_id = uuid4() diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py index f43b78c..dac0244 100644 --- a/tests/unit/test_artifacts_main.py +++ b/tests/unit/test_artifacts_main.py @@ -498,7 +498,8 @@ def fake_user_connection(*_args, **_kwargs): def fake_ingest_task_artifact_record(*_args, **_kwargs): raise TaskArtifactValidationError( "artifact docs/spec.bin has unsupported media type application/octet-stream; " - "supported types: text/plain, text/markdown, application/pdf" + "supported types: text/plain, text/markdown, application/pdf, " + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) monkeypatch.setattr(main_module, "get_settings", lambda: settings) @@ -514,7 +515,8 @@ def fake_ingest_task_artifact_record(*_args, **_kwargs): assert json.loads(response.body) == { "detail": ( "artifact docs/spec.bin has unsupported media type application/octet-stream; " - "supported types: text/plain, text/markdown, application/pdf" + "supported types: text/plain, text/markdown, application/pdf, " + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) } diff --git a/tests/unit/test_semantic_retrieval.py b/tests/unit/test_semantic_retrieval.py index 7f3b26f..4404e47 100644 --- a/tests/unit/test_semantic_retrieval.py +++ b/tests/unit/test_semantic_retrieval.py @@ -461,3 +461,51 @@ def test_retrieve_artifact_scoped_semantic_artifact_chunk_records_returns_empty_ "query_vector": [0.0, 1.0, 0.0], "limit": 5, } + + +def test_retrieve_task_scoped_semantic_artifact_chunk_records_infers_docx_media_type_without_hint() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + task_id = seed_task(store) + artifact_id = seed_artifact( + store, + task_id=task_id, + relative_path="docs/spec.docx", + media_type_hint=None, + ) + docx_row = semantic_artifact_row( + store, + task_id=task_id, + task_artifact_id=artifact_id, + relative_path="docs/spec.docx", + score=0.9, + sequence_no=1, + ) + docx_row["media_type_hint"] = None + store.task_artifact_retrieval_rows = [docx_row] + + payload = retrieve_task_scoped_semantic_artifact_chunk_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskScopedSemanticArtifactChunkRetrievalInput( + task_id=task_id, + embedding_config_id=config_id, + query_vector=(1.0, 0.0, 0.0), + limit=1, + ), + ) + + assert payload["items"] == [ + { + "id": str(docx_row["id"]), + "task_id": str(task_id), + "task_artifact_id": str(artifact_id), + "relative_path": "docs/spec.docx", + "media_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 11, + "text": "docs/spec.docx-chunk", + "score": 0.9, + } + ] From 64c1c94760e883aac6b75d20d1a65c645cf035af Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 13:13:28 +0100 Subject: [PATCH 014/135] Sprint 5N: RFC822 email artifact parsing v0 (#14) * Sprint 5N: RFC822 email artifact parsing packet * Sprint 5N: RFC822 email artifact parsing v0 --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 25 +- apps/api/src/alicebot_api/artifacts.py | 107 ++++ .../src/alicebot_api/semantic_retrieval.py | 1 + tests/integration/test_task_artifacts_api.py | 500 +++++++++++++++++- tests/unit/test_artifacts.py | 441 ++++++++++++++- tests/unit/test_artifacts_main.py | 6 +- tests/unit/test_semantic_retrieval.py | 48 ++ 7 files changed, 1112 insertions(+), 16 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index faf17a5..acb2a51 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,16 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5M. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5N. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, and narrow DOCX text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance +- durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, narrow DOCX text, and narrow RFC822 email text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, and DOCX support is limited to narrow local text extraction from `word/document.xml`; OCR, image extraction, layout reconstruction, connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, and RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content; OCR, image extraction, layout reconstruction, connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -62,7 +62,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, and Sprint 5M narrow DOCX artifact ingestion. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, and Sprint 5N narrow RFC822 email artifact ingestion. ## Core Flows Implemented Now @@ -198,17 +198,18 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 1. Accept a user-scoped `POST /v0/task-artifacts/{task_artifact_id}/ingest` request for one visible registered artifact. 2. Lock ingestion for that artifact before deciding whether work is needed. 3. Resolve the persisted workspace `local_path` plus persisted artifact `relative_path`, and reject any rooted-path escape deterministically. -4. Support only the current narrow explicit set: `text/plain`, `text/markdown`, narrow local `application/pdf` text extraction, and narrow local `application/vnd.openxmlformats-officedocument.wordprocessingml.document` text extraction from `word/document.xml`. +4. Support only the current narrow explicit set: `text/plain`, `text/markdown`, narrow local `application/pdf` text extraction, narrow local `application/vnd.openxmlformats-officedocument.wordprocessingml.document` text extraction from `word/document.xml`, and narrow local `message/rfc822` extraction. 5. For plain text and markdown, read file bytes deterministically and require valid UTF-8 text. 6. For PDFs, extract only narrow local text content; OCR, image extraction, and broader PDF compatibility remain out of scope. 7. For DOCX, extract only narrow local text from `word/document.xml`; OCR, image extraction, headers/footers/comments expansion, and layout reconstruction remain out of scope. -8. Reject malformed or textless richer-document inputs deterministically instead of producing misleading chunks. -9. Normalize line endings by rewriting `\r\n` and `\r` to `\n`. -10. Chunk normalized text deterministically with rule `normalized_utf8_text_fixed_window_1000_chars_v1`. -11. Persist ordered `task_artifact_chunks` rows with `sequence_no`, `char_start`, `char_end_exclusive`, and `text`. -12. Update the parent artifact to `ingestion_status = ingested`. -13. If the artifact is already ingested, return the existing artifact and chunk summary without reinserting chunks. -14. `GET /v0/task-artifacts/{task_artifact_id}/chunks` returns visible chunk rows in deterministic `sequence_no ASC, id ASC` order plus stable summary metadata. +8. For RFC822 email, extract only the selected top-level headers plus extractable plain-text body content; nested `message/rfc822` content, HTML rendering, and attachment extraction remain out of scope. +9. Reject malformed or textless richer-document inputs deterministically instead of producing misleading chunks. +10. Normalize line endings by rewriting `\r\n` and `\r` to `\n`. +11. Chunk normalized text deterministically with rule `normalized_utf8_text_fixed_window_1000_chars_v1`. +12. Persist ordered `task_artifact_chunks` rows with `sequence_no`, `char_start`, `char_end_exclusive`, and `text`. +13. Update the parent artifact to `ingestion_status = ingested`. +14. If the artifact is already ingested, return the existing artifact and chunk summary without reinserting chunks. +15. `GET /v0/task-artifacts/{task_artifact_id}/chunks` returns visible chunk rows in deterministic `sequence_no ASC, id ASC` order plus stable summary metadata. ### Artifact Chunk Retrieval diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py index 734f251..72bab4f 100644 --- a/apps/api/src/alicebot_api/artifacts.py +++ b/apps/api/src/alicebot_api/artifacts.py @@ -1,6 +1,10 @@ from __future__ import annotations import io +from email import policy +from email.errors import MessageDefect, MessageError +from email.message import EmailMessage +from email.parser import BytesParser import re import zlib from dataclasses import dataclass @@ -46,10 +50,12 @@ SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE = ( "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) +SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE = "message/rfc822" SUPPORTED_ARTIFACT_MEDIA_TYPES = ( *SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES, SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE, SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE, + SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE, ) SUPPORTED_ARTIFACT_EXTENSIONS = { ".txt": "text/plain", @@ -58,6 +64,7 @@ ".markdown": "text/markdown", ".pdf": SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE, ".docx": SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE, + ".eml": SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE, } TASK_ARTIFACT_CHUNK_MAX_CHARS = 1000 TASK_ARTIFACT_CHUNKING_RULE = "normalized_utf8_text_fixed_window_1000_chars_v1" @@ -148,6 +155,17 @@ _DOCX_BREAK_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}br" _DOCX_CARRIAGE_RETURN_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}cr" _DOCX_BODY_TAG = f"{{{_DOCX_WORDPROCESSING_NAMESPACE}}}body" +_RFC822_EMAIL_PARSE_POLICY = policy.default.clone(raise_on_defect=True) +_RFC822_EXTRACTED_HEADER_NAMES = ( + "From", + "To", + "Cc", + "Bcc", + "Reply-To", + "Subject", + "Date", + "Message-ID", +) @dataclass(frozen=True, slots=True) @@ -327,6 +345,90 @@ def _extract_text_from_docx_artifact_bytes(*, relative_path: str, payload: bytes return extracted_text +def _normalize_rfc822_header_value(value: str) -> str: + return re.sub(r"\s+", " ", value).strip() + + +def _parse_rfc822_email(*, relative_path: str, payload: bytes) -> EmailMessage: + try: + message = BytesParser(policy=_RFC822_EMAIL_PARSE_POLICY).parsebytes(payload) + except (MessageDefect, MessageError, ValueError, TypeError) as exc: + raise TaskArtifactValidationError( + f"artifact {relative_path} is not a valid RFC822 email" + ) from exc + return cast(EmailMessage, message) + + +def _extract_rfc822_header_lines(message: EmailMessage) -> list[str]: + header_lines: list[str] = [] + for header_name in _RFC822_EXTRACTED_HEADER_NAMES: + for header_value in message.get_all(header_name, failobj=[]): + normalized_value = _normalize_rfc822_header_value(str(header_value)) + if normalized_value != "": + header_lines.append(f"{header_name}: {normalized_value}") + return header_lines + + +def _is_extractable_rfc822_text_part(part: EmailMessage) -> bool: + if part.is_multipart(): + return False + if part.get_content_type() != "text/plain": + return False + if part.get_content_disposition() == "attachment": + return False + return part.get_filename() is None + + +def _extract_rfc822_part_text(*, relative_path: str, part: EmailMessage) -> str: + try: + payload = part.get_content() + except (MessageError, LookupError, UnicodeError, ValueError, TypeError) as exc: + raise TaskArtifactValidationError( + f"artifact {relative_path} is not a valid RFC822 email" + ) from exc + if not isinstance(payload, str): + raise TaskArtifactValidationError( + f"artifact {relative_path} is not a valid RFC822 email" + ) + return payload.strip() + + +def _iter_extractable_rfc822_text_parts(message: EmailMessage) -> list[EmailMessage]: + if _is_extractable_rfc822_text_part(message): + return [message] + if not message.is_multipart(): + return [] + + extractable_parts: list[EmailMessage] = [] + for child_part in message.iter_parts(): + child_email_part = cast(EmailMessage, child_part) + if child_email_part.get_content_maintype() == "message": + continue + extractable_parts.extend(_iter_extractable_rfc822_text_parts(child_email_part)) + return extractable_parts + + +def _extract_text_from_rfc822_artifact_bytes(*, relative_path: str, payload: bytes) -> str: + message = _parse_rfc822_email(relative_path=relative_path, payload=payload) + header_lines = _extract_rfc822_header_lines(message) + body_parts = [ + body_text + for part in _iter_extractable_rfc822_text_parts(message) + if (body_text := _extract_rfc822_part_text(relative_path=relative_path, part=part)) + != "" + ] + if not body_parts: + raise TaskArtifactValidationError( + f"artifact {relative_path} does not contain extractable RFC822 email text" + ) + + sections: list[str] = [] + if header_lines: + sections.append("\n".join(header_lines)) + sections.append("\n\n".join(body_parts)) + return "\n\n".join(sections) + + def _extract_pdf_name(dictionary: bytes, key: bytes) -> bytes | None: match = re.search(rb"/" + re.escape(key) + rb"\s*/([A-Za-z0-9_.#-]+)", dictionary) if match is None: @@ -806,6 +908,11 @@ def extract_artifact_text(*, row: TaskArtifactRow, artifact_path: Path, media_ty relative_path=row["relative_path"], payload=payload, ) + if media_type == SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE: + return _extract_text_from_rfc822_artifact_bytes( + relative_path=row["relative_path"], + payload=payload, + ) raise TaskArtifactValidationError( f"artifact {row['relative_path']} has unsupported media type {media_type}" ) diff --git a/apps/api/src/alicebot_api/semantic_retrieval.py b/apps/api/src/alicebot_api/semantic_retrieval.py index 5145062..16e9c85 100644 --- a/apps/api/src/alicebot_api/semantic_retrieval.py +++ b/apps/api/src/alicebot_api/semantic_retrieval.py @@ -34,6 +34,7 @@ ".md": "text/markdown", ".markdown": "text/markdown", ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".eml": "message/rfc822", } diff --git a/tests/integration/test_task_artifacts_api.py b/tests/integration/test_task_artifacts_api.py index 618bc24..2c61dd3 100644 --- a/tests/integration/test_task_artifacts_api.py +++ b/tests/integration/test_task_artifacts_api.py @@ -168,6 +168,116 @@ def _build_docx_bytes( return archive_buffer.getvalue() +def _build_rfc822_email_bytes( + *, + headers: list[tuple[str, str]] | None = None, + plain_body: str | None = None, + plain_parts: list[str] | None = None, + html_body: str | None = None, + attachment_text: str | None = None, + nested_message_bytes: bytes | None = None, + malformed_multipart: bool = False, +) -> bytes: + header_lines = [ + f"{name}: {value}" + for name, value in ( + headers + if headers is not None + else [ + ("From", "Alice "), + ("To", "Bob "), + ("Subject", "Sprint Update"), + ] + ) + ] + if malformed_multipart: + return ( + "\r\n".join( + [ + *header_lines, + "MIME-Version: 1.0", + "Content-Type: multipart/mixed", + "", + "--broken-boundary", + 'Content-Type: text/plain; charset="utf-8"', + "", + "broken", + "--broken-boundary--", + "", + ] + ).encode("utf-8") + ) + + if ( + plain_parts is None + and html_body is None + and attachment_text is None + and nested_message_bytes is None + ): + return ( + "\r\n".join( + [ + *header_lines, + 'Content-Type: text/plain; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + plain_body or "", + ] + ).encode("utf-8") + ) + + boundary = "alicebot-boundary-001" + lines = [ + *header_lines, + "MIME-Version: 1.0", + f'Content-Type: multipart/mixed; boundary="{boundary}"', + "", + ] + for part_text in plain_parts or []: + lines.extend( + [ + f"--{boundary}", + 'Content-Type: text/plain; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + part_text, + ] + ) + if html_body is not None: + lines.extend( + [ + f"--{boundary}", + 'Content-Type: text/html; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + html_body, + ] + ) + if attachment_text is not None: + lines.extend( + [ + f"--{boundary}", + 'Content-Type: text/plain; charset="utf-8"', + 'Content-Disposition: attachment; filename="note.txt"', + "Content-Transfer-Encoding: 8bit", + "", + attachment_text, + ] + ) + if nested_message_bytes is not None: + lines.extend( + [ + f"--{boundary}", + "Content-Type: message/rfc822", + "Content-Transfer-Encoding: 8bit", + "", + nested_message_bytes.decode("utf-8"), + ] + ) + lines.extend([f"--{boundary}--", ""]) + return "\r\n".join(lines).encode("utf-8") + + def invoke_request( method: str, path: str, @@ -596,7 +706,8 @@ def test_task_artifact_ingestion_and_chunk_endpoints_are_deterministic_and_isola "detail": ( "artifact docs/manual.bin has unsupported media type application/octet-stream; " "supported types: text/plain, text/markdown, application/pdf, " - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + "application/vnd.openxmlformats-officedocument.wordprocessingml.document, " + "message/rfc822" ) } @@ -855,6 +966,256 @@ def test_task_artifact_docx_ingestion_and_chunk_endpoints_are_deterministic_and_ } +def test_task_artifact_rfc822_ingestion_and_chunk_endpoints_are_deterministic_and_isolated( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + email_file = workspace_path / "mail" / "update.eml" + email_file.parent.mkdir(parents=True) + email_file.write_bytes( + _build_rfc822_email_bytes( + plain_body=("A" * 916) + "\r\n" + ("B" * 5) + "\rC", + ) + ) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(email_file), + "media_type_hint": "message/rfc822", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_chunk_list_status, isolated_chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_ingest_status, isolated_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(intruder["user_id"])}, + ) + + header_block = ( + "From: Alice \n" + "To: Bob \n" + "Subject: Sprint Update\n\n" + ) + assert ingest_status == 200 + assert ingest_payload == { + "artifact": { + "id": register_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "mail/update.eml", + "media_type_hint": "message/rfc822", + "created_at": register_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert chunk_list_status == 200 + assert chunk_list_payload == { + "items": [ + { + "id": chunk_list_payload["items"][0]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": header_block + ("A" * 916) + "\n" + "B", + "created_at": chunk_list_payload["items"][0]["created_at"], + "updated_at": chunk_list_payload["items"][0]["updated_at"], + }, + { + "id": chunk_list_payload["items"][1]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": chunk_list_payload["items"][1]["created_at"], + "updated_at": chunk_list_payload["items"][1]["updated_at"], + }, + ], + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert isolated_chunk_list_status == 404 + assert isolated_chunk_list_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + assert isolated_ingest_status == 404 + assert isolated_ingest_payload == { + "detail": f"task artifact {register_payload['artifact']['id']} was not found" + } + + +def test_task_artifact_rfc822_ingestion_excludes_nested_email_bodies( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + email_file = workspace_path / "mail" / "forwarded.eml" + email_file.parent.mkdir(parents=True) + email_file.write_bytes( + _build_rfc822_email_bytes( + plain_parts=["Outer body"], + nested_message_bytes=_build_rfc822_email_bytes( + headers=[ + ("From", "Nested "), + ("To", "Team "), + ("Subject", "Nested"), + ], + plain_body="Inner body", + ), + ) + ) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(email_file), + "media_type_hint": "message/rfc822", + }, + ) + assert register_status == 201 + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + chunk_list_status, chunk_list_payload = invoke_request( + "GET", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/chunks", + query_params={"user_id": str(owner["user_id"])}, + ) + + expected_text = ( + "From: Alice \n" + "To: Bob \n" + "Subject: Sprint Update\n\n" + "Outer body" + ) + assert ingest_status == 200 + assert ingest_payload == { + "artifact": { + "id": register_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "mail/forwarded.eml", + "media_type_hint": "message/rfc822", + "created_at": register_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": 1, + "total_characters": len(expected_text), + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + assert chunk_list_status == 200 + assert chunk_list_payload == { + "items": [ + { + "id": chunk_list_payload["items"][0]["id"], + "task_artifact_id": register_payload["artifact"]["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": len(expected_text), + "text": expected_text, + "created_at": chunk_list_payload["items"][0]["created_at"], + "updated_at": chunk_list_payload["items"][0]["updated_at"], + } + ], + "summary": { + "total_count": 1, + "total_characters": len(expected_text), + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + def test_task_artifact_ingestion_supports_markdown_and_reingest_is_idempotent( migrated_database_urls, monkeypatch, @@ -1131,6 +1492,78 @@ def test_task_artifact_ingestion_rejects_textless_or_malformed_docx( } +def test_task_artifact_ingestion_rejects_textless_or_malformed_rfc822_email( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + textless_email = workspace_path / "mail" / "empty.eml" + textless_email.parent.mkdir(parents=True) + textless_email.write_bytes(_build_rfc822_email_bytes(html_body="

html only

")) + malformed_email = workspace_path / "mail" / "broken.eml" + malformed_email.write_bytes(_build_rfc822_email_bytes(malformed_multipart=True)) + + textless_register_status, textless_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(textless_email), + "media_type_hint": "message/rfc822", + }, + ) + malformed_register_status, malformed_register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(malformed_email), + "media_type_hint": "message/rfc822", + }, + ) + assert textless_register_status == 201 + assert malformed_register_status == 201 + + textless_ingest_status, textless_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{textless_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + malformed_ingest_status, malformed_ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{malformed_register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert textless_ingest_status == 400 + assert textless_ingest_payload == { + "detail": "artifact mail/empty.eml does not contain extractable RFC822 email text" + } + assert malformed_ingest_status == 400 + assert malformed_ingest_payload == { + "detail": "artifact mail/broken.eml is not a valid RFC822 email" + } + + def test_task_artifact_ingestion_enforces_rooted_workspace_paths( migrated_database_urls, monkeypatch, @@ -1261,6 +1694,71 @@ def test_task_artifact_docx_ingestion_enforces_rooted_workspace_paths( } +def test_task_artifact_rfc822_ingestion_enforces_rooted_workspace_paths( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + assert workspace_status == 201 + + workspace_path = Path(workspace_payload["workspace"]["local_path"]) + safe_file = workspace_path / "mail" / "update.eml" + safe_file.parent.mkdir(parents=True) + safe_file.write_bytes(_build_rfc822_email_bytes(plain_body="spec")) + outside_file = tmp_path / "escape.eml" + outside_file.write_bytes(_build_rfc822_email_bytes(plain_body="escape")) + + register_status, register_payload = invoke_request( + "POST", + f"/v0/task-workspaces/{workspace_payload['workspace']['id']}/artifacts", + payload={ + "user_id": str(owner["user_id"]), + "local_path": str(safe_file), + "media_type_hint": "message/rfc822", + }, + ) + assert register_status == 201 + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE task_artifacts + SET relative_path = '../../../escape.eml' + WHERE id = %s + """, + (register_payload["artifact"]["id"],), + ) + conn.commit() + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/task-artifacts/{register_payload['artifact']['id']}/ingest", + payload={"user_id": str(owner["user_id"])}, + ) + + assert ingest_status == 400 + assert ingest_payload == { + "detail": f"artifact path {outside_file.resolve()} escapes workspace root {workspace_path.resolve()}" + } + + def test_task_artifact_chunk_retrieval_endpoints_are_scoped_deterministic_and_isolated( migrated_database_urls, monkeypatch, diff --git a/tests/unit/test_artifacts.py b/tests/unit/test_artifacts.py index 442a08d..df3de49 100644 --- a/tests/unit/test_artifacts.py +++ b/tests/unit/test_artifacts.py @@ -190,6 +190,116 @@ def _build_docx_bytes( return archive_buffer.getvalue() +def _build_rfc822_email_bytes( + *, + headers: list[tuple[str, str]] | None = None, + plain_body: str | None = None, + plain_parts: list[str] | None = None, + html_body: str | None = None, + attachment_text: str | None = None, + nested_message_bytes: bytes | None = None, + malformed_multipart: bool = False, +) -> bytes: + header_lines = [ + f"{name}: {value}" + for name, value in ( + headers + if headers is not None + else [ + ("From", "Alice "), + ("To", "Bob "), + ("Subject", "Sprint Update"), + ] + ) + ] + if malformed_multipart: + return ( + "\r\n".join( + [ + *header_lines, + "MIME-Version: 1.0", + "Content-Type: multipart/mixed", + "", + "--broken-boundary", + 'Content-Type: text/plain; charset="utf-8"', + "", + "broken", + "--broken-boundary--", + "", + ] + ).encode("utf-8") + ) + + if ( + plain_parts is None + and html_body is None + and attachment_text is None + and nested_message_bytes is None + ): + return ( + "\r\n".join( + [ + *header_lines, + 'Content-Type: text/plain; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + plain_body or "", + ] + ).encode("utf-8") + ) + + boundary = "alicebot-boundary-001" + lines = [ + *header_lines, + "MIME-Version: 1.0", + f'Content-Type: multipart/mixed; boundary="{boundary}"', + "", + ] + for part_text in plain_parts or []: + lines.extend( + [ + f"--{boundary}", + 'Content-Type: text/plain; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + part_text, + ] + ) + if html_body is not None: + lines.extend( + [ + f"--{boundary}", + 'Content-Type: text/html; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + html_body, + ] + ) + if attachment_text is not None: + lines.extend( + [ + f"--{boundary}", + 'Content-Type: text/plain; charset="utf-8"', + 'Content-Disposition: attachment; filename="note.txt"', + "Content-Transfer-Encoding: 8bit", + "", + attachment_text, + ] + ) + if nested_message_bytes is not None: + lines.extend( + [ + f"--{boundary}", + "Content-Type: message/rfc822", + "Content-Transfer-Encoding: 8bit", + "", + nested_message_bytes.decode("utf-8"), + ] + ) + lines.extend([f"--{boundary}--", ""]) + return "\r\n".join(lines).encode("utf-8") + + class ArtifactStoreStub: def __init__(self) -> None: self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) @@ -785,6 +895,230 @@ def test_ingest_task_artifact_record_persists_deterministic_docx_chunks(tmp_path ] +def test_ingest_task_artifact_record_persists_deterministic_rfc822_chunks(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "mail" / "update.eml" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes( + _build_rfc822_email_bytes( + plain_body=("A" * 916) + "\r\n" + ("B" * 5) + "\rC", + ) + ) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="mail/update.eml", + media_type_hint="message/rfc822", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + header_block = ( + "From: Alice \n" + "To: Bob \n" + "Subject: Sprint Update\n\n" + ) + assert response == { + "artifact": { + "id": str(artifact["id"]), + "task_id": str(task_id), + "task_workspace_id": str(task_workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "mail/update.eml", + "media_type_hint": "message/rfc822", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:30:00+00:00", + }, + "summary": { + "total_count": 2, + "total_characters": 1006, + "media_type": "message/rfc822", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert store.locked_artifact_ids == [artifact["id"]] + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 1000, + "text": header_block + ("A" * 916) + "\n" + "B", + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + }, + { + "id": store.artifact_chunks[1]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 2, + "char_start": 1000, + "char_end_exclusive": 1006, + "text": "BBBB\nC", + "created_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, 1, tzinfo=UTC), + }, + ] + + +def test_ingest_task_artifact_record_extracts_plain_text_parts_from_multipart_rfc822_email( + tmp_path, +) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "mail" / "multipart.eml" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes( + _build_rfc822_email_bytes( + plain_parts=["Alpha\r\nBeta", "Gamma"], + html_body="

ignored

", + attachment_text="ignored attachment", + ) + ) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="mail/multipart.eml", + media_type_hint="message/rfc822", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + assert response["summary"] == { + "total_count": 1, + "total_characters": 99, + "media_type": "message/rfc822", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + } + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 99, + "text": ( + "From: Alice \n" + "To: Bob \n" + "Subject: Sprint Update\n\n" + "Alpha\nBeta\n\nGamma" + ), + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + } + ] + + +def test_ingest_task_artifact_record_excludes_nested_rfc822_message_bodies(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "mail" / "forwarded.eml" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes( + _build_rfc822_email_bytes( + plain_parts=["Outer body"], + nested_message_bytes=_build_rfc822_email_bytes( + headers=[ + ("From", "Nested "), + ("To", "Team "), + ("Subject", "Nested"), + ], + plain_body="Inner body", + ), + ) + ) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="mail/forwarded.eml", + media_type_hint="message/rfc822", + ) + + response = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + expected_text = ( + "From: Alice \n" + "To: Bob \n" + "Subject: Sprint Update\n\n" + "Outer body" + ) + assert response["summary"] == { + "total_count": 1, + "total_characters": len(expected_text), + "media_type": "message/rfc822", + "chunking_rule": TASK_ARTIFACT_CHUNKING_RULE, + "order": ["sequence_no_asc", "id_asc"], + } + assert store.list_task_artifact_chunks(artifact["id"]) == [ + { + "id": store.artifact_chunks[0]["id"], + "user_id": user_id, + "task_artifact_id": artifact["id"], + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": len(expected_text), + "text": expected_text, + "created_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + "updated_at": datetime(2026, 3, 13, 10, 0, tzinfo=UTC), + } + ] + + def test_ingest_task_artifact_record_is_idempotent_for_already_ingested_artifact() -> None: store = ArtifactStoreStub() user_id = uuid4() @@ -872,7 +1206,8 @@ def test_ingest_task_artifact_record_rejects_unsupported_media_type(tmp_path) -> match=( "artifact docs/spec.bin has unsupported media type application/octet-stream; " "supported types: text/plain, text/markdown, application/pdf, " - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + "application/vnd.openxmlformats-officedocument.wordprocessingml.document, " + "message/rfc822" ), ): ingest_task_artifact_record( @@ -990,6 +1325,78 @@ def test_ingest_task_artifact_record_rejects_malformed_docx(tmp_path) -> None: ) +def test_ingest_task_artifact_record_rejects_textless_rfc822_email(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "mail" / "empty.eml" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_rfc822_email_bytes(html_body="

html only

")) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="mail/empty.eml", + media_type_hint="message/rfc822", + ) + + with pytest.raises( + TaskArtifactValidationError, + match="artifact mail/empty.eml does not contain extractable RFC822 email text", + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + +def test_ingest_task_artifact_record_rejects_malformed_rfc822_email(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + artifact_path = workspace_path / "mail" / "broken.eml" + artifact_path.parent.mkdir(parents=True) + artifact_path.write_bytes(_build_rfc822_email_bytes(malformed_multipart=True)) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="mail/broken.eml", + media_type_hint="message/rfc822", + ) + + with pytest.raises( + TaskArtifactValidationError, + match="artifact mail/broken.eml is not a valid RFC822 email", + ): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + def test_ingest_task_artifact_record_rejects_invalid_utf8_content(tmp_path) -> None: store = ArtifactStoreStub() user_id = uuid4() @@ -1090,6 +1497,38 @@ def test_ingest_task_artifact_record_rejects_docx_paths_outside_workspace(tmp_pa ) +def test_ingest_task_artifact_record_rejects_rfc822_paths_outside_workspace(tmp_path) -> None: + store = ArtifactStoreStub() + user_id = uuid4() + task_id = uuid4() + task_workspace_id = uuid4() + workspace_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + workspace_path.mkdir(parents=True) + outside_path = tmp_path / "escape.eml" + outside_path.write_bytes(_build_rfc822_email_bytes(plain_body="escape")) + store.create_task_workspace( + task_workspace_id=task_workspace_id, + task_id=task_id, + user_id=user_id, + local_path=str(workspace_path), + ) + artifact = store.create_task_artifact( + task_id=task_id, + task_workspace_id=task_workspace_id, + status="registered", + ingestion_status="pending", + relative_path="../escape.eml", + media_type_hint="message/rfc822", + ) + + with pytest.raises(TaskArtifactValidationError, match="escapes workspace root"): + ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=artifact["id"]), + ) + + def test_list_task_artifact_chunk_records_are_deterministic() -> None: store = ArtifactStoreStub() user_id = uuid4() diff --git a/tests/unit/test_artifacts_main.py b/tests/unit/test_artifacts_main.py index dac0244..439d47f 100644 --- a/tests/unit/test_artifacts_main.py +++ b/tests/unit/test_artifacts_main.py @@ -499,7 +499,8 @@ def fake_ingest_task_artifact_record(*_args, **_kwargs): raise TaskArtifactValidationError( "artifact docs/spec.bin has unsupported media type application/octet-stream; " "supported types: text/plain, text/markdown, application/pdf, " - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + "application/vnd.openxmlformats-officedocument.wordprocessingml.document, " + "message/rfc822" ) monkeypatch.setattr(main_module, "get_settings", lambda: settings) @@ -516,7 +517,8 @@ def fake_ingest_task_artifact_record(*_args, **_kwargs): "detail": ( "artifact docs/spec.bin has unsupported media type application/octet-stream; " "supported types: text/plain, text/markdown, application/pdf, " - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + "application/vnd.openxmlformats-officedocument.wordprocessingml.document, " + "message/rfc822" ) } diff --git a/tests/unit/test_semantic_retrieval.py b/tests/unit/test_semantic_retrieval.py index 4404e47..3a3058a 100644 --- a/tests/unit/test_semantic_retrieval.py +++ b/tests/unit/test_semantic_retrieval.py @@ -509,3 +509,51 @@ def test_retrieve_task_scoped_semantic_artifact_chunk_records_infers_docx_media_ "score": 0.9, } ] + + +def test_retrieve_task_scoped_semantic_artifact_chunk_records_infers_rfc822_media_type_without_hint() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + task_id = seed_task(store) + artifact_id = seed_artifact( + store, + task_id=task_id, + relative_path="mail/update.eml", + media_type_hint=None, + ) + email_row = semantic_artifact_row( + store, + task_id=task_id, + task_artifact_id=artifact_id, + relative_path="mail/update.eml", + score=0.85, + sequence_no=1, + ) + email_row["media_type_hint"] = None + store.task_artifact_retrieval_rows = [email_row] + + payload = retrieve_task_scoped_semantic_artifact_chunk_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=TaskScopedSemanticArtifactChunkRetrievalInput( + task_id=task_id, + embedding_config_id=config_id, + query_vector=(1.0, 0.0, 0.0), + limit=1, + ), + ) + + assert payload["items"] == [ + { + "id": str(email_row["id"]), + "task_id": str(task_id), + "task_artifact_id": str(artifact_id), + "relative_path": "mail/update.eml", + "media_type": "message/rfc822", + "sequence_no": 1, + "char_start": 0, + "char_end_exclusive": 11, + "text": "mail/update.eml-chunk", + "score": 0.85, + } + ] From a65f00ebb49515ea67fa2bceda98c99bc4cddaa2 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 13:46:37 +0100 Subject: [PATCH 015/135] Sprint 5O: Gmail connection and single-message ingestion (#15) * Sprint 5O: Gmail connection and single-message ingestion packet * Sprint 5O: Gmail connection and single-message ingestion --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 27 +- .../versions/20260316_0026_gmail_accounts.py | 82 +++ apps/api/src/alicebot_api/artifacts.py | 21 +- apps/api/src/alicebot_api/contracts.py | 63 +++ apps/api/src/alicebot_api/gmail.py | 284 ++++++++++ apps/api/src/alicebot_api/main.py | 130 +++++ apps/api/src/alicebot_api/store.py | 122 ++++ tests/integration/test_gmail_accounts_api.py | 525 ++++++++++++++++++ .../unit/test_20260316_0026_gmail_accounts.py | 46 ++ tests/unit/test_gmail.py | 475 ++++++++++++++++ tests/unit/test_gmail_main.py | 192 +++++++ 11 files changed, 1954 insertions(+), 13 deletions(-) create mode 100644 apps/api/alembic/versions/20260316_0026_gmail_accounts.py create mode 100644 apps/api/src/alicebot_api/gmail.py create mode 100644 tests/integration/test_gmail_accounts_api.py create mode 100644 tests/unit/test_20260316_0026_gmail_accounts.py create mode 100644 tests/unit/test_gmail.py create mode 100644 tests/unit/test_gmail_main.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index acb2a51..007de79 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,16 +2,17 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5N. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5O. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement +- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` persistence, deterministic account reads, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline - durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, narrow DOCX text, and narrow RFC822 email text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, and RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content; OCR, image extraction, layout reconstruction, connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -24,7 +25,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` -- task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `POST /v0/tasks/{task_id}/artifact-chunks/semantic-retrieval`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/semantic-retrieval`, `POST /v0/task-artifact-chunk-embeddings`, `GET /v0/task-artifacts/{task_artifact_id}/chunk-embeddings`, `GET /v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings`, `GET /v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` + - Gmail and task execution review: `POST /v0/gmail-accounts`, `GET /v0/gmail-accounts`, `GET /v0/gmail-accounts/{gmail_account_id}`, `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest`, `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `POST /v0/task-workspaces/{task_workspace_id}/artifacts`, `GET /v0/task-artifacts`, `GET /v0/task-artifacts/{task_artifact_id}`, `POST /v0/task-artifacts/{task_artifact_id}/ingest`, `GET /v0/task-artifacts/{task_artifact_id}/chunks`, `POST /v0/tasks/{task_id}/artifact-chunks/retrieve`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/retrieve`, `POST /v0/tasks/{task_id}/artifact-chunks/semantic-retrieval`, `POST /v0/task-artifacts/{task_artifact_id}/chunks/semantic-retrieval`, `POST /v0/task-artifact-chunk-embeddings`, `GET /v0/task-artifacts/{task_artifact_id}/chunk-embeddings`, `GET /v0/task-artifact-chunks/{task_artifact_chunk_id}/embeddings`, `GET /v0/task-artifact-chunk-embeddings/{task_artifact_chunk_embedding_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` - `apps/web` and `workers` remain starter shells only. ### Data Foundation @@ -37,6 +38,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` - graph tables: `entities`, `entity_edges` - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` + - connector tables: `gmail_accounts` - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, `task_artifact_chunk_embeddings` - `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. - `memory_review_labels` are append-only by database enforcement. @@ -58,11 +60,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, a narrow read-only Gmail connector seam, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, and Sprint 5N narrow RFC822 email artifact ingestion. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, and Sprint 5O read-only Gmail account plus single-message ingestion coverage. ## Core Flows Implemented Now @@ -90,6 +92,19 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 8. Order hybrid compile candidates deterministically by source precedence, lexical rank, semantic rank, `relative_path`, `sequence_no`, and `id`. 9. Return stable summary metadata covering scope, query terms, embedding config, query-vector dimensions, candidate counts, deduplication counts, inclusion counts, exclusion counts, and ordering rules. +### Narrow Gmail Connector + +1. Accept a user-scoped `POST /v0/gmail-accounts` request for one read-only Gmail account metadata record. +2. Persist exactly the narrow connector metadata currently required for reads later: `provider_account_id`, `email_address`, optional `display_name`, the fixed Gmail read-only scope, and one access token. +3. Expose deterministic user-scoped Gmail account list and detail reads. +4. Accept a user-scoped `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest` request for one visible Gmail account and one visible task workspace. +5. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. +6. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. +7. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. +8. Require Gmail to return RFC822 `raw` content, validate it against the existing narrow `message/rfc822` extraction rules, and reject unsupported content deterministically. +9. Materialize the message as one rooted `.eml` file inside the selected task workspace and then reuse the existing task-artifact registration plus artifact-ingestion seam. +10. Persist only the resulting `task_artifacts` and `task_artifact_chunks` rows; account-wide sync, search, attachments, Calendar, and write-capable actions remain out of scope. + ### Governed Memory And Retrieval 1. Accept explicit memory candidates through `POST /v0/memories/admit`. @@ -284,7 +299,7 @@ The following areas remain planned later and must not be described as implemente - runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam - artifact reranking, weighted fusion, or precedence changes beyond the current lexical-first hybrid compile merge and direct lexical/direct semantic ordering seams - rich document parsing beyond the current narrow UTF-8 text and markdown ingestion boundary -- read-only Gmail and Calendar connectors +- Gmail search, mailbox sync, attachment ingestion, write-capable Gmail actions, and Calendar connectors - broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler - model-driven extraction, reranking, and broader memory review automation - production deployment automation beyond the local developer stack diff --git a/apps/api/alembic/versions/20260316_0026_gmail_accounts.py b/apps/api/alembic/versions/20260316_0026_gmail_accounts.py new file mode 100644 index 0000000..fce7b4c --- /dev/null +++ b/apps/api/alembic/versions/20260316_0026_gmail_accounts.py @@ -0,0 +1,82 @@ +"""Add user-scoped Gmail account records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260316_0026" +down_revision = "20260314_0025" +branch_labels = None +depends_on = None + +GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly" + +_RLS_TABLES = ("gmail_accounts",) + +_UPGRADE_SCHEMA_STATEMENT = f""" + CREATE TABLE gmail_accounts ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_account_id text NOT NULL, + email_address text NOT NULL, + display_name text, + scope text NOT NULL, + access_token text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT gmail_accounts_provider_account_id_nonempty_check + CHECK (length(provider_account_id) > 0), + CONSTRAINT gmail_accounts_email_address_nonempty_check + CHECK (length(email_address) > 0), + CONSTRAINT gmail_accounts_display_name_nonempty_check + CHECK (display_name IS NULL OR length(display_name) > 0), + CONSTRAINT gmail_accounts_scope_readonly_check + CHECK (scope = '{GMAIL_READONLY_SCOPE}'), + CONSTRAINT gmail_accounts_access_token_nonempty_check + CHECK (length(access_token) > 0) + ); + + CREATE INDEX gmail_accounts_user_created_idx + ON gmail_accounts (user_id, created_at, id); + + CREATE UNIQUE INDEX gmail_accounts_provider_account_idx + ON gmail_accounts (user_id, provider_account_id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON gmail_accounts TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY gmail_accounts_is_owner ON gmail_accounts + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS gmail_accounts", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/artifacts.py b/apps/api/src/alicebot_api/artifacts.py index 72bab4f..9653594 100644 --- a/apps/api/src/alicebot_api/artifacts.py +++ b/apps/api/src/alicebot_api/artifacts.py @@ -891,30 +891,37 @@ def _extract_text_from_pdf_artifact_bytes(*, relative_path: str, payload: bytes) return extracted_text -def extract_artifact_text(*, row: TaskArtifactRow, artifact_path: Path, media_type: str) -> str: - payload = artifact_path.read_bytes() +def extract_artifact_text_from_bytes(*, relative_path: str, payload: bytes, media_type: str) -> str: if media_type in SUPPORTED_TEXT_ARTIFACT_MEDIA_TYPES: return _extract_text_from_utf8_artifact_bytes( - relative_path=row["relative_path"], + relative_path=relative_path, payload=payload, ) if media_type == SUPPORTED_PDF_ARTIFACT_MEDIA_TYPE: return _extract_text_from_pdf_artifact_bytes( - relative_path=row["relative_path"], + relative_path=relative_path, payload=payload, ) if media_type == SUPPORTED_DOCX_ARTIFACT_MEDIA_TYPE: return _extract_text_from_docx_artifact_bytes( - relative_path=row["relative_path"], + relative_path=relative_path, payload=payload, ) if media_type == SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE: return _extract_text_from_rfc822_artifact_bytes( - relative_path=row["relative_path"], + relative_path=relative_path, payload=payload, ) raise TaskArtifactValidationError( - f"artifact {row['relative_path']} has unsupported media type {media_type}" + f"artifact {relative_path} has unsupported media type {media_type}" + ) + + +def extract_artifact_text(*, row: TaskArtifactRow, artifact_path: Path, media_type: str) -> str: + return extract_artifact_text_from_bytes( + relative_path=row["relative_path"], + payload=artifact_path.read_bytes(), + media_type=media_type, ) diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index 7362392..d614b16 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -136,6 +136,7 @@ APPROVAL_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] +GMAIL_ACCOUNT_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_LIST_ORDER = ["created_at_asc", "id_asc"] TASK_ARTIFACT_CHUNK_LIST_ORDER = ["sequence_no_asc", "id_asc"] TASK_ARTIFACT_CHUNK_EMBEDDING_LIST_ORDER = [ @@ -174,6 +175,9 @@ TRACE_KIND_APPROVAL_RESOLVE = TRACE_KIND_APPROVAL_RESOLUTION PROXY_EXECUTION_VERSION_V0 = "proxy_execution_v0" TRACE_KIND_PROXY_EXECUTE = "tool.proxy.execute" +GMAIL_PROVIDER = "gmail" +GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN = "oauth_access_token" +GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly" TASK_STEP_SEQUENCE_VERSION_V0 = "task_step_sequence_v0" TRACE_KIND_TASK_STEP_SEQUENCE = "task.step.sequence" TASK_STEP_CONTINUATION_VERSION_V0 = "task_step_continuation_v0" @@ -1769,6 +1773,65 @@ class TaskDetailResponse(TypedDict): task: TaskRecord +@dataclass(frozen=True, slots=True) +class GmailAccountConnectInput: + provider_account_id: str + email_address: str + display_name: str | None + scope: str + access_token: str + + +@dataclass(frozen=True, slots=True) +class GmailMessageIngestInput: + gmail_account_id: UUID + task_workspace_id: UUID + provider_message_id: str + + +class GmailAccountRecord(TypedDict): + id: str + provider: str + auth_kind: str + provider_account_id: str + email_address: str + display_name: str | None + scope: str + created_at: str + updated_at: str + + +class GmailAccountConnectResponse(TypedDict): + account: GmailAccountRecord + + +class GmailAccountListSummary(TypedDict): + total_count: int + order: list[str] + + +class GmailAccountListResponse(TypedDict): + items: list[GmailAccountRecord] + summary: GmailAccountListSummary + + +class GmailAccountDetailResponse(TypedDict): + account: GmailAccountRecord + + +class GmailMessageIngestionRecord(TypedDict): + provider_message_id: str + artifact_relative_path: str + media_type: str + + +class GmailMessageIngestionResponse(TypedDict): + account: GmailAccountRecord + message: GmailMessageIngestionRecord + artifact: TaskArtifactRecord + summary: TaskArtifactChunkListSummary + + @dataclass(frozen=True, slots=True) class TaskWorkspaceCreateInput: task_id: UUID diff --git a/apps/api/src/alicebot_api/gmail.py b/apps/api/src/alicebot_api/gmail.py new file mode 100644 index 0000000..6243916 --- /dev/null +++ b/apps/api/src/alicebot_api/gmail.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import base64 +import json +import re +from pathlib import Path +from urllib.error import HTTPError, URLError +from urllib.parse import quote +from urllib.request import Request, urlopen +from uuid import UUID + +import psycopg + +from alicebot_api.artifacts import ( + SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE, + TaskArtifactAlreadyExistsError, + TaskArtifactValidationError, + ensure_artifact_path_is_rooted, + extract_artifact_text_from_bytes, + ingest_task_artifact_record, + register_task_artifact_record, +) +from alicebot_api.contracts import ( + GMAIL_ACCOUNT_LIST_ORDER, + GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, + GMAIL_PROVIDER, + GMAIL_READONLY_SCOPE, + GmailAccountConnectInput, + GmailAccountConnectResponse, + GmailAccountDetailResponse, + GmailAccountListResponse, + GmailAccountRecord, + GmailMessageIngestInput, + GmailMessageIngestionResponse, + TaskArtifactIngestInput, + TaskArtifactRegisterInput, +) +from alicebot_api.store import ContinuityStore, GmailAccountRow +from alicebot_api.workspaces import TaskWorkspaceNotFoundError + +GMAIL_MESSAGE_FETCH_TIMEOUT_SECONDS = 30 +GMAIL_MESSAGE_ARTIFACT_ROOT = "gmail" +_PATH_SEGMENT_PATTERN = re.compile(r"[^A-Za-z0-9._-]+") + + +class GmailAccountNotFoundError(LookupError): + """Raised when a Gmail account is not visible inside the current user scope.""" + + +class GmailAccountAlreadyExistsError(RuntimeError): + """Raised when the same provider account is connected twice for one user.""" + + +class GmailMessageNotFoundError(LookupError): + """Raised when a Gmail message cannot be found in the current account.""" + + +class GmailMessageUnsupportedError(ValueError): + """Raised when Gmail content cannot be converted into the RFC822 artifact seam.""" + + +class GmailMessageFetchError(RuntimeError): + """Raised when the Gmail API call fails for non-deterministic upstream reasons.""" + + +def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: + return { + "id": str(row["id"]), + "provider": GMAIL_PROVIDER, + "auth_kind": GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, + "provider_account_id": row["provider_account_id"], + "email_address": row["email_address"], + "display_name": row["display_name"], + "scope": row["scope"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def create_gmail_account_record( + store: ContinuityStore, + *, + user_id: UUID, + request: GmailAccountConnectInput, +) -> GmailAccountConnectResponse: + del user_id + + existing = store.get_gmail_account_by_provider_account_id_optional(request.provider_account_id) + if existing is not None: + raise GmailAccountAlreadyExistsError( + f"gmail account {request.provider_account_id} is already connected" + ) + + try: + row = store.create_gmail_account( + provider_account_id=request.provider_account_id, + email_address=request.email_address, + display_name=request.display_name, + scope=request.scope, + access_token=request.access_token, + ) + except psycopg.errors.UniqueViolation as exc: + raise GmailAccountAlreadyExistsError( + f"gmail account {request.provider_account_id} is already connected" + ) from exc + + return {"account": serialize_gmail_account_row(row)} + + +def list_gmail_account_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> GmailAccountListResponse: + del user_id + + items = [serialize_gmail_account_row(row) for row in store.list_gmail_accounts()] + return { + "items": items, + "summary": { + "total_count": len(items), + "order": list(GMAIL_ACCOUNT_LIST_ORDER), + }, + } + + +def get_gmail_account_record( + store: ContinuityStore, + *, + user_id: UUID, + gmail_account_id: UUID, +) -> GmailAccountDetailResponse: + del user_id + + row = store.get_gmail_account_optional(gmail_account_id) + if row is None: + raise GmailAccountNotFoundError(f"gmail account {gmail_account_id} was not found") + return {"account": serialize_gmail_account_row(row)} + + +def _sanitize_path_segment(value: str) -> str: + sanitized = _PATH_SEGMENT_PATTERN.sub("_", value.strip()) + return sanitized.strip("._") or "message" + + +def build_gmail_message_artifact_relative_path( + *, + provider_account_id: str, + provider_message_id: str, +) -> str: + return ( + f"{GMAIL_MESSAGE_ARTIFACT_ROOT}/" + f"{_sanitize_path_segment(provider_account_id)}/" + f"{_sanitize_path_segment(provider_message_id)}.eml" + ) + + +def fetch_gmail_message_raw_bytes(*, access_token: str, provider_message_id: str) -> bytes: + request = Request( + ( + "https://gmail.googleapis.com/gmail/v1/users/me/messages/" + f"{quote(provider_message_id, safe='')}?format=raw" + ), + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + }, + method="GET", + ) + + try: + with urlopen(request, timeout=GMAIL_MESSAGE_FETCH_TIMEOUT_SECONDS) as response: + payload = json.loads(response.read().decode("utf-8")) + except HTTPError as exc: + if exc.code == 404: + raise GmailMessageNotFoundError( + f"gmail message {provider_message_id} was not found" + ) from exc + raise GmailMessageFetchError( + f"gmail message {provider_message_id} could not be fetched" + ) from exc + except (OSError, URLError, UnicodeDecodeError, json.JSONDecodeError) as exc: + raise GmailMessageFetchError( + f"gmail message {provider_message_id} could not be fetched" + ) from exc + + raw_payload = payload.get("raw") + if not isinstance(raw_payload, str) or raw_payload == "": + raise GmailMessageUnsupportedError( + f"gmail message {provider_message_id} did not include RFC822 raw content" + ) + + padding = "=" * (-len(raw_payload) % 4) + try: + return base64.urlsafe_b64decode(raw_payload + padding) + except (ValueError, TypeError) as exc: + raise GmailMessageUnsupportedError( + f"gmail message {provider_message_id} did not include valid RFC822 raw content" + ) from exc + + +def ingest_gmail_message_record( + store: ContinuityStore, + *, + user_id: UUID, + request: GmailMessageIngestInput, +) -> GmailMessageIngestionResponse: + account = store.get_gmail_account_optional(request.gmail_account_id) + if account is None: + raise GmailAccountNotFoundError(f"gmail account {request.gmail_account_id} was not found") + + workspace = store.get_task_workspace_optional(request.task_workspace_id) + if workspace is None: + raise TaskWorkspaceNotFoundError( + f"task workspace {request.task_workspace_id} was not found" + ) + + store.lock_task_artifacts(workspace["id"]) + relative_path = build_gmail_message_artifact_relative_path( + provider_account_id=account["provider_account_id"], + provider_message_id=request.provider_message_id, + ) + existing_artifact = store.get_task_artifact_by_workspace_relative_path_optional( + task_workspace_id=request.task_workspace_id, + relative_path=relative_path, + ) + if existing_artifact is not None: + raise TaskArtifactAlreadyExistsError( + f"artifact {relative_path} is already registered for task workspace {request.task_workspace_id}" + ) + + raw_bytes = fetch_gmail_message_raw_bytes( + access_token=account["access_token"], + provider_message_id=request.provider_message_id, + ) + + try: + extract_artifact_text_from_bytes( + relative_path=relative_path, + payload=raw_bytes, + media_type=SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE, + ) + except TaskArtifactValidationError as exc: + raise GmailMessageUnsupportedError( + f"gmail message {request.provider_message_id} is not a supported RFC822 email" + ) from exc + + workspace_path = Path(workspace["local_path"]).expanduser().resolve() + artifact_path = (workspace_path / relative_path).resolve() + ensure_artifact_path_is_rooted( + workspace_path=workspace_path, + artifact_path=artifact_path, + ) + artifact_path.parent.mkdir(parents=True, exist_ok=True) + if artifact_path.exists(): + raise TaskArtifactValidationError( + f"artifact path {artifact_path} already exists before Gmail ingestion registration" + ) + artifact_path.write_bytes(raw_bytes) + + artifact_payload = register_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactRegisterInput( + task_workspace_id=request.task_workspace_id, + local_path=str(artifact_path), + media_type_hint=SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE, + ), + ) + ingestion_payload = ingest_task_artifact_record( + store, + user_id=user_id, + request=TaskArtifactIngestInput(task_artifact_id=UUID(artifact_payload["artifact"]["id"])), + ) + return { + "account": serialize_gmail_account_row(account), + "message": { + "provider_message_id": request.provider_message_id, + "artifact_relative_path": ingestion_payload["artifact"]["relative_path"], + "media_type": SUPPORTED_RFC822_ARTIFACT_MEDIA_TYPE, + }, + "artifact": ingestion_payload["artifact"], + "summary": ingestion_payload["summary"], + } diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index e530778..c5a224b 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -45,6 +45,9 @@ EntityCreateInput, EntityType, ExplicitPreferenceExtractionRequestInput, + GMAIL_READONLY_SCOPE, + GmailAccountConnectInput, + GmailMessageIngestInput, MemoryCandidateInput, MemoryEmbeddingUpsertInput, MemoryReviewLabelValue, @@ -135,6 +138,17 @@ list_execution_budget_records, supersede_execution_budget_record, ) +from alicebot_api.gmail import ( + GmailAccountAlreadyExistsError, + GmailAccountNotFoundError, + GmailMessageFetchError, + GmailMessageNotFoundError, + GmailMessageUnsupportedError, + create_gmail_account_record, + get_gmail_account_record, + ingest_gmail_message_record, + list_gmail_account_records, +) from alicebot_api.embedding import ( EmbeddingConfigValidationError, MemoryEmbeddingNotFoundError, @@ -518,6 +532,20 @@ class ExecuteApprovedProxyRequest(BaseModel): user_id: UUID +class ConnectGmailAccountRequest(BaseModel): + user_id: UUID + provider_account_id: str = Field(min_length=1, max_length=320) + email_address: str = Field(min_length=1, max_length=320) + display_name: str | None = Field(default=None, min_length=1, max_length=200) + scope: Literal["https://www.googleapis.com/auth/gmail.readonly"] = GMAIL_READONLY_SCOPE + access_token: str = Field(min_length=1, max_length=8000) + + +class IngestGmailMessageRequest(BaseModel): + user_id: UUID + task_workspace_id: UUID + + class CreateTaskWorkspaceRequest(BaseModel): user_id: UUID @@ -1256,6 +1284,108 @@ def get_task(task_id: UUID, user_id: UUID) -> JSONResponse: ) +@app.post("/v0/gmail-accounts") +def connect_gmail_account(request: ConnectGmailAccountRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_gmail_account_record( + ContinuityStore(conn), + user_id=request.user_id, + request=GmailAccountConnectInput( + provider_account_id=request.provider_account_id, + email_address=request.email_address, + display_name=request.display_name, + scope=request.scope, + access_token=request.access_token, + ), + ) + except GmailAccountAlreadyExistsError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/gmail-accounts") +def list_gmail_accounts(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_gmail_account_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/gmail-accounts/{gmail_account_id}") +def get_gmail_account(gmail_account_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_gmail_account_record( + ContinuityStore(conn), + user_id=user_id, + gmail_account_id=gmail_account_id, + ) + except GmailAccountNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest") +def ingest_gmail_message( + gmail_account_id: UUID, + provider_message_id: str, + request: IngestGmailMessageRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = ingest_gmail_message_record( + ContinuityStore(conn), + user_id=request.user_id, + request=GmailMessageIngestInput( + gmail_account_id=gmail_account_id, + task_workspace_id=request.task_workspace_id, + provider_message_id=provider_message_id, + ), + ) + except GmailAccountNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskWorkspaceNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except GmailMessageNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except GmailMessageUnsupportedError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except TaskArtifactValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except GmailMessageFetchError as exc: + return JSONResponse(status_code=502, content={"detail": str(exc)}) + except TaskArtifactAlreadyExistsError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + @app.post("/v0/tasks/{task_id}/workspace") def create_task_workspace(task_id: UUID, request: CreateTaskWorkspaceRequest) -> JSONResponse: settings = get_settings() diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index c2f0771..4098cda 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -245,6 +245,18 @@ class TaskWorkspaceRow(TypedDict): updated_at: datetime +class GmailAccountRow(TypedDict): + id: UUID + user_id: UUID + provider_account_id: str + email_address: str + display_name: str | None + scope: str + access_token: str + created_at: datetime + updated_at: datetime + + class TaskArtifactRow(TypedDict): id: UUID user_id: UUID @@ -1469,6 +1481,86 @@ class LabelCountRow(TypedDict): updated_at """ +INSERT_GMAIL_ACCOUNT_SQL = """ + INSERT INTO gmail_accounts ( + user_id, + provider_account_id, + email_address, + display_name, + scope, + access_token, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + provider_account_id, + email_address, + display_name, + scope, + access_token, + created_at, + updated_at + """ + +GET_GMAIL_ACCOUNT_SQL = """ + SELECT + id, + user_id, + provider_account_id, + email_address, + display_name, + scope, + access_token, + created_at, + updated_at + FROM gmail_accounts + WHERE id = %s + """ + +GET_GMAIL_ACCOUNT_BY_PROVIDER_ACCOUNT_ID_SQL = """ + SELECT + id, + user_id, + provider_account_id, + email_address, + display_name, + scope, + access_token, + created_at, + updated_at + FROM gmail_accounts + WHERE provider_account_id = %s + ORDER BY created_at ASC, id ASC + LIMIT 1 + """ + +LIST_GMAIL_ACCOUNTS_SQL = """ + SELECT + id, + user_id, + provider_account_id, + email_address, + display_name, + scope, + access_token, + created_at, + updated_at + FROM gmail_accounts + ORDER BY created_at ASC, id ASC + """ + INSERT_TASK_WORKSPACE_SQL = """ INSERT INTO task_workspaces ( user_id, @@ -2987,6 +3079,36 @@ def update_task_status_optional( (status, latest_approval_id, latest_execution_id, task_id), ) + def create_gmail_account( + self, + *, + provider_account_id: str, + email_address: str, + display_name: str | None, + scope: str, + access_token: str, + ) -> GmailAccountRow: + return self._fetch_one( + "create_gmail_account", + INSERT_GMAIL_ACCOUNT_SQL, + (provider_account_id, email_address, display_name, scope, access_token), + ) + + def get_gmail_account_optional(self, gmail_account_id: UUID) -> GmailAccountRow | None: + return self._fetch_optional_one(GET_GMAIL_ACCOUNT_SQL, (gmail_account_id,)) + + def get_gmail_account_by_provider_account_id_optional( + self, + provider_account_id: str, + ) -> GmailAccountRow | None: + return self._fetch_optional_one( + GET_GMAIL_ACCOUNT_BY_PROVIDER_ACCOUNT_ID_SQL, + (provider_account_id,), + ) + + def list_gmail_accounts(self) -> list[GmailAccountRow]: + return self._fetch_all(LIST_GMAIL_ACCOUNTS_SQL) + def lock_task_workspaces(self, task_id: UUID) -> None: with self.conn.cursor() as cur: cur.execute(LOCK_TASK_WORKSPACES_SQL, (str(task_id),)) diff --git a/tests/integration/test_gmail_accounts_api.py b/tests/integration/test_gmail_accounts_api.py new file mode 100644 index 0000000..d9aa1df --- /dev/null +++ b/tests/integration/test_gmail_accounts_api.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +import alicebot_api.gmail as gmail_module +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def _build_rfc822_email_bytes(*, subject: str, plain_body: str) -> bytes: + return ( + "\r\n".join( + [ + "From: Alice ", + "To: Bob ", + f"Subject: {subject}", + 'Content-Type: text/plain; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + plain_body, + ] + ).encode("utf-8") + ) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + + return {"user_id": user_id} + + +def seed_task(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Gmail thread") + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + task = store.create_task( + thread_id=thread["id"], + tool_id=tool["id"], + status="approved", + request={ + "thread_id": str(thread["id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + tool={ + "id": str(tool["id"]), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + latest_approval_id=None, + latest_execution_id=None, + ) + + return { + "user_id": user_id, + "task_id": task["id"], + } + + +def _connect_gmail_account(*, user_id: UUID, provider_account_id: str, email_address: str) -> tuple[int, dict[str, Any]]: + return invoke_request( + "POST", + "/v0/gmail-accounts", + payload={ + "user_id": str(user_id), + "provider_account_id": provider_account_id, + "email_address": email_address, + "display_name": email_address.split("@", 1)[0].title(), + "scope": "https://www.googleapis.com/auth/gmail.readonly", + "access_token": f"token-for-{provider_account_id}", + }, + ) + + +def test_gmail_account_endpoints_connect_list_detail_and_isolate( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + create_status, create_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-001", + email_address="owner@gmail.example", + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/gmail-accounts", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/gmail-accounts/{create_payload['account']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + duplicate_status, duplicate_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-001", + email_address="owner@gmail.example", + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/gmail-accounts", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/gmail-accounts/{create_payload['account']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert create_status == 201 + assert create_payload == { + "account": { + "id": create_payload["account"]["id"], + "provider": "gmail", + "auth_kind": "oauth_access_token", + "provider_account_id": "acct-owner-001", + "email_address": "owner@gmail.example", + "display_name": "Owner", + "scope": "https://www.googleapis.com/auth/gmail.readonly", + "created_at": create_payload["account"]["created_at"], + "updated_at": create_payload["account"]["updated_at"], + } + } + assert list_status == 200 + assert list_payload == { + "items": [create_payload["account"]], + "summary": {"total_count": 1, "order": ["created_at_asc", "id_asc"]}, + } + assert detail_status == 200 + assert detail_payload == {"account": create_payload["account"]} + assert duplicate_status == 409 + assert duplicate_payload == {"detail": "gmail account acct-owner-001 is already connected"} + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"gmail account {create_payload['account']['id']} was not found" + } + + +def test_gmail_message_ingestion_endpoint_persists_artifact_and_chunks( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + raw_bytes = _build_rfc822_email_bytes(subject="Inbox Update", plain_body="ingest this message") + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + lambda **_kwargs: raw_bytes, + ) + + account_status, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-001", + email_address="owner@gmail.example", + ) + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert account_status == 201 + assert workspace_status == 201 + assert ingest_status == 200 + assert ingest_payload == { + "account": account_payload["account"], + "message": { + "provider_message_id": "msg-001", + "artifact_relative_path": "gmail/acct-owner-001/msg-001.eml", + "media_type": "message/rfc822", + }, + "artifact": { + "id": ingest_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "gmail/acct-owner-001/msg-001.eml", + "media_type_hint": "message/rfc822", + "created_at": ingest_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": ingest_payload["summary"]["total_count"], + "total_characters": ingest_payload["summary"]["total_characters"], + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert ingest_payload["summary"]["total_count"] >= 1 + artifact_file = ( + Path(workspace_payload["workspace"]["local_path"]) / "gmail" / "acct-owner-001" / "msg-001.eml" + ) + assert artifact_file.read_bytes() == raw_bytes + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + artifact_rows = store.list_task_artifacts_for_task(owner["task_id"]) + assert len(artifact_rows) == 1 + assert artifact_rows[0]["relative_path"] == "gmail/acct-owner-001/msg-001.eml" + assert artifact_rows[0]["ingestion_status"] == "ingested" + chunk_rows = store.list_task_artifact_chunks(artifact_rows[0]["id"]) + assert len(chunk_rows) == ingest_payload["summary"]["total_count"] + assert chunk_rows[0]["text"].startswith("From: Alice ") + + +def test_gmail_message_ingestion_endpoint_rejects_cross_user_workspace_access( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + lambda **_kwargs: _build_rfc822_email_bytes( + subject="Inbox Update", + plain_body="ingest this message", + ), + ) + + _, owner_workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + _, intruder_account_payload = _connect_gmail_account( + user_id=intruder["user_id"], + provider_account_id="acct-intruder-001", + email_address="intruder@gmail.example", + ) + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{intruder_account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(intruder["user_id"]), + "task_workspace_id": owner_workspace_payload["workspace"]["id"], + }, + ) + + assert ingest_status == 404 + assert ingest_payload == { + "detail": f"task workspace {owner_workspace_payload['workspace']['id']} was not found" + } + + +def test_gmail_message_ingestion_endpoint_rejects_sanitized_path_collisions_without_overwrite( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + first_bytes = _build_rfc822_email_bytes(subject="First", plain_body="first message body") + second_bytes = _build_rfc822_email_bytes(subject="Second", plain_body="second message body") + + def fake_fetch_gmail_message_raw_bytes(*, provider_message_id: str, **_kwargs) -> bytes: + if provider_message_id == "msg+001": + return first_bytes + if provider_message_id == "msg:001": + return second_bytes + raise AssertionError(f"unexpected provider_message_id: {provider_message_id}") + + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + fake_fetch_gmail_message_raw_bytes, + ) + + _, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-001", + email_address="owner@gmail.example", + ) + _, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + first_ingest_status, first_ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg+001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + second_ingest_status, second_ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg:001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + artifact_file = ( + Path(workspace_payload["workspace"]["local_path"]) / "gmail" / "acct-owner-001" / "msg_001.eml" + ) + + assert first_ingest_status == 200 + assert second_ingest_status == 409 + assert second_ingest_payload == { + "detail": ( + "artifact gmail/acct-owner-001/msg_001.eml is already registered for task workspace " + f"{workspace_payload['workspace']['id']}" + ) + } + assert artifact_file.read_bytes() == first_bytes + assert first_ingest_payload["artifact"]["relative_path"] == "gmail/acct-owner-001/msg_001.eml" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + artifact_rows = store.list_task_artifacts_for_task(owner["task_id"]) + assert len(artifact_rows) == 1 + assert artifact_rows[0]["relative_path"] == "gmail/acct-owner-001/msg_001.eml" + + +def test_gmail_message_ingestion_endpoint_rejects_missing_and_unsupported_messages( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + _, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-001", + email_address="owner@gmail.example", + ) + _, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + + def fake_missing(**_kwargs): + raise gmail_module.GmailMessageNotFoundError("gmail message msg-missing was not found") + + monkeypatch.setattr(gmail_module, "fetch_gmail_message_raw_bytes", fake_missing) + missing_status, missing_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-missing/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + lambda **_kwargs: b"not-a-valid-rfc822-email", + ) + unsupported_status, unsupported_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-unsupported/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert missing_status == 404 + assert missing_payload == {"detail": "gmail message msg-missing was not found"} + assert unsupported_status == 400 + assert unsupported_payload == { + "detail": "gmail message msg-unsupported is not a supported RFC822 email" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_task_artifacts_for_task(owner["task_id"]) == [] diff --git a/tests/unit/test_20260316_0026_gmail_accounts.py b/tests/unit/test_20260316_0026_gmail_accounts.py new file mode 100644 index 0000000..cc0153a --- /dev/null +++ b/tests/unit/test_20260316_0026_gmail_accounts.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260316_0026_gmail_accounts" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE gmail_accounts ENABLE ROW LEVEL SECURITY", + "ALTER TABLE gmail_accounts FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_gmail_account_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON gmail_accounts TO alicebot_app", + ) diff --git a/tests/unit/test_gmail.py b/tests/unit/test_gmail.py new file mode 100644 index 0000000..d602b5c --- /dev/null +++ b/tests/unit/test_gmail.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from pathlib import Path +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.artifacts import TaskArtifactAlreadyExistsError +from alicebot_api.contracts import GMAIL_READONLY_SCOPE, GmailAccountConnectInput, GmailMessageIngestInput +from alicebot_api.gmail import ( + GmailAccountAlreadyExistsError, + GmailAccountNotFoundError, + GmailMessageUnsupportedError, + build_gmail_message_artifact_relative_path, + create_gmail_account_record, + get_gmail_account_record, + ingest_gmail_message_record, + list_gmail_account_records, +) +from alicebot_api.workspaces import TaskWorkspaceNotFoundError + + +def _build_rfc822_email_bytes(*, plain_body: str) -> bytes: + return ( + "\r\n".join( + [ + "From: Alice ", + "To: Bob ", + "Subject: Sprint Update", + 'Content-Type: text/plain; charset="utf-8"', + "Content-Transfer-Encoding: 8bit", + "", + plain_body, + ] + ).encode("utf-8") + ) + + +class GmailStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 16, 10, 0, tzinfo=UTC) + self.gmail_accounts: list[dict[str, object]] = [] + self.task_workspaces: list[dict[str, object]] = [] + self.task_artifacts: list[dict[str, object]] = [] + self.operations: list[tuple[str, object]] = [] + + def create_gmail_account( + self, + *, + provider_account_id: str, + email_address: str, + display_name: str | None, + scope: str, + access_token: str, + ) -> dict[str, object]: + row = { + "id": uuid4(), + "user_id": uuid4(), + "provider_account_id": provider_account_id, + "email_address": email_address, + "display_name": display_name, + "scope": scope, + "access_token": access_token, + "created_at": self.base_time + timedelta(minutes=len(self.gmail_accounts)), + "updated_at": self.base_time + timedelta(minutes=len(self.gmail_accounts)), + } + self.gmail_accounts.append(row) + return row + + def get_gmail_account_optional(self, gmail_account_id: UUID) -> dict[str, object] | None: + return next( + (row for row in self.gmail_accounts if row["id"] == gmail_account_id), + None, + ) + + def get_gmail_account_by_provider_account_id_optional( + self, + provider_account_id: str, + ) -> dict[str, object] | None: + return next( + ( + row + for row in self.gmail_accounts + if row["provider_account_id"] == provider_account_id + ), + None, + ) + + def list_gmail_accounts(self) -> list[dict[str, object]]: + return sorted( + self.gmail_accounts, + key=lambda row: (row["created_at"], row["id"]), + ) + + def create_task_workspace(self, *, task_workspace_id: UUID, local_path: str) -> dict[str, object]: + row = { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": uuid4(), + "status": "active", + "local_path": local_path, + "created_at": self.base_time, + "updated_at": self.base_time, + } + self.task_workspaces.append(row) + return row + + def get_task_workspace_optional(self, task_workspace_id: UUID) -> dict[str, object] | None: + return next( + (row for row in self.task_workspaces if row["id"] == task_workspace_id), + None, + ) + + def lock_task_artifacts(self, task_workspace_id: UUID) -> None: + self.operations.append(("lock_task_artifacts", task_workspace_id)) + + def create_task_artifact( + self, + *, + task_workspace_id: UUID, + relative_path: str, + ) -> dict[str, object]: + row = { + "id": uuid4(), + "user_id": uuid4(), + "task_id": uuid4(), + "task_workspace_id": task_workspace_id, + "status": "registered", + "ingestion_status": "ingested", + "relative_path": relative_path, + "media_type_hint": "message/rfc822", + "created_at": self.base_time, + "updated_at": self.base_time, + } + self.task_artifacts.append(row) + return row + + def get_task_artifact_by_workspace_relative_path_optional( + self, + *, + task_workspace_id: UUID, + relative_path: str, + ) -> dict[str, object] | None: + self.operations.append( + ("get_task_artifact_by_workspace_relative_path_optional", task_workspace_id) + ) + return next( + ( + row + for row in self.task_artifacts + if row["task_workspace_id"] == task_workspace_id + and row["relative_path"] == relative_path + ), + None, + ) + + +def test_create_list_and_get_gmail_account_records_are_deterministic() -> None: + store = GmailStoreStub() + user_id = uuid4() + + first = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + ) + second = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-002", + email_address="owner+2@example.com", + display_name=None, + scope=GMAIL_READONLY_SCOPE, + access_token="token-2", + ), + ) + + assert list_gmail_account_records(store, user_id=user_id) == { + "items": [first["account"], second["account"]], + "summary": {"total_count": 2, "order": ["created_at_asc", "id_asc"]}, + } + assert get_gmail_account_record( + store, + user_id=user_id, + gmail_account_id=UUID(second["account"]["id"]), + ) == {"account": second["account"]} + + +def test_create_gmail_account_record_rejects_duplicate_provider_account_id() -> None: + store = GmailStoreStub() + user_id = uuid4() + request = GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ) + create_gmail_account_record(store, user_id=user_id, request=request) + + with pytest.raises( + GmailAccountAlreadyExistsError, + match="gmail account acct-001 is already connected", + ): + create_gmail_account_record(store, user_id=user_id, request=request) + + +def test_get_gmail_account_record_raises_when_account_is_missing() -> None: + with pytest.raises(GmailAccountNotFoundError, match="was not found"): + get_gmail_account_record( + GmailStoreStub(), + user_id=uuid4(), + gmail_account_id=uuid4(), + ) + + +def test_ingest_gmail_message_record_writes_rfc822_artifact_and_reuses_artifact_seam( + monkeypatch, + tmp_path, +) -> None: + store = GmailStoreStub() + user_id = uuid4() + workspace_id = uuid4() + workspace = store.create_task_workspace( + task_workspace_id=workspace_id, + local_path=str((tmp_path / "workspace").resolve()), + ) + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + raw_bytes = _build_rfc822_email_bytes(plain_body="hello from gmail") + calls: dict[str, object] = {} + + monkeypatch.setattr( + "alicebot_api.gmail.fetch_gmail_message_raw_bytes", + lambda **_kwargs: raw_bytes, + ) + + def fake_register(_store, *, user_id: UUID, request): + calls["register_user_id"] = user_id + calls["register_request"] = request + path = Path(request.local_path) + assert path.read_bytes() == raw_bytes + assert path.is_file() + return { + "artifact": { + "id": "00000000-0000-0000-0000-000000000123", + "task_id": str(workspace["task_id"]), + "task_workspace_id": str(workspace_id), + "status": "registered", + "ingestion_status": "pending", + "relative_path": path.relative_to(Path(workspace["local_path"])).as_posix(), + "media_type_hint": "message/rfc822", + "created_at": "2026-03-16T10:00:00+00:00", + "updated_at": "2026-03-16T10:00:00+00:00", + } + } + + def fake_ingest(_store, *, user_id: UUID, request): + calls["ingest_user_id"] = user_id + calls["ingest_request"] = request + return { + "artifact": { + "id": "00000000-0000-0000-0000-000000000123", + "task_id": str(workspace["task_id"]), + "task_workspace_id": str(workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": build_gmail_message_artifact_relative_path( + provider_account_id="acct-001", + provider_message_id="msg-001", + ), + "media_type_hint": "message/rfc822", + "created_at": "2026-03-16T10:00:00+00:00", + "updated_at": "2026-03-16T10:00:01+00:00", + }, + "summary": { + "total_count": 1, + "total_characters": 16, + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + + monkeypatch.setattr("alicebot_api.gmail.register_task_artifact_record", fake_register) + monkeypatch.setattr("alicebot_api.gmail.ingest_task_artifact_record", fake_ingest) + + response = ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=UUID(account["id"]), + task_workspace_id=workspace_id, + provider_message_id="msg-001", + ), + ) + + assert response == { + "account": account, + "message": { + "provider_message_id": "msg-001", + "artifact_relative_path": "gmail/acct-001/msg-001.eml", + "media_type": "message/rfc822", + }, + "artifact": { + "id": "00000000-0000-0000-0000-000000000123", + "task_id": str(workspace["task_id"]), + "task_workspace_id": str(workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "gmail/acct-001/msg-001.eml", + "media_type_hint": "message/rfc822", + "created_at": "2026-03-16T10:00:00+00:00", + "updated_at": "2026-03-16T10:00:01+00:00", + }, + "summary": { + "total_count": 1, + "total_characters": 16, + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert calls["register_user_id"] == user_id + assert calls["ingest_user_id"] == user_id + + +def test_ingest_gmail_message_record_rejects_unsupported_message(monkeypatch, tmp_path) -> None: + store = GmailStoreStub() + user_id = uuid4() + workspace_id = uuid4() + store.create_task_workspace( + task_workspace_id=workspace_id, + local_path=str((tmp_path / "workspace").resolve()), + ) + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + + monkeypatch.setattr( + "alicebot_api.gmail.fetch_gmail_message_raw_bytes", + lambda **_kwargs: b"not-a-valid-rfc822-email", + ) + + with pytest.raises( + GmailMessageUnsupportedError, + match="gmail message msg-unsupported is not a supported RFC822 email", + ): + ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=UUID(account["id"]), + task_workspace_id=workspace_id, + provider_message_id="msg-unsupported", + ), + ) + + +def test_ingest_gmail_message_record_rejects_duplicate_sanitized_path_before_fetch_or_write( + monkeypatch, + tmp_path, +) -> None: + store = GmailStoreStub() + user_id = uuid4() + workspace_id = uuid4() + workspace_path = (tmp_path / "workspace").resolve() + store.create_task_workspace( + task_workspace_id=workspace_id, + local_path=str(workspace_path), + ) + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + relative_path = build_gmail_message_artifact_relative_path( + provider_account_id="acct-001", + provider_message_id="msg/001", + ) + existing_file = workspace_path / relative_path + existing_file.parent.mkdir(parents=True, exist_ok=True) + existing_file.write_bytes(b"original") + store.create_task_artifact( + task_workspace_id=workspace_id, + relative_path=relative_path, + ) + + def fail_fetch(**_kwargs): + raise AssertionError("fetch_gmail_message_raw_bytes should not be called") + + monkeypatch.setattr("alicebot_api.gmail.fetch_gmail_message_raw_bytes", fail_fetch) + + with pytest.raises( + TaskArtifactAlreadyExistsError, + match=f"artifact {relative_path} is already registered for task workspace {workspace_id}", + ): + ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=UUID(account["id"]), + task_workspace_id=workspace_id, + provider_message_id="msg:001", + ), + ) + + assert existing_file.read_bytes() == b"original" + assert store.operations[:2] == [ + ("lock_task_artifacts", workspace_id), + ("get_task_artifact_by_workspace_relative_path_optional", workspace_id), + ] + + +def test_ingest_gmail_message_record_requires_visible_workspace(monkeypatch) -> None: + store = GmailStoreStub() + user_id = uuid4() + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + + monkeypatch.setattr( + "alicebot_api.gmail.fetch_gmail_message_raw_bytes", + lambda **_kwargs: _build_rfc822_email_bytes(plain_body="hello"), + ) + + with pytest.raises(TaskWorkspaceNotFoundError, match="task workspace .* was not found"): + ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=UUID(account["id"]), + task_workspace_id=uuid4(), + provider_message_id="msg-001", + ), + ) diff --git a/tests/unit/test_gmail_main.py b/tests/unit/test_gmail_main.py new file mode 100644 index 0000000..b711e65 --- /dev/null +++ b/tests/unit/test_gmail_main.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.gmail import ( + GmailAccountAlreadyExistsError, + GmailAccountNotFoundError, + GmailMessageFetchError, + GmailMessageNotFoundError, + GmailMessageUnsupportedError, +) +from alicebot_api.workspaces import TaskWorkspaceNotFoundError + + +def test_list_gmail_accounts_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_gmail_account_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + }, + ) + + response = main_module.list_gmail_accounts(user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + +def test_connect_gmail_account_endpoint_maps_duplicate_to_409(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_gmail_account_record(*_args, **_kwargs): + raise GmailAccountAlreadyExistsError("gmail account acct-001 is already connected") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_gmail_account_record", fake_create_gmail_account_record) + + response = main_module.connect_gmail_account( + main_module.ConnectGmailAccountRequest( + user_id=user_id, + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + access_token="token-1", + ) + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": "gmail account acct-001 is already connected" + } + + +def test_get_gmail_account_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + gmail_account_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_gmail_account_record(*_args, **_kwargs): + raise GmailAccountNotFoundError(f"gmail account {gmail_account_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_gmail_account_record", fake_get_gmail_account_record) + + response = main_module.get_gmail_account(gmail_account_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"gmail account {gmail_account_id} was not found"} + + +def test_ingest_gmail_message_endpoint_maps_workspace_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + gmail_account_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_ingest_gmail_message_record(*_args, **_kwargs): + raise TaskWorkspaceNotFoundError(f"task workspace {task_workspace_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_ingest_gmail_message_record) + + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"task workspace {task_workspace_id} was not found" + } + + +def test_ingest_gmail_message_endpoint_maps_upstream_errors(monkeypatch) -> None: + user_id = uuid4() + gmail_account_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + + def fake_missing(*_args, **_kwargs): + raise GmailMessageNotFoundError("gmail message msg-001 was not found") + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_missing) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": "gmail message msg-001 was not found"} + + def fake_unsupported(*_args, **_kwargs): + raise GmailMessageUnsupportedError("gmail message msg-001 is not a supported RFC822 email") + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_unsupported) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "gmail message msg-001 is not a supported RFC822 email" + } + + def fake_fetch_error(*_args, **_kwargs): + raise GmailMessageFetchError("gmail message msg-001 could not be fetched") + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_fetch_error) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 502 + assert json.loads(response.body) == { + "detail": "gmail message msg-001 could not be fetched" + } From 93637e2abfe8a6c5ca418a36f0610ae8f7d0b778 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 18:18:18 +0100 Subject: [PATCH 016/135] Sprint 5P: Gmail credential hardening (#16) * Sprint 5P: Gmail credential hardening packet * Sprint 5P: Gmail credential hardening --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 30 ++- RULES.md | 1 + ...20260316_0027_gmail_account_credentials.py | 128 +++++++++ apps/api/src/alicebot_api/contracts.py | 1 + apps/api/src/alicebot_api/gmail.py | 67 ++++- apps/api/src/alicebot_api/main.py | 4 + apps/api/src/alicebot_api/store.py | 76 +++++- tests/integration/test_gmail_accounts_api.py | 99 +++++++ tests/integration/test_migrations.py | 99 +++++++ ...20260316_0027_gmail_account_credentials.py | 52 ++++ tests/unit/test_gmail.py | 249 +++++++++++++++++- tests/unit/test_gmail_main.py | 40 +++ 12 files changed, 817 insertions(+), 29 deletions(-) create mode 100644 apps/api/alembic/versions/20260316_0027_gmail_account_credentials.py create mode 100644 tests/unit/test_20260316_0027_gmail_account_credentials.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 007de79..6a76ea1 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,17 +2,17 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5O. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5P. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` persistence, deterministic account reads, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline +- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` protected credential storage, deterministic account reads without secret exposure, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline - durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, narrow DOCX text, and narrow RFC822 email text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with credentials resolved through a dedicated protected credential seam instead of the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -38,7 +38,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` - graph tables: `entities`, `entity_edges` - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` - - connector tables: `gmail_accounts` + - connector tables: `gmail_accounts`, `gmail_account_credentials` - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, `task_artifact_chunk_embeddings` - `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. - `memory_review_labels` are append-only by database enforcement. @@ -64,7 +64,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, and Sprint 5O read-only Gmail account plus single-message ingestion coverage. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, and Sprint 5P Gmail credential hardening coverage. ## Core Flows Implemented Now @@ -95,15 +95,17 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Narrow Gmail Connector 1. Accept a user-scoped `POST /v0/gmail-accounts` request for one read-only Gmail account metadata record. -2. Persist exactly the narrow connector metadata currently required for reads later: `provider_account_id`, `email_address`, optional `display_name`, the fixed Gmail read-only scope, and one access token. -3. Expose deterministic user-scoped Gmail account list and detail reads. -4. Accept a user-scoped `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest` request for one visible Gmail account and one visible task workspace. -5. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. -6. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. -7. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. -8. Require Gmail to return RFC822 `raw` content, validate it against the existing narrow `message/rfc822` extraction rules, and reject unsupported content deterministically. -9. Materialize the message as one rooted `.eml` file inside the selected task workspace and then reuse the existing task-artifact registration plus artifact-ingestion seam. -10. Persist only the resulting `task_artifacts` and `task_artifact_chunks` rows; account-wide sync, search, attachments, Calendar, and write-capable actions remain out of scope. +2. Persist exactly the narrow connector metadata required for later reads on `gmail_accounts`: `provider_account_id`, `email_address`, optional `display_name`, and the fixed Gmail read-only scope. +3. Persist the Gmail access token only in the dedicated `gmail_account_credentials` protected credential seam bound to the same visible user/account ownership scope. +4. Expose deterministic user-scoped Gmail account list and detail reads without secret material. +5. Accept a user-scoped `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest` request for one visible Gmail account and one visible task workspace. +6. Resolve the Gmail access token through the protected credential seam before any Gmail fetch, file write, or artifact registration. +7. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. +8. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. +9. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. +10. Require Gmail to return RFC822 `raw` content, validate it against the existing narrow `message/rfc822` extraction rules, and reject unsupported content deterministically. +11. Materialize the message as one rooted `.eml` file inside the selected task workspace and then reuse the existing task-artifact registration plus artifact-ingestion seam. +12. Persist only the resulting `task_artifacts` and `task_artifact_chunks` rows; account-wide sync, search, attachments, Calendar, and write-capable actions remain out of scope. ### Governed Memory And Retrieval diff --git a/RULES.md b/RULES.md index ad493d8..5956d87 100644 --- a/RULES.md +++ b/RULES.md @@ -24,6 +24,7 @@ - Treat Postgres as the v1 system of record unless measured constraints justify a change. - Task-step lineage and execution linkage must stay explicit; do not reconstruct them heuristically from broader task history. - Enforce row-level security on every user-owned table. +- Connector secrets must not be stored on normal metadata tables or exposed on read surfaces; they must use a dedicated protected storage seam. - Default memory admission to `NOOP`; promote only evidence-backed changes and preserve revision history for non-`NOOP` updates. - Apply domain and sensitivity filters before semantic retrieval. diff --git a/apps/api/alembic/versions/20260316_0027_gmail_account_credentials.py b/apps/api/alembic/versions/20260316_0027_gmail_account_credentials.py new file mode 100644 index 0000000..8f5dd87 --- /dev/null +++ b/apps/api/alembic/versions/20260316_0027_gmail_account_credentials.py @@ -0,0 +1,128 @@ +"""Move Gmail access tokens into a protected credential table.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260316_0027" +down_revision = "20260316_0026" +branch_labels = None +depends_on = None + +GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN = "oauth_access_token" +GMAIL_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_access_token_v1" + +_RLS_TABLES = ("gmail_account_credentials",) + +_UPGRADE_SCHEMA_STATEMENT = f""" + CREATE TABLE gmail_account_credentials ( + gmail_account_id uuid PRIMARY KEY REFERENCES gmail_accounts(id) ON DELETE CASCADE, + user_id uuid NOT NULL, + auth_kind text NOT NULL, + credential_blob jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + FOREIGN KEY (gmail_account_id, user_id) + REFERENCES gmail_accounts (id, user_id) + ON DELETE CASCADE, + CONSTRAINT gmail_account_credentials_auth_kind_check + CHECK (auth_kind = '{GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN}'), + CONSTRAINT gmail_account_credentials_blob_shape_check + CHECK ( + jsonb_typeof(credential_blob) = 'object' + AND credential_blob ? 'credential_kind' + AND credential_blob ? 'access_token' + AND credential_blob ->> 'credential_kind' = '{GMAIL_PROTECTED_CREDENTIAL_KIND}' + AND jsonb_typeof(credential_blob -> 'access_token') = 'string' + AND length(credential_blob ->> 'access_token') > 0 + ) + ); + + CREATE INDEX gmail_account_credentials_user_created_idx + ON gmail_account_credentials (user_id, created_at, gmail_account_id); + """ + +_UPGRADE_BACKFILL_STATEMENT = f""" + INSERT INTO gmail_account_credentials ( + gmail_account_id, + user_id, + auth_kind, + credential_blob, + created_at, + updated_at + ) + SELECT + id, + user_id, + '{GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN}', + jsonb_build_object( + 'credential_kind', '{GMAIL_PROTECTED_CREDENTIAL_KIND}', + 'access_token', access_token + ), + created_at, + updated_at + FROM gmail_accounts; + """ + +_UPGRADE_DROP_PLAINTEXT_STATEMENTS = ( + "ALTER TABLE gmail_accounts DROP CONSTRAINT gmail_accounts_access_token_nonempty_check", + "ALTER TABLE gmail_accounts DROP COLUMN access_token", +) + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON gmail_account_credentials TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY gmail_account_credentials_is_owner ON gmail_account_credentials + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_ADD_PLAINTEXT_STATEMENTS = ( + "ALTER TABLE gmail_accounts ADD COLUMN access_token text", +) + +_DOWNGRADE_BACKFILL_STATEMENT = """ + UPDATE gmail_accounts AS accounts + SET access_token = credentials.credential_blob ->> 'access_token' + FROM gmail_account_credentials AS credentials + WHERE credentials.gmail_account_id = accounts.id + """ + +_DOWNGRADE_RESTORE_CONSTRAINT_STATEMENTS = ( + "ALTER TABLE gmail_accounts ALTER COLUMN access_token SET NOT NULL", + """ + ALTER TABLE gmail_accounts + ADD CONSTRAINT gmail_accounts_access_token_nonempty_check + CHECK (length(access_token) > 0) + """, + "DROP TABLE IF EXISTS gmail_account_credentials", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_BACKFILL_STATEMENT) + _execute_statements(_UPGRADE_DROP_PLAINTEXT_STATEMENTS) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_ADD_PLAINTEXT_STATEMENTS) + op.execute(_DOWNGRADE_BACKFILL_STATEMENT) + _execute_statements(_DOWNGRADE_RESTORE_CONSTRAINT_STATEMENTS) diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index d614b16..035273e 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -178,6 +178,7 @@ GMAIL_PROVIDER = "gmail" GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN = "oauth_access_token" GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly" +GMAIL_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_access_token_v1" TASK_STEP_SEQUENCE_VERSION_V0 = "task_step_sequence_v0" TRACE_KIND_TASK_STEP_SEQUENCE = "task.step.sequence" TASK_STEP_CONTINUATION_VERSION_V0 = "task_step_continuation_v0" diff --git a/apps/api/src/alicebot_api/gmail.py b/apps/api/src/alicebot_api/gmail.py index 6243916..9d908d1 100644 --- a/apps/api/src/alicebot_api/gmail.py +++ b/apps/api/src/alicebot_api/gmail.py @@ -23,6 +23,7 @@ from alicebot_api.contracts import ( GMAIL_ACCOUNT_LIST_ORDER, GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, + GMAIL_PROTECTED_CREDENTIAL_KIND, GMAIL_PROVIDER, GMAIL_READONLY_SCOPE, GmailAccountConnectInput, @@ -63,6 +64,14 @@ class GmailMessageFetchError(RuntimeError): """Raised when the Gmail API call fails for non-deterministic upstream reasons.""" +class GmailCredentialNotFoundError(RuntimeError): + """Raised when Gmail protected credentials are missing for a visible account.""" + + +class GmailCredentialInvalidError(RuntimeError): + """Raised when Gmail protected credentials are malformed for a visible account.""" + + def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: return { "id": str(row["id"]), @@ -77,6 +86,49 @@ def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: } +def build_gmail_protected_credential_blob(*, access_token: str) -> dict[str, str]: + return { + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "access_token": access_token, + } + + +def resolve_gmail_access_token( + store: ContinuityStore, + *, + gmail_account_id: UUID, +) -> str: + credential = store.get_gmail_account_credential_optional(gmail_account_id) + if credential is None: + raise GmailCredentialNotFoundError( + f"gmail account {gmail_account_id} is missing protected credentials" + ) + + if credential["auth_kind"] != GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + credential_blob = credential["credential_blob"] + if not isinstance(credential_blob, dict): + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + credential_kind = credential_blob.get("credential_kind") + access_token = credential_blob.get("access_token") + if ( + credential_kind != GMAIL_PROTECTED_CREDENTIAL_KIND + or not isinstance(access_token, str) + or access_token == "" + ): + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + return access_token + + def create_gmail_account_record( store: ContinuityStore, *, @@ -97,7 +149,13 @@ def create_gmail_account_record( email_address=request.email_address, display_name=request.display_name, scope=request.scope, - access_token=request.access_token, + ) + store.create_gmail_account_credential( + gmail_account_id=row["id"], + auth_kind=GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, + credential_blob=build_gmail_protected_credential_blob( + access_token=request.access_token, + ), ) except psycopg.errors.UniqueViolation as exc: raise GmailAccountAlreadyExistsError( @@ -215,6 +273,11 @@ def ingest_gmail_message_record( f"task workspace {request.task_workspace_id} was not found" ) + access_token = resolve_gmail_access_token( + store, + gmail_account_id=request.gmail_account_id, + ) + store.lock_task_artifacts(workspace["id"]) relative_path = build_gmail_message_artifact_relative_path( provider_account_id=account["provider_account_id"], @@ -230,7 +293,7 @@ def ingest_gmail_message_record( ) raw_bytes = fetch_gmail_message_raw_bytes( - access_token=account["access_token"], + access_token=access_token, provider_message_id=request.provider_message_id, ) diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index c5a224b..9fbe271 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -140,6 +140,8 @@ ) from alicebot_api.gmail import ( GmailAccountAlreadyExistsError, + GmailCredentialInvalidError, + GmailCredentialNotFoundError, GmailAccountNotFoundError, GmailMessageFetchError, GmailMessageNotFoundError, @@ -1373,6 +1375,8 @@ def ingest_gmail_message( return JSONResponse(status_code=404, content={"detail": str(exc)}) except GmailMessageUnsupportedError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) + except (GmailCredentialNotFoundError, GmailCredentialInvalidError) as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) except TaskArtifactValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) except GmailMessageFetchError as exc: diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index 4098cda..4822a44 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -252,7 +252,15 @@ class GmailAccountRow(TypedDict): email_address: str display_name: str | None scope: str - access_token: str + created_at: datetime + updated_at: datetime + + +class ProtectedGmailCredentialRow(TypedDict): + gmail_account_id: UUID + user_id: UUID + auth_kind: str + credential_blob: JsonObject created_at: datetime updated_at: datetime @@ -1488,7 +1496,6 @@ class LabelCountRow(TypedDict): email_address, display_name, scope, - access_token, created_at, updated_at ) @@ -1498,7 +1505,6 @@ class LabelCountRow(TypedDict): %s, %s, %s, - %s, clock_timestamp(), clock_timestamp() ) @@ -1509,7 +1515,32 @@ class LabelCountRow(TypedDict): email_address, display_name, scope, - access_token, + created_at, + updated_at + """ + +INSERT_GMAIL_ACCOUNT_CREDENTIAL_SQL = """ + INSERT INTO gmail_account_credentials ( + gmail_account_id, + user_id, + auth_kind, + credential_blob, + created_at, + updated_at + ) + VALUES ( + %s, + app.current_user_id(), + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + gmail_account_id, + user_id, + auth_kind, + credential_blob, created_at, updated_at """ @@ -1522,7 +1553,6 @@ class LabelCountRow(TypedDict): email_address, display_name, scope, - access_token, created_at, updated_at FROM gmail_accounts @@ -1537,7 +1567,6 @@ class LabelCountRow(TypedDict): email_address, display_name, scope, - access_token, created_at, updated_at FROM gmail_accounts @@ -1546,6 +1575,18 @@ class LabelCountRow(TypedDict): LIMIT 1 """ +GET_GMAIL_ACCOUNT_CREDENTIAL_SQL = """ + SELECT + gmail_account_id, + user_id, + auth_kind, + credential_blob, + created_at, + updated_at + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """ + LIST_GMAIL_ACCOUNTS_SQL = """ SELECT id, @@ -1554,7 +1595,6 @@ class LabelCountRow(TypedDict): email_address, display_name, scope, - access_token, created_at, updated_at FROM gmail_accounts @@ -3086,17 +3126,35 @@ def create_gmail_account( email_address: str, display_name: str | None, scope: str, - access_token: str, ) -> GmailAccountRow: return self._fetch_one( "create_gmail_account", INSERT_GMAIL_ACCOUNT_SQL, - (provider_account_id, email_address, display_name, scope, access_token), + (provider_account_id, email_address, display_name, scope), + ) + + def create_gmail_account_credential( + self, + *, + gmail_account_id: UUID, + auth_kind: str, + credential_blob: JsonObject, + ) -> ProtectedGmailCredentialRow: + return self._fetch_one( + "create_gmail_account_credential", + INSERT_GMAIL_ACCOUNT_CREDENTIAL_SQL, + (gmail_account_id, auth_kind, Jsonb(credential_blob)), ) def get_gmail_account_optional(self, gmail_account_id: UUID) -> GmailAccountRow | None: return self._fetch_optional_one(GET_GMAIL_ACCOUNT_SQL, (gmail_account_id,)) + def get_gmail_account_credential_optional( + self, + gmail_account_id: UUID, + ) -> ProtectedGmailCredentialRow | None: + return self._fetch_optional_one(GET_GMAIL_ACCOUNT_CREDENTIAL_SQL, (gmail_account_id,)) + def get_gmail_account_by_provider_account_id_optional( self, provider_account_id: str, diff --git a/tests/integration/test_gmail_accounts_api.py b/tests/integration/test_gmail_accounts_api.py index d9aa1df..334cabd 100644 --- a/tests/integration/test_gmail_accounts_api.py +++ b/tests/integration/test_gmail_accounts_api.py @@ -7,6 +7,7 @@ from uuid import UUID, uuid4 import anyio +import psycopg import apps.api.src.alicebot_api.main as main_module from apps.api.src.alicebot_api.config import Settings @@ -240,6 +241,39 @@ def test_gmail_account_endpoints_connect_list_detail_and_isolate( assert isolated_detail_payload == { "detail": f"gmail account {create_payload['account']['id']} was not found" } + assert '"access_token":' not in json.dumps(create_payload) + assert '"access_token":' not in json.dumps(list_payload) + assert '"access_token":' not in json.dumps(detail_payload) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'gmail_accounts' + ORDER BY ordinal_position + """ + ) + gmail_account_columns = {row[0] for row in cur.fetchall()} + assert "access_token" not in gmail_account_columns + cur.execute( + """ + SELECT + auth_kind, + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (UUID(create_payload["account"]["id"]),), + ) + assert cur.fetchone() == ( + "oauth_access_token", + "gmail_oauth_access_token_v1", + "token-for-acct-owner-001", + ) def test_gmail_message_ingestion_endpoint_persists_artifact_and_chunks( @@ -379,6 +413,71 @@ def test_gmail_message_ingestion_endpoint_rejects_cross_user_workspace_access( } +def test_gmail_message_ingestion_endpoint_rejects_missing_protected_credentials_without_side_effects( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + def fail_fetch(**_kwargs): + raise AssertionError("fetch_gmail_message_raw_bytes should not be called") + + monkeypatch.setattr(gmail_module, "fetch_gmail_message_raw_bytes", fail_fetch) + + _, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-001", + email_address="owner@gmail.example", + ) + _, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + "DELETE FROM gmail_account_credentials WHERE gmail_account_id = %s", + (UUID(account_payload["account"]["id"]),), + ) + conn.commit() + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert ingest_status == 409 + assert ingest_payload == { + "detail": ( + f"gmail account {account_payload['account']['id']} is missing protected credentials" + ) + } + artifact_file = ( + Path(workspace_payload["workspace"]["local_path"]) / "gmail" / "acct-owner-001" / "msg-001.eml" + ) + assert not artifact_file.exists() + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_task_artifacts_for_task(owner["task_id"]) == [] + + def test_gmail_message_ingestion_endpoint_rejects_sanitized_path_collisions_without_overwrite( migrated_database_urls, monkeypatch, diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index e8e7011..7e4817d 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -248,6 +248,105 @@ def test_tool_execution_task_step_linkage_migration_backfills_existing_rows(data assert cur.fetchone() == ("NO",) +def test_gmail_account_credentials_migration_round_trip_preserves_tokens(database_urls): + config = make_alembic_config(database_urls["admin"]) + user_id = "00000000-0000-0000-0000-000000000101" + gmail_account_id = "00000000-0000-0000-0000-000000000102" + + command.upgrade(config, "20260316_0026") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, 'gmail-migration@example.com', 'Gmail Migration User') + """, + (user_id,), + ) + cur.execute( + """ + INSERT INTO gmail_accounts ( + id, + user_id, + provider_account_id, + email_address, + display_name, + scope, + access_token + ) + VALUES ( + %s, + %s, + 'acct-migration-001', + 'owner@gmail.example', + 'Owner', + 'https://www.googleapis.com/auth/gmail.readonly', + 'token-before-hardening' + ) + """, + (gmail_account_id, user_id), + ) + conn.commit() + + command.upgrade(config, "20260316_0027") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'gmail_accounts' + AND column_name = 'access_token' + """ + ) + assert cur.fetchone() is None + cur.execute( + """ + SELECT + auth_kind, + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (gmail_account_id,), + ) + assert cur.fetchone() == ( + "oauth_access_token", + "gmail_oauth_access_token_v1", + "token-before-hardening", + ) + + command.downgrade(config, "20260316_0026") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'gmail_accounts' + AND column_name = 'access_token' + """ + ) + assert cur.fetchone() == ("access_token",) + cur.execute( + """ + SELECT access_token + FROM gmail_accounts + WHERE id = %s + """, + (gmail_account_id,), + ) + assert cur.fetchone() == ("token-before-hardening",) + cur.execute("SELECT to_regclass('public.gmail_account_credentials')") + assert cur.fetchone() == (None,) + + def test_migrations_upgrade_and_downgrade(database_urls): config = make_alembic_config(database_urls["admin"]) diff --git a/tests/unit/test_20260316_0027_gmail_account_credentials.py b/tests/unit/test_20260316_0027_gmail_account_credentials.py new file mode 100644 index 0000000..ef80e31 --- /dev/null +++ b/tests/unit/test_20260316_0027_gmail_account_credentials.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260316_0027_gmail_account_credentials" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_BACKFILL_STATEMENT, + *module._UPGRADE_DROP_PLAINTEXT_STATEMENTS, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE gmail_account_credentials ENABLE ROW LEVEL SECURITY", + "ALTER TABLE gmail_account_credentials FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == [ + *module._DOWNGRADE_ADD_PLAINTEXT_STATEMENTS, + module._DOWNGRADE_BACKFILL_STATEMENT, + *module._DOWNGRADE_RESTORE_CONSTRAINT_STATEMENTS, + ] + + +def test_gmail_account_credential_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON gmail_account_credentials TO alicebot_app", + ) diff --git a/tests/unit/test_gmail.py b/tests/unit/test_gmail.py index d602b5c..b84c945 100644 --- a/tests/unit/test_gmail.py +++ b/tests/unit/test_gmail.py @@ -7,16 +7,25 @@ import pytest from alicebot_api.artifacts import TaskArtifactAlreadyExistsError -from alicebot_api.contracts import GMAIL_READONLY_SCOPE, GmailAccountConnectInput, GmailMessageIngestInput +from alicebot_api.contracts import ( + GMAIL_PROTECTED_CREDENTIAL_KIND, + GMAIL_READONLY_SCOPE, + GmailAccountConnectInput, + GmailMessageIngestInput, +) from alicebot_api.gmail import ( GmailAccountAlreadyExistsError, GmailAccountNotFoundError, + GmailCredentialInvalidError, + GmailCredentialNotFoundError, GmailMessageUnsupportedError, build_gmail_message_artifact_relative_path, + build_gmail_protected_credential_blob, create_gmail_account_record, get_gmail_account_record, ingest_gmail_message_record, list_gmail_account_records, + resolve_gmail_access_token, ) from alicebot_api.workspaces import TaskWorkspaceNotFoundError @@ -41,6 +50,7 @@ class GmailStoreStub: def __init__(self) -> None: self.base_time = datetime(2026, 3, 16, 10, 0, tzinfo=UTC) self.gmail_accounts: list[dict[str, object]] = [] + self.gmail_account_credentials: dict[UUID, dict[str, object]] = {} self.task_workspaces: list[dict[str, object]] = [] self.task_artifacts: list[dict[str, object]] = [] self.operations: list[tuple[str, object]] = [] @@ -52,7 +62,6 @@ def create_gmail_account( email_address: str, display_name: str | None, scope: str, - access_token: str, ) -> dict[str, object]: row = { "id": uuid4(), @@ -61,19 +70,48 @@ def create_gmail_account( "email_address": email_address, "display_name": display_name, "scope": scope, - "access_token": access_token, "created_at": self.base_time + timedelta(minutes=len(self.gmail_accounts)), "updated_at": self.base_time + timedelta(minutes=len(self.gmail_accounts)), } self.gmail_accounts.append(row) return row + def create_gmail_account_credential( + self, + *, + gmail_account_id: UUID, + auth_kind: str, + credential_blob: dict[str, object], + ) -> dict[str, object]: + row = { + "gmail_account_id": gmail_account_id, + "user_id": next( + account["user_id"] + for account in self.gmail_accounts + if account["id"] == gmail_account_id + ), + "auth_kind": auth_kind, + "credential_blob": credential_blob, + "created_at": self.base_time + timedelta(minutes=len(self.gmail_account_credentials)), + "updated_at": self.base_time + timedelta(minutes=len(self.gmail_account_credentials)), + } + self.gmail_account_credentials[gmail_account_id] = row + self.operations.append(("create_gmail_account_credential", gmail_account_id)) + return row + def get_gmail_account_optional(self, gmail_account_id: UUID) -> dict[str, object] | None: return next( (row for row in self.gmail_accounts if row["id"] == gmail_account_id), None, ) + def get_gmail_account_credential_optional( + self, + gmail_account_id: UUID, + ) -> dict[str, object] | None: + self.operations.append(("get_gmail_account_credential_optional", gmail_account_id)) + return self.gmail_account_credentials.get(gmail_account_id) + def get_gmail_account_by_provider_account_id_optional( self, provider_account_id: str, @@ -192,6 +230,45 @@ def test_create_list_and_get_gmail_account_records_are_deterministic() -> None: user_id=user_id, gmail_account_id=UUID(second["account"]["id"]), ) == {"account": second["account"]} + assert "access_token" not in first["account"] + assert "access_token" not in second["account"] + + +def test_create_gmail_account_record_persists_protected_credential_and_hides_secret() -> None: + store = GmailStoreStub() + user_id = uuid4() + + response = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + ) + + account_id = UUID(response["account"]["id"]) + assert response == { + "account": { + "id": str(account_id), + "provider": "gmail", + "auth_kind": "oauth_access_token", + "provider_account_id": "acct-001", + "email_address": "owner@example.com", + "display_name": "Owner", + "scope": GMAIL_READONLY_SCOPE, + "created_at": response["account"]["created_at"], + "updated_at": response["account"]["updated_at"], + } + } + assert store.gmail_account_credentials[account_id]["credential_blob"] == { + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "access_token": "token-1", + } + assert store.operations == [("create_gmail_account_credential", account_id)] def test_create_gmail_account_record_rejects_duplicate_provider_account_id() -> None: @@ -222,6 +299,63 @@ def test_get_gmail_account_record_raises_when_account_is_missing() -> None: ) +def test_resolve_gmail_access_token_reads_protected_credential() -> None: + store = GmailStoreStub() + account = create_gmail_account_record( + store, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + + assert resolve_gmail_access_token( + store, + gmail_account_id=UUID(account["id"]), + ) == "token-1" + + +def test_resolve_gmail_access_token_rejects_missing_and_invalid_protected_credentials() -> None: + store = GmailStoreStub() + account = create_gmail_account_record( + store, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + account_id = UUID(account["id"]) + + store.gmail_account_credentials.pop(account_id) + with pytest.raises( + GmailCredentialNotFoundError, + match=f"gmail account {account_id} is missing protected credentials", + ): + resolve_gmail_access_token(store, gmail_account_id=account_id) + + store.gmail_account_credentials[account_id] = { + "gmail_account_id": account_id, + "user_id": uuid4(), + "auth_kind": "oauth_access_token", + "credential_blob": {"credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND}, + "created_at": store.base_time, + "updated_at": store.base_time, + } + with pytest.raises( + GmailCredentialInvalidError, + match=f"gmail account {account_id} has invalid protected credentials", + ): + resolve_gmail_access_token(store, gmail_account_id=account_id) + + def test_ingest_gmail_message_record_writes_rfc822_artifact_and_reuses_artifact_seam( monkeypatch, tmp_path, @@ -340,6 +474,10 @@ def fake_ingest(_store, *, user_id: UUID, request): } assert calls["register_user_id"] == user_id assert calls["ingest_user_id"] == user_id + assert store.operations[:2] == [ + ("create_gmail_account_credential", UUID(account["id"])), + ("get_gmail_account_credential_optional", UUID(account["id"])), + ] def test_ingest_gmail_message_record_rejects_unsupported_message(monkeypatch, tmp_path) -> None: @@ -437,7 +575,8 @@ def fail_fetch(**_kwargs): ) assert existing_file.read_bytes() == b"original" - assert store.operations[:2] == [ + assert store.operations[-3:] == [ + ("get_gmail_account_credential_optional", UUID(account["id"])), ("lock_task_artifacts", workspace_id), ("get_task_artifact_by_workspace_relative_path_optional", workspace_id), ] @@ -473,3 +612,105 @@ def test_ingest_gmail_message_record_requires_visible_workspace(monkeypatch) -> provider_message_id="msg-001", ), ) + + +def test_ingest_gmail_message_record_rejects_missing_protected_credentials_before_artifact_work( + monkeypatch, + tmp_path, +) -> None: + store = GmailStoreStub() + user_id = uuid4() + workspace_id = uuid4() + workspace_path = (tmp_path / "workspace").resolve() + store.create_task_workspace( + task_workspace_id=workspace_id, + local_path=str(workspace_path), + ) + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + account_id = UUID(account["id"]) + store.gmail_account_credentials.pop(account_id) + + def fail_fetch(**_kwargs): + raise AssertionError("fetch_gmail_message_raw_bytes should not be called") + + monkeypatch.setattr("alicebot_api.gmail.fetch_gmail_message_raw_bytes", fail_fetch) + + with pytest.raises( + GmailCredentialNotFoundError, + match=f"gmail account {account_id} is missing protected credentials", + ): + ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=account_id, + task_workspace_id=workspace_id, + provider_message_id="msg-001", + ), + ) + + assert store.task_artifacts == [] + assert not workspace_path.exists() + assert ("lock_task_artifacts", workspace_id) not in store.operations + + +def test_ingest_gmail_message_record_rejects_invalid_protected_credentials_before_artifact_work( + monkeypatch, + tmp_path, +) -> None: + store = GmailStoreStub() + user_id = uuid4() + workspace_id = uuid4() + workspace_path = (tmp_path / "workspace").resolve() + store.create_task_workspace( + task_workspace_id=workspace_id, + local_path=str(workspace_path), + ) + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + account_id = UUID(account["id"]) + store.gmail_account_credentials[account_id]["credential_blob"] = build_gmail_protected_credential_blob( + access_token="", + ) + + def fail_fetch(**_kwargs): + raise AssertionError("fetch_gmail_message_raw_bytes should not be called") + + monkeypatch.setattr("alicebot_api.gmail.fetch_gmail_message_raw_bytes", fail_fetch) + + with pytest.raises( + GmailCredentialInvalidError, + match=f"gmail account {account_id} has invalid protected credentials", + ): + ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=account_id, + task_workspace_id=workspace_id, + provider_message_id="msg-001", + ), + ) + + assert store.task_artifacts == [] + assert not workspace_path.exists() + assert ("lock_task_artifacts", workspace_id) not in store.operations diff --git a/tests/unit/test_gmail_main.py b/tests/unit/test_gmail_main.py index b711e65..18fc375 100644 --- a/tests/unit/test_gmail_main.py +++ b/tests/unit/test_gmail_main.py @@ -9,6 +9,8 @@ from alicebot_api.gmail import ( GmailAccountAlreadyExistsError, GmailAccountNotFoundError, + GmailCredentialInvalidError, + GmailCredentialNotFoundError, GmailMessageFetchError, GmailMessageNotFoundError, GmailMessageUnsupportedError, @@ -174,6 +176,44 @@ def fake_unsupported(*_args, **_kwargs): "detail": "gmail message msg-001 is not a supported RFC822 email" } + def fake_missing_credentials(*_args, **_kwargs): + raise GmailCredentialNotFoundError( + f"gmail account {gmail_account_id} is missing protected credentials" + ) + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_missing_credentials) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"gmail account {gmail_account_id} is missing protected credentials" + } + + def fake_invalid_credentials(*_args, **_kwargs): + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_invalid_credentials) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"gmail account {gmail_account_id} has invalid protected credentials" + } + def fake_fetch_error(*_args, **_kwargs): raise GmailMessageFetchError("gmail message msg-001 could not be fetched") From d3bf99dcb2475fd3b16080c1c61616c6a1b8edb3 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 18:50:24 +0100 Subject: [PATCH 017/135] Sprint 5Q: Gmail refresh token lifecycle (#17) * Sprint 5Q: Gmail refresh token lifecycle packet * Sprint 5Q: Gmail refresh token lifecycle --------- Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 14 +- ...0316_0028_gmail_refresh_token_lifecycle.py | 89 ++++++ apps/api/src/alicebot_api/contracts.py | 5 + apps/api/src/alicebot_api/gmail.py | 247 +++++++++++++++-- apps/api/src/alicebot_api/main.py | 33 ++- apps/api/src/alicebot_api/store.py | 29 ++ tests/integration/test_gmail_accounts_api.py | 201 +++++++++++++- tests/integration/test_migrations.py | 120 ++++++++ ...0316_0028_gmail_refresh_token_lifecycle.py | 40 +++ tests/unit/test_gmail.py | 262 +++++++++++++++++- tests/unit/test_gmail_main.py | 71 +++++ tests/unit/test_gmail_refresh.py | 167 +++++++++++ 12 files changed, 1238 insertions(+), 40 deletions(-) create mode 100644 apps/api/alembic/versions/20260316_0028_gmail_refresh_token_lifecycle.py create mode 100644 tests/unit/test_20260316_0028_gmail_refresh_token_lifecycle.py create mode 100644 tests/unit/test_gmail_refresh.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 6a76ea1..3e7ad09 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,17 +2,17 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5P. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5Q. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` protected credential storage, deterministic account reads without secret exposure, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline +- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` protected credential storage, deterministic account reads without secret exposure, refresh-token-capable protected credential renewal for expired access tokens, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline - durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, narrow DOCX text, and narrow RFC822 email text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with credentials resolved through a dedicated protected credential seam instead of the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with credentials resolved through a dedicated protected credential seam, renewed through one explicit refresh path when an expired refresh-capable credential is present, and never exposed on the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -60,11 +60,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, a narrow read-only Gmail connector seam, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, a narrow read-only Gmail connector seam with protected refresh-token lifecycle support, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, and Sprint 5P Gmail credential hardening coverage. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, Sprint 5P Gmail credential hardening coverage, and Sprint 5Q Gmail refresh-token lifecycle coverage. ## Core Flows Implemented Now @@ -96,10 +96,10 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 1. Accept a user-scoped `POST /v0/gmail-accounts` request for one read-only Gmail account metadata record. 2. Persist exactly the narrow connector metadata required for later reads on `gmail_accounts`: `provider_account_id`, `email_address`, optional `display_name`, and the fixed Gmail read-only scope. -3. Persist the Gmail access token only in the dedicated `gmail_account_credentials` protected credential seam bound to the same visible user/account ownership scope. +3. Persist Gmail secrets only in the dedicated `gmail_account_credentials` protected credential seam bound to the same visible user/account ownership scope, using either a narrow access-token-only shape or a refresh-token-capable credential shape with expiry metadata. 4. Expose deterministic user-scoped Gmail account list and detail reads without secret material. 5. Accept a user-scoped `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest` request for one visible Gmail account and one visible task workspace. -6. Resolve the Gmail access token through the protected credential seam before any Gmail fetch, file write, or artifact registration. +6. Resolve the Gmail access token through the protected credential seam before any Gmail fetch, file write, or artifact registration, and renew it first through one explicit refresh path when the visible protected credential is refresh-capable and expired. 7. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. 8. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. 9. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. diff --git a/apps/api/alembic/versions/20260316_0028_gmail_refresh_token_lifecycle.py b/apps/api/alembic/versions/20260316_0028_gmail_refresh_token_lifecycle.py new file mode 100644 index 0000000..7d30c37 --- /dev/null +++ b/apps/api/alembic/versions/20260316_0028_gmail_refresh_token_lifecycle.py @@ -0,0 +1,89 @@ +"""Allow Gmail protected credentials to store refresh-token lifecycle data.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260316_0028" +down_revision = "20260316_0027" +branch_labels = None +depends_on = None + +GMAIL_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_access_token_v1" +GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_refresh_token_v2" + +_UPGRADE_STATEMENTS = ( + "ALTER TABLE gmail_account_credentials DROP CONSTRAINT gmail_account_credentials_blob_shape_check", + f""" + ALTER TABLE gmail_account_credentials + ADD CONSTRAINT gmail_account_credentials_blob_shape_check + CHECK ( + jsonb_typeof(credential_blob) = 'object' + AND credential_blob ? 'credential_kind' + AND credential_blob ? 'access_token' + AND jsonb_typeof(credential_blob -> 'access_token') = 'string' + AND length(credential_blob ->> 'access_token') > 0 + AND ( + ( + credential_blob ->> 'credential_kind' = '{GMAIL_PROTECTED_CREDENTIAL_KIND}' + ) + OR + ( + credential_blob ->> 'credential_kind' = '{GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND}' + AND credential_blob ? 'refresh_token' + AND credential_blob ? 'client_id' + AND credential_blob ? 'client_secret' + AND credential_blob ? 'access_token_expires_at' + AND jsonb_typeof(credential_blob -> 'refresh_token') = 'string' + AND jsonb_typeof(credential_blob -> 'client_id') = 'string' + AND jsonb_typeof(credential_blob -> 'client_secret') = 'string' + AND jsonb_typeof(credential_blob -> 'access_token_expires_at') = 'string' + AND length(credential_blob ->> 'refresh_token') > 0 + AND length(credential_blob ->> 'client_id') > 0 + AND length(credential_blob ->> 'client_secret') > 0 + AND length(credential_blob ->> 'access_token_expires_at') > 0 + ) + ) + ) + """, + "GRANT UPDATE ON gmail_account_credentials TO alicebot_app", +) + +_DOWNGRADE_STATEMENTS = ( + """ + UPDATE gmail_account_credentials + SET credential_blob = jsonb_build_object( + 'credential_kind', 'gmail_oauth_access_token_v1', + 'access_token', credential_blob ->> 'access_token' + ) + WHERE credential_blob ->> 'credential_kind' = 'gmail_oauth_refresh_token_v2' + """, + "REVOKE UPDATE ON gmail_account_credentials FROM alicebot_app", + "ALTER TABLE gmail_account_credentials DROP CONSTRAINT gmail_account_credentials_blob_shape_check", + f""" + ALTER TABLE gmail_account_credentials + ADD CONSTRAINT gmail_account_credentials_blob_shape_check + CHECK ( + jsonb_typeof(credential_blob) = 'object' + AND credential_blob ? 'credential_kind' + AND credential_blob ? 'access_token' + AND credential_blob ->> 'credential_kind' = '{GMAIL_PROTECTED_CREDENTIAL_KIND}' + AND jsonb_typeof(credential_blob -> 'access_token') = 'string' + AND length(credential_blob ->> 'access_token') > 0 + ) + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py index 035273e..4453118 100644 --- a/apps/api/src/alicebot_api/contracts.py +++ b/apps/api/src/alicebot_api/contracts.py @@ -179,6 +179,7 @@ GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN = "oauth_access_token" GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly" GMAIL_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_access_token_v1" +GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_refresh_token_v2" TASK_STEP_SEQUENCE_VERSION_V0 = "task_step_sequence_v0" TRACE_KIND_TASK_STEP_SEQUENCE = "task.step.sequence" TASK_STEP_CONTINUATION_VERSION_V0 = "task_step_continuation_v0" @@ -1781,6 +1782,10 @@ class GmailAccountConnectInput: display_name: str | None scope: str access_token: str + refresh_token: str | None = None + client_id: str | None = None + client_secret: str | None = None + access_token_expires_at: datetime | None = None @dataclass(frozen=True, slots=True) diff --git a/apps/api/src/alicebot_api/gmail.py b/apps/api/src/alicebot_api/gmail.py index 9d908d1..5d60a33 100644 --- a/apps/api/src/alicebot_api/gmail.py +++ b/apps/api/src/alicebot_api/gmail.py @@ -3,9 +3,11 @@ import base64 import json import re +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta from pathlib import Path from urllib.error import HTTPError, URLError -from urllib.parse import quote +from urllib.parse import quote, urlencode from urllib.request import Request, urlopen from uuid import UUID @@ -25,6 +27,7 @@ GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, GMAIL_PROTECTED_CREDENTIAL_KIND, GMAIL_PROVIDER, + GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, GMAIL_READONLY_SCOPE, GmailAccountConnectInput, GmailAccountConnectResponse, @@ -40,6 +43,8 @@ from alicebot_api.workspaces import TaskWorkspaceNotFoundError GMAIL_MESSAGE_FETCH_TIMEOUT_SECONDS = 30 +GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS = 30 +GMAIL_TOKEN_REFRESH_URL = "https://oauth2.googleapis.com/token" GMAIL_MESSAGE_ARTIFACT_ROOT = "gmail" _PATH_SEGMENT_PATTERN = re.compile(r"[^A-Za-z0-9._-]+") @@ -72,6 +77,24 @@ class GmailCredentialInvalidError(RuntimeError): """Raised when Gmail protected credentials are malformed for a visible account.""" +class GmailCredentialRefreshError(RuntimeError): + """Raised when Gmail access-token renewal fails for non-deterministic reasons.""" + + +class GmailCredentialValidationError(ValueError): + """Raised when Gmail connect input contains an invalid credential combination.""" + + +@dataclass(frozen=True, slots=True) +class ParsedGmailCredential: + access_token: str + credential_kind: str + refresh_token: str | None = None + client_id: str | None = None + client_secret: str | None = None + access_token_expires_at: datetime | None = None + + def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: return { "id": str(row["id"]), @@ -86,13 +109,184 @@ def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: } -def build_gmail_protected_credential_blob(*, access_token: str) -> dict[str, str]: +def _coerce_nonempty_string(value: object) -> str | None: + if not isinstance(value, str): + return None + normalized = value.strip() + if normalized == "": + return None + return normalized + + +def _normalize_datetime(value: datetime) -> datetime: + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value.astimezone(UTC) + + +def build_gmail_protected_credential_blob( + *, + access_token: str, + refresh_token: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + access_token_expires_at: datetime | None = None, +) -> dict[str, str]: + normalized_access_token = _coerce_nonempty_string(access_token) + if normalized_access_token is None: + raise GmailCredentialValidationError("gmail access token must be non-empty") + + refresh_bundle = ( + refresh_token, + client_id, + client_secret, + access_token_expires_at, + ) + if all(value is None for value in refresh_bundle): + return { + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "access_token": normalized_access_token, + } + + normalized_refresh_token = _coerce_nonempty_string(refresh_token) + normalized_client_id = _coerce_nonempty_string(client_id) + normalized_client_secret = _coerce_nonempty_string(client_secret) + if ( + normalized_refresh_token is None + or normalized_client_id is None + or normalized_client_secret is None + or access_token_expires_at is None + ): + raise GmailCredentialValidationError( + "gmail refresh credentials must include refresh_token, client_id, client_secret, " + "and access_token_expires_at" + ) + + normalized_expires_at = _normalize_datetime(access_token_expires_at) return { - "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, - "access_token": access_token, + "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, + "access_token": normalized_access_token, + "refresh_token": normalized_refresh_token, + "client_id": normalized_client_id, + "client_secret": normalized_client_secret, + "access_token_expires_at": normalized_expires_at.isoformat(), } +def _parse_gmail_credential( + *, + gmail_account_id: UUID, + credential_blob: object, +) -> ParsedGmailCredential: + if not isinstance(credential_blob, dict): + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + credential_kind = credential_blob.get("credential_kind") + access_token = _coerce_nonempty_string(credential_blob.get("access_token")) + if access_token is None: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + if credential_kind == GMAIL_PROTECTED_CREDENTIAL_KIND: + return ParsedGmailCredential( + access_token=access_token, + credential_kind=credential_kind, + ) + + if credential_kind != GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + refresh_token = _coerce_nonempty_string(credential_blob.get("refresh_token")) + client_id = _coerce_nonempty_string(credential_blob.get("client_id")) + client_secret = _coerce_nonempty_string(credential_blob.get("client_secret")) + access_token_expires_at_raw = _coerce_nonempty_string( + credential_blob.get("access_token_expires_at") + ) + if ( + refresh_token is None + or client_id is None + or client_secret is None + or access_token_expires_at_raw is None + ): + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + try: + access_token_expires_at = _normalize_datetime( + datetime.fromisoformat(access_token_expires_at_raw) + ) + except ValueError as exc: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) from exc + + return ParsedGmailCredential( + access_token=access_token, + credential_kind=credential_kind, + refresh_token=refresh_token, + client_id=client_id, + client_secret=client_secret, + access_token_expires_at=access_token_expires_at, + ) + + +def refresh_gmail_access_token( + *, + gmail_account_id: UUID, + refresh_token: str, + client_id: str, + client_secret: str, +) -> tuple[str, datetime]: + request = Request( + GMAIL_TOKEN_REFRESH_URL, + data=urlencode( + { + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + "grant_type": "refresh_token", + } + ).encode("utf-8"), + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }, + method="POST", + ) + + try: + with urlopen(request, timeout=GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS) as response: + payload = json.loads(response.read().decode("utf-8")) + except HTTPError as exc: + if exc.code in {400, 401}: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} refresh credentials were rejected" + ) from exc + raise GmailCredentialRefreshError( + f"gmail account {gmail_account_id} access token could not be renewed" + ) from exc + except (OSError, URLError, UnicodeDecodeError, json.JSONDecodeError) as exc: + raise GmailCredentialRefreshError( + f"gmail account {gmail_account_id} access token could not be renewed" + ) from exc + + refreshed_access_token = _coerce_nonempty_string(payload.get("access_token")) + expires_in = payload.get("expires_in") + if refreshed_access_token is None or not isinstance(expires_in, (int, float)) or expires_in <= 0: + raise GmailCredentialRefreshError( + f"gmail account {gmail_account_id} access token could not be renewed" + ) + + refreshed_expires_at = datetime.now(UTC) + timedelta(seconds=float(expires_in)) + return refreshed_access_token, refreshed_expires_at + + def resolve_gmail_access_token( store: ContinuityStore, *, @@ -109,24 +303,35 @@ def resolve_gmail_access_token( f"gmail account {gmail_account_id} has invalid protected credentials" ) - credential_blob = credential["credential_blob"] - if not isinstance(credential_blob, dict): - raise GmailCredentialInvalidError( - f"gmail account {gmail_account_id} has invalid protected credentials" - ) - - credential_kind = credential_blob.get("credential_kind") - access_token = credential_blob.get("access_token") + parsed_credential = _parse_gmail_credential( + gmail_account_id=gmail_account_id, + credential_blob=credential["credential_blob"], + ) if ( - credential_kind != GMAIL_PROTECTED_CREDENTIAL_KIND - or not isinstance(access_token, str) - or access_token == "" + parsed_credential.credential_kind != GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND + or parsed_credential.access_token_expires_at is None + or parsed_credential.access_token_expires_at > datetime.now(UTC) ): - raise GmailCredentialInvalidError( - f"gmail account {gmail_account_id} has invalid protected credentials" - ) + return parsed_credential.access_token - return access_token + refreshed_access_token, refreshed_expires_at = refresh_gmail_access_token( + gmail_account_id=gmail_account_id, + refresh_token=parsed_credential.refresh_token, + client_id=parsed_credential.client_id, + client_secret=parsed_credential.client_secret, + ) + store.update_gmail_account_credential( + gmail_account_id=gmail_account_id, + auth_kind=credential["auth_kind"], + credential_blob=build_gmail_protected_credential_blob( + access_token=refreshed_access_token, + refresh_token=parsed_credential.refresh_token, + client_id=parsed_credential.client_id, + client_secret=parsed_credential.client_secret, + access_token_expires_at=refreshed_expires_at, + ), + ) + return refreshed_access_token def create_gmail_account_record( @@ -155,6 +360,10 @@ def create_gmail_account_record( auth_kind=GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, credential_blob=build_gmail_protected_credential_blob( access_token=request.access_token, + refresh_token=request.refresh_token, + client_id=request.client_id, + client_secret=request.client_secret, + access_token_expires_at=request.access_token_expires_at, ), ) except psycopg.errors.UniqueViolation as exc: diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index 9fbe271..7382b06 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -5,7 +5,7 @@ from uuid import UUID from fastapi import FastAPI, Query from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from fastapi.responses import JSONResponse from urllib.parse import urlsplit, urlunsplit @@ -142,6 +142,8 @@ GmailAccountAlreadyExistsError, GmailCredentialInvalidError, GmailCredentialNotFoundError, + GmailCredentialRefreshError, + GmailCredentialValidationError, GmailAccountNotFoundError, GmailMessageFetchError, GmailMessageNotFoundError, @@ -541,6 +543,27 @@ class ConnectGmailAccountRequest(BaseModel): display_name: str | None = Field(default=None, min_length=1, max_length=200) scope: Literal["https://www.googleapis.com/auth/gmail.readonly"] = GMAIL_READONLY_SCOPE access_token: str = Field(min_length=1, max_length=8000) + refresh_token: str | None = Field(default=None, min_length=1, max_length=8000) + client_id: str | None = Field(default=None, min_length=1, max_length=2000) + client_secret: str | None = Field(default=None, min_length=1, max_length=8000) + access_token_expires_at: datetime | None = None + + @model_validator(mode="after") + def validate_refresh_bundle(self) -> ConnectGmailAccountRequest: + refresh_bundle = ( + self.refresh_token, + self.client_id, + self.client_secret, + self.access_token_expires_at, + ) + if all(value is None for value in refresh_bundle): + return self + if any(value is None for value in refresh_bundle): + raise ValueError( + "gmail refresh credentials must include refresh_token, client_id, " + "client_secret, and access_token_expires_at" + ) + return self class IngestGmailMessageRequest(BaseModel): @@ -1301,8 +1324,14 @@ def connect_gmail_account(request: ConnectGmailAccountRequest) -> JSONResponse: display_name=request.display_name, scope=request.scope, access_token=request.access_token, + refresh_token=request.refresh_token, + client_id=request.client_id, + client_secret=request.client_secret, + access_token_expires_at=request.access_token_expires_at, ), ) + except GmailCredentialValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) except GmailAccountAlreadyExistsError as exc: return JSONResponse(status_code=409, content={"detail": str(exc)}) @@ -1379,7 +1408,7 @@ def ingest_gmail_message( return JSONResponse(status_code=409, content={"detail": str(exc)}) except TaskArtifactValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) - except GmailMessageFetchError as exc: + except (GmailMessageFetchError, GmailCredentialRefreshError) as exc: return JSONResponse(status_code=502, content={"detail": str(exc)}) except TaskArtifactAlreadyExistsError as exc: return JSONResponse(status_code=409, content={"detail": str(exc)}) diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index 4822a44..61dcd53 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -1587,6 +1587,22 @@ class LabelCountRow(TypedDict): WHERE gmail_account_id = %s """ +UPDATE_GMAIL_ACCOUNT_CREDENTIAL_SQL = """ + UPDATE gmail_account_credentials + SET + auth_kind = %s, + credential_blob = %s, + updated_at = clock_timestamp() + WHERE gmail_account_id = %s + RETURNING + gmail_account_id, + user_id, + auth_kind, + credential_blob, + created_at, + updated_at + """ + LIST_GMAIL_ACCOUNTS_SQL = """ SELECT id, @@ -3155,6 +3171,19 @@ def get_gmail_account_credential_optional( ) -> ProtectedGmailCredentialRow | None: return self._fetch_optional_one(GET_GMAIL_ACCOUNT_CREDENTIAL_SQL, (gmail_account_id,)) + def update_gmail_account_credential( + self, + *, + gmail_account_id: UUID, + auth_kind: str, + credential_blob: JsonObject, + ) -> ProtectedGmailCredentialRow: + return self._fetch_one( + "update_gmail_account_credential", + UPDATE_GMAIL_ACCOUNT_CREDENTIAL_SQL, + (auth_kind, Jsonb(credential_blob), gmail_account_id), + ) + def get_gmail_account_by_provider_account_id_optional( self, provider_account_id: str, diff --git a/tests/integration/test_gmail_accounts_api.py b/tests/integration/test_gmail_accounts_api.py index 334cabd..8385474 100644 --- a/tests/integration/test_gmail_accounts_api.py +++ b/tests/integration/test_gmail_accounts_api.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from datetime import datetime from pathlib import Path from typing import Any from urllib.parse import urlencode @@ -151,18 +152,28 @@ def seed_task(database_url: str, *, email: str) -> dict[str, UUID]: } -def _connect_gmail_account(*, user_id: UUID, provider_account_id: str, email_address: str) -> tuple[int, dict[str, Any]]: +def _connect_gmail_account( + *, + user_id: UUID, + provider_account_id: str, + email_address: str, + credential_overrides: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + payload = { + "user_id": str(user_id), + "provider_account_id": provider_account_id, + "email_address": email_address, + "display_name": email_address.split("@", 1)[0].title(), + "scope": "https://www.googleapis.com/auth/gmail.readonly", + "access_token": f"token-for-{provider_account_id}", + } + if credential_overrides is not None: + payload.update(credential_overrides) + return invoke_request( "POST", "/v0/gmail-accounts", - payload={ - "user_id": str(user_id), - "provider_account_id": provider_account_id, - "email_address": email_address, - "display_name": email_address.split("@", 1)[0].title(), - "scope": "https://www.googleapis.com/auth/gmail.readonly", - "access_token": f"token-for-{provider_account_id}", - }, + payload=payload, ) @@ -244,6 +255,8 @@ def test_gmail_account_endpoints_connect_list_detail_and_isolate( assert '"access_token":' not in json.dumps(create_payload) assert '"access_token":' not in json.dumps(list_payload) assert '"access_token":' not in json.dumps(detail_payload) + assert '"refresh_token":' not in json.dumps(create_payload) + assert '"client_secret":' not in json.dumps(create_payload) with psycopg.connect(migrated_database_urls["admin"]) as conn: with conn.cursor() as cur: @@ -363,6 +376,104 @@ def test_gmail_message_ingestion_endpoint_persists_artifact_and_chunks( assert chunk_rows[0]["text"].startswith("From: Alice ") +def test_gmail_message_ingestion_endpoint_renews_expired_access_token( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + raw_bytes = _build_rfc822_email_bytes(subject="Inbox Update", plain_body="renewed token path") + fetch_tokens: list[str] = [] + + monkeypatch.setattr( + gmail_module, + "refresh_gmail_access_token", + lambda **_kwargs: ("token-refreshed", datetime.fromisoformat("2030-01-01T00:05:00+00:00")), + ) + + def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes: + fetch_tokens.append(access_token) + return raw_bytes + + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + fake_fetch_gmail_message_raw_bytes, + ) + + account_status, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-refresh-001", + email_address="owner@gmail.example", + credential_overrides={ + "refresh_token": "refresh-owner-001", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2020-01-01T00:00:00+00:00", + }, + ) + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert account_status == 201 + assert workspace_status == 201 + assert ingest_status == 200 + assert fetch_tokens == ["token-refreshed"] + assert ingest_payload["message"] == { + "provider_message_id": "msg-001", + "artifact_relative_path": "gmail/acct-owner-refresh-001/msg-001.eml", + "media_type": "message/rfc822", + } + assert '"refresh_token":' not in json.dumps(ingest_payload) + assert '"client_secret":' not in json.dumps(ingest_payload) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token', + credential_blob ->> 'refresh_token', + credential_blob ->> 'client_id', + credential_blob ->> 'client_secret', + credential_blob ->> 'access_token_expires_at' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (UUID(account_payload["account"]["id"]),), + ) + credential_row = cur.fetchone() + + assert credential_row is not None + assert credential_row[0] == "gmail_oauth_refresh_token_v2" + assert credential_row[1] == "token-refreshed" + assert credential_row[2] == "refresh-owner-001" + assert credential_row[3] == "client-owner-001" + assert credential_row[4] == "secret-owner-001" + assert credential_row[5] == "2030-01-01T00:05:00+00:00" + + def test_gmail_message_ingestion_endpoint_rejects_cross_user_workspace_access( migrated_database_urls, monkeypatch, @@ -478,6 +589,78 @@ def fail_fetch(**_kwargs): assert store.list_task_artifacts_for_task(owner["task_id"]) == [] +def test_gmail_message_ingestion_endpoint_rejects_invalid_refresh_credentials_without_side_effects( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + _, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-refresh-001", + email_address="owner@gmail.example", + credential_overrides={ + "refresh_token": "refresh-owner-001", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2020-01-01T00:00:00+00:00", + }, + ) + _, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + + def fail_refresh(**_kwargs): + raise gmail_module.GmailCredentialInvalidError( + f"gmail account {account_payload['account']['id']} refresh credentials were rejected" + ) + + monkeypatch.setattr(gmail_module, "refresh_gmail_access_token", fail_refresh) + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + lambda **_kwargs: (_ for _ in ()).throw( + AssertionError("fetch_gmail_message_raw_bytes should not be called") + ), + ) + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert ingest_status == 409 + assert ingest_payload == { + "detail": ( + f"gmail account {account_payload['account']['id']} refresh credentials were rejected" + ) + } + artifact_file = ( + Path(workspace_payload["workspace"]["local_path"]) / "gmail" / "acct-owner-refresh-001" / "msg-001.eml" + ) + assert not artifact_file.exists() + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_task_artifacts_for_task(owner["task_id"]) == [] + + def test_gmail_message_ingestion_endpoint_rejects_sanitized_path_collisions_without_overwrite( migrated_database_urls, monkeypatch, diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 7e4817d..aba3133 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -347,6 +347,126 @@ def test_gmail_account_credentials_migration_round_trip_preserves_tokens(databas assert cur.fetchone() == (None,) +def test_gmail_refresh_token_lifecycle_migration_round_trip_preserves_downgrade_compatibility( + database_urls, +): + config = make_alembic_config(database_urls["admin"]) + user_id = "00000000-0000-0000-0000-000000000201" + gmail_account_id = "00000000-0000-0000-0000-000000000202" + + command.upgrade(config, "20260316_0027") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, 'gmail-refresh@example.com', 'Gmail Refresh User') + """, + (user_id,), + ) + cur.execute( + """ + INSERT INTO gmail_accounts ( + id, + user_id, + provider_account_id, + email_address, + display_name, + scope + ) + VALUES ( + %s, + %s, + 'acct-refresh-001', + 'owner@gmail.example', + 'Owner', + 'https://www.googleapis.com/auth/gmail.readonly' + ) + """, + (gmail_account_id, user_id), + ) + cur.execute( + """ + INSERT INTO gmail_account_credentials ( + gmail_account_id, + user_id, + auth_kind, + credential_blob + ) + VALUES ( + %s, + %s, + 'oauth_access_token', + jsonb_build_object( + 'credential_kind', 'gmail_oauth_access_token_v1', + 'access_token', 'token-before-refresh-lifecycle' + ) + ) + """, + (gmail_account_id, user_id), + ) + conn.commit() + + command.upgrade(config, "20260316_0028") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE gmail_account_credentials + SET credential_blob = jsonb_build_object( + 'credential_kind', 'gmail_oauth_refresh_token_v2', + 'access_token', 'token-after-refresh', + 'refresh_token', 'refresh-001', + 'client_id', 'client-001', + 'client_secret', 'secret-001', + 'access_token_expires_at', '2030-01-01T00:05:00+00:00' + ) + WHERE gmail_account_id = %s + """, + (gmail_account_id,), + ) + cur.execute( + """ + SELECT + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token', + credential_blob ->> 'refresh_token' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (gmail_account_id,), + ) + assert cur.fetchone() == ( + "gmail_oauth_refresh_token_v2", + "token-after-refresh", + "refresh-001", + ) + conn.commit() + + command.downgrade(config, "20260316_0027") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token', + credential_blob ? 'refresh_token' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (gmail_account_id,), + ) + assert cur.fetchone() == ( + "gmail_oauth_access_token_v1", + "token-after-refresh", + False, + ) + + def test_migrations_upgrade_and_downgrade(database_urls): config = make_alembic_config(database_urls["admin"]) diff --git a/tests/unit/test_20260316_0028_gmail_refresh_token_lifecycle.py b/tests/unit/test_20260316_0028_gmail_refresh_token_lifecycle.py new file mode 100644 index 0000000..0c180c2 --- /dev/null +++ b/tests/unit/test_20260316_0028_gmail_refresh_token_lifecycle.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260316_0028_gmail_refresh_token_lifecycle" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_gmail_account_credential_privileges_allow_runtime_updates() -> None: + module = load_migration_module() + + assert module._UPGRADE_STATEMENTS[-1] == ( + "GRANT UPDATE ON gmail_account_credentials TO alicebot_app" + ) diff --git a/tests/unit/test_gmail.py b/tests/unit/test_gmail.py index b84c945..f482807 100644 --- a/tests/unit/test_gmail.py +++ b/tests/unit/test_gmail.py @@ -9,6 +9,7 @@ from alicebot_api.artifacts import TaskArtifactAlreadyExistsError from alicebot_api.contracts import ( GMAIL_PROTECTED_CREDENTIAL_KIND, + GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, GMAIL_READONLY_SCOPE, GmailAccountConnectInput, GmailMessageIngestInput, @@ -18,6 +19,7 @@ GmailAccountNotFoundError, GmailCredentialInvalidError, GmailCredentialNotFoundError, + GmailCredentialValidationError, GmailMessageUnsupportedError, build_gmail_message_artifact_relative_path, build_gmail_protected_credential_blob, @@ -112,6 +114,24 @@ def get_gmail_account_credential_optional( self.operations.append(("get_gmail_account_credential_optional", gmail_account_id)) return self.gmail_account_credentials.get(gmail_account_id) + def update_gmail_account_credential( + self, + *, + gmail_account_id: UUID, + auth_kind: str, + credential_blob: dict[str, object], + ) -> dict[str, object]: + existing = self.gmail_account_credentials[gmail_account_id] + updated = { + **existing, + "auth_kind": auth_kind, + "credential_blob": credential_blob, + "updated_at": self.base_time + timedelta(hours=1), + } + self.gmail_account_credentials[gmail_account_id] = updated + self.operations.append(("update_gmail_account_credential", gmail_account_id)) + return updated + def get_gmail_account_by_provider_account_id_optional( self, provider_account_id: str, @@ -271,6 +291,67 @@ def test_create_gmail_account_record_persists_protected_credential_and_hides_sec assert store.operations == [("create_gmail_account_credential", account_id)] +def test_create_gmail_account_record_persists_refreshable_protected_credential_and_hides_secret() -> None: + store = GmailStoreStub() + user_id = uuid4() + expires_at = datetime(2030, 1, 1, 0, 0, tzinfo=UTC) + + response = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-refresh-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + refresh_token="refresh-1", + client_id="client-1", + client_secret="secret-1", + access_token_expires_at=expires_at, + ), + ) + + account_id = UUID(response["account"]["id"]) + assert response == { + "account": { + "id": str(account_id), + "provider": "gmail", + "auth_kind": "oauth_access_token", + "provider_account_id": "acct-refresh-001", + "email_address": "owner@example.com", + "display_name": "Owner", + "scope": GMAIL_READONLY_SCOPE, + "created_at": response["account"]["created_at"], + "updated_at": response["account"]["updated_at"], + } + } + assert store.gmail_account_credentials[account_id]["credential_blob"] == { + "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, + "access_token": "token-1", + "refresh_token": "refresh-1", + "client_id": "client-1", + "client_secret": "secret-1", + "access_token_expires_at": expires_at.isoformat(), + } + assert store.operations == [("create_gmail_account_credential", account_id)] + + +def test_build_gmail_protected_credential_blob_rejects_partial_refresh_bundle() -> None: + with pytest.raises( + GmailCredentialValidationError, + match=( + "gmail refresh credentials must include refresh_token, client_id, client_secret, " + "and access_token_expires_at" + ), + ): + build_gmail_protected_credential_blob( + access_token="token-1", + refresh_token="refresh-1", + client_id="client-1", + ) + + def test_create_gmail_account_record_rejects_duplicate_provider_account_id() -> None: store = GmailStoreStub() user_id = uuid4() @@ -356,6 +437,76 @@ def test_resolve_gmail_access_token_rejects_missing_and_invalid_protected_creden resolve_gmail_access_token(store, gmail_account_id=account_id) +def test_resolve_gmail_access_token_renews_expired_refreshable_credential(monkeypatch) -> None: + store = GmailStoreStub() + expired_at = datetime(2020, 1, 1, 0, 0, tzinfo=UTC) + refreshed_at = datetime(2030, 1, 1, 0, 5, tzinfo=UTC) + account = create_gmail_account_record( + store, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + refresh_token="refresh-1", + client_id="client-1", + client_secret="secret-1", + access_token_expires_at=expired_at, + ), + )["account"] + account_id = UUID(account["id"]) + + monkeypatch.setattr( + "alicebot_api.gmail.refresh_gmail_access_token", + lambda **_kwargs: ("token-2", refreshed_at), + ) + + assert resolve_gmail_access_token(store, gmail_account_id=account_id) == "token-2" + assert store.gmail_account_credentials[account_id]["credential_blob"] == { + "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, + "access_token": "token-2", + "refresh_token": "refresh-1", + "client_id": "client-1", + "client_secret": "secret-1", + "access_token_expires_at": refreshed_at.isoformat(), + } + assert store.operations[-2:] == [ + ("get_gmail_account_credential_optional", account_id), + ("update_gmail_account_credential", account_id), + ] + + +def test_resolve_gmail_access_token_rejects_invalid_refreshable_protected_credentials() -> None: + store = GmailStoreStub() + account = create_gmail_account_record( + store, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + account_id = UUID(account["id"]) + store.gmail_account_credentials[account_id]["credential_blob"] = { + "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, + "access_token": "token-1", + "client_id": "client-1", + "client_secret": "secret-1", + "access_token_expires_at": "2020-01-01T00:00:00+00:00", + } + + with pytest.raises( + GmailCredentialInvalidError, + match=f"gmail account {account_id} has invalid protected credentials", + ): + resolve_gmail_access_token(store, gmail_account_id=account_id) + + def test_ingest_gmail_message_record_writes_rfc822_artifact_and_reuses_artifact_seam( monkeypatch, tmp_path, @@ -480,6 +631,110 @@ def fake_ingest(_store, *, user_id: UUID, request): ] +def test_ingest_gmail_message_record_renews_expired_access_token_before_fetch( + monkeypatch, + tmp_path, +) -> None: + store = GmailStoreStub() + user_id = uuid4() + workspace_id = uuid4() + workspace = store.create_task_workspace( + task_workspace_id=workspace_id, + local_path=str((tmp_path / "workspace").resolve()), + ) + account = create_gmail_account_record( + store, + user_id=user_id, + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-expired", + refresh_token="refresh-1", + client_id="client-1", + client_secret="secret-1", + access_token_expires_at=datetime(2020, 1, 1, 0, 0, tzinfo=UTC), + ), + )["account"] + raw_bytes = _build_rfc822_email_bytes(plain_body="hello from gmail") + calls: dict[str, object] = {} + + monkeypatch.setattr( + "alicebot_api.gmail.refresh_gmail_access_token", + lambda **_kwargs: ("token-refreshed", datetime(2030, 1, 1, 0, 5, tzinfo=UTC)), + ) + + def fake_fetch(**kwargs): + calls["fetch_access_token"] = kwargs["access_token"] + return raw_bytes + + monkeypatch.setattr("alicebot_api.gmail.fetch_gmail_message_raw_bytes", fake_fetch) + + monkeypatch.setattr( + "alicebot_api.gmail.register_task_artifact_record", + lambda _store, *, user_id, request: { + "artifact": { + "id": "00000000-0000-0000-0000-000000000123", + "task_id": str(workspace["task_id"]), + "task_workspace_id": str(workspace_id), + "status": "registered", + "ingestion_status": "pending", + "relative_path": Path(request.local_path) + .relative_to(Path(workspace["local_path"])) + .as_posix(), + "media_type_hint": "message/rfc822", + "created_at": "2026-03-16T10:00:00+00:00", + "updated_at": "2026-03-16T10:00:00+00:00", + } + }, + ) + monkeypatch.setattr( + "alicebot_api.gmail.ingest_task_artifact_record", + lambda _store, *, user_id, request: { + "artifact": { + "id": "00000000-0000-0000-0000-000000000123", + "task_id": str(workspace["task_id"]), + "task_workspace_id": str(workspace_id), + "status": "registered", + "ingestion_status": "ingested", + "relative_path": build_gmail_message_artifact_relative_path( + provider_account_id="acct-001", + provider_message_id="msg-001", + ), + "media_type_hint": "message/rfc822", + "created_at": "2026-03-16T10:00:00+00:00", + "updated_at": "2026-03-16T10:00:01+00:00", + }, + "summary": { + "total_count": 1, + "total_characters": 16, + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + }, + ) + + response = ingest_gmail_message_record( + store, + user_id=user_id, + request=GmailMessageIngestInput( + gmail_account_id=UUID(account["id"]), + task_workspace_id=workspace_id, + provider_message_id="msg-001", + ), + ) + + assert response["message"]["artifact_relative_path"] == "gmail/acct-001/msg-001.eml" + assert calls["fetch_access_token"] == "token-refreshed" + assert store.operations[:3] == [ + ("create_gmail_account_credential", UUID(account["id"])), + ("get_gmail_account_credential_optional", UUID(account["id"])), + ("update_gmail_account_credential", UUID(account["id"])), + ] + + def test_ingest_gmail_message_record_rejects_unsupported_message(monkeypatch, tmp_path) -> None: store = GmailStoreStub() user_id = uuid4() @@ -688,9 +943,10 @@ def test_ingest_gmail_message_record_rejects_invalid_protected_credentials_befor ), )["account"] account_id = UUID(account["id"]) - store.gmail_account_credentials[account_id]["credential_blob"] = build_gmail_protected_credential_blob( - access_token="", - ) + store.gmail_account_credentials[account_id]["credential_blob"] = { + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "access_token": "", + } def fail_fetch(**_kwargs): raise AssertionError("fetch_gmail_message_raw_bytes should not be called") diff --git a/tests/unit/test_gmail_main.py b/tests/unit/test_gmail_main.py index 18fc375..24adb3a 100644 --- a/tests/unit/test_gmail_main.py +++ b/tests/unit/test_gmail_main.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from uuid import uuid4 +import pytest import apps.api.src.alicebot_api.main as main_module from apps.api.src.alicebot_api.config import Settings from alicebot_api.gmail import ( @@ -11,6 +12,8 @@ GmailAccountNotFoundError, GmailCredentialInvalidError, GmailCredentialNotFoundError, + GmailCredentialRefreshError, + GmailCredentialValidationError, GmailMessageFetchError, GmailMessageNotFoundError, GmailMessageUnsupportedError, @@ -77,6 +80,55 @@ def fake_create_gmail_account_record(*_args, **_kwargs): } +def test_connect_gmail_account_request_requires_complete_refresh_bundle() -> None: + with pytest.raises(ValueError, match="gmail refresh credentials must include refresh_token"): + main_module.ConnectGmailAccountRequest( + user_id=uuid4(), + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + access_token="token-1", + refresh_token="refresh-1", + ) + + +def test_connect_gmail_account_endpoint_maps_invalid_refresh_bundle_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_gmail_account_record(*_args, **_kwargs): + raise GmailCredentialValidationError( + "gmail refresh credentials must include refresh_token, client_id, client_secret, " + "and access_token_expires_at" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_gmail_account_record", fake_create_gmail_account_record) + + response = main_module.connect_gmail_account( + main_module.ConnectGmailAccountRequest( + user_id=user_id, + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + access_token="token-1", + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": ( + "gmail refresh credentials must include refresh_token, client_id, client_secret, " + "and access_token_expires_at" + ) + } + + def test_get_gmail_account_endpoint_maps_not_found_to_404(monkeypatch) -> None: user_id = uuid4() gmail_account_id = uuid4() @@ -230,3 +282,22 @@ def fake_fetch_error(*_args, **_kwargs): assert json.loads(response.body) == { "detail": "gmail message msg-001 could not be fetched" } + + def fake_refresh_error(*_args, **_kwargs): + raise GmailCredentialRefreshError( + f"gmail account {gmail_account_id} access token could not be renewed" + ) + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_refresh_error) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 502 + assert json.loads(response.body) == { + "detail": f"gmail account {gmail_account_id} access token could not be renewed" + } diff --git a/tests/unit/test_gmail_refresh.py b/tests/unit/test_gmail_refresh.py new file mode 100644 index 0000000..779e454 --- /dev/null +++ b/tests/unit/test_gmail_refresh.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime, timedelta +from io import BytesIO +from urllib.error import HTTPError, URLError +from urllib.parse import parse_qs +from uuid import uuid4 + +import pytest + +from alicebot_api.gmail import ( + GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS, + GMAIL_TOKEN_REFRESH_URL, + GmailCredentialInvalidError, + GmailCredentialRefreshError, + refresh_gmail_access_token, +) + + +class _FakeHTTPResponse: + def __init__(self, payload: bytes) -> None: + self._payload = payload + + def __enter__(self) -> _FakeHTTPResponse: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def read(self) -> bytes: + return self._payload + + +def _make_http_error(status_code: int) -> HTTPError: + return HTTPError( + GMAIL_TOKEN_REFRESH_URL, + status_code, + "upstream error", + hdrs=None, + fp=BytesIO(b'{"error":"invalid_grant"}'), + ) + + +def test_refresh_gmail_access_token_posts_expected_payload_and_returns_expiry(monkeypatch) -> None: + gmail_account_id = uuid4() + seen: dict[str, object] = {} + + def fake_urlopen(request, timeout: int): + seen["url"] = request.full_url + seen["timeout"] = timeout + seen["content_type"] = request.headers["Content-type"] + seen["accept"] = request.headers["Accept"] + seen["body"] = parse_qs(request.data.decode("utf-8")) + return _FakeHTTPResponse( + json.dumps({"access_token": "token-refreshed", "expires_in": 3600}).encode("utf-8") + ) + + monkeypatch.setattr("alicebot_api.gmail.urlopen", fake_urlopen) + + started_at = datetime.now(UTC) + access_token, expires_at = refresh_gmail_access_token( + gmail_account_id=gmail_account_id, + refresh_token="refresh-001", + client_id="client-001", + client_secret="secret-001", + ) + finished_at = datetime.now(UTC) + + assert access_token == "token-refreshed" + assert started_at + timedelta(seconds=3590) <= expires_at <= finished_at + timedelta(seconds=3610) + assert seen == { + "url": GMAIL_TOKEN_REFRESH_URL, + "timeout": GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS, + "content_type": "application/x-www-form-urlencoded", + "accept": "application/json", + "body": { + "client_id": ["client-001"], + "client_secret": ["secret-001"], + "refresh_token": ["refresh-001"], + "grant_type": ["refresh_token"], + }, + } + + +@pytest.mark.parametrize("status_code", [400, 401]) +def test_refresh_gmail_access_token_maps_invalid_refresh_rejections_to_invalid_error( + monkeypatch, + status_code: int, +) -> None: + gmail_account_id = uuid4() + + def fake_urlopen(_request, timeout: int): + assert timeout == GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS + raise _make_http_error(status_code) + + monkeypatch.setattr("alicebot_api.gmail.urlopen", fake_urlopen) + + with pytest.raises( + GmailCredentialInvalidError, + match=f"gmail account {gmail_account_id} refresh credentials were rejected", + ): + refresh_gmail_access_token( + gmail_account_id=gmail_account_id, + refresh_token="refresh-001", + client_id="client-001", + client_secret="secret-001", + ) + + +def test_refresh_gmail_access_token_maps_non_deterministic_http_failure_to_refresh_error( + monkeypatch, +) -> None: + gmail_account_id = uuid4() + + def fake_urlopen(_request, timeout: int): + assert timeout == GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS + raise _make_http_error(500) + + monkeypatch.setattr("alicebot_api.gmail.urlopen", fake_urlopen) + + with pytest.raises( + GmailCredentialRefreshError, + match=f"gmail account {gmail_account_id} access token could not be renewed", + ): + refresh_gmail_access_token( + gmail_account_id=gmail_account_id, + refresh_token="refresh-001", + client_id="client-001", + client_secret="secret-001", + ) + + +@pytest.mark.parametrize( + ("response_payload", "error"), + [ + (b"not-json", None), + (json.dumps({"expires_in": 3600}).encode("utf-8"), None), + (None, URLError("network down")), + ], +) +def test_refresh_gmail_access_token_maps_malformed_or_transport_failures_to_refresh_error( + monkeypatch, + response_payload: bytes | None, + error: Exception | None, +) -> None: + gmail_account_id = uuid4() + + def fake_urlopen(_request, timeout: int): + assert timeout == GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS + if error is not None: + raise error + assert response_payload is not None + return _FakeHTTPResponse(response_payload) + + monkeypatch.setattr("alicebot_api.gmail.urlopen", fake_urlopen) + + with pytest.raises( + GmailCredentialRefreshError, + match=f"gmail account {gmail_account_id} access token could not be renewed", + ): + refresh_gmail_access_token( + gmail_account_id=gmail_account_id, + refresh_token="refresh-001", + client_id="client-001", + client_secret="secret-001", + ) From 945a6b8adad7d3143d82c57efccdeaeb8bc72e42 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 19:44:00 +0100 Subject: [PATCH 018/135] Sprint 5R: Gmail refresh-token rotation handling (#18) Co-authored-by: Sami Rusani --- apps/api/src/alicebot_api/gmail.py | 69 ++++-- apps/api/src/alicebot_api/main.py | 7 +- tests/integration/test_gmail_accounts_api.py | 233 ++++++++++++++++++- tests/unit/test_gmail.py | 101 +++++++- tests/unit/test_gmail_main.py | 20 ++ tests/unit/test_gmail_refresh.py | 42 +++- 6 files changed, 452 insertions(+), 20 deletions(-) diff --git a/apps/api/src/alicebot_api/gmail.py b/apps/api/src/alicebot_api/gmail.py index 5d60a33..016c8c3 100644 --- a/apps/api/src/alicebot_api/gmail.py +++ b/apps/api/src/alicebot_api/gmail.py @@ -39,7 +39,7 @@ TaskArtifactIngestInput, TaskArtifactRegisterInput, ) -from alicebot_api.store import ContinuityStore, GmailAccountRow +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError, GmailAccountRow from alicebot_api.workspaces import TaskWorkspaceNotFoundError GMAIL_MESSAGE_FETCH_TIMEOUT_SECONDS = 30 @@ -81,6 +81,10 @@ class GmailCredentialRefreshError(RuntimeError): """Raised when Gmail access-token renewal fails for non-deterministic reasons.""" +class GmailCredentialPersistenceError(RuntimeError): + """Raised when renewed Gmail protected credentials cannot be persisted.""" + + class GmailCredentialValidationError(ValueError): """Raised when Gmail connect input contains an invalid credential combination.""" @@ -95,6 +99,13 @@ class ParsedGmailCredential: access_token_expires_at: datetime | None = None +@dataclass(frozen=True, slots=True) +class RefreshedGmailCredential: + access_token: str + access_token_expires_at: datetime + refresh_token: str | None = None + + def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: return { "id": str(row["id"]), @@ -242,7 +253,7 @@ def refresh_gmail_access_token( refresh_token: str, client_id: str, client_secret: str, -) -> tuple[str, datetime]: +) -> RefreshedGmailCredential: request = Request( GMAIL_TOKEN_REFRESH_URL, data=urlencode( @@ -277,6 +288,7 @@ def refresh_gmail_access_token( ) from exc refreshed_access_token = _coerce_nonempty_string(payload.get("access_token")) + replacement_refresh_token = _coerce_nonempty_string(payload.get("refresh_token")) expires_in = payload.get("expires_in") if refreshed_access_token is None or not isinstance(expires_in, (int, float)) or expires_in <= 0: raise GmailCredentialRefreshError( @@ -284,7 +296,42 @@ def refresh_gmail_access_token( ) refreshed_expires_at = datetime.now(UTC) + timedelta(seconds=float(expires_in)) - return refreshed_access_token, refreshed_expires_at + return RefreshedGmailCredential( + access_token=refreshed_access_token, + access_token_expires_at=refreshed_expires_at, + refresh_token=replacement_refresh_token, + ) + + +def _persist_refreshed_gmail_credential( + store: ContinuityStore, + *, + gmail_account_id: UUID, + auth_kind: str, + existing_credential: ParsedGmailCredential, + refreshed_credential: RefreshedGmailCredential, +) -> None: + replacement_refresh_token = ( + refreshed_credential.refresh_token + if refreshed_credential.refresh_token is not None + else existing_credential.refresh_token + ) + try: + store.update_gmail_account_credential( + gmail_account_id=gmail_account_id, + auth_kind=auth_kind, + credential_blob=build_gmail_protected_credential_blob( + access_token=refreshed_credential.access_token, + refresh_token=replacement_refresh_token, + client_id=existing_credential.client_id, + client_secret=existing_credential.client_secret, + access_token_expires_at=refreshed_credential.access_token_expires_at, + ), + ) + except (ContinuityStoreInvariantError, psycopg.Error) as exc: + raise GmailCredentialPersistenceError( + f"gmail account {gmail_account_id} renewed protected credentials could not be persisted" + ) from exc def resolve_gmail_access_token( @@ -314,24 +361,20 @@ def resolve_gmail_access_token( ): return parsed_credential.access_token - refreshed_access_token, refreshed_expires_at = refresh_gmail_access_token( + refreshed_credential = refresh_gmail_access_token( gmail_account_id=gmail_account_id, refresh_token=parsed_credential.refresh_token, client_id=parsed_credential.client_id, client_secret=parsed_credential.client_secret, ) - store.update_gmail_account_credential( + _persist_refreshed_gmail_credential( + store, gmail_account_id=gmail_account_id, auth_kind=credential["auth_kind"], - credential_blob=build_gmail_protected_credential_blob( - access_token=refreshed_access_token, - refresh_token=parsed_credential.refresh_token, - client_id=parsed_credential.client_id, - client_secret=parsed_credential.client_secret, - access_token_expires_at=refreshed_expires_at, - ), + existing_credential=parsed_credential, + refreshed_credential=refreshed_credential, ) - return refreshed_access_token + return refreshed_credential.access_token def create_gmail_account_record( diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index 7382b06..f23b11e 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -142,6 +142,7 @@ GmailAccountAlreadyExistsError, GmailCredentialInvalidError, GmailCredentialNotFoundError, + GmailCredentialPersistenceError, GmailCredentialRefreshError, GmailCredentialValidationError, GmailAccountNotFoundError, @@ -1404,7 +1405,11 @@ def ingest_gmail_message( return JSONResponse(status_code=404, content={"detail": str(exc)}) except GmailMessageUnsupportedError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) - except (GmailCredentialNotFoundError, GmailCredentialInvalidError) as exc: + except ( + GmailCredentialNotFoundError, + GmailCredentialInvalidError, + GmailCredentialPersistenceError, + ) as exc: return JSONResponse(status_code=409, content={"detail": str(exc)}) except TaskArtifactValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) diff --git a/tests/integration/test_gmail_accounts_api.py b/tests/integration/test_gmail_accounts_api.py index 8385474..ec4c7aa 100644 --- a/tests/integration/test_gmail_accounts_api.py +++ b/tests/integration/test_gmail_accounts_api.py @@ -397,7 +397,10 @@ def test_gmail_message_ingestion_endpoint_renews_expired_access_token( monkeypatch.setattr( gmail_module, "refresh_gmail_access_token", - lambda **_kwargs: ("token-refreshed", datetime.fromisoformat("2030-01-01T00:05:00+00:00")), + lambda **_kwargs: gmail_module.RefreshedGmailCredential( + access_token="token-refreshed", + access_token_expires_at=datetime.fromisoformat("2030-01-01T00:05:00+00:00"), + ), ) def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes: @@ -474,6 +477,234 @@ def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes assert credential_row[5] == "2030-01-01T00:05:00+00:00" +def test_gmail_message_ingestion_endpoint_persists_rotated_refresh_token( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + raw_bytes = _build_rfc822_email_bytes(subject="Inbox Update", plain_body="rotated token path") + fetch_tokens: list[str] = [] + + monkeypatch.setattr( + gmail_module, + "refresh_gmail_access_token", + lambda **_kwargs: gmail_module.RefreshedGmailCredential( + access_token="token-refreshed-rotated", + access_token_expires_at=datetime.fromisoformat("2030-01-01T00:05:00+00:00"), + refresh_token="refresh-owner-rotated-002", + ), + ) + + def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes: + fetch_tokens.append(access_token) + return raw_bytes + + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + fake_fetch_gmail_message_raw_bytes, + ) + + account_status, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-rotated-001", + email_address="owner@gmail.example", + credential_overrides={ + "refresh_token": "refresh-owner-001", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2020-01-01T00:00:00+00:00", + }, + ) + workspace_status, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert account_status == 201 + assert workspace_status == 201 + assert ingest_status == 200 + assert fetch_tokens == ["token-refreshed-rotated"] + assert ingest_payload == { + "account": { + **account_payload["account"], + }, + "message": { + "provider_message_id": "msg-001", + "artifact_relative_path": "gmail/acct-owner-rotated-001/msg-001.eml", + "media_type": "message/rfc822", + }, + "artifact": { + "id": ingest_payload["artifact"]["id"], + "task_id": str(owner["task_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + "status": "registered", + "ingestion_status": "ingested", + "relative_path": "gmail/acct-owner-rotated-001/msg-001.eml", + "media_type_hint": "message/rfc822", + "created_at": ingest_payload["artifact"]["created_at"], + "updated_at": ingest_payload["artifact"]["updated_at"], + }, + "summary": { + "total_count": ingest_payload["summary"]["total_count"], + "total_characters": ingest_payload["summary"]["total_characters"], + "media_type": "message/rfc822", + "chunking_rule": "normalized_utf8_text_fixed_window_1000_chars_v1", + "order": ["sequence_no_asc", "id_asc"], + }, + } + assert '"refresh_token":' not in json.dumps(ingest_payload) + assert '"client_secret":' not in json.dumps(ingest_payload) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token', + credential_blob ->> 'refresh_token', + credential_blob ->> 'client_id', + credential_blob ->> 'client_secret', + credential_blob ->> 'access_token_expires_at' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (UUID(account_payload["account"]["id"]),), + ) + credential_row = cur.fetchone() + + assert credential_row is not None + assert credential_row[0] == "gmail_oauth_refresh_token_v2" + assert credential_row[1] == "token-refreshed-rotated" + assert credential_row[2] == "refresh-owner-rotated-002" + assert credential_row[3] == "client-owner-001" + assert credential_row[4] == "secret-owner-001" + assert credential_row[5] == "2030-01-01T00:05:00+00:00" + + +def test_gmail_message_ingestion_endpoint_fails_deterministically_when_rotated_credentials_cannot_be_persisted( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + monkeypatch.setattr( + gmail_module, + "refresh_gmail_access_token", + lambda **_kwargs: gmail_module.RefreshedGmailCredential( + access_token="token-refreshed-rotated", + access_token_expires_at=datetime.fromisoformat("2030-01-01T00:05:00+00:00"), + refresh_token="refresh-owner-rotated-002", + ), + ) + + def fail_update_gmail_account_credential(self, **_kwargs): + raise psycopg.Error("simulated credential persistence failure") + + def fail_fetch(**_kwargs): + raise AssertionError("fetch_gmail_message_raw_bytes should not be called") + + monkeypatch.setattr( + ContinuityStore, + "update_gmail_account_credential", + fail_update_gmail_account_credential, + ) + monkeypatch.setattr(gmail_module, "fetch_gmail_message_raw_bytes", fail_fetch) + + _, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-rotated-001", + email_address="owner@gmail.example", + credential_overrides={ + "refresh_token": "refresh-owner-001", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2020-01-01T00:00:00+00:00", + }, + ) + _, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert ingest_status == 409 + assert ingest_payload == { + "detail": ( + f"gmail account {account_payload['account']['id']} renewed protected credentials " + "could not be persisted" + ) + } + artifact_file = ( + Path(workspace_payload["workspace"]["local_path"]) / "gmail" / "acct-owner-rotated-001" / "msg-001.eml" + ) + assert not artifact_file.exists() + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_task_artifacts_for_task(owner["task_id"]) == [] + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + credential_blob ->> 'access_token', + credential_blob ->> 'refresh_token', + credential_blob ->> 'access_token_expires_at' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (UUID(account_payload["account"]["id"]),), + ) + credential_row = cur.fetchone() + + assert credential_row == ( + f"token-for-{account_payload['account']['provider_account_id']}", + "refresh-owner-001", + "2020-01-01T00:00:00+00:00", + ) + + def test_gmail_message_ingestion_endpoint_rejects_cross_user_workspace_access( migrated_database_urls, monkeypatch, diff --git a/tests/unit/test_gmail.py b/tests/unit/test_gmail.py index f482807..3242532 100644 --- a/tests/unit/test_gmail.py +++ b/tests/unit/test_gmail.py @@ -19,8 +19,10 @@ GmailAccountNotFoundError, GmailCredentialInvalidError, GmailCredentialNotFoundError, + GmailCredentialPersistenceError, GmailCredentialValidationError, GmailMessageUnsupportedError, + RefreshedGmailCredential, build_gmail_message_artifact_relative_path, build_gmail_protected_credential_blob, create_gmail_account_record, @@ -29,6 +31,7 @@ list_gmail_account_records, resolve_gmail_access_token, ) +from alicebot_api.store import ContinuityStoreInvariantError from alicebot_api.workspaces import TaskWorkspaceNotFoundError @@ -460,7 +463,10 @@ def test_resolve_gmail_access_token_renews_expired_refreshable_credential(monkey monkeypatch.setattr( "alicebot_api.gmail.refresh_gmail_access_token", - lambda **_kwargs: ("token-2", refreshed_at), + lambda **_kwargs: RefreshedGmailCredential( + access_token="token-2", + access_token_expires_at=refreshed_at, + ), ) assert resolve_gmail_access_token(store, gmail_account_id=account_id) == "token-2" @@ -478,6 +484,94 @@ def test_resolve_gmail_access_token_renews_expired_refreshable_credential(monkey ] +def test_resolve_gmail_access_token_persists_rotated_refresh_token(monkeypatch) -> None: + store = GmailStoreStub() + expired_at = datetime(2020, 1, 1, 0, 0, tzinfo=UTC) + refreshed_at = datetime(2030, 1, 1, 0, 5, tzinfo=UTC) + account = create_gmail_account_record( + store, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + refresh_token="refresh-1", + client_id="client-1", + client_secret="secret-1", + access_token_expires_at=expired_at, + ), + )["account"] + account_id = UUID(account["id"]) + + monkeypatch.setattr( + "alicebot_api.gmail.refresh_gmail_access_token", + lambda **_kwargs: RefreshedGmailCredential( + access_token="token-2", + access_token_expires_at=refreshed_at, + refresh_token="refresh-2", + ), + ) + + assert resolve_gmail_access_token(store, gmail_account_id=account_id) == "token-2" + assert store.gmail_account_credentials[account_id]["credential_blob"] == { + "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, + "access_token": "token-2", + "refresh_token": "refresh-2", + "client_id": "client-1", + "client_secret": "secret-1", + "access_token_expires_at": refreshed_at.isoformat(), + } + + +def test_resolve_gmail_access_token_fails_deterministically_when_persisting_refresh_update_fails( + monkeypatch, +) -> None: + store = GmailStoreStub() + expired_at = datetime(2020, 1, 1, 0, 0, tzinfo=UTC) + refreshed_at = datetime(2030, 1, 1, 0, 5, tzinfo=UTC) + account = create_gmail_account_record( + store, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + refresh_token="refresh-1", + client_id="client-1", + client_secret="secret-1", + access_token_expires_at=expired_at, + ), + )["account"] + account_id = UUID(account["id"]) + original_blob = dict(store.gmail_account_credentials[account_id]["credential_blob"]) + + monkeypatch.setattr( + "alicebot_api.gmail.refresh_gmail_access_token", + lambda **_kwargs: RefreshedGmailCredential( + access_token="token-2", + access_token_expires_at=refreshed_at, + refresh_token="refresh-2", + ), + ) + + def fail_update(**_kwargs): + raise ContinuityStoreInvariantError("update_gmail_account_credential did not return a row") + + monkeypatch.setattr(store, "update_gmail_account_credential", fail_update) + + with pytest.raises( + GmailCredentialPersistenceError, + match=f"gmail account {account_id} renewed protected credentials could not be persisted", + ): + resolve_gmail_access_token(store, gmail_account_id=account_id) + + assert store.gmail_account_credentials[account_id]["credential_blob"] == original_blob + + def test_resolve_gmail_access_token_rejects_invalid_refreshable_protected_credentials() -> None: store = GmailStoreStub() account = create_gmail_account_record( @@ -662,7 +756,10 @@ def test_ingest_gmail_message_record_renews_expired_access_token_before_fetch( monkeypatch.setattr( "alicebot_api.gmail.refresh_gmail_access_token", - lambda **_kwargs: ("token-refreshed", datetime(2030, 1, 1, 0, 5, tzinfo=UTC)), + lambda **_kwargs: RefreshedGmailCredential( + access_token="token-refreshed", + access_token_expires_at=datetime(2030, 1, 1, 0, 5, tzinfo=UTC), + ), ) def fake_fetch(**kwargs): diff --git a/tests/unit/test_gmail_main.py b/tests/unit/test_gmail_main.py index 24adb3a..2eaf011 100644 --- a/tests/unit/test_gmail_main.py +++ b/tests/unit/test_gmail_main.py @@ -12,6 +12,7 @@ GmailAccountNotFoundError, GmailCredentialInvalidError, GmailCredentialNotFoundError, + GmailCredentialPersistenceError, GmailCredentialRefreshError, GmailCredentialValidationError, GmailMessageFetchError, @@ -266,6 +267,25 @@ def fake_invalid_credentials(*_args, **_kwargs): "detail": f"gmail account {gmail_account_id} has invalid protected credentials" } + def fake_persistence_error(*_args, **_kwargs): + raise GmailCredentialPersistenceError( + f"gmail account {gmail_account_id} renewed protected credentials could not be persisted" + ) + + monkeypatch.setattr(main_module, "ingest_gmail_message_record", fake_persistence_error) + response = main_module.ingest_gmail_message( + gmail_account_id, + "msg-001", + main_module.IngestGmailMessageRequest( + user_id=user_id, + task_workspace_id=task_workspace_id, + ), + ) + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"gmail account {gmail_account_id} renewed protected credentials could not be persisted" + } + def fake_fetch_error(*_args, **_kwargs): raise GmailMessageFetchError("gmail message msg-001 could not be fetched") diff --git a/tests/unit/test_gmail_refresh.py b/tests/unit/test_gmail_refresh.py index 779e454..da7c9d7 100644 --- a/tests/unit/test_gmail_refresh.py +++ b/tests/unit/test_gmail_refresh.py @@ -14,6 +14,7 @@ GMAIL_TOKEN_REFRESH_URL, GmailCredentialInvalidError, GmailCredentialRefreshError, + RefreshedGmailCredential, refresh_gmail_access_token, ) @@ -59,7 +60,7 @@ def fake_urlopen(request, timeout: int): monkeypatch.setattr("alicebot_api.gmail.urlopen", fake_urlopen) started_at = datetime.now(UTC) - access_token, expires_at = refresh_gmail_access_token( + refreshed_credential = refresh_gmail_access_token( gmail_account_id=gmail_account_id, refresh_token="refresh-001", client_id="client-001", @@ -67,8 +68,14 @@ def fake_urlopen(request, timeout: int): ) finished_at = datetime.now(UTC) - assert access_token == "token-refreshed" - assert started_at + timedelta(seconds=3590) <= expires_at <= finished_at + timedelta(seconds=3610) + assert refreshed_credential == RefreshedGmailCredential( + access_token="token-refreshed", + access_token_expires_at=refreshed_credential.access_token_expires_at, + refresh_token=None, + ) + assert started_at + timedelta(seconds=3590) <= refreshed_credential.access_token_expires_at <= ( + finished_at + timedelta(seconds=3610) + ) assert seen == { "url": GMAIL_TOKEN_REFRESH_URL, "timeout": GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS, @@ -83,6 +90,35 @@ def fake_urlopen(request, timeout: int): } +def test_refresh_gmail_access_token_returns_rotated_refresh_token_when_provider_supplies_one( + monkeypatch, +) -> None: + gmail_account_id = uuid4() + + def fake_urlopen(_request, timeout: int): + assert timeout == GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS + return _FakeHTTPResponse( + json.dumps( + { + "access_token": "token-refreshed", + "expires_in": 3600, + "refresh_token": "refresh-rotated", + } + ).encode("utf-8") + ) + + monkeypatch.setattr("alicebot_api.gmail.urlopen", fake_urlopen) + + refreshed_credential = refresh_gmail_access_token( + gmail_account_id=gmail_account_id, + refresh_token="refresh-001", + client_id="client-001", + client_secret="secret-001", + ) + + assert refreshed_credential.refresh_token == "refresh-rotated" + + @pytest.mark.parametrize("status_code", [400, 401]) def test_refresh_gmail_access_token_maps_invalid_refresh_rejections_to_invalid_error( monkeypatch, From e28aa92ba336ef8e0012f48fc3e92f93463f12d0 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 22:55:50 +0100 Subject: [PATCH 019/135] Sprint 5S: project truth synchronization after Gmail auth hardening (#19) Co-authored-by: Sami Rusani --- ARCHITECTURE.md | 33 +++++++++++++++++++++------------ ROADMAP.md | 26 ++++++++++++++------------ 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 3e7ad09..c9b66d1 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,17 +2,17 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5Q. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5R. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` protected credential storage, deterministic account reads without secret exposure, refresh-token-capable protected credential renewal for expired access tokens, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline +- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` protected credential storage, deterministic account reads without secret exposure, refresh-token-capable protected credential renewal for expired access tokens, rotated refresh-token persistence when the provider returns a replacement token, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline - durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, narrow DOCX text, and narrow RFC822 email text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with credentials resolved through a dedicated protected credential seam, renewed through one explicit refresh path when an expired refresh-capable credential is present, and never exposed on the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with credentials resolved through a dedicated protected credential seam, renewed through one explicit refresh path when an expired refresh-capable credential is present, any provider-returned rotated refresh token persisted back through that same protected credential seam before Gmail fetches continue, and never exposed on the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -60,11 +60,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ### Repo Boundaries In This Slice -- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, a narrow read-only Gmail connector seam with protected refresh-token lifecycle support, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, a narrow read-only Gmail connector seam with protected refresh-token lifecycle and refresh-token rotation handling support, tasks, task steps, task workspaces, task artifacts, artifact-chunk embeddings, deterministic lexical artifact chunk retrieval, deterministic semantic artifact chunk retrieval over durable embeddings, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, Sprint 5P Gmail credential hardening coverage, and Sprint 5Q Gmail refresh-token lifecycle coverage. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, Sprint 5P Gmail credential hardening coverage, Sprint 5Q Gmail refresh-token lifecycle coverage, and Sprint 5R Gmail refresh-token rotation handling coverage. ## Core Flows Implemented Now @@ -100,12 +100,13 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 4. Expose deterministic user-scoped Gmail account list and detail reads without secret material. 5. Accept a user-scoped `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest` request for one visible Gmail account and one visible task workspace. 6. Resolve the Gmail access token through the protected credential seam before any Gmail fetch, file write, or artifact registration, and renew it first through one explicit refresh path when the visible protected credential is refresh-capable and expired. -7. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. -8. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. -9. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. -10. Require Gmail to return RFC822 `raw` content, validate it against the existing narrow `message/rfc822` extraction rules, and reject unsupported content deterministically. -11. Materialize the message as one rooted `.eml` file inside the selected task workspace and then reuse the existing task-artifact registration plus artifact-ingestion seam. -12. Persist only the resulting `task_artifacts` and `task_artifact_chunks` rows; account-wide sync, search, attachments, Calendar, and write-capable actions remain out of scope. +7. When the refresh provider returns a replacement refresh token, persist that rotated token back through the same protected credential seam before Gmail fetches continue. +8. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. +9. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. +10. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. +11. Require Gmail to return RFC822 `raw` content, validate it against the existing narrow `message/rfc822` extraction rules, and reject unsupported content deterministically. +12. Materialize the message as one rooted `.eml` file inside the selected task workspace and then reuse the existing task-artifact registration plus artifact-ingestion seam. +13. Persist only the resulting `task_artifacts` and `task_artifact_chunks` rows; account-wide sync, search, attachments, Calendar, and write-capable actions remain out of scope. ### Governed Memory And Retrieval @@ -263,7 +264,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i ## Testing Coverage Implemented Now - Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, compile-path hybrid memory retrieval, artifact lexical retrieval, artifact semantic retrieval, compile-path semantic artifact retrieval, hybrid artifact compile merge, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. -- Sprints 4O through 5J added explicit task lifecycle and artifact retrieval coverage: +- Sprints 4O through 5R added explicit task lifecycle, artifact retrieval, richer-document ingestion, and narrow Gmail coverage: - migrations for `tasks`, `task_steps`, and task-step lineage - staged/backfilled migration coverage for `tool_executions.task_step_id` - task and task-step store contracts @@ -289,6 +290,14 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - direct semantic artifact-chunk retrieval by task and by artifact - compile-path semantic artifact retrieval including trace visibility, exclusion rules, and scope isolation - deterministic hybrid artifact compile merge with dual-source provenance, deduplication, lexical-first precedence, and shared limit enforcement + - narrow PDF ingestion success and failure paths + - narrow DOCX ingestion success and failure paths + - narrow RFC822 ingestion success and failure paths + - read-only Gmail account connect/list/detail coverage with secret-free reads + - selected Gmail message ingestion through the rooted RFC822 artifact path + - protected Gmail credential storage isolation in `gmail_account_credentials` + - refresh-token renewal for expired refresh-capable Gmail credentials + - rotated refresh-token persistence when the provider returns a replacement token - task-artifact and task-artifact-chunk per-user isolation - trace visibility for continuation and transition events - user isolation for task and task-step reads and mutations diff --git a/ROADMAP.md b/ROADMAP.md index a3d9e21..936fca2 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -2,27 +2,29 @@ ## Current Position -- The accepted repo state is current through Sprint 5J. -- Milestone 5 now ships the rooted local workspace and artifact baseline end to end: workspace provisioning, artifact registration, narrow text ingestion, durable chunk storage, lexical artifact retrieval, compile-path artifact inclusion, artifact-chunk embeddings, direct semantic artifact retrieval, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. +- The accepted repo state is current through Sprint 5R. +- Milestone 5 now ships the rooted local workspace and artifact baseline end to end: workspace provisioning, artifact registration, narrow text ingestion, narrow PDF/DOCX/RFC822 ingestion, durable chunk storage, lexical artifact retrieval, compile-path artifact inclusion, artifact-chunk embeddings, direct semantic artifact retrieval, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. +- The same milestone also now ships the narrow Gmail seam: read-only Gmail account persistence, secret-free account reads, protected credential storage in `gmail_account_credentials`, refresh-token renewal for expired access tokens, rotated refresh-token persistence when the provider returns a replacement token, and one explicit selected-message ingestion path that lands in the existing RFC822 artifact pipeline. - This roadmap is future-facing from that shipped baseline; historical sprint-by-sprint detail lives in accepted build and review artifacts, not here. ## Next Delivery Focus -### Open Richer Document Parsing On Top Of The Shipped Artifact Retrieval Baseline +### Open One More Narrow Gmail Auth Seam On Top Of The Shipped Baseline -- Extend ingestion beyond the current `text/plain` and `text/markdown` seam without changing the rooted `task_workspaces` and durable `task_artifact_chunks` contracts. -- Keep retrieval building on persisted chunk rows and persisted embeddings; new parsing work should feed the existing compile-path lexical/semantic/hybrid artifact retrieval seam rather than inventing a parallel context path. -- Keep the next sprint narrow: richer document parsing first, then reassess connectors only after the parsing seam is accepted. +- Keep the next sprint auth-adjacent and narrow, building on the shipped protected-credential-backed Gmail seam rather than widening connector breadth. +- The next best seam is external secret-manager integration for the existing `gmail_account_credentials` boundary, without changing the read-only account contract or the single-message ingestion contract. +- Do not combine that work with Gmail search, mailbox sync, attachment ingestion, Calendar scope, UI work, or broader connector orchestration. -### Preserve Current Compile, Governance, And Task Guarantees +### Preserve Current Document, Compile, Governance, And Task Guarantees +- Keep the shipped PDF, DOCX, and RFC822 ingestion seams narrow and deterministic; richer parsing, OCR, layout reconstruction, attachment handling, and broader email processing still need separate later seams. - Keep approvals, execution budgets, task/task-step state, and trace visibility deterministic as Milestone 5 continues. - Preserve the shipped compile contract of one merged artifact section with explicit source provenance, deterministic lexical-first precedence, and trace-visible inclusion and exclusion decisions. -- Do not widen the current no-external-I/O proxy surface or introduce runner, connector, or UI scope until those areas are explicitly opened. +- Do not widen the current no-external-I/O proxy surface or introduce broader connector, runner, or UI scope until those areas are explicitly opened. ## After The Next Narrow Sprint -- Open read-only connector work only after richer document parsing remains deterministic under the current artifact and governance seams. +- Reassess broader connector work only after the current Gmail protected-credential boundary remains stable under externalized secret storage and the truth artifacts stay synchronized. - Revisit workflow UI only after backend document and connector seams are accepted and the truth artifacts stay current. - Revisit broader task orchestration only after the current explicit task-step seams remain stable under workspace, artifact, document, and connector flows. - Continue to defer broader tool execution breadth and production auth/deployment hardening until the current governed surface remains stable. @@ -30,12 +32,12 @@ ## Dependencies - Live truth docs must stay synchronized with accepted repo state so sprint planning does not start from stale assumptions. -- Rich document parsing should build on the shipped rooted local workspace, durable artifact chunk, and hybrid compile retrieval contracts. -- Connector work should remain read-only, approval-aware, and downstream of the document parsing seam. +- Rich document parsing work should continue to build on the shipped rooted local workspace, durable artifact chunk, and hybrid compile retrieval contracts. +- Connector work should remain read-only, single-message-only, approval-aware, and protected-credential-backed until a later sprint explicitly opens broader scope. - Runner-style orchestration should stay deferred until the repo no longer depends on narrow current-step assumptions for safety and explainability. ## Ongoing Risks - Memory extraction and retrieval quality remain the largest product risk. - Auth beyond database user context is still missing. -- Milestone 5 can drift if richer document parsing, connectors, UI, and orchestration work are mixed into one sprint instead of landing as narrow seams on top of the shipped artifact retrieval baseline. +- Milestone 5 can drift if Gmail auth hardening, broader connector breadth, UI, richer parsing, and orchestration work are mixed into one sprint instead of landing as narrow seams on top of the shipped document-ingestion and Gmail baseline. From e13dca3f68206d560b30340f6828768c279a67f6 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 23:57:04 +0100 Subject: [PATCH 020/135] Sprint 5T: externalize Gmail credential storage --- ARCHITECTURE.md | 17 +- ...0316_0029_gmail_external_secret_manager.py | 126 +++++++++ apps/api/src/alicebot_api/config.py | 7 + apps/api/src/alicebot_api/gmail.py | 263 ++++++++++++++++-- .../src/alicebot_api/gmail_secret_manager.py | 97 +++++++ apps/api/src/alicebot_api/main.py | 7 + apps/api/src/alicebot_api/store.py | 51 +++- tests/integration/test_gmail_accounts_api.py | 214 +++++++++++--- tests/integration/test_migrations.py | 110 ++++++++ ...0316_0029_gmail_external_secret_manager.py | 39 +++ tests/unit/test_config.py | 6 + tests/unit/test_gmail.py | 199 +++++++++++-- tests/unit/test_gmail_main.py | 50 +++- tests/unit/test_gmail_secret_manager.py | 44 +++ 14 files changed, 1125 insertions(+), 105 deletions(-) create mode 100644 apps/api/alembic/versions/20260316_0029_gmail_external_secret_manager.py create mode 100644 apps/api/src/alicebot_api/gmail_secret_manager.py create mode 100644 tests/unit/test_20260316_0029_gmail_external_secret_manager.py create mode 100644 tests/unit/test_gmail_secret_manager.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index c9b66d1..1eacd8c 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,17 +2,17 @@ ## Current Implemented Slice -AliceBot now implements the accepted repo slice through Sprint 5R. The shipped backend includes: +AliceBot now implements the accepted repo slice through Sprint 5T. The shipped backend includes: - foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` - deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records - governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge - deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events - user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement -- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` protected credential storage, deterministic account reads without secret exposure, refresh-token-capable protected credential renewal for expired access tokens, rotated refresh-token persistence when the provider returns a replacement token, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline +- a narrow read-only Gmail connector seam with user-scoped `gmail_accounts` metadata persistence, separate user-scoped `gmail_account_credentials` locator metadata for the primary credential path, one explicit Gmail secret-manager adapter seam for secret reads and writes, deterministic account reads without secret exposure, refresh-token-capable credential renewal for expired access tokens, rotated refresh-token persistence when the provider returns a replacement token, an explicit `legacy_db_v0` transition path externalized on first credential read for older rows, and one explicit selected-message ingestion path that materializes one Gmail message as a rooted `.eml` task artifact and then reuses the existing RFC822 artifact ingestion pipeline - durable `tasks`, `task_steps`, `task_workspaces`, `task_artifacts`, `task_artifact_chunks`, and `task_artifact_chunk_embeddings`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, deterministic rooted local task-workspace provisioning, explicit rooted local artifact registration, deterministic local plain-text, markdown, narrow PDF text, narrow DOCX text, and narrow RFC822 email text ingestion into durable chunk rows, deterministic lexical artifact-chunk retrieval over durable chunk rows, explicit user-scoped artifact-chunk embedding persistence tied to existing embedding configs, explicit task-scoped or artifact-scoped semantic artifact-chunk retrieval over those durable embeddings, and compile-path artifact retrieval that can include lexical results, semantic results, or one deterministic hybrid lexical-plus-semantic merged artifact section with per-chunk source provenance -The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with credentials resolved through a dedicated protected credential seam, renewed through one explicit refresh path when an expired refresh-capable credential is present, any provider-returned rotated refresh token persisted back through that same protected credential seam before Gmail fetches continue, and never exposed on the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries, and task artifacts are now implemented only as explicit rooted local-file registrations, narrow deterministic artifact ingestion under those workspaces, lexical retrieval over persisted chunk rows, explicit artifact-chunk embedding storage tied to existing embedding configs, direct semantic retrieval over those durable artifact-chunk embeddings for one visible task or one visible artifact at a time, and compile-path artifact retrieval that deterministically merges lexical and semantic candidates into one artifact section when both are requested for the same scope. The live richer-document boundary is still intentionally narrow: plain text and markdown ingest directly, PDF support is limited to narrow local text extraction, DOCX support is limited to narrow local text extraction from `word/document.xml`, RFC822 email support is limited to top-level selected headers plus extractable plain-text body content while excluding nested `message/rfc822` content, and the live connector boundary is limited to one read-only Gmail account seam plus one explicit selected-message ingestion path into the rooted RFC822 artifact pipeline, with the primary Gmail credential path now storing only locator metadata on `gmail_account_credentials`, resolving secrets through one explicit Gmail secret-manager adapter seam before fetches continue, renewing expired refresh-capable credentials through that same seam, persisting any provider-returned rotated refresh token back through that same seam, keeping a narrow `legacy_db_v0` first-read externalization path for older rows only, and never exposing secret material on the normal account metadata table surface. OCR, image extraction, layout reconstruction, Gmail search, mailbox sync, attachments, Calendar connectors, reranking beyond the current lexical-first hybrid merge, and new side-effect surfaces are still planned later and must not be described as live behavior. ## Implemented Now @@ -64,7 +64,7 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - `apps/web`: minimal shell only; no shipped workflow UI. - `workers`: scaffold only; no background jobs or runner logic are implemented. - `infra`: local development bootstrap assets only. -- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, Sprint 5P Gmail credential hardening coverage, Sprint 5Q Gmail refresh-token lifecycle coverage, and Sprint 5R Gmail refresh-token rotation handling coverage. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, Sprint 5A task-workspace provisioning, Sprint 5C task-artifact registration, Sprint 5D local artifact ingestion plus chunk reads, Sprint 5E lexical artifact-chunk retrieval, Sprint 5F compile-path artifact chunk integration, Sprint 5G artifact-chunk embedding persistence and reads, Sprint 5H direct semantic artifact-chunk retrieval, Sprint 5I compile-path semantic artifact retrieval, Sprint 5J deterministic hybrid lexical-plus-semantic artifact merge in compile, Sprint 5L narrow PDF artifact ingestion, Sprint 5M narrow DOCX artifact ingestion, Sprint 5N narrow RFC822 email artifact ingestion, Sprint 5O read-only Gmail account plus single-message ingestion coverage, Sprint 5P Gmail credential hardening coverage, Sprint 5Q Gmail refresh-token lifecycle coverage, Sprint 5R Gmail refresh-token rotation handling coverage, and Sprint 5T Gmail external secret-manager coverage. ## Core Flows Implemented Now @@ -96,11 +96,11 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i 1. Accept a user-scoped `POST /v0/gmail-accounts` request for one read-only Gmail account metadata record. 2. Persist exactly the narrow connector metadata required for later reads on `gmail_accounts`: `provider_account_id`, `email_address`, optional `display_name`, and the fixed Gmail read-only scope. -3. Persist Gmail secrets only in the dedicated `gmail_account_credentials` protected credential seam bound to the same visible user/account ownership scope, using either a narrow access-token-only shape or a refresh-token-capable credential shape with expiry metadata. +3. Persist only Gmail credential locator metadata on `gmail_account_credentials` for the primary path, and store the secret payload itself through one explicit Gmail secret-manager adapter seam bound to the same visible user/account ownership scope. 4. Expose deterministic user-scoped Gmail account list and detail reads without secret material. 5. Accept a user-scoped `POST /v0/gmail-accounts/{gmail_account_id}/messages/{provider_message_id}/ingest` request for one visible Gmail account and one visible task workspace. -6. Resolve the Gmail access token through the protected credential seam before any Gmail fetch, file write, or artifact registration, and renew it first through one explicit refresh path when the visible protected credential is refresh-capable and expired. -7. When the refresh provider returns a replacement refresh token, persist that rotated token back through the same protected credential seam before Gmail fetches continue. +6. Resolve the Gmail access token through the Gmail secret-manager adapter seam before any Gmail fetch, file write, or artifact registration, and renew it first through one explicit refresh path when the visible refresh-capable credential is expired. +7. When the refresh provider returns a replacement refresh token, persist that rotated token back through the same Gmail secret-manager adapter seam before Gmail fetches continue. 8. Derive one deterministic workspace-relative artifact path as `gmail//.eml`. 9. Reject duplicate `(task_workspace_id, relative_path)` collisions before any Gmail fetch or file write. 10. Fetch exactly one selected Gmail message through the read-only Gmail API path `users/me/messages/{provider_message_id}?format=raw`. @@ -295,9 +295,10 @@ The current multi-step boundary is narrow and explicit. Manual continuation is i - narrow RFC822 ingestion success and failure paths - read-only Gmail account connect/list/detail coverage with secret-free reads - selected Gmail message ingestion through the rooted RFC822 artifact path - - protected Gmail credential storage isolation in `gmail_account_credentials` + - primary Gmail credential locator storage isolation in `gmail_account_credentials` - refresh-token renewal for expired refresh-capable Gmail credentials - rotated refresh-token persistence when the provider returns a replacement token + - external Gmail secret-manager reference persistence and resolution - task-artifact and task-artifact-chunk per-user isolation - trace visibility for continuation and transition events - user isolation for task and task-step reads and mutations diff --git a/apps/api/alembic/versions/20260316_0029_gmail_external_secret_manager.py b/apps/api/alembic/versions/20260316_0029_gmail_external_secret_manager.py new file mode 100644 index 0000000..ff436f0 --- /dev/null +++ b/apps/api/alembic/versions/20260316_0029_gmail_external_secret_manager.py @@ -0,0 +1,126 @@ +"""Add external secret-manager references for Gmail protected credentials.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260316_0029" +down_revision = "20260316_0028" +branch_labels = None +depends_on = None + +GMAIL_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_access_token_v1" +GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND = "gmail_oauth_refresh_token_v2" +GMAIL_SECRET_MANAGER_KIND_FILE_V1 = "file_v1" +GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0 = "legacy_db_v0" + +_CREDENTIAL_BLOB_SHAPE_CHECK = f""" + ( + jsonb_typeof(credential_blob) = 'object' + AND credential_blob ? 'credential_kind' + AND credential_blob ? 'access_token' + AND jsonb_typeof(credential_blob -> 'access_token') = 'string' + AND length(credential_blob ->> 'access_token') > 0 + AND ( + ( + credential_blob ->> 'credential_kind' = '{GMAIL_PROTECTED_CREDENTIAL_KIND}' + ) + OR + ( + credential_blob ->> 'credential_kind' = '{GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND}' + AND credential_blob ? 'refresh_token' + AND credential_blob ? 'client_id' + AND credential_blob ? 'client_secret' + AND credential_blob ? 'access_token_expires_at' + AND jsonb_typeof(credential_blob -> 'refresh_token') = 'string' + AND jsonb_typeof(credential_blob -> 'client_id') = 'string' + AND jsonb_typeof(credential_blob -> 'client_secret') = 'string' + AND jsonb_typeof(credential_blob -> 'access_token_expires_at') = 'string' + AND length(credential_blob ->> 'refresh_token') > 0 + AND length(credential_blob ->> 'client_id') > 0 + AND length(credential_blob ->> 'client_secret') > 0 + AND length(credential_blob ->> 'access_token_expires_at') > 0 + ) + ) + ) +""" + +_UPGRADE_STATEMENTS = ( + "ALTER TABLE gmail_account_credentials DROP CONSTRAINT gmail_account_credentials_blob_shape_check", + "ALTER TABLE gmail_account_credentials ADD COLUMN credential_kind text", + "ALTER TABLE gmail_account_credentials ADD COLUMN secret_manager_kind text", + "ALTER TABLE gmail_account_credentials ADD COLUMN secret_ref text", + "ALTER TABLE gmail_account_credentials ALTER COLUMN credential_blob DROP NOT NULL", + """ + UPDATE gmail_account_credentials + SET credential_kind = credential_blob ->> 'credential_kind', + secret_manager_kind = 'legacy_db_v0' + """, + "ALTER TABLE gmail_account_credentials ALTER COLUMN credential_kind SET NOT NULL", + "ALTER TABLE gmail_account_credentials ALTER COLUMN secret_manager_kind SET NOT NULL", + f""" + ALTER TABLE gmail_account_credentials + ADD CONSTRAINT gmail_account_credentials_storage_shape_check + CHECK ( + credential_kind IN ( + '{GMAIL_PROTECTED_CREDENTIAL_KIND}', + '{GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND}' + ) + AND ( + ( + secret_manager_kind = '{GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0}' + AND secret_ref IS NULL + AND {_CREDENTIAL_BLOB_SHAPE_CHECK} + ) + OR + ( + secret_manager_kind = '{GMAIL_SECRET_MANAGER_KIND_FILE_V1}' + AND secret_ref IS NOT NULL + AND length(secret_ref) > 0 + AND credential_blob IS NULL + ) + ) + ) + """, +) + +_DOWNGRADE_STATEMENTS = ( + """ + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 + FROM gmail_account_credentials + WHERE secret_manager_kind = 'file_v1' + ) THEN + RAISE EXCEPTION + 'cannot downgrade gmail_account_credentials while external Gmail secrets are present'; + END IF; + END + $$; + """, + "ALTER TABLE gmail_account_credentials DROP CONSTRAINT gmail_account_credentials_storage_shape_check", + "ALTER TABLE gmail_account_credentials ALTER COLUMN credential_blob SET NOT NULL", + "ALTER TABLE gmail_account_credentials DROP COLUMN secret_ref", + "ALTER TABLE gmail_account_credentials DROP COLUMN secret_manager_kind", + "ALTER TABLE gmail_account_credentials DROP COLUMN credential_kind", + f""" + ALTER TABLE gmail_account_credentials + ADD CONSTRAINT gmail_account_credentials_blob_shape_check + CHECK {_CREDENTIAL_BLOB_SHAPE_CHECK} + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/config.py b/apps/api/src/alicebot_api/config.py index 3f41bd7..49c7425 100644 --- a/apps/api/src/alicebot_api/config.py +++ b/apps/api/src/alicebot_api/config.py @@ -31,6 +31,7 @@ DEFAULT_MODEL_API_KEY = "" DEFAULT_MODEL_TIMEOUT_SECONDS = 30 DEFAULT_TASK_WORKSPACE_ROOT = "/tmp/alicebot/task-workspaces" +DEFAULT_GMAIL_SECRET_MANAGER_URL = "" Environment = Mapping[str, str] @@ -69,6 +70,7 @@ class Settings: model_api_key: str = DEFAULT_MODEL_API_KEY model_timeout_seconds: int = DEFAULT_MODEL_TIMEOUT_SECONDS task_workspace_root: str = DEFAULT_TASK_WORKSPACE_ROOT + gmail_secret_manager_url: str = DEFAULT_GMAIL_SECRET_MANAGER_URL @classmethod def from_env(cls, env: Environment | None = None) -> "Settings": @@ -111,6 +113,11 @@ def from_env(cls, env: Environment | None = None) -> "Settings": "TASK_WORKSPACE_ROOT", cls.task_workspace_root, ), + gmail_secret_manager_url=_get_env_value( + current_env, + "GMAIL_SECRET_MANAGER_URL", + cls.gmail_secret_manager_url, + ), ) diff --git a/apps/api/src/alicebot_api/gmail.py b/apps/api/src/alicebot_api/gmail.py index 016c8c3..3d023a9 100644 --- a/apps/api/src/alicebot_api/gmail.py +++ b/apps/api/src/alicebot_api/gmail.py @@ -39,13 +39,19 @@ TaskArtifactIngestInput, TaskArtifactRegisterInput, ) -from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError, GmailAccountRow +from alicebot_api.gmail_secret_manager import ( + GMAIL_SECRET_MANAGER_KIND_FILE_V1, + GmailSecretManager, + GmailSecretManagerError, +) +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError, GmailAccountRow, JsonObject from alicebot_api.workspaces import TaskWorkspaceNotFoundError GMAIL_MESSAGE_FETCH_TIMEOUT_SECONDS = 30 GMAIL_TOKEN_REFRESH_TIMEOUT_SECONDS = 30 GMAIL_TOKEN_REFRESH_URL = "https://oauth2.googleapis.com/token" GMAIL_MESSAGE_ARTIFACT_ROOT = "gmail" +GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0 = "legacy_db_v0" _PATH_SEGMENT_PATTERN = re.compile(r"[^A-Za-z0-9._-]+") @@ -106,6 +112,15 @@ class RefreshedGmailCredential: refresh_token: str | None = None +@dataclass(frozen=True, slots=True) +class ResolvedGmailCredential: + parsed_credential: ParsedGmailCredential + credential_kind: str + secret_manager_kind: str + secret_ref: str | None + credential_blob: JsonObject | None + + def serialize_gmail_account_row(row: GmailAccountRow) -> GmailAccountRecord: return { "id": str(row["id"]), @@ -184,6 +199,152 @@ def build_gmail_protected_credential_blob( } +def build_gmail_secret_ref(*, user_id: UUID, gmail_account_id: UUID) -> str: + return f"users/{user_id}/gmail-account-credentials/{gmail_account_id}.json" + + +def _write_external_gmail_secret( + secret_manager: GmailSecretManager, + *, + gmail_account_id: UUID, + secret_ref: str, + credential_blob: JsonObject, +) -> None: + try: + secret_manager.write_secret(secret_ref=secret_ref, payload=credential_blob) + except GmailSecretManagerError as exc: + raise GmailCredentialPersistenceError( + f"gmail account {gmail_account_id} protected credentials could not be persisted" + ) from exc + + +def _load_external_gmail_secret( + secret_manager: GmailSecretManager, + *, + gmail_account_id: UUID, + secret_ref: str, +) -> JsonObject: + try: + return secret_manager.load_secret(secret_ref=secret_ref) + except GmailSecretManagerError as exc: + message = str(exc) + if message.endswith("was not found"): + raise GmailCredentialNotFoundError( + f"gmail account {gmail_account_id} is missing protected credentials" + ) from exc + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) from exc + + +def _persist_external_gmail_credential_metadata( + store: ContinuityStore, + *, + gmail_account_id: UUID, + auth_kind: str, + credential_kind: str, + secret_manager_kind: str, + secret_ref: str, +) -> None: + store.update_gmail_account_credential( + gmail_account_id=gmail_account_id, + auth_kind=auth_kind, + credential_kind=credential_kind, + secret_manager_kind=secret_manager_kind, + secret_ref=secret_ref, + credential_blob=None, + ) + + +def _resolve_gmail_credential( + store: ContinuityStore, + secret_manager: GmailSecretManager, + *, + gmail_account_id: UUID, +) -> ResolvedGmailCredential: + credential = store.get_gmail_account_credential_optional(gmail_account_id) + if credential is None: + raise GmailCredentialNotFoundError( + f"gmail account {gmail_account_id} is missing protected credentials" + ) + + if credential["auth_kind"] != GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + if credential["secret_manager_kind"] == GMAIL_SECRET_MANAGER_KIND_FILE_V1: + secret_ref = _coerce_nonempty_string(credential["secret_ref"]) + if secret_ref is None: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + return ResolvedGmailCredential( + parsed_credential=_parse_gmail_credential( + gmail_account_id=gmail_account_id, + credential_blob=_load_external_gmail_secret( + secret_manager, + gmail_account_id=gmail_account_id, + secret_ref=secret_ref, + ), + ), + credential_kind=credential["credential_kind"], + secret_manager_kind=credential["secret_manager_kind"], + secret_ref=secret_ref, + credential_blob=None, + ) + + if credential["secret_manager_kind"] != GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + if credential["credential_blob"] is None: + raise GmailCredentialInvalidError( + f"gmail account {gmail_account_id} has invalid protected credentials" + ) + + parsed_credential = _parse_gmail_credential( + gmail_account_id=gmail_account_id, + credential_blob=credential["credential_blob"], + ) + secret_ref = build_gmail_secret_ref( + user_id=credential["user_id"], + gmail_account_id=gmail_account_id, + ) + _write_external_gmail_secret( + secret_manager, + gmail_account_id=gmail_account_id, + secret_ref=secret_ref, + credential_blob=credential["credential_blob"], + ) + try: + _persist_external_gmail_credential_metadata( + store, + gmail_account_id=gmail_account_id, + auth_kind=credential["auth_kind"], + credential_kind=parsed_credential.credential_kind, + secret_manager_kind=secret_manager.kind, + secret_ref=secret_ref, + ) + except (ContinuityStoreInvariantError, psycopg.Error) as exc: + try: + secret_manager.delete_secret(secret_ref=secret_ref) + except GmailSecretManagerError: + pass + raise GmailCredentialPersistenceError( + f"gmail account {gmail_account_id} protected credentials could not be persisted" + ) from exc + + return ResolvedGmailCredential( + parsed_credential=parsed_credential, + credential_kind=parsed_credential.credential_kind, + secret_manager_kind=secret_manager.kind, + secret_ref=secret_ref, + credential_blob=None, + ) + + def _parse_gmail_credential( *, gmail_account_id: UUID, @@ -305,30 +466,48 @@ def refresh_gmail_access_token( def _persist_refreshed_gmail_credential( store: ContinuityStore, + secret_manager: GmailSecretManager, *, gmail_account_id: UUID, auth_kind: str, + secret_ref: str, existing_credential: ParsedGmailCredential, refreshed_credential: RefreshedGmailCredential, ) -> None: + original_credential_blob = build_gmail_protected_credential_blob( + access_token=existing_credential.access_token, + refresh_token=existing_credential.refresh_token, + client_id=existing_credential.client_id, + client_secret=existing_credential.client_secret, + access_token_expires_at=existing_credential.access_token_expires_at, + ) replacement_refresh_token = ( refreshed_credential.refresh_token if refreshed_credential.refresh_token is not None else existing_credential.refresh_token ) + replacement_credential_blob = build_gmail_protected_credential_blob( + access_token=refreshed_credential.access_token, + refresh_token=replacement_refresh_token, + client_id=existing_credential.client_id, + client_secret=existing_credential.client_secret, + access_token_expires_at=refreshed_credential.access_token_expires_at, + ) try: + secret_manager.write_secret(secret_ref=secret_ref, payload=replacement_credential_blob) store.update_gmail_account_credential( gmail_account_id=gmail_account_id, auth_kind=auth_kind, - credential_blob=build_gmail_protected_credential_blob( - access_token=refreshed_credential.access_token, - refresh_token=replacement_refresh_token, - client_id=existing_credential.client_id, - client_secret=existing_credential.client_secret, - access_token_expires_at=refreshed_credential.access_token_expires_at, - ), + credential_kind=replacement_credential_blob["credential_kind"], + secret_manager_kind=secret_manager.kind, + secret_ref=secret_ref, + credential_blob=None, ) - except (ContinuityStoreInvariantError, psycopg.Error) as exc: + except (GmailSecretManagerError, ContinuityStoreInvariantError, psycopg.Error) as exc: + try: + secret_manager.write_secret(secret_ref=secret_ref, payload=original_credential_blob) + except GmailSecretManagerError: + pass raise GmailCredentialPersistenceError( f"gmail account {gmail_account_id} renewed protected credentials could not be persisted" ) from exc @@ -336,24 +515,16 @@ def _persist_refreshed_gmail_credential( def resolve_gmail_access_token( store: ContinuityStore, + secret_manager: GmailSecretManager, *, gmail_account_id: UUID, ) -> str: - credential = store.get_gmail_account_credential_optional(gmail_account_id) - if credential is None: - raise GmailCredentialNotFoundError( - f"gmail account {gmail_account_id} is missing protected credentials" - ) - - if credential["auth_kind"] != GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN: - raise GmailCredentialInvalidError( - f"gmail account {gmail_account_id} has invalid protected credentials" - ) - - parsed_credential = _parse_gmail_credential( + credential = _resolve_gmail_credential( + store, + secret_manager, gmail_account_id=gmail_account_id, - credential_blob=credential["credential_blob"], ) + parsed_credential = credential.parsed_credential if ( parsed_credential.credential_kind != GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND or parsed_credential.access_token_expires_at is None @@ -369,8 +540,10 @@ def resolve_gmail_access_token( ) _persist_refreshed_gmail_credential( store, + secret_manager, gmail_account_id=gmail_account_id, - auth_kind=credential["auth_kind"], + auth_kind=GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, + secret_ref=credential.secret_ref, existing_credential=parsed_credential, refreshed_credential=refreshed_credential, ) @@ -379,6 +552,7 @@ def resolve_gmail_access_token( def create_gmail_account_record( store: ContinuityStore, + secret_manager: GmailSecretManager, *, user_id: UUID, request: GmailAccountConnectInput, @@ -391,6 +565,8 @@ def create_gmail_account_record( f"gmail account {request.provider_account_id} is already connected" ) + row: GmailAccountRow | None = None + secret_ref: str | None = None try: row = store.create_gmail_account( provider_account_id=request.provider_account_id, @@ -398,21 +574,46 @@ def create_gmail_account_record( display_name=request.display_name, scope=request.scope, ) + credential_blob = build_gmail_protected_credential_blob( + access_token=request.access_token, + refresh_token=request.refresh_token, + client_id=request.client_id, + client_secret=request.client_secret, + access_token_expires_at=request.access_token_expires_at, + ) + secret_ref = build_gmail_secret_ref( + user_id=row["user_id"], + gmail_account_id=row["id"], + ) + _write_external_gmail_secret( + secret_manager, + gmail_account_id=row["id"], + secret_ref=secret_ref, + credential_blob=credential_blob, + ) store.create_gmail_account_credential( gmail_account_id=row["id"], auth_kind=GMAIL_AUTH_KIND_OAUTH_ACCESS_TOKEN, - credential_blob=build_gmail_protected_credential_blob( - access_token=request.access_token, - refresh_token=request.refresh_token, - client_id=request.client_id, - client_secret=request.client_secret, - access_token_expires_at=request.access_token_expires_at, - ), + credential_kind=credential_blob["credential_kind"], + secret_manager_kind=secret_manager.kind, + secret_ref=secret_ref, + credential_blob=None, ) except psycopg.errors.UniqueViolation as exc: raise GmailAccountAlreadyExistsError( f"gmail account {request.provider_account_id} is already connected" ) from exc + except GmailCredentialPersistenceError: + raise + except (ContinuityStoreInvariantError, psycopg.Error) as exc: + if secret_ref is not None: + try: + secret_manager.delete_secret(secret_ref=secret_ref) + except GmailSecretManagerError: + pass + raise GmailCredentialPersistenceError( + "gmail protected credentials could not be persisted" + ) from exc return {"account": serialize_gmail_account_row(row)} @@ -511,6 +712,7 @@ def fetch_gmail_message_raw_bytes(*, access_token: str, provider_message_id: str def ingest_gmail_message_record( store: ContinuityStore, + secret_manager: GmailSecretManager, *, user_id: UUID, request: GmailMessageIngestInput, @@ -527,6 +729,7 @@ def ingest_gmail_message_record( access_token = resolve_gmail_access_token( store, + secret_manager, gmail_account_id=request.gmail_account_id, ) diff --git a/apps/api/src/alicebot_api/gmail_secret_manager.py b/apps/api/src/alicebot_api/gmail_secret_manager.py new file mode 100644 index 0000000..d931ae2 --- /dev/null +++ b/apps/api/src/alicebot_api/gmail_secret_manager.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import unquote, urlparse + +from alicebot_api.store import JsonObject + +GMAIL_SECRET_MANAGER_KIND_FILE_V1 = "file_v1" + + +class GmailSecretManagerError(RuntimeError): + """Raised when the configured Gmail secret manager cannot service a request.""" + + +@dataclass(frozen=True, slots=True) +class GmailSecretReference: + kind: str + ref: str + + +class GmailSecretManager: + kind: str + + def load_secret(self, *, secret_ref: str) -> JsonObject: + raise NotImplementedError + + def write_secret(self, *, secret_ref: str, payload: JsonObject) -> None: + raise NotImplementedError + + def delete_secret(self, *, secret_ref: str) -> None: + raise NotImplementedError + + +class FileGmailSecretManager(GmailSecretManager): + kind = GMAIL_SECRET_MANAGER_KIND_FILE_V1 + + def __init__(self, *, root: Path) -> None: + self._root = root.expanduser().resolve() + + def load_secret(self, *, secret_ref: str) -> JsonObject: + path = self._resolve_secret_path(secret_ref) + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except FileNotFoundError as exc: + raise GmailSecretManagerError(f"gmail secret {secret_ref} was not found") from exc + except (OSError, UnicodeDecodeError, json.JSONDecodeError) as exc: + raise GmailSecretManagerError(f"gmail secret {secret_ref} could not be loaded") from exc + if not isinstance(payload, dict): + raise GmailSecretManagerError(f"gmail secret {secret_ref} could not be loaded") + return payload + + def write_secret(self, *, secret_ref: str, payload: JsonObject) -> None: + path = self._resolve_secret_path(secret_ref) + path.parent.mkdir(parents=True, exist_ok=True) + temp_path = path.with_name(f".{path.name}.{os.getpid()}.tmp") + try: + temp_path.write_text(json.dumps(payload, sort_keys=True), encoding="utf-8") + temp_path.replace(path) + except OSError as exc: + try: + temp_path.unlink(missing_ok=True) + except OSError: + pass + raise GmailSecretManagerError(f"gmail secret {secret_ref} could not be written") from exc + + def delete_secret(self, *, secret_ref: str) -> None: + path = self._resolve_secret_path(secret_ref) + try: + path.unlink(missing_ok=True) + except OSError as exc: + raise GmailSecretManagerError(f"gmail secret {secret_ref} could not be deleted") from exc + + def _resolve_secret_path(self, secret_ref: str) -> Path: + candidate = (self._root / secret_ref).resolve() + try: + candidate.relative_to(self._root) + except ValueError as exc: + raise GmailSecretManagerError(f"gmail secret {secret_ref} is outside the configured root") from exc + return candidate + + +def build_gmail_secret_manager(secret_manager_url: str) -> GmailSecretManager: + if secret_manager_url.strip() == "": + raise ValueError("GMAIL_SECRET_MANAGER_URL must be configured") + + parsed = urlparse(secret_manager_url) + if parsed.scheme != "file": + raise ValueError("GMAIL_SECRET_MANAGER_URL must use the file:// scheme") + + root_path = Path(unquote(parsed.path or "/")) + if parsed.netloc not in ("", "localhost"): + root_path = Path(f"/{parsed.netloc}{root_path.as_posix()}") + + return FileGmailSecretManager(root=root_path) diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py index f23b11e..40e3327 100644 --- a/apps/api/src/alicebot_api/main.py +++ b/apps/api/src/alicebot_api/main.py @@ -154,6 +154,7 @@ ingest_gmail_message_record, list_gmail_account_records, ) +from alicebot_api.gmail_secret_manager import build_gmail_secret_manager from alicebot_api.embedding import ( EmbeddingConfigValidationError, MemoryEmbeddingNotFoundError, @@ -1313,11 +1314,13 @@ def get_task(task_id: UUID, user_id: UUID) -> JSONResponse: @app.post("/v0/gmail-accounts") def connect_gmail_account(request: ConnectGmailAccountRequest) -> JSONResponse: settings = get_settings() + secret_manager = build_gmail_secret_manager(settings.gmail_secret_manager_url) try: with user_connection(settings.database_url, request.user_id) as conn: payload = create_gmail_account_record( ContinuityStore(conn), + secret_manager, user_id=request.user_id, request=GmailAccountConnectInput( provider_account_id=request.provider_account_id, @@ -1333,6 +1336,8 @@ def connect_gmail_account(request: ConnectGmailAccountRequest) -> JSONResponse: ) except GmailCredentialValidationError as exc: return JSONResponse(status_code=400, content={"detail": str(exc)}) + except GmailCredentialPersistenceError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) except GmailAccountAlreadyExistsError as exc: return JSONResponse(status_code=409, content={"detail": str(exc)}) @@ -1385,11 +1390,13 @@ def ingest_gmail_message( request: IngestGmailMessageRequest, ) -> JSONResponse: settings = get_settings() + secret_manager = build_gmail_secret_manager(settings.gmail_secret_manager_url) try: with user_connection(settings.database_url, request.user_id) as conn: payload = ingest_gmail_message_record( ContinuityStore(conn), + secret_manager, user_id=request.user_id, request=GmailMessageIngestInput( gmail_account_id=gmail_account_id, diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py index 61dcd53..c12d464 100644 --- a/apps/api/src/alicebot_api/store.py +++ b/apps/api/src/alicebot_api/store.py @@ -260,7 +260,10 @@ class ProtectedGmailCredentialRow(TypedDict): gmail_account_id: UUID user_id: UUID auth_kind: str - credential_blob: JsonObject + credential_kind: str + secret_manager_kind: str + secret_ref: str | None + credential_blob: JsonObject | None created_at: datetime updated_at: datetime @@ -1524,6 +1527,9 @@ class LabelCountRow(TypedDict): gmail_account_id, user_id, auth_kind, + credential_kind, + secret_manager_kind, + secret_ref, credential_blob, created_at, updated_at @@ -1533,6 +1539,9 @@ class LabelCountRow(TypedDict): app.current_user_id(), %s, %s, + %s, + %s, + %s, clock_timestamp(), clock_timestamp() ) @@ -1540,6 +1549,9 @@ class LabelCountRow(TypedDict): gmail_account_id, user_id, auth_kind, + credential_kind, + secret_manager_kind, + secret_ref, credential_blob, created_at, updated_at @@ -1580,6 +1592,9 @@ class LabelCountRow(TypedDict): gmail_account_id, user_id, auth_kind, + credential_kind, + secret_manager_kind, + secret_ref, credential_blob, created_at, updated_at @@ -1591,6 +1606,9 @@ class LabelCountRow(TypedDict): UPDATE gmail_account_credentials SET auth_kind = %s, + credential_kind = %s, + secret_manager_kind = %s, + secret_ref = %s, credential_blob = %s, updated_at = clock_timestamp() WHERE gmail_account_id = %s @@ -1598,6 +1616,9 @@ class LabelCountRow(TypedDict): gmail_account_id, user_id, auth_kind, + credential_kind, + secret_manager_kind, + secret_ref, credential_blob, created_at, updated_at @@ -3154,12 +3175,22 @@ def create_gmail_account_credential( *, gmail_account_id: UUID, auth_kind: str, - credential_blob: JsonObject, + credential_kind: str, + secret_manager_kind: str, + secret_ref: str | None, + credential_blob: JsonObject | None, ) -> ProtectedGmailCredentialRow: return self._fetch_one( "create_gmail_account_credential", INSERT_GMAIL_ACCOUNT_CREDENTIAL_SQL, - (gmail_account_id, auth_kind, Jsonb(credential_blob)), + ( + gmail_account_id, + auth_kind, + credential_kind, + secret_manager_kind, + secret_ref, + None if credential_blob is None else Jsonb(credential_blob), + ), ) def get_gmail_account_optional(self, gmail_account_id: UUID) -> GmailAccountRow | None: @@ -3176,12 +3207,22 @@ def update_gmail_account_credential( *, gmail_account_id: UUID, auth_kind: str, - credential_blob: JsonObject, + credential_kind: str, + secret_manager_kind: str, + secret_ref: str | None, + credential_blob: JsonObject | None, ) -> ProtectedGmailCredentialRow: return self._fetch_one( "update_gmail_account_credential", UPDATE_GMAIL_ACCOUNT_CREDENTIAL_SQL, - (auth_kind, Jsonb(credential_blob), gmail_account_id), + ( + auth_kind, + credential_kind, + secret_manager_kind, + secret_ref, + None if credential_blob is None else Jsonb(credential_blob), + gmail_account_id, + ), ) def get_gmail_account_by_provider_account_id_optional( diff --git a/tests/integration/test_gmail_accounts_api.py b/tests/integration/test_gmail_accounts_api.py index ec4c7aa..acfbae4 100644 --- a/tests/integration/test_gmail_accounts_api.py +++ b/tests/integration/test_gmail_accounts_api.py @@ -82,6 +82,10 @@ def _build_rfc822_email_bytes(*, subject: str, plain_body: str) -> bytes: ) +def _build_gmail_secret_manager_url(root: Path) -> str: + return root.resolve().as_uri() + + def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: user_id = uuid4() @@ -180,13 +184,18 @@ def _connect_gmail_account( def test_gmail_account_endpoints_connect_list_detail_and_isolate( migrated_database_urls, monkeypatch, + tmp_path, ) -> None: owner = seed_user(migrated_database_urls["app"], email="owner@example.com") intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", - lambda: Settings(database_url=migrated_database_urls["app"]), + lambda: Settings( + database_url=migrated_database_urls["app"], + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), + ), ) create_status, create_payload = _connect_gmail_account( @@ -275,18 +284,28 @@ def test_gmail_account_endpoints_connect_list_detail_and_isolate( """ SELECT auth_kind, - credential_blob ->> 'credential_kind', - credential_blob ->> 'access_token' + credential_kind, + secret_manager_kind, + secret_ref, + credential_blob IS NULL FROM gmail_account_credentials WHERE gmail_account_id = %s """, (UUID(create_payload["account"]["id"]),), ) - assert cur.fetchone() == ( - "oauth_access_token", - "gmail_oauth_access_token_v1", - "token-for-acct-owner-001", - ) + credential_row = cur.fetchone() + + assert credential_row is not None + assert credential_row[0] == "oauth_access_token" + assert credential_row[1] == "gmail_oauth_access_token_v1" + assert credential_row[2] == "file_v1" + assert credential_row[4] is True + assert credential_row[3] is not None + secret_payload = json.loads((gmail_secret_root / credential_row[3]).read_text(encoding="utf-8")) + assert secret_payload == { + "credential_kind": "gmail_oauth_access_token_v1", + "access_token": "token-for-acct-owner-001", + } def test_gmail_message_ingestion_endpoint_persists_artifact_and_chunks( @@ -296,12 +315,14 @@ def test_gmail_message_ingestion_endpoint_persists_artifact_and_chunks( ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) raw_bytes = _build_rfc822_email_bytes(subject="Inbox Update", plain_body="ingest this message") @@ -383,12 +404,14 @@ def test_gmail_message_ingestion_endpoint_renews_expired_access_token( ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) raw_bytes = _build_rfc822_email_bytes(subject="Inbox Update", plain_body="renewed token path") @@ -455,12 +478,10 @@ def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes cur.execute( """ SELECT - credential_blob ->> 'credential_kind', - credential_blob ->> 'access_token', - credential_blob ->> 'refresh_token', - credential_blob ->> 'client_id', - credential_blob ->> 'client_secret', - credential_blob ->> 'access_token_expires_at' + credential_kind, + secret_manager_kind, + secret_ref, + credential_blob IS NULL FROM gmail_account_credentials WHERE gmail_account_id = %s """, @@ -470,11 +491,18 @@ def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes assert credential_row is not None assert credential_row[0] == "gmail_oauth_refresh_token_v2" - assert credential_row[1] == "token-refreshed" - assert credential_row[2] == "refresh-owner-001" - assert credential_row[3] == "client-owner-001" - assert credential_row[4] == "secret-owner-001" - assert credential_row[5] == "2030-01-01T00:05:00+00:00" + assert credential_row[1] == "file_v1" + assert credential_row[3] is True + assert credential_row[2] is not None + secret_payload = json.loads((gmail_secret_root / credential_row[2]).read_text(encoding="utf-8")) + assert secret_payload == { + "credential_kind": "gmail_oauth_refresh_token_v2", + "access_token": "token-refreshed", + "refresh_token": "refresh-owner-001", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2030-01-01T00:05:00+00:00", + } def test_gmail_message_ingestion_endpoint_persists_rotated_refresh_token( @@ -484,12 +512,14 @@ def test_gmail_message_ingestion_endpoint_persists_rotated_refresh_token( ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) raw_bytes = _build_rfc822_email_bytes(subject="Inbox Update", plain_body="rotated token path") @@ -580,12 +610,10 @@ def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes cur.execute( """ SELECT - credential_blob ->> 'credential_kind', - credential_blob ->> 'access_token', - credential_blob ->> 'refresh_token', - credential_blob ->> 'client_id', - credential_blob ->> 'client_secret', - credential_blob ->> 'access_token_expires_at' + credential_kind, + secret_manager_kind, + secret_ref, + credential_blob IS NULL FROM gmail_account_credentials WHERE gmail_account_id = %s """, @@ -595,11 +623,18 @@ def fake_fetch_gmail_message_raw_bytes(*, access_token: str, **_kwargs) -> bytes assert credential_row is not None assert credential_row[0] == "gmail_oauth_refresh_token_v2" - assert credential_row[1] == "token-refreshed-rotated" - assert credential_row[2] == "refresh-owner-rotated-002" - assert credential_row[3] == "client-owner-001" - assert credential_row[4] == "secret-owner-001" - assert credential_row[5] == "2030-01-01T00:05:00+00:00" + assert credential_row[1] == "file_v1" + assert credential_row[3] is True + assert credential_row[2] is not None + secret_payload = json.loads((gmail_secret_root / credential_row[2]).read_text(encoding="utf-8")) + assert secret_payload == { + "credential_kind": "gmail_oauth_refresh_token_v2", + "access_token": "token-refreshed-rotated", + "refresh_token": "refresh-owner-rotated-002", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2030-01-01T00:05:00+00:00", + } def test_gmail_message_ingestion_endpoint_fails_deterministically_when_rotated_credentials_cannot_be_persisted( @@ -609,12 +644,14 @@ def test_gmail_message_ingestion_endpoint_fails_deterministically_when_rotated_c ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) @@ -688,9 +725,10 @@ def fail_fetch(**_kwargs): cur.execute( """ SELECT - credential_blob ->> 'access_token', - credential_blob ->> 'refresh_token', - credential_blob ->> 'access_token_expires_at' + credential_kind, + secret_manager_kind, + secret_ref, + credential_blob IS NULL FROM gmail_account_credentials WHERE gmail_account_id = %s """, @@ -698,11 +736,20 @@ def fail_fetch(**_kwargs): ) credential_row = cur.fetchone() - assert credential_row == ( - f"token-for-{account_payload['account']['provider_account_id']}", - "refresh-owner-001", - "2020-01-01T00:00:00+00:00", - ) + assert credential_row is not None + assert credential_row[0] == "gmail_oauth_refresh_token_v2" + assert credential_row[1] == "file_v1" + assert credential_row[3] is True + assert credential_row[2] is not None + secret_payload = json.loads((gmail_secret_root / credential_row[2]).read_text(encoding="utf-8")) + assert secret_payload == { + "credential_kind": "gmail_oauth_refresh_token_v2", + "access_token": f"token-for-{account_payload['account']['provider_account_id']}", + "refresh_token": "refresh-owner-001", + "client_id": "client-owner-001", + "client_secret": "secret-owner-001", + "access_token_expires_at": "2020-01-01T00:00:00+00:00", + } def test_gmail_message_ingestion_endpoint_rejects_cross_user_workspace_access( @@ -713,12 +760,14 @@ def test_gmail_message_ingestion_endpoint_rejects_cross_user_workspace_access( owner = seed_task(migrated_database_urls["app"], email="owner@example.com") intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) monkeypatch.setattr( @@ -762,12 +811,14 @@ def test_gmail_message_ingestion_endpoint_rejects_missing_protected_credentials_ ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) @@ -820,6 +871,87 @@ def fail_fetch(**_kwargs): assert store.list_task_artifacts_for_task(owner["task_id"]) == [] +def test_gmail_message_ingestion_endpoint_rejects_missing_external_secret_without_side_effects( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), + ), + ) + + monkeypatch.setattr( + gmail_module, + "fetch_gmail_message_raw_bytes", + lambda **_kwargs: (_ for _ in ()).throw( + AssertionError("fetch_gmail_message_raw_bytes should not be called") + ), + ) + + _, account_payload = _connect_gmail_account( + user_id=owner["user_id"], + provider_account_id="acct-owner-secret-missing-001", + email_address="owner@gmail.example", + ) + _, workspace_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT secret_ref + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (UUID(account_payload["account"]["id"]),), + ) + secret_ref_row = cur.fetchone() + + assert secret_ref_row is not None + secret_path = gmail_secret_root / secret_ref_row[0] + secret_path.unlink() + + ingest_status, ingest_payload = invoke_request( + "POST", + f"/v0/gmail-accounts/{account_payload['account']['id']}/messages/msg-001/ingest", + payload={ + "user_id": str(owner["user_id"]), + "task_workspace_id": workspace_payload["workspace"]["id"], + }, + ) + + assert ingest_status == 409 + assert ingest_payload == { + "detail": ( + f"gmail account {account_payload['account']['id']} is missing protected credentials" + ) + } + artifact_file = ( + Path(workspace_payload["workspace"]["local_path"]) + / "gmail" + / "acct-owner-secret-missing-001" + / "msg-001.eml" + ) + assert not artifact_file.exists() + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_task_artifacts_for_task(owner["task_id"]) == [] + + def test_gmail_message_ingestion_endpoint_rejects_invalid_refresh_credentials_without_side_effects( migrated_database_urls, monkeypatch, @@ -827,12 +959,14 @@ def test_gmail_message_ingestion_endpoint_rejects_invalid_refresh_credentials_wi ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) @@ -899,12 +1033,14 @@ def test_gmail_message_ingestion_endpoint_rejects_sanitized_path_collisions_with ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) first_bytes = _build_rfc822_email_bytes(subject="First", plain_body="first message body") @@ -979,12 +1115,14 @@ def test_gmail_message_ingestion_endpoint_rejects_missing_and_unsupported_messag ) -> None: owner = seed_task(migrated_database_urls["app"], email="owner@example.com") workspace_root = tmp_path / "task-workspaces" + gmail_secret_root = tmp_path / "gmail-secrets" monkeypatch.setattr( main_module, "get_settings", lambda: Settings( database_url=migrated_database_urls["app"], task_workspace_root=str(workspace_root), + gmail_secret_manager_url=_build_gmail_secret_manager_url(gmail_secret_root), ), ) diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index aba3133..8361b65 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -467,6 +467,116 @@ def test_gmail_refresh_token_lifecycle_migration_round_trip_preserves_downgrade_ ) +def test_gmail_external_secret_manager_migration_round_trip_preserves_legacy_transition_rows( + database_urls, +): + config = make_alembic_config(database_urls["admin"]) + user_id = "00000000-0000-0000-0000-000000000301" + gmail_account_id = "00000000-0000-0000-0000-000000000302" + + command.upgrade(config, "20260316_0028") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, 'gmail-secret-manager@example.com', 'Gmail Secret Manager User') + """, + (user_id,), + ) + cur.execute( + """ + INSERT INTO gmail_accounts ( + id, + user_id, + provider_account_id, + email_address, + display_name, + scope + ) + VALUES ( + %s, + %s, + 'acct-secret-manager-001', + 'owner@gmail.example', + 'Owner', + 'https://www.googleapis.com/auth/gmail.readonly' + ) + """, + (gmail_account_id, user_id), + ) + cur.execute( + """ + INSERT INTO gmail_account_credentials ( + gmail_account_id, + user_id, + auth_kind, + credential_blob + ) + VALUES ( + %s, + %s, + 'oauth_access_token', + jsonb_build_object( + 'credential_kind', 'gmail_oauth_refresh_token_v2', + 'access_token', 'token-before-externalization', + 'refresh_token', 'refresh-001', + 'client_id', 'client-001', + 'client_secret', 'secret-001', + 'access_token_expires_at', '2030-01-01T00:05:00+00:00' + ) + ) + """, + (gmail_account_id, user_id), + ) + conn.commit() + + command.upgrade(config, "20260316_0029") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + credential_kind, + secret_manager_kind, + secret_ref, + credential_blob ->> 'access_token' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (gmail_account_id,), + ) + assert cur.fetchone() == ( + "gmail_oauth_refresh_token_v2", + "legacy_db_v0", + None, + "token-before-externalization", + ) + + command.downgrade(config, "20260316_0028") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + credential_blob ->> 'credential_kind', + credential_blob ->> 'access_token', + credential_blob ->> 'refresh_token' + FROM gmail_account_credentials + WHERE gmail_account_id = %s + """, + (gmail_account_id,), + ) + assert cur.fetchone() == ( + "gmail_oauth_refresh_token_v2", + "token-before-externalization", + "refresh-001", + ) + + def test_migrations_upgrade_and_downgrade(database_urls): config = make_alembic_config(database_urls["admin"]) diff --git a/tests/unit/test_20260316_0029_gmail_external_secret_manager.py b/tests/unit/test_20260316_0029_gmail_external_secret_manager.py new file mode 100644 index 0000000..087c4cb --- /dev/null +++ b/tests/unit/test_20260316_0029_gmail_external_secret_manager.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260316_0029_gmail_external_secret_manager" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_upgrade_marks_external_secret_manager_and_legacy_transition_kinds_explicitly() -> None: + module = load_migration_module() + + assert module.GMAIL_SECRET_MANAGER_KIND_FILE_V1 == "file_v1" + assert module.GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0 == "legacy_db_v0" diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 6d10d22..f59ac87 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -24,6 +24,7 @@ def test_settings_defaults(monkeypatch): "MODEL_API_KEY", "MODEL_TIMEOUT_SECONDS", "TASK_WORKSPACE_ROOT", + "GMAIL_SECRET_MANAGER_URL", ): monkeypatch.delenv(key, raising=False) @@ -39,6 +40,7 @@ def test_settings_defaults(monkeypatch): assert settings.model_name == "gpt-5-mini" assert settings.model_timeout_seconds == 30 assert settings.task_workspace_root == "/tmp/alicebot/task-workspaces" + assert settings.gmail_secret_manager_url == "" def test_settings_honor_environment_overrides(monkeypatch): @@ -50,6 +52,7 @@ def test_settings_honor_environment_overrides(monkeypatch): monkeypatch.setenv("MODEL_NAME", "gpt-5") monkeypatch.setenv("MODEL_TIMEOUT_SECONDS", "45") monkeypatch.setenv("TASK_WORKSPACE_ROOT", "/tmp/custom-workspaces") + monkeypatch.setenv("GMAIL_SECRET_MANAGER_URL", "file:///tmp/custom-gmail-secrets") settings = Settings.from_env() @@ -61,6 +64,7 @@ def test_settings_honor_environment_overrides(monkeypatch): assert settings.model_name == "gpt-5" assert settings.model_timeout_seconds == 45 assert settings.task_workspace_root == "/tmp/custom-workspaces" + assert settings.gmail_secret_manager_url == "file:///tmp/custom-gmail-secrets" def test_settings_can_be_loaded_from_an_explicit_environment_mapping() -> None: @@ -72,6 +76,7 @@ def test_settings_can_be_loaded_from_an_explicit_environment_mapping() -> None: "MODEL_PROVIDER": "openai_responses", "MODEL_NAME": "gpt-5-mini", "TASK_WORKSPACE_ROOT": "/tmp/mapped-workspaces", + "GMAIL_SECRET_MANAGER_URL": "file:///tmp/mapped-gmail-secrets", } ) @@ -81,6 +86,7 @@ def test_settings_can_be_loaded_from_an_explicit_environment_mapping() -> None: assert settings.model_provider == "openai_responses" assert settings.model_name == "gpt-5-mini" assert settings.task_workspace_root == "/tmp/mapped-workspaces" + assert settings.gmail_secret_manager_url == "file:///tmp/mapped-gmail-secrets" def test_settings_raise_clear_error_for_invalid_integer_values() -> None: diff --git a/tests/unit/test_gmail.py b/tests/unit/test_gmail.py index 3242532..ebcd6c9 100644 --- a/tests/unit/test_gmail.py +++ b/tests/unit/test_gmail.py @@ -22,7 +22,10 @@ GmailCredentialPersistenceError, GmailCredentialValidationError, GmailMessageUnsupportedError, + GMAIL_SECRET_MANAGER_KIND_FILE_V1, + GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0, RefreshedGmailCredential, + build_gmail_secret_ref, build_gmail_message_artifact_relative_path, build_gmail_protected_credential_blob, create_gmail_account_record, @@ -31,6 +34,7 @@ list_gmail_account_records, resolve_gmail_access_token, ) +from alicebot_api.gmail_secret_manager import GmailSecretManagerError from alicebot_api.store import ContinuityStoreInvariantError from alicebot_api.workspaces import TaskWorkspaceNotFoundError @@ -86,7 +90,10 @@ def create_gmail_account_credential( *, gmail_account_id: UUID, auth_kind: str, - credential_blob: dict[str, object], + credential_kind: str, + secret_manager_kind: str, + secret_ref: str | None, + credential_blob: dict[str, object] | None, ) -> dict[str, object]: row = { "gmail_account_id": gmail_account_id, @@ -96,6 +103,9 @@ def create_gmail_account_credential( if account["id"] == gmail_account_id ), "auth_kind": auth_kind, + "credential_kind": credential_kind, + "secret_manager_kind": secret_manager_kind, + "secret_ref": secret_ref, "credential_blob": credential_blob, "created_at": self.base_time + timedelta(minutes=len(self.gmail_account_credentials)), "updated_at": self.base_time + timedelta(minutes=len(self.gmail_account_credentials)), @@ -122,12 +132,18 @@ def update_gmail_account_credential( *, gmail_account_id: UUID, auth_kind: str, - credential_blob: dict[str, object], + credential_kind: str, + secret_manager_kind: str, + secret_ref: str | None, + credential_blob: dict[str, object] | None, ) -> dict[str, object]: existing = self.gmail_account_credentials[gmail_account_id] updated = { **existing, "auth_kind": auth_kind, + "credential_kind": credential_kind, + "secret_manager_kind": secret_manager_kind, + "secret_ref": secret_ref, "credential_blob": credential_blob, "updated_at": self.base_time + timedelta(hours=1), } @@ -217,12 +233,39 @@ def get_task_artifact_by_workspace_relative_path_optional( ) +class GmailSecretManagerStub: + def __init__(self) -> None: + self.secrets: dict[str, dict[str, object]] = {} + self.operations: list[tuple[str, str]] = [] + + @property + def kind(self) -> str: + return GMAIL_SECRET_MANAGER_KIND_FILE_V1 + + def load_secret(self, *, secret_ref: str) -> dict[str, object]: + self.operations.append(("load_secret", secret_ref)) + try: + return dict(self.secrets[secret_ref]) + except KeyError as exc: + raise GmailSecretManagerError(f"gmail secret {secret_ref} was not found") from exc + + def write_secret(self, *, secret_ref: str, payload: dict[str, object]) -> None: + self.operations.append(("write_secret", secret_ref)) + self.secrets[secret_ref] = dict(payload) + + def delete_secret(self, *, secret_ref: str) -> None: + self.operations.append(("delete_secret", secret_ref)) + self.secrets.pop(secret_ref, None) + + def test_create_list_and_get_gmail_account_records_are_deterministic() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() first = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -234,6 +277,7 @@ def test_create_list_and_get_gmail_account_records_are_deterministic() -> None: ) second = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-002", @@ -259,10 +303,12 @@ def test_create_list_and_get_gmail_account_records_are_deterministic() -> None: def test_create_gmail_account_record_persists_protected_credential_and_hides_secret() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() response = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -287,20 +333,38 @@ def test_create_gmail_account_record_persists_protected_credential_and_hides_sec "updated_at": response["account"]["updated_at"], } } - assert store.gmail_account_credentials[account_id]["credential_blob"] == { + secret_ref = build_gmail_secret_ref( + user_id=store.gmail_account_credentials[account_id]["user_id"], + gmail_account_id=account_id, + ) + assert store.gmail_account_credentials[account_id] == { + "gmail_account_id": account_id, + "user_id": store.gmail_account_credentials[account_id]["user_id"], + "auth_kind": "oauth_access_token", + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "secret_manager_kind": GMAIL_SECRET_MANAGER_KIND_FILE_V1, + "secret_ref": secret_ref, + "credential_blob": None, + "created_at": store.gmail_account_credentials[account_id]["created_at"], + "updated_at": store.gmail_account_credentials[account_id]["updated_at"], + } + assert secret_manager.secrets[secret_ref] == { "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, "access_token": "token-1", } assert store.operations == [("create_gmail_account_credential", account_id)] + assert secret_manager.operations == [("write_secret", secret_ref)] def test_create_gmail_account_record_persists_refreshable_protected_credential_and_hides_secret() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() expires_at = datetime(2030, 1, 1, 0, 0, tzinfo=UTC) response = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-refresh-001", @@ -329,7 +393,19 @@ def test_create_gmail_account_record_persists_refreshable_protected_credential_a "updated_at": response["account"]["updated_at"], } } - assert store.gmail_account_credentials[account_id]["credential_blob"] == { + secret_ref = build_gmail_secret_ref( + user_id=store.gmail_account_credentials[account_id]["user_id"], + gmail_account_id=account_id, + ) + assert store.gmail_account_credentials[account_id]["credential_blob"] is None + assert store.gmail_account_credentials[account_id]["credential_kind"] == ( + GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND + ) + assert store.gmail_account_credentials[account_id]["secret_manager_kind"] == ( + GMAIL_SECRET_MANAGER_KIND_FILE_V1 + ) + assert store.gmail_account_credentials[account_id]["secret_ref"] == secret_ref + assert secret_manager.secrets[secret_ref] == { "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, "access_token": "token-1", "refresh_token": "refresh-1", @@ -357,6 +433,7 @@ def test_build_gmail_protected_credential_blob_rejects_partial_refresh_bundle() def test_create_gmail_account_record_rejects_duplicate_provider_account_id() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() request = GmailAccountConnectInput( provider_account_id="acct-001", @@ -365,13 +442,13 @@ def test_create_gmail_account_record_rejects_duplicate_provider_account_id() -> scope=GMAIL_READONLY_SCOPE, access_token="token-1", ) - create_gmail_account_record(store, user_id=user_id, request=request) + create_gmail_account_record(store, secret_manager, user_id=user_id, request=request) with pytest.raises( GmailAccountAlreadyExistsError, match="gmail account acct-001 is already connected", ): - create_gmail_account_record(store, user_id=user_id, request=request) + create_gmail_account_record(store, secret_manager, user_id=user_id, request=request) def test_get_gmail_account_record_raises_when_account_is_missing() -> None: @@ -385,8 +462,10 @@ def test_get_gmail_account_record_raises_when_account_is_missing() -> None: def test_resolve_gmail_access_token_reads_protected_credential() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() account = create_gmail_account_record( store, + secret_manager, user_id=uuid4(), request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -399,14 +478,17 @@ def test_resolve_gmail_access_token_reads_protected_credential() -> None: assert resolve_gmail_access_token( store, + secret_manager, gmail_account_id=UUID(account["id"]), ) == "token-1" def test_resolve_gmail_access_token_rejects_missing_and_invalid_protected_credentials() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() account = create_gmail_account_record( store, + secret_manager, user_id=uuid4(), request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -423,12 +505,15 @@ def test_resolve_gmail_access_token_rejects_missing_and_invalid_protected_creden GmailCredentialNotFoundError, match=f"gmail account {account_id} is missing protected credentials", ): - resolve_gmail_access_token(store, gmail_account_id=account_id) + resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) store.gmail_account_credentials[account_id] = { "gmail_account_id": account_id, "user_id": uuid4(), "auth_kind": "oauth_access_token", + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "secret_manager_kind": GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0, + "secret_ref": None, "credential_blob": {"credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND}, "created_at": store.base_time, "updated_at": store.base_time, @@ -437,15 +522,56 @@ def test_resolve_gmail_access_token_rejects_missing_and_invalid_protected_creden GmailCredentialInvalidError, match=f"gmail account {account_id} has invalid protected credentials", ): - resolve_gmail_access_token(store, gmail_account_id=account_id) + resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) + + +def test_resolve_gmail_access_token_externalizes_legacy_db_credentials_on_first_read() -> None: + store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() + account = create_gmail_account_record( + store, + secret_manager, + user_id=uuid4(), + request=GmailAccountConnectInput( + provider_account_id="acct-legacy-001", + email_address="owner@example.com", + display_name="Owner", + scope=GMAIL_READONLY_SCOPE, + access_token="token-1", + ), + )["account"] + account_id = UUID(account["id"]) + credential_row = store.gmail_account_credentials[account_id] + legacy_blob = { + "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, + "access_token": "token-legacy-001", + } + secret_ref = credential_row["secret_ref"] + assert secret_ref is not None + credential_row["secret_manager_kind"] = GMAIL_SECRET_MANAGER_KIND_LEGACY_DB_V0 + credential_row["secret_ref"] = None + credential_row["credential_blob"] = legacy_blob + secret_manager.secrets.pop(secret_ref) + + assert resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) == ( + "token-legacy-001" + ) + assert store.gmail_account_credentials[account_id]["secret_manager_kind"] == ( + GMAIL_SECRET_MANAGER_KIND_FILE_V1 + ) + assert store.gmail_account_credentials[account_id]["secret_ref"] == secret_ref + assert store.gmail_account_credentials[account_id]["credential_blob"] is None + assert secret_manager.secrets[secret_ref] == legacy_blob def test_resolve_gmail_access_token_renews_expired_refreshable_credential(monkeypatch) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() expired_at = datetime(2020, 1, 1, 0, 0, tzinfo=UTC) refreshed_at = datetime(2030, 1, 1, 0, 5, tzinfo=UTC) account = create_gmail_account_record( store, + secret_manager, user_id=uuid4(), request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -469,8 +595,10 @@ def test_resolve_gmail_access_token_renews_expired_refreshable_credential(monkey ), ) - assert resolve_gmail_access_token(store, gmail_account_id=account_id) == "token-2" - assert store.gmail_account_credentials[account_id]["credential_blob"] == { + secret_ref = store.gmail_account_credentials[account_id]["secret_ref"] + assert resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) == "token-2" + assert secret_ref is not None + assert secret_manager.secrets[secret_ref] == { "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, "access_token": "token-2", "refresh_token": "refresh-1", @@ -486,10 +614,12 @@ def test_resolve_gmail_access_token_renews_expired_refreshable_credential(monkey def test_resolve_gmail_access_token_persists_rotated_refresh_token(monkeypatch) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() expired_at = datetime(2020, 1, 1, 0, 0, tzinfo=UTC) refreshed_at = datetime(2030, 1, 1, 0, 5, tzinfo=UTC) account = create_gmail_account_record( store, + secret_manager, user_id=uuid4(), request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -514,8 +644,10 @@ def test_resolve_gmail_access_token_persists_rotated_refresh_token(monkeypatch) ), ) - assert resolve_gmail_access_token(store, gmail_account_id=account_id) == "token-2" - assert store.gmail_account_credentials[account_id]["credential_blob"] == { + secret_ref = store.gmail_account_credentials[account_id]["secret_ref"] + assert resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) == "token-2" + assert secret_ref is not None + assert secret_manager.secrets[secret_ref] == { "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, "access_token": "token-2", "refresh_token": "refresh-2", @@ -529,10 +661,12 @@ def test_resolve_gmail_access_token_fails_deterministically_when_persisting_refr monkeypatch, ) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() expired_at = datetime(2020, 1, 1, 0, 0, tzinfo=UTC) refreshed_at = datetime(2030, 1, 1, 0, 5, tzinfo=UTC) account = create_gmail_account_record( store, + secret_manager, user_id=uuid4(), request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -547,7 +681,9 @@ def test_resolve_gmail_access_token_fails_deterministically_when_persisting_refr ), )["account"] account_id = UUID(account["id"]) - original_blob = dict(store.gmail_account_credentials[account_id]["credential_blob"]) + secret_ref = store.gmail_account_credentials[account_id]["secret_ref"] + assert secret_ref is not None + original_blob = dict(secret_manager.secrets[secret_ref]) monkeypatch.setattr( "alicebot_api.gmail.refresh_gmail_access_token", @@ -567,15 +703,17 @@ def fail_update(**_kwargs): GmailCredentialPersistenceError, match=f"gmail account {account_id} renewed protected credentials could not be persisted", ): - resolve_gmail_access_token(store, gmail_account_id=account_id) + resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) - assert store.gmail_account_credentials[account_id]["credential_blob"] == original_blob + assert secret_manager.secrets[secret_ref] == original_blob def test_resolve_gmail_access_token_rejects_invalid_refreshable_protected_credentials() -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() account = create_gmail_account_record( store, + secret_manager, user_id=uuid4(), request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -586,7 +724,9 @@ def test_resolve_gmail_access_token_rejects_invalid_refreshable_protected_creden ), )["account"] account_id = UUID(account["id"]) - store.gmail_account_credentials[account_id]["credential_blob"] = { + secret_ref = store.gmail_account_credentials[account_id]["secret_ref"] + assert secret_ref is not None + secret_manager.secrets[secret_ref] = { "credential_kind": GMAIL_REFRESHABLE_PROTECTED_CREDENTIAL_KIND, "access_token": "token-1", "client_id": "client-1", @@ -598,7 +738,7 @@ def test_resolve_gmail_access_token_rejects_invalid_refreshable_protected_creden GmailCredentialInvalidError, match=f"gmail account {account_id} has invalid protected credentials", ): - resolve_gmail_access_token(store, gmail_account_id=account_id) + resolve_gmail_access_token(store, secret_manager, gmail_account_id=account_id) def test_ingest_gmail_message_record_writes_rfc822_artifact_and_reuses_artifact_seam( @@ -606,6 +746,7 @@ def test_ingest_gmail_message_record_writes_rfc822_artifact_and_reuses_artifact_ tmp_path, ) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() workspace_id = uuid4() workspace = store.create_task_workspace( @@ -614,6 +755,7 @@ def test_ingest_gmail_message_record_writes_rfc822_artifact_and_reuses_artifact_ ) account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -683,6 +825,7 @@ def fake_ingest(_store, *, user_id: UUID, request): response = ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=UUID(account["id"]), @@ -730,6 +873,7 @@ def test_ingest_gmail_message_record_renews_expired_access_token_before_fetch( tmp_path, ) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() workspace_id = uuid4() workspace = store.create_task_workspace( @@ -738,6 +882,7 @@ def test_ingest_gmail_message_record_renews_expired_access_token_before_fetch( ) account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -815,6 +960,7 @@ def fake_fetch(**kwargs): response = ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=UUID(account["id"]), @@ -834,6 +980,7 @@ def fake_fetch(**kwargs): def test_ingest_gmail_message_record_rejects_unsupported_message(monkeypatch, tmp_path) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() workspace_id = uuid4() store.create_task_workspace( @@ -842,6 +989,7 @@ def test_ingest_gmail_message_record_rejects_unsupported_message(monkeypatch, tm ) account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -863,6 +1011,7 @@ def test_ingest_gmail_message_record_rejects_unsupported_message(monkeypatch, tm ): ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=UUID(account["id"]), @@ -877,6 +1026,7 @@ def test_ingest_gmail_message_record_rejects_duplicate_sanitized_path_before_fet tmp_path, ) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() workspace_id = uuid4() workspace_path = (tmp_path / "workspace").resolve() @@ -886,6 +1036,7 @@ def test_ingest_gmail_message_record_rejects_duplicate_sanitized_path_before_fet ) account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -918,6 +1069,7 @@ def fail_fetch(**_kwargs): ): ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=UUID(account["id"]), @@ -936,9 +1088,11 @@ def fail_fetch(**_kwargs): def test_ingest_gmail_message_record_requires_visible_workspace(monkeypatch) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -957,6 +1111,7 @@ def test_ingest_gmail_message_record_requires_visible_workspace(monkeypatch) -> with pytest.raises(TaskWorkspaceNotFoundError, match="task workspace .* was not found"): ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=UUID(account["id"]), @@ -971,6 +1126,7 @@ def test_ingest_gmail_message_record_rejects_missing_protected_credentials_befor tmp_path, ) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() workspace_id = uuid4() workspace_path = (tmp_path / "workspace").resolve() @@ -980,6 +1136,7 @@ def test_ingest_gmail_message_record_rejects_missing_protected_credentials_befor ) account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -1003,6 +1160,7 @@ def fail_fetch(**_kwargs): ): ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=account_id, @@ -1021,6 +1179,7 @@ def test_ingest_gmail_message_record_rejects_invalid_protected_credentials_befor tmp_path, ) -> None: store = GmailStoreStub() + secret_manager = GmailSecretManagerStub() user_id = uuid4() workspace_id = uuid4() workspace_path = (tmp_path / "workspace").resolve() @@ -1030,6 +1189,7 @@ def test_ingest_gmail_message_record_rejects_invalid_protected_credentials_befor ) account = create_gmail_account_record( store, + secret_manager, user_id=user_id, request=GmailAccountConnectInput( provider_account_id="acct-001", @@ -1040,7 +1200,9 @@ def test_ingest_gmail_message_record_rejects_invalid_protected_credentials_befor ), )["account"] account_id = UUID(account["id"]) - store.gmail_account_credentials[account_id]["credential_blob"] = { + secret_ref = store.gmail_account_credentials[account_id]["secret_ref"] + assert secret_ref is not None + secret_manager.secrets[secret_ref] = { "credential_kind": GMAIL_PROTECTED_CREDENTIAL_KIND, "access_token": "", } @@ -1056,6 +1218,7 @@ def fail_fetch(**_kwargs): ): ingest_gmail_message_record( store, + secret_manager, user_id=user_id, request=GmailMessageIngestInput( gmail_account_id=account_id, diff --git a/tests/unit/test_gmail_main.py b/tests/unit/test_gmail_main.py index 2eaf011..1ee6229 100644 --- a/tests/unit/test_gmail_main.py +++ b/tests/unit/test_gmail_main.py @@ -22,9 +22,16 @@ from alicebot_api.workspaces import TaskWorkspaceNotFoundError +def _settings() -> Settings: + return Settings( + database_url="postgresql://app", + gmail_secret_manager_url="file:///tmp/test-gmail-secrets", + ) + + def test_list_gmail_accounts_endpoint_returns_payload(monkeypatch) -> None: user_id = uuid4() - settings = Settings(database_url="postgresql://app") + settings = _settings() @contextmanager def fake_user_connection(*_args, **_kwargs): @@ -52,7 +59,7 @@ def fake_user_connection(*_args, **_kwargs): def test_connect_gmail_account_endpoint_maps_duplicate_to_409(monkeypatch) -> None: user_id = uuid4() - settings = Settings(database_url="postgresql://app") + settings = _settings() @contextmanager def fake_user_connection(*_args, **_kwargs): @@ -95,7 +102,7 @@ def test_connect_gmail_account_request_requires_complete_refresh_bundle() -> Non def test_connect_gmail_account_endpoint_maps_invalid_refresh_bundle_to_400(monkeypatch) -> None: user_id = uuid4() - settings = Settings(database_url="postgresql://app") + settings = _settings() @contextmanager def fake_user_connection(*_args, **_kwargs): @@ -130,10 +137,41 @@ def fake_create_gmail_account_record(*_args, **_kwargs): } +def test_connect_gmail_account_endpoint_maps_secret_persistence_failure_to_409(monkeypatch) -> None: + user_id = uuid4() + settings = _settings() + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_gmail_account_record(*_args, **_kwargs): + raise GmailCredentialPersistenceError("gmail protected credentials could not be persisted") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_gmail_account_record", fake_create_gmail_account_record) + + response = main_module.connect_gmail_account( + main_module.ConnectGmailAccountRequest( + user_id=user_id, + provider_account_id="acct-001", + email_address="owner@example.com", + display_name="Owner", + access_token="token-1", + ) + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": "gmail protected credentials could not be persisted" + } + + def test_get_gmail_account_endpoint_maps_not_found_to_404(monkeypatch) -> None: user_id = uuid4() gmail_account_id = uuid4() - settings = Settings(database_url="postgresql://app") + settings = _settings() @contextmanager def fake_user_connection(*_args, **_kwargs): @@ -156,7 +194,7 @@ def test_ingest_gmail_message_endpoint_maps_workspace_not_found_to_404(monkeypat user_id = uuid4() gmail_account_id = uuid4() task_workspace_id = uuid4() - settings = Settings(database_url="postgresql://app") + settings = _settings() @contextmanager def fake_user_connection(*_args, **_kwargs): @@ -188,7 +226,7 @@ def test_ingest_gmail_message_endpoint_maps_upstream_errors(monkeypatch) -> None user_id = uuid4() gmail_account_id = uuid4() task_workspace_id = uuid4() - settings = Settings(database_url="postgresql://app") + settings = _settings() @contextmanager def fake_user_connection(*_args, **_kwargs): diff --git a/tests/unit/test_gmail_secret_manager.py b/tests/unit/test_gmail_secret_manager.py new file mode 100644 index 0000000..387537c --- /dev/null +++ b/tests/unit/test_gmail_secret_manager.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest + +from alicebot_api.gmail_secret_manager import ( + GmailSecretManagerError, + build_gmail_secret_manager, +) + + +def test_build_gmail_secret_manager_rejects_non_file_schemes() -> None: + with pytest.raises(ValueError, match="GMAIL_SECRET_MANAGER_URL must use the file:// scheme"): + build_gmail_secret_manager("memory://gmail-secrets") + + +def test_build_gmail_secret_manager_requires_explicit_configuration() -> None: + with pytest.raises(ValueError, match="GMAIL_SECRET_MANAGER_URL must be configured"): + build_gmail_secret_manager("") + + +def test_file_gmail_secret_manager_round_trips_secret_payload(tmp_path) -> None: + manager = build_gmail_secret_manager(tmp_path.resolve().as_uri()) + secret_ref = "users/00000000-0000-0000-0000-000000000001/gmail-account-credentials/cred.json" + payload = { + "credential_kind": "gmail_oauth_access_token_v1", + "access_token": "token-001", + } + + manager.write_secret(secret_ref=secret_ref, payload=payload) + + assert manager.load_secret(secret_ref=secret_ref) == payload + + +def test_file_gmail_secret_manager_rejects_missing_or_escaped_refs(tmp_path) -> None: + manager = build_gmail_secret_manager(tmp_path.resolve().as_uri()) + + with pytest.raises(GmailSecretManagerError, match="was not found"): + manager.load_secret(secret_ref="users/u/gmail-account-credentials/missing.json") + + with pytest.raises(GmailSecretManagerError, match="outside the configured root"): + manager.write_secret( + secret_ref="../../escape.json", + payload={"credential_kind": "gmail_oauth_access_token_v1", "access_token": "token"}, + ) From ae918518c633f3514d587c53a4015042940feb28 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 17 Mar 2026 00:16:19 +0100 Subject: [PATCH 021/135] Sprint 5U: project truth sync after Gmail secret externalization (#20) Co-authored-by: Sami Rusani --- ROADMAP.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index 936fca2..e1a83f4 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -2,18 +2,18 @@ ## Current Position -- The accepted repo state is current through Sprint 5R. +- The accepted repo state is current through Sprint 5T. - Milestone 5 now ships the rooted local workspace and artifact baseline end to end: workspace provisioning, artifact registration, narrow text ingestion, narrow PDF/DOCX/RFC822 ingestion, durable chunk storage, lexical artifact retrieval, compile-path artifact inclusion, artifact-chunk embeddings, direct semantic artifact retrieval, compile-path semantic artifact retrieval, and deterministic hybrid lexical-plus-semantic artifact merge in compile. -- The same milestone also now ships the narrow Gmail seam: read-only Gmail account persistence, secret-free account reads, protected credential storage in `gmail_account_credentials`, refresh-token renewal for expired access tokens, rotated refresh-token persistence when the provider returns a replacement token, and one explicit selected-message ingestion path that lands in the existing RFC822 artifact pipeline. +- The same milestone also now ships the narrow Gmail seam: read-only Gmail account persistence, secret-free account reads, external-secret-backed primary credential storage with locator metadata in `gmail_account_credentials`, refresh-token renewal for expired access tokens through the externalized seam, rotated refresh-token persistence when the provider returns a replacement token, one narrow `legacy_db_v0` transition path for older rows, and one explicit selected-message ingestion path that lands in the existing RFC822 artifact pipeline. - This roadmap is future-facing from that shipped baseline; historical sprint-by-sprint detail lives in accepted build and review artifacts, not here. ## Next Delivery Focus -### Open One More Narrow Gmail Auth Seam On Top Of The Shipped Baseline +### Remove The Remaining Narrow Gmail Transition Path Without Widening Scope -- Keep the next sprint auth-adjacent and narrow, building on the shipped protected-credential-backed Gmail seam rather than widening connector breadth. -- The next best seam is external secret-manager integration for the existing `gmail_account_credentials` boundary, without changing the read-only account contract or the single-message ingestion contract. -- Do not combine that work with Gmail search, mailbox sync, attachment ingestion, Calendar scope, UI work, or broader connector orchestration. +- Keep the next sprint auth-adjacent and narrow, building on the shipped external-secret-backed Gmail seam rather than widening connector breadth. +- The next best seam is deliberate cleanup of the remaining `legacy_db_v0` transition boundary for older `gmail_account_credentials` rows, without changing the read-only account contract or the single-message ingestion contract. +- Do not combine that cleanup with Gmail search, mailbox sync, attachment ingestion, Calendar scope, UI work, or broader connector orchestration. ### Preserve Current Document, Compile, Governance, And Task Guarantees @@ -24,7 +24,7 @@ ## After The Next Narrow Sprint -- Reassess broader connector work only after the current Gmail protected-credential boundary remains stable under externalized secret storage and the truth artifacts stay synchronized. +- Reassess broader connector work only after the current Gmail external-secret-backed credential boundary remains stable without the `legacy_db_v0` transition path and the truth artifacts stay synchronized. - Revisit workflow UI only after backend document and connector seams are accepted and the truth artifacts stay current. - Revisit broader task orchestration only after the current explicit task-step seams remain stable under workspace, artifact, document, and connector flows. - Continue to defer broader tool execution breadth and production auth/deployment hardening until the current governed surface remains stable. @@ -33,11 +33,11 @@ - Live truth docs must stay synchronized with accepted repo state so sprint planning does not start from stale assumptions. - Rich document parsing work should continue to build on the shipped rooted local workspace, durable artifact chunk, and hybrid compile retrieval contracts. -- Connector work should remain read-only, single-message-only, approval-aware, and protected-credential-backed until a later sprint explicitly opens broader scope. +- Connector work should remain read-only, single-message-only, approval-aware, and external-secret-backed until a later sprint explicitly opens broader scope. - Runner-style orchestration should stay deferred until the repo no longer depends on narrow current-step assumptions for safety and explainability. ## Ongoing Risks - Memory extraction and retrieval quality remain the largest product risk. - Auth beyond database user context is still missing. -- Milestone 5 can drift if Gmail auth hardening, broader connector breadth, UI, richer parsing, and orchestration work are mixed into one sprint instead of landing as narrow seams on top of the shipped document-ingestion and Gmail baseline. +- Milestone 5 can drift if `legacy_db_v0` cleanup, broader Gmail breadth, UI, richer parsing, and orchestration work are mixed into one sprint instead of landing as narrow seams on top of the shipped document-ingestion and Gmail externalization baseline. From 9c1da9cdf1591900aff79c33e77145c197f09880 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 17 Mar 2026 08:31:39 +0100 Subject: [PATCH 022/135] Sprint 6A: MVP web shell and core operator views (#21) Co-authored-by: Sami Rusani --- apps/web/app/approvals/page.tsx | 193 +++++ apps/web/app/chat/page.tsx | 93 +++ apps/web/app/globals.css | 931 +++++++++++++++++++++++ apps/web/app/layout.tsx | 19 +- apps/web/app/page.tsx | 182 +++-- apps/web/app/tasks/page.tsx | 318 ++++++++ apps/web/app/traces/page.tsx | 173 +++++ apps/web/components/app-shell.tsx | 119 +++ apps/web/components/approval-list.tsx | 212 ++++++ apps/web/components/empty-state.tsx | 22 + apps/web/components/page-header.tsx | 21 + apps/web/components/request-composer.tsx | 230 ++++++ apps/web/components/section-card.tsx | 30 + apps/web/components/status-badge.tsx | 45 ++ apps/web/components/task-list.tsx | 113 +++ apps/web/components/task-step-list.tsx | 117 +++ apps/web/components/trace-list.tsx | 165 ++++ 17 files changed, 2936 insertions(+), 47 deletions(-) create mode 100644 apps/web/app/approvals/page.tsx create mode 100644 apps/web/app/chat/page.tsx create mode 100644 apps/web/app/globals.css create mode 100644 apps/web/app/tasks/page.tsx create mode 100644 apps/web/app/traces/page.tsx create mode 100644 apps/web/components/app-shell.tsx create mode 100644 apps/web/components/approval-list.tsx create mode 100644 apps/web/components/empty-state.tsx create mode 100644 apps/web/components/page-header.tsx create mode 100644 apps/web/components/request-composer.tsx create mode 100644 apps/web/components/section-card.tsx create mode 100644 apps/web/components/status-badge.tsx create mode 100644 apps/web/components/task-list.tsx create mode 100644 apps/web/components/task-step-list.tsx create mode 100644 apps/web/components/trace-list.tsx diff --git a/apps/web/app/approvals/page.tsx b/apps/web/app/approvals/page.tsx new file mode 100644 index 0000000..f6b7db6 --- /dev/null +++ b/apps/web/app/approvals/page.tsx @@ -0,0 +1,193 @@ +import { ApprovalList, type ApprovalItem } from "../../components/approval-list"; +import { PageHeader } from "../../components/page-header"; + +const approvalFixtures: ApprovalItem[] = [ + { + id: "approval-101", + thread_id: "thread-magnesium", + task_step_id: "step-21", + status: "pending", + request: { + thread_id: "thread-magnesium", + tool_id: "tool-purchase", + action: "place_order", + scope: "supplements", + domain_hint: "ecommerce", + risk_hint: "purchase", + attributes: { + merchant: "Thorne", + item: "Magnesium Bisglycinate", + quantity: "1", + budget_note: "Prefer previously approved merchant and package size.", + }, + }, + tool: { + id: "tool-purchase", + tool_key: "merchant_proxy", + name: "Merchant Proxy", + description: "Proxy for governed ecommerce actions.", + version: "0.1.0", + metadata_version: "tool_metadata_v0", + active: true, + tags: ["commerce", "approval"], + action_hints: ["place_order"], + scope_hints: ["supplements"], + domain_hints: ["ecommerce"], + risk_hints: ["purchase"], + metadata: {}, + created_at: "2026-03-15T08:00:00Z", + }, + routing: { + decision: "require_approval", + reasons: [ + { + code: "policy_effect_require_approval", + source: "policy", + message: "Purchases require explicit user approval before execution.", + tool_id: "tool-purchase", + policy_id: "policy-purchase-approval", + consent_key: null, + }, + { + code: "tool_metadata_matched", + source: "tool", + message: "Merchant proxy supports the requested purchase scope.", + tool_id: "tool-purchase", + policy_id: null, + consent_key: null, + }, + ], + trace: { + trace_id: "trace-approval-101", + trace_event_count: 6, + }, + }, + created_at: "2026-03-17T06:50:00Z", + resolution: null, + }, + { + id: "approval-100", + thread_id: "thread-vitamin-d", + task_step_id: "step-14", + status: "approved", + request: { + thread_id: "thread-vitamin-d", + tool_id: "tool-purchase", + action: "place_order", + scope: "supplements", + domain_hint: "ecommerce", + risk_hint: "purchase", + attributes: { + merchant: "Fullscript", + item: "Vitamin D3 + K2", + quantity: "1", + note: "Matched prior merchant and approved dosage plan.", + }, + }, + tool: { + id: "tool-purchase", + tool_key: "merchant_proxy", + name: "Merchant Proxy", + description: "Proxy for governed ecommerce actions.", + version: "0.1.0", + metadata_version: "tool_metadata_v0", + active: true, + tags: ["commerce", "approval"], + action_hints: ["place_order"], + scope_hints: ["supplements"], + domain_hints: ["ecommerce"], + risk_hints: ["purchase"], + metadata: {}, + created_at: "2026-03-14T09:15:00Z", + }, + routing: { + decision: "require_approval", + reasons: [ + { + code: "matched_policy", + source: "policy", + message: "Repeat supplement purchases remain approval-gated even when the merchant and dosage are known.", + tool_id: "tool-purchase", + policy_id: "policy-purchase-approval", + consent_key: null, + }, + ], + trace: { + trace_id: "trace-approval-100", + trace_event_count: 5, + }, + }, + created_at: "2026-03-16T14:10:00Z", + resolution: { + resolved_at: "2026-03-16T14:22:00Z", + resolved_by_user_id: "operator-1", + }, + }, +]; + +function getApiConfig() { + return { + apiBaseUrl: + process.env.NEXT_PUBLIC_ALICEBOT_API_BASE_URL ?? process.env.ALICEBOT_API_BASE_URL ?? "", + userId: process.env.NEXT_PUBLIC_ALICEBOT_USER_ID ?? process.env.ALICEBOT_USER_ID ?? "", + }; +} + +async function loadApprovals(): Promise<{ items: ApprovalItem[]; source: "live" | "fixture" }> { + const { apiBaseUrl, userId } = getApiConfig(); + if (!apiBaseUrl || !userId) { + return { items: approvalFixtures, source: "fixture" }; + } + + try { + const response = await fetch( + `${apiBaseUrl.replace(/\/$/, "")}/v0/approvals?user_id=${encodeURIComponent(userId)}`, + { cache: "no-store" }, + ); + + if (!response.ok) { + throw new Error("approval list request failed"); + } + + const payload = (await response.json()) as { items?: ApprovalItem[] }; + return { + items: payload.items ?? approvalFixtures, + source: "live", + }; + } catch { + return { items: approvalFixtures, source: "fixture" }; + } +} + +type SearchParams = Promise>; + +export default async function ApprovalsPage({ + searchParams, +}: { + searchParams?: SearchParams; +}) { + const params = (searchParams ? await searchParams : {}) as Record< + string, + string | string[] | undefined + >; + const selectedId = typeof params.approval === "string" ? params.approval : undefined; + const { items, source } = await loadApprovals(); + + return ( +
+ + {source === "live" ? "Live API" : "Fixture-backed"} + {items.length} items +
+ } + /> + + + + ); +} diff --git a/apps/web/app/chat/page.tsx b/apps/web/app/chat/page.tsx new file mode 100644 index 0000000..6dc68a6 --- /dev/null +++ b/apps/web/app/chat/page.tsx @@ -0,0 +1,93 @@ +import { PageHeader } from "../../components/page-header"; +import { RequestComposer, type RequestHistoryEntry } from "../../components/request-composer"; +import { SectionCard } from "../../components/section-card"; + +const initialEntries: RequestHistoryEntry[] = [ + { + id: "req-001", + request: "Summarize the open magnesium reorder task and tell me whether an approval is still required.", + response: + "The current task remains in a governed state. The latest task step is waiting on approval resolution before any execution can proceed, and the next operator action is to review the approval inbox rather than trigger another tool call.", + submittedAt: "2026-03-17T08:45:00Z", + source: "fixture", + trace: { + compileTraceId: "trace-ctx-401", + compileTraceEventCount: 9, + responseTraceId: "trace-resp-402", + responseTraceEventCount: 4, + }, + }, +]; + +function getApiConfig() { + return { + apiBaseUrl: + process.env.NEXT_PUBLIC_ALICEBOT_API_BASE_URL ?? process.env.ALICEBOT_API_BASE_URL ?? "", + userId: process.env.NEXT_PUBLIC_ALICEBOT_USER_ID ?? process.env.ALICEBOT_USER_ID ?? "", + threadId: process.env.NEXT_PUBLIC_ALICEBOT_THREAD_ID ?? process.env.ALICEBOT_THREAD_ID ?? "", + }; +} + +export default function ChatPage() { + const apiConfig = getApiConfig(); + const liveModeReady = Boolean(apiConfig.apiBaseUrl && apiConfig.userId && apiConfig.threadId); + + return ( +
+ + {liveModeReady ? "Live API mode" : "Fixture mode"} + Response traces visible +
+ } + /> + +
+ + +
+ +
    +
  • Requests are framed as operator instructions against existing governed seams.
  • +
  • Live mode posts to the shipped response endpoint only when API configuration is present.
  • +
  • Trace references stay attached to each recent response so explainability remains first-class.
  • +
+
+ + +
+
+
Sessions
+
Up to 8 recent sessions
+
+
+
Events
+
Up to 80 continuity events
+
+
+
Memories
+
Up to 20 admitted memories
+
+
+
Entities
+
Up to 12 entities and 20 edges
+
+
+
+
+
+ + ); +} diff --git a/apps/web/app/globals.css b/apps/web/app/globals.css new file mode 100644 index 0000000..2905c41 --- /dev/null +++ b/apps/web/app/globals.css @@ -0,0 +1,931 @@ +:root { + --font-sans: "Avenir Next", "Segoe UI", "Helvetica Neue", sans-serif; + --font-serif: "Iowan Old Style", "Palatino Linotype", "Book Antiqua", Georgia, serif; + --bg: #f3efe8; + --bg-accent: rgba(68, 88, 112, 0.12); + --surface: rgba(255, 252, 248, 0.88); + --surface-strong: rgba(255, 255, 255, 0.94); + --surface-muted: rgba(246, 240, 232, 0.82); + --border: rgba(42, 52, 66, 0.12); + --border-strong: rgba(42, 52, 66, 0.18); + --text: #18202a; + --text-soft: #566172; + --text-muted: #707988; + --accent: #274b63; + --accent-soft: rgba(39, 75, 99, 0.09); + --success: #2c6e62; + --success-soft: rgba(44, 110, 98, 0.1); + --warning: #8e6220; + --warning-soft: rgba(142, 98, 32, 0.11); + --danger: #8d4440; + --danger-soft: rgba(141, 68, 64, 0.1); + --info: #365d7c; + --info-soft: rgba(54, 93, 124, 0.11); + --shadow-lg: 0 22px 70px rgba(32, 43, 56, 0.09); + --shadow-md: 0 14px 36px rgba(32, 43, 56, 0.06); + --radius-xl: 28px; + --radius-lg: 22px; + --radius-md: 16px; + --radius-sm: 12px; + --content-width: 1360px; +} + +* { + box-sizing: border-box; + min-width: 0; +} + +html { + background: + radial-gradient(circle at top left, rgba(211, 221, 232, 0.5), transparent 28%), + radial-gradient(circle at top right, rgba(233, 220, 205, 0.55), transparent 24%), + var(--bg); + color: var(--text); +} + +body { + margin: 0; + font-family: var(--font-sans); + color: var(--text); + min-height: 100vh; +} + +a { + color: inherit; + text-decoration: none; +} + +button, +input, +textarea, +select { + font: inherit; +} + +button { + cursor: pointer; +} + +img, +svg { + display: block; + max-width: 100%; +} + +code, +.mono { + font-family: + "SFMono-Regular", "SF Mono", "JetBrains Mono", "Roboto Mono", "Menlo", monospace; +} + +.shell-chrome { + position: relative; + min-height: 100vh; +} + +.shell-chrome::before { + content: ""; + position: fixed; + inset: 0; + background: + linear-gradient(180deg, rgba(255, 255, 255, 0.35), transparent 32%), + radial-gradient(circle at 15% 20%, rgba(39, 75, 99, 0.08), transparent 24%); + pointer-events: none; +} + +.shell { + position: relative; + z-index: 1; + max-width: var(--content-width); + margin: 0 auto; + padding: 24px; + display: grid; + grid-template-columns: 280px minmax(0, 1fr); + gap: 24px; +} + +.shell-sidebar { + position: sticky; + top: 24px; + align-self: start; + display: grid; + gap: 18px; + padding: 22px; + background: rgba(255, 252, 248, 0.72); + border: 1px solid var(--border); + border-radius: 30px; + box-shadow: var(--shadow-md); + backdrop-filter: blur(16px); +} + +.brand-mark { + display: inline-grid; + place-items: center; + width: 42px; + height: 42px; + border-radius: 14px; + background: linear-gradient(180deg, rgba(39, 75, 99, 0.16), rgba(39, 75, 99, 0.08)); + border: 1px solid rgba(39, 75, 99, 0.16); + color: var(--accent); + font-weight: 700; + letter-spacing: 0.08em; +} + +.brand-copy { + display: grid; + gap: 6px; +} + +.eyebrow { + margin: 0; + font-size: 0.7rem; + letter-spacing: 0.16em; + text-transform: uppercase; + color: var(--text-muted); +} + +.brand-title { + margin: 0; + font-family: var(--font-serif); + font-size: 1.5rem; + font-weight: 600; + letter-spacing: -0.02em; +} + +.brand-description, +.muted-copy { + margin: 0; + color: var(--text-soft); + line-height: 1.6; +} + +.shell-nav, +.shell-nav--mobile { + display: grid; + gap: 10px; +} + +.shell-nav__item { + display: grid; + gap: 4px; + padding: 14px 16px; + border-radius: 16px; + border: 1px solid transparent; + transition: + border-color 140ms ease, + background-color 140ms ease, + transform 140ms ease; +} + +.shell-nav__item:hover, +.shell-nav__item:focus-visible { + border-color: var(--border); + background: rgba(255, 255, 255, 0.6); + transform: translateY(-1px); +} + +.shell-nav__item.is-active { + border-color: rgba(39, 75, 99, 0.18); + background: var(--accent-soft); +} + +.shell-nav__title { + font-size: 0.95rem; + font-weight: 600; +} + +.shell-nav__caption { + color: var(--text-soft); + font-size: 0.86rem; + line-height: 1.45; +} + +.shell-note { + padding: 16px; + border-radius: 18px; + background: rgba(244, 238, 230, 0.8); + border: 1px solid var(--border); +} + +.shell-note__title { + margin: 0 0 8px; + font-size: 0.88rem; + font-weight: 600; +} + +.shell-column { + display: grid; + gap: 22px; +} + +.shell-topbar { + display: grid; + gap: 18px; + padding: 22px 24px; + background: rgba(255, 252, 248, 0.72); + border: 1px solid var(--border); + border-radius: 30px; + box-shadow: var(--shadow-md); + backdrop-filter: blur(16px); +} + +.shell-topbar__row { + display: flex; + align-items: center; + justify-content: space-between; + gap: 16px; +} + +.shell-topbar__title { + margin: 0; + font-family: var(--font-serif); + font-size: clamp(1.6rem, 3vw, 2rem); + letter-spacing: -0.03em; +} + +.topbar-status { + display: flex; + flex-wrap: wrap; + gap: 10px; +} + +.subtle-chip { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 9px 12px; + border-radius: 999px; + background: rgba(255, 255, 255, 0.74); + border: 1px solid var(--border); + color: var(--text-soft); + font-size: 0.82rem; + line-height: 1; + white-space: nowrap; +} + +.shell-main { + padding-bottom: 24px; +} + +.content-frame { + display: grid; + gap: 24px; +} + +.page-stack, +.stack { + display: grid; + gap: 24px; +} + +.page-header { + display: flex; + flex-wrap: wrap; + align-items: flex-end; + justify-content: space-between; + gap: 18px 24px; +} + +.page-header__copy { + display: grid; + gap: 10px; + max-width: 860px; +} + +.page-header h1 { + margin: 0; + font-family: var(--font-serif); + font-size: clamp(2rem, 4vw, 3rem); + line-height: 1.05; + letter-spacing: -0.04em; +} + +.page-header p { + margin: 0; + color: var(--text-soft); + line-height: 1.7; + max-width: 74ch; +} + +.header-meta { + display: flex; + flex-wrap: wrap; + gap: 10px; +} + +.content-grid, +.dashboard-grid { + display: grid; + gap: 24px; +} + +.content-grid--wide { + grid-template-columns: minmax(0, 1.55fr) minmax(300px, 0.9fr); + align-items: flex-start; +} + +.dashboard-grid--detail { + grid-template-columns: minmax(320px, 0.95fr) minmax(0, 1.25fr); + align-items: flex-start; +} + +.metric-grid, +.route-grid { + display: grid; + gap: 18px; +} + +.metric-grid { + grid-template-columns: repeat(4, minmax(0, 1fr)); +} + +.route-grid { + grid-template-columns: repeat(2, minmax(0, 1fr)); +} + +.section-card { + display: grid; + gap: 18px; + padding: 24px; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius-xl); + box-shadow: var(--shadow-md); + backdrop-filter: blur(14px); +} + +.section-card--metric { + gap: 10px; + min-height: 100%; +} + +.section-card__header { + display: grid; + gap: 8px; +} + +.section-card__title { + margin: 0; + font-size: 1.1rem; + font-weight: 600; + letter-spacing: -0.02em; +} + +.section-card__description { + margin: 0; + color: var(--text-soft); + line-height: 1.6; +} + +.metric-value { + font-size: clamp(2rem, 5vw, 2.65rem); + font-weight: 600; + letter-spacing: -0.05em; +} + +.metric-label { + font-size: 0.94rem; + font-weight: 600; +} + +.metric-detail { + margin: 0; + color: var(--text-soft); + line-height: 1.6; +} + +.nav-card { + display: grid; + gap: 12px; + padding: 18px; + border-radius: 20px; + border: 1px solid var(--border); + background: rgba(255, 255, 255, 0.56); + transition: + transform 140ms ease, + border-color 140ms ease, + box-shadow 140ms ease; +} + +.nav-card:hover, +.nav-card:focus-visible { + transform: translateY(-1px); + border-color: var(--border-strong); + box-shadow: 0 14px 32px rgba(32, 43, 56, 0.05); +} + +.nav-card__topline { + display: flex; + justify-content: space-between; + gap: 12px; + align-items: flex-start; +} + +.nav-card__topline h3 { + margin: 0; + font-size: 1rem; + letter-spacing: -0.02em; +} + +.nav-card p { + margin: 0; + color: var(--text-soft); + line-height: 1.6; +} + +.nav-card__cta { + color: var(--accent); + font-size: 0.9rem; + font-weight: 600; +} + +.bullet-list { + margin: 0; + padding-left: 1.2rem; + color: var(--text-soft); + display: grid; + gap: 12px; + line-height: 1.6; +} + +.key-value-grid { + display: grid; + gap: 14px; + grid-template-columns: repeat(2, minmax(0, 1fr)); +} + +.key-value-grid div { + display: grid; + gap: 6px; + padding: 14px 16px; + border-radius: 16px; + background: rgba(255, 255, 255, 0.6); + border: 1px solid rgba(42, 52, 66, 0.08); +} + +.key-value-grid dt { + color: var(--text-muted); + font-size: 0.8rem; + text-transform: uppercase; + letter-spacing: 0.12em; +} + +.key-value-grid dd { + margin: 0; + line-height: 1.55; + overflow-wrap: anywhere; +} + +.status-badge { + display: inline-flex; + align-items: center; + justify-content: center; + padding: 8px 11px; + border-radius: 999px; + border: 1px solid transparent; + font-size: 0.78rem; + font-weight: 600; + letter-spacing: 0.04em; + line-height: 1; + text-transform: uppercase; + white-space: nowrap; +} + +.status-badge--success { + color: var(--success); + background: var(--success-soft); + border-color: rgba(44, 110, 98, 0.18); +} + +.status-badge--warning { + color: var(--warning); + background: var(--warning-soft); + border-color: rgba(142, 98, 32, 0.18); +} + +.status-badge--danger { + color: var(--danger); + background: var(--danger-soft); + border-color: rgba(141, 68, 64, 0.18); +} + +.status-badge--info { + color: var(--info); + background: var(--info-soft); + border-color: rgba(54, 93, 124, 0.18); +} + +.status-badge--neutral { + color: var(--text-soft); + background: rgba(255, 255, 255, 0.72); + border-color: var(--border); +} + +.empty-state { + display: grid; + gap: 12px; + justify-items: start; + padding: 28px; + border-radius: 22px; + background: rgba(255, 255, 255, 0.56); + border: 1px dashed var(--border-strong); +} + +.empty-state__title { + margin: 0; + font-size: 1rem; + font-weight: 600; +} + +.empty-state__description { + margin: 0; + color: var(--text-soft); + line-height: 1.6; +} + +.button, +.button-secondary { + display: inline-flex; + align-items: center; + justify-content: center; + padding: 12px 16px; + border-radius: 14px; + border: 1px solid transparent; + font-weight: 600; + line-height: 1; + transition: + transform 140ms ease, + background-color 140ms ease, + border-color 140ms ease; +} + +.button { + background: var(--accent); + color: #f7fafc; +} + +.button:hover, +.button:focus-visible, +.button-secondary:hover, +.button-secondary:focus-visible { + transform: translateY(-1px); +} + +.button:disabled { + opacity: 0.62; + cursor: not-allowed; + transform: none; +} + +.button-secondary { + background: rgba(255, 255, 255, 0.72); + border-color: var(--border); + color: var(--text); +} + +.composer-card { + display: grid; + gap: 24px; + padding: 24px; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius-xl); + box-shadow: var(--shadow-md); + backdrop-filter: blur(14px); +} + +.composer-card__header, +.detail-stack, +.trace-panel, +.trace-panel__detail, +.list-panel { + display: grid; + gap: 16px; +} + +.governance-banner { + display: flex; + flex-wrap: wrap; + gap: 10px; + align-items: center; + padding: 14px 16px; + background: rgba(39, 75, 99, 0.06); + border: 1px solid rgba(39, 75, 99, 0.12); + border-radius: 18px; + color: var(--text-soft); +} + +.governance-banner strong { + color: var(--text); +} + +.form-field { + display: grid; + gap: 10px; +} + +.form-field label { + font-size: 0.9rem; + font-weight: 600; +} + +.form-field textarea, +.form-field input { + width: 100%; + padding: 16px 18px; + background: rgba(255, 255, 255, 0.74); + border: 1px solid var(--border); + border-radius: 18px; + color: var(--text); + resize: vertical; +} + +.form-field textarea { + min-height: 168px; + line-height: 1.6; +} + +.field-hint { + margin: 0; + color: var(--text-muted); + font-size: 0.86rem; + line-height: 1.5; +} + +.composer-actions { + display: flex; + flex-wrap: wrap; + gap: 12px; + align-items: center; + justify-content: space-between; +} + +.composer-status { + color: var(--text-soft); + font-size: 0.9rem; +} + +.history-list, +.list-rows, +.timeline-list, +.trace-events { + display: grid; + gap: 12px; +} + +.history-entry, +.list-row, +.timeline-item, +.trace-event { + padding: 16px 18px; + border-radius: 18px; + border: 1px solid rgba(42, 52, 66, 0.08); + background: rgba(255, 255, 255, 0.58); +} + +.history-entry { + display: grid; + gap: 14px; +} + +.history-entry__topline, +.list-row__topline, +.timeline-item__topline, +.trace-event__topline, +.detail-summary { + display: flex; + align-items: flex-start; + justify-content: space-between; + gap: 12px; +} + +.history-entry__label, +.list-row__eyebrow, +.detail-summary__label { + color: var(--text-muted); + font-size: 0.8rem; + letter-spacing: 0.12em; + text-transform: uppercase; +} + +.history-entry p, +.list-row p, +.timeline-item p, +.trace-event p { + margin: 0; + color: var(--text-soft); + line-height: 1.6; + overflow-wrap: anywhere; +} + +.history-entry__trace, +.cluster { + display: flex; + flex-wrap: wrap; + gap: 10px; +} + +.split-layout { + display: grid; + gap: 24px; + grid-template-columns: minmax(320px, 0.95fr) minmax(0, 1.25fr); +} + +.list-panel__header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 14px; +} + +.list-panel__header h2, +.trace-panel h2 { + margin: 0; + font-size: 1.05rem; + letter-spacing: -0.02em; +} + +.list-panel__header p, +.trace-panel__detail > p { + margin: 0; + color: var(--text-soft); + line-height: 1.6; +} + +.list-row { + display: grid; + gap: 12px; + transition: + transform 140ms ease, + border-color 140ms ease, + background-color 140ms ease; +} + +.list-row:hover, +.list-row:focus-visible { + transform: translateY(-1px); + border-color: var(--border-strong); +} + +.list-row.is-selected { + border-color: rgba(39, 75, 99, 0.2); + background: var(--accent-soft); +} + +.list-row__title { + margin: 0; + font-size: 0.98rem; + font-weight: 600; + letter-spacing: -0.02em; +} + +.list-row__meta, +.attribute-list, +.evidence-list { + display: flex; + flex-wrap: wrap; + gap: 8px; +} + +.meta-pill { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 8px 10px; + border-radius: 999px; + background: rgba(255, 255, 255, 0.68); + border: 1px solid rgba(42, 52, 66, 0.08); + color: var(--text-soft); + font-size: 0.82rem; +} + +.detail-grid { + display: grid; + gap: 18px; +} + +.detail-group { + display: grid; + gap: 12px; +} + +.detail-group h3 { + margin: 0; + font-size: 0.92rem; + letter-spacing: -0.01em; +} + +.attribute-item, +.evidence-chip { + padding: 9px 11px; + border-radius: 12px; + background: rgba(255, 255, 255, 0.7); + border: 1px solid rgba(42, 52, 66, 0.08); + color: var(--text-soft); + font-size: 0.84rem; + overflow-wrap: anywhere; +} + +.reason-list { + margin: 0; + padding-left: 1.2rem; + color: var(--text-soft); + display: grid; + gap: 10px; + line-height: 1.6; +} + +.timeline-item { + display: grid; + gap: 12px; +} + +.timeline-item__summary { + display: grid; + gap: 8px; +} + +.trace-events { + margin: 0; + padding: 0; + list-style: none; +} + +.trace-event h4 { + margin: 0; + font-size: 0.95rem; + letter-spacing: -0.01em; +} + +.trace-summary { + display: grid; + gap: 14px; +} + +.responsive-note { + color: var(--text-muted); + font-size: 0.82rem; +} + +@media (max-width: 1120px) { + .shell { + grid-template-columns: 1fr; + } + + .shell-sidebar { + display: none; + } + + .content-grid--wide, + .dashboard-grid--detail, + .split-layout, + .metric-grid, + .route-grid { + grid-template-columns: 1fr; + } +} + +@media (min-width: 1121px) { + .shell-nav--mobile { + display: none; + } +} + +@media (max-width: 740px) { + .shell { + padding: 14px; + gap: 16px; + } + + .shell-topbar, + .section-card, + .composer-card { + padding: 20px; + border-radius: 24px; + } + + .shell-topbar__row, + .page-header, + .composer-actions, + .list-panel__header, + .history-entry__topline, + .list-row__topline, + .timeline-item__topline, + .trace-event__topline, + .detail-summary, + .nav-card__topline { + flex-direction: column; + align-items: flex-start; + } + + .shell-nav--mobile { + grid-auto-flow: column; + grid-auto-columns: minmax(150px, 1fr); + overflow-x: auto; + padding-bottom: 2px; + } + + .key-value-grid { + grid-template-columns: 1fr; + } +} diff --git a/apps/web/app/layout.tsx b/apps/web/app/layout.tsx index ed6cafd..4fcf704 100644 --- a/apps/web/app/layout.tsx +++ b/apps/web/app/layout.tsx @@ -1,10 +1,23 @@ +import type { Metadata } from "next"; +import type { ReactNode } from "react"; + +import { AppShell } from "../components/app-shell"; + +import "./globals.css"; + +export const metadata: Metadata = { + title: "AliceBot Operator Shell", + description: "Governed operator interface for requests, approvals, tasks, and explainability.", +}; + export default function RootLayout({ children, -}: Readonly<{ children: React.ReactNode }>) { +}: Readonly<{ children: ReactNode }>) { return ( - {children} + + {children} + ); } - diff --git a/apps/web/app/page.tsx b/apps/web/app/page.tsx index 7a46a7a..1d99bb2 100644 --- a/apps/web/app/page.tsx +++ b/apps/web/app/page.tsx @@ -1,51 +1,145 @@ -const milestones = [ - "API foundation and migrations", - "Continuity event store", - "Web dashboard shell", - "Worker orchestration", +import Link from "next/link"; + +import { PageHeader } from "../components/page-header"; +import { SectionCard } from "../components/section-card"; +import { StatusBadge } from "../components/status-badge"; + +const summaryCards = [ + { + value: "5", + label: "Operator views", + detail: "Home, request composition, approvals, task inspection, and explainability are all exposed in one bounded shell.", + }, + { + value: "3", + label: "Governance seams", + detail: "Requests, approvals, and tool executions stay visible instead of being hidden behind a consumer chat wrapper.", + }, + { + value: "2", + label: "Data modes", + detail: "Pages can read live backend seams when configured and degrade to explicit fixtures when no API contract is present.", + }, + { + value: "100%", + label: "Scoped surface", + detail: "The shell stays within the sprint packet: no auth expansion, no connector breadth, and no backend contract changes.", + }, +]; + +const routeCards = [ + { + href: "/chat", + title: "Governed Requests", + description: "Compose bounded operator requests, review response history, and keep compilation and response traces visible.", + status: "active", + }, + { + href: "/approvals", + title: "Approval Inbox", + description: "Review pending approvals with tool, scope, routing, and rationale all contained in a stable inspector layout.", + status: "pending_approval", + }, + { + href: "/tasks", + title: "Task Inspection", + description: "Inspect task lifecycle state, related governed requests, and ordered task-step progress without leaving the shell.", + status: "approved", + }, + { + href: "/traces", + title: "Explain-Why Review", + description: "Trace context compilation and governed actions through a calm evidence-first review surface.", + status: "executed", + }, +]; + +const shellNotes = [ + "Stable navigation with obvious current location and restrained emphasis.", + "Cards and lists sized for readable density rather than dashboard clutter.", + "Responsive stacking that protects text containment on tablet and mobile widths.", ]; export default function HomePage() { return ( -
-
-

- AliceBot Foundation -

-

- Operational shell for the modular monolith -

-

- The web app is intentionally minimal in this sprint. It exists to prove repository - structure while continuity, migrations, and safety primitives land in the API layer. -

-
    - {milestones.map((item) => ( -
  • {item}
  • - ))} -
+
+ + Sprint 6A shell + Design-system aligned +
+ } + /> + +
+ {summaryCards.map((card) => ( + +
{card.value}
+
{card.label}
+

{card.detail}

+
+ ))}
-
+ +
+ +
+ {routeCards.map((route) => ( + +
+

{route.title}

+ +
+

{route.description}

+ Open view + + ))} +
+
+ +
+ +
    + {shellNotes.map((note) => ( +
  • {note}
  • + ))} +
+
+ + +
+
+
Request path
+
Explicitly labeled as governed and reviewable.
+
+
+
Consequential actions
+
Held behind approval and execution review states.
+
+
+
Explainability
+
Trace review sits beside operational work, not in a debug-only corner.
+
+
+
+
+
+ ); } - diff --git a/apps/web/app/tasks/page.tsx b/apps/web/app/tasks/page.tsx new file mode 100644 index 0000000..637aee5 --- /dev/null +++ b/apps/web/app/tasks/page.tsx @@ -0,0 +1,318 @@ +import { PageHeader } from "../../components/page-header"; +import { SectionCard } from "../../components/section-card"; +import { StatusBadge } from "../../components/status-badge"; +import { TaskList, type TaskItem } from "../../components/task-list"; +import { TaskStepList, type TaskStepItem } from "../../components/task-step-list"; + +const taskFixtures: TaskItem[] = [ + { + id: "task-201", + thread_id: "thread-magnesium", + tool_id: "tool-purchase", + status: "pending_approval", + request: { + thread_id: "thread-magnesium", + tool_id: "tool-purchase", + action: "place_order", + scope: "supplements", + domain_hint: "ecommerce", + risk_hint: "purchase", + attributes: { + merchant: "Thorne", + item: "Magnesium Bisglycinate", + }, + }, + tool: { + id: "tool-purchase", + tool_key: "merchant_proxy", + name: "Merchant Proxy", + description: "Proxy for governed ecommerce actions.", + version: "0.1.0", + metadata_version: "tool_metadata_v0", + active: true, + tags: ["commerce", "approval"], + action_hints: ["place_order"], + scope_hints: ["supplements"], + domain_hints: ["ecommerce"], + risk_hints: ["purchase"], + metadata: {}, + created_at: "2026-03-15T08:00:00Z", + }, + latest_approval_id: "approval-101", + latest_execution_id: null, + created_at: "2026-03-17T06:49:00Z", + updated_at: "2026-03-17T06:50:00Z", + }, + { + id: "task-182", + thread_id: "thread-vitamin-d", + tool_id: "tool-purchase", + status: "approved", + request: { + thread_id: "thread-vitamin-d", + tool_id: "tool-purchase", + action: "place_order", + scope: "supplements", + domain_hint: "ecommerce", + risk_hint: "purchase", + attributes: { + merchant: "Fullscript", + item: "Vitamin D3 + K2", + }, + }, + tool: { + id: "tool-purchase", + tool_key: "merchant_proxy", + name: "Merchant Proxy", + description: "Proxy for governed ecommerce actions.", + version: "0.1.0", + metadata_version: "tool_metadata_v0", + active: true, + tags: ["commerce", "approval"], + action_hints: ["place_order"], + scope_hints: ["supplements"], + domain_hints: ["ecommerce"], + risk_hints: ["purchase"], + metadata: {}, + created_at: "2026-03-14T09:15:00Z", + }, + latest_approval_id: "approval-100", + latest_execution_id: null, + created_at: "2026-03-16T14:00:00Z", + updated_at: "2026-03-16T14:22:00Z", + }, +]; + +const stepFixtures: Record = { + "task-201": [ + { + id: "step-20", + task_id: "task-201", + sequence_no: 1, + kind: "governed_request", + status: "created", + request: { + thread_id: "thread-magnesium", + tool_id: "tool-purchase", + action: "place_order", + scope: "supplements", + domain_hint: "ecommerce", + risk_hint: "purchase", + attributes: { + merchant: "Thorne", + item: "Magnesium Bisglycinate", + package: "90 capsules", + }, + }, + outcome: { + routing_decision: "require_approval", + approval_id: "approval-101", + approval_status: "pending", + execution_id: null, + execution_status: null, + blocked_reason: null, + }, + lineage: { + parent_step_id: null, + source_approval_id: null, + source_execution_id: null, + }, + trace: { + trace_id: "trace-step-20", + trace_kind: "approval_request", + }, + created_at: "2026-03-17T06:49:00Z", + updated_at: "2026-03-17T06:50:00Z", + }, + ], + "task-182": [ + { + id: "step-14", + task_id: "task-182", + sequence_no: 1, + kind: "governed_request", + status: "approved", + request: { + thread_id: "thread-vitamin-d", + tool_id: "tool-purchase", + action: "place_order", + scope: "supplements", + domain_hint: "ecommerce", + risk_hint: "purchase", + attributes: { + merchant: "Fullscript", + item: "Vitamin D3 + K2", + quantity: "1", + }, + }, + outcome: { + routing_decision: "require_approval", + approval_id: "approval-100", + approval_status: "approved", + execution_id: null, + execution_status: null, + blocked_reason: null, + }, + lineage: { + parent_step_id: null, + source_approval_id: null, + source_execution_id: null, + }, + trace: { + trace_id: "trace-step-14", + trace_kind: "approval_resolution", + }, + created_at: "2026-03-16T14:00:00Z", + updated_at: "2026-03-16T14:22:00Z", + }, + ], +}; + +function getApiConfig() { + return { + apiBaseUrl: + process.env.NEXT_PUBLIC_ALICEBOT_API_BASE_URL ?? process.env.ALICEBOT_API_BASE_URL ?? "", + userId: process.env.NEXT_PUBLIC_ALICEBOT_USER_ID ?? process.env.ALICEBOT_USER_ID ?? "", + }; +} + +async function loadTasks(): Promise<{ items: TaskItem[]; source: "live" | "fixture" }> { + const { apiBaseUrl, userId } = getApiConfig(); + if (!apiBaseUrl || !userId) { + return { items: taskFixtures, source: "fixture" }; + } + + try { + const response = await fetch( + `${apiBaseUrl.replace(/\/$/, "")}/v0/tasks?user_id=${encodeURIComponent(userId)}`, + { cache: "no-store" }, + ); + + if (!response.ok) { + throw new Error("task list request failed"); + } + + const payload = (await response.json()) as { items?: TaskItem[] }; + return { + items: payload.items ?? taskFixtures, + source: "live", + }; + } catch { + return { items: taskFixtures, source: "fixture" }; + } +} + +async function loadTaskSteps( + taskId: string, + source: "live" | "fixture", +): Promise<{ items: TaskStepItem[]; source: "live" | "fixture" }> { + if (source === "fixture") { + return { items: stepFixtures[taskId] ?? [], source: "fixture" }; + } + + const { apiBaseUrl, userId } = getApiConfig(); + try { + const response = await fetch( + `${apiBaseUrl.replace(/\/$/, "")}/v0/tasks/${taskId}/steps?user_id=${encodeURIComponent(userId)}`, + { cache: "no-store" }, + ); + + if (!response.ok) { + throw new Error("task step request failed"); + } + + const payload = (await response.json()) as { items?: TaskStepItem[] }; + return { + items: payload.items ?? stepFixtures[taskId] ?? [], + source: "live", + }; + } catch { + return { items: stepFixtures[taskId] ?? [], source: "fixture" }; + } +} + +type SearchParams = Promise>; + +export default async function TasksPage({ + searchParams, +}: { + searchParams?: SearchParams; +}) { + const params = (searchParams ? await searchParams : {}) as Record< + string, + string | string[] | undefined + >; + const requestedTaskId = typeof params.task === "string" ? params.task : undefined; + const { items, source } = await loadTasks(); + const selectedTask = items.find((item) => item.id === requestedTaskId) ?? items[0] ?? null; + const { items: steps, source: stepSource } = selectedTask + ? await loadTaskSteps(selectedTask.id, source) + : { items: [], source }; + + return ( +
+ + {source === "live" ? "Live API" : "Fixture-backed"} + {items.length} tasks +
+ } + /> + +
+ + +
+ + {selectedTask ? ( +
+
+ + + {selectedTask.request.action} / {selectedTask.request.scope} + +
+
+
+
Thread
+
{selectedTask.thread_id}
+
+
+
Latest approval
+
{selectedTask.latest_approval_id ?? "Not linked"}
+
+
+
Latest execution
+
{selectedTask.latest_execution_id ?? "Not executed"}
+
+
+
Data source
+
{stepSource === "live" ? "Live task-step API" : "Local fixture steps"}
+
+
+
+ ) : ( +

+ No task records are available in the current mode. +

+ )} +
+ + +
+
+ + ); +} diff --git a/apps/web/app/traces/page.tsx b/apps/web/app/traces/page.tsx new file mode 100644 index 0000000..5bf40fb --- /dev/null +++ b/apps/web/app/traces/page.tsx @@ -0,0 +1,173 @@ +import { PageHeader } from "../../components/page-header"; +import { SectionCard } from "../../components/section-card"; +import { TraceList, type TraceItem } from "../../components/trace-list"; + +const traceFixtures: TraceItem[] = [ + { + id: "trace-ctx-401", + kind: "context_compile", + status: "completed", + title: "Context compile for magnesium reorder guidance", + summary: "Compiled prior task state, admitted memories, and recent thread continuity before assistant response assembly.", + eventCount: 9, + createdAt: "2026-03-17T08:45:00Z", + source: "Context compiler", + scope: "thread-magnesium", + related: { + threadId: "thread-magnesium", + taskId: "task-201", + }, + evidence: [ + "Memory evidence admitted for supplement preference and merchant history.", + "Recent approval state included as part of the continuity pack.", + "Task-step lineage referenced before response generation.", + ], + events: [ + { + id: "event-1", + kind: "compiler.scope", + title: "Scope resolved", + detail: "Single-user thread scope and compile limits were established for the request.", + }, + { + id: "event-2", + kind: "memory.retrieve", + title: "Memory evidence attached", + detail: "Preference and purchase-history memories were ranked into the response context pack.", + }, + { + id: "event-3", + kind: "task.retrieve", + title: "Task lifecycle linked", + detail: "Open task and step state were included so the answer could acknowledge the approval dependency.", + }, + ], + }, + { + id: "trace-approval-101", + kind: "approval_request", + status: "requires_review", + title: "Approval request for supplement purchase", + summary: "Routing required user approval before the merchant proxy could execute the purchase request.", + eventCount: 6, + createdAt: "2026-03-17T06:50:00Z", + source: "Approval workflow", + scope: "supplements", + related: { + threadId: "thread-magnesium", + taskId: "task-201", + approvalId: "approval-101", + }, + evidence: [ + "Policy rule marked purchase actions as approval-gated.", + "Tool metadata matched the requested action and scope.", + "Task-step trace link points back to the original governed request.", + ], + events: [ + { + id: "event-4", + kind: "tool.route", + title: "Routing completed", + detail: "The merchant proxy was selected as the governing tool for the request.", + }, + { + id: "event-5", + kind: "approval.state", + title: "Approval opened", + detail: "Approval record persisted with pending resolution state and task-step linkage.", + }, + { + id: "event-6", + kind: "task.lifecycle", + title: "Task updated", + detail: "Task lifecycle moved into a pending approval state while retaining request provenance.", + }, + ], + }, + { + id: "trace-exec-311", + kind: "proxy_execution", + status: "completed", + title: "Governed execution for vitamin reorder", + summary: "Approved supplement purchase request executed through the proxy handler with task and trace linkage preserved.", + eventCount: 7, + createdAt: "2026-03-16T14:24:00Z", + source: "Proxy execution", + scope: "supplements", + related: { + threadId: "thread-vitamin-d", + taskId: "task-182", + approvalId: "approval-100", + executionId: "execution-311", + }, + evidence: [ + "Execution occurred only after approval resolution.", + "Handler output and trace references stayed attached to the governed action record.", + "Task and task-step lifecycle traces were appended alongside execution status.", + ], + events: [ + { + id: "event-7", + kind: "approval.check", + title: "Approval validated", + detail: "Execution preflight confirmed the approval was in an executable state.", + }, + { + id: "event-8", + kind: "budget.check", + title: "Budget check passed", + detail: "Execution budget constraints did not block the governed action.", + }, + { + id: "event-9", + kind: "execution.result", + title: "Handler completed", + detail: "Proxy output was recorded for the approved supplement reorder with a linked execution trace and task-step status update.", + }, + ], + }, +]; + +type SearchParams = Promise>; + +export default async function TracesPage({ + searchParams, +}: { + searchParams?: SearchParams; +}) { + const params = (searchParams ? await searchParams : {}) as Record< + string, + string | string[] | undefined + >; + const selectedId = typeof params.trace === "string" ? params.trace : undefined; + + return ( +
+ + Fixture-backed detail view + Existing backend concepts only +
+ } + /> + + + + +
    +
  • Which evidence types contributed to the outcome and whether they were appropriate.
  • +
  • How the lifecycle moved from request to approval or execution without losing provenance.
  • +
  • Whether the current trace surface needs deeper live-event wiring in a future sprint.
  • +
+
+ + ); +} diff --git a/apps/web/components/app-shell.tsx b/apps/web/components/app-shell.tsx new file mode 100644 index 0000000..f2e254f --- /dev/null +++ b/apps/web/components/app-shell.tsx @@ -0,0 +1,119 @@ +"use client"; + +import type { ReactNode } from "react"; + +import Link from "next/link"; +import { usePathname } from "next/navigation"; + +const navigation = [ + { + href: "/", + label: "Overview", + caption: "Shell landing and governed surface summary", + }, + { + href: "/chat", + label: "Requests", + caption: "Compose bounded operator requests", + }, + { + href: "/approvals", + label: "Approvals", + caption: "Review approval queue and inspector", + }, + { + href: "/tasks", + label: "Tasks", + caption: "Inspect lifecycle state and task steps", + }, + { + href: "/traces", + label: "Traces", + caption: "Explain-why and governed action review", + }, +]; + +function isActive(pathname: string, href: string) { + if (href === "/") { + return pathname === "/"; + } + + return pathname.startsWith(href); +} + +export function AppShell({ children }: { children: ReactNode }) { + const pathname = usePathname(); + + return ( +
+
+ + +
+
+
+
+

MVP Web Shell

+

Governed operator interface

+
+ +
+ Single-user v1 + Existing backend seams only +
+
+ + +
+ +
+
{children}
+
+
+
+
+ ); +} diff --git a/apps/web/components/approval-list.tsx b/apps/web/components/approval-list.tsx new file mode 100644 index 0000000..ec9eea9 --- /dev/null +++ b/apps/web/components/approval-list.tsx @@ -0,0 +1,212 @@ +import Link from "next/link"; + +import { EmptyState } from "./empty-state"; +import { SectionCard } from "./section-card"; +import { StatusBadge } from "./status-badge"; + +export type ApprovalItem = { + id: string; + thread_id: string; + task_step_id: string | null; + status: string; + request: { + thread_id: string; + tool_id: string; + action: string; + scope: string; + domain_hint: string | null; + risk_hint: string | null; + attributes: Record; + }; + tool: { + id: string; + tool_key: string; + name: string; + description: string; + version: string; + metadata_version: string; + active: boolean; + tags: string[]; + action_hints: string[]; + scope_hints: string[]; + domain_hints: string[]; + risk_hints: string[]; + metadata: Record; + created_at: string; + }; + routing: { + decision: string; + reasons: Array<{ + code: string; + source: string; + message: string; + tool_id: string | null; + policy_id: string | null; + consent_key: string | null; + }>; + trace: { + trace_id: string; + trace_event_count: number; + }; + }; + created_at: string; + resolution: { + resolved_at: string; + resolved_by_user_id: string; + } | null; +}; + +function formatDate(value: string) { + return new Intl.DateTimeFormat("en", { + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + }).format(new Date(value)); +} + +function formatAttributeValue(value: unknown) { + if (value == null) { + return "None"; + } + + if (typeof value === "string" || typeof value === "number" || typeof value === "boolean") { + return String(value); + } + + return JSON.stringify(value); +} + +export function ApprovalList({ + items, + selectedId, +}: { + items: ApprovalItem[]; + selectedId?: string; +}) { + if (items.length === 0) { + return ( + + + + ); + } + + const selected = items.find((item) => item.id === selectedId) ?? items[0]; + + return ( +
+ +
+
+

{items.length} total approvals

+
+
+ {items.map((item) => ( + +
+
+ {formatDate(item.created_at)} +

{item.tool.name}

+
+ +
+

+ {item.request.action} / {item.request.scope} +

+
+ Thread {item.thread_id} + {item.request.risk_hint ? Risk {item.request.risk_hint} : null} +
+ + ))} +
+
+
+ + +
+
+ + + {selected.request.action} / {selected.request.scope} + +
+ +
+
+
Thread
+
{selected.thread_id}
+
+
+
Task step
+
{selected.task_step_id ?? "Unlinked"}
+
+
+
Routing decision
+
{selected.routing.decision}
+
+
+
Trace
+
+ {selected.routing.trace.trace_id} · {selected.routing.trace.trace_event_count} events +
+
+
+ +
+

Request attributes

+
+ {Object.entries(selected.request.attributes).map(([key, value]) => ( + + {key}: {formatAttributeValue(value)} + + ))} +
+
+ +
+

Routing rationale

+
    + {selected.routing.reasons.map((reason) => ( +
  • + {reason.message} +
  • + ))} +
+
+ +
+

Resolution

+

+ {selected.resolution + ? `Resolved ${formatDate(selected.resolution.resolved_at)} by ${selected.resolution.resolved_by_user_id}.` + : "Still awaiting explicit operator resolution."} +

+
+
+
+
+ ); +} diff --git a/apps/web/components/empty-state.tsx b/apps/web/components/empty-state.tsx new file mode 100644 index 0000000..fc44acf --- /dev/null +++ b/apps/web/components/empty-state.tsx @@ -0,0 +1,22 @@ +import Link from "next/link"; + +type EmptyStateProps = { + title: string; + description: string; + actionHref?: string; + actionLabel?: string; +}; + +export function EmptyState({ title, description, actionHref, actionLabel }: EmptyStateProps) { + return ( +
+

{title}

+

{description}

+ {actionHref && actionLabel ? ( + + {actionLabel} + + ) : null} +
+ ); +} diff --git a/apps/web/components/page-header.tsx b/apps/web/components/page-header.tsx new file mode 100644 index 0000000..49fc524 --- /dev/null +++ b/apps/web/components/page-header.tsx @@ -0,0 +1,21 @@ +import type { ReactNode } from "react"; + +type PageHeaderProps = { + eyebrow?: string; + title: string; + description: string; + meta?: ReactNode; +}; + +export function PageHeader({ eyebrow, title, description, meta }: PageHeaderProps) { + return ( +
+
+ {eyebrow ?

{eyebrow}

: null} +

{title}

+

{description}

+
+ {meta ?
{meta}
: null} +
+ ); +} diff --git a/apps/web/components/request-composer.tsx b/apps/web/components/request-composer.tsx new file mode 100644 index 0000000..c1d0b97 --- /dev/null +++ b/apps/web/components/request-composer.tsx @@ -0,0 +1,230 @@ +"use client"; + +import type { FormEvent } from "react"; +import { useState, useTransition } from "react"; + +export type RequestHistoryEntry = { + id: string; + request: string; + response: string; + submittedAt: string; + source: "live" | "fixture"; + trace?: { + compileTraceId: string; + compileTraceEventCount: number; + responseTraceId: string; + responseTraceEventCount: number; + }; +}; + +type RequestComposerProps = { + initialEntries: RequestHistoryEntry[]; + apiBaseUrl?: string; + userId?: string; + threadId?: string; +}; + +type LiveResponsePayload = { + assistant: { + event_id: string; + text: string; + }; + trace: { + compile_trace_id: string; + compile_trace_event_count: number; + response_trace_id: string; + response_trace_event_count: number; + }; +}; + +function formatDate(value: string) { + return new Intl.DateTimeFormat("en", { + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + }).format(new Date(value)); +} + +function buildFixtureEntry(message: string): RequestHistoryEntry { + const excerpt = message.trim().slice(0, 120); + const requestLabel = excerpt.length > 0 ? excerpt : "Operator request"; + const nonce = Date.now().toString(36); + + return { + id: `fixture-${nonce}`, + request: requestLabel, + response: + `Prepared a governed response preview for "${requestLabel}". In live mode this surface returns assistant output together with compile and response trace references from the backend.`, + submittedAt: new Date().toISOString(), + source: "fixture", + trace: { + compileTraceId: `trace-ctx-${nonce}`, + compileTraceEventCount: 5, + responseTraceId: `trace-resp-${nonce}`, + responseTraceEventCount: 3, + }, + }; +} + +export function RequestComposer({ + initialEntries, + apiBaseUrl, + userId, + threadId, +}: RequestComposerProps) { + const [message, setMessage] = useState(""); + const [entries, setEntries] = useState(initialEntries); + const [statusText, setStatusText] = useState("Ready for a governed operator request."); + const [isPending, startTransition] = useTransition(); + + const liveModeReady = Boolean(apiBaseUrl && userId && threadId); + + async function handleSubmit(event: FormEvent) { + event.preventDefault(); + + const nextMessage = message.trim(); + if (!nextMessage) { + return; + } + + setStatusText(liveModeReady ? "Submitting request to the response endpoint..." : "Saving fixture-backed preview..."); + + if (!liveModeReady) { + const entry = buildFixtureEntry(nextMessage); + startTransition(() => { + setEntries((current) => [entry, ...current]); + setMessage(""); + setStatusText("Fixture response added. Configure the web API env vars to switch this view into live mode."); + }); + return; + } + + try { + const response = await fetch(`${apiBaseUrl?.replace(/\/$/, "")}/v0/responses`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + user_id: userId, + thread_id: threadId, + message: nextMessage, + max_sessions: 8, + max_events: 80, + max_memories: 20, + max_entities: 12, + max_entity_edges: 20, + }), + }); + + const payload = (await response.json()) as LiveResponsePayload | { detail?: string }; + if (!response.ok || !("assistant" in payload)) { + throw new Error("detail" in payload && payload.detail ? payload.detail : "Request failed"); + } + + const entry: RequestHistoryEntry = { + id: payload.assistant.event_id, + request: nextMessage, + response: payload.assistant.text, + submittedAt: new Date().toISOString(), + source: "live", + trace: { + compileTraceId: payload.trace.compile_trace_id, + compileTraceEventCount: payload.trace.compile_trace_event_count, + responseTraceId: payload.trace.response_trace_id, + responseTraceEventCount: payload.trace.response_trace_event_count, + }, + }; + + startTransition(() => { + setEntries((current) => [entry, ...current]); + setMessage(""); + setStatusText("Live response received and trace references recorded."); + }); + } catch (error) { + const detail = error instanceof Error ? error.message : "Request failed"; + setStatusText(`Unable to submit live request: ${detail}`); + } + } + + return ( +
+
+
+ {liveModeReady ? "Live operator mode" : "Fixture operator mode"} + + Requests stay explicitly governed and recent trace references remain attached to each response. + +
+ +
+ +

+ Keep requests bounded to existing backend concepts. This surface is optimized for clarity + and review rather than casual back-and-forth. +

+
+
+ +
+
+