From 3fb35e0001dbda7681882738b33fb1e6c5bf91d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 18:09:29 +0200 Subject: [PATCH 01/46] wip --- TANSTACK_AG_UI_PARITY_PLAN.md | 758 ++++++++++++++++ go.mod | 34 +- go.sum | 32 + pkg/ag-ui/events.go | 650 ++++++++++++++ pkg/ag-ui/events_test.go | 144 +++ pkg/ai-stream/approval.go | 194 ++++ pkg/ai-stream/bridgev2/events.go | 133 +++ pkg/ai-stream/bridgev2/events_test.go | 100 +++ pkg/ai-stream/matrix/content.go | 90 ++ pkg/ai-stream/matrix/content_test.go | 131 +++ pkg/ai-stream/pack.go | 236 +++++ pkg/ai-stream/run.go | 657 ++++++++++++++ pkg/ai-stream/stream_test.go | 236 +++++ pkg/connector/ai_runtime.go | 1185 +++++++++++++++++++++++++ pkg/connector/ai_runtime_test.go | 506 +++++++++++ pkg/connector/ai_text.go | 214 +++++ pkg/connector/client.go | 527 ++++++++++- pkg/connector/client_test.go | 76 ++ pkg/connector/connector.go | 15 +- 19 files changed, 5898 insertions(+), 20 deletions(-) create mode 100644 TANSTACK_AG_UI_PARITY_PLAN.md create mode 100644 pkg/ag-ui/events.go create mode 100644 pkg/ag-ui/events_test.go create mode 100644 pkg/ai-stream/approval.go create mode 100644 pkg/ai-stream/bridgev2/events.go create mode 100644 pkg/ai-stream/bridgev2/events_test.go create mode 100644 pkg/ai-stream/matrix/content.go create mode 100644 pkg/ai-stream/matrix/content_test.go create mode 100644 pkg/ai-stream/pack.go create mode 100644 pkg/ai-stream/run.go create mode 100644 pkg/ai-stream/stream_test.go create mode 100644 pkg/connector/ai_runtime.go create mode 100644 pkg/connector/ai_runtime_test.go create mode 100644 pkg/connector/ai_text.go diff --git a/TANSTACK_AG_UI_PARITY_PLAN.md b/TANSTACK_AG_UI_PARITY_PLAN.md new file mode 100644 index 0000000..6f600f5 --- /dev/null +++ b/TANSTACK_AG_UI_PARITY_PLAN.md @@ -0,0 +1,758 @@ +# TanStack/AG-UI Parity Implementation Brief + +## Summary + +Build dummybridge AI around current TanStack AG-UI primitives in both directions: + +- dummybridge emits AG-UI stream events and accepts TanStack-shaped approval responses. +- Desktop consumes multi-event encrypted AG-UI streams and hides carrier events from the normal timeline. +- Shared Go packages define the primitive contract instead of preserving old AI SDK or agentremote decisions. + +Do not install dependencies or modify lockfiles unless the user explicitly approves that dependency change. `@tanstack/ai-react-ui` is approved for the Desktop rendering work in this plan. + +## Current State + +The dummybridge repo currently has a provisional `pkg/aichats` package and AI handling in `pkg/connector/client.go`. + +Known current behavior: + +- AI DM resolution uses the `ai`/`AI` ghost and AI portals with the `ai-` prefix. +- The bridge sends one visible placeholder event, streams `com.beeper.llm.deltas`, then edits the placeholder with final content. +- Current deltas are AG-UI-like but incomplete. +- Current approval requests are separate Matrix events with `com.beeper.ai.approval` metadata and reaction options. +- Current approval reaction handling should keep the user's selected emoji and remove the bridge-posted placeholder/non-selected options, but this needs robust implementation and tests. +- Current text streaming has started moving away from full accumulated content on each delta, but the final design must enforce that. + +Desktop already has partial AI stream support in these areas: + +- `src/common/ai-common.ts`: `BeeperAIMessage`, `BeeperAGUIEvent`, approval constants, and type guards. +- `src/common/types/beeper.ts`: stream content types with `.deltas` and `updates`. +- `src/pas-server/beeper/EventSyncContext.ts`: maps `com.beeper.ai`, `com.beeper.stream`, per-message profile, edits, and hidden AI notices. +- `src/pas-server/beeper/BeeperClient.ts`: processes stream events into `STATE_SYNC message stream`. +- `src/renderer/stores/AIChatsStore.ts`: extracts `.deltas`, orders by `seq`, applies AG-UI events, tracks approvals, and merges stream state. +- `src/renderer/ai/ui-message.ts`: applies AG-UI events into current Desktop UI message parts. + +The implementation should update those Desktop paths instead of inventing a second client stream path. + +Implementation rules: + +- Keep the code simple, clean, and direct. +- Prefer less LOC, less indirection, and fewer abstractions. +- Fold or flatten abstractions that do not carry real behavior. +- Do not add fake layers, simple wrappers, barrel exports, duplicated logic, or duplicated types. +- Smaller files are fine only when they represent real concerns. +- Optimize for one coherent system per concern, not multiple parallel ways to do the same thing. +- Current AI code was generated and never released, so no backward compatibility or legacy compatibility is required. +- Delete provisional schemas, routes, event shapes, migrations, aliases, and helper layers if they only exist for history or compatibility. +- Prefer deleting code over preserving it. +- Prefer collapsing duplicate entrypoints over keeping aliases. +- Product intention matters more than the current code shape. +- If product intent is ambiguous, explicitly call out the question instead of encoding both options. + +Compatibility policy: + +- No compatibility policy is required for the provisional/current dummybridge AI event shapes. +- Desktop and dummybridge should converge on one new TanStack/AG-UI shape. +- Delete old reader/writer paths instead of accepting old names such as `REASONING_MESSAGE_*` or older tool-call fields unless they are required by current TanStack AG-UI docs. + +## Full Paths To Inspect + +Primary dummybridge checkout: + +- `/Users/batuhan/Projects/labs/dummybridge` +- `/Users/batuhan/Projects/labs/dummybridge/TANSTACK_AG_UI_PARITY_PLAN.md` +- `/Users/batuhan/Projects/labs/dummybridge/README.md` +- `/Users/batuhan/Projects/labs/dummybridge/go.mod` +- `/Users/batuhan/Projects/labs/dummybridge/go.sum` +- `/Users/batuhan/Projects/labs/dummybridge/config-agui.yaml` +- `/Users/batuhan/Projects/labs/dummybridge/config-qa-agui.yaml` + +Current dummybridge AI implementation to replace: + +- `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/agui.go` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/matrix.go` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/agui_test.go` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/client.go` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/connector.go` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/login.go` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/example-config.yaml` + +New dummybridge package targets: + +- `/Users/batuhan/Projects/labs/dummybridge/pkg/ag-ui` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream/matrix` +- `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream/bridgev2` + +Archived AI dummybridge reference: + +- `/Users/batuhan/Projects/labs/ai-bridge-archived` +- `/Users/batuhan/Projects/labs/ai-bridge-archived/bridges/dummybridge/runtime.go` +- `/Users/batuhan/Projects/labs/ai-bridge-archived/bridges/dummybridge/runtime_test.go` +- `/Users/batuhan/Projects/labs/ai-bridge-archived/sdk/writer.go` +- `/Users/batuhan/Projects/labs/ai-bridge-archived/approval_flow.go` + +Local TanStack AI reference checkout: + +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-react-ui/src/text-part.tsx` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-react-ui/src/chat-message.tsx` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-react-ui/package.json` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-client/src/types.ts` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-client/src/chat-client.ts` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-client/src/connection-adapters.ts` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-event-client/src/index.ts` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai/src/types.ts` +- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai/src/utilities/chat-params.ts` + +Desktop checkout and AI consumer paths: + +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/common/ai-common.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/common/ai-common.test.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/common/types/beeper.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/EventSyncContext.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/BeeperClient.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/connect/ws-event-mapper.test.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/connect/ws-events-server.test.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/stores/AIChatsStore.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/ai/ui-message.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/ai/ui-message.test.ts` +- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/ai/ai-message-view.ts` + +Local tooling for live smoke tests: + +- `/Users/batuhan/Projects/texts/bridge-manager/bbctl` +- `/Users/batuhan/Projects/labs/desktop-api-cli/packages/cli` + +Local runtime artifacts that may be useful for debugging, but should not be treated as source: + +- `/Users/batuhan/Projects/labs/dummybridge/logs` +- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-agui.db` +- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-agui.db-shm` +- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-agui.db-wal` +- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-qa-agui.db` + +## TanStack/AG-UI Contract + +Use TanStack primitives as the source of truth: + +- `StreamChunk = AGUIEvent`; do not preserve legacy non-AG-UI chunk formats. +- Support every current AG-UI lifecycle event explicitly: + - `RUN_STARTED` + - `RUN_FINISHED` + - `RUN_ERROR` + - `TEXT_MESSAGE_START` + - `TEXT_MESSAGE_CONTENT` + - `TEXT_MESSAGE_END` + - `TOOL_CALL_START` + - `TOOL_CALL_ARGS` + - `TOOL_CALL_END` + - `TOOL_CALL_RESULT` + - `STEP_STARTED` + - `STEP_FINISHED` + - `STATE_SNAPSHOT` + - `STATE_DELTA` + - `MESSAGES_SNAPSHOT` + - `CUSTOM` +- Support bidirectional AG-UI run input: `threadId`, `runId`, `state`, `messages`, `tools`, `context`, `forwardedProps`, and legacy `data` mirror. +- Model `UIMessage` as `{ id, role, parts, createdAt? }`, preserving ordered parts. +- Use TanStack part shapes: + - Text part: `{ type: "text", content }` + - Thinking part: `{ type: "thinking", content }` + - Tool call part: `{ type: "tool-call", id, name, arguments, state, approval?, output? }` + - Tool result part: `{ type: "tool-result", toolCallId, content, state, error? }` +- Use TanStack tool states: + - `awaiting-input` + - `input-streaming` + - `input-complete` + - `approval-requested` + - `approval-responded` +- Use TanStack tool result states: + - `streaming` + - `complete` + - `error` +- Treat AG-UI `REASONING_START`, `REASONING_MESSAGE_START`, `REASONING_MESSAGE_CONTENT`, `REASONING_MESSAGE_END`, and `REASONING_END` as the canonical thinking/reasoning stream for new output. +- Keep `STEP_STARTED` / `STEP_FINISHED` as step lifecycle events using AG-UI `stepName`, not deprecated `stepId`, and not as a substitute for reasoning content. +- Fully support AG-UI `STATE_SNAPSHOT`, `STATE_DELTA`, and `MESSAGES_SNAPSHOT` events. +- Every emitted AG-UI event must include `timestamp`. +- Support optional AG-UI `rawEvent` on every event, with the bounded/truncation policy below. +- Support `TOOL_CALL_START.index` for parallel tool calls. +- Support partial JSON argument streaming through `TOOL_CALL_ARGS`; consumers should preserve partial input while parsing best-effort and finalize on `TOOL_CALL_END`. +- Support `TOOL_CALL_END` both with and without a result payload. +- Support AG-UI `TOOL_CALL_RESULT` for separate tool-result parts instead of Beeper custom tool-result events. +- Support multiple assistant `messageId`s per run. Do not assume a run has exactly one assistant text message. + +Relevant docs: + +- AG-UI event definitions: +- Streaming: +- Tool states and parts: +- UIMessage: +- Bidirectional AG-UI compliance: +- Local source tags currently include `@tanstack/ai@0.18.0`, `@tanstack/ai-client@0.10.0`, and `@tanstack/ai-event-client@0.3.2`; prefer the local checkout above for exact type names during implementation. + +## Package Layout + +Create `pkg/ag-ui/` with Go package name `agui`. + +Responsibilities: + +- Standalone AG-UI event and UI message types. +- `RunAgentInput` and bidirectional request types. +- Tool, tool result, approval, text, thinking, step, custom, run, and error event builders. +- Validation helpers that reject invalid event ordering, missing IDs, bad states, invalid tool approval shapes, and oversized individual deltas. +- No Matrix, bridgev2, Desktop, or dummybridge-specific dependencies. + +Create `pkg/ai-stream/` with Go package name `aistream`. + +Responsibilities: + +- Run writer for ordered AG-UI event emission. +- Accumulation used only for finalization, preview generation, and test reconstruction. +- Stream envelope and chunk packing helpers. +- Approval resolver primitives. +- Terminal/finalization helpers. +- Spec enforcement, but not transport ownership. + +Add adapter layers: + +- `pkg/ai-stream/matrix`: Matrix content helpers using mautrix event types, stream carrier content, approval prompt content, reaction option serialization. +- `pkg/ai-stream/bridgev2`: bridgev2 queue/send/redaction adapter. This layer may import bridgev2 and database types. + +Delete `pkg/aichats` once the new packages fully replace it. Do not keep it as an unused compatibility package. + +## Archived Dummybridge Parity + +Use `../ai-bridge-archived/bridges/dummybridge/runtime.go` and `runtime_test.go` as the feature checklist, not as an architecture to copy. + +Commands: + +- `help` +- `/help` +- `!help` +- `dummybridge help` +- `stream-lorem [common options]` +- `stream-tools ... [common options]` +- `stream-random [seconds] [--actions=N] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--delay-ms=min:max] [--allow-abort] [--allow-error] [--allow-approval]` +- `stream-chaos [runs] [seconds] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--stagger-ms=min:max] [--max-actions=N] [--allow-abort] [--allow-error] [--allow-approval]` + +The help aliases are intentional product/demo affordances and should remain unless there is a later product decision to reduce command aliases. + +Common options: + +- `--reasoning=N` +- `--steps=N` +- `--sources=N` +- `--documents=N` +- `--files=N` +- `--meta` +- `--data=name` +- `--data-transient=name` +- `--delay-ms=min:max` +- `--chunk-chars=min:max` +- `--seed=N` +- `--finish=stop|length|tool-calls|content-filter|other` +- `--abort` +- `--error` + +Tool tags: + +- `#fail` +- `#approval` +- `#deny` +- `#delta` +- `#inputerror` +- `#prelim` +- `#provider` + +Behavior to preserve: + +- `stream-lorem` emits markdown-rich visible text, optional thinking/reasoning, optional steps, optional sources/documents/files/data, and a final run state. +- `stream-tools` emits text, thinking, tool input streaming, input errors, approval requests, approval denials, tool output streaming, final output, and tool failures. +- `stream-random` emits weighted random actions with deterministic seed support and profiles. +- `stream-chaos` starts multiple staggered runs and runs random streams per run. +- Persistent data survives final snapshots; transient data does not. +- Markdown generation must include realistic links, lists, quotes, code blocks, and tables. +- Terminal states include normal finish, error, and abort. + +Limits should start from the archived limits, except the explicit over-64KB streaming tests require larger generated output support: + +- Archived default chunk range: 24 to 96 characters. +- Archived maximum chunk size option: 512 characters. +- Archived maximum random actions: 64. +- Archived maximum chaos runs: 16. +- Archived maximum chaos actions: 64. +- Archived maximum demo duration: 5 minutes. +- Archived maximum delay/stagger: 30 seconds. +- Increase text generation limits enough to test at least 70KiB output. The transport must handle this by splitting carrier events, not by sending oversized Matrix events. + +## Streaming Transport + +Every AI run starts with one visible Matrix anchor event. + +Anchor event requirements: + +- `msgtype: m.text` +- AI per-message profile for the AI ghost +- Minimal `com.beeper.ai` +- Stable AG-UI `threadId` +- Stable AG-UI `runId` +- Stable AG-UI `messageId` +- Useful preview text in `body` +- `com.beeper.stream` descriptor when using the Beeper stream publisher + +ID model: + +- Use AG-UI IDs for semantic identity. +- `threadId` is the conversation/thread identity. For dummybridge this should map to the Beeper thread/portal/room identity used by Desktop. +- `runId` is the assistant execution identity. Do not add a separate Beeper execution ID unless a future AG-UI version requires it. +- `messageId` is the AG-UI assistant UI message identity. It should map to the first visible/anchor message, not to every carrier. +- Matrix event IDs are transport identities. Use the anchor Matrix event ID as `target_event` / `m.relates_to.event_id` for carriers. +- The Beeper stream descriptor is identified by `(room_id, event_id, type)` and does not expose a separate stream ID. Do not invent `streamId`; use `target_event` plus `runId` for merging. + +Carrier events: + +- Are sent through bridgev2 remote events so E2EE works normally. +- Must never be raw Matrix sends. +- Are `m.room.message` events with `msgtype: m.text` for bridgev2 and client compatibility. +- Contain `com.beeper.llm.deltas`. +- Carry ordered AG-UI envelopes. +- Are hidden from normal chat rendering by Desktop after deltas are extracted. +- Use empty or minimal body text after the initial visible preview; they must not appear as chat bubbles in Desktop. + +Envelope shape: + +- `threadId` +- `runId` +- `messageId` +- `seq` +- `part` +- `target_event` or `m.relates_to.event_id` +- optional `agent_id` + +Ordering and merge key: + +- `seq` is strictly increasing per `{target_event, runId}`. +- If `target_event` is unavailable during early processing, temporarily key by `{threadId, runId}` and promote to `{target_event, runId}` when the anchor message is known. +- Desktop buffers out-of-order deltas within existing ordering limits. +- Duplicate or stale `seq` values are ignored or rejected consistently. + +Size budget: + +- Treat 64KB as the external ceiling. +- Use a hard carrier budget of 58KB for serialized Matrix content to leave buffer for encryption overhead, wrappers, event metadata, and implementation variance. +- The packer must measure serialized JSON byte size before adding an envelope to a carrier. +- If a single text delta would exceed the 58KB carrier budget, split it at UTF-8 rune boundaries. +- If a non-text event cannot fit inside the 58KB budget, return a validation error rather than sending it. +- `rawEvent` must be optional, bounded, and safe to omit. If including `rawEvent` would push a carrier over budget, truncate it or drop it before packing rather than bloating the event. +- Truncated raw provider data must be marked, e.g. `rawEventTruncated: true`, so debugging does not confuse partial raw data with complete provider payloads. + +Preview/body algorithm: + +- The first visible message is the canonical message for the run. +- Put as much useful early visible preview as practical into the first message while preserving required metadata and staying under the 58KB budget. +- All run-level metadata that should survive as the message identity, such as model, usage, thread/run/message IDs, terminal state, and approval summary, belongs on the first visible message or its compact final metadata. +- Later carrier messages should be hidden and merged by compatible clients into the first visible message. +- Later carrier bodies should be empty or minimal and put payload in `.deltas`. +- Compatible clients must reconstruct from ordered deltas and merge content/parts into the first message, not display carriers as separate runs. +- Do not rewrite full accumulated content on every delta. + +Finalization: + +- The run accumulator is only for finalization, preview generation, and tests. +- Finalization emits compact terminal metadata and a compact final UI state when needed. +- Do not require a final Matrix edit containing the full generated body for over-64KB runs. +- The client is responsible for merging the stream. + +Replay/backfill: + +- Desktop must be able to reconstruct a run from persisted anchor plus persisted carrier messages, not only from live stream events. +- Replay must use the same merge key and ordering rules as live streaming. +- Backfilled carrier events should remain hidden after extraction. + +Redaction/delete behavior: + +- If a carrier is deleted/redacted, Desktop should recompute the visible anchor from remaining carrier events when possible. +- If recomputation leaves a sequence gap or invalid stream, mark the anchor message incomplete/failed. +- Approval prompt deletion/redaction should not delete or corrupt the AI run; it only removes that visible prompt. + +Ordering gap timeout: + +- Do not buffer missing `seq` gaps forever. +- If a gap remains unresolved past the configured timeout, mark the first visible anchor message incomplete/failed and keep carrier messages hidden. +- Late arrivals after failure should not create separate visible carrier messages. + +AG-UI state events: + +- Fully support `STATE_SNAPSHOT`, `STATE_DELTA`, and `MESSAGES_SNAPSHOT` as first-class AG-UI events. +- `STATE_SNAPSHOT` replaces the current run/application state view for the AI run. +- `STATE_DELTA` applies an incremental patch/update to that state. +- `MESSAGES_SNAPSHOT` carries a complete AG-UI `UIMessage[]` snapshot. +- Desktop must preserve and expose this state for AI rendering/devtools instead of dropping it. +- State events are allowed to affect rendered state when the renderer intentionally consumes them. +- State events must still obey the 58KB carrier budget and multi-carrier splitting rules. +- Do not duplicate the normal streaming path: text should still prefer text events, tool calls should still prefer tool events, and state events should be used when AG-UI state synchronization is the right primitive. + +Run errors: + +- `RUN_ERROR` may be run-scoped or session/thread-scoped. +- If `RUN_ERROR.runId` is present, fail only that run. +- If `RUN_ERROR.runId` is absent, fail active runs in that thread/session. +- Desktop should surface the failure on the first visible anchor message for each affected run and keep carrier messages hidden. + +First message metadata schema: + +- The first visible message must contain enough metadata for a compatible client to render run chrome, status, model info, usage, approvals, and non-part attachments without reading carrier event metadata. +- Do not put streamed UI parts, text chunks, thinking chunks, tool argument chunks, or tool output chunks in this metadata. +- AG-UI/TanStack-mirrored fields must use TanStack naming and value shapes: `threadId`, `runId`, `messageId`, `finishReason`, `promptTokens`, `completionTokens`, and `totalTokens`. +- Beeper-only fields should be grouped clearly under Beeper-owned names instead of changing AG-UI concepts. +- Suggested `com.beeper.ai.metadata` shape: + +```json +{ + "schema": "com.beeper.ai.run.v1", + "protocol": "ag-ui", + "threadId": "thread-id", + "runId": "run-id", + "messageId": "message-id", + "agent": { + "id": "ai", + "displayName": "AI" + }, + "model": "dummybridge/ag-ui", + "usage": { + "promptTokens": 0, + "completionTokens": 0, + "totalTokens": 0 + }, + "usageDetails": { + "reasoningTokens": 0, + "cachedInputTokens": 0 + }, + "status": { + "state": "streaming", + "finishReason": "stop", + "terminal": null, + "error": null + }, + "approvals": [ + { + "id": "approval-id", + "toolCallId": "tool-call-id", + "state": "requested", + "always": false, + "reason": "" + } + ], + "artifacts": { + "sources": [], + "documents": [], + "files": [] + }, + "data": {}, + "preview": { + "text": "bounded visible preview", + "truncated": true + } +} +``` + +- `model` is the AG-UI model identifier string. Do not add `modelInfo`; display/provider details should be derived from the model registry, agent profile, or bridge/network metadata instead of duplicated on every message. +- `usage` mirrors AG-UI `RUN_FINISHED.usage`. Extra usage fields belong in `usageDetails`. +- `finishReason` should use TanStack/AG-UI values: `stop`, `length`, `content_filter`, `tool_calls`, or `null`. Command aliases may accept hyphenated input, but emitted metadata should use AG-UI values. +- `usage` is token/usage metadata only. Do not add dollar cost fields unless the product explicitly decides to expose pricing. +- `artifacts` and `data` are for descriptors needed to render run-level UI outside the streamed parts. If an item is naturally a UI part, it should stay in the stream instead of being duplicated here. +- Final compact metadata may update `status`, `usage`, `approvals`, `artifacts`, `data`, and `preview`, but still must not embed full chunks/parts. + +## Desktop Work + +Update Desktop as part of parity because the new transport deliberately splits one stream across multiple Matrix events. + +Dependency: + +- Add `@tanstack/ai-react-ui` to the Desktop app and use it for AI message rendering. +- Do not hand-roll a parallel markdown renderer when TanStack's UI package already provides one. +- `@tanstack/ai-react-ui` `TextPart` renders Markdown with `react-markdown`, GFM tables/strikethrough via `remark-gfm`, sanitized HTML via `rehype-sanitize`, and code highlighting via `rehype-highlight`. +- Keep Beeper-specific shell/layout/actions in Desktop, but delegate TanStack text/thinking/tool/result part rendering to TanStack UI components or thin render props around them. + +TanStack ownership in Desktop: + +- Desktop already depends on `@tanstack/ai` and `@tanstack/ai-client`; use those packages instead of duplicating their concepts. +- Import `UIMessage`, `MessagePart`, `TextPart`, `ThinkingPart`, `ToolCallPart`, `ToolResultPart`, `ToolCallState`, and `ToolResultState` from TanStack packages. +- Use TanStack `StreamProcessor`/stream utilities where practical for applying AG-UI chunks into UI messages instead of maintaining a parallel Desktop-only stream reducer. +- Use TanStack `parsePartialJSON`/partial JSON utilities for streaming tool args instead of maintaining a separate parser. +- Use `@tanstack/ai-react-ui` for `ChatMessage`/part rendering, with Beeper render props only for product-specific chrome, approvals, and bridge actions. +- Delete or collapse Desktop-only normalized AI types that duplicate TanStack structures, such as separate text/reasoning/tool-call models, once the TanStack path can feed the UI directly. +- Keep Desktop-local types only for Beeper transport/persistence: Matrix event IDs, `target_event`, carrier visibility, `com.beeper.stream`, `com.beeper.ai.metadata`, and approval prompt Matrix metadata. + +Intentional AG-UI boundaries: + +- AG-UI owns semantic events, UI message parts, tool states, run input, and stream processing. +- Beeper owns transport: encrypted Matrix events, carrier hiding, target event mapping, replay from persisted Matrix history, and approval reaction cleanup. +- `com.beeper.ai.metadata` is Beeper message metadata and must not become a second UI message schema. It may store non-part run metadata, but streamed parts/chunks remain AG-UI. +- `target_event` is a Beeper transport pointer, not an AG-UI field. Keep it in the carrier envelope and Desktop stream routing, not in TanStack `UIMessage`. + +PAS sync: + +- In `src/pas-server/beeper/EventSyncContext.ts`, detect decrypted `m.room.message` events that contain stream delta content keys ending in `.deltas` or batched `updates`. +- Extract stream deltas from encrypted carrier events after decryption. +- Emit `STATE_SYNC message stream` updates instead of normal message upserts for carrier-only events. +- Mark carrier timeline events hidden after extraction so they do not show as chat bubbles. +- Keep the visible anchor event and approval prompt event as normal messages. +- Preserve `com.beeper.ai`, `com.beeper.stream`, and per-message profile behavior for anchor/final messages. + +Beeper client stream routing: + +- In `src/pas-server/beeper/BeeperClient.ts`, keep using the existing stream event path, but ensure multi-carrier events preserve `room_id`, carrier `event_id`, `target_event`, `threadId`, `runId`, `messageId`, and `seq`. + +Common types: + +- In `src/common/types/beeper.ts`, extend stream types to include AG-UI `threadId`, `runId`, and `messageId`, plus Beeper transport `target_event`. +- Keep support for both single-update `.deltas` and batched replay `updates`. + +Renderer store: + +- In `src/renderer/stores/AIChatsStore.ts`, merge by `{target_event, runId}` rather than only target message/run. +- Map carrier target events back to the visible anchor message. +- Continue buffering out-of-order `seq`. +- Treat stream carriers as dirtying and extending the first visible AI message only, never as separate visible messages. +- Merge streamed content and ordered UI parts into the first visible message's renderer state. +- Hide carrier messages after extracting deltas. +- Track approval prompts by approval ID and target tool call. +- Support multiple assistant messages/parts per run by indexing on AG-UI `messageId`, not assuming one text part per run. +- Support parallel tool calls by distinct tool call IDs and optional `index`. +- Preserve streamed partial tool arguments while parsing partial JSON best-effort; replace with finalized arguments when `TOOL_CALL_END` arrives. +- Accept tool output either on `TOOL_CALL_END.result` or as AG-UI `TOOL_CALL_RESULT`. +- Apply run-scoped versus thread-scoped `RUN_ERROR` behavior as described above. +- Reconstruct runs from persisted anchor plus carrier history during replay/backfill using the same code path as live streaming. + +UI message application: + +- In `src/renderer/ai/ui-message.ts`, apply AG-UI events into TanStack-shaped parts. +- Preserve ordered parts instead of collapsing everything by type. +- Render the resulting TanStack `UIMessage` with `@tanstack/ai-react-ui` instead of converting it into a separate Beeper-only part model. +- Support compatibility input for current events while preferring new output shapes: + - text + - thinking/step + - tool-call + - tool-result + - state snapshot/state delta/messages snapshot + - source-url/source-document/file/custom data +- Map approval states to TanStack states: + - `approval-requested` + - `approval-responded` + - result `complete` + - result `error` + +Message types: + +- AI visible messages and approval prompts should use message types that render as bubbles where intended. +- Stream carrier events are `m.room.message`/`m.text` for compatibility, but should not render as bubbles after Desktop extraction. +- Avoid `m.notice` for visible AI chat content in AI-network rooms because Desktop hides AI `m.notice` events. + +## Approvals + +Approval requests remain separate visible Matrix events with reaction options. + +Generic reaction option shape: + +```go +type ReactionOption[T any] struct { + ID string + Label string + Values []string + Value T +} +``` + +`Values` is the complete set of strings that should match this option. Entries may be literal emoji (`👍`), symbolic reaction keys (`approval.allow_once`), short names (`allow`), or bridge-specific aliases. The helper owns normalization and matching; callers should pass strings and not branch on whether a value is an emoji or a key. + +Tool approval response shape: + +```go +type ToolApprovalResponse struct { + ID string + Approved bool + Always bool + Reason string + Fields map[string]any + Metadata map[string]any +} +``` + +`Always` supports allow-always style options without making that concept Matrix-specific. `Fields` is for flexible provider/bridge-specific approval data that should survive resolution but not force new top-level schema every time. + +Rules: + +- AG-UI stream emits a tool-call state transition to `approval-requested`. +- The tool-call part includes `approval: { id, needsApproval: true }`. +- Matrix reaction choices are transport metadata and must not be embedded into AG-UI events. +- Approval prompt events should relate to the first visible anchor message and include `threadId`, `runId`, `messageId`, `toolCallId`, and approval ID. +- Matrix approval event stores `com.beeper.ai.approval` with tool call ID, tool name, `threadId`, `runId`, `messageId`, expiration if any, and reaction options. +- On user reaction, the bridge resolves the option to a `ToolApprovalResponse`. +- After resolution, emit AG-UI state `approval-responded`. +- If approved, continue execution and emit tool result `complete` or `error`. +- If denied, emit a `tool-result` with `state: "error"` and structured reason `denied`; do not pretend the tool executed. +- Approval options should support flexible fields, including allow-once, allow-always, deny, reason, and provider/bridge-specific metadata. +- Keep the user's selected Matrix reaction event exactly as the visible user choice, regardless of whether it matched by emoji or symbolic key. +- Remove bridge-posted placeholder option reactions and non-selected option reactions. +- The cleanup helper should return the selected option, selected reaction event ID if known, and a list of bridge-posted reaction event IDs to remove. Actual Matrix redaction/deletion remains the bridge adapter's job. +- Programmatic approval and Matrix reaction approval must share the same resolver and produce the same stream events. + +Custom events: + +- Support AG-UI `CUSTOM` events. +- Use built-in/custom names from TanStack when they exist, such as `approval-requested`. +- Beeper-specific custom events must use a clear namespace such as `com.beeper.*`. +- Do not add random one-off custom names when an AG-UI lifecycle, tool, state, or message event already models the behavior. + +## Open Decisions With Recommended Defaults + +These are the remaining decisions that affect product behavior or implementation shape. Use the recommended default unless the answer to the question changes the product intent. + +1. Source of truth for AG-UI schemas + - Recommended: `pkg/ag-ui` is the only Go source of truth for AG-UI concepts. Other packages import it instead of redefining parallel event, message, tool, or approval types. + - Decision: Desktop must use TanStack types directly for AG-UI/UI message concepts. Local Desktop types should only describe Beeper transport envelopes and app-specific metadata. + +2. Final persisted state for long runs + - Recommended: never require a final full-text edit. The first visible message stores compact identity/terminal metadata and Desktop reconstructs long content from carriers. + - Decision: compact final metadata should include everything needed to render the run except streamed parts/chunks. Do not store full parts/chunks in final metadata for large runs. + +3. Metadata contract on the first visible message + - Decision: first message owns all non-part run metadata: IDs, model, usage, finish/terminal state, approval summary, source/file/data descriptors that are metadata, and any archived `aichats` metadata that is not the streamed UI parts/chunks themselves. Do not include dollar cost fields unless there is a separate product decision. + +4. Dropped or invalid carriers + - Recommended: Desktop marks the first visible AI message failed and keeps carriers hidden. Do not expose carrier messages as fallback bubbles. + - Direct question: Should Desktop show a recoverable "stream incomplete" state, or a hard failed generation state? + +5. Approval idempotency + - Recommended: first valid approval resolution wins. Later Matrix reactions or programmatic responses are ignored, do not re-run the tool, and may be cleaned up as stale choices. + - Direct question: Should a user be allowed to change approval before the tool starts executing, or is first valid reaction always final? + +6. Allow-always behavior + - Recommended: support it generically in approval fields and reaction options, but dummybridge should only persist/use it if there is a clear storage target. + - Direct question: Should dummybridge actually remember allow-always across runs, or only emit the field to prove the UI/transport supports it? + +7. Package boundaries + - Recommended: embrace mautrix in `pkg/ai-stream/matrix`; keep bridgev2-specific database/queue/redaction in `pkg/ai-stream/bridgev2`; keep `pkg/ag-ui` pure. + - Direct question: Should `pkg/ai-stream/matrix` return mautrix `event.MessageEventContent` directly everywhere, or expose a small content struct plus conversion helpers? + +8. TanStack docs freshness + - Recommended: before implementation starts, re-open current TanStack AI docs and update the contract section if state names or part shapes changed. + - Direct question: Should implementation pin to the docs current at implementation start, or should tests tolerate small TanStack naming changes? + +## Tests + +Dummybridge Go tests: + +- Port archived parser tests. +- Verify help aliases. +- Verify command guide includes all commands. +- Verify conflicting terminal options are rejected. +- Verify invalid random profile is rejected. +- Verify oversized option inputs are rejected. +- Verify markdown-rich text generation is deterministic by seed and varied across calls. +- Verify table/link/list/code/quote markdown signals. +- Verify `stream-lorem` emits thinking, steps, text, sources, documents, files, persistent data, and excludes transient data from final snapshot. +- Verify `stream-tools` covers success, failure, approval, denial, delta input, input error, preliminary output, and provider-executed tools. +- Verify random streams finish and respect duration. +- Verify chaos streams start multiple runs with stagger and max-actions. +- Verify error and abort terminal states. + +`pkg/ag-ui` tests: + +- Validate all current AG-UI lifecycle event builders: `RUN_STARTED`, `RUN_FINISHED`, `RUN_ERROR`, `TEXT_MESSAGE_START`, `TEXT_MESSAGE_CONTENT`, `TEXT_MESSAGE_END`, `TOOL_CALL_START`, `TOOL_CALL_ARGS`, `TOOL_CALL_END`, `TOOL_CALL_RESULT`, `STEP_STARTED`, `STEP_FINISHED`, `STATE_SNAPSHOT`, `STATE_DELTA`, `MESSAGES_SNAPSHOT`, and `CUSTOM`. +- Validate event builders and required IDs. +- Validate `RunAgentInput`. +- Validate `UIMessage` ordered part shape. +- Validate tool-call and tool-result states against TanStack values. +- Validate approval request and response shapes. +- Validate step/thinking events. +- Validate `STATE_SNAPSHOT`, `STATE_DELTA`, and `MESSAGES_SNAPSHOT` event shapes. +- Validate every emitted event has `timestamp`. +- Validate `rawEvent` is optional and bounded/truncated/omitted before exceeding carrier limits. +- Validate `TOOL_CALL_START.index`. +- Validate partial JSON `TOOL_CALL_ARGS`. +- Validate `TOOL_CALL_END` with and without result. +- Validate `TOOL_CALL_RESULT` creates/updates TanStack `tool-result` parts. +- Validate run-scoped and thread/session-scoped `RUN_ERROR`. +- Validate multiple assistant `messageId`s per run. +- Reject legacy/non-AG-UI chunk shapes. + +`pkg/ai-stream` tests: + +- Verify ordered run writer output. +- Verify no per-delta accumulated full text. +- Verify final accumulator is only used at finalization. +- Verify UTF-8 splitting. +- Verify carrier packer respects the 58KB serialized JSON budget. +- Verify stream reconstruction from carriers. +- Verify duplicate/stale/out-of-order `seq` behavior. +- Verify missing `seq` gap timeout marks the anchor incomplete/failed. +- Verify carrier delete/redaction recomputes or marks the anchor incomplete/failed. +- Verify approval reaction resolver keeps the selected value and identifies removals. + +Over-64KB tests: + +- Generate at least 70KiB of output. +- Assert every carrier's serialized content is at or below 58KB. +- Assert at least two carrier events are emitted. +- Assert later carriers have no preview body or only minimal body. +- Assert reconstruction from deltas exactly equals generated output. +- Assert no final full-body edit is required to display the complete stream. + +Desktop tests: + +- PAS extracts `.deltas` from decrypted carrier events. +- Carrier-only events are hidden and do not render as chat bubbles. +- Single-update and batched `updates` formats still work. +- Multi-carrier stream merges into the visible anchor message. +- Out-of-order `seq` buffering works. +- Duplicate/stale `seq` handling works. +- TanStack-shaped text/thinking/tool/result parts render through the AI message view. +- State snapshot, state delta, and messages snapshot events are preserved and exposed to rendering/devtools. +- Approval prompt indexing works from both visible prompt metadata and stream state. +- Approval response transitions resolve approval state. +- Parallel tool calls render/merge by distinct tool call IDs and optional indexes. +- Partial JSON tool args remain visible while streaming and finalize cleanly. +- `TOOL_CALL_END.result` renders as a completed tool result. +- `RUN_ERROR` with `runId` fails only that run; `RUN_ERROR` without `runId` fails active runs in the thread. +- Multiple assistant `messageId`s in one run render in order. +- Replay/backfill reconstructs the same visible run from persisted anchor plus carrier history as live streaming. +- Deleted/redacted carriers keep carrier bubbles hidden and mark/recompute the anchor correctly. +- Ordering gaps time out instead of buffering forever. +- Over-64KB carrier sequence reconstructs into one AI message. + +Commands to run: + +- In dummybridge: `go test -mod=readonly ./...` +- In Desktop after adding `@tanstack/ai-react-ui`: run the package manager install/update command explicitly approved for that dependency and commit the resulting manifest/lockfile changes with the Desktop implementation. +- In Desktop: run the existing focused test commands for touched files. At minimum cover `ai-common`, `ui-message`, `AIChatsStore`, `EventSyncContext`, and stream mapper tests. + +## Live Smoke Testing + +Use bridgev2 and Desktop API, not raw Matrix sends, for end-to-end checks. + +Recommended smoke cases: + +- Create/login a QA account using the established `qatest+@beeper.com` pattern and fixed OTP only if a fresh account is needed. +- Create or reuse an AI DM through bridge-manager/Desktop API. +- Send `help` and confirm the command guide appears as a normal AI bubble. +- Send `stream-lorem 70000 --chunk-chars=512 --seed=7` and confirm Desktop shows one streaming AI message, not many carrier bubbles. +- Send `stream-tools 200 shell#approval --seed=3` and confirm the approval prompt appears separately with reaction options. +- React approve and confirm the selected emoji remains while other bridge options disappear and the tool completes. +- React deny and confirm the tool is cancelled/denied and does not execute. +- Send `stream-random 5 --actions=8 --allow-approval --seed=9`. +- Send `stream-chaos 3 5 --max-actions=5 --seed=11`. + +Acceptance criteria: + +- All carrier events are encrypted in E2EE rooms. +- No plaintext raw Matrix sends are used. +- Visible AI output uses bubble-rendering message types. +- Carrier events do not show as separate bubbles. +- Streaming remains incremental. +- Over-64KB output reconstructs correctly. +- Approvals work from Matrix reactions and programmatic/TanStack-shaped responses. +- The selected approval emoji is kept and non-selected placeholder options are removed. diff --git a/go.mod b/go.mod index 88ab339..a2add9c 100644 --- a/go.mod +++ b/go.mod @@ -1,38 +1,38 @@ module github.com/beeper/dummybridge -go 1.24.0 +go 1.25.0 toolchain go1.25.6 require ( - github.com/rs/zerolog v1.34.0 - go.mau.fi/util v0.9.5 - maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b + github.com/rs/zerolog v1.35.1 + go.mau.fi/util v0.9.9 + maunium.net/go/mautrix v0.28.0 ) require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.2.0 // indirect github.com/coder/websocket v1.8.14 // indirect - github.com/coreos/go-systemd/v22 v22.6.0 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/coreos/go-systemd/v22 v22.7.0 // indirect + github.com/lib/pq v1.12.3 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.33 // indirect - github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect + github.com/mattn/go-sqlite3 v1.14.44 // indirect + github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 // indirect github.com/rs/xid v1.6.0 // indirect github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e // indirect - github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/gjson v1.19.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - github.com/yuin/goldmark v1.7.16 // indirect + github.com/yuin/goldmark v1.8.2 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect - golang.org/x/crypto v0.47.0 // indirect - golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect - golang.org/x/net v0.49.0 // indirect - golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/crypto v0.51.0 // indirect + golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect + golang.org/x/net v0.54.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index ad7799f..ecbcb12 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= @@ -7,11 +9,15 @@ github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6p github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= +github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= +github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -21,8 +27,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -30,6 +40,8 @@ github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= +github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -37,6 +49,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU= +github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= @@ -46,25 +60,41 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= +github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= +go.mau.fi/util v0.9.9 h1:ujDeXCo07HBor5oQLyO1tHklupmqVmPgasc53d7q/NE= +go.mau.fi/util v0.9.9/go.mod h1:pqt4Vcrt+5gcH/CgrHZg11qSx+b34o6mknGzOEA6waY= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -75,3 +105,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= +maunium.net/go/mautrix v0.28.0 h1:vBakLzf8MAdfED3NzAKiMeKQbc3AQ4EAS03NC+TVMXQ= +maunium.net/go/mautrix v0.28.0/go.mod h1:/a9A7LGaqb9B3nho4tLd28n0EPcCdwpm2dxkxkLLgh0= diff --git a/pkg/ag-ui/events.go b/pkg/ag-ui/events.go new file mode 100644 index 0000000..d3ea264 --- /dev/null +++ b/pkg/ag-ui/events.go @@ -0,0 +1,650 @@ +package agui + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +const ( + EventRunStarted = "RUN_STARTED" + EventRunFinished = "RUN_FINISHED" + EventRunError = "RUN_ERROR" + EventTextMessageStart = "TEXT_MESSAGE_START" + EventTextMessageContent = "TEXT_MESSAGE_CONTENT" + EventTextMessageEnd = "TEXT_MESSAGE_END" + EventToolCallStart = "TOOL_CALL_START" + EventToolCallArgs = "TOOL_CALL_ARGS" + EventToolCallEnd = "TOOL_CALL_END" + EventToolCallResult = "TOOL_CALL_RESULT" + EventStepStarted = "STEP_STARTED" + EventStepFinished = "STEP_FINISHED" + EventStateSnapshot = "STATE_SNAPSHOT" + EventStateDelta = "STATE_DELTA" + EventMessagesSnapshot = "MESSAGES_SNAPSHOT" + EventCustom = "CUSTOM" + EventReasoningStart = "REASONING_START" + EventReasoningEnd = "REASONING_END" + EventReasoningMsgStart = "REASONING_MESSAGE_START" + EventReasoningMsgCont = "REASONING_MESSAGE_CONTENT" + EventReasoningMsgEnd = "REASONING_MESSAGE_END" +) + +const ( + RoleAssistant = "assistant" + RoleUser = "user" + RoleSystem = "system" + RoleTool = "tool" +) + +const ( + ToolStateAwaitingInput = "awaiting-input" + ToolStateInputStreaming = "input-streaming" + ToolStateInputComplete = "input-complete" + ToolStateApprovalRequested = "approval-requested" + ToolStateApprovalResponded = "approval-responded" + ToolResultStateStreaming = "streaming" + ToolResultStateComplete = "complete" + ToolResultStateError = "error" + ApprovalCustomRequested = "approval-requested" + ApprovalCustomResponded = "approval-responded" + FinishReasonStop = "stop" + FinishReasonLength = "length" + FinishReasonContentFilter = "content_filter" + FinishReasonToolCalls = "tool_calls" + FinishReasonOther = "other" +) + +type Event map[string]any + +type UIMessage struct { + ID string `json:"id"` + Role string `json:"role"` + Parts []MessagePart `json:"parts"` + CreatedAt *time.Time `json:"createdAt,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type MessagePart map[string]any + +type RunAgentInput struct { + ThreadID string `json:"threadId,omitempty"` + RunID string `json:"runId,omitempty"` + State map[string]any `json:"state,omitempty"` + Messages []UIMessage `json:"messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Context []ContextItem `json:"context,omitempty"` + ForwardedProps map[string]any `json:"forwardedProps,omitempty"` + Data map[string]any `json:"data,omitempty"` +} + +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]any `json:"inputSchema,omitempty"` + OutputSchema map[string]any `json:"outputSchema,omitempty"` + NeedsApproval bool `json:"needsApproval,omitempty"` +} + +type ContextItem struct { + Type string `json:"type"` + Value any `json:"value,omitempty"` + Meta map[string]any `json:"meta,omitempty"` +} + +type ToolApproval struct { + ID string `json:"id"` + NeedsApproval bool `json:"needsApproval"` + Fields map[string]any `json:"fields,omitempty"` +} + +type ToolApprovalResponse struct { + ID string `json:"id"` + Approved bool `json:"approved"` + Always bool `json:"always,omitempty"` + Reason string `json:"reason,omitempty"` + Fields map[string]any `json:"fields,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type Usage struct { + PromptTokens int `json:"promptTokens,omitempty"` + CompletionTokens int `json:"completionTokens,omitempty"` + TotalTokens int `json:"totalTokens,omitempty"` +} + +type EventBuilder struct { + now func() time.Time + model string +} + +func NewEventBuilder(model string, now func() time.Time) EventBuilder { + if now == nil { + now = time.Now + } + return EventBuilder{now: now, model: strings.TrimSpace(model)} +} + +func (b EventBuilder) base(eventType string) Event { + evt := Event{ + "type": eventType, + "timestamp": b.now().UnixMilli(), + } + if b.model != "" { + evt["model"] = b.model + } + return evt +} + +func (b EventBuilder) RunStarted(threadID, runID string) Event { + evt := b.base(EventRunStarted) + evt["threadId"] = threadID + evt["runId"] = runID + return evt +} + +func (b EventBuilder) RunFinished(threadID, runID, finishReason string, usage Usage) Event { + evt := b.base(EventRunFinished) + evt["threadId"] = threadID + evt["runId"] = runID + evt["finishReason"] = NormalizeFinishReason(finishReason) + evt["usage"] = usage + return evt +} + +func (b EventBuilder) RunError(threadID, runID, message string) Event { + evt := b.base(EventRunError) + evt["threadId"] = threadID + if strings.TrimSpace(runID) != "" { + evt["runId"] = runID + } + evt["message"] = message + evt["error"] = map[string]any{"message": message} + return evt +} + +func (b EventBuilder) TextMessageStart(messageID, role string) Event { + if role == "" { + role = RoleAssistant + } + evt := b.base(EventTextMessageStart) + evt["messageId"] = messageID + evt["role"] = role + return evt +} + +func (b EventBuilder) TextMessageContent(messageID, delta string) Event { + evt := b.base(EventTextMessageContent) + evt["messageId"] = messageID + evt["delta"] = delta + return evt +} + +func (b EventBuilder) TextMessageEnd(messageID string) Event { + evt := b.base(EventTextMessageEnd) + evt["messageId"] = messageID + return evt +} + +func (b EventBuilder) ReasoningStart(messageID string) Event { + evt := b.base(EventReasoningStart) + evt["messageId"] = messageID + return evt +} + +func (b EventBuilder) ReasoningEnd(messageID string) Event { + evt := b.base(EventReasoningEnd) + evt["messageId"] = messageID + return evt +} + +func (b EventBuilder) ReasoningMessageStart(messageID string) Event { + evt := b.base(EventReasoningMsgStart) + evt["messageId"] = messageID + return evt +} + +func (b EventBuilder) ReasoningMessageContent(messageID, delta string) Event { + evt := b.base(EventReasoningMsgCont) + evt["messageId"] = messageID + evt["delta"] = delta + return evt +} + +func (b EventBuilder) ReasoningMessageEnd(messageID string) Event { + evt := b.base(EventReasoningMsgEnd) + evt["messageId"] = messageID + return evt +} + +func (b EventBuilder) ToolCallStart(messageID, toolCallID, name string, index *int, approval *ToolApproval) Event { + evt := b.base(EventToolCallStart) + if messageID != "" { + evt["parentMessageId"] = messageID + } + evt["toolCallId"] = toolCallID + evt["toolCallName"] = name + evt["toolName"] = name + if index != nil { + evt["index"] = *index + } + if approval != nil { + evt["approval"] = approval + evt["state"] = ToolStateApprovalRequested + } else { + evt["state"] = ToolStateAwaitingInput + } + return evt +} + +func (b EventBuilder) ToolCallArgs(toolCallID, delta string, args any) Event { + evt := b.base(EventToolCallArgs) + evt["toolCallId"] = toolCallID + evt["delta"] = delta + evt["state"] = ToolStateInputStreaming + if args != nil { + evt["args"] = args + } + return evt +} + +func (b EventBuilder) ToolCallEnd(toolCallID, name string, input, result any, state string) Event { + evt := b.base(EventToolCallEnd) + evt["toolCallId"] = toolCallID + evt["toolCallName"] = name + evt["toolName"] = name + if input != nil { + evt["input"] = input + } + if result != nil { + evt["result"] = result + } + if state == "" { + state = ToolStateInputComplete + } + evt["state"] = state + return evt +} + +func (b EventBuilder) ToolCallResult(messageID, toolCallID, content, state, role string) Event { + if role == "" { + role = RoleTool + } + if state == "" { + state = ToolResultStateComplete + } + evt := b.base(EventToolCallResult) + evt["messageId"] = messageID + evt["toolCallId"] = toolCallID + evt["content"] = content + evt["state"] = state + evt["role"] = role + return evt +} + +func (b EventBuilder) StepStarted(messageID, stepName string) Event { + evt := b.base(EventStepStarted) + if messageID != "" { + evt["messageId"] = messageID + } + if stepName != "" { + evt["stepName"] = stepName + } + return evt +} + +func (b EventBuilder) StepFinished(messageID, stepName string) Event { + evt := b.base(EventStepFinished) + if messageID != "" { + evt["messageId"] = messageID + } + if stepName != "" { + evt["stepName"] = stepName + } + return evt +} + +func (b EventBuilder) StateSnapshot(state map[string]any) Event { + evt := b.base(EventStateSnapshot) + evt["snapshot"] = state + return evt +} + +func (b EventBuilder) StateDelta(delta any) Event { + evt := b.base(EventStateDelta) + evt["delta"] = delta + return evt +} + +func (b EventBuilder) MessagesSnapshot(messages []UIMessage) Event { + evt := b.base(EventMessagesSnapshot) + evt["messages"] = messages + return evt +} + +func (b EventBuilder) Custom(name string, value any) Event { + evt := b.base(EventCustom) + evt["name"] = name + evt["value"] = value + return evt +} + +func TextPart(content string) MessagePart { + return MessagePart{"type": "text", "content": content} +} + +func ThinkingPart(content string) MessagePart { + return MessagePart{"type": "thinking", "content": content} +} + +func ToolCallPart(id, name string, arguments any, state string, approval *ToolApproval, output any) MessagePart { + part := MessagePart{"type": "tool-call", "id": id, "name": name, "arguments": arguments, "state": state} + if approval != nil { + part["approval"] = approval + } + if output != nil { + part["output"] = output + } + return part +} + +func ToolResultPart(toolCallID string, content any, state string, err any) MessagePart { + part := MessagePart{"type": "tool-result", "toolCallId": toolCallID, "content": content, "state": state} + if err != nil { + part["error"] = err + } + return part +} + +func ValidateEvent(evt Event) error { + eventType, _ := evt["type"].(string) + if eventType == "" { + return fmt.Errorf("ag-ui event missing type") + } + if _, ok := evt["timestamp"]; !ok { + return fmt.Errorf("%s missing timestamp", eventType) + } + switch eventType { + case EventRunStarted: + return require(evt, "threadId", "runId") + case EventRunFinished: + return require(evt, "threadId", "runId", "finishReason") + case EventRunError: + return require(evt, "message") + case EventTextMessageStart: + return require(evt, "messageId", "role") + case EventTextMessageContent: + return require(evt, "messageId", "delta") + case EventTextMessageEnd: + return require(evt, "messageId") + case EventReasoningStart, EventReasoningEnd, EventReasoningMsgStart, EventReasoningMsgEnd: + return require(evt, "messageId") + case EventReasoningMsgCont: + return require(evt, "messageId", "delta") + case EventToolCallStart: + if err := require(evt, "toolCallId", "toolCallName"); err != nil { + return err + } + if approval, ok := evt["approval"]; ok { + if err := validateToolApproval(approval); err != nil { + return fmt.Errorf("%s has invalid approval: %w", evt["type"], err) + } + } + return validateStringSet(evt, "state", true, validToolStates) + case EventToolCallArgs: + if err := require(evt, "toolCallId", "delta"); err != nil { + return err + } + if err := validateStringSet(evt, "state", false, validToolStates); err != nil { + return err + } + if args, ok := evt["args"]; ok { + if _, ok := args.(string); !ok { + return fmt.Errorf("%s has invalid args %T", evt["type"], args) + } + } + return nil + case EventToolCallEnd: + if err := require(evt, "toolCallId"); err != nil { + return err + } + if result, ok := evt["result"]; ok { + if _, ok := result.(string); !ok { + return fmt.Errorf("%s has invalid result %T", evt["type"], result) + } + } + return validateStringSet(evt, "state", true, validToolStates) + case EventToolCallResult: + if err := require(evt, "messageId", "toolCallId", "content"); err != nil { + return err + } + return validateStringSet(evt, "state", false, validToolResultStates) + case EventStepStarted, EventStepFinished: + return require(evt, "stepName") + case EventStateSnapshot: + return require(evt, "snapshot") + case EventStateDelta: + return require(evt, "delta") + case EventMessagesSnapshot: + return require(evt, "messages") + case EventCustom: + return require(evt, "name") + default: + return fmt.Errorf("unsupported ag-ui event type %q", eventType) + } +} + +func validateToolApproval(value any) error { + switch approval := value.(type) { + case ToolApproval: + if strings.TrimSpace(approval.ID) == "" { + return fmt.Errorf("missing id") + } + if !approval.NeedsApproval { + return fmt.Errorf("needsApproval must be true") + } + return nil + case *ToolApproval: + if approval == nil { + return fmt.Errorf("missing approval") + } + return validateToolApproval(*approval) + case map[string]any: + id, _ := approval["id"].(string) + if strings.TrimSpace(id) == "" { + return fmt.Errorf("missing id") + } + if approval["needsApproval"] != true { + return fmt.Errorf("needsApproval must be true") + } + return nil + default: + return fmt.Errorf("unexpected %T", value) + } +} + +func ValidateEventSequence(events []Event) error { + seenRunStart := false + terminal := false + textOpen := map[string]bool{} + reasoningOpen := map[string]bool{} + toolStarted := map[string]bool{} + toolEnded := map[string]bool{} + + for i, evt := range events { + if err := ValidateEvent(evt); err != nil { + return fmt.Errorf("event %d: %w", i+1, err) + } + eventType, _ := evt["type"].(string) + if terminal { + return fmt.Errorf("event %d: %s after terminal run event", i+1, eventType) + } + + switch eventType { + case EventRunStarted: + if seenRunStart { + return fmt.Errorf("event %d: duplicate RUN_STARTED", i+1) + } + seenRunStart = true + case EventRunFinished: + if !seenRunStart { + return fmt.Errorf("event %d: RUN_FINISHED before RUN_STARTED", i+1) + } + terminal = true + case EventRunError: + terminal = true + case EventTextMessageStart: + messageID := stringField(evt, "messageId") + if textOpen[messageID] { + return fmt.Errorf("event %d: duplicate TEXT_MESSAGE_START for %s", i+1, messageID) + } + textOpen[messageID] = true + case EventTextMessageContent: + messageID := stringField(evt, "messageId") + if !textOpen[messageID] { + return fmt.Errorf("event %d: TEXT_MESSAGE_CONTENT before TEXT_MESSAGE_START for %s", i+1, messageID) + } + case EventTextMessageEnd: + messageID := stringField(evt, "messageId") + if !textOpen[messageID] { + return fmt.Errorf("event %d: TEXT_MESSAGE_END before TEXT_MESSAGE_START for %s", i+1, messageID) + } + delete(textOpen, messageID) + case EventReasoningMsgStart: + messageID := stringField(evt, "messageId") + if reasoningOpen[messageID] { + return fmt.Errorf("event %d: duplicate REASONING_MESSAGE_START for %s", i+1, messageID) + } + reasoningOpen[messageID] = true + case EventReasoningMsgCont: + messageID := stringField(evt, "messageId") + if !reasoningOpen[messageID] { + return fmt.Errorf("event %d: REASONING_MESSAGE_CONTENT before REASONING_MESSAGE_START for %s", i+1, messageID) + } + case EventReasoningMsgEnd: + messageID := stringField(evt, "messageId") + if !reasoningOpen[messageID] { + return fmt.Errorf("event %d: REASONING_MESSAGE_END before REASONING_MESSAGE_START for %s", i+1, messageID) + } + delete(reasoningOpen, messageID) + case EventToolCallStart: + toolCallID := stringField(evt, "toolCallId") + if toolStarted[toolCallID] { + return fmt.Errorf("event %d: duplicate TOOL_CALL_START for %s", i+1, toolCallID) + } + toolStarted[toolCallID] = true + case EventToolCallArgs: + toolCallID := stringField(evt, "toolCallId") + if !toolStarted[toolCallID] { + return fmt.Errorf("event %d: TOOL_CALL_ARGS before TOOL_CALL_START for %s", i+1, toolCallID) + } + case EventToolCallEnd: + toolCallID := stringField(evt, "toolCallId") + if !toolStarted[toolCallID] { + return fmt.Errorf("event %d: TOOL_CALL_END before TOOL_CALL_START for %s", i+1, toolCallID) + } + if toolEnded[toolCallID] { + return fmt.Errorf("event %d: duplicate TOOL_CALL_END for %s", i+1, toolCallID) + } + toolEnded[toolCallID] = true + case EventToolCallResult: + toolCallID := stringField(evt, "toolCallId") + if !toolStarted[toolCallID] { + return fmt.Errorf("event %d: TOOL_CALL_RESULT before TOOL_CALL_START for %s", i+1, toolCallID) + } + } + } + return nil +} + +var validToolStates = map[string]bool{ + ToolStateAwaitingInput: true, + ToolStateInputStreaming: true, + ToolStateInputComplete: true, + ToolStateApprovalRequested: true, + ToolStateApprovalResponded: true, +} + +func stringField(evt Event, key string) string { + value, _ := evt[key].(string) + return value +} + +var validToolResultStates = map[string]bool{ + ToolResultStateStreaming: true, + ToolResultStateComplete: true, + ToolResultStateError: true, +} + +func validateStringSet(evt Event, key string, required bool, allowed map[string]bool) error { + value, ok := evt[key] + if !ok || value == nil { + if required { + return fmt.Errorf("%s missing %s", evt["type"], key) + } + return nil + } + stringValue, ok := value.(string) + if !ok || !allowed[stringValue] { + return fmt.Errorf("%s has invalid %s %q", evt["type"], key, value) + } + return nil +} + +func NormalizeFinishReason(value string) string { + switch strings.TrimSpace(strings.ToLower(value)) { + case "", FinishReasonStop: + return FinishReasonStop + case FinishReasonLength: + return FinishReasonLength + case "content-filter", "contentfilter", FinishReasonContentFilter: + return FinishReasonContentFilter + case "tool-calls", "toolcalls", FinishReasonToolCalls: + return FinishReasonToolCalls + case FinishReasonOther: + return FinishReasonOther + default: + return FinishReasonStop + } +} + +func CloneEvent(evt Event) Event { + raw, err := json.Marshal(evt) + if err != nil { + cp := make(Event, len(evt)) + for k, v := range evt { + cp[k] = v + } + return cp + } + var cp Event + if err := json.Unmarshal(raw, &cp); err != nil { + cp = make(Event, len(evt)) + for k, v := range evt { + cp[k] = v + } + } + return cp +} + +func require(evt Event, keys ...string) error { + for _, key := range keys { + value, ok := evt[key] + if !ok || emptyValue(value) { + return fmt.Errorf("%s missing %s", evt["type"], key) + } + } + return nil +} + +func emptyValue(value any) bool { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) == "" + case nil: + return true + default: + return false + } +} diff --git a/pkg/ag-ui/events_test.go b/pkg/ag-ui/events_test.go new file mode 100644 index 0000000..f05d019 --- /dev/null +++ b/pkg/ag-ui/events_test.go @@ -0,0 +1,144 @@ +package agui + +import ( + "testing" + "time" +) + +func TestBuildersCoverLifecycleEventsWithTimestamps(t *testing.T) { + now := func() time.Time { return time.Unix(10, 0) } + builder := NewEventBuilder("dummy/model", now) + idx := 1 + events := []Event{ + builder.RunStarted("thread", "run"), + builder.RunFinished("thread", "run", "tool-calls", Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}), + builder.RunError("thread", "run", "failed"), + builder.TextMessageStart("msg", RoleAssistant), + builder.TextMessageContent("msg", "hello"), + builder.TextMessageEnd("msg"), + builder.ReasoningStart("msg"), + builder.ReasoningMessageStart("msg"), + builder.ReasoningMessageContent("msg", "thinking"), + builder.ReasoningMessageEnd("msg"), + builder.ReasoningEnd("msg"), + builder.ToolCallStart("msg", "tool", "search", &idx, &ToolApproval{ID: "approval", NeedsApproval: true}), + builder.ToolCallArgs("tool", `{"q":"he`, nil), + builder.ToolCallEnd("tool", "search", map[string]any{"q": "hello"}, `{"ok":true}`, ToolStateInputComplete), + builder.ToolCallResult("msg", "tool", `{"ok":true}`, ToolResultStateComplete, RoleTool), + builder.StepStarted("msg", "step"), + builder.StepFinished("msg", "step"), + builder.StateSnapshot(map[string]any{"open": true}), + builder.StateDelta(map[string]any{"path": "/open", "value": false}), + builder.MessagesSnapshot([]UIMessage{{ID: "msg", Role: RoleAssistant, Parts: []MessagePart{TextPart("hello")}}}), + builder.Custom("com.beeper.test", map[string]any{"ok": true}), + } + for _, evt := range events { + if err := ValidateEvent(evt); err != nil { + t.Fatalf("ValidateEvent(%s) returned error: %v", evt["type"], err) + } + if evt["timestamp"] == nil { + t.Fatalf("event missing timestamp: %#v", evt) + } + } + if got := events[1]["finishReason"]; got != FinishReasonToolCalls { + t.Fatalf("finish reason = %q, want %q", got, FinishReasonToolCalls) + } + if got := events[2]["message"]; got != "failed" { + t.Fatalf("run error message = %#v, want failed", got) + } + toolStart := events[11] + if got := toolStart["index"]; got != 1 { + t.Fatalf("tool index = %#v, want 1", got) + } + if got := toolStart["parentMessageId"]; got != "msg" { + t.Fatalf("tool parentMessageId = %#v, want msg", got) + } + if _, hasMessageID := toolStart["messageId"]; hasMessageID { + t.Fatalf("tool start should not emit deprecated messageId: %#v", toolStart) + } + if _, hasSnapshot := events[17]["snapshot"]; !hasSnapshot { + t.Fatalf("state snapshot should emit snapshot field: %#v", events[17]) + } +} + +func TestValidateRejectsBadEvents(t *testing.T) { + tests := []Event{ + {}, + {"type": EventRunStarted, "timestamp": int64(1), "threadId": "thread"}, + {"type": EventRunError, "timestamp": int64(1), "threadId": "thread", "error": map[string]any{"message": "failed"}}, + {"type": EventTextMessageContent, "timestamp": int64(1), "messageId": "msg"}, + {"type": "REASONING_MESSAGE_CONTENT", "timestamp": int64(1)}, + {"type": EventToolCallStart, "timestamp": int64(1), "toolCallId": "tool", "toolCallName": "search", "state": "output-available"}, + {"type": EventToolCallStart, "timestamp": int64(1), "toolCallId": "tool", "toolCallName": "search", "state": ToolStateApprovalRequested, "approval": ToolApproval{ID: "", NeedsApproval: true}}, + {"type": EventToolCallStart, "timestamp": int64(1), "toolCallId": "tool", "toolCallName": "search", "state": ToolStateApprovalRequested, "approval": map[string]any{"id": "approval", "needsApproval": false}}, + {"type": EventToolCallArgs, "timestamp": int64(1), "toolCallId": "tool", "delta": "{}", "args": map[string]any{"bad": true}}, + {"type": EventToolCallEnd, "timestamp": int64(1), "toolCallId": "tool", "result": map[string]any{"bad": true}, "state": ToolStateInputComplete}, + {"type": EventToolCallResult, "timestamp": int64(1), "messageId": "msg", "toolCallId": "tool", "content": "{}", "state": "output-error"}, + {"type": EventStepStarted, "timestamp": int64(1), "stepId": "deprecated-only"}, + {"type": EventStateSnapshot, "timestamp": int64(1), "state": map[string]any{}}, + } + for _, evt := range tests { + if err := ValidateEvent(evt); err == nil { + t.Fatalf("expected validation error for %#v", evt) + } + } +} + +func TestValidateEventSequenceRejectsBadOrdering(t *testing.T) { + now := func() time.Time { return time.Unix(10, 0) } + builder := NewEventBuilder("dummy/model", now) + + valid := []Event{ + builder.RunStarted("thread", "run"), + builder.TextMessageStart("msg", RoleAssistant), + builder.TextMessageContent("msg", "hello"), + builder.TextMessageEnd("msg"), + builder.ToolCallStart("msg", "tool", "search", nil, nil), + builder.ToolCallArgs("tool", `{"q":"hello"}`, `{"q":"hello"}`), + builder.ToolCallEnd("tool", "search", map[string]any{"q": "hello"}, `{"ok":true}`, ToolStateInputComplete), + builder.RunFinished("thread", "run", FinishReasonStop, Usage{}), + } + if err := ValidateEventSequence(valid); err != nil { + t.Fatalf("ValidateEventSequence(valid) returned error: %v", err) + } + + tests := [][]Event{ + {builder.TextMessageContent("msg", "hello")}, + {builder.ReasoningMessageContent("msg", "thinking")}, + {builder.ToolCallArgs("tool", "{}", "{}")}, + {builder.ToolCallResult("msg", "tool", "{}", ToolResultStateComplete, RoleTool)}, + { + builder.RunStarted("thread", "run"), + builder.RunFinished("thread", "run", FinishReasonStop, Usage{}), + builder.TextMessageStart("msg", RoleAssistant), + }, + } + for _, events := range tests { + if err := ValidateEventSequence(events); err == nil { + t.Fatalf("expected ordering error for %#v", events) + } + } +} + +func TestRunAgentInputModelsBidirectionalShape(t *testing.T) { + input := RunAgentInput{ + ThreadID: "thread", + RunID: "run", + State: map[string]any{"open": true}, + Messages: []UIMessage{{ + ID: "msg", + Role: RoleUser, + Parts: []MessagePart{TextPart("hello")}, + }}, + Tools: []Tool{{Name: "send_email", NeedsApproval: true}}, + Context: []ContextItem{{ + Type: "beeper-room", + Value: "room", + }}, + ForwardedProps: map[string]any{"trace": "abc"}, + Data: map[string]any{"legacy": true}, + } + if input.ThreadID != "thread" || !input.Tools[0].NeedsApproval || input.ForwardedProps["trace"] != "abc" { + t.Fatalf("bad RunAgentInput shape: %#v", input) + } +} diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go new file mode 100644 index 0000000..5821344 --- /dev/null +++ b/pkg/ai-stream/approval.go @@ -0,0 +1,194 @@ +package aistream + +import ( + "strings" + "time" + + "github.com/beeper/dummybridge/pkg/ag-ui" +) + +const ( + ApprovalReactionAllowOnce = "approval.allow_once" + ApprovalReactionAllowAlways = "approval.allow_always" + ApprovalReactionDeny = "approval.deny" +) + +type ReactionOption[T any] struct { + ID string `json:"id"` + Label string `json:"label"` + Values []string `json:"values"` + Value T `json:"value"` +} + +type ApprovalCleanup[T any] struct { + Selected ReactionOption[T] + SelectedReactionEvent string + RedactReactionEvents []string + Matched bool +} + +type ReactionEvent struct { + EventID string + Sender string + Key string + Bridge bool +} + +type ApprovalContext struct { + ID string `json:"id"` + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + MessageID string `json:"messageId"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + TargetEvent string `json:"target_event"` + AgentID string `json:"agentId,omitempty"` + AgentName string `json:"agentName,omitempty"` + Model string `json:"model,omitempty"` + SeqStart int `json:"seqStart,omitempty"` +} + +func DefaultApprovalOptions(approvalID string) []ReactionOption[agui.ToolApprovalResponse] { + return []ReactionOption[agui.ToolApprovalResponse]{ + { + ID: ApprovalReactionAllowOnce, + Label: "Allow", + Values: []string{"👍", "approval.allow_once", "allow", "allow_once"}, + Value: agui.ToolApprovalResponse{ID: approvalID, Approved: true}, + }, + { + ID: ApprovalReactionAllowAlways, + Label: "Always allow", + Values: []string{"✅", "approval.allow_always", "always", "allow_always"}, + Value: agui.ToolApprovalResponse{ID: approvalID, Approved: true, Always: true}, + }, + { + ID: ApprovalReactionDeny, + Label: "Deny", + Values: []string{"👎", "approval.deny", "deny", "reject"}, + Value: agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: "denied"}, + }, + } +} + +func ResolveReaction[T any](options []ReactionOption[T], raw string) (ReactionOption[T], bool) { + key := NormalizeReaction(raw) + for _, option := range options { + if NormalizeReaction(option.ID) == key { + return option, true + } + for _, value := range option.Values { + if NormalizeReaction(value) == key { + return option, true + } + } + } + var zero ReactionOption[T] + return zero, false +} + +func CleanupReactions[T any](options []ReactionOption[T], selectedKey string, events []ReactionEvent, bridgeSender string) ApprovalCleanup[T] { + selected, ok := ResolveReaction(options, selectedKey) + if !ok { + return ApprovalCleanup[T]{} + } + cleanup := ApprovalCleanup[T]{Selected: selected, Matched: true} + for _, evt := range events { + if evt.EventID == "" { + continue + } + option, matchesOption := ResolveReaction(options, evt.Key) + isSelected := matchesOption && option.ID == selected.ID + isBridge := evt.Bridge || (bridgeSender != "" && evt.Sender == bridgeSender) + if isSelected && !isBridge && cleanup.SelectedReactionEvent == "" { + cleanup.SelectedReactionEvent = evt.EventID + continue + } + if isBridge || (matchesOption && !isSelected) { + cleanup.RedactReactionEvents = append(cleanup.RedactReactionEvents, evt.EventID) + } + } + return cleanup +} + +func NormalizeReaction(reaction string) string { + reaction = strings.TrimSpace(reaction) + reaction = strings.ReplaceAll(reaction, "\ufe0f", "") + return strings.ToLower(reaction) +} + +func ApprovalResponseRun(ctx ApprovalContext, response agui.ToolApprovalResponse, now time.Time) Run { + if response.ID == "" { + response.ID = ctx.ID + } + agentID := ctx.AgentID + if agentID == "" { + agentID = "ai" + } + agentName := ctx.AgentName + if agentName == "" { + agentName = "AI" + } + model := ctx.Model + if model == "" { + model = DefaultModel + } + run := NewRun("approval-"+ctx.ID, ctx.ThreadID, model, agentID, agentName, now) + run.RunID = ctx.RunID + run.MessageID = ctx.MessageID + run.ToolCallID = ctx.ToolCallID + run.ApprovalID = ctx.ID + run.Status = Status{State: "complete"} + run.Approvals = []ApprovalSummary{{ + ID: ctx.ID, + ToolCallID: ctx.ToolCallID, + State: approvalSummaryState(response), + Always: response.Always, + Reason: response.Reason, + Fields: response.Fields, + Metadata: response.Metadata, + }} + builder := agui.NewEventBuilder(model, func() time.Time { return now }) + run.Events = append(run.Events, builder.Custom(agui.ApprovalCustomResponded, map[string]any{ + "threadId": ctx.ThreadID, + "runId": ctx.RunID, + "messageId": ctx.MessageID, + "toolCallId": ctx.ToolCallID, + "toolName": ctx.ToolName, + "approval": response, + })) + result := map[string]any{ + "approvalId": response.ID, + "always": response.Always, + } + if response.Fields != nil { + result["fields"] = response.Fields + } + if response.Metadata != nil { + result["metadata"] = response.Metadata + } + if response.Approved { + result["state"] = agui.ToolResultStateComplete + result["approved"] = true + } else { + reason := response.Reason + if reason == "" { + reason = "denied" + } + result["state"] = agui.ToolResultStateError + result["reason"] = reason + run.Status = Status{State: "error", Error: result} + } + run.Events = append(run.Events, builder.ToolCallEnd(ctx.ToolCallID, ctx.ToolName, nil, jsonString(result), agui.ToolStateApprovalResponded)) + return *run +} + +func approvalSummaryState(response agui.ToolApprovalResponse) string { + if response.Approved { + if response.Always { + return "approved-always" + } + return "approved" + } + return "denied" +} diff --git a/pkg/ai-stream/bridgev2/events.go b/pkg/ai-stream/bridgev2/events.go new file mode 100644 index 0000000..5f5c7b9 --- /dev/null +++ b/pkg/ai-stream/bridgev2/events.go @@ -0,0 +1,133 @@ +package aibridgev2 + +import ( + "context" + "time" + + aistream "github.com/beeper/dummybridge/pkg/ai-stream" + aimatrix "github.com/beeper/dummybridge/pkg/ai-stream/matrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/simplevent" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func Anchor(portalKey networkid.PortalKey, sender networkid.UserID, run aistream.Run, timestamp time.Time) *simplevent.PreConvertedMessage { + content, extra := aimatrix.AnchorContent(run) + return &simplevent.PreConvertedMessage{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventMessage, + PortalKey: portalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: timestamp, + StreamOrder: timestamp.UnixNano(), + }, + Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{{ + ID: networkid.PartID("0"), + Type: event.EventMessage, + Content: content, + Extra: extra, + }}}, + ID: networkid.MessageID(run.MessageID), + } +} + +func Carrier(portalKey networkid.PortalKey, sender networkid.UserID, run aistream.Run, carrier aistream.Carrier, targetEventID id.EventID, index int, timestamp time.Time) *simplevent.PreConvertedMessage { + content, extra := aimatrix.CarrierContent(carrier, targetEventID) + return &simplevent.PreConvertedMessage{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventMessage, + PortalKey: portalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: timestamp, + StreamOrder: timestamp.UnixNano(), + }, + Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{{ + ID: networkid.PartID("0"), + Type: event.EventMessage, + Content: content, + Extra: extra, + }}}, + ID: networkid.MessageID(aistream.StreamTxnID(run.RunID, index)), + } +} + +func ApprovalPrompt(portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, timestamp time.Time) *simplevent.PreConvertedMessage { + content, extra := aimatrix.ApprovalContent(ctx, aistream.DefaultApprovalOptions(ctx.ID)) + return &simplevent.PreConvertedMessage{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventMessage, + PortalKey: portalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: timestamp, + StreamOrder: timestamp.UnixNano(), + }, + Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{{ + ID: networkid.PartID("0"), + Type: event.EventMessage, + Content: content, + Extra: extra, + DBMetadata: map[string]any{ + "com.beeper.ai.approval": ctx, + }, + }}}, + ID: networkid.MessageID(ctx.ID), + } +} + +func ApprovalOptionReaction[T any](portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, option aistream.ReactionOption[T], timestamp time.Time) *simplevent.Reaction { + return &simplevent.Reaction{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventReaction, + PortalKey: portalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: timestamp, + StreamOrder: timestamp.UnixNano(), + }, + TargetMessage: networkid.MessageID(ctx.ID), + EmojiID: networkid.EmojiID(option.ID), + Emoji: option.Values[0], + ExtraContent: map[string]any{ + "com.beeper.ai.approval_option": map[string]any{ + "approvalId": ctx.ID, + "toolCallId": ctx.ToolCallID, + "optionId": option.ID, + "value": option.Value, + }, + }, + } +} + +func FinalMetadataEdit(portalKey networkid.PortalKey, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, timestamp time.Time) *simplevent.Message[*aistream.Run] { + finalContent, finalExtra := aimatrix.AnchorContent(run) + return &simplevent.Message[*aistream.Run]{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventEdit, + PortalKey: portalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: timestamp, + StreamOrder: timestamp.UnixNano(), + }, + Data: &run, + ID: messageID, + TargetMessage: messageID, + ConvertEditFunc: func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data *aistream.Run) (*bridgev2.ConvertedEdit, error) { + if len(existing) == 0 { + return nil, nil + } + return &bridgev2.ConvertedEdit{ + ModifiedParts: []*bridgev2.ConvertedEditPart{{ + Part: existing[0], + Type: event.EventMessage, + Content: finalContent, + Extra: finalExtra, + TopLevelExtra: map[string]any{ + "com.beeper.dont_render_edited": true, + }, + }}, + }, nil + }, + } +} diff --git a/pkg/ai-stream/bridgev2/events_test.go b/pkg/ai-stream/bridgev2/events_test.go new file mode 100644 index 0000000..2218aa3 --- /dev/null +++ b/pkg/ai-stream/bridgev2/events_test.go @@ -0,0 +1,100 @@ +package aibridgev2 + +import ( + "strings" + "testing" + "time" + + aistream "github.com/beeper/dummybridge/pkg/ai-stream" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestBridgeV2AIEvents(t *testing.T) { + now := time.Unix(10, 0) + run := aistream.NewRun("run-1", "thread-1", "", "ai", "AI", now) + run.Preview = aistream.Preview{Text: "visible preview"} + + anchor := Anchor( + networkid.PortalKey{ID: "portal-1"}, + networkid.UserID("ai"), + *run, + now, + ) + if anchor.Type != bridgev2.RemoteEventMessage { + t.Fatalf("anchor type = %v", anchor.Type) + } + part := anchor.Data.Parts[0] + if part.Type != event.EventMessage || part.Content.Body != "visible preview" { + t.Fatalf("unexpected anchor part: %#v", part) + } + if part.Extra[aistream.BeeperAIKey] == nil || part.Extra[aistream.BeeperAIMetadataKey] == nil { + t.Fatalf("anchor missing AI metadata: %#v", part.Extra) + } + stream, ok := part.Extra["com.beeper.stream"].(map[string]any) + if !ok || stream["type"] != aistream.BeeperAIStreamDeltas { + t.Fatalf("anchor missing stream descriptor: %#v", part.Extra) + } + + carrier := Carrier( + networkid.PortalKey{ID: "portal-1"}, + networkid.UserID("ai"), + *run, + aistream.Carrier{Envelopes: []aistream.Envelope{{ + Seq: 1, + RunID: run.RunID, + ThreadID: run.ThreadID, + TargetEvent: "$anchor", + }}}, + id.EventID("$anchor"), + 1, + now, + ) + carrierPart := carrier.Data.Parts[0] + if carrierPart.Content.MsgType != event.MsgText || carrierPart.Content.Body != "" { + t.Fatalf("carrier should be hidden text carrier: %#v", carrierPart.Content) + } + if carrierPart.Extra[aistream.BeeperAIStreamDeltas] == nil { + t.Fatalf("carrier missing deltas: %#v", carrierPart.Extra) + } + + approval := ApprovalPrompt(networkid.PortalKey{ID: "portal-1"}, networkid.UserID("ai"), aistream.ApprovalContext{ + ID: "approval-1", + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + ToolCallID: "tool-1", + ToolName: "dummy_echo", + TargetEvent: "$anchor", + }, now) + approvalPart := approval.Data.Parts[0] + approvalMetadata, ok := approvalPart.DBMetadata.(map[string]any) + if !ok || approvalMetadata["com.beeper.ai.approval"] == nil { + t.Fatalf("approval missing DB metadata: %#v", approvalPart.DBMetadata) + } +} + +func TestFinalMetadataEditUsesCompactAnchorContent(t *testing.T) { + now := time.Unix(10, 0) + run := aistream.NewRun("run-1", "thread-1", "", "ai", "AI", now) + run.Preview = aistream.Preview{Text: strings.Repeat("a", aistream.PreviewBudgetBytes+1), Truncated: true} + + edit := FinalMetadataEdit( + networkid.PortalKey{ID: "portal-1"}, + networkid.UserID("ai"), + networkid.MessageID(run.MessageID), + *run, + now, + ) + if edit.Type != bridgev2.RemoteEventEdit { + t.Fatalf("final metadata event type = %v", edit.Type) + } + if edit.TargetMessage != networkid.MessageID(run.MessageID) { + t.Fatalf("final metadata target = %q", edit.TargetMessage) + } + if edit.Data.Text() != "" { + t.Fatalf("final metadata edit must not expose full accumulated text") + } +} diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go new file mode 100644 index 0000000..e0a9e88 --- /dev/null +++ b/pkg/ai-stream/matrix/content.go @@ -0,0 +1,90 @@ +package matrix + +import ( + "fmt" + + "github.com/beeper/dummybridge/pkg/ag-ui" + "github.com/beeper/dummybridge/pkg/ai-stream" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any) { + body := run.Preview.Text + if body == "" { + body = "..." + } + rendered := format.RenderMarkdown(body, true, false) + content := &rendered + content.BeeperPerMessageProfile = &event.BeeperPerMessageProfile{ + ID: run.AgentID, + Displayname: run.AgentName, + } + extra := map[string]any{ + aistream.BeeperAIKey: run.InitialUIMessage(), + aistream.BeeperAIMetadataKey: run.Metadata(), + "com.beeper.stream": map[string]any{ + "type": aistream.BeeperAIStreamDeltas, + }, + } + return content, extra +} + +func CarrierContent(carrier aistream.Carrier, targetEventID id.EventID) (*event.MessageEventContent, map[string]any) { + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: "", + Mentions: &event.Mentions{}, + RelatesTo: &event.RelatesTo{ + Type: event.RelReference, + EventID: targetEventID, + }, + } + return content, aistream.CarrierContent(carrier.Envelopes) +} + +func ApprovalContent(ctx aistream.ApprovalContext, options []aistream.ReactionOption[agui.ToolApprovalResponse]) (*event.MessageEventContent, map[string]any) { + toolName := ctx.ToolName + body := fmt.Sprintf("Approval required for %s", toolName) + if len(options) > 0 { + body += "\nReact with one of the listed options." + } + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: body, + Mentions: &event.Mentions{}, + } + if ctx.TargetEvent != "" { + content.RelatesTo = &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(ctx.TargetEvent)} + } + extra := map[string]any{ + "com.beeper.ai.approval": map[string]any{ + "id": ctx.ID, + "toolCallId": ctx.ToolCallID, + "toolName": toolName, + "threadId": ctx.ThreadID, + "runId": ctx.RunID, + "messageId": ctx.MessageID, + "approval": agui.ToolApproval{ + ID: ctx.ID, + NeedsApproval: true, + }, + "reactions": ReactionOptionsAsAny(options), + }, + } + return content, extra +} + +func ReactionOptionsAsAny(options []aistream.ReactionOption[agui.ToolApprovalResponse]) []any { + out := make([]any, 0, len(options)) + for _, option := range options { + out = append(out, map[string]any{ + "id": option.ID, + "label": option.Label, + "values": option.Values, + "value": option.Value, + }) + } + return out +} diff --git a/pkg/ai-stream/matrix/content_test.go b/pkg/ai-stream/matrix/content_test.go new file mode 100644 index 0000000..467c59a --- /dev/null +++ b/pkg/ai-stream/matrix/content_test.go @@ -0,0 +1,131 @@ +package matrix + +import ( + "strings" + "testing" + "time" + + "github.com/beeper/dummybridge/pkg/ag-ui" + "github.com/beeper/dummybridge/pkg/ai-stream" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestAnchorContentUsesVisibleTextAndAIProfile(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + run.Preview = aistream.Preview{Text: "visible preview"} + + content, extra := AnchorContent(*run) + if content.MsgType != event.MsgText || content.Body != "visible preview" { + t.Fatalf("bad anchor content: %#v", content) + } + if content.Format != event.FormatHTML || content.FormattedBody == "" { + t.Fatalf("anchor preview should include Matrix HTML: %#v", content) + } + if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.ID != "ai" || content.BeeperPerMessageProfile.Displayname != "AI" { + t.Fatalf("missing AI per-message profile: %#v", content.BeeperPerMessageProfile) + } + uiMessage, ok := extra[aistream.BeeperAIKey].(map[string]any) + if !ok || uiMessage["id"] == "" || uiMessage["metadata"] != nil { + t.Fatalf("bad compact AI message: %#v", extra[aistream.BeeperAIKey]) + } + if extra[aistream.BeeperAIMetadataKey] == nil { + t.Fatalf("missing AI metadata: %#v", extra) + } + stream, ok := extra["com.beeper.stream"].(map[string]any) + if !ok || stream["user_id"] != nil || stream["type"] != aistream.BeeperAIStreamDeltas { + t.Fatalf("missing stream descriptor: %#v", extra["com.beeper.stream"]) + } +} + +func TestAnchorContentKeepsLongRunsCompact(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + writer.Text(strings.Repeat("a", 70*1024)) + writer.Finish(agui.FinishReasonStop) + + content, extra := AnchorContent(*run) + if len(content.Body) > aistream.PreviewBudgetBytes { + t.Fatalf("anchor body length = %d, want <= %d", len(content.Body), aistream.PreviewBudgetBytes) + } + metadata := extra[aistream.BeeperAIMetadataKey].(map[string]any) + if _, hasParts := metadata["parts"]; hasParts { + t.Fatalf("metadata must not contain streamed parts: %#v", metadata) + } + if _, hasChunks := metadata["chunks"]; hasChunks { + t.Fatalf("metadata must not contain streamed chunks: %#v", metadata) + } + preview := metadata["preview"].(aistream.Preview) + if !preview.Truncated || len(preview.Text) > aistream.PreviewBudgetBytes { + t.Fatalf("bad bounded preview: %#v", preview) + } +} + +func TestAnchorContentRendersFinalPreviewAsMatrixHTML(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + run.Preview = aistream.Preview{Text: "Use **bold** and `code`"} + + content, _ := AnchorContent(*run) + if content.Format != event.FormatHTML { + t.Fatalf("format = %q, want Matrix HTML", content.Format) + } + if !strings.Contains(content.FormattedBody, "bold") || !strings.Contains(content.FormattedBody, "code") { + t.Fatalf("formatted body did not render markdown: %q", content.FormattedBody) + } +} + +func TestCarrierContentIsHiddenTextCarrierWithDeltas(t *testing.T) { + carrier := aistream.Carrier{Envelopes: []aistream.Envelope{{ + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Seq: 1, + TargetEvent: "$anchor", + }}} + + content, extra := CarrierContent(carrier, id.EventID("$anchor")) + if content.MsgType != event.MsgText || content.Body != "" { + t.Fatalf("carrier should be empty m.text, got %#v", content) + } + if content.RelatesTo == nil || content.RelatesTo.EventID != "$anchor" { + t.Fatalf("carrier should reference anchor, got %#v", content.RelatesTo) + } + deltas, ok := extra[aistream.BeeperAIStreamDeltas].([]aistream.Envelope) + if !ok || len(deltas) != 1 || deltas[0].Seq != 1 { + t.Fatalf("missing deltas: %#v", extra) + } +} + +func TestApprovalContentIncludesContextAndGenericReactionOptions(t *testing.T) { + ctx := aistream.ApprovalContext{ + ID: "approval-1", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + ToolCallID: "tool-1", + ToolName: "shell", + TargetEvent: "$anchor", + } + options := aistream.DefaultApprovalOptions(ctx.ID) + + content, extra := ApprovalContent(ctx, options) + if content.MsgType != event.MsgText || content.RelatesTo == nil || content.RelatesTo.EventID != "$anchor" { + t.Fatalf("bad approval content: %#v", content) + } + meta, ok := extra["com.beeper.ai.approval"].(map[string]any) + if !ok { + t.Fatalf("missing approval metadata: %#v", extra) + } + if meta["id"] != ctx.ID || meta["runId"] != ctx.RunID || meta["messageId"] != ctx.MessageID || meta["toolCallId"] != ctx.ToolCallID { + t.Fatalf("bad approval metadata: %#v", meta) + } + reactions, ok := meta["reactions"].([]any) + if !ok || len(reactions) != len(options) { + t.Fatalf("bad approval reactions: %#v", meta["reactions"]) + } + first := reactions[0].(map[string]any) + if first["id"] != aistream.ApprovalReactionAllowOnce { + t.Fatalf("bad first reaction option: %#v", first) + } +} diff --git a/pkg/ai-stream/pack.go b/pkg/ai-stream/pack.go new file mode 100644 index 0000000..011291d --- /dev/null +++ b/pkg/ai-stream/pack.go @@ -0,0 +1,236 @@ +package aistream + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/beeper/dummybridge/pkg/ag-ui" +) + +type Envelope struct { + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + MessageID string `json:"messageId"` + Seq int `json:"seq"` + Part agui.Event `json:"part"` + TargetEvent string `json:"target_event,omitempty"` + RelatesTo Relation `json:"m.relates_to,omitempty"` + AgentID string `json:"agent_id,omitempty"` +} + +type Relation struct { + Type string `json:"rel_type"` + EventID string `json:"event_id"` +} + +type Carrier struct { + Envelopes []Envelope +} + +func BuildEnvelope(run Run, seq int, part agui.Event, targetEventID string) (Envelope, error) { + if seq <= 0 { + return Envelope{}, fmt.Errorf("stream envelope: seq must be > 0") + } + if err := agui.ValidateEvent(part); err != nil { + return Envelope{}, err + } + targetEventID = strings.TrimSpace(targetEventID) + if targetEventID == "" { + return Envelope{}, fmt.Errorf("stream envelope: missing target event id") + } + return Envelope{ + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Seq: seq, + Part: part, + TargetEvent: targetEventID, + RelatesTo: Relation{Type: "m.reference", EventID: targetEventID}, + AgentID: run.AgentID, + }, nil +} + +func PackRun(run Run, targetEventID string, budget int) ([]Carrier, error) { + return PackRunFromSeq(run, targetEventID, budget, 1) +} + +func PackRunFromSeq(run Run, targetEventID string, budget int, startSeq int) ([]Carrier, error) { + if budget <= 0 { + budget = CarrierBudgetBytes + } + if startSeq <= 0 { + startSeq = 1 + } + if err := run.Validate(); err != nil { + return nil, err + } + var carriers []Carrier + var current Carrier + seq := startSeq + for _, original := range run.Events { + for _, part := range splitEventForBudget(original, budget) { + env, err := BuildEnvelope(run, seq, part, targetEventID) + if err != nil { + return nil, err + } + single := CarrierContent([]Envelope{env}) + if JSONSize(single) > budget { + return nil, fmt.Errorf("stream envelope %d exceeds %d byte budget", seq, budget) + } + candidate := append(append([]Envelope{}, current.Envelopes...), env) + if len(current.Envelopes) > 0 && JSONSize(CarrierContent(candidate)) > budget { + carriers = append(carriers, current) + current = Carrier{} + } + current.Envelopes = append(current.Envelopes, env) + seq++ + } + } + if len(current.Envelopes) > 0 { + carriers = append(carriers, current) + } + return carriers, nil +} + +func eventTimestampMillis(evt agui.Event) int64 { + switch value := evt["timestamp"].(type) { + case int64: + return value + case int: + return int64(value) + case float64: + return int64(value) + case json.Number: + n, _ := value.Int64() + return n + default: + return 0 + } +} + +func NextSeq(carriers []Carrier) int { + next := 1 + for _, carrier := range carriers { + for _, env := range carrier.Envelopes { + if env.Seq >= next { + next = env.Seq + 1 + } + } + } + return next +} + +func CarrierContent(envelopes []Envelope) map[string]any { + return map[string]any{BeeperAIStreamDeltas: envelopes} +} + +func ReconstructText(carriers []Carrier) string { + var out strings.Builder + for _, carrier := range carriers { + for _, env := range carrier.Envelopes { + if env.Part["type"] == agui.EventTextMessageContent { + delta, _ := env.Part["delta"].(string) + out.WriteString(delta) + } + } + } + return out.String() +} + +func splitEventForBudget(evt agui.Event, budget int) []agui.Event { + if JSONSize(evt) <= budget { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } + if evt["type"] == agui.EventMessagesSnapshot { + return splitMessagesSnapshotForBudget(evt, budget) + } + if evt["type"] != agui.EventTextMessageContent { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } + delta, _ := evt["delta"].(string) + if delta == "" { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } + maxDelta := budget / 2 + if maxDelta < 1024 { + maxDelta = 1024 + } + var out []agui.Event + for _, chunk := range SplitTextUTF8(delta, maxDelta) { + cp := agui.CloneEvent(evt) + cp["delta"] = chunk + out = append(out, sanitizeRawEvent(cp, budget)) + } + return out +} + +func splitMessagesSnapshotForBudget(evt agui.Event, budget int) []agui.Event { + rawMessages, ok := evt["messages"].([]agui.UIMessage) + if !ok || len(rawMessages) == 0 { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } + var out []agui.Event + for _, message := range rawMessages { + base := agui.CloneEvent(evt) + messageWithoutParts := message + messageWithoutParts.Parts = nil + base["messages"] = []agui.UIMessage{messageWithoutParts} + var current []agui.MessagePart + flush := func() { + if len(current) == 0 { + return + } + cp := agui.CloneEvent(evt) + msg := message + msg.Parts = append([]agui.MessagePart{}, current...) + cp["messages"] = []agui.UIMessage{msg} + out = append(out, sanitizeRawEvent(cp, budget)) + current = nil + } + for _, part := range message.Parts { + candidate := append(append([]agui.MessagePart{}, current...), part) + cp := agui.CloneEvent(evt) + msg := message + msg.Parts = candidate + cp["messages"] = []agui.UIMessage{msg} + if len(current) > 0 && JSONSize(cp) > budget { + flush() + } + current = append(current, part) + } + flush() + } + if len(out) == 0 { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } + return out +} + +func sanitizeRawEvent(evt agui.Event, budget int) agui.Event { + cp := agui.CloneEvent(evt) + if _, ok := cp["rawEvent"]; !ok { + return cp + } + if JSONSize(cp) <= budget { + return cp + } + raw, err := json.Marshal(cp["rawEvent"]) + if err != nil || len(raw) > 2048 { + cp["rawEvent"] = string(raw[:min(len(raw), 2048)]) + cp["rawEventTruncated"] = true + } + if JSONSize(cp) > budget { + delete(cp, "rawEvent") + cp["rawEventTruncated"] = true + } + return cp +} + +func StreamTxnID(runID string, seq int) string { + runID = strings.TrimSpace(runID) + if runID == "" { + return fmt.Sprintf("ai_stream_%d", seq) + } + return fmt.Sprintf("ai_stream_%s_%d", runID, seq) +} diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go new file mode 100644 index 0000000..cbfd767 --- /dev/null +++ b/pkg/ai-stream/run.go @@ -0,0 +1,657 @@ +package aistream + +import ( + "encoding/json" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/beeper/dummybridge/pkg/ag-ui" +) + +const ( + BeeperAIKey = "com.beeper.ai" + BeeperAIMetadataKey = "com.beeper.ai.metadata" + BeeperAIStreamKey = "com.beeper.llm" + BeeperAIStreamDeltas = BeeperAIStreamKey + ".deltas" + DefaultModel = "dummybridge/ag-ui" + CarrierBudgetBytes = 58 * 1024 + PreviewBudgetBytes = 4096 + SnapshotTextBytes = 4096 +) + +type Run struct { + ThreadID string + RunID string + MessageID string + Model string + AgentID string + AgentName string + Events []agui.Event + Approvals []ApprovalSummary + Artifacts ArtifactSummary + Data map[string]any + Status Status + Usage agui.Usage + Preview Preview + ToolCallID string + ApprovalID string + Prompts []ApprovalPrompt +} + +type Status struct { + State string `json:"state"` + FinishReason string `json:"finishReason,omitempty"` + Terminal any `json:"terminal"` + Error any `json:"error"` +} + +type Preview struct { + Text string `json:"text"` + Truncated bool `json:"truncated"` +} + +type ApprovalSummary struct { + ID string `json:"id"` + ToolCallID string `json:"toolCallId"` + State string `json:"state"` + Always bool `json:"always"` + Reason string `json:"reason,omitempty"` + Fields map[string]any `json:"fields,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ApprovalPrompt struct { + ID string + ToolCallID string + ToolName string + SeqStart int +} + +type ArtifactSummary struct { + Sources []map[string]any `json:"sources"` + Documents []map[string]any `json:"documents"` + Files []map[string]any `json:"files"` +} + +type Writer struct { + Run *Run + builder agui.EventBuilder + reasoningOpen bool +} + +func NewRun(runID, threadID, model, agentID, agentName string, now time.Time) *Run { + runID = strings.TrimSpace(runID) + if runID == "" { + runID = fmt.Sprintf("run-%d", now.UnixNano()) + } + threadID = strings.TrimSpace(threadID) + if threadID == "" { + threadID = runID + } + model = strings.TrimSpace(model) + if model == "" { + model = DefaultModel + } + if agentID == "" { + agentID = "ai" + } + if agentName == "" { + agentName = "AI" + } + run := &Run{ + ThreadID: threadID, + RunID: runID, + MessageID: "msg-" + runID, + Model: model, + AgentID: agentID, + AgentName: agentName, + Data: map[string]any{}, + Status: Status{State: "streaming"}, + } + run.Preview = Preview{Text: BoundedPreview("", PreviewBudgetBytes)} + return run +} + +func NewWriter(run *Run, now func() time.Time) *Writer { + return &Writer{Run: run, builder: agui.NewEventBuilder(run.Model, now)} +} + +func (w *Writer) Add(evt agui.Event) { + if w == nil || w.Run == nil || len(evt) == 0 { + return + } + w.Run.Events = append(w.Run.Events, evt) + w.applySummary(evt) +} + +func (w *Writer) Start() { + w.Add(w.builder.RunStarted(w.Run.ThreadID, w.Run.RunID)) + w.Add(w.builder.TextMessageStart(w.Run.MessageID, agui.RoleAssistant)) +} + +func (w *Writer) Text(delta string) { + if delta == "" { + return + } + w.Add(w.builder.TextMessageContent(w.Run.MessageID, delta)) +} + +func (w *Writer) Thinking(delta string) { + if delta == "" { + return + } + if !w.reasoningOpen { + w.Add(w.builder.ReasoningStart(w.Run.MessageID)) + w.Add(w.builder.ReasoningMessageStart(w.Run.MessageID)) + w.reasoningOpen = true + } + w.Add(w.builder.ReasoningMessageContent(w.Run.MessageID, delta)) +} + +func (w *Writer) StepStart(stepID string) { + w.Add(w.builder.StepStarted(w.Run.MessageID, stepID)) +} + +func (w *Writer) StepFinish(stepID string) { + w.Add(w.builder.StepFinished(w.Run.MessageID, stepID)) +} + +func (w *Writer) ToolStart(toolCallID, name string, index int, approval *agui.ToolApproval) { + idx := index + w.Add(w.builder.ToolCallStart(w.Run.MessageID, toolCallID, name, &idx, approval)) + if approval != nil { + w.recordApprovalRequest(toolCallID, name, approval) + } +} + +func (w *Writer) ToolApprovalRequested(toolCallID, name string, input any, approval agui.ToolApproval) { + w.recordApprovalRequest(toolCallID, name, &approval) + w.Add(w.builder.Custom(agui.ApprovalCustomRequested, map[string]any{ + "threadId": w.Run.ThreadID, + "runId": w.Run.RunID, + "messageId": w.Run.MessageID, + "toolCallId": toolCallID, + "toolName": name, + "input": input, + "approval": approval, + })) +} + +func (w *Writer) recordApprovalRequest(toolCallID, name string, approval *agui.ToolApproval) { + if approval == nil || approval.ID == "" { + return + } + w.Run.ToolCallID = toolCallID + w.Run.ApprovalID = approval.ID + for _, existing := range w.Run.Approvals { + if existing.ID == approval.ID { + return + } + } + w.Run.Approvals = append(w.Run.Approvals, ApprovalSummary{ + ID: approval.ID, + ToolCallID: toolCallID, + State: "requested", + }) + w.Run.Prompts = append(w.Run.Prompts, ApprovalPrompt{ID: approval.ID, ToolCallID: toolCallID, ToolName: name}) +} + +func (w *Writer) ToolArgs(toolCallID, delta string, args any) { + w.Add(w.builder.ToolCallArgs(toolCallID, delta, args)) +} + +func (w *Writer) ToolEnd(toolCallID, name string, input, result any) { + w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(result), agui.ToolStateInputComplete)) +} + +func (w *Writer) ToolApprovalInputComplete(toolCallID, name string, input any) { + w.Add(w.builder.ToolCallEnd(toolCallID, name, input, nil, agui.ToolStateApprovalRequested)) +} + +func (w *Writer) ToolResult(toolCallID, content, state string) { + w.Add(w.builder.ToolCallResult(w.Run.MessageID, toolCallID, content, state, agui.RoleTool)) +} + +func (w *Writer) ToolError(toolCallID, name string, input any, reason string) { + w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(map[string]any{ + "state": agui.ToolResultStateError, + "reason": reason, + }), agui.ToolStateInputComplete)) +} + +func (w *Writer) ToolDenied(toolCallID, name string, input any, approvalID, reason string) { + if reason == "" { + reason = "denied" + } + for i := range w.Run.Approvals { + if w.Run.Approvals[i].ID == approvalID { + w.Run.Approvals[i].State = "denied" + w.Run.Approvals[i].Reason = reason + } + } + w.Add(w.builder.Custom(agui.ApprovalCustomResponded, map[string]any{ + "approval": agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: reason}, + })) + w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(map[string]any{ + "state": agui.ToolResultStateError, + "reason": "denied", + }), agui.ToolStateApprovalResponded)) +} + +func jsonString(value any) any { + if value == nil { + return nil + } + if text, ok := value.(string); ok { + return text + } + raw, err := json.Marshal(value) + if err != nil { + return fmt.Sprint(value) + } + return string(raw) +} + +func (w *Writer) StateSnapshot(state map[string]any) { + w.Add(w.builder.StateSnapshot(state)) +} + +func (w *Writer) StateDelta(delta any) { + w.Add(w.builder.StateDelta(delta)) +} + +func (w *Writer) MessagesSnapshot(messages []agui.UIMessage) { + w.Add(w.builder.MessagesSnapshot(messages)) +} + +func (w *Writer) Custom(name string, value any) { + w.Add(w.builder.Custom(name, value)) +} + +func (w *Writer) Finish(reason string) { + reason = agui.NormalizeFinishReason(reason) + text := w.Run.Text() + w.finishReasoning() + w.Run.Usage = agui.Usage{ + PromptTokens: 1, + CompletionTokens: utf8.RuneCountInString(text), + TotalTokens: utf8.RuneCountInString(text) + 1, + } + w.Run.Status = Status{State: "complete", FinishReason: reason} + w.Add(w.builder.TextMessageEnd(w.Run.MessageID)) + w.addFinalSnapshot() + w.Add(w.builder.RunFinished(w.Run.ThreadID, w.Run.RunID, reason, w.Run.Usage)) +} + +func (w *Writer) Error(message string) { + w.finishReasoning() + w.Run.Status = Status{State: "error", Error: map[string]any{"message": message}} + w.addFinalSnapshot() + w.Add(w.builder.RunError(w.Run.ThreadID, w.Run.RunID, message)) +} + +func (w *Writer) Abort(message string) { + w.finishReasoning() + w.Run.Status = Status{State: "aborted", Error: map[string]any{"message": message}} + w.addFinalSnapshot() + w.Add(w.builder.RunError(w.Run.ThreadID, w.Run.RunID, message)) +} + +func (w *Writer) addFinalSnapshot() { + if w == nil || w.Run == nil { + return + } + w.MessagesSnapshot([]agui.UIMessage{w.Run.FinalUIMessageSnapshot(SnapshotTextBytes)}) +} + +func (w *Writer) finishReasoning() { + if !w.reasoningOpen { + return + } + w.Add(w.builder.ReasoningMessageEnd(w.Run.MessageID)) + w.Add(w.builder.ReasoningEnd(w.Run.MessageID)) + w.reasoningOpen = false +} + +func (w *Writer) applySummary(evt agui.Event) { + switch evt["type"] { + case agui.EventTextMessageContent: + if delta, _ := evt["delta"].(string); delta != "" { + w.Run.Preview = PreviewFromText(w.Run.Text(), PreviewBudgetBytes) + } + case agui.EventCustom: + name, _ := evt["name"].(string) + value, _ := evt["value"].(map[string]any) + switch name { + case "com.beeper.source": + w.Run.Artifacts.Sources = append(w.Run.Artifacts.Sources, value) + case "com.beeper.document": + w.Run.Artifacts.Documents = append(w.Run.Artifacts.Documents, value) + case "com.beeper.file": + w.Run.Artifacts.Files = append(w.Run.Artifacts.Files, value) + case "com.beeper.data": + if key, _ := value["name"].(string); key != "" { + w.Run.Data[key] = value["value"] + } + } + } +} + +func (t Run) Text() string { + var out strings.Builder + for _, evt := range t.Events { + if evt["type"] == agui.EventTextMessageContent { + if delta, _ := evt["delta"].(string); delta != "" { + out.WriteString(delta) + } + } + } + return out.String() +} + +func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { + message := agui.UIMessage{ + ID: t.MessageID, + Role: agui.RoleAssistant, + Metadata: map[string]any{ + "threadId": t.ThreadID, + "runId": t.RunID, + "status": t.Status, + "usage": t.Usage, + }, + } + var textPart agui.MessagePart + var thinkingPart agui.MessagePart + toolParts := map[string]agui.MessagePart{} + toolResultParts := map[string]agui.MessagePart{} + approvalByID := map[string]any{} + appendPart := func(part agui.MessagePart) agui.MessagePart { + message.Parts = append(message.Parts, part) + return part + } + for _, evt := range t.Events { + switch evt["type"] { + case agui.EventTextMessageContent: + delta, _ := evt["delta"].(string) + if delta == "" { + continue + } + if textPart == nil { + textPart = appendPart(agui.MessagePart{"type": "text", "content": "", "state": "streaming"}) + } + textPart["content"] = asString(textPart["content"]) + delta + case agui.EventTextMessageEnd: + if textPart != nil { + textPart["state"] = "done" + } + case agui.EventReasoningMsgCont: + delta, _ := evt["delta"].(string) + if delta == "" { + continue + } + if thinkingPart == nil { + thinkingPart = appendPart(agui.MessagePart{"type": "thinking", "content": "", "state": "streaming"}) + } + thinkingPart["content"] = asString(thinkingPart["content"]) + delta + case agui.EventReasoningMsgEnd: + if thinkingPart != nil { + thinkingPart["state"] = "done" + } + case agui.EventToolCallStart: + toolCallID, _ := evt["toolCallId"].(string) + if toolCallID == "" { + continue + } + part := agui.MessagePart{ + "type": "tool-call", + "id": toolCallID, + "toolCallId": toolCallID, + "name": firstString(evt["toolName"], evt["toolCallName"]), + "arguments": "", + "state": firstString(evt["state"]), + } + if index, ok := evt["index"]; ok { + part["index"] = index + } + if approval, ok := evt["approval"]; ok { + part["approval"] = approval + } + toolParts[toolCallID] = appendPart(part) + case agui.EventToolCallArgs: + toolCallID, _ := evt["toolCallId"].(string) + part := toolParts[toolCallID] + if part == nil { + part = appendPart(agui.MessagePart{"type": "tool-call", "id": toolCallID, "toolCallId": toolCallID, "arguments": ""}) + toolParts[toolCallID] = part + } + part["state"] = firstString(evt["state"]) + if delta, _ := evt["delta"].(string); delta != "" { + part["arguments"] = asString(part["arguments"]) + delta + } + if args, ok := evt["args"]; ok { + part["input"] = args + } + case agui.EventToolCallEnd: + toolCallID, _ := evt["toolCallId"].(string) + part := toolParts[toolCallID] + if part == nil { + part = appendPart(agui.MessagePart{"type": "tool-call", "id": toolCallID, "toolCallId": toolCallID}) + toolParts[toolCallID] = part + } + part["name"] = firstString(part["name"], evt["toolName"], evt["toolCallName"]) + part["state"] = firstString(evt["state"]) + if input, ok := evt["input"]; ok { + part["input"] = input + } + if result, ok := evt["result"]; ok { + part["output"] = result + } + case agui.EventToolCallResult: + toolCallID, _ := evt["toolCallId"].(string) + if toolCallID == "" { + continue + } + part := toolResultParts[toolCallID] + if part == nil { + part = appendPart(agui.MessagePart{"type": "tool-result", "toolCallId": toolCallID, "content": "", "state": firstString(evt["state"])}) + toolResultParts[toolCallID] = part + } + part["state"] = firstString(evt["state"]) + part["content"] = asString(part["content"]) + asString(evt["content"]) + case agui.EventCustom: + name, _ := evt["name"].(string) + value, _ := evt["value"].(map[string]any) + switch name { + case agui.ApprovalCustomRequested: + if toolCallID, _ := value["toolCallId"].(string); toolCallID != "" { + if part := toolParts[toolCallID]; part != nil { + part["approval"] = value["approval"] + part["state"] = agui.ToolStateApprovalRequested + } + } + case agui.ApprovalCustomResponded: + if approval, ok := value["approval"]; ok { + approvalByID[approvalMapID(approval)] = approval + } + case "com.beeper.source": + message.Parts = append(message.Parts, agui.MessagePart{"type": "source-url", "source": value}) + case "com.beeper.document": + message.Parts = append(message.Parts, agui.MessagePart{"type": "file", "file": value}) + case "com.beeper.file": + message.Parts = append(message.Parts, agui.MessagePart{"type": "file", "file": value}) + case "com.beeper.data": + message.Parts = append(message.Parts, agui.MessagePart{"type": "data-com-beeper-data", "data": value}) + } + } + } + for _, part := range toolParts { + if approvalID := approvalMapID(part["approval"]); approvalID != "" { + if response := approvalByID[approvalID]; response != nil { + part["approvalResponse"] = response + part["state"] = agui.ToolStateApprovalResponded + } + } + } + compactTextPart(textPart, textBudget) + compactTextPart(thinkingPart, textBudget) + return message +} + +func (t Run) InitialUIMessage() map[string]any { + return map[string]any{ + "id": t.MessageID, + "role": agui.RoleAssistant, + "parts": []any{}, + } +} + +func compactTextPart(part agui.MessagePart, budget int) { + if part == nil { + return + } + content, _ := part["content"].(string) + preview := BoundedPreview(content, budget) + part["content"] = preview + if len(preview) < len(content) { + part["providerMetadata"] = map[string]any{"truncated": true} + } + if part["state"] == "" { + part["state"] = "done" + } +} + +func asString(value any) string { + switch typed := value.(type) { + case string: + return typed + case fmt.Stringer: + return typed.String() + case nil: + return "" + default: + return fmt.Sprint(typed) + } +} + +func firstString(values ...any) string { + for _, value := range values { + if text, ok := value.(string); ok && text != "" { + return text + } + } + return "" +} + +func approvalMapID(value any) string { + switch typed := value.(type) { + case agui.ToolApproval: + return typed.ID + case *agui.ToolApproval: + if typed != nil { + return typed.ID + } + case agui.ToolApprovalResponse: + return typed.ID + case *agui.ToolApprovalResponse: + if typed != nil { + return typed.ID + } + case map[string]any: + id, _ := typed["id"].(string) + return id + } + return "" +} + +func (t Run) Metadata() map[string]any { + return map[string]any{ + "schema": "com.beeper.ai.run.v1", + "protocol": "ag-ui", + "threadId": t.ThreadID, + "runId": t.RunID, + "messageId": t.MessageID, + "agent": map[string]any{ + "id": t.AgentID, + "displayName": t.AgentName, + }, + "model": t.Model, + "usage": map[string]any{ + "promptTokens": t.Usage.PromptTokens, + "completionTokens": t.Usage.CompletionTokens, + "totalTokens": t.Usage.TotalTokens, + }, + "usageDetails": map[string]any{}, + "status": t.Status, + "approvals": t.Approvals, + "artifacts": t.Artifacts, + "data": t.Data, + "preview": t.Preview, + } +} + +func (t Run) Validate() error { + for i, evt := range t.Events { + if err := agui.ValidateEvent(evt); err != nil { + return fmt.Errorf("event %d: %w", i+1, err) + } + } + return nil +} + +func PreviewFromText(text string, budget int) Preview { + preview := BoundedPreview(text, budget) + return Preview{Text: preview, Truncated: len(preview) < len(text)} +} + +func BoundedPreview(text string, budget int) string { + text = strings.TrimSpace(text) + if budget <= 0 || len(text) <= budget { + return text + } + end := budget + for end > 0 && !utf8.RuneStart(text[end]) { + end-- + } + if end <= 0 { + return "" + } + return strings.TrimSpace(text[:end]) +} + +func SplitTextUTF8(text string, maxBytes int) []string { + if maxBytes <= 0 { + return nil + } + if len(text) <= maxBytes { + return []string{text} + } + var chunks []string + start := 0 + for start < len(text) { + end := start + maxBytes + if end >= len(text) { + chunks = append(chunks, text[start:]) + break + } + for end > start && !utf8.RuneStart(text[end]) { + end-- + } + if end == start { + _, size := utf8.DecodeRuneInString(text[start:]) + end = start + size + } + chunks = append(chunks, text[start:end]) + start = end + } + return chunks +} + +func JSONSize(value any) int { + raw, err := json.Marshal(value) + if err != nil { + return CarrierBudgetBytes + 1 + } + return len(raw) +} diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go new file mode 100644 index 0000000..a4c3830 --- /dev/null +++ b/pkg/ai-stream/stream_test.go @@ -0,0 +1,236 @@ +package aistream + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/beeper/dummybridge/pkg/ag-ui" +) + +func TestPackRunSplitsOver64KBAndReconstructs(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + writer.Text(strings.Repeat("a", 70*1024)) + writer.Finish(agui.FinishReasonStop) + + carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + if len(carriers) < 2 { + t.Fatalf("expected multiple carriers for over-64KB output, got %d", len(carriers)) + } + for i, carrier := range carriers { + if size := JSONSize(CarrierContent(carrier.Envelopes)); size > CarrierBudgetBytes { + t.Fatalf("carrier %d is %d bytes, budget %d", i, size, CarrierBudgetBytes) + } + for _, env := range carrier.Envelopes { + if env.SeqTotal <= 0 { + t.Fatalf("carrier envelope missing total count: %#v", env) + } + } + } + if got := ReconstructText(carriers); got != strings.Repeat("a", 70*1024) { + t.Fatalf("reconstructed text length = %d", len(got)) + } +} + +func TestPackRunUsesDeltaEventsInsteadOfAccumulatedText(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + tick := int64(10) + writer := NewWriter(run, func() time.Time { + tick++ + return time.Unix(tick, 0) + }) + writer.Start() + writer.Text("abc") + writer.Text("def") + writer.Finish(agui.FinishReasonStop) + + carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + if len(carriers) != 1 { + t.Fatalf("under-budget run should be packed into one carrier, got %d", len(carriers)) + } + var deltas []string + for _, carrier := range carriers { + for _, env := range carrier.Envelopes { + if env.Part["type"] == agui.EventTextMessageContent { + deltas = append(deltas, env.Part["delta"].(string)) + } + } + } + if strings.Join(deltas, "|") != "abc|def" { + t.Fatalf("expected original deltas only, got %#v", deltas) + } +} + +func TestRawEventIsTruncatedBeforePacking(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + builder := agui.NewEventBuilder(DefaultModel, func() time.Time { return time.Unix(10, 0) }) + run.Events = append(run.Events, builder.Custom("com.beeper.debug", map[string]any{"ok": true})) + run.Events[0]["rawEvent"] = strings.Repeat("x", CarrierBudgetBytes) + + carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + part := carriers[0].Envelopes[0].Part + if part["rawEventTruncated"] != true { + t.Fatalf("expected rawEventTruncated marker, got %#v", part) + } + if size := JSONSize(CarrierContent(carriers[0].Envelopes)); size > CarrierBudgetBytes { + t.Fatalf("carrier size = %d, budget %d", size, CarrierBudgetBytes) + } +} + +func TestPackRunRejectsOversizedNonTextEvent(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + builder := agui.NewEventBuilder(DefaultModel, func() time.Time { return time.Unix(10, 0) }) + run.Events = append(run.Events, builder.Custom("com.beeper.large", map[string]any{ + "value": strings.Repeat("x", CarrierBudgetBytes), + })) + + _, err := PackRun(*run, "$anchor", CarrierBudgetBytes) + if err == nil { + t.Fatal("expected oversized non-text event to fail packing") + } +} + +func TestValidateRejectsLegacyOrInvalidToolResultShape(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + builder := agui.NewEventBuilder(DefaultModel, func() time.Time { return time.Unix(10, 0) }) + run.Events = append(run.Events, + builder.RunStarted("thread-1", "run-1"), + builder.ToolCallStart("msg-run-1", "tool-1", "shell", nil, nil), + builder.ToolCallEnd("tool-1", "shell", nil, map[string]any{"ok": true}, agui.ToolStateInputComplete), + ) + if err := run.Validate(); err == nil { + t.Fatal("expected validation error for non-string TOOL_CALL_END.result") + } +} + +func TestApprovalResolverMatchesEmojiKeysAndAliases(t *testing.T) { + options := DefaultApprovalOptions("approval-1") + for _, key := range []string{"👍", "approval.allow_once", "allow"} { + option, ok := ResolveReaction(options, key) + if !ok || !option.Value.Approved || option.Value.Always { + t.Fatalf("expected allow-once for %q, got %#v ok=%v", key, option, ok) + } + } + option, ok := ResolveReaction(options, "always") + if !ok || !option.Value.Approved || !option.Value.Always { + t.Fatalf("expected allow-always, got %#v ok=%v", option, ok) + } + option, ok = ResolveReaction(options, "👎") + if !ok || option.Value.Approved || option.Value.Reason != "denied" { + t.Fatalf("expected denial, got %#v ok=%v", option, ok) + } +} + +func TestCleanupKeepsSelectedUserReactionAndRemovesBridgeOptions(t *testing.T) { + options := DefaultApprovalOptions("approval-1") + cleanup := CleanupReactions(options, "👍", []ReactionEvent{ + {EventID: "$bridge-allow", Sender: "ai", Key: "👍", Bridge: true}, + {EventID: "$bridge-deny", Sender: "ai", Key: "👎", Bridge: true}, + {EventID: "$user-allow", Sender: "@user:example", Key: "👍"}, + {EventID: "$user-deny", Sender: "@user:example", Key: "👎"}, + }, "ai") + if !cleanup.Matched || cleanup.SelectedReactionEvent != "$user-allow" { + t.Fatalf("bad selected reaction: %#v", cleanup) + } + got := strings.Join(cleanup.RedactReactionEvents, ",") + if !strings.Contains(got, "$bridge-allow") || !strings.Contains(got, "$bridge-deny") || !strings.Contains(got, "$user-deny") { + t.Fatalf("bad cleanup redactions: %#v", cleanup.RedactReactionEvents) + } +} + +func TestApprovalResponseRunEmitsRespondedStateAndToolResult(t *testing.T) { + run := ApprovalResponseRun(ApprovalContext{ + ID: "approval-1", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-1", + ToolCallID: "tool-1", + ToolName: "shell", + TargetEvent: "$anchor", + SeqStart: 10, + }, agui.ToolApprovalResponse{ + Approved: false, + Reason: "denied", + Fields: map[string]any{"scope": "once"}, + Metadata: map[string]any{"source": "reaction"}, + }, time.Unix(10, 0)) + + if run.RunID != "run-1" || run.MessageID != "msg-1" { + t.Fatalf("approval response must continue the existing run/message, got %#v", run) + } + if len(run.Events) != 2 { + t.Fatalf("expected approval response and tool result events, got %#v", run.Events) + } + if run.Events[0]["type"] != agui.EventCustom || run.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("missing approval-responded event: %#v", run.Events[0]) + } + if run.Events[1]["type"] != agui.EventToolCallEnd || run.Events[1]["state"] != agui.ToolStateApprovalResponded { + t.Fatalf("missing approval-responded tool end: %#v", run.Events[1]) + } + result := jsonMap(t, run.Events[1]["result"]) + if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { + t.Fatalf("expected structured denied result, got %#v", result) + } + if result["fields"].(map[string]any)["scope"] != "once" || result["metadata"].(map[string]any)["source"] != "reaction" { + t.Fatalf("expected flexible approval fields to survive, got %#v", result) + } + if run.Approvals[0].Fields["scope"] != "once" || run.Approvals[0].Metadata["source"] != "reaction" { + t.Fatalf("expected approval summary fields to survive, got %#v", run.Approvals[0]) + } + + carriers, err := PackRunFromSeq(run, "$anchor", CarrierBudgetBytes, 10) + if err != nil { + t.Fatal(err) + } + if carriers[0].Envelopes[0].Seq != 10 { + t.Fatalf("expected continuation seq 10, got %#v", carriers[0].Envelopes[0]) + } +} + +func TestApprovalResponseRunPreservesApprovedAlways(t *testing.T) { + run := ApprovalResponseRun(ApprovalContext{ + ID: "approval-1", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-1", + ToolCallID: "tool-1", + ToolName: "shell", + TargetEvent: "$anchor", + }, agui.ToolApprovalResponse{Approved: true, Always: true}, time.Unix(10, 0)) + + if run.Status.State != "complete" { + t.Fatalf("expected complete approval response run, got %#v", run.Status) + } + if len(run.Approvals) != 1 || run.Approvals[0].State != "approved-always" || !run.Approvals[0].Always { + t.Fatalf("bad approval summary: %#v", run.Approvals) + } + result := jsonMap(t, run.Events[1]["result"]) + if result["state"] != agui.ToolResultStateComplete || result["approved"] != true || result["always"] != true { + t.Fatalf("bad approval result: %#v", result) + } +} + +func jsonMap(t *testing.T, value any) map[string]any { + t.Helper() + text, ok := value.(string) + if !ok { + t.Fatalf("expected JSON string result, got %#v", value) + } + var out map[string]any + if err := json.Unmarshal([]byte(text), &out); err != nil { + t.Fatalf("failed to parse result %q: %v", text, err) + } + return out +} diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go new file mode 100644 index 0000000..47e1f39 --- /dev/null +++ b/pkg/connector/ai_runtime.go @@ -0,0 +1,1185 @@ +package connector + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "sort" + "strconv" + "strings" + "time" + + "github.com/beeper/dummybridge/pkg/ag-ui" + "github.com/beeper/dummybridge/pkg/ai-stream" +) + +var errApprovalRequested = errors.New("approval requested") + +const ( + defaultChunkMin = 24 + defaultChunkMax = 96 + maxDemoChars = 96 * 1024 + maxDemoReasoningChars = 8192 + maxDemoToolSpecs = 16 + maxDemoSteps = 32 + maxDemoCollections = 16 + maxDemoRandomActions = 64 + maxDemoChaosRuns = 16 + maxDemoChaosActions = 64 + maxDemoDuration = 5 * time.Minute + maxDemoDelay = 30 * time.Second + maxDemoChunkChars = 512 + maxDemoStagger = 30 * time.Second +) + +const ( + randomActionText = "text" + randomActionThinking = "thinking" + randomActionStep = "step" + randomActionTool = "tool" + randomActionToolFail = "tool_fail" + randomActionToolDeny = "tool_deny" + randomActionToolApproval = "tool_approval" + randomActionSource = "source" + randomActionDocument = "document" + randomActionFile = "file" + randomActionMetadata = "metadata" + randomActionData = "data" + randomActionDataTransient = "data_transient" +) + +type commonCommandOptions struct { + ReasoningChars int + Steps int + Sources int + Documents int + Files int + Meta bool + DataName string + DataTransientName string + DelayMin time.Duration + DelayMax time.Duration + ChunkMin int + ChunkMax int + FinishReason string + Abort bool + Error bool + Seed int64 + SeedSet bool +} + +type loremCommand struct { + Chars int + Options commonCommandOptions +} + +type toolSpec struct { + Name string + Tags []string + Fail bool + Approval bool + Deny bool + Delta bool + InputError bool + Preliminary bool + Provider bool + SequenceIndex int +} + +type toolsCommand struct { + Chars int + Tools []toolSpec + Options commonCommandOptions +} + +type sharedStreamOptions struct { + Profile string + Seed int64 + SeedSet bool + AllowAbort bool + AllowError bool + AllowApproval bool +} + +type randomCommand struct { + Duration time.Duration + Actions int + DelayMin time.Duration + DelayMax time.Duration + sharedStreamOptions +} + +type randomActionOption struct { + name string + weight int +} + +type chaosCommand struct { + Runs int + Duration time.Duration + StaggerMin time.Duration + StaggerMax time.Duration + MaxActions int + sharedStreamOptions +} + +type parsedCommand struct { + Name string + Lorem *loremCommand + Tools *toolsCommand + Random *randomCommand + Chaos *chaosCommand +} + +type aiRuntime struct { + now func() time.Time + sleep func(context.Context, time.Duration) error +} + +type aiRunPlan struct { + Run *aistream.Run + Delay time.Duration +} + +func defaultAIRuntime() aiRuntime { + return aiRuntime{ + now: time.Now, + sleep: func(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }, + } +} + +func virtualAIRuntime(now time.Time) aiRuntime { + current := now + return aiRuntime{ + now: func() time.Time { + return current + }, + sleep: func(ctx context.Context, delay time.Duration) error { + if err := ctx.Err(); err != nil { + return err + } + if delay > 0 { + current = current.Add(delay) + } + return nil + }, + } +} + +func buildAIRun(ctx context.Context, runID, threadID, input string, now time.Time) (*aistream.Run, error) { + plans, err := buildAIRunPlans(ctx, runID, threadID, input, now) + if err != nil { + return nil, err + } + if len(plans) == 0 { + return nil, fmt.Errorf("no AI runs built") + } + return plans[0].Run, nil +} + +func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now time.Time) ([]aiRunPlan, error) { + cmd, err := parseCommand(input) + if err != nil { + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, string(aiGhostID), aiGhostName, now) + writer := aistream.NewWriter(run, func() time.Time { return now }) + writer.Start() + writer.Text(err.Error() + "\n\n" + helpText()) + writer.Finish(agui.FinishReasonStop) + return []aiRunPlan{{Run: run}}, nil + } + if cmd != nil && cmd.Chaos != nil { + return buildAIChaosRunPlans(ctx, runID, threadID, now, *cmd.Chaos) + } + run, err := buildAIRunFromCommand(ctx, runID, threadID, now, cmd) + if err != nil { + return nil, err + } + return []aiRunPlan{{Run: run}}, nil +} + +func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand) (*aistream.Run, error) { + runtime := virtualAIRuntime(now) + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, string(aiGhostID), aiGhostName, now) + writer := aistream.NewWriter(run, runtime.now) + writer.Start() + + runner := aiRunner{runtime: runtime} + var err error + switch { + case cmd == nil || cmd.Name == "help": + writer.Text(helpText()) + writer.Finish(agui.FinishReasonStop) + case cmd.Lorem != nil: + err = runner.runLorem(ctx, writer, *cmd.Lorem) + case cmd.Tools != nil: + err = runner.runTools(ctx, writer, *cmd.Tools) + case cmd.Random != nil: + err = runner.runRandom(ctx, writer, *cmd.Random) + } + if errors.Is(err, errApprovalRequested) { + err = nil + } + if err != nil { + writer.Error(err.Error()) + } else if err = agui.ValidateEventSequence(run.Events); err != nil { + writer.Error(err.Error()) + } + return run, nil +} + +func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd chaosCommand) ([]aiRunPlan, error) { + seed := cmd.Seed + if !cmd.SeedSet { + seed = now.UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + runner := aiRunner{runtime: virtualAIRuntime(now)} + actions := max(3, min(cmd.MaxActions, int(cmd.Duration/time.Second))) + plans := make([]aiRunPlan, 0, cmd.Runs) + for i := range cmd.Runs { + var delay time.Duration + if i > 0 { + delay = runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + } + runID := fmt.Sprintf("%s-%d", baseRunID, i+1) + randomCmd := randomCommand{ + Duration: cmd.Duration, + Actions: actions, + DelayMin: 180 * time.Millisecond, + DelayMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{ + Profile: cmd.Profile, + Seed: seed + int64(i+1)*97, + SeedSet: true, + AllowAbort: cmd.AllowAbort, + AllowError: cmd.AllowError, + AllowApproval: cmd.AllowApproval, + }, + } + run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), &parsedCommand{ + Name: "stream-random", + Random: &randomCmd, + }) + if err != nil { + return nil, err + } + plans = append(plans, aiRunPlan{Run: run, Delay: delay}) + } + return plans, nil +} + +func parseCommand(input string) (*parsedCommand, error) { + tokens := strings.Fields(strings.TrimSpace(input)) + if len(tokens) == 0 { + return &parsedCommand{Name: "help"}, nil + } + switch strings.ToLower(tokens[0]) { + case "help", "/help", "!help": + return &parsedCommand{Name: "help"}, nil + case "dummybridge": + if len(tokens) > 1 && strings.EqualFold(tokens[1], "help") { + return &parsedCommand{Name: "help"}, nil + } + return &parsedCommand{Name: "help"}, nil + case "stream-lorem": + cmd, err := parseLoremCommand(tokens[1:]) + return &parsedCommand{Name: "stream-lorem", Lorem: cmd}, err + case "stream-tools": + cmd, err := parseToolsCommand(tokens[1:]) + return &parsedCommand{Name: "stream-tools", Tools: cmd}, err + case "stream-random": + cmd, err := parseRandomCommand(tokens[1:]) + return &parsedCommand{Name: "stream-random", Random: cmd}, err + case "stream-chaos": + cmd, err := parseChaosCommand(tokens[1:]) + return &parsedCommand{Name: "stream-chaos", Chaos: cmd}, err + default: + return &parsedCommand{Name: "stream-lorem", Lorem: &loremCommand{ + Chars: min(max(len(input)*4, 120), 1200), + Options: defaultCommonOptions(), + }}, nil + } +} + +func helpText() string { + return strings.Join([]string{ + "DummyBridge demo commands:", + "help", + "stream-lorem [--reasoning=N] [--steps=N] [--sources=N] [--documents=N] [--files=N] [--meta] [--data=name] [--data-transient=name] [--delay-ms=min:max] [--chunk-chars=min:max] [--seed=N] [--finish=stop|length|tool-calls|content-filter|other] [--abort|--error]", + "stream-tools ... [common options]", + "stream-random [seconds] [--actions=N] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--delay-ms=min:max] [--allow-abort] [--allow-error] [--allow-approval]", + "stream-chaos [runs] [seconds] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--stagger-ms=min:max] [--max-actions=N] [--allow-abort] [--allow-error] [--allow-approval]", + "Notes: approval-tagged tools emit a separate Matrix approval event with reaction options.", + }, "\n") +} + +func defaultCommonOptions() commonCommandOptions { + return commonCommandOptions{ + DelayMin: 30 * time.Millisecond, + DelayMax: 150 * time.Millisecond, + ChunkMin: defaultChunkMin, + ChunkMax: defaultChunkMax, + FinishReason: agui.FinishReasonStop, + } +} + +func parseLoremCommand(tokens []string) (*loremCommand, error) { + if len(tokens) == 0 { + return nil, fmt.Errorf("stream-lorem requires a character count") + } + count, err := parsePositiveInt(tokens[0], "character count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { + return nil, err + } + opts, err := parseCommonOptions(tokens[1:]) + if err != nil { + return nil, err + } + return &loremCommand{Chars: count, Options: opts}, nil +} + +func parseToolsCommand(tokens []string) (*toolsCommand, error) { + if len(tokens) < 2 { + return nil, fmt.Errorf("stream-tools requires a character count and at least one tool") + } + count, err := parsePositiveInt(tokens[0], "character count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { + return nil, err + } + var toolTokens, optTokens []string + for _, token := range tokens[1:] { + if strings.HasPrefix(token, "--") { + optTokens = append(optTokens, token) + } else { + toolTokens = append(toolTokens, token) + } + } + if len(toolTokens) == 0 { + return nil, fmt.Errorf("stream-tools requires at least one tool spec") + } + if err := validateMaxIntValue(len(toolTokens), maxDemoToolSpecs, "tool spec count"); err != nil { + return nil, err + } + opts, err := parseCommonOptions(optTokens) + if err != nil { + return nil, err + } + tools := make([]toolSpec, 0, len(toolTokens)) + for idx, token := range toolTokens { + spec, err := parseToolSpec(token, idx) + if err != nil { + return nil, err + } + tools = append(tools, spec) + } + return &toolsCommand{Chars: count, Tools: tools, Options: opts}, nil +} + +func parseRandomCommand(tokens []string) (*randomCommand, error) { + cmd := &randomCommand{ + Duration: 20 * time.Second, + Actions: 20, + DelayMin: 350 * time.Millisecond, + DelayMax: 1150 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, + } + rest := tokens + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + seconds, err := parsePositiveInt(rest[0], "duration") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { + return nil, err + } + cmd.Duration = time.Duration(seconds) * time.Second + rest = rest[1:] + } + for _, token := range rest { + key, value, hasValue := parseOptionToken(token) + switch key { + case "actions": + n, err := parseValidatedInt(value, hasValue, token, "actions", maxDemoRandomActions, false) + if err != nil { + return nil, err + } + cmd.Actions = n + case "delay-ms": + minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) + if err != nil { + return nil, err + } + cmd.DelayMin, cmd.DelayMax = minDelay, maxDelay + default: + handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) + if err != nil || !handled { + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unknown random option %q", token) + } + } + } + return cmd, nil +} + +func parseChaosCommand(tokens []string) (*chaosCommand, error) { + cmd := &chaosCommand{ + Runs: 3, + Duration: 10 * time.Second, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + MaxActions: 10, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, + } + rest := tokens + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + n, err := parsePositiveInt(rest[0], "run count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(n, maxDemoChaosRuns, "run count"); err != nil { + return nil, err + } + cmd.Runs = n + rest = rest[1:] + } + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + seconds, err := parsePositiveInt(rest[0], "duration") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { + return nil, err + } + cmd.Duration = time.Duration(seconds) * time.Second + rest = rest[1:] + } + for _, token := range rest { + key, value, hasValue := parseOptionToken(token) + switch key { + case "stagger-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) + if err != nil { + return nil, err + } + cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay + case "max-actions": + n, err := parseValidatedInt(value, hasValue, token, "max-actions", maxDemoChaosActions, false) + if err != nil { + return nil, err + } + cmd.MaxActions = n + default: + handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) + if err != nil || !handled { + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unknown chaos option %q", token) + } + } + } + return cmd, nil +} + +func parseCommonOptions(tokens []string) (commonCommandOptions, error) { + opts := defaultCommonOptions() + for _, token := range tokens { + key, value, hasValue := parseOptionToken(token) + switch key { + case "reasoning": + n, err := parseValidatedInt(value, hasValue, token, "reasoning", maxDemoReasoningChars, true) + if err != nil { + return opts, err + } + opts.ReasoningChars = n + case "steps": + n, err := parseValidatedInt(value, hasValue, token, "steps", maxDemoSteps, false) + if err != nil { + return opts, err + } + opts.Steps = n + case "sources": + n, err := parseValidatedInt(value, hasValue, token, "sources", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Sources = n + case "documents": + n, err := parseValidatedInt(value, hasValue, token, "documents", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Documents = n + case "files": + n, err := parseValidatedInt(value, hasValue, token, "files", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Files = n + case "meta": + opts.Meta = true + case "data": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.DataName = value + case "data-transient": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.DataTransientName = value + case "delay-ms": + minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) + if err != nil { + return opts, err + } + opts.DelayMin, opts.DelayMax = minDelay, maxDelay + case "chunk-chars": + minChunk, maxChunk, err := parseIntRangeOption(value, hasValue, token, "chunk-chars", maxDemoChunkChars) + if err != nil { + return opts, err + } + opts.ChunkMin, opts.ChunkMax = minChunk, maxChunk + case "seed": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + seed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return opts, fmt.Errorf("invalid seed %q", value) + } + opts.Seed, opts.SeedSet = seed, true + case "finish": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.FinishReason = agui.NormalizeFinishReason(value) + case "abort": + opts.Abort = true + case "error": + opts.Error = true + default: + return opts, fmt.Errorf("unknown option %q", token) + } + } + if opts.Abort && opts.Error { + return opts, fmt.Errorf("--abort and --error cannot be combined") + } + if (opts.Abort || opts.Error) && opts.FinishReason != agui.FinishReasonStop { + return opts, fmt.Errorf("--finish cannot be combined with --abort or --error") + } + return opts, nil +} + +func parseSharedStreamOption(key, value string, hasValue bool, token string, opts *sharedStreamOptions) (bool, error) { + switch key { + case "profile": + if !hasValue { + return false, fmt.Errorf("%s requires a value", token) + } + switch strings.ToLower(value) { + case "balanced", "tools", "artifacts", "terminals": + opts.Profile = strings.ToLower(value) + default: + return false, fmt.Errorf("unknown profile %q", value) + } + case "seed": + if !hasValue { + return false, fmt.Errorf("%s requires a value", token) + } + seed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid seed %q", value) + } + opts.Seed, opts.SeedSet = seed, true + case "allow-abort": + opts.AllowAbort = true + case "allow-error": + opts.AllowError = true + case "allow-approval": + opts.AllowApproval = true + default: + return false, nil + } + return true, nil +} + +func parseToolSpec(raw string, idx int) (toolSpec, error) { + parts := strings.Split(raw, "#") + spec := toolSpec{Name: strings.TrimSpace(parts[0]), SequenceIndex: idx + 1} + if spec.Name == "" { + return spec, fmt.Errorf("tool spec %q is missing a tool name", raw) + } + for _, tag := range parts[1:] { + tag = strings.TrimSpace(strings.ToLower(tag)) + if tag == "" { + continue + } + spec.Tags = append(spec.Tags, tag) + switch tag { + case "fail": + spec.Fail = true + case "approval": + spec.Approval = true + case "deny": + spec.Deny = true + case "delta": + spec.Delta = true + case "inputerror": + spec.InputError = true + case "prelim": + spec.Preliminary = true + case "provider": + spec.Provider = true + default: + return spec, fmt.Errorf("unknown tool tag %q in %q", tag, raw) + } + } + finalStates := 0 + for _, enabled := range []bool{spec.Fail, spec.Approval, spec.Deny} { + if enabled { + finalStates++ + } + } + if finalStates > 1 { + return spec, fmt.Errorf("tool spec %q has conflicting final state tags", raw) + } + return spec, nil +} + +type aiRunner struct { + runtime aiRuntime +} + +func (r aiRunner) runLorem(ctx context.Context, w *aistream.Writer, cmd loremCommand) error { + opts := cmd.Options + rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) + steps := max(opts.Steps, 1) + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) + for step := range steps { + if opts.Steps > 0 { + w.StepStart(fmt.Sprintf("step-%d", step+1)) + } + emitDecorations(w, opts, cmd.Chars, step, steps) + if reasoning != "" { + w.Thinking(sliceByStep(reasoning, steps, step)) + } + for _, chunk := range chunkText(sliceByStep(text, steps, step), rng, opts.ChunkMin, opts.ChunkMax) { + w.Text(chunk) + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } + } + if opts.Steps > 0 { + w.StepFinish(fmt.Sprintf("step-%d", step+1)) + } + } + finishWriter(w, opts) + return nil +} + +func (r aiRunner) runTools(ctx context.Context, w *aistream.Writer, cmd toolsCommand) error { + opts := cmd.Options + rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) + phaseCount := max(len(cmd.Tools)+1, max(opts.Steps, 1)) + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) + for phase := range phaseCount { + w.StepStart(fmt.Sprintf("phase-%d", phase+1)) + emitDecorations(w, opts, cmd.Chars, phase, phaseCount) + if reasoning != "" { + w.Thinking(sliceByStep(reasoning, phaseCount, phase)) + } + for _, chunk := range chunkText(sliceByStep(text, phaseCount, phase), rng, opts.ChunkMin, opts.ChunkMax) { + w.Text(chunk) + } + if phase < len(cmd.Tools) { + if err := r.runToolSpec(ctx, w, cmd.Tools[phase], rng, opts); err != nil { + if errors.Is(err, errApprovalRequested) { + w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) + } + return err + } + } + w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) + } + finishWriter(w, opts) + return nil +} + +func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomCommand) error { + seed := cmd.Seed + if !cmd.SeedSet { + seed = r.runtime.now().UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + started := r.runtime.now() + var deadline time.Time + if cmd.Duration > 0 { + deadline = started.Add(cmd.Duration) + } + stepOpen := false + stepName := "" + for action := range cmd.Actions { + if !deadline.IsZero() && !r.runtime.now().Before(deadline) { + break + } + if action > 0 { + delay := r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax) + if !deadline.IsZero() && r.runtime.now().Add(delay).After(deadline) { + delay = deadline.Sub(r.runtime.now()) + } + if err := r.runtime.sleep(ctx, delay); err != nil { + return err + } + if !deadline.IsZero() && !r.runtime.now().Before(deadline) { + break + } + } + switch chooseRandomAction(cmd, rng) { + case randomActionText: + for _, chunk := range chunkText(buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))), rng, defaultChunkMin, defaultChunkMax) { + w.Text(chunk) + } + case randomActionThinking: + w.Thinking(buildLoremText(30+rng.Intn(120), rand.New(rand.NewSource(rng.Int63())))) + case randomActionStep: + if stepOpen { + w.StepFinish(stepName) + stepOpen = false + stepName = "" + } else { + stepName = fmt.Sprintf("random-step-%d", action+1) + w.StepStart(stepName) + stepOpen = true + } + case randomActionTool: + _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}, rng, defaultCommonOptions()) + case randomActionToolFail: + _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), Fail: true, SequenceIndex: action + 1}, rng, defaultCommonOptions()) + case randomActionToolDeny: + _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), Deny: true, SequenceIndex: action + 1}, rng, defaultCommonOptions()) + case randomActionToolApproval: + _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}, rng, defaultCommonOptions()) + case randomActionSource: + w.Custom("com.beeper.source", map[string]any{"url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) + case randomActionDocument: + w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("random-doc-%d", action+1), "title": fmt.Sprintf("Random Document %d", action+1), "mediaType": "text/plain"}) + case randomActionFile: + w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/random-file-%d", action+1), "mediaType": "application/octet-stream"}) + case randomActionMetadata: + w.StateDelta(statePatch(map[string]any{"command": "stream-random", "seed": seed, "action": action + 1, "profile": cmd.Profile})) + case randomActionData: + w.Custom("com.beeper.data", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) + case randomActionDataTransient: + w.Custom("com.beeper.data.transient", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) + } + } + if stepOpen { + w.StepFinish(stepName) + } + switch chooseRandomTerminal(cmd, rng) { + case "abort": + w.Abort("DummyBridge random mode aborted") + case "error": + w.Error("DummyBridge random mode failed") + default: + w.Finish(agui.FinishReasonStop) + } + return nil +} + +func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec toolSpec, rng *rand.Rand, opts commonCommandOptions) error { + toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) + input := map[string]any{"tool": spec.Name, "sequence": spec.SequenceIndex, "tags": spec.Tags} + approvalID := approvalIDForRun(w.Run.RunID, toolCallID) + var approval *agui.ToolApproval + if spec.Approval { + approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} + } + w.ToolStart(toolCallID, spec.Name, spec.SequenceIndex-1, approval) + annotateProviderRawEvent(w, spec, "tool_call_start") + if spec.InputError { + w.ToolArgs(toolCallID, jsonToolInput(input), nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + w.ToolError(toolCallID, spec.Name, input, "input-error") + annotateProviderRawEvent(w, spec, "tool_call_error") + return nil + } + if spec.Delta { + for _, chunk := range chunkText(fmt.Sprintf("{\"tool\":%q,\"sequence\":%d}", spec.Name, spec.SequenceIndex), rng, opts.ChunkMin, opts.ChunkMax) { + w.ToolArgs(toolCallID, chunk, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } + } + } else { + encodedInput := jsonToolInput(input) + w.ToolArgs(toolCallID, encodedInput, encodedInput) + annotateProviderRawEvent(w, spec, "tool_call_args") + } + if spec.Preliminary { + w.ToolResult(toolCallID, fmt.Sprintf(`{"state":%q,"tool":%q}`, agui.ToolResultStateStreaming, spec.Name), agui.ToolResultStateStreaming) + annotateProviderRawEvent(w, spec, "tool_call_result") + } + switch { + case spec.Approval: + w.ToolApprovalInputComplete(toolCallID, spec.Name, input) + annotateProviderRawEvent(w, spec, "tool_call_input_complete") + w.ToolApprovalRequested(toolCallID, spec.Name, input, *approval) + annotateProviderRawEvent(w, spec, "approval_requested") + return errApprovalRequested + case spec.Deny: + w.ToolDenied(toolCallID, spec.Name, input, approvalID, "denied") + annotateProviderRawEvent(w, spec, "tool_call_denied") + case spec.Fail: + w.ToolError(toolCallID, spec.Name, input, "DummyBridge synthetic tool failure") + annotateProviderRawEvent(w, spec, "tool_call_error") + default: + w.ToolEnd(toolCallID, spec.Name, input, map[string]any{"status": "ok", "tool": spec.Name, "sequence": spec.SequenceIndex}) + annotateProviderRawEvent(w, spec, "tool_call_end") + } + return nil +} + +func approvalIDForRun(runID, toolCallID string) string { + return "approval-" + runID + "-" + toolCallID +} + +func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { + if !spec.Provider || w == nil || w.Run == nil || len(w.Run.Events) == 0 { + return + } + w.Run.Events[len(w.Run.Events)-1]["rawEvent"] = map[string]any{ + "provider": "dummybridge", + "stage": stage, + "tool": spec.Name, + "sequence": spec.SequenceIndex, + "tags": spec.Tags, + } +} + +func jsonToolInput(input map[string]any) string { + raw, err := json.Marshal(input) + if err != nil { + return "{}" + } + return string(raw) +} + +func finishWriter(w *aistream.Writer, opts commonCommandOptions) { + switch { + case opts.Abort: + w.Abort("DummyBridge synthetic abort") + case opts.Error: + w.Error("DummyBridge synthetic error") + default: + w.Finish(opts.FinishReason) + } +} + +func emitDecorations(w *aistream.Writer, opts commonCommandOptions, chars, step, steps int) { + if opts.Meta { + seed := opts.Seed + if !opts.SeedSet { + seed = int64(chars) + } + w.StateDelta(statePatch(map[string]any{"command": "demo", "seed": seed, "step": step + 1})) + } + for i := range splitCount(opts.Sources, steps, step) { + w.Custom("com.beeper.source", map[string]any{"url": fmt.Sprintf("https://dummybridge.local/source/%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Source %d.%d", step+1, i+1)}) + } + for i := range splitCount(opts.Documents, steps, step) { + w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("demo-doc-%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Document %d.%d", step+1, i+1), "mediaType": "text/plain"}) + } + for i := range splitCount(opts.Files, steps, step) { + w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/demo-file-%d-%d", step+1, i+1), "mediaType": "application/octet-stream"}) + } + if step == 0 && opts.DataName != "" { + w.Custom("com.beeper.data", map[string]any{"name": opts.DataName, "value": map[string]any{"mode": "persistent", "stage": step + 1}}) + } + if step == 0 && opts.DataTransientName != "" { + w.Custom("com.beeper.data.transient", map[string]any{"name": opts.DataTransientName, "value": map[string]any{"mode": "transient", "stage": step + 1}}) + } +} + +func statePatch(values map[string]any) []map[string]any { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + patch := make([]map[string]any, 0, len(keys)) + for _, key := range keys { + patch = append(patch, map[string]any{ + "op": "add", + "path": "/" + key, + "value": values[key], + }) + } + return patch +} + +func chooseRandomAction(cmd randomCommand, rng *rand.Rand) string { + options := []randomActionOption{ + {randomActionText, 6}, + {randomActionThinking, 4}, + {randomActionStep, 2}, + {randomActionTool, 3}, + {randomActionToolFail, 2}, + {randomActionSource, 2}, + {randomActionDocument, 2}, + {randomActionFile, 2}, + {randomActionMetadata, 2}, + {randomActionData, 1}, + {randomActionDataTransient, 1}, + } + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 2}) + } + switch cmd.Profile { + case "tools": + options = append(options, + randomActionOption{randomActionTool, 6}, + randomActionOption{randomActionToolFail, 4}, + randomActionOption{randomActionToolDeny, 3}, + ) + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 4}) + } + case "artifacts": + options = append(options, + randomActionOption{randomActionSource, 4}, + randomActionOption{randomActionDocument, 4}, + randomActionOption{randomActionFile, 4}, + randomActionOption{randomActionMetadata, 3}, + randomActionOption{randomActionData, 3}, + randomActionOption{randomActionDataTransient, 3}, + ) + case "terminals": + options = append(options, randomActionOption{randomActionStep, 5}) + } + total := 0 + for _, option := range options { + total += option.weight + } + pick := rng.Intn(total) + for _, option := range options { + if pick < option.weight { + return option.name + } + pick -= option.weight + } + return randomActionText +} + +func chooseRandomTerminal(cmd randomCommand, rng *rand.Rand) string { + options := []string{"finish"} + if cmd.AllowAbort { + options = append(options, "abort") + } + if cmd.AllowError { + options = append(options, "error") + } + return options[rng.Intn(len(options))] +} + +func randomToolName(rng *rand.Rand) string { + names := []string{"search", "fetch", "summarize", "calendar", "shell", "files", "preview"} + return names[rng.Intn(len(names))] +} + +func (r aiRunner) sampleDelay(rng *rand.Rand, minDelay, maxDelay time.Duration) time.Duration { + if maxDelay <= minDelay { + return minDelay + } + return minDelay + time.Duration(rng.Int63n(int64(maxDelay-minDelay)+1)) +} + +func parseOptionToken(token string) (string, string, bool) { + trimmed := strings.TrimPrefix(strings.TrimSpace(token), "--") + key, value, ok := strings.Cut(trimmed, "=") + return strings.ToLower(strings.TrimSpace(key)), strings.TrimSpace(value), ok +} + +func parseValidatedInt(value string, hasValue bool, token, label string, maxValue int, allowZero bool) (int, error) { + if !hasValue { + return 0, fmt.Errorf("%s requires a value", token) + } + var n int + var err error + if allowZero { + n, err = parseNonNegativeInt(value, label) + } else { + n, err = parsePositiveInt(value, label) + } + if err != nil { + return 0, err + } + return n, validateMaxIntValue(n, maxValue, label) +} + +func parsePositiveInt(raw, label string) (int, error) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || n <= 0 { + return 0, fmt.Errorf("invalid %s %q", label, raw) + } + return n, nil +} + +func parseNonNegativeInt(raw, label string) (int, error) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || n < 0 { + return 0, fmt.Errorf("invalid %s %q", label, raw) + } + return n, nil +} + +func parseDurationRangeMS(value string, hasValue bool, token string) (time.Duration, time.Duration, error) { + return parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) +} + +func parseDurationRange(value string, hasValue bool, token, label string, maxValue time.Duration) (time.Duration, time.Duration, error) { + minValue, maxRange, err := parseIntRangeOption(value, hasValue, token, label, int(maxValue/time.Millisecond)) + if err != nil { + return 0, 0, err + } + return time.Duration(minValue) * time.Millisecond, time.Duration(maxRange) * time.Millisecond, nil +} + +func parseIntRangeOption(value string, hasValue bool, token, label string, maxValue int) (int, int, error) { + if !hasValue { + return 0, 0, fmt.Errorf("%s requires a value", token) + } + minValue, maxRange, ok := strings.Cut(value, ":") + if !ok { + n, err := parseNonNegativeInt(value, label) + if err != nil { + return 0, 0, err + } + if err := validateMaxIntValue(n, maxValue, label); err != nil { + return 0, 0, err + } + return n, n, nil + } + minInt, err := parseNonNegativeInt(minValue, label) + if err != nil { + return 0, 0, err + } + maxInt, err := parseNonNegativeInt(maxRange, label) + if err != nil { + return 0, 0, err + } + if maxInt < minInt { + return 0, 0, fmt.Errorf("invalid %s range %q", label, value) + } + if err := validateMaxIntValue(maxInt, maxValue, label); err != nil { + return 0, 0, err + } + return minInt, maxInt, nil +} + +func validateMaxIntValue(value, maxValue int, label string) error { + if value > maxValue { + return fmt.Errorf("%s %d exceeds the maximum of %d", label, value, maxValue) + } + return nil +} + +func rngForOptions(seedSet bool, seed, fallback int64) *rand.Rand { + if !seedSet { + seed = fallback + } + return rand.New(rand.NewSource(seed)) +} + +func chunkText(text string, rng *rand.Rand, minChunk, maxChunk int) []string { + if strings.TrimSpace(text) == "" { + return nil + } + if minChunk <= 0 { + minChunk = defaultChunkMin + } + if maxChunk < minChunk { + maxChunk = minChunk + } + var chunks []string + for len(text) > 0 { + size := minChunk + if maxChunk > minChunk { + size += rng.Intn(maxChunk - minChunk + 1) + } + if size > len(text) { + size = len(text) + } + chunks = append(chunks, text[:size]) + text = text[size:] + } + return chunks +} + +func splitCount(total, parts, index int) int { + if total <= 0 || parts <= 0 || index < 0 || index >= parts { + return 0 + } + base := total / parts + remainder := total % parts + if index < remainder { + return base + 1 + } + return base +} + +func sliceByStep(text string, parts, index int) string { + if parts <= 1 || text == "" { + return text + } + start := 0 + for i := 0; i < index; i++ { + start += splitCount(len(text), parts, i) + } + length := splitCount(len(text), parts, index) + if start >= len(text) || length <= 0 { + return "" + } + end := min(start+length, len(text)) + return text[start:end] +} + +func sanitizeToolName(name string) string { + name = strings.ToLower(strings.TrimSpace(name)) + var out strings.Builder + for _, r := range name { + if r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '_' || r == '-' { + out.WriteRune(r) + } + } + if out.Len() == 0 { + return "tool" + } + return out.String() +} diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go new file mode 100644 index 0000000..97feefe --- /dev/null +++ b/pkg/connector/ai_runtime_test.go @@ -0,0 +1,506 @@ +package connector + +import ( + "context" + "encoding/json" + "math/rand" + "strings" + "testing" + "time" + + "github.com/beeper/dummybridge/pkg/ag-ui" + "github.com/beeper/dummybridge/pkg/ai-stream" +) + +func TestParseCommandRecognizesHelpAliases(t *testing.T) { + for _, input := range []string{"help", "/help", "!help", "dummybridge help"} { + cmd, err := parseCommand(input) + if err != nil { + t.Fatalf("parseCommand(%q) returned error: %v", input, err) + } + if cmd == nil || cmd.Name != "help" { + t.Fatalf("expected help command for %q, got %#v", input, cmd) + } + } +} + +func TestParseCommandRejectsConflictingToolTags(t *testing.T) { + _, err := parseCommand("stream-tools 100 shell#fail#approval") + if err == nil { + t.Fatal("expected parse error for conflicting tool tags") + } +} + +func TestParseCommandRejectsInvalidProfilesAndOversizedOptions(t *testing.T) { + tests := []string{ + "stream-random --profile=unknown", + "stream-lorem 100 --abort --error", + "stream-lorem 100 --finish=length --abort", + "stream-lorem 1000000", + "stream-tools 100 shell --chunk-chars=1:9999", + } + for _, input := range tests { + if _, err := parseCommand(input); err == nil { + t.Fatalf("expected parse error for %q", input) + } + } +} + +func TestHelpTextMentionsCommandsOptionsAndToolTags(t *testing.T) { + guide := helpText() + for _, expected := range []string{ + "stream-lorem", + "stream-tools", + "stream-random", + "stream-chaos", + "--data-transient", + "--allow-approval", + "#provider", + "#inputerror", + } { + if !strings.Contains(guide, expected) { + t.Fatalf("help text missing %q:\n%s", expected, guide) + } + } +} + +func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-lorem 400 --reasoning=80 --steps=2 --sources=1 --documents=1 --files=1 --meta --data=demo --data-transient=temp --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + seen := map[string]bool{} + for _, evt := range run.Events { + switch evt["type"] { + case agui.EventTextMessageContent, agui.EventStepStarted, agui.EventStepFinished: + seen[evt["type"].(string)] = true + case agui.EventStateDelta: + seen[evt["type"].(string)] = true + if _, ok := evt["delta"].([]map[string]any); !ok { + t.Fatalf("STATE_DELTA should use JSON Patch array, got %#v", evt["delta"]) + } + case agui.EventCustom: + name, _ := evt["name"].(string) + seen[name] = true + if name == "com.beeper.data" { + value := evt["value"].(map[string]any) + if value["name"] == "temp" { + t.Fatal("transient data must not persist as metadata") + } + } + } + } + for _, key := range []string{agui.EventTextMessageContent, agui.EventStepStarted, agui.EventStepFinished, agui.EventStateDelta, "com.beeper.source", "com.beeper.document", "com.beeper.file", "com.beeper.data", "com.beeper.data.transient"} { + if !seen[key] { + t.Fatalf("missing %s in events", key) + } + } + metadata := run.Metadata() + if metadata["model"] == "" || metadata["threadId"] != "thread-1" || metadata["runId"] != "run-1" { + t.Fatalf("bad metadata: %#v", metadata) + } + data := metadata["data"].(map[string]any) + if _, ok := data["temp"]; ok { + t.Fatalf("transient data leaked into final metadata: %#v", data) + } +} + +func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Prompts) != 1 { + t.Fatalf("expected one approval prompt, got %#v", run.Prompts) + } + if run.Prompts[0].ID != "approval-run-1-dummy-tool-1-shell" { + t.Fatalf("approval prompt ID = %q, want run-scoped ID", run.Prompts[0].ID) + } + foundToolStart := false + seenArgsBeforeApproval := false + seenApprovalStateBeforeCustom := false + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallArgs { + seenArgsBeforeApproval = true + } + if evt["type"] == agui.EventToolCallStart { + if evt["state"] != agui.ToolStateApprovalRequested { + t.Fatalf("expected approval-requested tool state, got %#v", evt) + } + approval, ok := evt["approval"].(*agui.ToolApproval) + if !ok { + t.Fatalf("expected tool start approval metadata, got %#v", evt["approval"]) + } + if approval.ID != "approval-run-1-dummy-tool-1-shell" || !approval.NeedsApproval { + t.Fatalf("bad approval metadata: %#v", approval) + } + foundToolStart = true + } + if evt["type"] == agui.EventToolCallEnd { + if evt["state"] == agui.ToolStateInputComplete { + t.Fatalf("approval tool must not downgrade to input-complete: %#v", evt) + } + if evt["state"] == agui.ToolStateApprovalRequested { + if evt["input"] == nil { + t.Fatalf("approval input-complete event should include final input: %#v", evt) + } + seenApprovalStateBeforeCustom = true + } + } + if evt["type"] == agui.EventCustom && evt["name"] == agui.ApprovalCustomRequested { + if !seenArgsBeforeApproval || !seenApprovalStateBeforeCustom { + t.Fatalf("approval custom event should be emitted after tool args and approval state update: %#v", run.Events) + } + value := evt["value"].(map[string]any) + if _, hasOptions := value["options"]; hasOptions { + t.Fatalf("AG-UI approval event must not embed Matrix reaction options: %#v", value) + } + if value["input"] == nil { + t.Fatalf("approval event should include final tool input: %#v", value) + } + } + } + if !foundToolStart { + t.Fatal("missing tool start event") + } + if run.Status.State != "streaming" { + t.Fatalf("approval request should pause the run without terminal status, got %#v", run.Status) + } + for _, evt := range run.Events { + if evt["type"] == agui.EventRunFinished { + t.Fatalf("approval request should not finish the run before response: %#v", run.Events) + } + } +} + +func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#deny --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallEnd { + continue + } + result := jsonResultMap(t, evt["result"]) + if result["state"] == agui.ToolResultStateError && result["reason"] == "denied" { + return + } + } + t.Fatalf("missing structured denied tool result: %#v", run.Events) +} + +func TestBuildAIRunToolsArgsAreJSONStrings(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallArgs { + continue + } + args, ok := evt["args"].(string) + if !ok { + t.Fatalf("expected args to be JSON string, got %#v", evt["args"]) + } + if !strings.Contains(args, `"tool":"shell"`) { + t.Fatalf("expected JSON tool args, got %q", args) + } + return + } + t.Fatal("missing TOOL_CALL_ARGS event") +} + +func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#prelim --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallResult { + continue + } + if evt["state"] != agui.ToolResultStateStreaming || evt["toolCallId"] == "" || evt["content"] == "" { + t.Fatalf("bad TOOL_CALL_RESULT event: %#v", evt) + } + return + } + t.Fatal("missing TOOL_CALL_RESULT event") +} + +func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#prelim fetch#fail --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + var snapshot []agui.UIMessage + seenRunFinished := false + for _, evt := range run.Events { + switch evt["type"] { + case agui.EventMessagesSnapshot: + if seenRunFinished { + t.Fatal("final snapshot must be emitted before RUN_FINISHED") + } + var ok bool + snapshot, ok = evt["messages"].([]agui.UIMessage) + if !ok { + t.Fatalf("bad snapshot payload: %#v", evt["messages"]) + } + case agui.EventRunFinished: + seenRunFinished = true + } + } + if len(snapshot) != 1 { + t.Fatalf("expected one final UI message snapshot, got %#v", snapshot) + } + seenToolCall := false + seenToolResult := false + for _, part := range snapshot[0].Parts { + switch part["type"] { + case "tool-call": + seenToolCall = true + case "tool-result": + seenToolResult = true + } + } + if !seenToolCall || !seenToolResult { + t.Fatalf("final snapshot lost tool parts: %#v", snapshot[0].Parts) + } +} + +func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#fail fetch#delta parser#inputerror --seed=7 --chunk-chars=8:8", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + seenFailure := false + seenDelta := false + seenInputError := false + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallEnd && evt["type"] != agui.EventToolCallArgs { + continue + } + toolCallID, _ := evt["toolCallId"].(string) + if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") && evt["args"] == nil { + seenDelta = true + } + if evt["type"] == agui.EventToolCallEnd { + result := jsonResultMap(t, evt["result"]) + if strings.Contains(toolCallID, "shell") && result["state"] == agui.ToolResultStateError { + seenFailure = true + } + if strings.Contains(toolCallID, "parser") && result["reason"] == "input-error" { + seenInputError = true + } + } + } + if !seenFailure || !seenDelta || !seenInputError { + t.Fatalf("missing tool tag coverage: failure=%v delta=%v inputError=%v", seenFailure, seenDelta, seenInputError) + } +} + +func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#provider --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + raw, ok := evt["rawEvent"].(map[string]any) + if !ok { + continue + } + if raw["provider"] != "dummybridge" || raw["tool"] != "shell" { + t.Fatalf("bad raw provider event: %#v", raw) + } + carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + if len(carriers) == 0 { + t.Fatal("expected packed carriers") + } + return + } + t.Fatal("missing rawEvent for provider-tagged tool") +} + +func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { + errorRun, err := buildAIRun(context.Background(), "run-error", "thread-1", "stream-lorem 80 --error --seed=7", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if errorRun.Status.State != "error" { + t.Fatalf("expected error status, got %#v", errorRun.Status) + } + abortRun, err := buildAIRun(context.Background(), "run-abort", "thread-1", "stream-lorem 80 --abort --seed=7", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if abortRun.Status.State != "aborted" { + t.Fatalf("expected aborted status, got %#v", abortRun.Status) + } +} + +func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-lorem 70000 --seed=7 --chunk-chars=512:512", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + if len(carriers) < 2 { + t.Fatalf("expected split carriers, got %d", len(carriers)) + } + for i, carrier := range carriers { + if size := aistream.JSONSize(aistream.CarrierContent(carrier.Envelopes)); size > aistream.CarrierBudgetBytes { + t.Fatalf("carrier %d size = %d", i, size) + } + } + for _, carrier := range carriers { + for _, envelope := range carrier.Envelopes { + if envelope.Part["type"] != agui.EventMessagesSnapshot { + continue + } + raw, err := json.Marshal(envelope.Part) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), strings.Repeat("a", 60*1024)) { + t.Fatal("final snapshot should not repeat full streamed text") + } + } + } + if len(aistream.ReconstructText(carriers)) < 60*1024 { + t.Fatalf("expected large reconstructed output, got %d", len(aistream.ReconstructText(carriers))) + } +} + +func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-1", "stream-chaos 3 1 --max-actions=3 --seed=7 --stagger-ms=1:1", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if len(plans) != 3 { + t.Fatalf("expected three chaos runs, got %d", len(plans)) + } + seen := map[string]bool{} + for i, plan := range plans { + if plan.Run == nil { + t.Fatalf("nil run at %d", i) + } + if seen[plan.Run.RunID] { + t.Fatalf("duplicate run ID %q", plan.Run.RunID) + } + seen[plan.Run.RunID] = true + if plan.Run.ThreadID != "thread-1" { + t.Fatalf("bad thread ID: %q", plan.Run.ThreadID) + } + if i > 0 && plan.Delay <= 0 { + t.Fatalf("expected nonzero child stagger delay, got %#v", plans) + } + } +} + +func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-random 3 --actions=4 --seed=7 --delay-ms=100:100", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + var first, last int64 + for _, evt := range run.Events { + ts, _ := evt["timestamp"].(int64) + if ts == 0 { + if n, ok := evt["timestamp"].(int); ok { + ts = int64(n) + } + } + if ts == 0 { + continue + } + if first == 0 { + first = ts + } + last = ts + } + if first == 0 || last-first < 300 { + t.Fatalf("expected random run timestamps to reflect action delays, first=%d last=%d", first, last) + } +} + +func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { + cmd := randomCommand{ + sharedStreamOptions: sharedStreamOptions{ + Profile: "tools", + AllowApproval: true, + }, + } + seen := map[string]bool{} + rng := rand.New(rand.NewSource(4)) + for range 400 { + seen[chooseRandomAction(cmd, rng)] = true + } + for _, action := range []string{randomActionTool, randomActionToolFail, randomActionToolDeny, randomActionToolApproval} { + if !seen[action] { + t.Fatalf("tools profile never selected %s; seen=%#v", action, seen) + } + } + + cmd.Profile = "artifacts" + seen = map[string]bool{} + rng = rand.New(rand.NewSource(8)) + for range 400 { + seen[chooseRandomAction(cmd, rng)] = true + } + for _, action := range []string{randomActionSource, randomActionDocument, randomActionFile, randomActionMetadata, randomActionData, randomActionDataTransient} { + if !seen[action] { + t.Fatalf("artifacts profile never selected %s; seen=%#v", action, seen) + } + } +} + +func TestRandomTerminalUsesAllowedOutcomes(t *testing.T) { + cmd := randomCommand{sharedStreamOptions: sharedStreamOptions{AllowAbort: true, AllowError: true}} + seen := map[string]bool{} + rng := rand.New(rand.NewSource(10)) + for range 80 { + seen[chooseRandomTerminal(cmd, rng)] = true + } + for _, terminal := range []string{"finish", "abort", "error"} { + if !seen[terminal] { + t.Fatalf("terminal %s was never selected; seen=%#v", terminal, seen) + } + } + + if terminal := chooseRandomTerminal(randomCommand{}, rand.New(rand.NewSource(1))); terminal != "finish" { + t.Fatalf("unexpected terminal without flags: %q", terminal) + } +} + +func TestBuildDemoVisibleTextIsMarkdownRichAndDeterministic(t *testing.T) { + first := buildDemoVisibleText(420, rand.New(rand.NewSource(7))) + second := buildDemoVisibleText(420, rand.New(rand.NewSource(7))) + if first != second { + t.Fatalf("expected deterministic output for seed") + } + for _, signal := range []string{"[", "](", "**", "\n- ", "\n> ", "```", "\n| "} { + if strings.Contains(first, signal) { + return + } + } + t.Fatalf("expected markdown-rich text, got %q", first) +} + +func jsonResultMap(t *testing.T, value any) map[string]any { + t.Helper() + text, ok := value.(string) + if !ok { + t.Fatalf("expected JSON string result, got %#v", value) + } + var out map[string]any + if err := json.Unmarshal([]byte(text), &out); err != nil { + t.Fatalf("failed to parse result %q: %v", text, err) + } + return out +} diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go new file mode 100644 index 0000000..960fe8e --- /dev/null +++ b/pkg/connector/ai_text.go @@ -0,0 +1,214 @@ +package connector + +import ( + "fmt" + "math/rand" + "strings" +) + +var loremSentenceCorpus = []string{ + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.", + "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.", + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.", + "Integer nec odio praesent libero sed cursus ante dapibus diam.", + "Nulla quis sem at nibh elementum imperdiet duis sagittis ipsum.", + "Praesent mauris fusce nec tellus sed augue semper porta.", + "Mauris massa vestibulum lacinia arcu eget nulla.", + "Class aptent taciti sociosqu ad litora torquent per conubia nostra.", + "In consectetur orci eu erat varius, vitae facilisis lorem blandit.", + "Curabitur ullamcorper ultricies nisi nam eget dui etiam rhoncus.", +} + +var demoMarkdownLabels = []string{"release notes", "ops runbook", "incident log", "design memo", "qa checklist", "support brief"} +var demoMarkdownURLs = []string{ + "https://dummybridge.local/docs/streaming", + "https://dummybridge.local/docs/markdown", + "https://dummybridge.local/runbooks/runs", + "https://dummybridge.local/notes/demo-output", +} +var demoMarkdownEmphasis = []string{"high-signal", "operator-visible", "tool-safe", "incremental", "review-ready"} +var demoMarkdownListItems = []string{ + "Confirm the seeded output changes shape between runs.", + "Surface enough formatting to stress the renderer.", + "Keep deltas readable while chunks arrive out of phase.", + "Preserve stable output for deterministic test fixtures.", + "Expose links, tables, and code blocks without extra flags.", +} +var demoMarkdownQuoteCorpus = []string{ + "Streaming output should feel alive, not like the same paragraph repeated forever.", + "Richer markdown gives the client something realistic to render while the run is still open.", +} +var demoMarkdownCodeSnippets = []string{ + "const preview = chunks.filter(Boolean).join(\"\");", + "writer.textDelta(\"| status | value |\\n| --- | --- |\\n\");", + "if (seeded) { return renderMarkdownBlocks(); }", +} +var demoMarkdownTableHeaders = [][]string{{"Metric", "Value", "Notes"}, {"Phase", "Owner", "Status"}, {"Artifact", "State", "Latency"}} +var demoMarkdownTableRows = [][]string{ + {"stream", "warming", "steady deltas"}, + {"renderer", "active", "accepts markdown"}, + {"tool call", "complete", "output persisted"}, + {"search step", "queued", "awaiting sources"}, + {"summary", "ready", "links attached"}, +} + +type demoSegmentSpec struct { + weight int + minLen int + build func(*rand.Rand, int) string +} + +func buildLoremText(chars int, rng *rand.Rand) string { + if chars <= 0 { + return "" + } + if rng == nil { + rng = rand.New(rand.NewSource(int64(chars))) + } + var sb strings.Builder + sb.Grow(chars + 128) + lastIndex := -1 + for sb.Len() < chars+64 { + index := rng.Intn(len(loremSentenceCorpus)) + if len(loremSentenceCorpus) > 1 && index == lastIndex { + index = (index + 1 + rng.Intn(len(loremSentenceCorpus)-1)) % len(loremSentenceCorpus) + } + if sb.Len() > 0 { + sb.WriteByte(' ') + } + sb.WriteString(loremSentenceCorpus[index]) + lastIndex = index + } + return trimText(sb.String(), chars) +} + +func buildDemoVisibleText(chars int, rng *rand.Rand) string { + if chars <= 0 { + return "" + } + if rng == nil { + rng = rand.New(rand.NewSource(int64(chars))) + } + segments := []demoSegmentSpec{ + {weight: 5, minLen: 48, build: func(rng *rand.Rand, remaining int) string { + return buildLoremText(max(48, min(168, remaining+48)), rand.New(rand.NewSource(rng.Int63()))) + }}, + {weight: 4, minLen: 96, build: func(rng *rand.Rand, _ int) string { + return fmt.Sprintf("%s Review the [%s](%s) entry for **%s** output and _staged_ formatting transitions.", + buildLoremText(72+rng.Intn(48), rand.New(rand.NewSource(rng.Int63()))), + demoMarkdownLabels[rng.Intn(len(demoMarkdownLabels))], + demoMarkdownURLs[rng.Intn(len(demoMarkdownURLs))], + demoMarkdownEmphasis[rng.Intn(len(demoMarkdownEmphasis))]) + }}, + {weight: 3, minLen: 96, build: func(rng *rand.Rand, _ int) string { + var lines []string + for i := 0; i < 2+rng.Intn(3); i++ { + prefix := "-" + if rng.Intn(4) == 0 { + prefix = "- [x]" + } + lines = append(lines, fmt.Sprintf("%s %s", prefix, demoMarkdownListItems[(rng.Intn(len(demoMarkdownListItems))+i)%len(demoMarkdownListItems)])) + } + return strings.Join(lines, "\n") + }}, + {weight: 2, minLen: 72, build: func(rng *rand.Rand, _ int) string { + return fmt.Sprintf("> %s\n>\n> %s", demoMarkdownQuoteCorpus[rng.Intn(len(demoMarkdownQuoteCorpus))], buildLoremText(48+rng.Intn(36), rand.New(rand.NewSource(rng.Int63())))) + }}, + {weight: 2, minLen: 72, build: func(rng *rand.Rand, _ int) string { + return fmt.Sprintf("Use `%s` for incremental patches.\n\n```js\n%s\n```", sanitizeToolName(demoMarkdownLabels[rng.Intn(len(demoMarkdownLabels))]), demoMarkdownCodeSnippets[rng.Intn(len(demoMarkdownCodeSnippets))]) + }}, + {weight: 2, minLen: 180, build: func(rng *rand.Rand, _ int) string { + header := demoMarkdownTableHeaders[rng.Intn(len(demoMarkdownTableHeaders))] + lines := []string{fmt.Sprintf("| %s |", strings.Join(header, " | ")), "| --- | --- | --- |"} + for i := 0; i < 2+rng.Intn(2); i++ { + lines = append(lines, fmt.Sprintf("| %s |", strings.Join(demoMarkdownTableRows[(rng.Intn(len(demoMarkdownTableRows))+i)%len(demoMarkdownTableRows)], " | "))) + } + return strings.Join(lines, "\n") + }}, + } + var blocks []string + total := 0 + for total < chars { + block := chooseDemoSegment(segments, rng, chars-total) + blocks = append(blocks, block) + total += len(block) + 2 + } + return trimVisibleText(strings.Join(blocks, "\n\n"), chars) +} + +func chooseDemoSegment(specs []demoSegmentSpec, rng *rand.Rand, remaining int) string { + var candidates []demoSegmentSpec + total := 0 + for _, spec := range specs { + if remaining > 0 && remaining < spec.minLen/2 { + continue + } + candidates = append(candidates, spec) + total += spec.weight + } + if len(candidates) == 0 { + candidates = specs + for _, spec := range candidates { + total += spec.weight + } + } + target := rng.Intn(total) + for _, spec := range candidates { + target -= spec.weight + if target < 0 { + return spec.build(rng, remaining) + } + } + return candidates[0].build(rng, remaining) +} + +func trimVisibleText(text string, limit int) string { + text = strings.TrimSpace(text) + if len(text) <= limit { + return text + } + blocks := strings.Split(text, "\n\n") + var kept []string + total := 0 + for _, block := range blocks { + block = strings.TrimSpace(block) + if block == "" { + continue + } + next := total + len(block) + if len(kept) > 0 { + next += 2 + } + if next > limit && len(kept) > 0 { + break + } + kept = append(kept, block) + total = next + } + if len(kept) > 0 { + return strings.Join(kept, "\n\n") + } + return trimText(text, limit) +} + +func trimText(text string, limit int) string { + text = strings.TrimSpace(text) + if limit <= 0 || len(text) <= limit { + return text + } + minCutoff := max(1, (limit*3)/4) + for i := min(limit, len(text)); i >= minCutoff; i-- { + switch text[i-1] { + case '.', '!', '?': + return strings.TrimSpace(text[:i]) + } + } + for i := min(limit, len(text)); i >= minCutoff; i-- { + if text[i-1] == ' ' { + return strings.Trim(strings.TrimSpace(text[:i]), ".,;:") + } + } + return strings.Trim(strings.TrimSpace(text[:limit]), ".,;:") +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 38f3d05..566669a 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -2,6 +2,7 @@ package connector import ( "context" + "encoding/json" "errors" "fmt" "regexp" @@ -9,6 +10,9 @@ import ( "sync" "time" + "github.com/beeper/dummybridge/pkg/ag-ui" + "github.com/beeper/dummybridge/pkg/ai-stream" + aibridgev2 "github.com/beeper/dummybridge/pkg/ai-stream/bridgev2" "github.com/rs/zerolog/log" "go.mau.fi/util/jsontime" "go.mau.fi/util/ptr" @@ -19,6 +23,7 @@ import ( "maunium.net/go/mautrix/bridgev2/simplevent" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) type DummyClient struct { @@ -28,13 +33,24 @@ type DummyClient struct { UserLogin *bridgev2.UserLogin Connector *DummyConnector + + approvalMu sync.Mutex + approvalSelections map[string]string } var _ bridgev2.NetworkAPI = (*DummyClient)(nil) var _ bridgev2.IdentifierResolvingNetworkAPI = (*DummyClient)(nil) +var _ bridgev2.ContactListingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.BackfillingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.DeleteChatHandlingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*DummyClient)(nil) +var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) + +const ( + aiGhostID networkid.UserID = "ai" + aiGhostName = "AI" + aiPortalIDPrefix = "ai-" +) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -125,6 +141,14 @@ func (dc *DummyClient) IsThisUser(ctx context.Context, userID networkid.UserID) } func (dc *DummyClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if isAIPortalID(portal.ID) { + roomType := database.RoomTypeDM + return &bridgev2.ChatInfo{ + Name: ptr.Ptr(aiGhostName), + Type: ptr.Ptr(roomType), + }, nil + } + portalIDPrefix := string(portal.ID) if len(portalIDPrefix) > 6 { portalIDPrefix = portalIDPrefix[:6] @@ -151,6 +175,17 @@ func (dc *DummyClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) } func (tc *DummyClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { + if ghost.ID == aiGhostID { + name := aiGhostName + isBot := true + ghost.UpdateName(ctx, name) + return &bridgev2.UserInfo{ + Identifiers: []string{string(aiGhostID), "AI"}, + Name: &name, + IsBot: &isBot, + }, nil + } + name := ghost.Name if name == "" { name = string(ghost.ID) @@ -198,16 +233,153 @@ func (dc *DummyClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma messageID = networkid.MessageID(msg.Event.Unsigned.TransactionID) } - return &bridgev2.MatrixMessageResponse{ + resp := &bridgev2.MatrixMessageResponse{ DB: &database.Message{ ID: messageID, SenderID: networkid.UserID(dc.UserLogin.ID), Timestamp: timestamp, }, StreamOrder: time.Now().UnixNano(), + } + + if msg.Portal != nil && isAIPortalID(msg.Portal.ID) { + dc.queueAIResponse(ctx, msg.Portal, msg.Content) + } + + return resp, nil +} + +func (dc *DummyClient) PreHandleMatrixReaction(_ context.Context, msg *bridgev2.MatrixReaction) (bridgev2.MatrixReactionPreResponse, error) { + if msg == nil || msg.Content == nil { + return bridgev2.MatrixReactionPreResponse{}, nil + } + senderID := networkid.UserID("") + if dc != nil && dc.UserLogin != nil { + senderID = networkid.UserID(dc.UserLogin.ID) + } + key := normalizeApprovalReaction(msg.Content.RelatesTo.Key) + return bridgev2.MatrixReactionPreResponse{ + SenderID: senderID, + EmojiID: networkid.EmojiID(key), + Emoji: key, + MaxReactions: 1, }, nil } +func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (*database.Reaction, error) { + if dc == nil || dc.UserLogin == nil || msg == nil || msg.TargetMessage == nil || msg.Content == nil || msg.Portal == nil { + return &database.Reaction{}, nil + } + approvalID := string(msg.TargetMessage.ID) + if !strings.HasPrefix(approvalID, "approval-") { + return &database.Reaction{}, nil + } + reaction := normalizeApprovalReaction(msg.Content.RelatesTo.Key) + selected, ok := aistream.ResolveReaction(aistream.DefaultApprovalOptions(approvalID), reaction) + if !ok { + return &database.Reaction{}, nil + } + + selectedKey, firstResolution := dc.resolveApprovalOnce(approvalID, reaction) + dc.cleanupApprovalReactions(ctx, msg.Portal, networkid.MessageID(approvalID), selectedKey, msg) + if !firstResolution { + log.Info(). + Str("approval_id", approvalID). + Str("reaction", reaction). + Str("selected_reaction", selectedKey). + Msg("Ignoring duplicate dummy AI approval reaction") + return &database.Reaction{}, nil + } + dc.queueAIApprovalResponse(ctx, msg.Portal, msg.TargetMessage, selected.Value) + + log.Info(). + Str("approval_id", approvalID). + Str("reaction", reaction). + Bool("approved", selected.Value.Approved). + Stringer("sender", msg.Event.Sender). + Msg("Resolved dummy AI approval from Matrix reaction") + + return &database.Reaction{}, nil +} + +func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (string, bool) { + dc.approvalMu.Lock() + defer dc.approvalMu.Unlock() + if dc.approvalSelections == nil { + dc.approvalSelections = make(map[string]string) + } + if existing := dc.approvalSelections[approvalID]; existing != "" { + return existing, false + } + dc.approvalSelections[approvalID] = selectedKey + return selectedKey, true +} + +func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bridgev2.Portal, approvalMessageID networkid.MessageID, selectedKey string, msg *bridgev2.MatrixReaction) { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { + return + } + reactions, err := dc.UserLogin.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, approvalMessageID) + if err != nil { + log.Warn().Err(err).Str("approval_id", string(approvalMessageID)).Msg("Failed to load approval reactions") + return + } + events := make([]aistream.ReactionEvent, 0, len(reactions)+1) + reactionByMXID := make(map[string]*database.Reaction, len(reactions)) + for _, reaction := range reactions { + if reaction == nil || reaction.MXID == "" { + continue + } + eventID := string(reaction.MXID) + reactionByMXID[eventID] = reaction + events = append(events, aistream.ReactionEvent{ + EventID: eventID, + Sender: string(reaction.SenderID), + Key: reaction.Emoji, + Bridge: reaction.SenderID == aiGhostID, + }) + } + if msg != nil && msg.Event != nil && msg.Event.ID != "" { + events = append(events, aistream.ReactionEvent{ + EventID: string(msg.Event.ID), + Sender: string(msg.Event.Sender), + Key: selectedKey, + }) + } + cleanup := aistream.CleanupReactions(aistream.DefaultApprovalOptions(string(approvalMessageID)), selectedKey, events, string(aiGhostID)) + intent, ok := portal.GetIntentFor(ctx, bridgev2.EventSender{Sender: aiGhostID}, dc.UserLogin, bridgev2.RemoteEventMessageRemove) + if !ok || intent == nil { + log.Warn().Str("approval_id", string(approvalMessageID)).Msg("Failed to resolve AI sender intent for approval reaction cleanup") + return + } + for _, reactionEventID := range cleanup.RedactReactionEvents { + reactionMXID := id.EventID(reactionEventID) + _, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{Redacts: reactionMXID}, + }, nil) + if err != nil { + log.Warn().Err(err).Stringer("reaction_mxid", reactionMXID).Msg("Failed to redact approval reaction") + continue + } + if reaction := reactionByMXID[reactionEventID]; reaction != nil { + if err := dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, reaction); err != nil { + log.Warn().Err(err).Stringer("reaction_mxid", reaction.MXID).Msg("Failed to delete approval reaction") + } + } + } +} + +func (dc *DummyClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { + if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && msg != nil && msg.TargetReaction != nil { + _ = dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, msg.TargetReaction) + } + return nil +} + +func normalizeApprovalReaction(reaction string) string { + return strings.TrimSpace(strings.ReplaceAll(reaction, "\ufe0f", "")) +} + func getTransactionID(msg *bridgev2.MatrixMessage) networkid.TransactionID { if msg.Event != nil && msg.Event.Unsigned.TransactionID != "" { return networkid.TransactionID(msg.Event.Unsigned.TransactionID) @@ -290,6 +462,277 @@ func cloneMessageContent(content *event.MessageEventContent) *event.MessageEvent return &cloned } +func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Portal, inbound *event.MessageEventContent) { + if portal == nil { + return + } + + now := time.Now() + runID := "run-" + string(randomMessageID()) + plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), inboundBody(inbound), now) + if err != nil { + log.Warn().Err(err).Msg("Failed to build AI runs") + return + } + for _, plan := range plans { + if plan.Run == nil { + continue + } + timestamp := now.Add(plan.Delay) + placeholderID := networkid.MessageID(plan.Run.MessageID) + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, aiGhostID, initialAIAnchorRun(*plan.Run), timestamp)) + + go dc.queueAIRunStreamAndMetadata(portal, placeholderID, *plan.Run) + } +} + +func initialAIAnchorRun(run aistream.Run) aistream.Run { + run.Status = aistream.Status{State: "streaming"} + run.Usage = agui.Usage{} + return run +} + +func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, messageID networkid.MessageID, run aistream.Run) { + targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) + if targetEventID == "" { + log.Warn(). + Str("run_id", run.RunID). + Str("message_id", string(messageID)). + Msg("Timed out waiting for AI anchor Matrix event") + return + } + carriers, err := dc.queueAICarriers(portal, targetEventID, run, 1) + if err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") + return + } + nextSeq := aistream.NextSeq(carriers) + for i, prompt := range run.Prompts { + prompt.SeqStart = nextSeq + i*10 + dc.queueAIApprovalPrompt(portal, run, prompt, targetEventID, time.Now()) + } + dc.queueAIRunFinalMetadata(portal, messageID, run) +} + +func (dc *DummyClient) queueAICarriers(portal *bridgev2.Portal, targetEventID id.EventID, run aistream.Run, startSeq int) ([]aistream.Carrier, error) { + carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + if err != nil { + return nil, err + } + for i, carrier := range carriers { + now := time.Now() + dc.UserLogin.QueueRemoteEvent(aibridgev2.Carrier(portal.PortalKey, aiGhostID, run, carrier, targetEventID, startSeq+i, now)) + } + return carriers, nil +} + +func (dc *DummyClient) waitForMessageMXID( + portal *bridgev2.Portal, + messageID networkid.MessageID, + timeout time.Duration, +) id.EventID { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { + return "" + } + parent := dc.ctx + if parent == nil { + parent = context.Background() + } + ctx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + + receivers := []networkid.UserLoginID{portal.Receiver} + if dc.UserLogin.ID != "" && dc.UserLogin.ID != portal.Receiver { + receivers = append(receivers, dc.UserLogin.ID) + } + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for ctx.Err() == nil { + for _, receiver := range receivers { + mxid := dc.lookupMessageMXID(ctx, receiver, messageID) + if mxid != "" { + return mxid + } + } + select { + case <-ctx.Done(): + return "" + case <-ticker.C: + } + } + return "" +} + +func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) id.EventID { + var mxid id.EventID + err := dc.UserLogin.Bridge.DB.Message.GetDB().QueryRow( + ctx, + `SELECT mxid FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`, + dc.UserLogin.Bridge.DB.Message.BridgeID, + receiver, + messageID, + ).Scan(&mxid) + if err != nil { + return "" + } + return mxid +} + +func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, timestamp time.Time) { + reactions := aistream.DefaultApprovalOptions(prompt.ID) + approvalCtx := aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: string(targetEventID), + AgentID: run.AgentID, + AgentName: run.AgentName, + Model: run.Model, + SeqStart: prompt.SeqStart, + } + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, aiGhostID, approvalCtx, timestamp)) + + for i, reaction := range reactions { + reaction := reaction + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, aiGhostID, approvalCtx, reaction, timestamp.Add(time.Duration(i+1)*time.Millisecond))) + } +} + +func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response agui.ToolApprovalResponse) { + approvalCtx, ok := dc.approvalContextForMessage(ctx, portal, approvalMessage) + if !ok { + log.Warn().Str("approval_id", messageIDString(approvalMessage)).Msg("Missing AI approval metadata") + return + } + if response.ID == "" { + response.ID = approvalCtx.ID + } + now := time.Now() + run := aistream.ApprovalResponseRun(approvalCtx, response, now) + targetEventID := id.EventID(approvalCtx.TargetEvent) + if targetEventID == "" { + log.Warn().Str("approval_id", approvalCtx.ID).Msg("Missing AI approval target event") + return + } + if _, err := dc.queueAICarriers(portal, targetEventID, run, approvalCtx.SeqStart); err != nil { + log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to queue AI approval response") + } +} + +func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { + var fetch func(context.Context, networkid.MessageID) (*database.Message, error) + if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && dc.UserLogin.Bridge.DB != nil && portal != nil { + fetch = func(ctx context.Context, messageID networkid.MessageID) (*database.Message, error) { + return dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, messageID) + } + } + return approvalContextForMessage(ctx, message, fetch) +} + +func approvalContextForMessage(ctx context.Context, message *database.Message, fetch func(context.Context, networkid.MessageID) (*database.Message, error)) (aistream.ApprovalContext, bool) { + if approvalCtx, ok := approvalContextFromMetadata(message); ok { + return approvalCtx, true + } + if message == nil || message.ID == "" || fetch == nil { + return aistream.ApprovalContext{}, false + } + fetched, err := fetch(ctx, message.ID) + if err != nil { + log.Warn().Err(err).Str("approval_id", string(message.ID)).Msg("Failed to reload AI approval message") + return aistream.ApprovalContext{}, false + } + return approvalContextFromMetadata(fetched) +} + +func approvalContextFromMetadata(message *database.Message) (aistream.ApprovalContext, bool) { + if message == nil { + return aistream.ApprovalContext{}, false + } + return approvalContextFromAny(message.Metadata) +} + +func approvalContextFromAny(value any) (aistream.ApprovalContext, bool) { + switch typed := value.(type) { + case aistream.ApprovalContext: + return validApprovalContext(typed) + case *aistream.ApprovalContext: + if typed == nil { + return aistream.ApprovalContext{}, false + } + return validApprovalContext(*typed) + case map[string]any: + if nested, ok := typed["com.beeper.ai.approval"]; ok { + return approvalContextFromAny(nested) + } + case *map[string]any: + if typed == nil { + return aistream.ApprovalContext{}, false + } + return approvalContextFromAny(*typed) + case json.RawMessage: + return approvalContextFromJSON(typed) + case []byte: + return approvalContextFromJSON(typed) + case string: + return approvalContextFromJSON([]byte(typed)) + } + var ctx aistream.ApprovalContext + raw, err := json.Marshal(value) + if err != nil { + return aistream.ApprovalContext{}, false + } + if err = json.Unmarshal(raw, &ctx); err != nil { + return aistream.ApprovalContext{}, false + } + return validApprovalContext(ctx) +} + +func approvalContextFromJSON(raw []byte) (aistream.ApprovalContext, bool) { + var decoded any + if err := json.Unmarshal(raw, &decoded); err == nil { + if approvalCtx, ok := approvalContextFromAny(decoded); ok { + return approvalCtx, true + } + } + var ctx aistream.ApprovalContext + if err := json.Unmarshal(raw, &ctx); err != nil { + return aistream.ApprovalContext{}, false + } + return validApprovalContext(ctx) +} + +func messageIDString(message *database.Message) string { + if message == nil { + return "" + } + return string(message.ID) +} + +func validApprovalContext(ctx aistream.ApprovalContext) (aistream.ApprovalContext, bool) { + if ctx.ID == "" || ctx.ThreadID == "" || ctx.RunID == "" || ctx.MessageID == "" || ctx.ToolCallID == "" || ctx.TargetEvent == "" { + return aistream.ApprovalContext{}, false + } + if ctx.SeqStart <= 0 { + ctx.SeqStart = 1 + } + return ctx, true +} + +func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, messageID networkid.MessageID, run aistream.Run) { + dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, aiGhostID, messageID, run, time.Now())) +} + +func inboundBody(content *event.MessageEventContent) string { + if content == nil { + return "" + } + return content.Body +} + func (dc *DummyClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { // bridgev2 will delete the portal + Matrix room after this returns nil. // For dummybridge, there's no separate remote-side deletion to do. @@ -309,6 +752,10 @@ func (dc *DummyClient) HandleMatrixAcceptMessageRequest(ctx context.Context, msg } func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + if isAIIdentifier(identifier) { + return dc.resolveAIIdentifier(ctx, createChat) + } + userID := networkid.UserID(identifier) portalID := randomPortalID() portalKey := networkid.PortalKey{ @@ -357,3 +804,81 @@ func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, }, nil } + +func (dc *DummyClient) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { + contact, err := dc.resolveAIIdentifier(ctx, false) + if err != nil { + return nil, err + } + return []*bridgev2.ResolveIdentifierResponse{contact}, nil +} + +func (dc *DummyClient) resolveAIIdentifier(ctx context.Context, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { + ghost, err := dc.UserLogin.Bridge.GetGhostByID(ctx, aiGhostID) + if err != nil { + return nil, fmt.Errorf("failed to get AI ghost: %w", err) + } + userInfo, _ := dc.GetUserInfo(ctx, ghost) + response := &bridgev2.ResolveIdentifierResponse{ + Ghost: ghost, + UserID: aiGhostID, + UserInfo: userInfo, + } + if !createChat { + return response, nil + } + + portalID := networkid.PortalID(aiPortalIDPrefix + string(randomPortalID())) + portalKey := networkid.PortalKey{ID: portalID, Receiver: dc.UserLogin.ID} + portal, err := dc.UserLogin.Bridge.GetPortalByKey(ctx, portalKey) + if err != nil { + return nil, fmt.Errorf("failed to get AI portal: %w", err) + } + roomType := database.RoomTypeDM + response.Chat = &bridgev2.CreateChatResponse{ + Portal: portal, + PortalKey: portalKey, + PortalInfo: &bridgev2.ChatInfo{ + Name: ptr.Ptr(aiGhostName), + Topic: ptr.Ptr("DummyBridge AI chat"), + Type: ptr.Ptr(roomType), + CanBackfill: true, + Members: &bridgev2.ChatMemberList{ + Members: []bridgev2.ChatMember{ + { + EventSender: bridgev2.EventSender{ + IsFromMe: true, + Sender: networkid.UserID(dc.UserLogin.ID), + }, + Membership: event.MembershipJoin, + PowerLevel: ptr.Ptr(100), + }, + { + EventSender: bridgev2.EventSender{ + Sender: aiGhostID, + }, + Membership: event.MembershipJoin, + PowerLevel: ptr.Ptr(50), + MemberEventExtra: map[string]any{ + "displayname": aiGhostName, + "com.beeper.ai.agent": string(aiGhostID), + "com.beeper.ai.model_id": aistream.DefaultModel, + "com.beeper.ai.protocol": "ag-ui", + "com.beeper.ai.static_ai": true, + }, + }, + }, + }, + }, + } + return response, nil +} + +func isAIIdentifier(identifier string) bool { + identifier = strings.TrimSpace(identifier) + return strings.EqualFold(identifier, string(aiGhostID)) || strings.EqualFold(identifier, aiGhostName) +} + +func isAIPortalID(portalID networkid.PortalID) bool { + return strings.HasPrefix(string(portalID), aiPortalIDPrefix) +} diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index ed2a445..da31658 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -1,9 +1,15 @@ package connector import ( + "context" + "encoding/json" "testing" "time" + "github.com/beeper/dummybridge/pkg/ag-ui" + "github.com/beeper/dummybridge/pkg/ai-stream" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" ) @@ -37,3 +43,73 @@ func TestGetRemoteEchoBehavior(t *testing.T) { }) } } + +func TestResolveApprovalOnceKeepsFirstSelection(t *testing.T) { + client := &DummyClient{} + selected, first := client.resolveApprovalOnce("approval-1", "allow") + if !first || selected != "allow" { + t.Fatalf("first selection = %q first=%v", selected, first) + } + selected, first = client.resolveApprovalOnce("approval-1", "deny") + if first || selected != "allow" { + t.Fatalf("second selection = %q first=%v", selected, first) + } +} + +func TestInitialAIAnchorRunKeepsPreviewButNotTerminalMetadata(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + writer.Text("visible preview") + writer.Finish(agui.FinishReasonStop) + + anchor := initialAIAnchorRun(*run) + if anchor.Preview.Text == "" { + t.Fatal("expected anchor to keep useful preview text") + } + if anchor.Status.State != "streaming" { + t.Fatalf("anchor status = %#v, want streaming", anchor.Status) + } + if anchor.Usage.TotalTokens != 0 || anchor.Usage.CompletionTokens != 0 || anchor.Usage.PromptTokens != 0 { + t.Fatalf("anchor leaked terminal usage: %#v", anchor.Usage) + } + if run.Status.State != "complete" || run.Usage.TotalTokens == 0 { + t.Fatalf("final run should keep terminal metadata: status=%#v usage=%#v", run.Status, run.Usage) + } +} + +func TestApprovalContextForMessageFallsBackToStoredMessage(t *testing.T) { + want := aistream.ApprovalContext{ + ID: "approval-1", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-1", + ToolCallID: "tool-1", + TargetEvent: "$event", + SeqStart: 12, + } + stub := &database.Message{ID: "approval-1"} + rawMetadata, err := json.Marshal(map[string]any{"com.beeper.ai.approval": want}) + if err != nil { + t.Fatal(err) + } + fetched := &database.Message{ID: "approval-1", Metadata: rawMetadata} + called := false + + got, ok := approvalContextForMessage(context.Background(), stub, func(_ context.Context, messageID networkid.MessageID) (*database.Message, error) { + called = true + if messageID != stub.ID { + t.Fatalf("fetch message ID = %q, want %q", messageID, stub.ID) + } + return fetched, nil + }) + if !ok { + t.Fatal("expected approval context") + } + if !called { + t.Fatal("expected fallback fetch") + } + if got.ID != want.ID || got.RunID != want.RunID || got.TargetEvent != want.TargetEvent || got.SeqStart != want.SeqStart { + t.Fatalf("approval context = %#v, want %#v", got, want) + } +} diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 91ea6b1..cfad894 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -44,7 +44,15 @@ func (dc *DummyConnector) Start(ctx context.Context) error { } func (dc *DummyConnector) GetCapabilities() *bridgev2.NetworkGeneralCapabilities { - return &bridgev2.NetworkGeneralCapabilities{} + return &bridgev2.NetworkGeneralCapabilities{ + Provisioning: bridgev2.ProvisioningCapabilities{ + ResolveIdentifier: bridgev2.ResolveIdentifierCapabilities{ + CreateDM: true, + LookupUsername: true, + ContactList: true, + }, + }, + } } func (dc *DummyConnector) GetBridgeInfoVersion() (info, caps int) { @@ -62,7 +70,10 @@ func (dc *DummyConnector) GetName() bridgev2.BridgeName { } func (dc *DummyConnector) GetDBMetaTypes() database.MetaTypes { - return database.MetaTypes{} + return database.MetaTypes{ + Message: func() any { return &map[string]any{} }, + Reaction: func() any { return &map[string]any{} }, + } } //go:embed example-config.yaml From 0008955d6fa0d61699730585dddd9e6c558838c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 20:00:37 +0200 Subject: [PATCH 02/46] wip --- TANSTACK_AG_UI_PARITY_PLAN.md | 267 +++++++++++++++++++++++++------- pkg/ai-stream/approval.go | 25 +-- pkg/ai-stream/matrix/content.go | 6 +- pkg/ai-stream/pack.go | 149 ++++++++++++++---- pkg/ai-stream/run.go | 5 +- pkg/ai-stream/stream_test.go | 118 +++++++++++++- pkg/connector/client.go | 30 ++-- 7 files changed, 483 insertions(+), 117 deletions(-) diff --git a/TANSTACK_AG_UI_PARITY_PLAN.md b/TANSTACK_AG_UI_PARITY_PLAN.md index 6f600f5..ca3562c 100644 --- a/TANSTACK_AG_UI_PARITY_PLAN.md +++ b/TANSTACK_AG_UI_PARITY_PLAN.md @@ -8,20 +8,102 @@ Build dummybridge AI around current TanStack AG-UI primitives in both directions - Desktop consumes multi-event encrypted AG-UI streams and hides carrier events from the normal timeline. - Shared Go packages define the primitive contract instead of preserving old AI SDK or agentremote decisions. -Do not install dependencies or modify lockfiles unless the user explicitly approves that dependency change. `@tanstack/ai-react-ui` is approved for the Desktop rendering work in this plan. - -## Current State - -The dummybridge repo currently has a provisional `pkg/aichats` package and AI handling in `pkg/connector/client.go`. - -Known current behavior: +Do not install dependencies, update dependencies, or modify lockfiles unless the user explicitly approves that exact dependency change. Before adding `@tanstack/ai-react-ui` or changing the Desktop lockfile, verify whether it is already present in the current Desktop checkout; if it is not, ask first. + +## Plan Of Record + +This is the intended behavior to preserve while finishing the implementation: + +1. The normal AG-UI stream is an ordered delta log. Envelopes have `seq`, never `seqTotal`, carrier totals, or final event counts. +2. Every run has one visible Matrix anchor message. Normal stream carriers and finalization carriers are hidden transport events that merge into that anchor. +3. The final supported-client state is complete AG-UI state: text, thinking, tool calls/results, approval state, sources/files/data/state, terminal status, usage, model, and run metadata. +4. Finalization sends hidden carriers first, then the compact final Matrix edit last. The final edit stops streaming and carries Matrix-native preview HTML; it does not need to carry the full generated text or full parts array. +5. Over-budget final state uses a base `MESSAGES_SNAPSHOT` plus `CUSTOM name="com.beeper.ai.final-parts"` continuations. Continuations contain only relation data and omitted parts, not repeated full metadata. +6. PAS/Desktop must process hidden stream/finalization carriers before treating the final edit as the point where streaming stops. From the renderer's perspective, there is always one final AI message. +7. Approval prompts remain separate visible Matrix messages for actionability, while semantic approval state is also represented in AG-UI so supported clients can show it inline as a tool-call state. +8. Unsupported clients are not the primary target, but the final Matrix edit must still be a coherent bounded Matrix HTML preview for timeline, search, and notifications. + +Do not reintroduce these rejected approaches: + +- No `seqTotal`, carrier totals, or final event totals on normal streaming envelopes. +- No visible carrier bubbles as a fallback for unsupported or failed merge behavior. +- No final full-text Matrix edit for long runs. The final edit is a bounded preview plus compact metadata. +- No broad `as any` or whole-object assertions at the Desktop TanStack render boundary. +- No duplicate Beeper-only UI message model where TanStack types already describe the part contract. +- No package-manager install/update/lockfile change without explicit approval. + +## Implementation Status + +As of this checkout, the plan should be read as a completion/audit checklist rather than a blank design doc. + +Done in dummybridge: + +- New `pkg/ag-ui`, `pkg/ai-stream`, `pkg/ai-stream/matrix`, and `pkg/ai-stream/bridgev2` packages exist. +- The old provisional `pkg/aichats` package has been removed. +- Normal stream envelopes use ordered `seq` and do not carry `seqTotal`. +- `MESSAGES_SNAPSHOT` finalization can split into a metadata-preserving base snapshot plus `com.beeper.ai.final-parts` continuations. +- Large text/thinking final parts split at UTF-8 boundaries and are reassembled by the supported client model. +- Carrier replay for built runs is contiguous; synthetic timestamps no longer add random delays between already-built carrier sends. +- Final anchor edits use mautrix Markdown rendering for Matrix HTML preview content. +- Approval response carriers are queued before the final metadata edit for the anchor. + +Done in the related Desktop checkout: + +- Carrier-only encrypted events with non-empty `*.deltas` are routed as hidden stream updates instead of normal timeline upserts. +- `com.beeper.ai.final-parts` continuations merge into the existing TanStack-shaped UI message by `messageId`, `runId`, and `partOffset`. +- The AI renderer path uses a typed high-level adapter at the TanStack render boundary instead of asserting the whole message as `any`. +- `src/renderer/ai/ui-message.ts` uses typed builders/guards for TanStack and Beeper custom parts instead of broad `MutableUIPart`/record-level assertions. + +Completion status: + +- The non-visual plan gates are implemented and verified by the evidence below. +- Full Desktop typecheck is still red, but the failures are outside the touched AI/PAS files and are listed below. +- Visual testing remains explicitly excluded from the current completion target. + +Completion gates: + +- Unit/focused tests pass for dummybridge and touched Desktop AI/PAS paths. +- Full Desktop typecheck either passes or every failure is documented as unrelated to touched AI/PAS files. +- Live staging smoke proves over-64KB output produces one visible AI anchor and hidden carriers only. +- Live staging smoke proves approve and deny both finalize through hidden response carriers before the final anchor edit. +- Replay/backfill, redaction/delete, and missing-gap behavior have either automated tests or an explicit live/manual verification note. +- `rg` source scan proves runtime source does not emit `seqTotal`. + +Current verification snapshot, 2026-05-19: + +- `go test -mod=readonly ./...` in dummybridge passes. +- Desktop focused AI/PAS tests pass: + - `bun run test --run src/common/ai-common.test.ts src/renderer/ai/ui-message.test.ts src/renderer/ai/stream-ordering.test.ts src/renderer/stores/AIChatsStore.test.ts src/pas-server/beeper/EventSyncContext.test.ts src/pas-server/beeper/connect/ws-event-mapper.test.ts src/pas-server/beeper/connect/ws-events-server.test.ts` +- Full Desktop typecheck still fails in unrelated files outside the touched AI/PAS paths: + - `BrandLink.stories.tsx` + - `ComposeMessage/TextArea/TextArea.tsx` + - `DetachedAccountsOnboarding.tsx` + - `electron-ipc.ts` + - `measureInteractionNextPaint.ts` + - `QuickRepliesPrefsSubView.tsx` +- Runtime source scans find no total-count fields: + - `rg -n 'seqTotal|carrierTotal|finalEventTotal' pkg cmd --glob '!**/*_test.go'` + - `rg -n 'seqTotal|carrierTotal|finalEventTotal' src/common src/renderer/ai src/renderer/stores src/pas-server/beeper --glob '!**/*.test.ts' --glob '!**/*.test.tsx'` +- Desktop AI render/store source scan finds no broad assertions in the touched AI paths: + - `rg -n '\] as any|as MutableUIPart|MutableUIPart|ToolUIPartRecord|console\.log\(' src/renderer/ai src/renderer/stores/AIChatsStore.ts` +- Live staging over-64KB smoke passed after lowering the raw carrier budget to 40KB: one visible AI anchor, hidden carriers, no visible `com.beeper.ai.final-parts` leakage, and Matrix HTML preview on the final edit. +- Live staging approval approve and deny pass: prompt remains separate, selected user reaction remains, bridge option reactions are redacted, response carriers are queued, and final anchor edit preserves the existing preview instead of reverting to `...`. +- Live staging random/chaos smoke passes for hidden-carrier behavior: no carrier bubbles appeared in Desktop API output, approval prompts stayed separate, and final edits did not regress completed approval anchors. +- Replay/backfill has unit coverage through batched `updates` extraction and `AIChatsStore` replay into an existing anchor. +- Deleted/redacted carrier and missing-gap behavior have unit coverage: carrier deletion marks the anchor failed while keeping carrier events hidden, and unresolved sequence gaps now fail via timer without requiring another stream event. + +## Baseline And Existing Entry Points + +The original plan replaced a provisional `pkg/aichats` package plus AI handling in `pkg/connector/client.go`. In this checkout, `pkg/aichats` should stay deleted; the active implementation is the new `pkg/ag-ui` and `pkg/ai-stream` stack plus the connector integration points. + +Behavior baseline to preserve or improve: - AI DM resolution uses the `ai`/`AI` ghost and AI portals with the `ai-` prefix. -- The bridge sends one visible placeholder event, streams `com.beeper.llm.deltas`, then edits the placeholder with final content. -- Current deltas are AG-UI-like but incomplete. -- Current approval requests are separate Matrix events with `com.beeper.ai.approval` metadata and reaction options. -- Current approval reaction handling should keep the user's selected emoji and remove the bridge-posted placeholder/non-selected options, but this needs robust implementation and tests. -- Current text streaming has started moving away from full accumulated content on each delta, but the final design must enforce that. +- The bridge sends one visible anchor/placeholder event, streams `com.beeper.llm.deltas`, then edits the anchor with final compact metadata and Matrix preview content. +- Stream deltas must be real AG-UI events, not AG-UI-like compatibility shapes. +- Approval requests are separate Matrix events with `com.beeper.ai.approval` metadata and reaction options. +- Approval reaction handling keeps the user's selected emoji and removes bridge-posted placeholder/non-selected options. +- Text streaming must send incremental deltas and must not resend full accumulated text on every delta. Desktop already has partial AI stream support in these areas: @@ -67,7 +149,7 @@ Primary dummybridge checkout: - `/Users/batuhan/Projects/labs/dummybridge/config-agui.yaml` - `/Users/batuhan/Projects/labs/dummybridge/config-qa-agui.yaml` -Current dummybridge AI implementation to replace: +Legacy dummybridge AI implementation that should not be restored: - `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/agui.go` - `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/matrix.go` @@ -77,7 +159,7 @@ Current dummybridge AI implementation to replace: - `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/login.go` - `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/example-config.yaml` -New dummybridge package targets: +Active dummybridge package targets: - `/Users/batuhan/Projects/labs/dummybridge/pkg/ag-ui` - `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream` @@ -334,6 +416,7 @@ Envelope shape: Ordering and merge key: - `seq` is strictly increasing per `{target_event, runId}`. +- Do not put total counts such as `seqTotal`, carrier count, or final event count on normal stream envelopes. The streaming layer is an ordered event stream, not a pre-counted file transfer. - If `target_event` is unavailable during early processing, temporarily key by `{threadId, runId}` and promote to `{target_event, runId}` when the anchor message is known. - Desktop buffers out-of-order deltas within existing ordering limits. - Duplicate or stale `seq` values are ignored or rejected consistently. @@ -341,17 +424,17 @@ Ordering and merge key: Size budget: - Treat 64KB as the external ceiling. -- Use a hard carrier budget of 58KB for serialized Matrix content to leave buffer for encryption overhead, wrappers, event metadata, and implementation variance. +- Use a hard carrier budget of 40KB for serialized Matrix content. Live staging E2EE sends showed that 58KB raw carrier content can become 66-79KB encrypted Matrix event content, so the budget must leave room for megolm/base64/wrapper overhead. - The packer must measure serialized JSON byte size before adding an envelope to a carrier. -- If a single text delta would exceed the 58KB carrier budget, split it at UTF-8 rune boundaries. -- If a non-text event cannot fit inside the 58KB budget, return a validation error rather than sending it. +- If a single text delta would exceed the carrier budget, split it at UTF-8 rune boundaries. +- If a non-text event cannot fit inside the carrier budget, return a validation error rather than sending it. - `rawEvent` must be optional, bounded, and safe to omit. If including `rawEvent` would push a carrier over budget, truncate it or drop it before packing rather than bloating the event. - Truncated raw provider data must be marked, e.g. `rawEventTruncated: true`, so debugging does not confuse partial raw data with complete provider payloads. Preview/body algorithm: - The first visible message is the canonical message for the run. -- Put as much useful early visible preview as practical into the first message while preserving required metadata and staying under the 58KB budget. +- Put as much useful early visible preview as practical into the first message while preserving required metadata and staying under the carrier budget. - All run-level metadata that should survive as the message identity, such as model, usage, thread/run/message IDs, terminal state, and approval summary, belongs on the first visible message or its compact final metadata. - Later carrier messages should be hidden and merged by compatible clients into the first visible message. - Later carrier bodies should be empty or minimal and put payload in `.deltas`. @@ -361,9 +444,61 @@ Preview/body algorithm: Finalization: - The run accumulator is only for finalization, preview generation, and tests. -- Finalization emits compact terminal metadata and a compact final UI state when needed. -- Do not require a final Matrix edit containing the full generated body for over-64KB runs. -- The client is responsible for merging the stream. +- Normal stream chunks remain unaware of final chunk totals. Completion is determined by ordered AG-UI terminal/finalization events plus the final edit ordering, not by `seqTotal`. +- Finalization must emit the complete final AG-UI UI state for supported clients, including text, thinking, tool calls, tool results, approval state, sources/files/data/state, terminal status, usage, and model/run metadata. +- Finalization state may be split across hidden carrier events to stay under the serialized Matrix carrier budget. +- The final Matrix edit is sent only after all normal stream carriers and finalization carriers have been queued. It marks the anchor finalized and carries compact metadata plus Matrix-native preview HTML, not the full parts array. +- Do not require a final Matrix edit containing the full generated body or full AG-UI parts for over-64KB runs. +- The client is responsible for merging the hidden stream/finalization carriers into the anchor message. + +Final snapshot splitting algorithm: + +- Build one final AG-UI `UIMessage` in render order. +- Compact adjacent same-kind text fragments before packing final state when doing so does not lose detail. For example, five adjacent text-only chunks should become one final text part. +- Preserve semantic boundaries. Do not merge text across thinking, tool-call, tool-result, approval, source/file/data, or state parts. +- Start with a base `MESSAGES_SNAPSHOT` event containing the message identity and metadata: + - `id` + - `role` + - `metadata` + - `parts` +- The base event should include as many user-visible parts as fit under budget, in display order. Prioritize visible content over bulky diagnostics. +- If the next part would exceed budget, omit it from the base event and move it to a continuation event instead of duplicating metadata. +- Continuation events use a Beeper-owned AG-UI custom event: `CUSTOM` with `name: "com.beeper.ai.final-parts"`. +- Continuation event payload contains only relation/merge data and parts: + +```json +{ + "messageId": "message-id", + "runId": "run-id", + "threadId": "thread-id", + "partOffset": 3, + "parts": [] +} +``` + +- `partOffset` is the zero-based part index in the final message and is used for deterministic append/validation. Continuations must not repeat full message metadata. +- Desktop merges by applying the base snapshot, then inserting/appending continuation `parts` at `partOffset`. If the continuation part has the same semantic part identity as the part at that offset and only extends a split `content` field, concatenate the content instead of creating a second visible part. +- Split at the highest semantic level possible: carrier -> AG-UI event -> UIMessage parts -> large string fields. +- If a single text or thinking part is too large, split only its `content` at UTF-8 rune boundaries and use the same `partOffset` for the continuation slices so they concatenate back into one part. +- Do not split tool call, tool result, source/file/data, approval, or structured state objects unless there is an explicit field-level reassembly schema. Drop or truncate raw/debug/provider metadata before considering structured splitting. +- If one non-splittable structured part cannot fit under budget after raw/debug/provider metadata is removed, fail packing with a validation error instead of emitting an unmergeable partial object. +- Finalization carriers are sent before the final Matrix edit. The final edit must not race ahead of the final-parts carriers. + +Final Matrix preview: + +- Finalized messages must have Matrix-native preview content on the anchor edit: + - `body`: bounded plain text preview + - `format`: `org.matrix.custom.html` + - `formatted_body`: Matrix HTML generated by mautrix's Markdown renderer +- Unsupported clients are not a primary target, but the final edit should still be a coherent Matrix message preview for timeline/search/notifications. +- The full supported-client AI state comes from hidden carriers, not from the final edit body. + +Transport ordering: + +- For built dummybridge runs, send carrier events contiguously once the anchor Matrix event ID is known. Do not sleep between carriers based on synthetic generation timestamps. +- Demo/random delays may affect when runs are started or what timestamps are embedded in AG-UI events, but they must not delay replaying an already-built carrier sequence before finalization. +- Queue order for one run must be: anchor -> hidden normal carriers -> visible approval prompts/reaction options when applicable -> hidden approval response carriers when resolved -> hidden finalization carriers -> final anchor edit. +- If finalization carriers and final edit arrive in the same sync batch, PAS/Desktop must process carrier stream entries before using the edit to stop streaming. Replay/backfill: @@ -391,7 +526,7 @@ AG-UI state events: - `MESSAGES_SNAPSHOT` carries a complete AG-UI `UIMessage[]` snapshot. - Desktop must preserve and expose this state for AI rendering/devtools instead of dropping it. - State events are allowed to affect rendered state when the renderer intentionally consumes them. -- State events must still obey the 58KB carrier budget and multi-carrier splitting rules. +- State events must still obey the carrier budget and multi-carrier splitting rules. - Do not duplicate the normal streaming path: text should still prefer text events, tool calls should still prefer tool events, and state events should be used when AG-UI state synchronization is the right primitive. Run errors: @@ -471,7 +606,8 @@ Update Desktop as part of parity because the new transport deliberately splits o Dependency: -- Add `@tanstack/ai-react-ui` to the Desktop app and use it for AI message rendering. +- Use `@tanstack/ai-react-ui` from the current Desktop checkout when it is already present. +- If it is absent, ask before adding it, running a package-manager install/update, or changing any manifest/lockfile. - Do not hand-roll a parallel markdown renderer when TanStack's UI package already provides one. - `@tanstack/ai-react-ui` `TextPart` renders Markdown with `react-markdown`, GFM tables/strikethrough via `remark-gfm`, sanitized HTML via `rehype-sanitize`, and code highlighting via `rehype-highlight`. - Keep Beeper-specific shell/layout/actions in Desktop, but delegate TanStack text/thinking/tool/result part rendering to TanStack UI components or thin render props around them. @@ -532,6 +668,9 @@ UI message application: - In `src/renderer/ai/ui-message.ts`, apply AG-UI events into TanStack-shaped parts. - Preserve ordered parts instead of collapsing everything by type. - Render the resulting TanStack `UIMessage` with `@tanstack/ai-react-ui` instead of converting it into a separate Beeper-only part model. +- Type the state at the highest correct level. The renderer should accept a TanStack-shaped `UIMessage`/renderable message type and should not require whole-message `as any` assertions. +- Use narrow builders/guards for Beeper custom part variants instead of broad `MutableUIPart` assertions. If a part is not expressible as a TanStack part, keep the extension isolated behind a typed Beeper custom-part union and convert at the render boundary. +- Do not use assertions to bypass missing required fields. If TanStack requires a field, either populate it from AG-UI state or keep the part out of the TanStack render path until it has a real representation. - Support compatibility input for current events while preferring new output shapes: - text - thinking/step @@ -587,9 +726,11 @@ Rules: - AG-UI stream emits a tool-call state transition to `approval-requested`. - The tool-call part includes `approval: { id, needsApproval: true }`. -- Matrix reaction choices are transport metadata and must not be embedded into AG-UI events. +- Matrix reaction choices are transport metadata and must not be embedded into AG-UI events as the source of truth for reactions. - Approval prompt events should relate to the first visible anchor message and include `threadId`, `runId`, `messageId`, `toolCallId`, and approval ID. - Matrix approval event stores `com.beeper.ai.approval` with tool call ID, tool name, `threadId`, `runId`, `messageId`, expiration if any, and reaction options. +- Approval prompts are separate visible Matrix messages for actionability and reaction handling. +- Supported clients may render the same approval inline as a tool-call variant on the anchor message. To support that, duplicate semantic approval state into the AG-UI stream while keeping Matrix prompt/reaction metadata on the prompt event. - On user reaction, the bridge resolves the option to a `ToolApprovalResponse`. - After resolution, emit AG-UI state `approval-responded`. - If approved, continue execution and emit tool result `complete` or `error`. @@ -607,40 +748,28 @@ Custom events: - Beeper-specific custom events must use a clear namespace such as `com.beeper.*`. - Do not add random one-off custom names when an AG-UI lifecycle, tool, state, or message event already models the behavior. -## Open Decisions With Recommended Defaults - -These are the remaining decisions that affect product behavior or implementation shape. Use the recommended default unless the answer to the question changes the product intent. - -1. Source of truth for AG-UI schemas - - Recommended: `pkg/ag-ui` is the only Go source of truth for AG-UI concepts. Other packages import it instead of redefining parallel event, message, tool, or approval types. - - Decision: Desktop must use TanStack types directly for AG-UI/UI message concepts. Local Desktop types should only describe Beeper transport envelopes and app-specific metadata. +## Decisions And Remaining Behavior -2. Final persisted state for long runs - - Recommended: never require a final full-text edit. The first visible message stores compact identity/terminal metadata and Desktop reconstructs long content from carriers. - - Decision: compact final metadata should include everything needed to render the run except streamed parts/chunks. Do not store full parts/chunks in final metadata for large runs. +Settled decisions: -3. Metadata contract on the first visible message - - Decision: first message owns all non-part run metadata: IDs, model, usage, finish/terminal state, approval summary, source/file/data descriptors that are metadata, and any archived `aichats` metadata that is not the streamed UI parts/chunks themselves. Do not include dollar cost fields unless there is a separate product decision. +- `pkg/ag-ui` is the Go source of truth for AG-UI concepts. Other Go packages import it instead of redefining parallel event, message, tool, or approval types. +- Desktop uses TanStack types directly for AG-UI/UI message concepts wherever possible. Desktop-local types describe Beeper transport and persistence, not a second AI message model. +- Long runs never require a final full-text Matrix edit. The final edit stores compact identity/terminal metadata and Matrix HTML preview; supported clients reconstruct complete UI state from hidden carriers. +- Final AG-UI state is complete and may be split into hidden finalization carriers. The final anchor edit remains compact and must not embed the full parts/chunks array. +- The final split format is base `MESSAGES_SNAPSHOT` plus `com.beeper.ai.final-parts` continuations with relation data and omitted parts only. +- Normal stream chunks do not include `seqTotal` or any total-count field. +- First visible message metadata owns non-part run metadata: IDs, model, usage, finish/terminal state, approval summary, and source/file/data descriptors that are metadata. It does not store streamed text chunks, thinking chunks, tool args, tool results, or full parts. +- Use mautrix in `pkg/ai-stream/matrix`. Keep bridgev2-specific queue/database/redaction behavior outside the pure AG-UI package. -4. Dropped or invalid carriers - - Recommended: Desktop marks the first visible AI message failed and keeps carriers hidden. Do not expose carrier messages as fallback bubbles. - - Direct question: Should Desktop show a recoverable "stream incomplete" state, or a hard failed generation state? +Behavior still requiring implementation or verification: -5. Approval idempotency - - Recommended: first valid approval resolution wins. Later Matrix reactions or programmatic responses are ignored, do not re-run the tool, and may be cleaned up as stale choices. - - Direct question: Should a user be allowed to change approval before the tool starts executing, or is first valid reaction always final? - -6. Allow-always behavior - - Recommended: support it generically in approval fields and reaction options, but dummybridge should only persist/use it if there is a clear storage target. - - Direct question: Should dummybridge actually remember allow-always across runs, or only emit the field to prove the UI/transport supports it? - -7. Package boundaries - - Recommended: embrace mautrix in `pkg/ai-stream/matrix`; keep bridgev2-specific database/queue/redaction in `pkg/ai-stream/bridgev2`; keep `pkg/ag-ui` pure. - - Direct question: Should `pkg/ai-stream/matrix` return mautrix `event.MessageEventContent` directly everywhere, or expose a small content struct plus conversion helpers? - -8. TanStack docs freshness - - Recommended: before implementation starts, re-open current TanStack AI docs and update the contract section if state names or part shapes changed. - - Direct question: Should implementation pin to the docs current at implementation start, or should tests tolerate small TanStack naming changes? +- Dropped or invalid carriers: Desktop should mark the anchor incomplete/failed and keep carriers hidden. Do not show carrier messages as fallback bubbles. +- Missing `seq` gaps: timeout must stop infinite buffering, fail or mark incomplete on the anchor, and keep later stray carrier events hidden. +- Carrier delete/redaction: recompute from remaining carriers when possible; otherwise mark incomplete/failed. +- Replay/backfill: reconstruct the same visible AI run from persisted anchor plus carriers as live streaming. +- Approval idempotency: first valid approval resolution should win. Later reactions/programmatic responses should not re-run the tool and may be cleaned up as stale. +- Allow-always: support the field generically in approval options/responses, but dummybridge should not persist cross-run allow-always state until there is a real product storage target. +- TanStack drift: before future dependency upgrades, re-open current TanStack docs/source and update this contract deliberately. Do not silently adapt by assertions. ## Tests @@ -683,11 +812,16 @@ Dummybridge Go tests: `pkg/ai-stream` tests: - Verify ordered run writer output. +- Verify normal stream envelopes do not contain finalization totals such as `seqTotal`. - Verify no per-delta accumulated full text. - Verify final accumulator is only used at finalization. - Verify UTF-8 splitting. -- Verify carrier packer respects the 58KB serialized JSON budget. +- Verify carrier packer respects the serialized JSON carrier budget. - Verify stream reconstruction from carriers. +- Verify finalization carriers split a complete final UI message into a base snapshot plus continuation parts without repeating metadata. +- Verify finalization continuations merge deterministically by `messageId`, `runId`, and `partOffset`. +- Verify oversized text/thinking final parts split at UTF-8 boundaries and reassemble exactly. +- Verify oversized raw/debug/provider metadata is truncated or omitted before splitting structured tool/data parts. - Verify duplicate/stale/out-of-order `seq` behavior. - Verify missing `seq` gap timeout marks the anchor incomplete/failed. - Verify carrier delete/redaction recomputes or marks the anchor incomplete/failed. @@ -696,11 +830,13 @@ Dummybridge Go tests: Over-64KB tests: - Generate at least 70KiB of output. -- Assert every carrier's serialized content is at or below 58KB. +- Assert every carrier's serialized content is at or below the carrier budget. - Assert at least two carrier events are emitted. - Assert later carriers have no preview body or only minimal body. - Assert reconstruction from deltas exactly equals generated output. - Assert no final full-body edit is required to display the complete stream. +- Assert final snapshot state is complete even when split across finalization carriers. +- Assert final edit contains Matrix `formatted_body` generated by mautrix Markdown rendering and does not contain the full parts array. Desktop tests: @@ -708,6 +844,9 @@ Desktop tests: - Carrier-only events are hidden and do not render as chat bubbles. - Single-update and batched `updates` formats still work. - Multi-carrier stream merges into the visible anchor message. +- Finalization base snapshot plus `com.beeper.ai.final-parts` continuations merge into one final `UIMessage`. +- Final edit arriving after carriers finalizes the existing anchor without creating a second message or flickering back to preview-only content. +- If final edit and stream/finalization carriers arrive in one sync batch, carriers are applied before streaming is stopped. - Out-of-order `seq` buffering works. - Duplicate/stale `seq` handling works. - TanStack-shaped text/thinking/tool/result parts render through the AI message view. @@ -727,8 +866,21 @@ Desktop tests: Commands to run: - In dummybridge: `go test -mod=readonly ./...` -- In Desktop after adding `@tanstack/ai-react-ui`: run the package manager install/update command explicitly approved for that dependency and commit the resulting manifest/lockfile changes with the Desktop implementation. +- In Desktop, if `@tanstack/ai-react-ui` is not already present: ask before running any package-manager install/update command or changing any lockfile. - In Desktop: run the existing focused test commands for touched files. At minimum cover `ai-common`, `ui-message`, `AIChatsStore`, `EventSyncContext`, and stream mapper tests. +- In Desktop: run typecheck after the focused tests. If the full repo typecheck is already failing for unrelated reasons, record the unrelated failures and separately prove touched AI files are type-clean. + +Verification status to track in the PR or completion note: + +- Dummybridge unit tests: command, date, result. +- Desktop focused tests: command, date, result. +- Desktop typecheck: command, date, result, and whether failures touch AI files. +- Source scan: prove no runtime source emits `seqTotal`. +- Over-64KB live smoke: one visible AI anchor, hidden carriers, final Matrix HTML preview, complete reconstructed supported-client state. +- Approval live smoke: visible approval prompt, selected reaction preserved, stale bridge option reactions removed, final anchor edit after response carriers. +- Random/chaos live smoke: no carrier bubbles, no stuck streaming state, no flicker to preview-only content after final edit. +- Replay/backfill smoke: persisted history reconstructs the same visible run after restart/reload. +- Redaction/gap smoke or unit coverage: carriers stay hidden and anchor becomes recomputed or incomplete/failed. ## Live Smoke Testing @@ -754,5 +906,6 @@ Acceptance criteria: - Carrier events do not show as separate bubbles. - Streaming remains incremental. - Over-64KB output reconstructs correctly. +- Finalized over-64KB runs still have one visible anchor message, complete supported-client AG-UI state, and bounded Matrix HTML preview on the final edit. - Approvals work from Matrix reactions and programmatic/TanStack-shaped responses. - The selected approval emoji is kept and non-selected placeholder options are removed. diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index 5821344..f647d01 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -35,17 +35,19 @@ type ReactionEvent struct { } type ApprovalContext struct { - ID string `json:"id"` - ThreadID string `json:"threadId"` - RunID string `json:"runId"` - MessageID string `json:"messageId"` - ToolCallID string `json:"toolCallId"` - ToolName string `json:"toolName"` - TargetEvent string `json:"target_event"` - AgentID string `json:"agentId,omitempty"` - AgentName string `json:"agentName,omitempty"` - Model string `json:"model,omitempty"` - SeqStart int `json:"seqStart,omitempty"` + ID string `json:"id"` + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + MessageID string `json:"messageId"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + TargetEvent string `json:"target_event"` + AgentID string `json:"agentId,omitempty"` + AgentName string `json:"agentName,omitempty"` + Model string `json:"model,omitempty"` + SeqStart int `json:"seqStart,omitempty"` + PreviewText string `json:"previewText,omitempty"` + PreviewTruncated bool `json:"previewTruncated,omitempty"` } func DefaultApprovalOptions(approvalID string) []ReactionOption[agui.ToolApprovalResponse] { @@ -139,6 +141,7 @@ func ApprovalResponseRun(ctx ApprovalContext, response agui.ToolApprovalResponse run.ToolCallID = ctx.ToolCallID run.ApprovalID = ctx.ID run.Status = Status{State: "complete"} + run.Preview = Preview{Text: ctx.PreviewText, Truncated: ctx.PreviewTruncated} run.Approvals = []ApprovalSummary{{ ID: ctx.ID, ToolCallID: ctx.ToolCallID, diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go index e0a9e88..bb86667 100644 --- a/pkg/ai-stream/matrix/content.go +++ b/pkg/ai-stream/matrix/content.go @@ -5,8 +5,8 @@ import ( "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" - "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) @@ -16,6 +16,10 @@ func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any body = "..." } rendered := format.RenderMarkdown(body, true, false) + if rendered.Format != event.FormatHTML { + rendered.Format = event.FormatHTML + rendered.FormattedBody = rendered.Body + } content := &rendered content.BeeperPerMessageProfile = &event.BeeperPerMessageProfile{ ID: run.AgentID, diff --git a/pkg/ai-stream/pack.go b/pkg/ai-stream/pack.go index 011291d..d94168a 100644 --- a/pkg/ai-stream/pack.go +++ b/pkg/ai-stream/pack.go @@ -139,12 +139,12 @@ func ReconstructText(carriers []Carrier) string { } func splitEventForBudget(evt agui.Event, budget int) []agui.Event { - if JSONSize(evt) <= budget { - return []agui.Event{sanitizeRawEvent(evt, budget)} - } if evt["type"] == agui.EventMessagesSnapshot { return splitMessagesSnapshotForBudget(evt, budget) } + if JSONSize(evt) <= budget { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } if evt["type"] != agui.EventTextMessageContent { return []agui.Event{sanitizeRawEvent(evt, budget)} } @@ -172,41 +172,130 @@ func splitMessagesSnapshotForBudget(evt agui.Event, budget int) []agui.Event { } var out []agui.Event for _, message := range rawMessages { - base := agui.CloneEvent(evt) - messageWithoutParts := message - messageWithoutParts.Parts = nil - base["messages"] = []agui.UIMessage{messageWithoutParts} - var current []agui.MessagePart - flush := func() { - if len(current) == 0 { - return - } - cp := agui.CloneEvent(evt) - msg := message - msg.Parts = append([]agui.MessagePart{}, current...) - cp["messages"] = []agui.UIMessage{msg} - out = append(out, sanitizeRawEvent(cp, budget)) - current = nil + out = append(out, splitFinalMessageSnapshot(evt, message, budget)...) + } + if len(out) == 0 { + return []agui.Event{sanitizeRawEvent(evt, budget)} + } + return out +} + +func splitFinalMessageSnapshot(evt agui.Event, message agui.UIMessage, budget int) []agui.Event { + base := agui.CloneEvent(evt) + baseMessage := message + baseMessage.Parts = nil + base["messages"] = []agui.UIMessage{baseMessage} + + var out []agui.Event + baseFlushed := false + flushBase := func() { + if baseFlushed { + return + } + out = append(out, sanitizeRawEvent(base, budget)) + baseFlushed = true + } + appendToBase := func(part agui.MessagePart) bool { + if baseFlushed { + return false + } + nextMessage := baseMessage + nextMessage.Parts = append(append([]agui.MessagePart{}, baseMessage.Parts...), part) + candidate := agui.CloneEvent(base) + candidate["messages"] = []agui.UIMessage{nextMessage} + if JSONSize(candidate) > budget { + return false + } + baseMessage = nextMessage + base["messages"] = []agui.UIMessage{baseMessage} + return true + } + + var continuationParts []agui.MessagePart + continuationOffset := 0 + flushContinuation := func() { + if len(continuationParts) == 0 { + return } - for _, part := range message.Parts { - candidate := append(append([]agui.MessagePart{}, current...), part) - cp := agui.CloneEvent(evt) - msg := message - msg.Parts = candidate - cp["messages"] = []agui.UIMessage{msg} - if len(current) > 0 && JSONSize(cp) > budget { - flush() + out = append(out, finalPartsEvent(evt, message.ID, message.Metadata, continuationOffset, continuationParts)) + continuationParts = nil + } + addContinuation := func(partOffset int, part agui.MessagePart) { + if len(continuationParts) > 0 && partOffset != continuationOffset+len(continuationParts) { + flushContinuation() + } + if len(continuationParts) == 0 { + continuationOffset = partOffset + } + candidateParts := append(append([]agui.MessagePart{}, continuationParts...), part) + candidate := finalPartsEvent(evt, message.ID, message.Metadata, continuationOffset, candidateParts) + if len(continuationParts) > 0 && JSONSize(candidate) > budget { + flushContinuation() + continuationOffset = partOffset + } + continuationParts = append(continuationParts, part) + } + + for partOffset, part := range message.Parts { + for pieceIndex, piece := range splitFinalPartForBudget(part, budget) { + if pieceIndex == 0 && appendToBase(piece) { + continue } - current = append(current, part) + flushBase() + addContinuation(partOffset, piece) } - flush() } - if len(out) == 0 { - return []agui.Event{sanitizeRawEvent(evt, budget)} + flushBase() + flushContinuation() + return out +} + +func finalPartsEvent(base agui.Event, messageID string, metadata map[string]any, partOffset int, parts []agui.MessagePart) agui.Event { + evt := agui.CloneEvent(base) + evt["type"] = agui.EventCustom + evt["name"] = FinalPartsCustomName + delete(evt, "messages") + runID, _ := metadata["runId"].(string) + evt["value"] = map[string]any{ + "messageId": messageID, + "runId": runID, + "partOffset": partOffset, + "parts": append([]agui.MessagePart{}, parts...), + } + return evt +} + +func splitFinalPartForBudget(part agui.MessagePart, budget int) []agui.MessagePart { + partType, _ := part["type"].(string) + if partType != "text" && partType != "thinking" { + return []agui.MessagePart{part} + } + content, _ := part["content"].(string) + if content == "" || JSONSize(part) <= budget/2 { + return []agui.MessagePart{part} + } + maxContent := budget / 3 + if maxContent < 1024 { + maxContent = 1024 + } + chunks := SplitTextUTF8(content, maxContent) + out := make([]agui.MessagePart, 0, len(chunks)) + for _, chunk := range chunks { + cp := cloneMessagePart(part) + cp["content"] = chunk + out = append(out, cp) } return out } +func cloneMessagePart(part agui.MessagePart) agui.MessagePart { + cp := make(agui.MessagePart, len(part)) + for key, value := range part { + cp[key] = value + } + return cp +} + func sanitizeRawEvent(evt agui.Event, budget int) agui.Event { cp := agui.CloneEvent(evt) if _, ok := cp["rawEvent"]; !ok { diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index cbfd767..0dbc351 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -15,8 +15,9 @@ const ( BeeperAIMetadataKey = "com.beeper.ai.metadata" BeeperAIStreamKey = "com.beeper.llm" BeeperAIStreamDeltas = BeeperAIStreamKey + ".deltas" + FinalPartsCustomName = "com.beeper.ai.final-parts" DefaultModel = "dummybridge/ag-ui" - CarrierBudgetBytes = 58 * 1024 + CarrierBudgetBytes = 40 * 1024 PreviewBudgetBytes = 4096 SnapshotTextBytes = 4096 ) @@ -303,7 +304,7 @@ func (w *Writer) addFinalSnapshot() { if w == nil || w.Run == nil { return } - w.MessagesSnapshot([]agui.UIMessage{w.Run.FinalUIMessageSnapshot(SnapshotTextBytes)}) + w.MessagesSnapshot([]agui.UIMessage{w.Run.FinalUIMessageSnapshot(0)}) } func (w *Writer) finishReasoning() { diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go index a4c3830..c5e1ea8 100644 --- a/pkg/ai-stream/stream_test.go +++ b/pkg/ai-stream/stream_test.go @@ -23,18 +23,124 @@ func TestPackRunSplitsOver64KBAndReconstructs(t *testing.T) { if len(carriers) < 2 { t.Fatalf("expected multiple carriers for over-64KB output, got %d", len(carriers)) } + for i, carrier := range carriers { + if size := JSONSize(CarrierContent(carrier.Envelopes)); size > CarrierBudgetBytes { + t.Fatalf("carrier %d is %d bytes, budget %d", i, size, CarrierBudgetBytes) + } + } + if got := ReconstructText(carriers); got != strings.Repeat("a", 70*1024) { + t.Fatalf("reconstructed text length = %d", len(got)) + } +} + +func TestPackRunDoesNotPutFinalizationTotalsOnStreamEnvelopes(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + writer.Text("hello") + writer.Finish(agui.FinishReasonStop) + + carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + raw, err := json.Marshal(CarrierContent(carriers[0].Envelopes)) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), "seqTotal") { + t.Fatalf("stream envelopes must not contain finalization totals: %s", raw) + } +} + +func TestFinalSnapshotSplitsIntoBaseAndContinuationParts(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + writer.Thinking(strings.Repeat("t", 12*1024)) + writer.Text(strings.Repeat("a", 70*1024)) + writer.ToolStart("tool-1", "shell", 0, nil) + writer.ToolArgs("tool-1", `{"cmd":"pwd"}`, `{"cmd":"pwd"}`) + writer.ToolEnd("tool-1", "shell", `{"cmd":"pwd"}`, map[string]any{"ok": true}) + writer.Finish(agui.FinishReasonStop) + + carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + var baseSnapshots, continuations int + var reconstructedText strings.Builder + var sawMetadata bool for i, carrier := range carriers { if size := JSONSize(CarrierContent(carrier.Envelopes)); size > CarrierBudgetBytes { t.Fatalf("carrier %d is %d bytes, budget %d", i, size, CarrierBudgetBytes) } for _, env := range carrier.Envelopes { - if env.SeqTotal <= 0 { - t.Fatalf("carrier envelope missing total count: %#v", env) + switch env.Part["type"] { + case agui.EventMessagesSnapshot: + baseSnapshots++ + messages, ok := env.Part["messages"].([]any) + if !ok || len(messages) != 1 { + t.Fatalf("bad final base snapshot: %#v", env.Part["messages"]) + } + message, ok := messages[0].(map[string]any) + if !ok { + t.Fatalf("bad final base snapshot message: %#v", messages[0]) + } + metadata, ok := message["metadata"].(map[string]any) + if ok && metadata["runId"] == "run-1" { + sawMetadata = true + } + case agui.EventCustom: + if env.Part["name"] != FinalPartsCustomName { + continue + } + continuations++ + value := env.Part["value"].(map[string]any) + if value["messageId"] != run.MessageID || value["runId"] != run.RunID { + t.Fatalf("bad continuation relation data: %#v", value) + } + if _, ok := value["metadata"]; ok { + t.Fatalf("continuation must not duplicate message metadata: %#v", value) + } + for _, part := range testFinalParts(t, value["parts"]) { + if part["type"] == "text" { + reconstructedText.WriteString(part["content"].(string)) + } + } } } } - if got := ReconstructText(carriers); got != strings.Repeat("a", 70*1024) { - t.Fatalf("reconstructed text length = %d", len(got)) + if baseSnapshots != 1 || continuations == 0 || !sawMetadata { + t.Fatalf("expected one metadata base snapshot and continuations, base=%d continuations=%d metadata=%v", baseSnapshots, continuations, sawMetadata) + } + if !strings.Contains(run.Text(), reconstructedText.String()) { + t.Fatalf("unexpected continuation text reconstruction length=%d", reconstructedText.Len()) + } +} + +func testFinalParts(t *testing.T, value any) []map[string]any { + t.Helper() + switch parts := value.(type) { + case []agui.MessagePart: + out := make([]map[string]any, 0, len(parts)) + for _, part := range parts { + out = append(out, map[string]any(part)) + } + return out + case []any: + out := make([]map[string]any, 0, len(parts)) + for _, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + t.Fatalf("bad final part: %#v", rawPart) + } + out = append(out, part) + } + return out + default: + t.Fatalf("bad final parts: %#v", value) + return nil } } @@ -160,6 +266,7 @@ func TestApprovalResponseRunEmitsRespondedStateAndToolResult(t *testing.T) { ToolName: "shell", TargetEvent: "$anchor", SeqStart: 10, + PreviewText: "Use supportbrief for incremental patches.", }, agui.ToolApprovalResponse{ Approved: false, Reason: "denied", @@ -170,6 +277,9 @@ func TestApprovalResponseRunEmitsRespondedStateAndToolResult(t *testing.T) { if run.RunID != "run-1" || run.MessageID != "msg-1" { t.Fatalf("approval response must continue the existing run/message, got %#v", run) } + if run.Preview.Text != "Use supportbrief for incremental patches." { + t.Fatalf("approval response must preserve anchor preview, got %#v", run.Preview) + } if len(run.Events) != 2 { t.Fatalf("expected approval response and tool result events, got %#v", run.Events) } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 566669a..7f0a41b 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -511,7 +511,9 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, mess prompt.SeqStart = nextSeq + i*10 dc.queueAIApprovalPrompt(portal, run, prompt, targetEventID, time.Now()) } - dc.queueAIRunFinalMetadata(portal, messageID, run) + if run.Status.State != "streaming" { + dc.queueAIRunFinalMetadata(portal, messageID, run) + } } func (dc *DummyClient) queueAICarriers(portal *bridgev2.Portal, targetEventID id.EventID, run aistream.Run, startSeq int) ([]aistream.Carrier, error) { @@ -582,17 +584,19 @@ func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, timestamp time.Time) { reactions := aistream.DefaultApprovalOptions(prompt.ID) approvalCtx := aistream.ApprovalContext{ - ID: prompt.ID, - ThreadID: run.ThreadID, - RunID: run.RunID, - MessageID: run.MessageID, - ToolCallID: prompt.ToolCallID, - ToolName: prompt.ToolName, - TargetEvent: string(targetEventID), - AgentID: run.AgentID, - AgentName: run.AgentName, - Model: run.Model, - SeqStart: prompt.SeqStart, + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: string(targetEventID), + AgentID: run.AgentID, + AgentName: run.AgentName, + Model: run.Model, + SeqStart: prompt.SeqStart, + PreviewText: run.Preview.Text, + PreviewTruncated: run.Preview.Truncated, } dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, aiGhostID, approvalCtx, timestamp)) @@ -620,7 +624,9 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid } if _, err := dc.queueAICarriers(portal, targetEventID, run, approvalCtx.SeqStart); err != nil { log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to queue AI approval response") + return } + dc.queueAIRunFinalMetadata(portal, networkid.MessageID(approvalCtx.MessageID), run) } func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { From b42db72eacb89baa64c62912f8dcbe578f6e6e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 20:46:26 +0200 Subject: [PATCH 03/46] wip --- TANSTACK_AG_UI_PARITY_PLAN.md | 911 ------------------------------- pkg/ag-ui/events.go | 2 + pkg/ai-stream/bridgev2/events.go | 90 ++- pkg/ai-stream/pack.go | 37 +- pkg/ai-stream/run.go | 21 +- pkg/connector/ai_runtime.go | 38 +- pkg/connector/ai_runtime_test.go | 6 +- pkg/connector/client.go | 33 +- 8 files changed, 93 insertions(+), 1045 deletions(-) delete mode 100644 TANSTACK_AG_UI_PARITY_PLAN.md diff --git a/TANSTACK_AG_UI_PARITY_PLAN.md b/TANSTACK_AG_UI_PARITY_PLAN.md deleted file mode 100644 index ca3562c..0000000 --- a/TANSTACK_AG_UI_PARITY_PLAN.md +++ /dev/null @@ -1,911 +0,0 @@ -# TanStack/AG-UI Parity Implementation Brief - -## Summary - -Build dummybridge AI around current TanStack AG-UI primitives in both directions: - -- dummybridge emits AG-UI stream events and accepts TanStack-shaped approval responses. -- Desktop consumes multi-event encrypted AG-UI streams and hides carrier events from the normal timeline. -- Shared Go packages define the primitive contract instead of preserving old AI SDK or agentremote decisions. - -Do not install dependencies, update dependencies, or modify lockfiles unless the user explicitly approves that exact dependency change. Before adding `@tanstack/ai-react-ui` or changing the Desktop lockfile, verify whether it is already present in the current Desktop checkout; if it is not, ask first. - -## Plan Of Record - -This is the intended behavior to preserve while finishing the implementation: - -1. The normal AG-UI stream is an ordered delta log. Envelopes have `seq`, never `seqTotal`, carrier totals, or final event counts. -2. Every run has one visible Matrix anchor message. Normal stream carriers and finalization carriers are hidden transport events that merge into that anchor. -3. The final supported-client state is complete AG-UI state: text, thinking, tool calls/results, approval state, sources/files/data/state, terminal status, usage, model, and run metadata. -4. Finalization sends hidden carriers first, then the compact final Matrix edit last. The final edit stops streaming and carries Matrix-native preview HTML; it does not need to carry the full generated text or full parts array. -5. Over-budget final state uses a base `MESSAGES_SNAPSHOT` plus `CUSTOM name="com.beeper.ai.final-parts"` continuations. Continuations contain only relation data and omitted parts, not repeated full metadata. -6. PAS/Desktop must process hidden stream/finalization carriers before treating the final edit as the point where streaming stops. From the renderer's perspective, there is always one final AI message. -7. Approval prompts remain separate visible Matrix messages for actionability, while semantic approval state is also represented in AG-UI so supported clients can show it inline as a tool-call state. -8. Unsupported clients are not the primary target, but the final Matrix edit must still be a coherent bounded Matrix HTML preview for timeline, search, and notifications. - -Do not reintroduce these rejected approaches: - -- No `seqTotal`, carrier totals, or final event totals on normal streaming envelopes. -- No visible carrier bubbles as a fallback for unsupported or failed merge behavior. -- No final full-text Matrix edit for long runs. The final edit is a bounded preview plus compact metadata. -- No broad `as any` or whole-object assertions at the Desktop TanStack render boundary. -- No duplicate Beeper-only UI message model where TanStack types already describe the part contract. -- No package-manager install/update/lockfile change without explicit approval. - -## Implementation Status - -As of this checkout, the plan should be read as a completion/audit checklist rather than a blank design doc. - -Done in dummybridge: - -- New `pkg/ag-ui`, `pkg/ai-stream`, `pkg/ai-stream/matrix`, and `pkg/ai-stream/bridgev2` packages exist. -- The old provisional `pkg/aichats` package has been removed. -- Normal stream envelopes use ordered `seq` and do not carry `seqTotal`. -- `MESSAGES_SNAPSHOT` finalization can split into a metadata-preserving base snapshot plus `com.beeper.ai.final-parts` continuations. -- Large text/thinking final parts split at UTF-8 boundaries and are reassembled by the supported client model. -- Carrier replay for built runs is contiguous; synthetic timestamps no longer add random delays between already-built carrier sends. -- Final anchor edits use mautrix Markdown rendering for Matrix HTML preview content. -- Approval response carriers are queued before the final metadata edit for the anchor. - -Done in the related Desktop checkout: - -- Carrier-only encrypted events with non-empty `*.deltas` are routed as hidden stream updates instead of normal timeline upserts. -- `com.beeper.ai.final-parts` continuations merge into the existing TanStack-shaped UI message by `messageId`, `runId`, and `partOffset`. -- The AI renderer path uses a typed high-level adapter at the TanStack render boundary instead of asserting the whole message as `any`. -- `src/renderer/ai/ui-message.ts` uses typed builders/guards for TanStack and Beeper custom parts instead of broad `MutableUIPart`/record-level assertions. - -Completion status: - -- The non-visual plan gates are implemented and verified by the evidence below. -- Full Desktop typecheck is still red, but the failures are outside the touched AI/PAS files and are listed below. -- Visual testing remains explicitly excluded from the current completion target. - -Completion gates: - -- Unit/focused tests pass for dummybridge and touched Desktop AI/PAS paths. -- Full Desktop typecheck either passes or every failure is documented as unrelated to touched AI/PAS files. -- Live staging smoke proves over-64KB output produces one visible AI anchor and hidden carriers only. -- Live staging smoke proves approve and deny both finalize through hidden response carriers before the final anchor edit. -- Replay/backfill, redaction/delete, and missing-gap behavior have either automated tests or an explicit live/manual verification note. -- `rg` source scan proves runtime source does not emit `seqTotal`. - -Current verification snapshot, 2026-05-19: - -- `go test -mod=readonly ./...` in dummybridge passes. -- Desktop focused AI/PAS tests pass: - - `bun run test --run src/common/ai-common.test.ts src/renderer/ai/ui-message.test.ts src/renderer/ai/stream-ordering.test.ts src/renderer/stores/AIChatsStore.test.ts src/pas-server/beeper/EventSyncContext.test.ts src/pas-server/beeper/connect/ws-event-mapper.test.ts src/pas-server/beeper/connect/ws-events-server.test.ts` -- Full Desktop typecheck still fails in unrelated files outside the touched AI/PAS paths: - - `BrandLink.stories.tsx` - - `ComposeMessage/TextArea/TextArea.tsx` - - `DetachedAccountsOnboarding.tsx` - - `electron-ipc.ts` - - `measureInteractionNextPaint.ts` - - `QuickRepliesPrefsSubView.tsx` -- Runtime source scans find no total-count fields: - - `rg -n 'seqTotal|carrierTotal|finalEventTotal' pkg cmd --glob '!**/*_test.go'` - - `rg -n 'seqTotal|carrierTotal|finalEventTotal' src/common src/renderer/ai src/renderer/stores src/pas-server/beeper --glob '!**/*.test.ts' --glob '!**/*.test.tsx'` -- Desktop AI render/store source scan finds no broad assertions in the touched AI paths: - - `rg -n '\] as any|as MutableUIPart|MutableUIPart|ToolUIPartRecord|console\.log\(' src/renderer/ai src/renderer/stores/AIChatsStore.ts` -- Live staging over-64KB smoke passed after lowering the raw carrier budget to 40KB: one visible AI anchor, hidden carriers, no visible `com.beeper.ai.final-parts` leakage, and Matrix HTML preview on the final edit. -- Live staging approval approve and deny pass: prompt remains separate, selected user reaction remains, bridge option reactions are redacted, response carriers are queued, and final anchor edit preserves the existing preview instead of reverting to `...`. -- Live staging random/chaos smoke passes for hidden-carrier behavior: no carrier bubbles appeared in Desktop API output, approval prompts stayed separate, and final edits did not regress completed approval anchors. -- Replay/backfill has unit coverage through batched `updates` extraction and `AIChatsStore` replay into an existing anchor. -- Deleted/redacted carrier and missing-gap behavior have unit coverage: carrier deletion marks the anchor failed while keeping carrier events hidden, and unresolved sequence gaps now fail via timer without requiring another stream event. - -## Baseline And Existing Entry Points - -The original plan replaced a provisional `pkg/aichats` package plus AI handling in `pkg/connector/client.go`. In this checkout, `pkg/aichats` should stay deleted; the active implementation is the new `pkg/ag-ui` and `pkg/ai-stream` stack plus the connector integration points. - -Behavior baseline to preserve or improve: - -- AI DM resolution uses the `ai`/`AI` ghost and AI portals with the `ai-` prefix. -- The bridge sends one visible anchor/placeholder event, streams `com.beeper.llm.deltas`, then edits the anchor with final compact metadata and Matrix preview content. -- Stream deltas must be real AG-UI events, not AG-UI-like compatibility shapes. -- Approval requests are separate Matrix events with `com.beeper.ai.approval` metadata and reaction options. -- Approval reaction handling keeps the user's selected emoji and removes bridge-posted placeholder/non-selected options. -- Text streaming must send incremental deltas and must not resend full accumulated text on every delta. - -Desktop already has partial AI stream support in these areas: - -- `src/common/ai-common.ts`: `BeeperAIMessage`, `BeeperAGUIEvent`, approval constants, and type guards. -- `src/common/types/beeper.ts`: stream content types with `.deltas` and `updates`. -- `src/pas-server/beeper/EventSyncContext.ts`: maps `com.beeper.ai`, `com.beeper.stream`, per-message profile, edits, and hidden AI notices. -- `src/pas-server/beeper/BeeperClient.ts`: processes stream events into `STATE_SYNC message stream`. -- `src/renderer/stores/AIChatsStore.ts`: extracts `.deltas`, orders by `seq`, applies AG-UI events, tracks approvals, and merges stream state. -- `src/renderer/ai/ui-message.ts`: applies AG-UI events into current Desktop UI message parts. - -The implementation should update those Desktop paths instead of inventing a second client stream path. - -Implementation rules: - -- Keep the code simple, clean, and direct. -- Prefer less LOC, less indirection, and fewer abstractions. -- Fold or flatten abstractions that do not carry real behavior. -- Do not add fake layers, simple wrappers, barrel exports, duplicated logic, or duplicated types. -- Smaller files are fine only when they represent real concerns. -- Optimize for one coherent system per concern, not multiple parallel ways to do the same thing. -- Current AI code was generated and never released, so no backward compatibility or legacy compatibility is required. -- Delete provisional schemas, routes, event shapes, migrations, aliases, and helper layers if they only exist for history or compatibility. -- Prefer deleting code over preserving it. -- Prefer collapsing duplicate entrypoints over keeping aliases. -- Product intention matters more than the current code shape. -- If product intent is ambiguous, explicitly call out the question instead of encoding both options. - -Compatibility policy: - -- No compatibility policy is required for the provisional/current dummybridge AI event shapes. -- Desktop and dummybridge should converge on one new TanStack/AG-UI shape. -- Delete old reader/writer paths instead of accepting old names such as `REASONING_MESSAGE_*` or older tool-call fields unless they are required by current TanStack AG-UI docs. - -## Full Paths To Inspect - -Primary dummybridge checkout: - -- `/Users/batuhan/Projects/labs/dummybridge` -- `/Users/batuhan/Projects/labs/dummybridge/TANSTACK_AG_UI_PARITY_PLAN.md` -- `/Users/batuhan/Projects/labs/dummybridge/README.md` -- `/Users/batuhan/Projects/labs/dummybridge/go.mod` -- `/Users/batuhan/Projects/labs/dummybridge/go.sum` -- `/Users/batuhan/Projects/labs/dummybridge/config-agui.yaml` -- `/Users/batuhan/Projects/labs/dummybridge/config-qa-agui.yaml` - -Legacy dummybridge AI implementation that should not be restored: - -- `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/agui.go` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/matrix.go` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/aichats/agui_test.go` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/client.go` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/connector.go` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/login.go` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/connector/example-config.yaml` - -Active dummybridge package targets: - -- `/Users/batuhan/Projects/labs/dummybridge/pkg/ag-ui` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream/matrix` -- `/Users/batuhan/Projects/labs/dummybridge/pkg/ai-stream/bridgev2` - -Archived AI dummybridge reference: - -- `/Users/batuhan/Projects/labs/ai-bridge-archived` -- `/Users/batuhan/Projects/labs/ai-bridge-archived/bridges/dummybridge/runtime.go` -- `/Users/batuhan/Projects/labs/ai-bridge-archived/bridges/dummybridge/runtime_test.go` -- `/Users/batuhan/Projects/labs/ai-bridge-archived/sdk/writer.go` -- `/Users/batuhan/Projects/labs/ai-bridge-archived/approval_flow.go` - -Local TanStack AI reference checkout: - -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-react-ui/src/text-part.tsx` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-react-ui/src/chat-message.tsx` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-react-ui/package.json` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-client/src/types.ts` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-client/src/chat-client.ts` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-client/src/connection-adapters.ts` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai-event-client/src/index.ts` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai/src/types.ts` -- `/Users/batuhan/Projects/labs/upstream/tanstack-ai/packages/typescript/ai/src/utilities/chat-params.ts` - -Desktop checkout and AI consumer paths: - -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/common/ai-common.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/common/ai-common.test.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/common/types/beeper.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/EventSyncContext.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/BeeperClient.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/connect/ws-event-mapper.test.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/pas-server/beeper/connect/ws-events-server.test.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/stores/AIChatsStore.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/ai/ui-message.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/ai/ui-message.test.ts` -- `/Users/batuhan/Projects/texts/beeper-workspace/beeper/beeper/desktop/src/renderer/ai/ai-message-view.ts` - -Local tooling for live smoke tests: - -- `/Users/batuhan/Projects/texts/bridge-manager/bbctl` -- `/Users/batuhan/Projects/labs/desktop-api-cli/packages/cli` - -Local runtime artifacts that may be useful for debugging, but should not be treated as source: - -- `/Users/batuhan/Projects/labs/dummybridge/logs` -- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-agui.db` -- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-agui.db-shm` -- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-agui.db-wal` -- `/Users/batuhan/Projects/labs/dummybridge/sh-dummybridge-qa-agui.db` - -## TanStack/AG-UI Contract - -Use TanStack primitives as the source of truth: - -- `StreamChunk = AGUIEvent`; do not preserve legacy non-AG-UI chunk formats. -- Support every current AG-UI lifecycle event explicitly: - - `RUN_STARTED` - - `RUN_FINISHED` - - `RUN_ERROR` - - `TEXT_MESSAGE_START` - - `TEXT_MESSAGE_CONTENT` - - `TEXT_MESSAGE_END` - - `TOOL_CALL_START` - - `TOOL_CALL_ARGS` - - `TOOL_CALL_END` - - `TOOL_CALL_RESULT` - - `STEP_STARTED` - - `STEP_FINISHED` - - `STATE_SNAPSHOT` - - `STATE_DELTA` - - `MESSAGES_SNAPSHOT` - - `CUSTOM` -- Support bidirectional AG-UI run input: `threadId`, `runId`, `state`, `messages`, `tools`, `context`, `forwardedProps`, and legacy `data` mirror. -- Model `UIMessage` as `{ id, role, parts, createdAt? }`, preserving ordered parts. -- Use TanStack part shapes: - - Text part: `{ type: "text", content }` - - Thinking part: `{ type: "thinking", content }` - - Tool call part: `{ type: "tool-call", id, name, arguments, state, approval?, output? }` - - Tool result part: `{ type: "tool-result", toolCallId, content, state, error? }` -- Use TanStack tool states: - - `awaiting-input` - - `input-streaming` - - `input-complete` - - `approval-requested` - - `approval-responded` -- Use TanStack tool result states: - - `streaming` - - `complete` - - `error` -- Treat AG-UI `REASONING_START`, `REASONING_MESSAGE_START`, `REASONING_MESSAGE_CONTENT`, `REASONING_MESSAGE_END`, and `REASONING_END` as the canonical thinking/reasoning stream for new output. -- Keep `STEP_STARTED` / `STEP_FINISHED` as step lifecycle events using AG-UI `stepName`, not deprecated `stepId`, and not as a substitute for reasoning content. -- Fully support AG-UI `STATE_SNAPSHOT`, `STATE_DELTA`, and `MESSAGES_SNAPSHOT` events. -- Every emitted AG-UI event must include `timestamp`. -- Support optional AG-UI `rawEvent` on every event, with the bounded/truncation policy below. -- Support `TOOL_CALL_START.index` for parallel tool calls. -- Support partial JSON argument streaming through `TOOL_CALL_ARGS`; consumers should preserve partial input while parsing best-effort and finalize on `TOOL_CALL_END`. -- Support `TOOL_CALL_END` both with and without a result payload. -- Support AG-UI `TOOL_CALL_RESULT` for separate tool-result parts instead of Beeper custom tool-result events. -- Support multiple assistant `messageId`s per run. Do not assume a run has exactly one assistant text message. - -Relevant docs: - -- AG-UI event definitions: -- Streaming: -- Tool states and parts: -- UIMessage: -- Bidirectional AG-UI compliance: -- Local source tags currently include `@tanstack/ai@0.18.0`, `@tanstack/ai-client@0.10.0`, and `@tanstack/ai-event-client@0.3.2`; prefer the local checkout above for exact type names during implementation. - -## Package Layout - -Create `pkg/ag-ui/` with Go package name `agui`. - -Responsibilities: - -- Standalone AG-UI event and UI message types. -- `RunAgentInput` and bidirectional request types. -- Tool, tool result, approval, text, thinking, step, custom, run, and error event builders. -- Validation helpers that reject invalid event ordering, missing IDs, bad states, invalid tool approval shapes, and oversized individual deltas. -- No Matrix, bridgev2, Desktop, or dummybridge-specific dependencies. - -Create `pkg/ai-stream/` with Go package name `aistream`. - -Responsibilities: - -- Run writer for ordered AG-UI event emission. -- Accumulation used only for finalization, preview generation, and test reconstruction. -- Stream envelope and chunk packing helpers. -- Approval resolver primitives. -- Terminal/finalization helpers. -- Spec enforcement, but not transport ownership. - -Add adapter layers: - -- `pkg/ai-stream/matrix`: Matrix content helpers using mautrix event types, stream carrier content, approval prompt content, reaction option serialization. -- `pkg/ai-stream/bridgev2`: bridgev2 queue/send/redaction adapter. This layer may import bridgev2 and database types. - -Delete `pkg/aichats` once the new packages fully replace it. Do not keep it as an unused compatibility package. - -## Archived Dummybridge Parity - -Use `../ai-bridge-archived/bridges/dummybridge/runtime.go` and `runtime_test.go` as the feature checklist, not as an architecture to copy. - -Commands: - -- `help` -- `/help` -- `!help` -- `dummybridge help` -- `stream-lorem [common options]` -- `stream-tools ... [common options]` -- `stream-random [seconds] [--actions=N] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--delay-ms=min:max] [--allow-abort] [--allow-error] [--allow-approval]` -- `stream-chaos [runs] [seconds] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--stagger-ms=min:max] [--max-actions=N] [--allow-abort] [--allow-error] [--allow-approval]` - -The help aliases are intentional product/demo affordances and should remain unless there is a later product decision to reduce command aliases. - -Common options: - -- `--reasoning=N` -- `--steps=N` -- `--sources=N` -- `--documents=N` -- `--files=N` -- `--meta` -- `--data=name` -- `--data-transient=name` -- `--delay-ms=min:max` -- `--chunk-chars=min:max` -- `--seed=N` -- `--finish=stop|length|tool-calls|content-filter|other` -- `--abort` -- `--error` - -Tool tags: - -- `#fail` -- `#approval` -- `#deny` -- `#delta` -- `#inputerror` -- `#prelim` -- `#provider` - -Behavior to preserve: - -- `stream-lorem` emits markdown-rich visible text, optional thinking/reasoning, optional steps, optional sources/documents/files/data, and a final run state. -- `stream-tools` emits text, thinking, tool input streaming, input errors, approval requests, approval denials, tool output streaming, final output, and tool failures. -- `stream-random` emits weighted random actions with deterministic seed support and profiles. -- `stream-chaos` starts multiple staggered runs and runs random streams per run. -- Persistent data survives final snapshots; transient data does not. -- Markdown generation must include realistic links, lists, quotes, code blocks, and tables. -- Terminal states include normal finish, error, and abort. - -Limits should start from the archived limits, except the explicit over-64KB streaming tests require larger generated output support: - -- Archived default chunk range: 24 to 96 characters. -- Archived maximum chunk size option: 512 characters. -- Archived maximum random actions: 64. -- Archived maximum chaos runs: 16. -- Archived maximum chaos actions: 64. -- Archived maximum demo duration: 5 minutes. -- Archived maximum delay/stagger: 30 seconds. -- Increase text generation limits enough to test at least 70KiB output. The transport must handle this by splitting carrier events, not by sending oversized Matrix events. - -## Streaming Transport - -Every AI run starts with one visible Matrix anchor event. - -Anchor event requirements: - -- `msgtype: m.text` -- AI per-message profile for the AI ghost -- Minimal `com.beeper.ai` -- Stable AG-UI `threadId` -- Stable AG-UI `runId` -- Stable AG-UI `messageId` -- Useful preview text in `body` -- `com.beeper.stream` descriptor when using the Beeper stream publisher - -ID model: - -- Use AG-UI IDs for semantic identity. -- `threadId` is the conversation/thread identity. For dummybridge this should map to the Beeper thread/portal/room identity used by Desktop. -- `runId` is the assistant execution identity. Do not add a separate Beeper execution ID unless a future AG-UI version requires it. -- `messageId` is the AG-UI assistant UI message identity. It should map to the first visible/anchor message, not to every carrier. -- Matrix event IDs are transport identities. Use the anchor Matrix event ID as `target_event` / `m.relates_to.event_id` for carriers. -- The Beeper stream descriptor is identified by `(room_id, event_id, type)` and does not expose a separate stream ID. Do not invent `streamId`; use `target_event` plus `runId` for merging. - -Carrier events: - -- Are sent through bridgev2 remote events so E2EE works normally. -- Must never be raw Matrix sends. -- Are `m.room.message` events with `msgtype: m.text` for bridgev2 and client compatibility. -- Contain `com.beeper.llm.deltas`. -- Carry ordered AG-UI envelopes. -- Are hidden from normal chat rendering by Desktop after deltas are extracted. -- Use empty or minimal body text after the initial visible preview; they must not appear as chat bubbles in Desktop. - -Envelope shape: - -- `threadId` -- `runId` -- `messageId` -- `seq` -- `part` -- `target_event` or `m.relates_to.event_id` -- optional `agent_id` - -Ordering and merge key: - -- `seq` is strictly increasing per `{target_event, runId}`. -- Do not put total counts such as `seqTotal`, carrier count, or final event count on normal stream envelopes. The streaming layer is an ordered event stream, not a pre-counted file transfer. -- If `target_event` is unavailable during early processing, temporarily key by `{threadId, runId}` and promote to `{target_event, runId}` when the anchor message is known. -- Desktop buffers out-of-order deltas within existing ordering limits. -- Duplicate or stale `seq` values are ignored or rejected consistently. - -Size budget: - -- Treat 64KB as the external ceiling. -- Use a hard carrier budget of 40KB for serialized Matrix content. Live staging E2EE sends showed that 58KB raw carrier content can become 66-79KB encrypted Matrix event content, so the budget must leave room for megolm/base64/wrapper overhead. -- The packer must measure serialized JSON byte size before adding an envelope to a carrier. -- If a single text delta would exceed the carrier budget, split it at UTF-8 rune boundaries. -- If a non-text event cannot fit inside the carrier budget, return a validation error rather than sending it. -- `rawEvent` must be optional, bounded, and safe to omit. If including `rawEvent` would push a carrier over budget, truncate it or drop it before packing rather than bloating the event. -- Truncated raw provider data must be marked, e.g. `rawEventTruncated: true`, so debugging does not confuse partial raw data with complete provider payloads. - -Preview/body algorithm: - -- The first visible message is the canonical message for the run. -- Put as much useful early visible preview as practical into the first message while preserving required metadata and staying under the carrier budget. -- All run-level metadata that should survive as the message identity, such as model, usage, thread/run/message IDs, terminal state, and approval summary, belongs on the first visible message or its compact final metadata. -- Later carrier messages should be hidden and merged by compatible clients into the first visible message. -- Later carrier bodies should be empty or minimal and put payload in `.deltas`. -- Compatible clients must reconstruct from ordered deltas and merge content/parts into the first message, not display carriers as separate runs. -- Do not rewrite full accumulated content on every delta. - -Finalization: - -- The run accumulator is only for finalization, preview generation, and tests. -- Normal stream chunks remain unaware of final chunk totals. Completion is determined by ordered AG-UI terminal/finalization events plus the final edit ordering, not by `seqTotal`. -- Finalization must emit the complete final AG-UI UI state for supported clients, including text, thinking, tool calls, tool results, approval state, sources/files/data/state, terminal status, usage, and model/run metadata. -- Finalization state may be split across hidden carrier events to stay under the serialized Matrix carrier budget. -- The final Matrix edit is sent only after all normal stream carriers and finalization carriers have been queued. It marks the anchor finalized and carries compact metadata plus Matrix-native preview HTML, not the full parts array. -- Do not require a final Matrix edit containing the full generated body or full AG-UI parts for over-64KB runs. -- The client is responsible for merging the hidden stream/finalization carriers into the anchor message. - -Final snapshot splitting algorithm: - -- Build one final AG-UI `UIMessage` in render order. -- Compact adjacent same-kind text fragments before packing final state when doing so does not lose detail. For example, five adjacent text-only chunks should become one final text part. -- Preserve semantic boundaries. Do not merge text across thinking, tool-call, tool-result, approval, source/file/data, or state parts. -- Start with a base `MESSAGES_SNAPSHOT` event containing the message identity and metadata: - - `id` - - `role` - - `metadata` - - `parts` -- The base event should include as many user-visible parts as fit under budget, in display order. Prioritize visible content over bulky diagnostics. -- If the next part would exceed budget, omit it from the base event and move it to a continuation event instead of duplicating metadata. -- Continuation events use a Beeper-owned AG-UI custom event: `CUSTOM` with `name: "com.beeper.ai.final-parts"`. -- Continuation event payload contains only relation/merge data and parts: - -```json -{ - "messageId": "message-id", - "runId": "run-id", - "threadId": "thread-id", - "partOffset": 3, - "parts": [] -} -``` - -- `partOffset` is the zero-based part index in the final message and is used for deterministic append/validation. Continuations must not repeat full message metadata. -- Desktop merges by applying the base snapshot, then inserting/appending continuation `parts` at `partOffset`. If the continuation part has the same semantic part identity as the part at that offset and only extends a split `content` field, concatenate the content instead of creating a second visible part. -- Split at the highest semantic level possible: carrier -> AG-UI event -> UIMessage parts -> large string fields. -- If a single text or thinking part is too large, split only its `content` at UTF-8 rune boundaries and use the same `partOffset` for the continuation slices so they concatenate back into one part. -- Do not split tool call, tool result, source/file/data, approval, or structured state objects unless there is an explicit field-level reassembly schema. Drop or truncate raw/debug/provider metadata before considering structured splitting. -- If one non-splittable structured part cannot fit under budget after raw/debug/provider metadata is removed, fail packing with a validation error instead of emitting an unmergeable partial object. -- Finalization carriers are sent before the final Matrix edit. The final edit must not race ahead of the final-parts carriers. - -Final Matrix preview: - -- Finalized messages must have Matrix-native preview content on the anchor edit: - - `body`: bounded plain text preview - - `format`: `org.matrix.custom.html` - - `formatted_body`: Matrix HTML generated by mautrix's Markdown renderer -- Unsupported clients are not a primary target, but the final edit should still be a coherent Matrix message preview for timeline/search/notifications. -- The full supported-client AI state comes from hidden carriers, not from the final edit body. - -Transport ordering: - -- For built dummybridge runs, send carrier events contiguously once the anchor Matrix event ID is known. Do not sleep between carriers based on synthetic generation timestamps. -- Demo/random delays may affect when runs are started or what timestamps are embedded in AG-UI events, but they must not delay replaying an already-built carrier sequence before finalization. -- Queue order for one run must be: anchor -> hidden normal carriers -> visible approval prompts/reaction options when applicable -> hidden approval response carriers when resolved -> hidden finalization carriers -> final anchor edit. -- If finalization carriers and final edit arrive in the same sync batch, PAS/Desktop must process carrier stream entries before using the edit to stop streaming. - -Replay/backfill: - -- Desktop must be able to reconstruct a run from persisted anchor plus persisted carrier messages, not only from live stream events. -- Replay must use the same merge key and ordering rules as live streaming. -- Backfilled carrier events should remain hidden after extraction. - -Redaction/delete behavior: - -- If a carrier is deleted/redacted, Desktop should recompute the visible anchor from remaining carrier events when possible. -- If recomputation leaves a sequence gap or invalid stream, mark the anchor message incomplete/failed. -- Approval prompt deletion/redaction should not delete or corrupt the AI run; it only removes that visible prompt. - -Ordering gap timeout: - -- Do not buffer missing `seq` gaps forever. -- If a gap remains unresolved past the configured timeout, mark the first visible anchor message incomplete/failed and keep carrier messages hidden. -- Late arrivals after failure should not create separate visible carrier messages. - -AG-UI state events: - -- Fully support `STATE_SNAPSHOT`, `STATE_DELTA`, and `MESSAGES_SNAPSHOT` as first-class AG-UI events. -- `STATE_SNAPSHOT` replaces the current run/application state view for the AI run. -- `STATE_DELTA` applies an incremental patch/update to that state. -- `MESSAGES_SNAPSHOT` carries a complete AG-UI `UIMessage[]` snapshot. -- Desktop must preserve and expose this state for AI rendering/devtools instead of dropping it. -- State events are allowed to affect rendered state when the renderer intentionally consumes them. -- State events must still obey the carrier budget and multi-carrier splitting rules. -- Do not duplicate the normal streaming path: text should still prefer text events, tool calls should still prefer tool events, and state events should be used when AG-UI state synchronization is the right primitive. - -Run errors: - -- `RUN_ERROR` may be run-scoped or session/thread-scoped. -- If `RUN_ERROR.runId` is present, fail only that run. -- If `RUN_ERROR.runId` is absent, fail active runs in that thread/session. -- Desktop should surface the failure on the first visible anchor message for each affected run and keep carrier messages hidden. - -First message metadata schema: - -- The first visible message must contain enough metadata for a compatible client to render run chrome, status, model info, usage, approvals, and non-part attachments without reading carrier event metadata. -- Do not put streamed UI parts, text chunks, thinking chunks, tool argument chunks, or tool output chunks in this metadata. -- AG-UI/TanStack-mirrored fields must use TanStack naming and value shapes: `threadId`, `runId`, `messageId`, `finishReason`, `promptTokens`, `completionTokens`, and `totalTokens`. -- Beeper-only fields should be grouped clearly under Beeper-owned names instead of changing AG-UI concepts. -- Suggested `com.beeper.ai.metadata` shape: - -```json -{ - "schema": "com.beeper.ai.run.v1", - "protocol": "ag-ui", - "threadId": "thread-id", - "runId": "run-id", - "messageId": "message-id", - "agent": { - "id": "ai", - "displayName": "AI" - }, - "model": "dummybridge/ag-ui", - "usage": { - "promptTokens": 0, - "completionTokens": 0, - "totalTokens": 0 - }, - "usageDetails": { - "reasoningTokens": 0, - "cachedInputTokens": 0 - }, - "status": { - "state": "streaming", - "finishReason": "stop", - "terminal": null, - "error": null - }, - "approvals": [ - { - "id": "approval-id", - "toolCallId": "tool-call-id", - "state": "requested", - "always": false, - "reason": "" - } - ], - "artifacts": { - "sources": [], - "documents": [], - "files": [] - }, - "data": {}, - "preview": { - "text": "bounded visible preview", - "truncated": true - } -} -``` - -- `model` is the AG-UI model identifier string. Do not add `modelInfo`; display/provider details should be derived from the model registry, agent profile, or bridge/network metadata instead of duplicated on every message. -- `usage` mirrors AG-UI `RUN_FINISHED.usage`. Extra usage fields belong in `usageDetails`. -- `finishReason` should use TanStack/AG-UI values: `stop`, `length`, `content_filter`, `tool_calls`, or `null`. Command aliases may accept hyphenated input, but emitted metadata should use AG-UI values. -- `usage` is token/usage metadata only. Do not add dollar cost fields unless the product explicitly decides to expose pricing. -- `artifacts` and `data` are for descriptors needed to render run-level UI outside the streamed parts. If an item is naturally a UI part, it should stay in the stream instead of being duplicated here. -- Final compact metadata may update `status`, `usage`, `approvals`, `artifacts`, `data`, and `preview`, but still must not embed full chunks/parts. - -## Desktop Work - -Update Desktop as part of parity because the new transport deliberately splits one stream across multiple Matrix events. - -Dependency: - -- Use `@tanstack/ai-react-ui` from the current Desktop checkout when it is already present. -- If it is absent, ask before adding it, running a package-manager install/update, or changing any manifest/lockfile. -- Do not hand-roll a parallel markdown renderer when TanStack's UI package already provides one. -- `@tanstack/ai-react-ui` `TextPart` renders Markdown with `react-markdown`, GFM tables/strikethrough via `remark-gfm`, sanitized HTML via `rehype-sanitize`, and code highlighting via `rehype-highlight`. -- Keep Beeper-specific shell/layout/actions in Desktop, but delegate TanStack text/thinking/tool/result part rendering to TanStack UI components or thin render props around them. - -TanStack ownership in Desktop: - -- Desktop already depends on `@tanstack/ai` and `@tanstack/ai-client`; use those packages instead of duplicating their concepts. -- Import `UIMessage`, `MessagePart`, `TextPart`, `ThinkingPart`, `ToolCallPart`, `ToolResultPart`, `ToolCallState`, and `ToolResultState` from TanStack packages. -- Use TanStack `StreamProcessor`/stream utilities where practical for applying AG-UI chunks into UI messages instead of maintaining a parallel Desktop-only stream reducer. -- Use TanStack `parsePartialJSON`/partial JSON utilities for streaming tool args instead of maintaining a separate parser. -- Use `@tanstack/ai-react-ui` for `ChatMessage`/part rendering, with Beeper render props only for product-specific chrome, approvals, and bridge actions. -- Delete or collapse Desktop-only normalized AI types that duplicate TanStack structures, such as separate text/reasoning/tool-call models, once the TanStack path can feed the UI directly. -- Keep Desktop-local types only for Beeper transport/persistence: Matrix event IDs, `target_event`, carrier visibility, `com.beeper.stream`, `com.beeper.ai.metadata`, and approval prompt Matrix metadata. - -Intentional AG-UI boundaries: - -- AG-UI owns semantic events, UI message parts, tool states, run input, and stream processing. -- Beeper owns transport: encrypted Matrix events, carrier hiding, target event mapping, replay from persisted Matrix history, and approval reaction cleanup. -- `com.beeper.ai.metadata` is Beeper message metadata and must not become a second UI message schema. It may store non-part run metadata, but streamed parts/chunks remain AG-UI. -- `target_event` is a Beeper transport pointer, not an AG-UI field. Keep it in the carrier envelope and Desktop stream routing, not in TanStack `UIMessage`. - -PAS sync: - -- In `src/pas-server/beeper/EventSyncContext.ts`, detect decrypted `m.room.message` events that contain stream delta content keys ending in `.deltas` or batched `updates`. -- Extract stream deltas from encrypted carrier events after decryption. -- Emit `STATE_SYNC message stream` updates instead of normal message upserts for carrier-only events. -- Mark carrier timeline events hidden after extraction so they do not show as chat bubbles. -- Keep the visible anchor event and approval prompt event as normal messages. -- Preserve `com.beeper.ai`, `com.beeper.stream`, and per-message profile behavior for anchor/final messages. - -Beeper client stream routing: - -- In `src/pas-server/beeper/BeeperClient.ts`, keep using the existing stream event path, but ensure multi-carrier events preserve `room_id`, carrier `event_id`, `target_event`, `threadId`, `runId`, `messageId`, and `seq`. - -Common types: - -- In `src/common/types/beeper.ts`, extend stream types to include AG-UI `threadId`, `runId`, and `messageId`, plus Beeper transport `target_event`. -- Keep support for both single-update `.deltas` and batched replay `updates`. - -Renderer store: - -- In `src/renderer/stores/AIChatsStore.ts`, merge by `{target_event, runId}` rather than only target message/run. -- Map carrier target events back to the visible anchor message. -- Continue buffering out-of-order `seq`. -- Treat stream carriers as dirtying and extending the first visible AI message only, never as separate visible messages. -- Merge streamed content and ordered UI parts into the first visible message's renderer state. -- Hide carrier messages after extracting deltas. -- Track approval prompts by approval ID and target tool call. -- Support multiple assistant messages/parts per run by indexing on AG-UI `messageId`, not assuming one text part per run. -- Support parallel tool calls by distinct tool call IDs and optional `index`. -- Preserve streamed partial tool arguments while parsing partial JSON best-effort; replace with finalized arguments when `TOOL_CALL_END` arrives. -- Accept tool output either on `TOOL_CALL_END.result` or as AG-UI `TOOL_CALL_RESULT`. -- Apply run-scoped versus thread-scoped `RUN_ERROR` behavior as described above. -- Reconstruct runs from persisted anchor plus carrier history during replay/backfill using the same code path as live streaming. - -UI message application: - -- In `src/renderer/ai/ui-message.ts`, apply AG-UI events into TanStack-shaped parts. -- Preserve ordered parts instead of collapsing everything by type. -- Render the resulting TanStack `UIMessage` with `@tanstack/ai-react-ui` instead of converting it into a separate Beeper-only part model. -- Type the state at the highest correct level. The renderer should accept a TanStack-shaped `UIMessage`/renderable message type and should not require whole-message `as any` assertions. -- Use narrow builders/guards for Beeper custom part variants instead of broad `MutableUIPart` assertions. If a part is not expressible as a TanStack part, keep the extension isolated behind a typed Beeper custom-part union and convert at the render boundary. -- Do not use assertions to bypass missing required fields. If TanStack requires a field, either populate it from AG-UI state or keep the part out of the TanStack render path until it has a real representation. -- Support compatibility input for current events while preferring new output shapes: - - text - - thinking/step - - tool-call - - tool-result - - state snapshot/state delta/messages snapshot - - source-url/source-document/file/custom data -- Map approval states to TanStack states: - - `approval-requested` - - `approval-responded` - - result `complete` - - result `error` - -Message types: - -- AI visible messages and approval prompts should use message types that render as bubbles where intended. -- Stream carrier events are `m.room.message`/`m.text` for compatibility, but should not render as bubbles after Desktop extraction. -- Avoid `m.notice` for visible AI chat content in AI-network rooms because Desktop hides AI `m.notice` events. - -## Approvals - -Approval requests remain separate visible Matrix events with reaction options. - -Generic reaction option shape: - -```go -type ReactionOption[T any] struct { - ID string - Label string - Values []string - Value T -} -``` - -`Values` is the complete set of strings that should match this option. Entries may be literal emoji (`👍`), symbolic reaction keys (`approval.allow_once`), short names (`allow`), or bridge-specific aliases. The helper owns normalization and matching; callers should pass strings and not branch on whether a value is an emoji or a key. - -Tool approval response shape: - -```go -type ToolApprovalResponse struct { - ID string - Approved bool - Always bool - Reason string - Fields map[string]any - Metadata map[string]any -} -``` - -`Always` supports allow-always style options without making that concept Matrix-specific. `Fields` is for flexible provider/bridge-specific approval data that should survive resolution but not force new top-level schema every time. - -Rules: - -- AG-UI stream emits a tool-call state transition to `approval-requested`. -- The tool-call part includes `approval: { id, needsApproval: true }`. -- Matrix reaction choices are transport metadata and must not be embedded into AG-UI events as the source of truth for reactions. -- Approval prompt events should relate to the first visible anchor message and include `threadId`, `runId`, `messageId`, `toolCallId`, and approval ID. -- Matrix approval event stores `com.beeper.ai.approval` with tool call ID, tool name, `threadId`, `runId`, `messageId`, expiration if any, and reaction options. -- Approval prompts are separate visible Matrix messages for actionability and reaction handling. -- Supported clients may render the same approval inline as a tool-call variant on the anchor message. To support that, duplicate semantic approval state into the AG-UI stream while keeping Matrix prompt/reaction metadata on the prompt event. -- On user reaction, the bridge resolves the option to a `ToolApprovalResponse`. -- After resolution, emit AG-UI state `approval-responded`. -- If approved, continue execution and emit tool result `complete` or `error`. -- If denied, emit a `tool-result` with `state: "error"` and structured reason `denied`; do not pretend the tool executed. -- Approval options should support flexible fields, including allow-once, allow-always, deny, reason, and provider/bridge-specific metadata. -- Keep the user's selected Matrix reaction event exactly as the visible user choice, regardless of whether it matched by emoji or symbolic key. -- Remove bridge-posted placeholder option reactions and non-selected option reactions. -- The cleanup helper should return the selected option, selected reaction event ID if known, and a list of bridge-posted reaction event IDs to remove. Actual Matrix redaction/deletion remains the bridge adapter's job. -- Programmatic approval and Matrix reaction approval must share the same resolver and produce the same stream events. - -Custom events: - -- Support AG-UI `CUSTOM` events. -- Use built-in/custom names from TanStack when they exist, such as `approval-requested`. -- Beeper-specific custom events must use a clear namespace such as `com.beeper.*`. -- Do not add random one-off custom names when an AG-UI lifecycle, tool, state, or message event already models the behavior. - -## Decisions And Remaining Behavior - -Settled decisions: - -- `pkg/ag-ui` is the Go source of truth for AG-UI concepts. Other Go packages import it instead of redefining parallel event, message, tool, or approval types. -- Desktop uses TanStack types directly for AG-UI/UI message concepts wherever possible. Desktop-local types describe Beeper transport and persistence, not a second AI message model. -- Long runs never require a final full-text Matrix edit. The final edit stores compact identity/terminal metadata and Matrix HTML preview; supported clients reconstruct complete UI state from hidden carriers. -- Final AG-UI state is complete and may be split into hidden finalization carriers. The final anchor edit remains compact and must not embed the full parts/chunks array. -- The final split format is base `MESSAGES_SNAPSHOT` plus `com.beeper.ai.final-parts` continuations with relation data and omitted parts only. -- Normal stream chunks do not include `seqTotal` or any total-count field. -- First visible message metadata owns non-part run metadata: IDs, model, usage, finish/terminal state, approval summary, and source/file/data descriptors that are metadata. It does not store streamed text chunks, thinking chunks, tool args, tool results, or full parts. -- Use mautrix in `pkg/ai-stream/matrix`. Keep bridgev2-specific queue/database/redaction behavior outside the pure AG-UI package. - -Behavior still requiring implementation or verification: - -- Dropped or invalid carriers: Desktop should mark the anchor incomplete/failed and keep carriers hidden. Do not show carrier messages as fallback bubbles. -- Missing `seq` gaps: timeout must stop infinite buffering, fail or mark incomplete on the anchor, and keep later stray carrier events hidden. -- Carrier delete/redaction: recompute from remaining carriers when possible; otherwise mark incomplete/failed. -- Replay/backfill: reconstruct the same visible AI run from persisted anchor plus carriers as live streaming. -- Approval idempotency: first valid approval resolution should win. Later reactions/programmatic responses should not re-run the tool and may be cleaned up as stale. -- Allow-always: support the field generically in approval options/responses, but dummybridge should not persist cross-run allow-always state until there is a real product storage target. -- TanStack drift: before future dependency upgrades, re-open current TanStack docs/source and update this contract deliberately. Do not silently adapt by assertions. - -## Tests - -Dummybridge Go tests: - -- Port archived parser tests. -- Verify help aliases. -- Verify command guide includes all commands. -- Verify conflicting terminal options are rejected. -- Verify invalid random profile is rejected. -- Verify oversized option inputs are rejected. -- Verify markdown-rich text generation is deterministic by seed and varied across calls. -- Verify table/link/list/code/quote markdown signals. -- Verify `stream-lorem` emits thinking, steps, text, sources, documents, files, persistent data, and excludes transient data from final snapshot. -- Verify `stream-tools` covers success, failure, approval, denial, delta input, input error, preliminary output, and provider-executed tools. -- Verify random streams finish and respect duration. -- Verify chaos streams start multiple runs with stagger and max-actions. -- Verify error and abort terminal states. - -`pkg/ag-ui` tests: - -- Validate all current AG-UI lifecycle event builders: `RUN_STARTED`, `RUN_FINISHED`, `RUN_ERROR`, `TEXT_MESSAGE_START`, `TEXT_MESSAGE_CONTENT`, `TEXT_MESSAGE_END`, `TOOL_CALL_START`, `TOOL_CALL_ARGS`, `TOOL_CALL_END`, `TOOL_CALL_RESULT`, `STEP_STARTED`, `STEP_FINISHED`, `STATE_SNAPSHOT`, `STATE_DELTA`, `MESSAGES_SNAPSHOT`, and `CUSTOM`. -- Validate event builders and required IDs. -- Validate `RunAgentInput`. -- Validate `UIMessage` ordered part shape. -- Validate tool-call and tool-result states against TanStack values. -- Validate approval request and response shapes. -- Validate step/thinking events. -- Validate `STATE_SNAPSHOT`, `STATE_DELTA`, and `MESSAGES_SNAPSHOT` event shapes. -- Validate every emitted event has `timestamp`. -- Validate `rawEvent` is optional and bounded/truncated/omitted before exceeding carrier limits. -- Validate `TOOL_CALL_START.index`. -- Validate partial JSON `TOOL_CALL_ARGS`. -- Validate `TOOL_CALL_END` with and without result. -- Validate `TOOL_CALL_RESULT` creates/updates TanStack `tool-result` parts. -- Validate run-scoped and thread/session-scoped `RUN_ERROR`. -- Validate multiple assistant `messageId`s per run. -- Reject legacy/non-AG-UI chunk shapes. - -`pkg/ai-stream` tests: - -- Verify ordered run writer output. -- Verify normal stream envelopes do not contain finalization totals such as `seqTotal`. -- Verify no per-delta accumulated full text. -- Verify final accumulator is only used at finalization. -- Verify UTF-8 splitting. -- Verify carrier packer respects the serialized JSON carrier budget. -- Verify stream reconstruction from carriers. -- Verify finalization carriers split a complete final UI message into a base snapshot plus continuation parts without repeating metadata. -- Verify finalization continuations merge deterministically by `messageId`, `runId`, and `partOffset`. -- Verify oversized text/thinking final parts split at UTF-8 boundaries and reassemble exactly. -- Verify oversized raw/debug/provider metadata is truncated or omitted before splitting structured tool/data parts. -- Verify duplicate/stale/out-of-order `seq` behavior. -- Verify missing `seq` gap timeout marks the anchor incomplete/failed. -- Verify carrier delete/redaction recomputes or marks the anchor incomplete/failed. -- Verify approval reaction resolver keeps the selected value and identifies removals. - -Over-64KB tests: - -- Generate at least 70KiB of output. -- Assert every carrier's serialized content is at or below the carrier budget. -- Assert at least two carrier events are emitted. -- Assert later carriers have no preview body or only minimal body. -- Assert reconstruction from deltas exactly equals generated output. -- Assert no final full-body edit is required to display the complete stream. -- Assert final snapshot state is complete even when split across finalization carriers. -- Assert final edit contains Matrix `formatted_body` generated by mautrix Markdown rendering and does not contain the full parts array. - -Desktop tests: - -- PAS extracts `.deltas` from decrypted carrier events. -- Carrier-only events are hidden and do not render as chat bubbles. -- Single-update and batched `updates` formats still work. -- Multi-carrier stream merges into the visible anchor message. -- Finalization base snapshot plus `com.beeper.ai.final-parts` continuations merge into one final `UIMessage`. -- Final edit arriving after carriers finalizes the existing anchor without creating a second message or flickering back to preview-only content. -- If final edit and stream/finalization carriers arrive in one sync batch, carriers are applied before streaming is stopped. -- Out-of-order `seq` buffering works. -- Duplicate/stale `seq` handling works. -- TanStack-shaped text/thinking/tool/result parts render through the AI message view. -- State snapshot, state delta, and messages snapshot events are preserved and exposed to rendering/devtools. -- Approval prompt indexing works from both visible prompt metadata and stream state. -- Approval response transitions resolve approval state. -- Parallel tool calls render/merge by distinct tool call IDs and optional indexes. -- Partial JSON tool args remain visible while streaming and finalize cleanly. -- `TOOL_CALL_END.result` renders as a completed tool result. -- `RUN_ERROR` with `runId` fails only that run; `RUN_ERROR` without `runId` fails active runs in the thread. -- Multiple assistant `messageId`s in one run render in order. -- Replay/backfill reconstructs the same visible run from persisted anchor plus carrier history as live streaming. -- Deleted/redacted carriers keep carrier bubbles hidden and mark/recompute the anchor correctly. -- Ordering gaps time out instead of buffering forever. -- Over-64KB carrier sequence reconstructs into one AI message. - -Commands to run: - -- In dummybridge: `go test -mod=readonly ./...` -- In Desktop, if `@tanstack/ai-react-ui` is not already present: ask before running any package-manager install/update command or changing any lockfile. -- In Desktop: run the existing focused test commands for touched files. At minimum cover `ai-common`, `ui-message`, `AIChatsStore`, `EventSyncContext`, and stream mapper tests. -- In Desktop: run typecheck after the focused tests. If the full repo typecheck is already failing for unrelated reasons, record the unrelated failures and separately prove touched AI files are type-clean. - -Verification status to track in the PR or completion note: - -- Dummybridge unit tests: command, date, result. -- Desktop focused tests: command, date, result. -- Desktop typecheck: command, date, result, and whether failures touch AI files. -- Source scan: prove no runtime source emits `seqTotal`. -- Over-64KB live smoke: one visible AI anchor, hidden carriers, final Matrix HTML preview, complete reconstructed supported-client state. -- Approval live smoke: visible approval prompt, selected reaction preserved, stale bridge option reactions removed, final anchor edit after response carriers. -- Random/chaos live smoke: no carrier bubbles, no stuck streaming state, no flicker to preview-only content after final edit. -- Replay/backfill smoke: persisted history reconstructs the same visible run after restart/reload. -- Redaction/gap smoke or unit coverage: carriers stay hidden and anchor becomes recomputed or incomplete/failed. - -## Live Smoke Testing - -Use bridgev2 and Desktop API, not raw Matrix sends, for end-to-end checks. - -Recommended smoke cases: - -- Create/login a QA account using the established `qatest+@beeper.com` pattern and fixed OTP only if a fresh account is needed. -- Create or reuse an AI DM through bridge-manager/Desktop API. -- Send `help` and confirm the command guide appears as a normal AI bubble. -- Send `stream-lorem 70000 --chunk-chars=512 --seed=7` and confirm Desktop shows one streaming AI message, not many carrier bubbles. -- Send `stream-tools 200 shell#approval --seed=3` and confirm the approval prompt appears separately with reaction options. -- React approve and confirm the selected emoji remains while other bridge options disappear and the tool completes. -- React deny and confirm the tool is cancelled/denied and does not execute. -- Send `stream-random 5 --actions=8 --allow-approval --seed=9`. -- Send `stream-chaos 3 5 --max-actions=5 --seed=11`. - -Acceptance criteria: - -- All carrier events are encrypted in E2EE rooms. -- No plaintext raw Matrix sends are used. -- Visible AI output uses bubble-rendering message types. -- Carrier events do not show as separate bubbles. -- Streaming remains incremental. -- Over-64KB output reconstructs correctly. -- Finalized over-64KB runs still have one visible anchor message, complete supported-client AG-UI state, and bounded Matrix HTML preview on the final edit. -- Approvals work from Matrix reactions and programmatic/TanStack-shaped responses. -- The selected approval emoji is kept and non-selected placeholder options are removed. diff --git a/pkg/ag-ui/events.go b/pkg/ag-ui/events.go index d3ea264..5429877 100644 --- a/pkg/ag-ui/events.go +++ b/pkg/ag-ui/events.go @@ -47,6 +47,8 @@ const ( ToolResultStateStreaming = "streaming" ToolResultStateComplete = "complete" ToolResultStateError = "error" + PartStateStreaming = "streaming" + PartStateDone = "done" ApprovalCustomRequested = "approval-requested" ApprovalCustomResponded = "approval-responded" FinishReasonStop = "stop" diff --git a/pkg/ai-stream/bridgev2/events.go b/pkg/ai-stream/bridgev2/events.go index 5f5c7b9..311b874 100644 --- a/pkg/ai-stream/bridgev2/events.go +++ b/pkg/ai-stream/bridgev2/events.go @@ -14,78 +14,58 @@ import ( "maunium.net/go/mautrix/id" ) +func eventMeta(eventType bridgev2.RemoteEventType, portalKey networkid.PortalKey, sender networkid.UserID, timestamp time.Time) simplevent.EventMeta { + return simplevent.EventMeta{ + Type: eventType, + PortalKey: portalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: timestamp, + StreamOrder: timestamp.UnixNano(), + } +} + +func messagePart(content *event.MessageEventContent, extra map[string]any, dbMetadata map[string]any) *bridgev2.ConvertedMessagePart { + return &bridgev2.ConvertedMessagePart{ + ID: networkid.PartID("0"), + Type: event.EventMessage, + Content: content, + Extra: extra, + DBMetadata: dbMetadata, + } +} + func Anchor(portalKey networkid.PortalKey, sender networkid.UserID, run aistream.Run, timestamp time.Time) *simplevent.PreConvertedMessage { content, extra := aimatrix.AnchorContent(run) return &simplevent.PreConvertedMessage{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventMessage, - PortalKey: portalKey, - Sender: bridgev2.EventSender{Sender: sender}, - Timestamp: timestamp, - StreamOrder: timestamp.UnixNano(), - }, - Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: content, - Extra: extra, - }}}, - ID: networkid.MessageID(run.MessageID), + EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), + Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{messagePart(content, extra, nil)}}, + ID: networkid.MessageID(run.MessageID), } } func Carrier(portalKey networkid.PortalKey, sender networkid.UserID, run aistream.Run, carrier aistream.Carrier, targetEventID id.EventID, index int, timestamp time.Time) *simplevent.PreConvertedMessage { content, extra := aimatrix.CarrierContent(carrier, targetEventID) return &simplevent.PreConvertedMessage{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventMessage, - PortalKey: portalKey, - Sender: bridgev2.EventSender{Sender: sender}, - Timestamp: timestamp, - StreamOrder: timestamp.UnixNano(), - }, - Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: content, - Extra: extra, - }}}, - ID: networkid.MessageID(aistream.StreamTxnID(run.RunID, index)), + EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), + Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{messagePart(content, extra, nil)}}, + ID: networkid.MessageID(aistream.StreamTxnID(run.RunID, index)), } } func ApprovalPrompt(portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, timestamp time.Time) *simplevent.PreConvertedMessage { content, extra := aimatrix.ApprovalContent(ctx, aistream.DefaultApprovalOptions(ctx.ID)) return &simplevent.PreConvertedMessage{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventMessage, - PortalKey: portalKey, - Sender: bridgev2.EventSender{Sender: sender}, - Timestamp: timestamp, - StreamOrder: timestamp.UnixNano(), - }, - Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: content, - Extra: extra, - DBMetadata: map[string]any{ - "com.beeper.ai.approval": ctx, - }, - }}}, + EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), + Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{ + messagePart(content, extra, map[string]any{"com.beeper.ai.approval": ctx}), + }}, ID: networkid.MessageID(ctx.ID), } } func ApprovalOptionReaction[T any](portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, option aistream.ReactionOption[T], timestamp time.Time) *simplevent.Reaction { return &simplevent.Reaction{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventReaction, - PortalKey: portalKey, - Sender: bridgev2.EventSender{Sender: sender}, - Timestamp: timestamp, - StreamOrder: timestamp.UnixNano(), - }, + EventMeta: eventMeta(bridgev2.RemoteEventReaction, portalKey, sender, timestamp), TargetMessage: networkid.MessageID(ctx.ID), EmojiID: networkid.EmojiID(option.ID), Emoji: option.Values[0], @@ -103,13 +83,7 @@ func ApprovalOptionReaction[T any](portalKey networkid.PortalKey, sender network func FinalMetadataEdit(portalKey networkid.PortalKey, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, timestamp time.Time) *simplevent.Message[*aistream.Run] { finalContent, finalExtra := aimatrix.AnchorContent(run) return &simplevent.Message[*aistream.Run]{ - EventMeta: simplevent.EventMeta{ - Type: bridgev2.RemoteEventEdit, - PortalKey: portalKey, - Sender: bridgev2.EventSender{Sender: sender}, - Timestamp: timestamp, - StreamOrder: timestamp.UnixNano(), - }, + EventMeta: eventMeta(bridgev2.RemoteEventEdit, portalKey, sender, timestamp), Data: &run, ID: messageID, TargetMessage: messageID, diff --git a/pkg/ai-stream/pack.go b/pkg/ai-stream/pack.go index d94168a..e54cb15 100644 --- a/pkg/ai-stream/pack.go +++ b/pkg/ai-stream/pack.go @@ -67,6 +67,8 @@ func PackRunFromSeq(run Run, targetEventID string, budget int, startSeq int) ([] } var carriers []Carrier var current Carrier + currentSize := 0 + emptyCarrierOverhead := JSONSize(CarrierContent([]Envelope{})) seq := startSeq for _, original := range run.Events { for _, part := range splitEventForBudget(original, budget) { @@ -74,14 +76,25 @@ func PackRunFromSeq(run Run, targetEventID string, budget int, startSeq int) ([] if err != nil { return nil, err } - single := CarrierContent([]Envelope{env}) - if JSONSize(single) > budget { + envSize := JSONSize(env) + if emptyCarrierOverhead+envSize > budget { return nil, fmt.Errorf("stream envelope %d exceeds %d byte budget", seq, budget) } - candidate := append(append([]Envelope{}, current.Envelopes...), env) - if len(current.Envelopes) > 0 && JSONSize(CarrierContent(candidate)) > budget { + // +1 for the comma separator between envelopes in the JSON array. + addedSize := envSize + if len(current.Envelopes) > 0 { + addedSize++ + } + if len(current.Envelopes) > 0 && currentSize+addedSize > budget { carriers = append(carriers, current) current = Carrier{} + currentSize = 0 + addedSize = envSize + } + if len(current.Envelopes) == 0 { + currentSize = emptyCarrierOverhead + envSize + } else { + currentSize += addedSize } current.Envelopes = append(current.Envelopes, env) seq++ @@ -93,22 +106,6 @@ func PackRunFromSeq(run Run, targetEventID string, budget int, startSeq int) ([] return carriers, nil } -func eventTimestampMillis(evt agui.Event) int64 { - switch value := evt["timestamp"].(type) { - case int64: - return value - case int: - return int64(value) - case float64: - return int64(value) - case json.Number: - n, _ := value.Int64() - return n - default: - return 0 - } -} - func NextSeq(carriers []Carrier) int { next := 1 for _, carrier := range carriers { diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index 0dbc351..1225260 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -365,6 +365,7 @@ func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { } var textPart agui.MessagePart var thinkingPart agui.MessagePart + var textContent, thinkingContent strings.Builder toolParts := map[string]agui.MessagePart{} toolResultParts := map[string]agui.MessagePart{} approvalByID := map[string]any{} @@ -380,12 +381,12 @@ func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { continue } if textPart == nil { - textPart = appendPart(agui.MessagePart{"type": "text", "content": "", "state": "streaming"}) + textPart = appendPart(agui.MessagePart{"type": "text", "content": "", "state": agui.PartStateStreaming}) } - textPart["content"] = asString(textPart["content"]) + delta + textContent.WriteString(delta) case agui.EventTextMessageEnd: if textPart != nil { - textPart["state"] = "done" + textPart["state"] = agui.PartStateDone } case agui.EventReasoningMsgCont: delta, _ := evt["delta"].(string) @@ -393,12 +394,12 @@ func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { continue } if thinkingPart == nil { - thinkingPart = appendPart(agui.MessagePart{"type": "thinking", "content": "", "state": "streaming"}) + thinkingPart = appendPart(agui.MessagePart{"type": "thinking", "content": "", "state": agui.PartStateStreaming}) } - thinkingPart["content"] = asString(thinkingPart["content"]) + delta + thinkingContent.WriteString(delta) case agui.EventReasoningMsgEnd: if thinkingPart != nil { - thinkingPart["state"] = "done" + thinkingPart["state"] = agui.PartStateDone } case agui.EventToolCallStart: toolCallID, _ := evt["toolCallId"].(string) @@ -495,6 +496,12 @@ func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { } } } + if textPart != nil { + textPart["content"] = textContent.String() + } + if thinkingPart != nil { + thinkingPart["content"] = thinkingContent.String() + } compactTextPart(textPart, textBudget) compactTextPart(thinkingPart, textBudget) return message @@ -519,7 +526,7 @@ func compactTextPart(part agui.MessagePart, budget int) { part["providerMetadata"] = map[string]any{"truncated": true} } if part["state"] == "" { - part["state"] = "done" + part["state"] = agui.PartStateDone } } diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 47e1f39..03b3f7d 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -143,25 +143,6 @@ type aiRunPlan struct { Delay time.Duration } -func defaultAIRuntime() aiRuntime { - return aiRuntime{ - now: time.Now, - sleep: func(ctx context.Context, delay time.Duration) error { - if delay <= 0 { - return nil - } - timer := time.NewTimer(delay) - defer timer.Stop() - select { - case <-timer.C: - return nil - case <-ctx.Done(): - return ctx.Err() - } - }, - } -} - func virtualAIRuntime(now time.Time) aiRuntime { current := now return aiRuntime{ @@ -288,12 +269,7 @@ func parseCommand(input string) (*parsedCommand, error) { return &parsedCommand{Name: "help"}, nil } switch strings.ToLower(tokens[0]) { - case "help", "/help", "!help": - return &parsedCommand{Name: "help"}, nil - case "dummybridge": - if len(tokens) > 1 && strings.EqualFold(tokens[1], "help") { - return &parsedCommand{Name: "help"}, nil - } + case "help", "/help", "!help", "dummybridge": return &parsedCommand{Name: "help"}, nil case "stream-lorem": cmd, err := parseLoremCommand(tokens[1:]) @@ -743,6 +719,7 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC } stepOpen := false stepName := "" + actionOptions, actionWeightTotal := buildRandomActionOptions(cmd) for action := range cmd.Actions { if !deadline.IsZero() && !r.runtime.now().Before(deadline) { break @@ -759,7 +736,7 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC break } } - switch chooseRandomAction(cmd, rng) { + switch pickWeighted(actionOptions, actionWeightTotal, rng) { case randomActionText: for _, chunk := range chunkText(buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))), rng, defaultChunkMin, defaultChunkMax) { w.Text(chunk) @@ -944,7 +921,7 @@ func statePatch(values map[string]any) []map[string]any { return patch } -func chooseRandomAction(cmd randomCommand, rng *rand.Rand) string { +func buildRandomActionOptions(cmd randomCommand) ([]randomActionOption, int) { options := []randomActionOption{ {randomActionText, 6}, {randomActionThinking, 4}, @@ -987,6 +964,13 @@ func chooseRandomAction(cmd randomCommand, rng *rand.Rand) string { for _, option := range options { total += option.weight } + return options, total +} + +func pickWeighted(options []randomActionOption, total int, rng *rand.Rand) string { + if total <= 0 || len(options) == 0 { + return randomActionText + } pick := rng.Intn(total) for _, option := range options { if pick < option.weight { diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 97feefe..6de6b66 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -439,7 +439,8 @@ func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { seen := map[string]bool{} rng := rand.New(rand.NewSource(4)) for range 400 { - seen[chooseRandomAction(cmd, rng)] = true + options, total := buildRandomActionOptions(cmd) + seen[pickWeighted(options, total, rng)] = true } for _, action := range []string{randomActionTool, randomActionToolFail, randomActionToolDeny, randomActionToolApproval} { if !seen[action] { @@ -451,7 +452,8 @@ func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { seen = map[string]bool{} rng = rand.New(rand.NewSource(8)) for range 400 { - seen[chooseRandomAction(cmd, rng)] = true + options, total := buildRandomActionOptions(cmd) + seen[pickWeighted(options, total, rng)] = true } for _, action := range []string{randomActionSource, randomActionDocument, randomActionFile, randomActionMetadata, randomActionData, randomActionDataTransient} { if !seen[action] { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 7f0a41b..e7933e6 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -257,7 +257,7 @@ func (dc *DummyClient) PreHandleMatrixReaction(_ context.Context, msg *bridgev2. if dc != nil && dc.UserLogin != nil { senderID = networkid.UserID(dc.UserLogin.ID) } - key := normalizeApprovalReaction(msg.Content.RelatesTo.Key) + key := aistream.NormalizeReaction(msg.Content.RelatesTo.Key) return bridgev2.MatrixReactionPreResponse{ SenderID: senderID, EmojiID: networkid.EmojiID(key), @@ -274,7 +274,7 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M if !strings.HasPrefix(approvalID, "approval-") { return &database.Reaction{}, nil } - reaction := normalizeApprovalReaction(msg.Content.RelatesTo.Key) + reaction := aistream.NormalizeReaction(msg.Content.RelatesTo.Key) selected, ok := aistream.ResolveReaction(aistream.DefaultApprovalOptions(approvalID), reaction) if !ok { return &database.Reaction{}, nil @@ -376,10 +376,6 @@ func (dc *DummyClient) HandleMatrixReactionRemove(ctx context.Context, msg *brid return nil } -func normalizeApprovalReaction(reaction string) string { - return strings.TrimSpace(strings.ReplaceAll(reaction, "\ufe0f", "")) -} - func getTransactionID(msg *bridgev2.MatrixMessage) networkid.TransactionID { if msg.Event != nil && msg.Event.Unsigned.TransactionID != "" { return networkid.TransactionID(msg.Event.Unsigned.TransactionID) @@ -469,7 +465,11 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por now := time.Now() runID := "run-" + string(randomMessageID()) - plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), inboundBody(inbound), now) + var body string + if inbound != nil { + body = inbound.Body + } + plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), body, now) if err != nil { log.Warn().Err(err).Msg("Failed to build AI runs") return @@ -547,21 +547,21 @@ func (dc *DummyClient) waitForMessageMXID( if dc.UserLogin.ID != "" && dc.UserLogin.ID != portal.Receiver { receivers = append(receivers, dc.UserLogin.ID) } - ticker := time.NewTicker(50 * time.Millisecond) + ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() for ctx.Err() == nil { + select { + case <-ctx.Done(): + return "" + case <-ticker.C: + } for _, receiver := range receivers { mxid := dc.lookupMessageMXID(ctx, receiver, messageID) if mxid != "" { return mxid } } - select { - case <-ctx.Done(): - return "" - case <-ticker.C: - } } return "" } @@ -732,13 +732,6 @@ func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, messageI dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, aiGhostID, messageID, run, time.Now())) } -func inboundBody(content *event.MessageEventContent) string { - if content == nil { - return "" - } - return content.Body -} - func (dc *DummyClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { // bridgev2 will delete the portal + Matrix room after this returns nil. // For dummybridge, there's no separate remote-side deletion to do. From 9962a2c7d848deb2a179487a19953222c28a8f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:15:11 +0200 Subject: [PATCH 04/46] Escape AI anchor fallback HTML --- pkg/ai-stream/matrix/content.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go index bb86667..d211893 100644 --- a/pkg/ai-stream/matrix/content.go +++ b/pkg/ai-stream/matrix/content.go @@ -2,6 +2,7 @@ package matrix import ( "fmt" + "html" "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" @@ -18,7 +19,7 @@ func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any rendered := format.RenderMarkdown(body, true, false) if rendered.Format != event.FormatHTML { rendered.Format = event.FormatHTML - rendered.FormattedBody = rendered.Body + rendered.FormattedBody = html.EscapeString(rendered.Body) } content := &rendered content.BeeperPerMessageProfile = &event.BeeperPerMessageProfile{ From 293ceefb42cad7b469f7b3b9b3ed6d90448c9b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:15:20 +0200 Subject: [PATCH 05/46] Handle raw event marshal errors --- pkg/ai-stream/pack.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/ai-stream/pack.go b/pkg/ai-stream/pack.go index e54cb15..7b02c71 100644 --- a/pkg/ai-stream/pack.go +++ b/pkg/ai-stream/pack.go @@ -302,8 +302,11 @@ func sanitizeRawEvent(evt agui.Event, budget int) agui.Event { return cp } raw, err := json.Marshal(cp["rawEvent"]) - if err != nil || len(raw) > 2048 { - cp["rawEvent"] = string(raw[:min(len(raw), 2048)]) + if err != nil { + delete(cp, "rawEvent") + cp["rawEventTruncated"] = true + } else if len(raw) > 2048 { + cp["rawEvent"] = string(raw[:2048]) cp["rawEventTruncated"] = true } if JSONSize(cp) > budget { From 2e6431441b80c67ae806ce15cef7a9fe7f9ed171 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:15:30 +0200 Subject: [PATCH 06/46] Guard empty approval reaction values --- pkg/ai-stream/bridgev2/events.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/ai-stream/bridgev2/events.go b/pkg/ai-stream/bridgev2/events.go index 311b874..5381625 100644 --- a/pkg/ai-stream/bridgev2/events.go +++ b/pkg/ai-stream/bridgev2/events.go @@ -64,11 +64,15 @@ func ApprovalPrompt(portalKey networkid.PortalKey, sender networkid.UserID, ctx } func ApprovalOptionReaction[T any](portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, option aistream.ReactionOption[T], timestamp time.Time) *simplevent.Reaction { + emoji := option.ID + if len(option.Values) > 0 { + emoji = option.Values[0] + } return &simplevent.Reaction{ EventMeta: eventMeta(bridgev2.RemoteEventReaction, portalKey, sender, timestamp), TargetMessage: networkid.MessageID(ctx.ID), EmojiID: networkid.EmojiID(option.ID), - Emoji: option.Values[0], + Emoji: emoji, ExtraContent: map[string]any{ "com.beeper.ai.approval_option": map[string]any{ "approvalId": ctx.ID, From c42e52919cfd3a1772c476ab246814dc708eecf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:15:52 +0200 Subject: [PATCH 07/46] Propagate random tool errors --- pkg/connector/ai_runtime.go | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 03b3f7d..ffa48b7 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -720,6 +720,17 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC stepOpen := false stepName := "" actionOptions, actionWeightTotal := buildRandomActionOptions(cmd) + handleTool := func(spec toolSpec) error { + if err := r.runToolSpec(ctx, w, spec, rng, defaultCommonOptions()); err != nil { + if errors.Is(err, errApprovalRequested) && stepOpen { + w.StepFinish(stepName) + stepOpen = false + stepName = "" + } + return err + } + return nil + } for action := range cmd.Actions { if !deadline.IsZero() && !r.runtime.now().Before(deadline) { break @@ -754,13 +765,21 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC stepOpen = true } case randomActionTool: - _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}, rng, defaultCommonOptions()) + if err := handleTool(toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}); err != nil { + return err + } case randomActionToolFail: - _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), Fail: true, SequenceIndex: action + 1}, rng, defaultCommonOptions()) + if err := handleTool(toolSpec{Name: randomToolName(rng), Fail: true, SequenceIndex: action + 1}); err != nil { + return err + } case randomActionToolDeny: - _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), Deny: true, SequenceIndex: action + 1}, rng, defaultCommonOptions()) + if err := handleTool(toolSpec{Name: randomToolName(rng), Deny: true, SequenceIndex: action + 1}); err != nil { + return err + } case randomActionToolApproval: - _ = r.runToolSpec(ctx, w, toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}, rng, defaultCommonOptions()) + if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { + return err + } case randomActionSource: w.Custom("com.beeper.source", map[string]any{"url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) case randomActionDocument: From 7317dd373d0d7f4a8b90771cc5d9090b8173dbf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:16:04 +0200 Subject: [PATCH 08/46] Trim oversized first markdown block --- pkg/connector/ai_text.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go index 960fe8e..f9b083e 100644 --- a/pkg/connector/ai_text.go +++ b/pkg/connector/ai_text.go @@ -181,7 +181,10 @@ func trimVisibleText(text string, limit int) string { if len(kept) > 0 { next += 2 } - if next > limit && len(kept) > 0 { + if next > limit { + if len(kept) == 0 { + kept = append(kept, trimText(block, limit)) + } break } kept = append(kept, block) From ca455e4e5e6c8ef10e877b829b9dba91a38e5775 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:16:14 +0200 Subject: [PATCH 09/46] Clean up duplicate approval reactions --- pkg/connector/client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index e7933e6..138fa03 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -281,7 +281,7 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M } selectedKey, firstResolution := dc.resolveApprovalOnce(approvalID, reaction) - dc.cleanupApprovalReactions(ctx, msg.Portal, networkid.MessageID(approvalID), selectedKey, msg) + dc.cleanupApprovalReactions(ctx, msg.Portal, networkid.MessageID(approvalID), selectedKey, reaction, msg) if !firstResolution { log.Info(). Str("approval_id", approvalID). @@ -315,7 +315,7 @@ func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (stri return selectedKey, true } -func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bridgev2.Portal, approvalMessageID networkid.MessageID, selectedKey string, msg *bridgev2.MatrixReaction) { +func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bridgev2.Portal, approvalMessageID networkid.MessageID, selectedKey, reactionKey string, msg *bridgev2.MatrixReaction) { if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { return } @@ -343,7 +343,7 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri events = append(events, aistream.ReactionEvent{ EventID: string(msg.Event.ID), Sender: string(msg.Event.Sender), - Key: selectedKey, + Key: reactionKey, }) } cleanup := aistream.CleanupReactions(aistream.DefaultApprovalOptions(string(approvalMessageID)), selectedKey, events, string(aiGhostID)) From 658ca92630080762f675d18c72b9ab0e22cd3bb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:16:23 +0200 Subject: [PATCH 10/46] Track AI stream goroutines --- pkg/connector/client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 138fa03..7e90cf6 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -482,7 +482,11 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por placeholderID := networkid.MessageID(plan.Run.MessageID) dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, aiGhostID, initialAIAnchorRun(*plan.Run), timestamp)) - go dc.queueAIRunStreamAndMetadata(portal, placeholderID, *plan.Run) + dc.wg.Add(1) + go func(portal *bridgev2.Portal, messageID networkid.MessageID, run aistream.Run) { + defer dc.wg.Done() + dc.queueAIRunStreamAndMetadata(portal, messageID, run) + }(portal, placeholderID, *plan.Run) } } From 236bc45fef26a8c8bf79062db1c543ea6ef5f3cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:16:33 +0200 Subject: [PATCH 11/46] Use MessageDB for AI anchor lookup --- pkg/connector/client.go | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 7e90cf6..14b3551 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -571,18 +571,11 @@ func (dc *DummyClient) waitForMessageMXID( } func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) id.EventID { - var mxid id.EventID - err := dc.UserLogin.Bridge.DB.Message.GetDB().QueryRow( - ctx, - `SELECT mxid FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`, - dc.UserLogin.Bridge.DB.Message.BridgeID, - receiver, - messageID, - ).Scan(&mxid) - if err != nil { + message, err := dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, receiver, messageID) + if err != nil || message == nil { return "" } - return mxid + return message.MXID } func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, timestamp time.Time) { From d2d0231e0868cb873443ae1106830bac6c7b52c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:16:43 +0200 Subject: [PATCH 12/46] Split AI chunks on UTF-8 boundaries --- pkg/connector/ai_runtime.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index ffa48b7..170182a 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -1139,8 +1139,10 @@ func chunkText(text string, rng *rand.Rand, minChunk, maxChunk int) []string { if size > len(text) { size = len(text) } - chunks = append(chunks, text[:size]) - text = text[size:] + parts := aistream.SplitTextUTF8(text, size) + chunk := parts[0] + chunks = append(chunks, chunk) + text = text[len(chunk):] } return chunks } From c5400be9af571874f408921acaf18490e6d550bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:24:03 +0200 Subject: [PATCH 13/46] Fix lint issues --- pkg/connector/client.go | 4 ++-- pkg/connector/commands.go | 3 +-- pkg/connector/generators.go | 6 +++--- pkg/connector/message_requests.go | 14 ++++++++------ 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 14b3551..d991ad0 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -48,8 +48,8 @@ var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) const ( aiGhostID networkid.UserID = "ai" - aiGhostName = "AI" - aiPortalIDPrefix = "ai-" + aiGhostName string = "AI" + aiPortalIDPrefix string = "ai-" ) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index 9722482..630a004 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -226,8 +226,7 @@ var FileCommand = &commands.FullHandler{ Func: func(e *commands.Event) { e.Reply("Generating file event in this room") - var mediaData []byte - mediaData = []byte("Test text file") + mediaData := []byte("Test text file") mediaName := "test.txt" mediaMime := "text/plain" diff --git a/pkg/connector/generators.go b/pkg/connector/generators.go index 5c812c2..a3ebf56 100644 --- a/pkg/connector/generators.go +++ b/pkg/connector/generators.go @@ -58,8 +58,8 @@ func generatePortal(ctx context.Context, br *bridgev2.Bridge, login *bridgev2.Us Type: ptr.Ptr(roomType), CanBackfill: true, Members: &bridgev2.ChatMemberList{ - Members: []bridgev2.ChatMember{ - { + MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(login.ID): { EventSender: bridgev2.EventSender{ IsFromMe: true, Sender: networkid.UserID(login.ID), @@ -78,7 +78,7 @@ func generatePortal(ctx context.Context, br *bridgev2.Bridge, login *bridgev2.Us return nil, fmt.Errorf("failed to get ghost by id: %w", err) } - chatInfo.Members.Members = append(chatInfo.Members.Members, bridgev2.ChatMember{ + chatInfo.Members.MemberMap.Set(bridgev2.ChatMember{ EventSender: bridgev2.EventSender{ Sender: userID, }, diff --git a/pkg/connector/message_requests.go b/pkg/connector/message_requests.go index a1b3b21..512964c 100644 --- a/pkg/connector/message_requests.go +++ b/pkg/connector/message_requests.go @@ -113,11 +113,13 @@ func createMessageRequestPortal( Type: ptr.Ptr(roomType), MessageRequest: &isMessageRequest, CanBackfill: true, - Members: &bridgev2.ChatMemberList{Members: []bridgev2.ChatMember{{ - EventSender: bridgev2.EventSender{IsFromMe: true, Sender: networkid.UserID(login.ID)}, - Membership: event.MembershipJoin, - PowerLevel: ptr.Ptr(100), - }}}, + Members: &bridgev2.ChatMemberList{MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(login.ID): { + EventSender: bridgev2.EventSender{IsFromMe: true, Sender: networkid.UserID(login.ID)}, + Membership: event.MembershipJoin, + PowerLevel: ptr.Ptr(100), + }, + }}, } firstGhost := stablePortalUserIDByIndex(portalID, 0) @@ -128,7 +130,7 @@ func createMessageRequestPortal( return nil, fmt.Errorf("failed to get ghost by id: %w", err) } ghost.UpdateName(ctx, fmt.Sprintf("Dummy User %d", i+1)) - chatInfo.Members.Members = append(chatInfo.Members.Members, bridgev2.ChatMember{ + chatInfo.Members.MemberMap.Set(bridgev2.ChatMember{ EventSender: bridgev2.EventSender{Sender: userID}, Membership: event.MembershipJoin, }) From e297b00119f4803345a5a56de426d0ccccb912d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:30:44 +0200 Subject: [PATCH 14/46] simplify --- pkg/ai-stream/matrix/content.go | 29 +++++++---------------------- pkg/connector/ai_runtime.go | 6 +++++- pkg/connector/client.go | 20 ++++++++------------ 3 files changed, 20 insertions(+), 35 deletions(-) diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go index d211893..7a547d5 100644 --- a/pkg/ai-stream/matrix/content.go +++ b/pkg/ai-stream/matrix/content.go @@ -2,7 +2,6 @@ package matrix import ( "fmt" - "html" "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" @@ -17,11 +16,8 @@ func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any body = "..." } rendered := format.RenderMarkdown(body, true, false) - if rendered.Format != event.FormatHTML { - rendered.Format = event.FormatHTML - rendered.FormattedBody = html.EscapeString(rendered.Body) - } content := &rendered + content.EnsureHasHTML() content.BeeperPerMessageProfile = &event.BeeperPerMessageProfile{ ID: run.AgentID, Displayname: run.AgentName, @@ -37,16 +33,9 @@ func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any } func CarrierContent(carrier aistream.Carrier, targetEventID id.EventID) (*event.MessageEventContent, map[string]any) { - content := &event.MessageEventContent{ - MsgType: event.MsgText, - Body: "", - Mentions: &event.Mentions{}, - RelatesTo: &event.RelatesTo{ - Type: event.RelReference, - EventID: targetEventID, - }, - } - return content, aistream.CarrierContent(carrier.Envelopes) + content := format.TextToContent("") + content.SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: targetEventID}) + return &content, aistream.CarrierContent(carrier.Envelopes) } func ApprovalContent(ctx aistream.ApprovalContext, options []aistream.ReactionOption[agui.ToolApprovalResponse]) (*event.MessageEventContent, map[string]any) { @@ -55,13 +44,9 @@ func ApprovalContent(ctx aistream.ApprovalContext, options []aistream.ReactionOp if len(options) > 0 { body += "\nReact with one of the listed options." } - content := &event.MessageEventContent{ - MsgType: event.MsgText, - Body: body, - Mentions: &event.Mentions{}, - } + content := format.TextToContent(body) if ctx.TargetEvent != "" { - content.RelatesTo = &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(ctx.TargetEvent)} + content.SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(ctx.TargetEvent)}) } extra := map[string]any{ "com.beeper.ai.approval": map[string]any{ @@ -78,7 +63,7 @@ func ApprovalContent(ctx aistream.ApprovalContext, options []aistream.ReactionOp "reactions": ReactionOptionsAsAny(options), }, } - return content, extra + return &content, extra } func ReactionOptionsAsAny(options []aistream.ReactionOption[agui.ToolApprovalResponse]) []any { diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 170182a..014fe95 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -13,6 +13,7 @@ import ( "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" + "go.mau.fi/util/shlex" ) var errApprovalRequested = errors.New("approval requested") @@ -264,7 +265,10 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t } func parseCommand(input string) (*parsedCommand, error) { - tokens := strings.Fields(strings.TrimSpace(input)) + tokens, err := shlex.Split(input) + if err != nil { + return nil, fmt.Errorf("invalid command syntax: %w", err) + } if len(tokens) == 0 { return &parsedCommand{Name: "help"}, nil } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index d991ad0..c99f9c2 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -14,6 +14,7 @@ import ( "github.com/beeper/dummybridge/pkg/ai-stream" aibridgev2 "github.com/beeper/dummybridge/pkg/ai-stream/bridgev2" "github.com/rs/zerolog/log" + "go.mau.fi/util/exsync" "go.mau.fi/util/jsontime" "go.mau.fi/util/ptr" @@ -34,8 +35,8 @@ type DummyClient struct { UserLogin *bridgev2.UserLogin Connector *DummyConnector - approvalMu sync.Mutex - approvalSelections map[string]string + approvalSelectionsOnce sync.Once + approvalSelections *exsync.Map[string, string] } var _ bridgev2.NetworkAPI = (*DummyClient)(nil) @@ -303,16 +304,11 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M } func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (string, bool) { - dc.approvalMu.Lock() - defer dc.approvalMu.Unlock() - if dc.approvalSelections == nil { - dc.approvalSelections = make(map[string]string) - } - if existing := dc.approvalSelections[approvalID]; existing != "" { - return existing, false - } - dc.approvalSelections[approvalID] = selectedKey - return selectedKey, true + dc.approvalSelectionsOnce.Do(func() { + dc.approvalSelections = exsync.NewMap[string, string]() + }) + selected, alreadyResolved := dc.approvalSelections.GetOrSet(approvalID, selectedKey) + return selected, !alreadyResolved } func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bridgev2.Portal, approvalMessageID networkid.MessageID, selectedKey, reactionKey string, msg *bridgev2.MatrixReaction) { From 0ad5e7e97ac69970bcf535a6dd4c499f8dca1e0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 22:31:58 +0200 Subject: [PATCH 15/46] Use Go 1.25 in Docker builds --- cmd/dummybridge/Dockerfile | 2 +- cmd/loginhelper/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/dummybridge/Dockerfile b/cmd/dummybridge/Dockerfile index 65bcc74..fe2a7c3 100644 --- a/cmd/dummybridge/Dockerfile +++ b/cmd/dummybridge/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1-alpine3.23 AS builder +FROM golang:1.25-alpine3.23 AS builder RUN apk add --no-cache build-base olm-dev diff --git a/cmd/loginhelper/Dockerfile b/cmd/loginhelper/Dockerfile index 856ff03..c9f20b6 100644 --- a/cmd/loginhelper/Dockerfile +++ b/cmd/loginhelper/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1-alpine3.20 AS builder +FROM golang:1.25-alpine3.23 AS builder RUN apk add --no-cache build-base From abd50d759e3182205e3af910d061d033e5a602f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:13 +0200 Subject: [PATCH 16/46] Guard approval logging for nil Matrix events --- pkg/connector/client.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index c99f9c2..41a7c4b 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -293,12 +293,14 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M } dc.queueAIApprovalResponse(ctx, msg.Portal, msg.TargetMessage, selected.Value) - log.Info(). + logger := log.Info(). Str("approval_id", approvalID). Str("reaction", reaction). - Bool("approved", selected.Value.Approved). - Stringer("sender", msg.Event.Sender). - Msg("Resolved dummy AI approval from Matrix reaction") + Bool("approved", selected.Value.Approved) + if msg.Event != nil { + logger = logger.Stringer("sender", msg.Event.Sender) + } + logger.Msg("Resolved dummy AI approval from Matrix reaction") return &database.Reaction{}, nil } From 8facf1e4ac0b86bcbeb613c7ab6c9f8c0aabc1a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:18 +0200 Subject: [PATCH 17/46] Use normalized sender for approval cleanup reactions --- pkg/connector/client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 41a7c4b..c958959 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -338,9 +338,13 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri }) } if msg != nil && msg.Event != nil && msg.Event.ID != "" { + senderID := string(msg.Event.Sender) + if msg.PreHandleResp != nil && msg.PreHandleResp.SenderID != "" { + senderID = string(msg.PreHandleResp.SenderID) + } events = append(events, aistream.ReactionEvent{ EventID: string(msg.Event.ID), - Sender: string(msg.Event.Sender), + Sender: senderID, Key: reactionKey, }) } From fff178bb297ebac15f4d016cdea36a651e2086ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:21 +0200 Subject: [PATCH 18/46] Populate AI chat members with MemberMap --- pkg/connector/client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index c958959..d8b052e 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -842,8 +842,8 @@ func (dc *DummyClient) resolveAIIdentifier(ctx context.Context, createChat bool) Type: ptr.Ptr(roomType), CanBackfill: true, Members: &bridgev2.ChatMemberList{ - Members: []bridgev2.ChatMember{ - { + MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(dc.UserLogin.ID): { EventSender: bridgev2.EventSender{ IsFromMe: true, Sender: networkid.UserID(dc.UserLogin.ID), @@ -851,7 +851,7 @@ func (dc *DummyClient) resolveAIIdentifier(ctx context.Context, createChat bool) Membership: event.MembershipJoin, PowerLevel: ptr.Ptr(100), }, - { + aiGhostID: { EventSender: bridgev2.EventSender{ Sender: aiGhostID, }, From 7b9d87c68b4c0ce0af01b27ab0499dccf374fb29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:27 +0200 Subject: [PATCH 19/46] Run dummybridge container as non-root --- cmd/dummybridge/Dockerfile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cmd/dummybridge/Dockerfile b/cmd/dummybridge/Dockerfile index fe2a7c3..41bdc19 100644 --- a/cmd/dummybridge/Dockerfile +++ b/cmd/dummybridge/Dockerfile @@ -9,6 +9,9 @@ RUN go build -o /usr/bin/dummybridge ./cmd/dummybridge FROM alpine:3.20 -RUN apk add --no-cache ca-certificates olm su-exec bash jq yq curl -COPY --from=builder /usr/bin/dummybridge /usr/bin/dummybridge +RUN apk add --no-cache ca-certificates olm su-exec bash jq yq curl \ + && addgroup -S dummybridge \ + && adduser -S -G dummybridge dummybridge +COPY --from=builder --chown=dummybridge:dummybridge /usr/bin/dummybridge /usr/bin/dummybridge +USER dummybridge:dummybridge CMD ["/usr/bin/dummybridge"] From 57f1e52d1ac376e6ddda304dd9017af48b60eb02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:30 +0200 Subject: [PATCH 20/46] Align dummybridge runtime Alpine with builder --- cmd/dummybridge/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/dummybridge/Dockerfile b/cmd/dummybridge/Dockerfile index 41bdc19..ce4ab2d 100644 --- a/cmd/dummybridge/Dockerfile +++ b/cmd/dummybridge/Dockerfile @@ -7,7 +7,7 @@ WORKDIR /build ENV CGO_ENABLED=1 RUN go build -o /usr/bin/dummybridge ./cmd/dummybridge -FROM alpine:3.20 +FROM alpine:3.23 RUN apk add --no-cache ca-certificates olm su-exec bash jq yq curl \ && addgroup -S dummybridge \ From bdf63125cc5f6693580b6f31c78bbdf00c0a6e14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:33 +0200 Subject: [PATCH 21/46] Run loginhelper container as non-root --- cmd/loginhelper/Dockerfile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cmd/loginhelper/Dockerfile b/cmd/loginhelper/Dockerfile index c9f20b6..fafb80d 100644 --- a/cmd/loginhelper/Dockerfile +++ b/cmd/loginhelper/Dockerfile @@ -9,6 +9,9 @@ RUN go build -o /usr/bin/loginhelper ./cmd/loginhelper FROM alpine:3.20 -RUN apk add --no-cache ca-certificates -COPY --from=builder /usr/bin/loginhelper /usr/bin/loginhelper +RUN apk add --no-cache ca-certificates \ + && addgroup -S loginhelper \ + && adduser -S -G loginhelper loginhelper +COPY --from=builder --chown=loginhelper:loginhelper /usr/bin/loginhelper /usr/bin/loginhelper +USER loginhelper:loginhelper CMD ["/usr/bin/loginhelper"] From 49a81fbd0f134e0327d75bf129be92f655a52bb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:37 +0200 Subject: [PATCH 22/46] Align loginhelper runtime Alpine with builder --- cmd/loginhelper/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/loginhelper/Dockerfile b/cmd/loginhelper/Dockerfile index fafb80d..76cbcde 100644 --- a/cmd/loginhelper/Dockerfile +++ b/cmd/loginhelper/Dockerfile @@ -7,7 +7,7 @@ WORKDIR /build ENV CGO_ENABLED=1 RUN go build -o /usr/bin/loginhelper ./cmd/loginhelper -FROM alpine:3.20 +FROM alpine:3.23 RUN apk add --no-cache ca-certificates \ && addgroup -S loginhelper \ From 016fb7e2692d897b5dca93539f84a13153374fcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:41 +0200 Subject: [PATCH 23/46] Require step names in AG-UI builder --- pkg/ag-ui/events.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/ag-ui/events.go b/pkg/ag-ui/events.go index 5429877..917a542 100644 --- a/pkg/ag-ui/events.go +++ b/pkg/ag-ui/events.go @@ -286,24 +286,26 @@ func (b EventBuilder) ToolCallResult(messageID, toolCallID, content, state, role } func (b EventBuilder) StepStarted(messageID, stepName string) Event { + if stepName == "" { + panic("ag-ui: stepName is required for STEP_STARTED") + } evt := b.base(EventStepStarted) if messageID != "" { evt["messageId"] = messageID } - if stepName != "" { - evt["stepName"] = stepName - } + evt["stepName"] = stepName return evt } func (b EventBuilder) StepFinished(messageID, stepName string) Event { + if stepName == "" { + panic("ag-ui: stepName is required for STEP_FINISHED") + } evt := b.base(EventStepFinished) if messageID != "" { evt["messageId"] = messageID } - if stepName != "" { - evt["stepName"] = stepName - } + evt["stepName"] = stepName return evt } From cf94cc179e0e74072ce141e66dee1eb015c45e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Tue, 19 May 2026 23:21:41 +0200 Subject: [PATCH 24/46] Add random approval pause regression test --- pkg/connector/ai_runtime_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 6de6b66..a3b724c 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "math/rand" + "strconv" "strings" "testing" "time" @@ -429,6 +430,28 @@ func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { } } +func TestRandomModeApprovalPause(t *testing.T) { + for seed := int64(1); seed <= 200; seed++ { + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream-random 1 --profile=tools --allow-approval --seed="+strconv.FormatInt(seed, 10), time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if run.ApprovalID == "" { + continue + } + for _, evt := range run.Events { + if evt["type"] == agui.EventRunFinished { + t.Fatalf("approval run emitted RUN_FINISHED with seed %d", seed) + } + } + if run.Status.State != "streaming" { + t.Fatalf("expected approval run to remain streaming, got %q", run.Status.State) + } + return + } + t.Fatal("no approval action selected for tested random seeds") +} + func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { cmd := randomCommand{ sharedStreamOptions: sharedStreamOptions{ From beb243e87b15d349279e52f24779771499bb27e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 17:19:33 +0200 Subject: [PATCH 25/46] Support per-portal AI sender and agent name Parameterize AI agent ID/name across ai_runtime build functions so runs can be attributed to a configurable agent. Add helper funcs (isAIDemoCommandContent, dummyAISenderForPortal, dummyAIAgentNameForPortal) and use a portal-specific sender/agent when queuing AI anchors, carriers, approval prompts/responses, and final metadata in DummyClient. Preserve existing dedicated AI portal behavior (aiGhostID/AI) while allowing demo commands in normal rooms to appear from a stable per-portal sender. Update tests to cover demo command detection and sender selection and adjust ai_runtime tests to pass agent info. --- pkg/connector/ai_runtime.go | 18 +++---- pkg/connector/ai_runtime_test.go | 2 +- pkg/connector/client.go | 87 ++++++++++++++++++++++++-------- pkg/connector/client_test.go | 37 ++++++++++++++ 4 files changed, 112 insertions(+), 32 deletions(-) diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 014fe95..f434017 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -163,7 +163,7 @@ func virtualAIRuntime(now time.Time) aiRuntime { } func buildAIRun(ctx context.Context, runID, threadID, input string, now time.Time) (*aistream.Run, error) { - plans, err := buildAIRunPlans(ctx, runID, threadID, input, now) + plans, err := buildAIRunPlans(ctx, runID, threadID, input, now, "ai", "AI") if err != nil { return nil, err } @@ -173,10 +173,10 @@ func buildAIRun(ctx context.Context, runID, threadID, input string, now time.Tim return plans[0].Run, nil } -func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now time.Time) ([]aiRunPlan, error) { +func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now time.Time, agentID, agentName string) ([]aiRunPlan, error) { cmd, err := parseCommand(input) if err != nil { - run := aistream.NewRun(runID, threadID, aistream.DefaultModel, string(aiGhostID), aiGhostName, now) + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) writer := aistream.NewWriter(run, func() time.Time { return now }) writer.Start() writer.Text(err.Error() + "\n\n" + helpText()) @@ -184,18 +184,18 @@ func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now tim return []aiRunPlan{{Run: run}}, nil } if cmd != nil && cmd.Chaos != nil { - return buildAIChaosRunPlans(ctx, runID, threadID, now, *cmd.Chaos) + return buildAIChaosRunPlans(ctx, runID, threadID, now, *cmd.Chaos, agentID, agentName) } - run, err := buildAIRunFromCommand(ctx, runID, threadID, now, cmd) + run, err := buildAIRunFromCommand(ctx, runID, threadID, now, cmd, agentID, agentName) if err != nil { return nil, err } return []aiRunPlan{{Run: run}}, nil } -func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand) (*aistream.Run, error) { +func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string) (*aistream.Run, error) { runtime := virtualAIRuntime(now) - run := aistream.NewRun(runID, threadID, aistream.DefaultModel, string(aiGhostID), aiGhostName, now) + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) writer := aistream.NewWriter(run, runtime.now) writer.Start() @@ -223,7 +223,7 @@ func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time return run, nil } -func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd chaosCommand) ([]aiRunPlan, error) { +func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd chaosCommand, agentID, agentName string) ([]aiRunPlan, error) { seed := cmd.Seed if !cmd.SeedSet { seed = now.UnixNano() @@ -255,7 +255,7 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), &parsedCommand{ Name: "stream-random", Random: &randomCmd, - }) + }, agentID, agentName) if err != nil { return nil, err } diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index a3b724c..40bf5d8 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -379,7 +379,7 @@ func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { } func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { - plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-1", "stream-chaos 3 1 --max-actions=3 --seed=7 --stagger-ms=1:1", time.Unix(10, 0)) + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-1", "stream-chaos 3 1 --max-actions=3 --seed=7 --stagger-ms=1:1", time.Unix(10, 0), "ai", "AI") if err != nil { t.Fatal(err) } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index d8b052e..6cc8f0c 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -51,6 +51,7 @@ const ( aiGhostID networkid.UserID = "ai" aiGhostName string = "AI" aiPortalIDPrefix string = "ai-" + dummyAIAgentName string = "Dummy" ) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -243,7 +244,7 @@ func (dc *DummyClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma StreamOrder: time.Now().UnixNano(), } - if msg.Portal != nil && isAIPortalID(msg.Portal.ID) { + if msg.Portal != nil && (isAIPortalID(msg.Portal.ID) || isAIDemoCommandContent(msg.Content)) { dc.queueAIResponse(ctx, msg.Portal, msg.Content) } @@ -334,7 +335,7 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri EventID: eventID, Sender: string(reaction.SenderID), Key: reaction.Emoji, - Bridge: reaction.SenderID == aiGhostID, + Bridge: reaction.SenderID == dummyAISenderForPortal(portal), }) } if msg != nil && msg.Event != nil && msg.Event.ID != "" { @@ -348,8 +349,9 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri Key: reactionKey, }) } - cleanup := aistream.CleanupReactions(aistream.DefaultApprovalOptions(string(approvalMessageID)), selectedKey, events, string(aiGhostID)) - intent, ok := portal.GetIntentFor(ctx, bridgev2.EventSender{Sender: aiGhostID}, dc.UserLogin, bridgev2.RemoteEventMessageRemove) + sender := dummyAISenderForPortal(portal) + cleanup := aistream.CleanupReactions(aistream.DefaultApprovalOptions(string(approvalMessageID)), selectedKey, events, string(sender)) + intent, ok := portal.GetIntentFor(ctx, bridgev2.EventSender{Sender: sender}, dc.UserLogin, bridgev2.RemoteEventMessageRemove) if !ok || intent == nil { log.Warn().Str("approval_id", string(approvalMessageID)).Msg("Failed to resolve AI sender intent for approval reaction cleanup") return @@ -412,6 +414,41 @@ func getRemoteEchoBehavior(content *event.MessageEventContent) remoteEchoBehavio return remoteEchoBehavior{pending: true, delay: delay} } +func isAIDemoCommandContent(content *event.MessageEventContent) bool { + if content == nil { + return false + } + tokens := strings.Fields(strings.TrimSpace(content.Body)) + if len(tokens) == 0 { + return false + } + switch strings.ToLower(tokens[0]) { + case "help", "/help", "!help", "stream-lorem", "stream-tools", "stream-random", "stream-chaos": + return true + case "dummybridge": + return len(tokens) > 1 && strings.EqualFold(tokens[1], "help") + default: + return false + } +} + +func dummyAISenderForPortal(portal *bridgev2.Portal) networkid.UserID { + if portal == nil { + return networkid.UserID(dummyAIAgentName) + } + if isAIPortalID(portal.ID) { + return aiGhostID + } + return stablePortalUserIDByIndex(portal.ID, 0) +} + +func dummyAIAgentNameForPortal(portal *bridgev2.Portal) string { + if portal != nil && isAIPortalID(portal.ID) { + return aiGhostName + } + return dummyAIAgentName +} + func (dc *DummyClient) queueRemoteEcho(msg *bridgev2.MatrixMessage, transactionID networkid.TransactionID, timestamp time.Time, delay time.Duration) { if delay <= 0 || msg.Portal == nil { return @@ -467,11 +504,13 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por now := time.Now() runID := "run-" + string(randomMessageID()) + sender := dummyAISenderForPortal(portal) + agentName := dummyAIAgentNameForPortal(portal) var body string if inbound != nil { body = inbound.Body } - plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), body, now) + plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), body, now, string(sender), agentName) if err != nil { log.Warn().Err(err).Msg("Failed to build AI runs") return @@ -482,13 +521,13 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por } timestamp := now.Add(plan.Delay) placeholderID := networkid.MessageID(plan.Run.MessageID) - dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, aiGhostID, initialAIAnchorRun(*plan.Run), timestamp)) + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(*plan.Run), timestamp)) dc.wg.Add(1) - go func(portal *bridgev2.Portal, messageID networkid.MessageID, run aistream.Run) { + go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { defer dc.wg.Done() - dc.queueAIRunStreamAndMetadata(portal, messageID, run) - }(portal, placeholderID, *plan.Run) + dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run) + }(portal, sender, placeholderID, *plan.Run) } } @@ -498,7 +537,7 @@ func initialAIAnchorRun(run aistream.Run) aistream.Run { return run } -func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, messageID networkid.MessageID, run aistream.Run) { +func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) if targetEventID == "" { log.Warn(). @@ -507,7 +546,7 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, mess Msg("Timed out waiting for AI anchor Matrix event") return } - carriers, err := dc.queueAICarriers(portal, targetEventID, run, 1) + carriers, err := dc.queueAICarriers(portal, sender, targetEventID, run, 1) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return @@ -515,21 +554,21 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, mess nextSeq := aistream.NextSeq(carriers) for i, prompt := range run.Prompts { prompt.SeqStart = nextSeq + i*10 - dc.queueAIApprovalPrompt(portal, run, prompt, targetEventID, time.Now()) + dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, time.Now()) } if run.Status.State != "streaming" { - dc.queueAIRunFinalMetadata(portal, messageID, run) + dc.queueAIRunFinalMetadata(portal, sender, messageID, run) } } -func (dc *DummyClient) queueAICarriers(portal *bridgev2.Portal, targetEventID id.EventID, run aistream.Run, startSeq int) ([]aistream.Carrier, error) { +func (dc *DummyClient) queueAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, startSeq int) ([]aistream.Carrier, error) { carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) if err != nil { return nil, err } for i, carrier := range carriers { now := time.Now() - dc.UserLogin.QueueRemoteEvent(aibridgev2.Carrier(portal.PortalKey, aiGhostID, run, carrier, targetEventID, startSeq+i, now)) + dc.UserLogin.QueueRemoteEvent(aibridgev2.Carrier(portal.PortalKey, sender, run, carrier, targetEventID, startSeq+i, now)) } return carriers, nil } @@ -580,7 +619,7 @@ func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid return message.MXID } -func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, timestamp time.Time) { +func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, timestamp time.Time) { reactions := aistream.DefaultApprovalOptions(prompt.ID) approvalCtx := aistream.ApprovalContext{ ID: prompt.ID, @@ -597,11 +636,11 @@ func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, run aistre PreviewText: run.Preview.Text, PreviewTruncated: run.Preview.Truncated, } - dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, aiGhostID, approvalCtx, timestamp)) + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, sender, approvalCtx, timestamp)) for i, reaction := range reactions { reaction := reaction - dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, aiGhostID, approvalCtx, reaction, timestamp.Add(time.Duration(i+1)*time.Millisecond))) + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, sender, approvalCtx, reaction, timestamp.Add(time.Duration(i+1)*time.Millisecond))) } } @@ -621,11 +660,15 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid log.Warn().Str("approval_id", approvalCtx.ID).Msg("Missing AI approval target event") return } - if _, err := dc.queueAICarriers(portal, targetEventID, run, approvalCtx.SeqStart); err != nil { + sender := networkid.UserID(approvalCtx.AgentID) + if sender == "" { + sender = dummyAISenderForPortal(portal) + } + if _, err := dc.queueAICarriers(portal, sender, targetEventID, run, approvalCtx.SeqStart); err != nil { log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to queue AI approval response") return } - dc.queueAIRunFinalMetadata(portal, networkid.MessageID(approvalCtx.MessageID), run) + dc.queueAIRunFinalMetadata(portal, sender, networkid.MessageID(approvalCtx.MessageID), run) } func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { @@ -727,8 +770,8 @@ func validApprovalContext(ctx aistream.ApprovalContext) (aistream.ApprovalContex return ctx, true } -func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, messageID networkid.MessageID, run aistream.Run) { - dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, aiGhostID, messageID, run, time.Now())) +func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { + dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, sender, messageID, run, time.Now())) } func (dc *DummyClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index da31658..ba56d5a 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -8,6 +8,7 @@ import ( "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -44,6 +45,42 @@ func TestGetRemoteEchoBehavior(t *testing.T) { } } +func TestAIDemoCommandContentOnlyMatchesExplicitDemoCommands(t *testing.T) { + for _, body := range []string{ + "help", + "/help", + "!help", + "dummybridge help", + "stream-lorem 100", + "stream-tools 100 shell", + "stream-random 1", + "stream-chaos 2 1", + } { + if !isAIDemoCommandContent(&event.MessageEventContent{Body: body}) { + t.Fatalf("expected AI demo command for %q", body) + } + } + for _, body := range []string{ + "", + "hello", + "dummybridge", + "remote-echo delay 1s", + } { + if isAIDemoCommandContent(&event.MessageEventContent{Body: body}) { + t.Fatalf("did not expect AI demo command for %q", body) + } + } +} + +func TestDummyAISenderForPortalSupportsDedicatedAndNormalRooms(t *testing.T) { + if got := dummyAISenderForPortal(&bridgev2.Portal{Portal: &database.Portal{PortalKey: networkid.PortalKey{ID: "ai-room"}}}); got != aiGhostID { + t.Fatalf("AI portal sender = %q, want %q", got, aiGhostID) + } + if got := dummyAISenderForPortal(&bridgev2.Portal{Portal: &database.Portal{PortalKey: networkid.PortalKey{ID: "normal-room"}}}); got != stablePortalUserIDByIndex("normal-room", 0) { + t.Fatalf("normal portal sender = %q", got) + } +} + func TestResolveApprovalOnceKeepsFirstSelection(t *testing.T) { client := &DummyClient{} selected, first := client.resolveApprovalOnce("approval-1", "allow") From f8212e8008645ec8fd368668372e845ef12fa25a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 21:10:32 +0200 Subject: [PATCH 26/46] wip --- pkg/ai-stream/approval.go | 277 +++++++++++++++----------- pkg/ai-stream/bridgev2/events.go | 17 +- pkg/ai-stream/matrix/content.go | 71 ++++--- pkg/ai-stream/matrix/content_test.go | 60 ++++-- pkg/ai-stream/run.go | 223 ++++++++++++++++----- pkg/ai-stream/stream_test.go | 198 ++++++++++--------- pkg/connector/ai_runtime.go | 31 ++- pkg/connector/ai_runtime_test.go | 249 ++++++++++++++++++++++++ pkg/connector/client.go | 279 +++++++++++++++++++++++++-- pkg/connector/client_test.go | 81 ++++++++ pkg/connector/connector.go | 2 +- 11 files changed, 1152 insertions(+), 336 deletions(-) diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index f647d01..70a20e8 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -2,26 +2,25 @@ package aistream import ( "strings" - "time" "github.com/beeper/dummybridge/pkg/ag-ui" ) const ( - ApprovalReactionAllowOnce = "approval.allow_once" - ApprovalReactionAllowAlways = "approval.allow_always" - ApprovalReactionDeny = "approval.deny" + ApprovalChoiceApprove = "approve" + ApprovalChoiceAlwaysApprove = "always_approve" + ApprovalChoiceDeny = "deny" ) -type ReactionOption[T any] struct { - ID string `json:"id"` - Label string `json:"label"` - Values []string `json:"values"` - Value T `json:"value"` +type ApprovalChoice struct { + Key string `json:"key"` + Label string `json:"label"` + Alias string `json:"alias"` + Style string `json:"style,omitempty"` } -type ApprovalCleanup[T any] struct { - Selected ReactionOption[T] +type ApprovalCleanup struct { + Selected ApprovalChoice SelectedReactionEvent string RedactReactionEvents []string Matched bool @@ -39,6 +38,7 @@ type ApprovalContext struct { ThreadID string `json:"threadId"` RunID string `json:"runId"` MessageID string `json:"messageId"` + Command string `json:"command"` ToolCallID string `json:"toolCallId"` ToolName string `json:"toolName"` TargetEvent string `json:"target_event"` @@ -50,63 +50,189 @@ type ApprovalContext struct { PreviewTruncated bool `json:"previewTruncated,omitempty"` } -func DefaultApprovalOptions(approvalID string) []ReactionOption[agui.ToolApprovalResponse] { - return []ReactionOption[agui.ToolApprovalResponse]{ +type ApprovalRequestedValue struct { + ThreadID string + RunID string + MessageID string + ToolCallID string + ToolName string + Input any + Approval agui.ToolApproval + ApprovalMessageID string + ApprovalEventID string + Choices []ApprovalChoice +} + +type ApprovalNotice struct { + Schema string + ID string + MessageID string + ToolCallID string + ToolName string + State string + Choices []ApprovalChoice +} + +func NewApprovalRequestedValue(run Run, toolCallID, toolName string, input any, approval agui.ToolApproval) ApprovalRequestedValue { + return ApprovalRequestedValue{ + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + ToolCallID: toolCallID, + ToolName: toolName, + Input: input, + Approval: approval, + ApprovalMessageID: approval.ID, + Choices: DefaultApprovalChoices(), + } +} + +func NewApprovalNotice(ctx ApprovalContext, choices []ApprovalChoice) ApprovalNotice { + return ApprovalNotice{ + Schema: "com.beeper.ai.approval.v1", + ID: ctx.ID, + MessageID: ctx.MessageID, + ToolCallID: ctx.ToolCallID, + ToolName: ctx.ToolName, + State: "requested", + Choices: choices, + } +} + +func (v ApprovalRequestedValue) Map() map[string]any { + value := map[string]any{ + "threadId": v.ThreadID, + "runId": v.RunID, + "messageId": v.MessageID, + "toolCallId": v.ToolCallID, + "toolName": v.ToolName, + "input": v.Input, + "approval": v.Approval, + "approvalMessageId": v.ApprovalMessageID, + "choices": v.Choices, + } + if v.ApprovalEventID != "" { + value["approvalEventId"] = v.ApprovalEventID + } + return value +} + +func (n ApprovalNotice) Map() map[string]any { + return map[string]any{ + "schema": n.Schema, + "id": n.ID, + "messageId": n.MessageID, + "toolCallId": n.ToolCallID, + "toolName": n.ToolName, + "state": n.State, + "choices": ApprovalChoicesAsAny(n.Choices), + } +} + +func ApprovalChoicesAsAny(choices []ApprovalChoice) []any { + out := make([]any, 0, len(choices)) + for _, choice := range choices { + item := map[string]any{ + "key": choice.Key, + "label": choice.Label, + "alias": choice.Alias, + } + if choice.Style != "" { + item["style"] = choice.Style + } + out = append(out, item) + } + return out +} + +func ApprovalIDFromRequestedValue(value map[string]any) string { + approval, _ := value["approval"].(agui.ToolApproval) + if approval.ID != "" { + return approval.ID + } + if raw, ok := value["approval"].(map[string]any); ok { + approvalID, _ := raw["id"].(string) + return approvalID + } + return "" +} + +func SetApprovalRequestedEventID(value map[string]any, eventID string) bool { + if value == nil || eventID == "" { + return false + } + approvalID := ApprovalIDFromRequestedValue(value) + if approvalID == "" { + return false + } + value["approvalMessageId"] = approvalID + value["approvalEventId"] = eventID + return true +} + +func DefaultApprovalChoices() []ApprovalChoice { + return []ApprovalChoice{ { - ID: ApprovalReactionAllowOnce, - Label: "Allow", - Values: []string{"👍", "approval.allow_once", "allow", "allow_once"}, - Value: agui.ToolApprovalResponse{ID: approvalID, Approved: true}, + Key: ApprovalChoiceApprove, + Label: "Approve", + Alias: "✅", }, { - ID: ApprovalReactionAllowAlways, - Label: "Always allow", - Values: []string{"✅", "approval.allow_always", "always", "allow_always"}, - Value: agui.ToolApprovalResponse{ID: approvalID, Approved: true, Always: true}, + Key: ApprovalChoiceAlwaysApprove, + Label: "Always approve", + Alias: "☑️", }, { - ID: ApprovalReactionDeny, - Label: "Deny", - Values: []string{"👎", "approval.deny", "deny", "reject"}, - Value: agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: "denied"}, + Key: ApprovalChoiceDeny, + Label: "Deny", + Alias: "❌", + Style: "danger", }, } } -func ResolveReaction[T any](options []ReactionOption[T], raw string) (ReactionOption[T], bool) { +func ResolveApprovalChoice(choices []ApprovalChoice, raw string) (ApprovalChoice, bool) { key := NormalizeReaction(raw) - for _, option := range options { - if NormalizeReaction(option.ID) == key { - return option, true - } - for _, value := range option.Values { - if NormalizeReaction(value) == key { - return option, true - } + for _, choice := range choices { + if NormalizeReaction(choice.Key) == key || NormalizeReaction(choice.Alias) == key { + return choice, true } } - var zero ReactionOption[T] + var zero ApprovalChoice return zero, false } -func CleanupReactions[T any](options []ReactionOption[T], selectedKey string, events []ReactionEvent, bridgeSender string) ApprovalCleanup[T] { - selected, ok := ResolveReaction(options, selectedKey) +func ApprovalResponseForChoice(approvalID string, choice ApprovalChoice) agui.ToolApprovalResponse { + switch choice.Key { + case ApprovalChoiceApprove: + return agui.ToolApprovalResponse{ID: approvalID, Approved: true} + case ApprovalChoiceAlwaysApprove: + return agui.ToolApprovalResponse{ID: approvalID, Approved: true, Always: true} + case ApprovalChoiceDeny: + return agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: "denied"} + default: + return agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: "invalid approval choice"} + } +} + +func CleanupApprovalReactions(choices []ApprovalChoice, selectedKey string, events []ReactionEvent, bridgeSender string) ApprovalCleanup { + selected, ok := ResolveApprovalChoice(choices, selectedKey) if !ok { - return ApprovalCleanup[T]{} + return ApprovalCleanup{} } - cleanup := ApprovalCleanup[T]{Selected: selected, Matched: true} + cleanup := ApprovalCleanup{Selected: selected, Matched: true} for _, evt := range events { if evt.EventID == "" { continue } - option, matchesOption := ResolveReaction(options, evt.Key) - isSelected := matchesOption && option.ID == selected.ID + choice, matchesChoice := ResolveApprovalChoice(choices, evt.Key) + isSelected := matchesChoice && choice.Key == selected.Key isBridge := evt.Bridge || (bridgeSender != "" && evt.Sender == bridgeSender) if isSelected && !isBridge && cleanup.SelectedReactionEvent == "" { cleanup.SelectedReactionEvent = evt.EventID continue } - if isBridge || (matchesOption && !isSelected) { + if isBridge || (matchesChoice && !isSelected) { cleanup.RedactReactionEvents = append(cleanup.RedactReactionEvents, evt.EventID) } } @@ -119,73 +245,6 @@ func NormalizeReaction(reaction string) string { return strings.ToLower(reaction) } -func ApprovalResponseRun(ctx ApprovalContext, response agui.ToolApprovalResponse, now time.Time) Run { - if response.ID == "" { - response.ID = ctx.ID - } - agentID := ctx.AgentID - if agentID == "" { - agentID = "ai" - } - agentName := ctx.AgentName - if agentName == "" { - agentName = "AI" - } - model := ctx.Model - if model == "" { - model = DefaultModel - } - run := NewRun("approval-"+ctx.ID, ctx.ThreadID, model, agentID, agentName, now) - run.RunID = ctx.RunID - run.MessageID = ctx.MessageID - run.ToolCallID = ctx.ToolCallID - run.ApprovalID = ctx.ID - run.Status = Status{State: "complete"} - run.Preview = Preview{Text: ctx.PreviewText, Truncated: ctx.PreviewTruncated} - run.Approvals = []ApprovalSummary{{ - ID: ctx.ID, - ToolCallID: ctx.ToolCallID, - State: approvalSummaryState(response), - Always: response.Always, - Reason: response.Reason, - Fields: response.Fields, - Metadata: response.Metadata, - }} - builder := agui.NewEventBuilder(model, func() time.Time { return now }) - run.Events = append(run.Events, builder.Custom(agui.ApprovalCustomResponded, map[string]any{ - "threadId": ctx.ThreadID, - "runId": ctx.RunID, - "messageId": ctx.MessageID, - "toolCallId": ctx.ToolCallID, - "toolName": ctx.ToolName, - "approval": response, - })) - result := map[string]any{ - "approvalId": response.ID, - "always": response.Always, - } - if response.Fields != nil { - result["fields"] = response.Fields - } - if response.Metadata != nil { - result["metadata"] = response.Metadata - } - if response.Approved { - result["state"] = agui.ToolResultStateComplete - result["approved"] = true - } else { - reason := response.Reason - if reason == "" { - reason = "denied" - } - result["state"] = agui.ToolResultStateError - result["reason"] = reason - run.Status = Status{State: "error", Error: result} - } - run.Events = append(run.Events, builder.ToolCallEnd(ctx.ToolCallID, ctx.ToolName, nil, jsonString(result), agui.ToolStateApprovalResponded)) - return *run -} - func approvalSummaryState(response agui.ToolApprovalResponse) string { if response.Approved { if response.Always { diff --git a/pkg/ai-stream/bridgev2/events.go b/pkg/ai-stream/bridgev2/events.go index 5381625..05cbaa5 100644 --- a/pkg/ai-stream/bridgev2/events.go +++ b/pkg/ai-stream/bridgev2/events.go @@ -53,7 +53,7 @@ func Carrier(portalKey networkid.PortalKey, sender networkid.UserID, run aistrea } func ApprovalPrompt(portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, timestamp time.Time) *simplevent.PreConvertedMessage { - content, extra := aimatrix.ApprovalContent(ctx, aistream.DefaultApprovalOptions(ctx.ID)) + content, extra := aimatrix.ApprovalContent(ctx, aistream.DefaultApprovalChoices()) return &simplevent.PreConvertedMessage{ EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{ @@ -63,29 +63,24 @@ func ApprovalPrompt(portalKey networkid.PortalKey, sender networkid.UserID, ctx } } -func ApprovalOptionReaction[T any](portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, option aistream.ReactionOption[T], timestamp time.Time) *simplevent.Reaction { - emoji := option.ID - if len(option.Values) > 0 { - emoji = option.Values[0] - } +func ApprovalOptionReaction(portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, choice aistream.ApprovalChoice, timestamp time.Time) *simplevent.Reaction { return &simplevent.Reaction{ EventMeta: eventMeta(bridgev2.RemoteEventReaction, portalKey, sender, timestamp), TargetMessage: networkid.MessageID(ctx.ID), - EmojiID: networkid.EmojiID(option.ID), - Emoji: emoji, + EmojiID: networkid.EmojiID(choice.Key), + Emoji: choice.Alias, ExtraContent: map[string]any{ "com.beeper.ai.approval_option": map[string]any{ "approvalId": ctx.ID, "toolCallId": ctx.ToolCallID, - "optionId": option.ID, - "value": option.Value, + "choice": choice.Key, }, }, } } func FinalMetadataEdit(portalKey networkid.PortalKey, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, timestamp time.Time) *simplevent.Message[*aistream.Run] { - finalContent, finalExtra := aimatrix.AnchorContent(run) + finalContent, finalExtra := aimatrix.FinalContent(run) return &simplevent.Message[*aistream.Run]{ EventMeta: eventMeta(bridgev2.RemoteEventEdit, portalKey, sender, timestamp), Data: &run, diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go index 7a547d5..3a3c0e5 100644 --- a/pkg/ai-stream/matrix/content.go +++ b/pkg/ai-stream/matrix/content.go @@ -3,14 +3,39 @@ package matrix import ( "fmt" - "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) +const ApprovalRelationType = event.RelationType("com.beeper.ai.approval") + func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any) { + content := previewContent(run) + extra := map[string]any{ + aistream.BeeperAIKey: run.InitialUIMessage(), + aistream.BeeperAIMetadataKey: run.Metadata(), + "com.beeper.stream": map[string]any{ + "type": aistream.BeeperAIStreamDeltas, + }, + } + return content, extra +} + +func FinalContent(run aistream.Run) (*event.MessageEventContent, map[string]any) { + content := previewContent(run) + extra := map[string]any{ + aistream.BeeperAIKey: run.FinalUIMessage(aistream.SnapshotTextBytes, true), + aistream.BeeperAIMetadataKey: run.Metadata(), + "com.beeper.stream": map[string]any{ + "type": aistream.BeeperAIStreamDeltas, + }, + } + return content, extra +} + +func previewContent(run aistream.Run) *event.MessageEventContent { body := run.Preview.Text if body == "" { body = "..." @@ -22,14 +47,7 @@ func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any ID: run.AgentID, Displayname: run.AgentName, } - extra := map[string]any{ - aistream.BeeperAIKey: run.InitialUIMessage(), - aistream.BeeperAIMetadataKey: run.Metadata(), - "com.beeper.stream": map[string]any{ - "type": aistream.BeeperAIStreamDeltas, - }, - } - return content, extra + return content } func CarrierContent(carrier aistream.Carrier, targetEventID id.EventID) (*event.MessageEventContent, map[string]any) { @@ -38,43 +56,22 @@ func CarrierContent(carrier aistream.Carrier, targetEventID id.EventID) (*event. return &content, aistream.CarrierContent(carrier.Envelopes) } -func ApprovalContent(ctx aistream.ApprovalContext, options []aistream.ReactionOption[agui.ToolApprovalResponse]) (*event.MessageEventContent, map[string]any) { +func ApprovalContent(ctx aistream.ApprovalContext, choices []aistream.ApprovalChoice) (*event.MessageEventContent, map[string]any) { toolName := ctx.ToolName body := fmt.Sprintf("Approval required for %s", toolName) - if len(options) > 0 { - body += "\nReact with one of the listed options." + if len(choices) > 0 { + body += "\nReact with one of the listed choices." } content := format.TextToContent(body) if ctx.TargetEvent != "" { - content.SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(ctx.TargetEvent)}) + content.SetRelatesTo(&event.RelatesTo{Type: ApprovalRelationType, EventID: id.EventID(ctx.TargetEvent)}) } extra := map[string]any{ - "com.beeper.ai.approval": map[string]any{ - "id": ctx.ID, - "toolCallId": ctx.ToolCallID, - "toolName": toolName, - "threadId": ctx.ThreadID, - "runId": ctx.RunID, - "messageId": ctx.MessageID, - "approval": agui.ToolApproval{ - ID: ctx.ID, - NeedsApproval: true, - }, - "reactions": ReactionOptionsAsAny(options), - }, + "com.beeper.ai.approval": aistream.NewApprovalNotice(ctx, choices).Map(), } return &content, extra } -func ReactionOptionsAsAny(options []aistream.ReactionOption[agui.ToolApprovalResponse]) []any { - out := make([]any, 0, len(options)) - for _, option := range options { - out = append(out, map[string]any{ - "id": option.ID, - "label": option.Label, - "values": option.Values, - "value": option.Value, - }) - } - return out +func ApprovalChoicesAsAny(choices []aistream.ApprovalChoice) []any { + return aistream.ApprovalChoicesAsAny(choices) } diff --git a/pkg/ai-stream/matrix/content_test.go b/pkg/ai-stream/matrix/content_test.go index 467c59a..ca656ec 100644 --- a/pkg/ai-stream/matrix/content_test.go +++ b/pkg/ai-stream/matrix/content_test.go @@ -25,10 +25,13 @@ func TestAnchorContentUsesVisibleTextAndAIProfile(t *testing.T) { if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.ID != "ai" || content.BeeperPerMessageProfile.Displayname != "AI" { t.Fatalf("missing AI per-message profile: %#v", content.BeeperPerMessageProfile) } - uiMessage, ok := extra[aistream.BeeperAIKey].(map[string]any) - if !ok || uiMessage["id"] == "" || uiMessage["metadata"] != nil { + uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) + if !ok || uiMessage.ID == "" || uiMessage.Metadata == nil || len(uiMessage.Parts) != 1 { t.Fatalf("bad compact AI message: %#v", extra[aistream.BeeperAIKey]) } + if uiMessage.Parts[0]["type"] != "text" || uiMessage.Parts[0]["content"] != "visible preview" { + t.Fatalf("anchor AI message should include preview text part: %#v", uiMessage.Parts) + } if extra[aistream.BeeperAIMetadataKey] == nil { t.Fatalf("missing AI metadata: %#v", extra) } @@ -75,6 +78,34 @@ func TestAnchorContentRendersFinalPreviewAsMatrixHTML(t *testing.T) { } } +func TestFinalContentIncludesFinalUIParts(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + writer.Thinking("hidden reasoning") + writer.Text("final **preview**") + writer.Finish(agui.FinishReasonStop) + + content, extra := FinalContent(*run) + if content.Body != "final **preview**" || content.Format != event.FormatHTML { + t.Fatalf("bad final preview content: %#v", content) + } + uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) + if !ok || len(uiMessage.Parts) != 2 || uiMessage.Parts[0]["type"] != "thinking" || uiMessage.Parts[1]["type"] != "text" { + t.Fatalf("final edit must include concrete UI parts: %#v", extra[aistream.BeeperAIKey]) + } + if uiMessage.Parts[0]["content"] != "hidden reasoning" || uiMessage.Parts[1]["content"] == "" { + t.Fatalf("final edit must preserve reasoning and text parts: %#v", uiMessage.Parts) + } + if extra[aistream.BeeperAIMetadataKey] == nil { + t.Fatalf("missing final metadata: %#v", extra) + } + stream, ok := extra["com.beeper.stream"].(map[string]any) + if !ok || stream["type"] != aistream.BeeperAIStreamDeltas { + t.Fatalf("missing final stream descriptor: %#v", extra["com.beeper.stream"]) + } +} + func TestCarrierContentIsHiddenTextCarrierWithDeltas(t *testing.T) { carrier := aistream.Carrier{Envelopes: []aistream.Envelope{{ ThreadID: "thread-1", @@ -97,7 +128,7 @@ func TestCarrierContentIsHiddenTextCarrierWithDeltas(t *testing.T) { } } -func TestApprovalContentIncludesContextAndGenericReactionOptions(t *testing.T) { +func TestApprovalContentIncludesContextAndChoices(t *testing.T) { ctx := aistream.ApprovalContext{ ID: "approval-1", ThreadID: "thread-1", @@ -107,25 +138,28 @@ func TestApprovalContentIncludesContextAndGenericReactionOptions(t *testing.T) { ToolName: "shell", TargetEvent: "$anchor", } - options := aistream.DefaultApprovalOptions(ctx.ID) + choices := aistream.DefaultApprovalChoices() - content, extra := ApprovalContent(ctx, options) - if content.MsgType != event.MsgText || content.RelatesTo == nil || content.RelatesTo.EventID != "$anchor" { + content, extra := ApprovalContent(ctx, choices) + if content.MsgType != event.MsgText || content.RelatesTo == nil || content.RelatesTo.EventID != "$anchor" || content.RelatesTo.Type != ApprovalRelationType { t.Fatalf("bad approval content: %#v", content) } meta, ok := extra["com.beeper.ai.approval"].(map[string]any) if !ok { t.Fatalf("missing approval metadata: %#v", extra) } - if meta["id"] != ctx.ID || meta["runId"] != ctx.RunID || meta["messageId"] != ctx.MessageID || meta["toolCallId"] != ctx.ToolCallID { + if meta["schema"] != "com.beeper.ai.approval.v1" || meta["id"] != ctx.ID || meta["messageId"] != ctx.MessageID || meta["toolCallId"] != ctx.ToolCallID || meta["state"] != "requested" { t.Fatalf("bad approval metadata: %#v", meta) } - reactions, ok := meta["reactions"].([]any) - if !ok || len(reactions) != len(options) { - t.Fatalf("bad approval reactions: %#v", meta["reactions"]) + if _, ok := meta["runId"]; ok { + t.Fatalf("approval event should not duplicate run metadata: %#v", meta) + } + approvalChoices, ok := meta["choices"].([]any) + if !ok || len(approvalChoices) != len(choices) { + t.Fatalf("bad approval choices: %#v", meta["choices"]) } - first := reactions[0].(map[string]any) - if first["id"] != aistream.ApprovalReactionAllowOnce { - t.Fatalf("bad first reaction option: %#v", first) + first := approvalChoices[0].(map[string]any) + if first["key"] != aistream.ApprovalChoiceApprove || first["alias"] != "✅" { + t.Fatalf("bad first approval choice: %#v", first) } } diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index 1225260..3c77572 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -53,6 +53,68 @@ type Preview struct { Truncated bool `json:"truncated"` } +type UIMessageMetadata struct { + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + Status Status `json:"status"` + Usage *agui.Usage `json:"usage,omitempty"` +} + +func (m UIMessageMetadata) Map() map[string]any { + out := map[string]any{ + "threadId": m.ThreadID, + "runId": m.RunID, + "status": m.Status, + } + if m.Usage != nil { + out["usage"] = *m.Usage + } + return out +} + +type RunMetadata struct { + Schema string + Protocol string + ThreadID string + RunID string + MessageID string + AgentID string + AgentName string + Model string + Usage agui.Usage + Status Status + Approvals []ApprovalSummary + Artifacts ArtifactSummary + Data map[string]any + Preview Preview +} + +func (m RunMetadata) Map() map[string]any { + return map[string]any{ + "schema": m.Schema, + "protocol": m.Protocol, + "threadId": m.ThreadID, + "runId": m.RunID, + "messageId": m.MessageID, + "agent": map[string]any{ + "id": m.AgentID, + "displayName": m.AgentName, + }, + "model": m.Model, + "usage": map[string]any{ + "promptTokens": m.Usage.PromptTokens, + "completionTokens": m.Usage.CompletionTokens, + "totalTokens": m.Usage.TotalTokens, + }, + "usageDetails": map[string]any{}, + "status": m.Status, + "approvals": m.Approvals, + "artifacts": m.Artifacts, + "data": m.Data, + "preview": m.Preview, + } +} + type ApprovalSummary struct { ID string `json:"id"` ToolCallID string `json:"toolCallId"` @@ -169,15 +231,10 @@ func (w *Writer) ToolStart(toolCallID, name string, index int, approval *agui.To func (w *Writer) ToolApprovalRequested(toolCallID, name string, input any, approval agui.ToolApproval) { w.recordApprovalRequest(toolCallID, name, &approval) - w.Add(w.builder.Custom(agui.ApprovalCustomRequested, map[string]any{ - "threadId": w.Run.ThreadID, - "runId": w.Run.RunID, - "messageId": w.Run.MessageID, - "toolCallId": toolCallID, - "toolName": name, - "input": input, - "approval": approval, - })) + w.Add(w.builder.Custom( + agui.ApprovalCustomRequested, + NewApprovalRequestedValue(*w.Run, toolCallID, name, input, approval).Map(), + )) } func (w *Writer) recordApprovalRequest(toolCallID, name string, approval *agui.ToolApproval) { @@ -211,6 +268,48 @@ func (w *Writer) ToolApprovalInputComplete(toolCallID, name string, input any) { w.Add(w.builder.ToolCallEnd(toolCallID, name, input, nil, agui.ToolStateApprovalRequested)) } +func (w *Writer) ToolApprovalResponded(toolCallID, name string, input any, response agui.ToolApprovalResponse) { + for i := range w.Run.Approvals { + if w.Run.Approvals[i].ID == response.ID { + w.Run.Approvals[i].State = approvalSummaryState(response) + w.Run.Approvals[i].Always = response.Always + w.Run.Approvals[i].Reason = response.Reason + w.Run.Approvals[i].Fields = response.Fields + w.Run.Approvals[i].Metadata = response.Metadata + } + } + w.Add(w.builder.Custom(agui.ApprovalCustomResponded, map[string]any{ + "threadId": w.Run.ThreadID, + "runId": w.Run.RunID, + "messageId": w.Run.MessageID, + "toolCallId": toolCallID, + "toolName": name, + "approval": response, + })) + result := map[string]any{ + "approvalId": response.ID, + "always": response.Always, + } + if response.Fields != nil { + result["fields"] = response.Fields + } + if response.Metadata != nil { + result["metadata"] = response.Metadata + } + if response.Approved { + result["state"] = agui.ToolResultStateComplete + result["approved"] = true + } else { + reason := response.Reason + if reason == "" { + reason = "denied" + } + result["state"] = agui.ToolResultStateError + result["reason"] = reason + } + w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(result), agui.ToolStateApprovalResponded)) +} + func (w *Writer) ToolResult(toolCallID, content, state string) { w.Add(w.builder.ToolCallResult(w.Run.MessageID, toolCallID, content, state, agui.RoleTool)) } @@ -304,7 +403,7 @@ func (w *Writer) addFinalSnapshot() { if w == nil || w.Run == nil { return } - w.MessagesSnapshot([]agui.UIMessage{w.Run.FinalUIMessageSnapshot(0)}) + w.MessagesSnapshot([]agui.UIMessage{w.Run.FinalUIMessage(0, true)}) } func (w *Writer) finishReasoning() { @@ -352,16 +451,11 @@ func (t Run) Text() string { return out.String() } -func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { +func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage { message := agui.UIMessage{ - ID: t.MessageID, - Role: agui.RoleAssistant, - Metadata: map[string]any{ - "threadId": t.ThreadID, - "runId": t.RunID, - "status": t.Status, - "usage": t.Usage, - }, + ID: t.MessageID, + Role: agui.RoleAssistant, + Metadata: t.UIMessageMetadata(true).Map(), } var textPart agui.MessagePart var thinkingPart agui.MessagePart @@ -393,6 +487,9 @@ func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { if delta == "" { continue } + if !includeThinking { + continue + } if thinkingPart == nil { thinkingPart = appendPart(agui.MessagePart{"type": "thinking", "content": "", "state": agui.PartStateStreaming}) } @@ -504,15 +601,52 @@ func (t Run) FinalUIMessageSnapshot(textBudget int) agui.UIMessage { } compactTextPart(textPart, textBudget) compactTextPart(thinkingPart, textBudget) + if len(message.Parts) > 1 { + visible := make([]agui.MessagePart, 0, len(message.Parts)) + other := make([]agui.MessagePart, 0, len(message.Parts)) + for _, part := range message.Parts { + switch part["type"] { + case "text", "thinking": + visible = append(visible, part) + default: + other = append(other, part) + } + } + if len(visible) > 0 { + message.Parts = append(visible, other...) + } + } return message } -func (t Run) InitialUIMessage() map[string]any { - return map[string]any{ - "id": t.MessageID, - "role": agui.RoleAssistant, - "parts": []any{}, +func (t Run) InitialUIMessage() agui.UIMessage { + message := agui.UIMessage{ + ID: t.MessageID, + Role: agui.RoleAssistant, + Metadata: t.UIMessageMetadata(false).Map(), + } + if t.Preview.Text != "" { + message.Parts = []agui.MessagePart{{ + "type": "text", + "content": t.Preview.Text, + "state": agui.PartStateStreaming, + }} + } else { + message.Parts = []agui.MessagePart{} + } + return message +} + +func (t Run) UIMessageMetadata(includeUsage bool) UIMessageMetadata { + metadata := UIMessageMetadata{ + ThreadID: t.ThreadID, + RunID: t.RunID, + Status: t.Status, + } + if includeUsage { + metadata.Usage = &t.Usage } + return metadata } func compactTextPart(part agui.MessagePart, budget int) { @@ -574,28 +708,25 @@ func approvalMapID(value any) string { } func (t Run) Metadata() map[string]any { - return map[string]any{ - "schema": "com.beeper.ai.run.v1", - "protocol": "ag-ui", - "threadId": t.ThreadID, - "runId": t.RunID, - "messageId": t.MessageID, - "agent": map[string]any{ - "id": t.AgentID, - "displayName": t.AgentName, - }, - "model": t.Model, - "usage": map[string]any{ - "promptTokens": t.Usage.PromptTokens, - "completionTokens": t.Usage.CompletionTokens, - "totalTokens": t.Usage.TotalTokens, - }, - "usageDetails": map[string]any{}, - "status": t.Status, - "approvals": t.Approvals, - "artifacts": t.Artifacts, - "data": t.Data, - "preview": t.Preview, + return t.RunMetadata().Map() +} + +func (t Run) RunMetadata() RunMetadata { + return RunMetadata{ + Schema: "com.beeper.ai.run.v1", + Protocol: "ag-ui", + ThreadID: t.ThreadID, + RunID: t.RunID, + MessageID: t.MessageID, + AgentID: t.AgentID, + AgentName: t.AgentName, + Model: t.Model, + Usage: t.Usage, + Status: t.Status, + Approvals: t.Approvals, + Artifacts: t.Artifacts, + Data: t.Data, + Preview: t.Preview, } } diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go index c5e1ea8..4e0be17 100644 --- a/pkg/ai-stream/stream_test.go +++ b/pkg/ai-stream/stream_test.go @@ -69,6 +69,7 @@ func TestFinalSnapshotSplitsIntoBaseAndContinuationParts(t *testing.T) { t.Fatal(err) } var baseSnapshots, continuations int + var baseText string var reconstructedText strings.Builder var sawMetadata bool for i, carrier := range carriers { @@ -91,6 +92,11 @@ func TestFinalSnapshotSplitsIntoBaseAndContinuationParts(t *testing.T) { if ok && metadata["runId"] == "run-1" { sawMetadata = true } + for _, part := range testFinalParts(t, message["parts"]) { + if part["type"] == "text" { + baseText += part["content"].(string) + } + } case agui.EventCustom: if env.Part["name"] != FinalPartsCustomName { continue @@ -114,6 +120,9 @@ func TestFinalSnapshotSplitsIntoBaseAndContinuationParts(t *testing.T) { if baseSnapshots != 1 || continuations == 0 || !sawMetadata { t.Fatalf("expected one metadata base snapshot and continuations, base=%d continuations=%d metadata=%v", baseSnapshots, continuations, sawMetadata) } + if baseText == "" { + t.Fatal("base final snapshot must keep visible text in the primary event") + } if !strings.Contains(run.Text(), reconstructedText.String()) { t.Fatalf("unexpected continuation text reconstruction length=%d", reconstructedText.Len()) } @@ -222,125 +231,128 @@ func TestValidateRejectsLegacyOrInvalidToolResultShape(t *testing.T) { } func TestApprovalResolverMatchesEmojiKeysAndAliases(t *testing.T) { - options := DefaultApprovalOptions("approval-1") - for _, key := range []string{"👍", "approval.allow_once", "allow"} { - option, ok := ResolveReaction(options, key) - if !ok || !option.Value.Approved || option.Value.Always { - t.Fatalf("expected allow-once for %q, got %#v ok=%v", key, option, ok) + choices := DefaultApprovalChoices() + for _, key := range []string{"✅", "approve"} { + choice, ok := ResolveApprovalChoice(choices, key) + response := ApprovalResponseForChoice("approval-1", choice) + if !ok || !response.Approved || response.Always { + t.Fatalf("expected approve for %q, got %#v ok=%v", key, choice, ok) } } - option, ok := ResolveReaction(options, "always") - if !ok || !option.Value.Approved || !option.Value.Always { - t.Fatalf("expected allow-always, got %#v ok=%v", option, ok) + choice, ok := ResolveApprovalChoice(choices, "☑️") + response := ApprovalResponseForChoice("approval-1", choice) + if !ok || !response.Approved || !response.Always { + t.Fatalf("expected always-approve, got %#v ok=%v", choice, ok) } - option, ok = ResolveReaction(options, "👎") - if !ok || option.Value.Approved || option.Value.Reason != "denied" { - t.Fatalf("expected denial, got %#v ok=%v", option, ok) + choice, ok = ResolveApprovalChoice(choices, "deny") + response = ApprovalResponseForChoice("approval-1", choice) + if !ok || response.Approved || response.Reason != "denied" { + t.Fatalf("expected denial, got %#v ok=%v", choice, ok) } } -func TestCleanupKeepsSelectedUserReactionAndRemovesBridgeOptions(t *testing.T) { - options := DefaultApprovalOptions("approval-1") - cleanup := CleanupReactions(options, "👍", []ReactionEvent{ - {EventID: "$bridge-allow", Sender: "ai", Key: "👍", Bridge: true}, - {EventID: "$bridge-deny", Sender: "ai", Key: "👎", Bridge: true}, - {EventID: "$user-allow", Sender: "@user:example", Key: "👍"}, - {EventID: "$user-deny", Sender: "@user:example", Key: "👎"}, - }, "ai") - if !cleanup.Matched || cleanup.SelectedReactionEvent != "$user-allow" { - t.Fatalf("bad selected reaction: %#v", cleanup) - } - got := strings.Join(cleanup.RedactReactionEvents, ",") - if !strings.Contains(got, "$bridge-allow") || !strings.Contains(got, "$bridge-deny") || !strings.Contains(got, "$user-deny") { - t.Fatalf("bad cleanup redactions: %#v", cleanup.RedactReactionEvents) - } -} +func TestApprovalRequestedValueOwnsStreamPayloadShape(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + run.MessageID = "msg-run-1" + approval := agui.ToolApproval{ID: "approval-1", NeedsApproval: true} -func TestApprovalResponseRunEmitsRespondedStateAndToolResult(t *testing.T) { - run := ApprovalResponseRun(ApprovalContext{ - ID: "approval-1", - ThreadID: "thread-1", - RunID: "run-1", - MessageID: "msg-1", - ToolCallID: "tool-1", - ToolName: "shell", - TargetEvent: "$anchor", - SeqStart: 10, - PreviewText: "Use supportbrief for incremental patches.", - }, agui.ToolApprovalResponse{ - Approved: false, - Reason: "denied", - Fields: map[string]any{"scope": "once"}, - Metadata: map[string]any{"source": "reaction"}, - }, time.Unix(10, 0)) + value := NewApprovalRequestedValue(*run, "tool-1", "shell", map[string]any{"command": "ls"}, approval).Map() - if run.RunID != "run-1" || run.MessageID != "msg-1" { - t.Fatalf("approval response must continue the existing run/message, got %#v", run) - } - if run.Preview.Text != "Use supportbrief for incremental patches." { - t.Fatalf("approval response must preserve anchor preview, got %#v", run.Preview) + if value["threadId"] != "thread-1" || value["runId"] != "run-1" || value["messageId"] != "msg-run-1" { + t.Fatalf("bad run identifiers: %#v", value) } - if len(run.Events) != 2 { - t.Fatalf("expected approval response and tool result events, got %#v", run.Events) + if value["toolCallId"] != "tool-1" || value["toolName"] != "shell" { + t.Fatalf("bad tool identifiers: %#v", value) } - if run.Events[0]["type"] != agui.EventCustom || run.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("missing approval-responded event: %#v", run.Events[0]) + if value["approvalMessageId"] != "approval-1" { + t.Fatalf("missing approval message id: %#v", value) } - if run.Events[1]["type"] != agui.EventToolCallEnd || run.Events[1]["state"] != agui.ToolStateApprovalResponded { - t.Fatalf("missing approval-responded tool end: %#v", run.Events[1]) + if _, ok := value["approvalEventId"]; ok { + t.Fatalf("approval event id should only be added after Matrix send: %#v", value) } - result := jsonMap(t, run.Events[1]["result"]) - if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { - t.Fatalf("expected structured denied result, got %#v", result) + choices, ok := value["choices"].([]ApprovalChoice) + if !ok || len(choices) != len(DefaultApprovalChoices()) || choices[0].Key != ApprovalChoiceApprove { + t.Fatalf("bad approval choices: %#v", value["choices"]) } - if result["fields"].(map[string]any)["scope"] != "once" || result["metadata"].(map[string]any)["source"] != "reaction" { - t.Fatalf("expected flexible approval fields to survive, got %#v", result) + if ApprovalIDFromRequestedValue(value) != "approval-1" { + t.Fatalf("approval id resolver failed for value: %#v", value) } - if run.Approvals[0].Fields["scope"] != "once" || run.Approvals[0].Metadata["source"] != "reaction" { - t.Fatalf("expected approval summary fields to survive, got %#v", run.Approvals[0]) + if !SetApprovalRequestedEventID(value, "$approval") || value["approvalEventId"] != "$approval" { + t.Fatalf("failed to annotate approval event id: %#v", value) } +} - carriers, err := PackRunFromSeq(run, "$anchor", CarrierBudgetBytes, 10) - if err != nil { - t.Fatal(err) +func TestRunMetadataOwnsMatrixPayloadShape(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "agent-1", "Agent", time.Unix(10, 0)) + run.MessageID = "msg-run-1" + run.Usage = agui.Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3} + run.Preview = Preview{Text: "hello", Truncated: false} + + metadata := run.Metadata() + + if metadata["schema"] != "com.beeper.ai.run.v1" || metadata["protocol"] != "ag-ui" { + t.Fatalf("bad protocol metadata: %#v", metadata) + } + if metadata["threadId"] != "thread-1" || metadata["runId"] != "run-1" || metadata["messageId"] != "msg-run-1" { + t.Fatalf("bad run identifiers: %#v", metadata) + } + agent, ok := metadata["agent"].(map[string]any) + if !ok || agent["id"] != "agent-1" || agent["displayName"] != "Agent" { + t.Fatalf("bad agent metadata: %#v", metadata["agent"]) } - if carriers[0].Envelopes[0].Seq != 10 { - t.Fatalf("expected continuation seq 10, got %#v", carriers[0].Envelopes[0]) + usage, ok := metadata["usage"].(map[string]any) + if !ok || usage["promptTokens"] != 1 || usage["completionTokens"] != 2 || usage["totalTokens"] != 3 { + t.Fatalf("bad usage metadata: %#v", metadata["usage"]) + } + if _, ok := metadata["usageDetails"].(map[string]any); !ok { + t.Fatalf("usage details should always be present: %#v", metadata) } } -func TestApprovalResponseRunPreservesApprovedAlways(t *testing.T) { - run := ApprovalResponseRun(ApprovalContext{ - ID: "approval-1", - ThreadID: "thread-1", - RunID: "run-1", - MessageID: "msg-1", - ToolCallID: "tool-1", - ToolName: "shell", - TargetEvent: "$anchor", - }, agui.ToolApprovalResponse{Approved: true, Always: true}, time.Unix(10, 0)) +func TestApprovalNoticeOwnsHiddenMessagePayloadShape(t *testing.T) { + notice := NewApprovalNotice(ApprovalContext{ + ID: "approval-1", + MessageID: "msg-run-1", + ToolCallID: "tool-1", + ToolName: "shell", + }, DefaultApprovalChoices()).Map() - if run.Status.State != "complete" { - t.Fatalf("expected complete approval response run, got %#v", run.Status) + if notice["schema"] != "com.beeper.ai.approval.v1" || notice["state"] != "requested" { + t.Fatalf("bad approval notice metadata: %#v", notice) + } + if notice["id"] != "approval-1" || notice["messageId"] != "msg-run-1" || notice["toolCallId"] != "tool-1" || notice["toolName"] != "shell" { + t.Fatalf("bad approval notice identifiers: %#v", notice) + } + choices, ok := notice["choices"].([]any) + if !ok || len(choices) != 3 { + t.Fatalf("bad approval notice choices: %#v", notice["choices"]) + } + first, ok := choices[0].(map[string]any) + if !ok || first["key"] != ApprovalChoiceApprove || first["label"] != "Approve" || first["alias"] != "✅" { + t.Fatalf("bad first approval choice: %#v", choices[0]) } - if len(run.Approvals) != 1 || run.Approvals[0].State != "approved-always" || !run.Approvals[0].Always { - t.Fatalf("bad approval summary: %#v", run.Approvals) + if _, ok := first["style"]; ok { + t.Fatalf("empty style should be omitted from approval choices: %#v", first) } - result := jsonMap(t, run.Events[1]["result"]) - if result["state"] != agui.ToolResultStateComplete || result["approved"] != true || result["always"] != true { - t.Fatalf("bad approval result: %#v", result) + deny, ok := choices[2].(map[string]any) + if !ok || deny["style"] != "danger" { + t.Fatalf("deny choice should keep danger style: %#v", choices[2]) } } -func jsonMap(t *testing.T, value any) map[string]any { - t.Helper() - text, ok := value.(string) - if !ok { - t.Fatalf("expected JSON string result, got %#v", value) +func TestCleanupKeepsSelectedUserReactionAndRemovesBridgeOptions(t *testing.T) { + choices := DefaultApprovalChoices() + cleanup := CleanupApprovalReactions(choices, "✅", []ReactionEvent{ + {EventID: "$bridge-allow", Sender: "ai", Key: "✅", Bridge: true}, + {EventID: "$bridge-deny", Sender: "ai", Key: "❌", Bridge: true}, + {EventID: "$user-allow", Sender: "@user:example", Key: "✅"}, + {EventID: "$user-deny", Sender: "@user:example", Key: "❌"}, + }, "ai") + if !cleanup.Matched || cleanup.SelectedReactionEvent != "$user-allow" { + t.Fatalf("bad selected reaction: %#v", cleanup) } - var out map[string]any - if err := json.Unmarshal([]byte(text), &out); err != nil { - t.Fatalf("failed to parse result %q: %v", text, err) + got := strings.Join(cleanup.RedactReactionEvents, ",") + if !strings.Contains(got, "$bridge-allow") || !strings.Contains(got, "$bridge-deny") || !strings.Contains(got, "$user-deny") { + t.Fatalf("bad cleanup redactions: %#v", cleanup.RedactReactionEvents) } - return out } diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index f434017..b02694b 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -16,7 +16,10 @@ import ( "go.mau.fi/util/shlex" ) -var errApprovalRequested = errors.New("approval requested") +var ( + errApprovalRequested = errors.New("approval requested") + errApprovalDenied = errors.New("approval denied") +) const ( defaultChunkMin = 24 @@ -139,6 +142,11 @@ type aiRuntime struct { sleep func(context.Context, time.Duration) error } +type aiRunner struct { + runtime aiRuntime + approvals map[string]agui.ToolApprovalResponse +} + type aiRunPlan struct { Run *aistream.Run Delay time.Duration @@ -194,12 +202,16 @@ func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now tim } func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string) (*aistream.Run, error) { + return buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now, cmd, agentID, agentName, nil) +} + +func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { runtime := virtualAIRuntime(now) run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) writer := aistream.NewWriter(run, runtime.now) writer.Start() - runner := aiRunner{runtime: runtime} + runner := aiRunner{runtime: runtime, approvals: approvals} var err error switch { case cmd == nil || cmd.Name == "help": @@ -649,10 +661,6 @@ func parseToolSpec(raw string, idx int) (toolSpec, error) { return spec, nil } -type aiRunner struct { - runtime aiRuntime -} - func (r aiRunner) runLorem(ctx context.Context, w *aistream.Writer, cmd loremCommand) error { opts := cmd.Options rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) @@ -848,6 +856,17 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool } switch { case spec.Approval: + if response, ok := r.approvals[approvalID]; ok { + if response.ID == "" { + response.ID = approvalID + } + w.ToolApprovalResponded(toolCallID, spec.Name, input, response) + annotateProviderRawEvent(w, spec, "approval_responded") + if !response.Approved { + return errApprovalDenied + } + return nil + } w.ToolApprovalInputComplete(toolCallID, spec.Name, input) annotateProviderRawEvent(w, spec, "tool_call_input_complete") w.ToolApprovalRequested(toolCallID, spec.Name, input, *approval) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 40bf5d8..1552491 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -11,6 +11,7 @@ import ( "github.com/beeper/dummybridge/pkg/ag-ui" "github.com/beeper/dummybridge/pkg/ai-stream" + "maunium.net/go/mautrix/id" ) func TestParseCommandRecognizesHelpAliases(t *testing.T) { @@ -156,6 +157,13 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { if _, hasOptions := value["options"]; hasOptions { t.Fatalf("AG-UI approval event must not embed Matrix reaction options: %#v", value) } + if value["approvalMessageId"] != "approval-run-1-dummy-tool-1-shell" { + t.Fatalf("approval event should name the Matrix reaction target: %#v", value) + } + choices, ok := value["choices"].([]aistream.ApprovalChoice) + if !ok || len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { + t.Fatalf("approval event should duplicate renderer choices: %#v", value["choices"]) + } if value["input"] == nil { t.Fatalf("approval event should include final tool input: %#v", value) } @@ -174,6 +182,247 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { } } +func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + carriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + if err != nil { + t.Fatal(err) + } + nextSeq := aistream.NextSeq(splitCarriersForTimedEmission(carriers)) + if nextSeq <= 1 { + t.Fatalf("expected initial stream to consume carrier sequence numbers, got %d", nextSeq) + } + + prompt := run.Prompts[0] + prompt.SeqStart = nextSeq + approvalCtx := aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: "$anchor", + AgentID: run.AgentID, + AgentName: run.AgentName, + SeqStart: prompt.SeqStart, + } + continuation, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + ID: prompt.ID, + Approved: true, + }, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + if err != nil { + t.Fatal(err) + } + if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { + t.Fatalf("continuation should start at next carrier seq %d, got %#v", nextSeq, continuationCarriers) + } + if continuationCarriers[0].Envelopes[0].Seq >= 100000 { + t.Fatalf("continuation sequence has legacy large gap: %#v", continuationCarriers[0]) + } +} + +func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { + command := "stream-tools 120 shell#approval fetch --seed=7 --chunk-chars=32:32" + run, err := buildAIRun(context.Background(), "run-1", "thread-1", command, time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Prompts) != 1 { + t.Fatalf("expected one approval prompt, got %#v", run.Prompts) + } + initialCarriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + if err != nil { + t.Fatal(err) + } + initialCarriers = splitCarriersForTimedEmission(initialCarriers) + nextSeq := aistream.NextSeq(initialCarriers) + if nextSeq <= 1 { + t.Fatalf("expected initial carriers to advance sequence, got %d", nextSeq) + } + + prompt := run.Prompts[0] + prompt.SeqStart = nextSeq + approvalCtx := aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: command, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: "$anchor", + AgentID: run.AgentID, + AgentName: run.AgentName, + SeqStart: prompt.SeqStart, + } + notice := aistream.NewApprovalNotice(approvalCtx, aistream.DefaultApprovalChoices()).Map() + if notice["id"] != prompt.ID || notice["messageId"] != run.MessageID || notice["state"] != "requested" { + t.Fatalf("approval notice does not target the paused run: %#v", notice) + } + + annotateApprovalEventIDs(run, map[string]id.EventID{prompt.ID: "$approval"}) + annotatedCarriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + if err != nil { + t.Fatal(err) + } + var annotatedValue map[string]any + for _, carrier := range annotatedCarriers { + for _, env := range carrier.Envelopes { + if env.Part["type"] != agui.EventCustom || env.Part["name"] != agui.ApprovalCustomRequested { + continue + } + annotatedValue, _ = env.Part["value"].(map[string]any) + } + } + if annotatedValue == nil || annotatedValue["approvalMessageId"] != prompt.ID || annotatedValue["approvalEventId"] != "$approval" { + t.Fatalf("approval-requested stream event missing Matrix target: %#v", annotatedValue) + } + choices, ok := annotatedValue["choices"].([]any) + if !ok || len(choices) != len(aistream.DefaultApprovalChoices()) { + t.Fatalf("approval-requested stream event missing choices: %#v", annotatedValue["choices"]) + } + firstChoice, ok := choices[0].(map[string]any) + if !ok || firstChoice["key"] != aistream.ApprovalChoiceApprove || firstChoice["label"] != "Approve" { + t.Fatalf("approval-requested stream event has bad choice shape: %#v", choices[0]) + } + + continuation, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + ID: prompt.ID, + Approved: true, + }, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + if len(continuation.Prompts) != 0 { + t.Fatalf("continuation must not request approval again: %#v", continuation.Prompts) + } + if continuation.Status.State != "complete" { + t.Fatalf("approved continuation should finish the run, got %#v", continuation.Status) + } + continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + if err != nil { + t.Fatal(err) + } + if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { + t.Fatalf("continuation should resume at seq %d, got %#v", nextSeq, continuationCarriers) + } + if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("continuation must start by acknowledging approval: %#v", continuation.Events) + } +} + +func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { + command := "stream-tools 240 shell#approval fetch --seed=7 --chunk-chars=32:32" + approvalCtx := aistream.ApprovalContext{ + ID: "approval-run-1-dummy-tool-1-shell", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Command: command, + ToolCallID: "dummy-tool-1-shell", + ToolName: "shell", + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 12, + } + run, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + ID: approvalCtx.ID, + Approved: true, + }, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Events) == 0 { + t.Fatal("expected continuation events") + } + if run.Events[0]["type"] != agui.EventCustom || run.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("first continuation event should acknowledge approval, got %#v", run.Events[0]) + } + seenApprovedTool := false + seenLaterTool := false + seenFinished := false + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID { + if evt["state"] == agui.ToolStateApprovalResponded { + result := jsonResultMap(t, evt["result"]) + if result["approved"] != true { + t.Fatalf("approved result missing approval state: %#v", result) + } + seenApprovedTool = true + } + } + if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + seenLaterTool = true + } + if evt["type"] == agui.EventRunFinished { + seenFinished = true + } + } + if !seenApprovedTool || !seenLaterTool || !seenFinished { + t.Fatalf("continuation did not resume fully: approved=%v laterTool=%v finished=%v events=%#v", seenApprovedTool, seenLaterTool, seenFinished, run.Events) + } + if run.Status.State != "complete" { + t.Fatalf("approved continuation status = %#v", run.Status) + } + if len(run.Prompts) != 0 { + t.Fatalf("finished continuation should not keep pending prompts: %#v", run.Prompts) + } +} + +func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { + command := "stream-tools 240 shell#approval fetch --seed=7 --chunk-chars=32:32" + approvalCtx := aistream.ApprovalContext{ + ID: "approval-run-1-dummy-tool-1-shell", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Command: command, + ToolCallID: "dummy-tool-1-shell", + ToolName: "shell", + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 12, + } + run, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + ID: approvalCtx.ID, + Approved: false, + Reason: "denied", + }, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + seenDeniedTool := false + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + t.Fatalf("denied approval must not continue later tools: %#v", run.Events) + } + if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID && evt["state"] == agui.ToolStateApprovalResponded { + result := jsonResultMap(t, evt["result"]) + if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { + t.Fatalf("bad denied result: %#v", result) + } + seenDeniedTool = true + } + } + if !seenDeniedTool { + t.Fatalf("missing denied approval result: %#v", run.Events) + } + if run.Status.State != "error" { + t.Fatalf("denied continuation status = %#v", run.Status) + } +} + func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#deny --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) if err != nil { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 6cc8f0c..73a1984 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -272,15 +272,19 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M if dc == nil || dc.UserLogin == nil || msg == nil || msg.TargetMessage == nil || msg.Content == nil || msg.Portal == nil { return &database.Reaction{}, nil } + if isApprovalOptionReaction(msg) { + return &database.Reaction{}, nil + } approvalID := string(msg.TargetMessage.ID) if !strings.HasPrefix(approvalID, "approval-") { return &database.Reaction{}, nil } reaction := aistream.NormalizeReaction(msg.Content.RelatesTo.Key) - selected, ok := aistream.ResolveReaction(aistream.DefaultApprovalOptions(approvalID), reaction) + selected, ok := aistream.ResolveApprovalChoice(aistream.DefaultApprovalChoices(), reaction) if !ok { return &database.Reaction{}, nil } + response := aistream.ApprovalResponseForChoice(approvalID, selected) selectedKey, firstResolution := dc.resolveApprovalOnce(approvalID, reaction) dc.cleanupApprovalReactions(ctx, msg.Portal, networkid.MessageID(approvalID), selectedKey, reaction, msg) @@ -292,12 +296,13 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M Msg("Ignoring duplicate dummy AI approval reaction") return &database.Reaction{}, nil } - dc.queueAIApprovalResponse(ctx, msg.Portal, msg.TargetMessage, selected.Value) + dc.queueAIApprovalResponse(ctx, msg.Portal, msg.TargetMessage, response) logger := log.Info(). Str("approval_id", approvalID). Str("reaction", reaction). - Bool("approved", selected.Value.Approved) + Str("choice", selected.Key). + Bool("approved", response.Approved) if msg.Event != nil { logger = logger.Stringer("sender", msg.Event.Sender) } @@ -306,6 +311,14 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M return &database.Reaction{}, nil } +func isApprovalOptionReaction(msg *bridgev2.MatrixReaction) bool { + if msg == nil || msg.Event == nil { + return false + } + _, ok := msg.Event.Content.Raw["com.beeper.ai.approval_option"] + return ok +} + func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (string, bool) { dc.approvalSelectionsOnce.Do(func() { dc.approvalSelections = exsync.NewMap[string, string]() @@ -350,7 +363,7 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri }) } sender := dummyAISenderForPortal(portal) - cleanup := aistream.CleanupReactions(aistream.DefaultApprovalOptions(string(approvalMessageID)), selectedKey, events, string(sender)) + cleanup := aistream.CleanupApprovalReactions(aistream.DefaultApprovalChoices(), selectedKey, events, string(sender)) intent, ok := portal.GetIntentFor(ctx, bridgev2.EventSender{Sender: sender}, dc.UserLogin, bridgev2.RemoteEventMessageRemove) if !ok || intent == nil { log.Warn().Str("approval_id", string(approvalMessageID)).Msg("Failed to resolve AI sender intent for approval reaction cleanup") @@ -519,15 +532,23 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por if plan.Run == nil { continue } - timestamp := now.Add(plan.Delay) placeholderID := networkid.MessageID(plan.Run.MessageID) - dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(*plan.Run), timestamp)) dc.wg.Add(1) - go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { + go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { defer dc.wg.Done() - dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run) - }(portal, sender, placeholderID, *plan.Run) + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-dc.ctx.Done(): + timer.Stop() + return + } + } + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(run), time.Now())) + dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run, command) + }(portal, sender, placeholderID, *plan.Run, body, plan.Delay) } } @@ -537,7 +558,7 @@ func initialAIAnchorRun(run aistream.Run) aistream.Run { return run } -func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { +func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string) { targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) if targetEventID == "" { log.Warn(). @@ -546,15 +567,49 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send Msg("Timed out waiting for AI anchor Matrix event") return } - carriers, err := dc.queueAICarriers(portal, sender, targetEventID, run, 1) + carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, 1) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return } + carriers = splitCarriersForTimedEmission(carriers) nextSeq := aistream.NextSeq(carriers) + approvalEventIDs := make(map[string]id.EventID, len(run.Prompts)) for i, prompt := range run.Prompts { prompt.SeqStart = nextSeq + i*10 - dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, time.Now()) + ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) + if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { + approvalEventIDs[ctx.ID] = approvalEventID + log.Info(). + Str("run_id", run.RunID). + Str("approval_id", ctx.ID). + Stringer("approval_event_id", approvalEventID). + Int("approval_seq_start", ctx.SeqStart). + Msg("AI approval notice ready for reaction") + } else { + log.Warn(). + Str("run_id", run.RunID). + Str("approval_id", ctx.ID). + Int("approval_seq_start", ctx.SeqStart). + Msg("Timed out waiting for AI approval notice Matrix event") + } + } + if len(approvalEventIDs) > 0 { + annotateApprovalEventIDs(&run, approvalEventIDs) + carriers, err = aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, 1) + if err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to repack AI stream with approval event IDs") + return + } + carriers = splitCarriersForTimedEmission(carriers) + } + dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, 1) + if len(run.Prompts) > 0 && run.Status.State == "streaming" { + log.Info(). + Str("run_id", run.RunID). + Str("message_id", string(messageID)). + Int("approval_prompts", len(run.Prompts)). + Msg("AI run paused for approval") } if run.Status.State != "streaming" { dc.queueAIRunFinalMetadata(portal, sender, messageID, run) @@ -566,11 +621,109 @@ func (dc *DummyClient) queueAICarriers(portal *bridgev2.Portal, sender networkid if err != nil { return nil, err } + carriers = splitCarriersForTimedEmission(carriers) + dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq) + return carriers, nil +} + +func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int) { + streamStart := time.Now() for i, carrier := range carriers { + dc.sleepUntilCarrierTime(run, carrier, streamStart) now := time.Now() dc.UserLogin.QueueRemoteEvent(aibridgev2.Carrier(portal.PortalKey, sender, run, carrier, targetEventID, startSeq+i, now)) } - return carriers, nil +} + +func splitCarriersForTimedEmission(carriers []aistream.Carrier) []aistream.Carrier { + out := make([]aistream.Carrier, 0, len(carriers)) + for _, carrier := range carriers { + if len(carrier.Envelopes) <= 1 { + out = append(out, carrier) + continue + } + for _, env := range carrier.Envelopes { + out = append(out, aistream.Carrier{Envelopes: []aistream.Envelope{env}}) + } + } + return out +} + +func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { + target := carrierTimestamp(run, carrier, streamStart) + if target.IsZero() { + return + } + delay := time.Until(target) + if delay <= 0 { + return + } + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-dc.ctx.Done(): + timer.Stop() + } +} + +func carrierTimestamp(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) time.Time { + base := runStartTimestamp(run) + if base.IsZero() { + return time.Time{} + } + var latest time.Time + for _, env := range carrier.Envelopes { + eventTime := eventTimestamp(env.Part) + if eventTime.IsZero() { + continue + } + if latest.IsZero() || eventTime.After(latest) { + latest = eventTime + } + } + if latest.IsZero() { + return time.Time{} + } + return streamStart.Add(latest.Sub(base)) +} + +func runStartTimestamp(run aistream.Run) time.Time { + for _, evt := range run.Events { + if ts := eventTimestamp(evt); !ts.IsZero() { + return ts + } + } + return time.Time{} +} + +func eventTimestamp(evt agui.Event) time.Time { + raw, ok := evt["timestamp"] + if !ok { + return time.Time{} + } + var millis int64 + switch value := raw.(type) { + case int64: + millis = value + case int: + millis = int64(value) + case int32: + millis = int64(value) + case float64: + millis = int64(value) + case json.Number: + parsed, err := value.Int64() + if err != nil { + return time.Time{} + } + millis = parsed + default: + return time.Time{} + } + if millis <= 0 { + return time.Time{} + } + return time.UnixMilli(millis) } func (dc *DummyClient) waitForMessageMXID( @@ -619,13 +772,14 @@ func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid return message.MXID } -func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, timestamp time.Time) { - reactions := aistream.DefaultApprovalOptions(prompt.ID) +func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) aistream.ApprovalContext { + choices := aistream.DefaultApprovalChoices() approvalCtx := aistream.ApprovalContext{ ID: prompt.ID, ThreadID: run.ThreadID, RunID: run.RunID, MessageID: run.MessageID, + Command: command, ToolCallID: prompt.ToolCallID, ToolName: prompt.ToolName, TargetEvent: string(targetEventID), @@ -638,9 +792,31 @@ func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender net } dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, sender, approvalCtx, timestamp)) - for i, reaction := range reactions { - reaction := reaction - dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, sender, approvalCtx, reaction, timestamp.Add(time.Duration(i+1)*time.Millisecond))) + for i, choice := range choices { + choice := choice + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, sender, approvalCtx, choice, timestamp.Add(time.Duration(i+1)*time.Millisecond))) + } + return approvalCtx +} + +func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) { + if run == nil || len(eventIDs) == 0 { + return + } + for _, evt := range run.Events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + continue + } + value, _ := evt["value"].(map[string]any) + if value == nil { + continue + } + approvalID := aistream.ApprovalIDFromRequestedValue(value) + eventID := eventIDs[approvalID] + if eventID == "" { + continue + } + aistream.SetApprovalRequestedEventID(value, string(eventID)) } } @@ -654,7 +830,11 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid response.ID = approvalCtx.ID } now := time.Now() - run := aistream.ApprovalResponseRun(approvalCtx, response, now) + run, err := buildAIApprovalContinuationRun(ctx, approvalCtx, response, now) + if err != nil { + log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to build AI approval continuation") + return + } targetEventID := id.EventID(approvalCtx.TargetEvent) if targetEventID == "" { log.Warn().Str("approval_id", approvalCtx.ID).Msg("Missing AI approval target event") @@ -669,6 +849,65 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid return } dc.queueAIRunFinalMetadata(portal, sender, networkid.MessageID(approvalCtx.MessageID), run) + log.Info(). + Str("run_id", approvalCtx.RunID). + Str("approval_id", approvalCtx.ID). + Str("tool_call_id", approvalCtx.ToolCallID). + Bool("approved", response.Approved). + Bool("always", response.Always). + Int("seq_start", approvalCtx.SeqStart). + Str("state", run.Status.State). + Msg("Queued AI approval continuation") +} + +func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.ApprovalContext, response agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { + cmd, err := parseCommand(approvalCtx.Command) + if err != nil { + return aistream.Run{}, err + } + if response.ID == "" { + response.ID = approvalCtx.ID + } + run, err := buildAIRunFromCommandWithApprovals(ctx, approvalCtx.RunID, approvalCtx.ThreadID, now, cmd, approvalCtx.AgentID, approvalCtx.AgentName, map[string]agui.ToolApprovalResponse{ + approvalCtx.ID: response, + }) + if err != nil { + return aistream.Run{}, err + } + if run == nil { + return aistream.Run{}, fmt.Errorf("approval continuation produced no run") + } + start := approvalContinuationStart(run.Events, approvalCtx.ID) + if start < 0 { + return aistream.Run{}, fmt.Errorf("approval response event %q not found", approvalCtx.ID) + } + run.Events = append([]agui.Event(nil), run.Events[start:]...) + run.RunID = approvalCtx.RunID + run.ThreadID = approvalCtx.ThreadID + run.MessageID = approvalCtx.MessageID + run.ToolCallID = approvalCtx.ToolCallID + run.ApprovalID = approvalCtx.ID + run.Prompts = nil + return *run, nil +} + +func approvalContinuationStart(events []agui.Event, approvalID string) int { + for i, evt := range events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomResponded { + continue + } + value, _ := evt["value"].(map[string]any) + approval, _ := value["approval"].(agui.ToolApprovalResponse) + if approval.ID == approvalID { + return i + } + if raw, ok := value["approval"].(map[string]any); ok { + if idValue, _ := raw["id"].(string); idValue == approvalID { + return i + } + } + } + return -1 } func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { @@ -761,7 +1000,7 @@ func messageIDString(message *database.Message) string { } func validApprovalContext(ctx aistream.ApprovalContext) (aistream.ApprovalContext, bool) { - if ctx.ID == "" || ctx.ThreadID == "" || ctx.RunID == "" || ctx.MessageID == "" || ctx.ToolCallID == "" || ctx.TargetEvent == "" { + if ctx.ID == "" || ctx.ThreadID == "" || ctx.RunID == "" || ctx.MessageID == "" || ctx.Command == "" || ctx.ToolCallID == "" || ctx.TargetEvent == "" { return aistream.ApprovalContext{}, false } if ctx.SeqStart <= 0 { diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index ba56d5a..04ed5a9 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -12,6 +12,7 @@ import ( "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) func TestGetRemoteEchoBehavior(t *testing.T) { @@ -104,6 +105,13 @@ func TestInitialAIAnchorRunKeepsPreviewButNotTerminalMetadata(t *testing.T) { if anchor.Preview.Text == "" { t.Fatal("expected anchor to keep useful preview text") } + uiMessage := anchor.InitialUIMessage() + if len(uiMessage.Parts) != 1 || uiMessage.Parts[0]["type"] != "text" || uiMessage.Parts[0]["content"] != "visible preview" { + t.Fatalf("anchor UI message should include visible preview text part: %#v", uiMessage.Parts) + } + if uiMessage.Metadata["runId"] != run.RunID { + t.Fatalf("anchor UI metadata missing run id: %#v", uiMessage.Metadata) + } if anchor.Status.State != "streaming" { t.Fatalf("anchor status = %#v, want streaming", anchor.Status) } @@ -115,12 +123,44 @@ func TestInitialAIAnchorRunKeepsPreviewButNotTerminalMetadata(t *testing.T) { } } +func TestCarrierTimestampUsesEventOffsetFromRunStart(t *testing.T) { + run := aistream.Run{ + Events: []agui.Event{ + {"timestamp": int64(10_000), "type": agui.EventRunStarted, "threadId": "thread-1"}, + {"timestamp": int64(13_500), "type": agui.EventTextMessageContent, "messageId": "msg-1", "delta": "later"}, + }, + } + streamStart := time.Unix(100, 0) + target := carrierTimestamp(run, aistream.Carrier{Envelopes: []aistream.Envelope{{ + Part: run.Events[1], + }}}, streamStart) + if want := streamStart.Add(3500 * time.Millisecond); !target.Equal(want) { + t.Fatalf("target = %s, want %s", target, want) + } +} + +func TestSplitCarriersForTimedEmissionKeepsOneEnvelopePerCarrier(t *testing.T) { + carriers := splitCarriersForTimedEmission([]aistream.Carrier{{ + Envelopes: []aistream.Envelope{ + {Seq: 1}, + {Seq: 2}, + }, + }}) + if len(carriers) != 2 { + t.Fatalf("carrier count = %d, want 2", len(carriers)) + } + if carriers[0].Envelopes[0].Seq != 1 || carriers[1].Envelopes[0].Seq != 2 { + t.Fatalf("bad split carriers: %#v", carriers) + } +} + func TestApprovalContextForMessageFallsBackToStoredMessage(t *testing.T) { want := aistream.ApprovalContext{ ID: "approval-1", ThreadID: "thread-1", RunID: "run-1", MessageID: "msg-1", + Command: "stream-tools 120 shell#approval", ToolCallID: "tool-1", TargetEvent: "$event", SeqStart: 12, @@ -150,3 +190,44 @@ func TestApprovalContextForMessageFallsBackToStoredMessage(t *testing.T) { t.Fatalf("approval context = %#v, want %#v", got, want) } } + +func TestApprovalOptionReactionIsBridgeManagedFallback(t *testing.T) { + msg := &bridgev2.MatrixReaction{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.ReactionEventContent]{ + Event: &event.Event{Content: event.Content{Raw: map[string]any{ + "com.beeper.ai.approval_option": map[string]any{"choice": "approve"}, + }}}, + }, + } + if !isApprovalOptionReaction(msg) { + t.Fatal("expected managed approval option reaction") + } + if isApprovalOptionReaction(&bridgev2.MatrixReaction{MatrixEventBase: bridgev2.MatrixEventBase[*event.ReactionEventContent]{Event: &event.Event{Content: event.Content{Raw: map[string]any{}}}}}) { + t.Fatal("plain user reaction must not be treated as a managed approval option") + } +} + +func TestAnnotateApprovalEventIDsAddsReactionTargetEventToStreamPrompt(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.ToolApprovalRequested("tool-1", "shell", map[string]any{"command": "ls"}, agui.ToolApproval{ + ID: "approval-1", + NeedsApproval: true, + }) + + annotateApprovalEventIDs(run, map[string]id.EventID{ + "approval-1": "$approval", + }) + + for _, evt := range run.Events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + continue + } + value, _ := evt["value"].(map[string]any) + if value["approvalMessageId"] != "approval-1" || value["approvalEventId"] != "$approval" { + t.Fatalf("approval stream event missing target ids: %#v", value) + } + return + } + t.Fatal("missing approval-requested event") +} diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index cfad894..3adf8fa 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -56,7 +56,7 @@ func (dc *DummyConnector) GetCapabilities() *bridgev2.NetworkGeneralCapabilities } func (dc *DummyConnector) GetBridgeInfoVersion() (info, caps int) { - return 0, 0 + return 0, 1 } func (dc *DummyConnector) GetName() bridgev2.BridgeName { From a2a2645d907e52bd85665ca18b87ed3341a7edea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 21:19:58 +0200 Subject: [PATCH 27/46] wip --- pkg/ag-ui/events.go | 7 +++++ pkg/ai-stream/approval.go | 4 +-- pkg/ai-stream/run.go | 9 +++++- pkg/ai-stream/stream_test.go | 20 +++++++++++++- pkg/connector/ai_runtime.go | 47 +++++++++++++++++++++++++++++++- pkg/connector/ai_runtime_test.go | 6 +++- 6 files changed, 87 insertions(+), 6 deletions(-) diff --git a/pkg/ag-ui/events.go b/pkg/ag-ui/events.go index 917a542..05df815 100644 --- a/pkg/ag-ui/events.go +++ b/pkg/ag-ui/events.go @@ -221,6 +221,10 @@ func (b EventBuilder) ReasoningMessageEnd(messageID string) Event { } func (b EventBuilder) ToolCallStart(messageID, toolCallID, name string, index *int, approval *ToolApproval) Event { + return b.ToolCallStartWithMetadata(messageID, toolCallID, name, index, approval, nil) +} + +func (b EventBuilder) ToolCallStartWithMetadata(messageID, toolCallID, name string, index *int, approval *ToolApproval, metadata map[string]any) Event { evt := b.base(EventToolCallStart) if messageID != "" { evt["parentMessageId"] = messageID @@ -228,6 +232,9 @@ func (b EventBuilder) ToolCallStart(messageID, toolCallID, name string, index *i evt["toolCallId"] = toolCallID evt["toolCallName"] = name evt["toolName"] = name + if len(metadata) > 0 { + evt["metadata"] = metadata + } if index != nil { evt["index"] = *index } diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index 70a20e8..456cac8 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -174,12 +174,12 @@ func DefaultApprovalChoices() []ApprovalChoice { return []ApprovalChoice{ { Key: ApprovalChoiceApprove, - Label: "Approve", + Label: "Allow once", Alias: "✅", }, { Key: ApprovalChoiceAlwaysApprove, - Label: "Always approve", + Label: "Allow always", Alias: "☑️", }, { diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index 3c77572..fb519e3 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -222,8 +222,12 @@ func (w *Writer) StepFinish(stepID string) { } func (w *Writer) ToolStart(toolCallID, name string, index int, approval *agui.ToolApproval) { + w.ToolStartWithMetadata(toolCallID, name, index, approval, nil) +} + +func (w *Writer) ToolStartWithMetadata(toolCallID, name string, index int, approval *agui.ToolApproval, metadata map[string]any) { idx := index - w.Add(w.builder.ToolCallStart(w.Run.MessageID, toolCallID, name, &idx, approval)) + w.Add(w.builder.ToolCallStartWithMetadata(w.Run.MessageID, toolCallID, name, &idx, approval, metadata)) if approval != nil { w.recordApprovalRequest(toolCallID, name, approval) } @@ -517,6 +521,9 @@ func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage if approval, ok := evt["approval"]; ok { part["approval"] = approval } + if metadata, ok := evt["metadata"]; ok { + part["metadata"] = metadata + } toolParts[toolCallID] = appendPart(part) case agui.EventToolCallArgs: toolCallID, _ := evt["toolCallId"].(string) diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go index 4e0be17..1e39c54 100644 --- a/pkg/ai-stream/stream_test.go +++ b/pkg/ai-stream/stream_test.go @@ -230,6 +230,24 @@ func TestValidateRejectsLegacyOrInvalidToolResultShape(t *testing.T) { } } +func TestFinalUIMessageCarriesToolCallMetadata(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.ToolStartWithMetadata("tool-1", "calendar.get_events", 0, nil, map[string]any{ + "displayName": "List Calendar Events", + "iconId": "3257-5951", + }) + + message := run.FinalUIMessage(0, true) + if len(message.Parts) != 1 { + t.Fatalf("expected one part, got %#v", message.Parts) + } + metadata, ok := message.Parts[0]["metadata"].(map[string]any) + if !ok || metadata["displayName"] != "List Calendar Events" || metadata["iconId"] != "3257-5951" { + t.Fatalf("bad tool metadata: %#v", message.Parts[0]) + } +} + func TestApprovalResolverMatchesEmojiKeysAndAliases(t *testing.T) { choices := DefaultApprovalChoices() for _, key := range []string{"✅", "approve"} { @@ -328,7 +346,7 @@ func TestApprovalNoticeOwnsHiddenMessagePayloadShape(t *testing.T) { t.Fatalf("bad approval notice choices: %#v", notice["choices"]) } first, ok := choices[0].(map[string]any) - if !ok || first["key"] != ApprovalChoiceApprove || first["label"] != "Approve" || first["alias"] != "✅" { + if !ok || first["key"] != ApprovalChoiceApprove || first["label"] != "Allow once" || first["alias"] != "✅" { t.Fatalf("bad first approval choice: %#v", choices[0]) } if _, ok := first["style"]; ok { diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index b02694b..f9bdda5 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -828,7 +828,7 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool if spec.Approval { approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} } - w.ToolStart(toolCallID, spec.Name, spec.SequenceIndex-1, approval) + w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, toolDisplayMetadata(spec.Name)) annotateProviderRawEvent(w, spec, "tool_call_start") if spec.InputError { w.ToolArgs(toolCallID, jsonToolInput(input), nil) @@ -885,6 +885,51 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool return nil } +func toolDisplayMetadata(name string) map[string]any { + displayName := titleToolName(name) + metadata := map[string]any{ + "displayName": displayName, + } + switch strings.ToLower(name) { + case "calendar.get_events", "google_calendar.get_events", "google-calendar.get-events": + metadata["displayName"] = "List Calendar Events" + metadata["iconId"] = "3257-5951" + metadata["provider"] = map[string]any{ + "id": "google-calendar", + "displayName": "Google Calendar", + "iconId": "3257-5951", + } + case "linear.list_issues", "linear.list-issues", "list_issues", "list-issues": + metadata["displayName"] = "List Issues" + metadata["iconId"] = "3257-5945" + metadata["provider"] = map[string]any{ + "id": "linear", + "displayName": "Linear", + "iconId": "3257-5945", + } + case "shell": + metadata["displayName"] = "Run Command" + metadata["iconId"] = "3255-2310" + case "fetch": + metadata["displayName"] = "Fetch Web" + metadata["iconId"] = "source-placeholder" + } + return metadata +} + +func titleToolName(name string) string { + parts := strings.FieldsFunc(name, func(r rune) bool { + return r == '_' || r == '-' || r == '.' + }) + for i, part := range parts { + if part == "" { + continue + } + parts[i] = strings.ToUpper(part[:1]) + part[1:] + } + return strings.Join(parts, " ") +} + func approvalIDForRun(runID, toolCallID string) string { return "approval-" + runID + "-" + toolCallID } diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 1552491..752ae68 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -136,6 +136,10 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { if approval.ID != "approval-run-1-dummy-tool-1-shell" || !approval.NeedsApproval { t.Fatalf("bad approval metadata: %#v", approval) } + metadata, ok := evt["metadata"].(map[string]any) + if !ok || metadata["displayName"] != "Run Command" || metadata["iconId"] != "3255-2310" { + t.Fatalf("bad tool display metadata: %#v", evt["metadata"]) + } foundToolStart = true } if evt["type"] == agui.EventToolCallEnd { @@ -291,7 +295,7 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { t.Fatalf("approval-requested stream event missing choices: %#v", annotatedValue["choices"]) } firstChoice, ok := choices[0].(map[string]any) - if !ok || firstChoice["key"] != aistream.ApprovalChoiceApprove || firstChoice["label"] != "Approve" { + if !ok || firstChoice["key"] != aistream.ApprovalChoiceApprove || firstChoice["label"] != "Allow once" { t.Fatalf("approval-requested stream event has bad choice shape: %#v", choices[0]) } From 7bead3505105da756f0fc4e1447ad5562c331a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 21:36:52 +0200 Subject: [PATCH 28/46] wip --- pkg/ag-ui/events.go | 33 +++++++- pkg/ai-stream/approval.go | 6 ++ pkg/ai-stream/pack.go | 14 ++- pkg/connector/ai_runtime.go | 100 ++++++++++++++++++++-- pkg/connector/ai_runtime_test.go | 109 ++++++++++++++++++++++++ pkg/connector/client.go | 141 +++++++++++++++++++++++++------ 6 files changed, 365 insertions(+), 38 deletions(-) diff --git a/pkg/ag-ui/events.go b/pkg/ag-ui/events.go index 05df815..dd5ec6b 100644 --- a/pkg/ag-ui/events.go +++ b/pkg/ag-ui/events.go @@ -386,13 +386,19 @@ func ValidateEvent(evt Event) error { case EventTextMessageStart: return require(evt, "messageId", "role") case EventTextMessageContent: - return require(evt, "messageId", "delta") + if err := require(evt, "messageId"); err != nil { + return err + } + return requireStringField(evt, "delta") case EventTextMessageEnd: return require(evt, "messageId") case EventReasoningStart, EventReasoningEnd, EventReasoningMsgStart, EventReasoningMsgEnd: return require(evt, "messageId") case EventReasoningMsgCont: - return require(evt, "messageId", "delta") + if err := require(evt, "messageId"); err != nil { + return err + } + return requireStringField(evt, "delta") case EventToolCallStart: if err := require(evt, "toolCallId", "toolCallName"); err != nil { return err @@ -404,7 +410,10 @@ func ValidateEvent(evt Event) error { } return validateStringSet(evt, "state", true, validToolStates) case EventToolCallArgs: - if err := require(evt, "toolCallId", "delta"); err != nil { + if err := require(evt, "toolCallId"); err != nil { + return err + } + if err := requireStringField(evt, "delta"); err != nil { return err } if err := validateStringSet(evt, "state", false, validToolStates); err != nil { @@ -649,6 +658,24 @@ func require(evt Event, keys ...string) error { return nil } +// requireStringField checks that the field is present and is a string. +// Unlike require, it accepts whitespace-only strings — streaming deltas can +// legitimately consist only of spaces or newlines between tokens. +func requireStringField(evt Event, key string) error { + value, ok := evt[key] + if !ok { + return fmt.Errorf("%s missing %s", evt["type"], key) + } + str, ok := value.(string) + if !ok { + return fmt.Errorf("%s has invalid %s %T", evt["type"], key, value) + } + if str == "" { + return fmt.Errorf("%s missing %s", evt["type"], key) + } + return nil +} + func emptyValue(value any) bool { switch v := value.(type) { case string: diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index 456cac8..4b4b0d1 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -50,6 +50,12 @@ type ApprovalContext struct { PreviewTruncated bool `json:"previewTruncated,omitempty"` } +// ApprovalSeqReservation is the size of the sequence-number window reserved +// for the continuation of a single approval prompt. Large enough to fit any +// realistic continuation stream without colliding with neighbouring prompts' +// reserved ranges or their (possibly nested) continuations. +const ApprovalSeqReservation = 10000 + type ApprovalRequestedValue struct { ThreadID string RunID string diff --git a/pkg/ai-stream/pack.go b/pkg/ai-stream/pack.go index 7b02c71..439de44 100644 --- a/pkg/ai-stream/pack.go +++ b/pkg/ai-stream/pack.go @@ -4,10 +4,22 @@ import ( "encoding/json" "fmt" "strings" + "unicode/utf8" "github.com/beeper/dummybridge/pkg/ag-ui" ) +func truncateUTF8(s string, maxBytes int) string { + if maxBytes <= 0 || len(s) <= maxBytes { + return s + } + end := maxBytes + for end > 0 && !utf8.RuneStart(s[end]) { + end-- + } + return s[:end] +} + type Envelope struct { ThreadID string `json:"threadId"` RunID string `json:"runId"` @@ -306,7 +318,7 @@ func sanitizeRawEvent(evt agui.Event, budget int) agui.Event { delete(cp, "rawEvent") cp["rawEventTruncated"] = true } else if len(raw) > 2048 { - cp["rawEvent"] = string(raw[:2048]) + cp["rawEvent"] = truncateUTF8(string(raw), 2048) cp["rawEventTruncated"] = true } if JSONSize(cp) > budget { diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index f9bdda5..cb726af 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -150,6 +150,11 @@ type aiRunner struct { type aiRunPlan struct { Run *aistream.Run Delay time.Duration + // EffectiveCommand is the canonical command form used to deterministically + // replay this run during approval continuation. For random/chaos sub-runs + // (where the seed was derived implicitly) this includes the resolved + // --seed=N so the continuation reproduces the same action sequence. + EffectiveCommand string } func virtualAIRuntime(now time.Time) aiRuntime { @@ -189,16 +194,71 @@ func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now tim writer.Start() writer.Text(err.Error() + "\n\n" + helpText()) writer.Finish(agui.FinishReasonStop) - return []aiRunPlan{{Run: run}}, nil + return []aiRunPlan{{Run: run, EffectiveCommand: input}}, nil } if cmd != nil && cmd.Chaos != nil { return buildAIChaosRunPlans(ctx, runID, threadID, now, *cmd.Chaos, agentID, agentName) } + resolveCommandSeed(cmd, now) run, err := buildAIRunFromCommand(ctx, runID, threadID, now, cmd, agentID, agentName) if err != nil { return nil, err } - return []aiRunPlan{{Run: run}}, nil + return []aiRunPlan{{Run: run, EffectiveCommand: canonicalCommand(input, cmd)}}, nil +} + +// resolveCommandSeed fills in an implicit seed for commands that derive their +// random behavior from the current time, so the continuation can replay the +// exact same sequence. +func resolveCommandSeed(cmd *parsedCommand, now time.Time) { + if cmd == nil { + return + } + switch { + case cmd.Lorem != nil && !cmd.Lorem.Options.SeedSet: + cmd.Lorem.Options.Seed = now.UnixNano() + cmd.Lorem.Options.SeedSet = true + case cmd.Tools != nil && !cmd.Tools.Options.SeedSet: + cmd.Tools.Options.Seed = now.UnixNano() + cmd.Tools.Options.SeedSet = true + case cmd.Random != nil && !cmd.Random.SeedSet: + cmd.Random.Seed = now.UnixNano() + cmd.Random.SeedSet = true + } +} + +// canonicalCommand returns a command string that, when re-parsed, reproduces +// the same run as cmd. If the original input already encoded all randomness +// inputs (e.g. an explicit --seed), it is returned as-is. +func canonicalCommand(input string, cmd *parsedCommand) string { + if cmd == nil { + return input + } + switch { + case cmd.Lorem != nil: + return ensureSeedFlag(input, cmd.Lorem.Options.Seed, cmd.Lorem.Options.SeedSet) + case cmd.Tools != nil: + return ensureSeedFlag(input, cmd.Tools.Options.Seed, cmd.Tools.Options.SeedSet) + case cmd.Random != nil: + return ensureSeedFlag(input, cmd.Random.Seed, cmd.Random.SeedSet) + } + return input +} + +func ensureSeedFlag(input string, seed int64, seedSet bool) string { + if !seedSet || hasSeedFlag(input) { + return input + } + return strings.TrimRight(input, " ") + " --seed=" + strconv.FormatInt(seed, 10) +} + +func hasSeedFlag(input string) bool { + for _, token := range strings.Fields(input) { + if strings.HasPrefix(token, "--seed=") || token == "--seed" { + return true + } + } + return false } func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string) (*aistream.Run, error) { @@ -264,18 +324,44 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t AllowApproval: cmd.AllowApproval, }, } - run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), &parsedCommand{ - Name: "stream-random", - Random: &randomCmd, - }, agentID, agentName) + parsed := &parsedCommand{Name: "stream-random", Random: &randomCmd} + run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName) if err != nil { return nil, err } - plans = append(plans, aiRunPlan{Run: run, Delay: delay}) + plans = append(plans, aiRunPlan{ + Run: run, + Delay: delay, + EffectiveCommand: chaosSubRunCommand(randomCmd), + }) } return plans, nil } +// chaosSubRunCommand renders a stream-random command equivalent to the sub-run +// derived from a stream-chaos invocation. Used as the canonical command stored +// in the approval context so a chaos approval can be replayed deterministically. +func chaosSubRunCommand(cmd randomCommand) string { + parts := []string{ + "stream-random", + strconv.Itoa(int(cmd.Duration / time.Second)), + "--actions=" + strconv.Itoa(cmd.Actions), + "--delay-ms=" + strconv.Itoa(int(cmd.DelayMin/time.Millisecond)) + ":" + strconv.Itoa(int(cmd.DelayMax/time.Millisecond)), + "--profile=" + cmd.Profile, + "--seed=" + strconv.FormatInt(cmd.Seed, 10), + } + if cmd.AllowAbort { + parts = append(parts, "--allow-abort") + } + if cmd.AllowError { + parts = append(parts, "--allow-error") + } + if cmd.AllowApproval { + parts = append(parts, "--allow-approval") + } + return strings.Join(parts, " ") +} + func parseCommand(input string) (*parsedCommand, error) { tokens, err := shlex.Split(input) if err != nil { diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 752ae68..1b8b8e0 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -770,6 +770,115 @@ func TestBuildDemoVisibleTextIsMarkdownRichAndDeterministic(t *testing.T) { t.Fatalf("expected markdown-rich text, got %q", first) } +func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { + command := "stream-tools 240 shell#approval fetch#approval --seed=7 --chunk-chars=32:32" + approvalCtx := aistream.ApprovalContext{ + ID: "approval-run-1-dummy-tool-1-shell", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Command: command, + ToolCallID: "dummy-tool-1-shell", + ToolName: "shell", + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 12, + } + run, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + ID: approvalCtx.ID, + Approved: true, + }, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Prompts) != 1 { + t.Fatalf("expected second approval prompt to be preserved, got %#v", run.Prompts) + } + if run.Prompts[0].ToolName != "fetch" { + t.Fatalf("expected preserved prompt to belong to fetch, got %#v", run.Prompts[0]) + } + if run.Status.State != "streaming" { + t.Fatalf("expected continuation with pending approval to remain streaming, got %#v", run.Status) + } +} + +func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { + // Iterate clocks until the random-action profile produces an approval + // request — the seed is implicit (resolved from now()), and the bug being + // guarded against is that the continuation would otherwise pick a fresh + // seed and lose the original toolCallID. + for tick := int64(1); tick <= 500; tick++ { + now := time.Unix(tick, 0) + plans, err := buildAIRunPlans(context.Background(), "run-rand", "thread-rand", "stream-random 1 --profile=tools --allow-approval", now, "ai", "AI") + if err != nil { + t.Fatal(err) + } + if len(plans) != 1 || plans[0].Run == nil { + t.Fatalf("expected one random plan, got %#v", plans) + } + originalRun := plans[0].Run + if originalRun.ApprovalID == "" { + continue + } + if !strings.Contains(plans[0].EffectiveCommand, "--seed=") { + t.Fatalf("effective command must include resolved seed: %q", plans[0].EffectiveCommand) + } + approvalCtx := aistream.ApprovalContext{ + ID: originalRun.ApprovalID, + ThreadID: originalRun.ThreadID, + RunID: originalRun.RunID, + MessageID: originalRun.MessageID, + Command: plans[0].EffectiveCommand, + ToolCallID: originalRun.ToolCallID, + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 50, + } + continuation, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + ID: approvalCtx.ID, + Approved: true, + }, now.Add(time.Hour)) + if err != nil { + t.Fatalf("continuation failed: %v", err) + } + if len(continuation.Events) == 0 { + t.Fatalf("expected continuation events for random run, got none") + } + if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("first continuation event should acknowledge approval, got %#v", continuation.Events[0]) + } + return + } + t.Fatal("no implicit-seed random run produced an approval prompt in the tested range") +} + +func TestChaosSubRunCommandIsParseable(t *testing.T) { + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-chaos", "stream-chaos 2 1 --allow-approval --seed=11", time.Unix(0, 0), "ai", "AI") + if err != nil { + t.Fatal(err) + } + if len(plans) != 2 { + t.Fatalf("expected two chaos sub-runs, got %d", len(plans)) + } + for i, plan := range plans { + if !strings.HasPrefix(plan.EffectiveCommand, "stream-random ") { + t.Fatalf("chaos plan %d must render as stream-random, got %q", i, plan.EffectiveCommand) + } + if !strings.Contains(plan.EffectiveCommand, "--seed=") { + t.Fatalf("chaos sub-run command must include explicit seed: %q", plan.EffectiveCommand) + } + cmd, err := parseCommand(plan.EffectiveCommand) + if err != nil { + t.Fatalf("chaos sub-run command did not re-parse: %v (%q)", err, plan.EffectiveCommand) + } + if cmd == nil || cmd.Random == nil || !cmd.Random.SeedSet { + t.Fatalf("re-parsed chaos sub-run lost seed: %#v", cmd) + } + } +} + func jsonResultMap(t *testing.T, value any) map[string]any { t.Helper() text, ok := value.(string) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 73a1984..298c7a6 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -296,7 +296,13 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M Msg("Ignoring duplicate dummy AI approval reaction") return &database.Reaction{}, nil } - dc.queueAIApprovalResponse(ctx, msg.Portal, msg.TargetMessage, response) + portal := msg.Portal + target := msg.TargetMessage + dc.wg.Add(1) + go func() { + defer dc.wg.Done() + dc.queueAIApprovalResponse(dc.ctx, portal, target, response) + }() logger := log.Info(). Str("approval_id", approvalID). @@ -388,7 +394,9 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri func (dc *DummyClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && msg != nil && msg.TargetReaction != nil { - _ = dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, msg.TargetReaction) + if err := dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, msg.TargetReaction); err != nil { + log.Warn().Err(err).Stringer("reaction_mxid", msg.TargetReaction.MXID).Msg("Failed to delete reaction on remove") + } } return nil } @@ -445,6 +453,40 @@ func isAIDemoCommandContent(content *event.MessageEventContent) bool { } } +// ensureAISenderInvited queues a ChatInfoChange that adds the AI sender ghost +// to the given portal. The bridge's default portal generator can create +// portals with members=0, in which case the per-portal AI sender chosen by +// dummyAISenderForPortal is not actually a room member — sending the anchor +// from a non-member ghost would fail. Re-asserting an existing membership is +// a no-op for bridgev2, so it is safe to call for every AI run. +func (dc *DummyClient) ensureAISenderInvited(portal *bridgev2.Portal, sender networkid.UserID) { + if dc == nil || dc.UserLogin == nil || portal == nil || sender == "" { + return + } + if isAIPortalID(portal.ID) { + return + } + changes := &bridgev2.ChatMemberList{MemberMap: bridgev2.ChatMemberMap{}} + changes.MemberMap.Set(bridgev2.ChatMember{ + EventSender: bridgev2.EventSender{Sender: sender}, + Membership: event.MembershipJoin, + MemberEventExtra: map[string]any{ + "displayname": dummyAIAgentNameForPortal(portal), + }, + }) + now := time.Now() + dc.UserLogin.QueueRemoteEvent(&simplevent.ChatInfoChange{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatInfoChange, + PortalKey: portal.PortalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: now, + StreamOrder: now.UnixNano(), + }, + ChatInfoChange: &bridgev2.ChatInfoChange{MemberChanges: changes}, + }) +} + func dummyAISenderForPortal(portal *bridgev2.Portal) networkid.UserID { if portal == nil { return networkid.UserID(dummyAIAgentName) @@ -533,6 +575,10 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por continue } placeholderID := networkid.MessageID(plan.Run.MessageID) + effectiveCommand := plan.EffectiveCommand + if effectiveCommand == "" { + effectiveCommand = body + } dc.wg.Add(1) go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { @@ -546,9 +592,11 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por return } } - dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(run), time.Now())) - dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run, command) - }(portal, sender, placeholderID, *plan.Run, body, plan.Delay) + dc.ensureAISenderInvited(portal, sender) + anchorAt := time.Now() + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(run), anchorAt)) + dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run, command, anchorAt) + }(portal, sender, placeholderID, *plan.Run, effectiveCommand, plan.Delay) } } @@ -558,7 +606,7 @@ func initialAIAnchorRun(run aistream.Run) aistream.Run { return run } -func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string) { +func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, anchorAt time.Time) { targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) if targetEventID == "" { log.Warn(). @@ -567,7 +615,16 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send Msg("Timed out waiting for AI anchor Matrix event") return } - carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, 1) + dc.emitAIRunStream(portal, sender, messageID, targetEventID, run, command, 1, anchorAt) +} + +// emitAIRunStream packs and emits one segment of an AI run — used both for +// the initial run and for any approval continuation. It queues approval +// prompts produced by the segment, repacks once approval event IDs are +// known, and finally emits the carriers and (if the run terminated) the +// final metadata edit. +func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { + carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return @@ -576,7 +633,7 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send nextSeq := aistream.NextSeq(carriers) approvalEventIDs := make(map[string]id.EventID, len(run.Prompts)) for i, prompt := range run.Prompts { - prompt.SeqStart = nextSeq + i*10 + prompt.SeqStart = nextSeq + i*aistream.ApprovalSeqReservation ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { approvalEventIDs[ctx.ID] = approvalEventID @@ -596,14 +653,14 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send } if len(approvalEventIDs) > 0 { annotateApprovalEventIDs(&run, approvalEventIDs) - carriers, err = aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, 1) + carriers, err = aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to repack AI stream with approval event IDs") return } carriers = splitCarriersForTimedEmission(carriers) } - dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, 1) + dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) if len(run.Prompts) > 0 && run.Status.State == "streaming" { log.Info(). Str("run_id", run.RunID). @@ -616,21 +673,22 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send } } -func (dc *DummyClient) queueAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, startSeq int) ([]aistream.Carrier, error) { - carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) - if err != nil { - return nil, err - } - carriers = splitCarriersForTimedEmission(carriers) - dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq) - return carriers, nil -} - -func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int) { +func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { streamStart := time.Now() + // minCarrierTimestamp guarantees every carrier lands strictly after the + // anchor message timestamp so Matrix room ordering keeps the anchor first + // and downstream RelatesTo resolution can always find the parent event. + minCarrierTimestamp := anchorAt.Add(time.Millisecond) + if streamStart.Before(minCarrierTimestamp) { + streamStart = minCarrierTimestamp + } for i, carrier := range carriers { dc.sleepUntilCarrierTime(run, carrier, streamStart) now := time.Now() + if now.Before(minCarrierTimestamp) { + now = minCarrierTimestamp + } + minCarrierTimestamp = now.Add(time.Nanosecond) dc.UserLogin.QueueRemoteEvent(aibridgev2.Carrier(portal.PortalKey, sender, run, carrier, targetEventID, startSeq+i, now)) } } @@ -844,11 +902,8 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid if sender == "" { sender = dummyAISenderForPortal(portal) } - if _, err := dc.queueAICarriers(portal, sender, targetEventID, run, approvalCtx.SeqStart); err != nil { - log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to queue AI approval response") - return - } - dc.queueAIRunFinalMetadata(portal, sender, networkid.MessageID(approvalCtx.MessageID), run) + dc.ensureAISenderInvited(portal, sender) + dc.emitAIRunStream(portal, sender, networkid.MessageID(approvalCtx.MessageID), targetEventID, run, approvalCtx.Command, approvalCtx.SeqStart, now) log.Info(). Str("run_id", approvalCtx.RunID). Str("approval_id", approvalCtx.ID). @@ -857,6 +912,7 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid Bool("always", response.Always). Int("seq_start", approvalCtx.SeqStart). Str("state", run.Status.State). + Int("pending_prompts", len(run.Prompts)). Msg("Queued AI approval continuation") } @@ -887,10 +943,41 @@ func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.Ap run.MessageID = approvalCtx.MessageID run.ToolCallID = approvalCtx.ToolCallID run.ApprovalID = approvalCtx.ID - run.Prompts = nil + // Keep only prompts that the continuation segment newly emitted (i.e. + // approvals raised by tools that ran AFTER the resolved one). The + // already-resolved approval has been removed from the event range above + // and must not be queued again. + run.Prompts = filterPendingPrompts(run.Prompts, approvalCtx.ID, run.Events) return *run, nil } +func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, events []agui.Event) []aistream.ApprovalPrompt { + if len(prompts) == 0 { + return nil + } + requested := make(map[string]bool, len(events)) + for _, evt := range events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + continue + } + value, _ := evt["value"].(map[string]any) + if id := aistream.ApprovalIDFromRequestedValue(value); id != "" { + requested[id] = true + } + } + out := prompts[:0] + for _, prompt := range prompts { + if prompt.ID == resolvedID { + continue + } + if !requested[prompt.ID] { + continue + } + out = append(out, prompt) + } + return out +} + func approvalContinuationStart(events []agui.Event, approvalID string) int { for i, evt := range events { if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomResponded { From 5bd611bbe98c8c349eb43154919474cc21b65b02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 21:45:28 +0200 Subject: [PATCH 29/46] wip --- pkg/ai-stream/matrix/content.go | 2 +- pkg/ai-stream/matrix/content_test.go | 37 ++++++++++++++++++++++++++++ pkg/ai-stream/run.go | 6 +++++ pkg/connector/client.go | 1 + pkg/connector/client_test.go | 10 ++++---- 5 files changed, 50 insertions(+), 6 deletions(-) diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go index 3a3c0e5..4192a32 100644 --- a/pkg/ai-stream/matrix/content.go +++ b/pkg/ai-stream/matrix/content.go @@ -26,7 +26,7 @@ func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any func FinalContent(run aistream.Run) (*event.MessageEventContent, map[string]any) { content := previewContent(run) extra := map[string]any{ - aistream.BeeperAIKey: run.FinalUIMessage(aistream.SnapshotTextBytes, true), + aistream.BeeperAIKey: run.FinalUIMessage(0, true), aistream.BeeperAIMetadataKey: run.Metadata(), "com.beeper.stream": map[string]any{ "type": aistream.BeeperAIStreamDeltas, diff --git a/pkg/ai-stream/matrix/content_test.go b/pkg/ai-stream/matrix/content_test.go index ca656ec..58ab46f 100644 --- a/pkg/ai-stream/matrix/content_test.go +++ b/pkg/ai-stream/matrix/content_test.go @@ -65,6 +65,20 @@ func TestAnchorContentKeepsLongRunsCompact(t *testing.T) { } } +func TestStreamingAnchorDoesNotIncludePreviewPart(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + run.Preview = aistream.Preview{} + + content, extra := AnchorContent(*run) + if content.Body != "..." { + t.Fatalf("empty streaming anchor should use placeholder body, got %q", content.Body) + } + uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) + if !ok || len(uiMessage.Parts) != 0 { + t.Fatalf("streaming anchor should not include an initial text snapshot: %#v", extra[aistream.BeeperAIKey]) + } +} + func TestAnchorContentRendersFinalPreviewAsMatrixHTML(t *testing.T) { run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) run.Preview = aistream.Preview{Text: "Use **bold** and `code`"} @@ -106,6 +120,29 @@ func TestFinalContentIncludesFinalUIParts(t *testing.T) { } } +func TestFinalContentDoesNotTruncateUIParts(t *testing.T) { + run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Start() + full := strings.Repeat("| Artifact | State | Latency |\n| --- | --- | --- |\n| renderer | active | accepts markdown |\n\n", 100) + writer.Text(full) + writer.Finish(agui.FinishReasonStop) + expected := run.Text() + + _, extra := FinalContent(*run) + uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) + if !ok || len(uiMessage.Parts) == 0 { + t.Fatalf("missing final UI message: %#v", extra[aistream.BeeperAIKey]) + } + textPart := uiMessage.Parts[len(uiMessage.Parts)-1] + if textPart["content"] != expected { + t.Fatalf("final UI text was truncated: got %d bytes want %d", len(textPart["content"].(string)), len(expected)) + } + if metadata, ok := textPart["providerMetadata"]; ok { + t.Fatalf("final UI text should not be marked truncated: %#v", metadata) + } +} + func TestCarrierContentIsHiddenTextCarrierWithDeltas(t *testing.T) { carrier := aistream.Carrier{Envelopes: []aistream.Envelope{{ ThreadID: "thread-1", diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index fb519e3..50dd379 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -661,6 +661,12 @@ func compactTextPart(part agui.MessagePart, budget int) { return } content, _ := part["content"].(string) + if budget <= 0 { + if part["state"] == "" { + part["state"] = agui.PartStateDone + } + return + } preview := BoundedPreview(content, budget) part["content"] = preview if len(preview) < len(content) { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 298c7a6..9297abd 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -603,6 +603,7 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por func initialAIAnchorRun(run aistream.Run) aistream.Run { run.Status = aistream.Status{State: "streaming"} run.Usage = agui.Usage{} + run.Preview = aistream.Preview{} return run } diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 04ed5a9..673bf8f 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -94,7 +94,7 @@ func TestResolveApprovalOnceKeepsFirstSelection(t *testing.T) { } } -func TestInitialAIAnchorRunKeepsPreviewButNotTerminalMetadata(t *testing.T) { +func TestInitialAIAnchorRunOmitsPreviewAndTerminalMetadata(t *testing.T) { run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) writer.Start() @@ -102,12 +102,12 @@ func TestInitialAIAnchorRunKeepsPreviewButNotTerminalMetadata(t *testing.T) { writer.Finish(agui.FinishReasonStop) anchor := initialAIAnchorRun(*run) - if anchor.Preview.Text == "" { - t.Fatal("expected anchor to keep useful preview text") + if anchor.Preview.Text != "" { + t.Fatalf("anchor should not include initial preview text: %#v", anchor.Preview) } uiMessage := anchor.InitialUIMessage() - if len(uiMessage.Parts) != 1 || uiMessage.Parts[0]["type"] != "text" || uiMessage.Parts[0]["content"] != "visible preview" { - t.Fatalf("anchor UI message should include visible preview text part: %#v", uiMessage.Parts) + if len(uiMessage.Parts) != 0 { + t.Fatalf("anchor UI message should wait for stream deltas: %#v", uiMessage.Parts) } if uiMessage.Metadata["runId"] != run.RunID { t.Fatalf("anchor UI metadata missing run id: %#v", uiMessage.Metadata) From ceef0f89f3ce03c97a6e363f98f3ce5b90096538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 22:11:45 +0200 Subject: [PATCH 30/46] wip --- pkg/ai-stream/approval.go | 29 ++-- pkg/connector/ai_runtime.go | 249 ++++++++++++++++-------------- pkg/connector/ai_runtime_test.go | 68 +++++--- pkg/connector/ai_stream_random.go | 89 +++++++++++ pkg/connector/client.go | 83 +++++++++- pkg/connector/client_test.go | 5 +- 6 files changed, 370 insertions(+), 153 deletions(-) create mode 100644 pkg/connector/ai_stream_random.go diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index 4b4b0d1..d0dfd28 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -34,20 +34,21 @@ type ReactionEvent struct { } type ApprovalContext struct { - ID string `json:"id"` - ThreadID string `json:"threadId"` - RunID string `json:"runId"` - MessageID string `json:"messageId"` - Command string `json:"command"` - ToolCallID string `json:"toolCallId"` - ToolName string `json:"toolName"` - TargetEvent string `json:"target_event"` - AgentID string `json:"agentId,omitempty"` - AgentName string `json:"agentName,omitempty"` - Model string `json:"model,omitempty"` - SeqStart int `json:"seqStart,omitempty"` - PreviewText string `json:"previewText,omitempty"` - PreviewTruncated bool `json:"previewTruncated,omitempty"` + ID string `json:"id"` + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + MessageID string `json:"messageId"` + Command string `json:"command"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + TargetEvent string `json:"target_event"` + AgentID string `json:"agentId,omitempty"` + AgentName string `json:"agentName,omitempty"` + Model string `json:"model,omitempty"` + SeqStart int `json:"seqStart,omitempty"` + PriorApprovals []agui.ToolApprovalResponse `json:"priorApprovals,omitempty"` + PreviewText string `json:"previewText,omitempty"` + PreviewTruncated bool `json:"previewTruncated,omitempty"` } // ApprovalSeqReservation is the size of the sequence-number window reserved diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index cb726af..3f2d9aa 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -108,10 +108,15 @@ type sharedStreamOptions struct { } type randomCommand struct { - Duration time.Duration - Actions int - DelayMin time.Duration - DelayMax time.Duration + Duration time.Duration + Actions int + Chars int + DelayMin time.Duration + DelayMax time.Duration + Terminal string + Runs int + StaggerMin time.Duration + StaggerMax time.Duration sharedStreamOptions } @@ -200,6 +205,9 @@ func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now tim return buildAIChaosRunPlans(ctx, runID, threadID, now, *cmd.Chaos, agentID, agentName) } resolveCommandSeed(cmd, now) + if cmd != nil && cmd.Random != nil && cmd.Random.Runs > 1 { + return buildAIStreamRunPlans(ctx, runID, threadID, now, *cmd.Random, agentID, agentName) + } run, err := buildAIRunFromCommand(ctx, runID, threadID, now, cmd, agentID, agentName) if err != nil { return nil, err @@ -324,7 +332,7 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t AllowApproval: cmd.AllowApproval, }, } - parsed := &parsedCommand{Name: "stream-random", Random: &randomCmd} + parsed := &parsedCommand{Name: "stream", Random: &randomCmd} run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName) if err != nil { return nil, err @@ -338,27 +346,64 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t return plans, nil } -// chaosSubRunCommand renders a stream-random command equivalent to the sub-run -// derived from a stream-chaos invocation. Used as the canonical command stored -// in the approval context so a chaos approval can be replayed deterministically. +func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd randomCommand, agentID, agentName string) ([]aiRunPlan, error) { + seed := cmd.Seed + if !cmd.SeedSet { + seed = now.UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + plans := make([]aiRunPlan, 0, cmd.Runs) + for i := range cmd.Runs { + var delay time.Duration + if i > 0 { + delay = aiRunner{runtime: virtualAIRuntime(now)}.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + } + child := cmd + child.Runs = 1 + child.Seed = seed + int64(i+1)*97 + child.SeedSet = true + parsed := &parsedCommand{Name: "stream", Random: &child} + run, err := buildAIRunFromCommand(ctx, fmt.Sprintf("%s-%d", baseRunID, i+1), threadID, now.Add(delay), parsed, agentID, agentName) + if err != nil { + return nil, err + } + plans = append(plans, aiRunPlan{ + Run: run, + Delay: delay, + EffectiveCommand: streamSubRunCommand(child), + }) + } + return plans, nil +} + func chaosSubRunCommand(cmd randomCommand) string { + return streamSubRunCommand(cmd) +} + +func streamSubRunCommand(cmd randomCommand) string { parts := []string{ - "stream-random", + "stream", strconv.Itoa(int(cmd.Duration / time.Second)), "--actions=" + strconv.Itoa(cmd.Actions), "--delay-ms=" + strconv.Itoa(int(cmd.DelayMin/time.Millisecond)) + ":" + strconv.Itoa(int(cmd.DelayMax/time.Millisecond)), "--profile=" + cmd.Profile, "--seed=" + strconv.FormatInt(cmd.Seed, 10), } + if cmd.Chars > 0 { + parts = append(parts, "--chars="+strconv.Itoa(cmd.Chars)) + } + if cmd.Terminal != "" { + parts = append(parts, "--terminal="+cmd.Terminal) + } + if !cmd.AllowApproval { + parts = append(parts, "--no-approval") + } if cmd.AllowAbort { parts = append(parts, "--allow-abort") } if cmd.AllowError { parts = append(parts, "--allow-error") } - if cmd.AllowApproval { - parts = append(parts, "--allow-approval") - } return strings.Join(parts, " ") } @@ -373,23 +418,14 @@ func parseCommand(input string) (*parsedCommand, error) { switch strings.ToLower(tokens[0]) { case "help", "/help", "!help", "dummybridge": return &parsedCommand{Name: "help"}, nil - case "stream-lorem": - cmd, err := parseLoremCommand(tokens[1:]) - return &parsedCommand{Name: "stream-lorem", Lorem: cmd}, err case "stream-tools": cmd, err := parseToolsCommand(tokens[1:]) return &parsedCommand{Name: "stream-tools", Tools: cmd}, err - case "stream-random": - cmd, err := parseRandomCommand(tokens[1:]) - return &parsedCommand{Name: "stream-random", Random: cmd}, err - case "stream-chaos": - cmd, err := parseChaosCommand(tokens[1:]) - return &parsedCommand{Name: "stream-chaos", Chaos: cmd}, err + case "stream": + cmd, err := parseStreamCommand(tokens[1:]) + return &parsedCommand{Name: "stream", Random: cmd}, err default: - return &parsedCommand{Name: "stream-lorem", Lorem: &loremCommand{ - Chars: min(max(len(input)*4, 120), 1200), - Options: defaultCommonOptions(), - }}, nil + return nil, fmt.Errorf("unknown AI demo command %q", tokens[0]) } } @@ -397,11 +433,9 @@ func helpText() string { return strings.Join([]string{ "DummyBridge demo commands:", "help", - "stream-lorem [--reasoning=N] [--steps=N] [--sources=N] [--documents=N] [--files=N] [--meta] [--data=name] [--data-transient=name] [--delay-ms=min:max] [--chunk-chars=min:max] [--seed=N] [--finish=stop|length|tool-calls|content-filter|other] [--abort|--error]", + "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", "stream-tools ... [common options]", - "stream-random [seconds] [--actions=N] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--delay-ms=min:max] [--allow-abort] [--allow-error] [--allow-approval]", - "stream-chaos [runs] [seconds] [--profile=balanced|tools|artifacts|terminals] [--seed=N] [--stagger-ms=min:max] [--max-actions=N] [--allow-abort] [--allow-error] [--allow-approval]", - "Notes: approval-tagged tools emit a separate Matrix approval event with reaction options.", + "Notes: stream enables approval requests by default; approval-tagged tools emit a separate Matrix approval event with reaction options.", }, "\n") } @@ -417,7 +451,7 @@ func defaultCommonOptions() commonCommandOptions { func parseLoremCommand(tokens []string) (*loremCommand, error) { if len(tokens) == 0 { - return nil, fmt.Errorf("stream-lorem requires a character count") + return nil, fmt.Errorf("text stream requires a character count") } count, err := parsePositiveInt(tokens[0], "character count") if err != nil { @@ -479,8 +513,28 @@ func parseRandomCommand(tokens []string) (*randomCommand, error) { Actions: 20, DelayMin: 350 * time.Millisecond, DelayMax: 1150 * time.Millisecond, + Runs: 1, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, } + return parseStreamLikeCommand(tokens, cmd, false) +} + +func parseStreamCommand(tokens []string) (*randomCommand, error) { + cmd := &randomCommand{ + Duration: 20 * time.Second, + DelayMin: 350 * time.Millisecond, + DelayMax: 1150 * time.Millisecond, + Runs: 1, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced", AllowApproval: true}, + } + return parseStreamLikeCommand(tokens, cmd, true) +} + +func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions bool) (*randomCommand, error) { rest := tokens if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { seconds, err := parsePositiveInt(rest[0], "duration") @@ -493,6 +547,9 @@ func parseRandomCommand(tokens []string) (*randomCommand, error) { cmd.Duration = time.Duration(seconds) * time.Second rest = rest[1:] } + if deriveActions && cmd.Actions == 0 { + cmd.Actions = max(3, min(maxDemoRandomActions, int(cmd.Duration/time.Second)*2)) + } for _, token := range rest { key, value, hasValue := parseOptionToken(token) switch key { @@ -502,19 +559,53 @@ func parseRandomCommand(tokens []string) (*randomCommand, error) { return nil, err } cmd.Actions = n + case "chars": + n, err := parseValidatedInt(value, hasValue, token, "character count", maxDemoChars, false) + if err != nil { + return nil, err + } + cmd.Chars = n case "delay-ms": minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) if err != nil { return nil, err } cmd.DelayMin, cmd.DelayMax = minDelay, maxDelay + case "terminal": + if !hasValue { + return nil, fmt.Errorf("%s requires a value", token) + } + switch strings.ToLower(value) { + case "stop", "finish": + cmd.Terminal = "finish" + case "abort", "error": + cmd.Terminal = strings.ToLower(value) + case "length", "tool-calls", "content-filter", "other": + cmd.Terminal = agui.NormalizeFinishReason(value) + default: + return nil, fmt.Errorf("unknown terminal %q", value) + } + case "runs": + n, err := parseValidatedInt(value, hasValue, token, "run count", maxDemoChaosRuns, false) + if err != nil { + return nil, err + } + cmd.Runs = n + case "stagger-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) + if err != nil { + return nil, err + } + cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay + case "no-approval": + cmd.AllowApproval = false default: handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) if err != nil || !handled { if err != nil { return nil, err } - return nil, fmt.Errorf("unknown random option %q", token) + return nil, fmt.Errorf("unknown stream option %q", token) } } } @@ -678,7 +769,7 @@ func parseSharedStreamOption(key, value string, hasValue bool, token string, opt return false, fmt.Errorf("%s requires a value", token) } switch strings.ToLower(value) { - case "balanced", "tools", "artifacts", "terminals": + case "balanced", "tools", "errors", "artifacts": opts.Profile = strings.ToLower(value) default: return false, fmt.Errorf("unknown profile %q", value) @@ -696,8 +787,6 @@ func parseSharedStreamOption(key, value string, hasValue bool, token string, opt opts.AllowAbort = true case "allow-error": opts.AllowError = true - case "allow-approval": - opts.AllowApproval = true default: return false, nil } @@ -818,6 +907,15 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC stepOpen := false stepName := "" actionOptions, actionWeightTotal := buildRandomActionOptions(cmd) + if cmd.Chars > 0 { + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { + w.Text(chunk) + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax)); err != nil { + return err + } + } + } handleTool := func(spec toolSpec) error { if err := r.runToolSpec(ctx, w, spec, rng, defaultCommonOptions()); err != nil { if errors.Is(err, errApprovalRequested) && stepOpen { @@ -885,7 +983,7 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC case randomActionFile: w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/random-file-%d", action+1), "mediaType": "application/octet-stream"}) case randomActionMetadata: - w.StateDelta(statePatch(map[string]any{"command": "stream-random", "seed": seed, "action": action + 1, "profile": cmd.Profile})) + w.StateDelta(statePatch(map[string]any{"command": "stream", "seed": seed, "action": action + 1, "profile": cmd.Profile})) case randomActionData: w.Custom("com.beeper.data", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) case randomActionDataTransient: @@ -895,11 +993,14 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC if stepOpen { w.StepFinish(stepName) } - switch chooseRandomTerminal(cmd, rng) { + terminal := chooseRandomTerminal(cmd, rng) + switch terminal { case "abort": w.Abort("DummyBridge random mode aborted") case "error": w.Error("DummyBridge random mode failed") + case agui.FinishReasonLength, agui.FinishReasonToolCalls, agui.FinishReasonContentFilter, agui.FinishReasonOther: + w.Finish(terminal) default: w.Finish(agui.FinishReasonStop) } @@ -1094,82 +1195,6 @@ func statePatch(values map[string]any) []map[string]any { return patch } -func buildRandomActionOptions(cmd randomCommand) ([]randomActionOption, int) { - options := []randomActionOption{ - {randomActionText, 6}, - {randomActionThinking, 4}, - {randomActionStep, 2}, - {randomActionTool, 3}, - {randomActionToolFail, 2}, - {randomActionSource, 2}, - {randomActionDocument, 2}, - {randomActionFile, 2}, - {randomActionMetadata, 2}, - {randomActionData, 1}, - {randomActionDataTransient, 1}, - } - if cmd.AllowApproval { - options = append(options, randomActionOption{randomActionToolApproval, 2}) - } - switch cmd.Profile { - case "tools": - options = append(options, - randomActionOption{randomActionTool, 6}, - randomActionOption{randomActionToolFail, 4}, - randomActionOption{randomActionToolDeny, 3}, - ) - if cmd.AllowApproval { - options = append(options, randomActionOption{randomActionToolApproval, 4}) - } - case "artifacts": - options = append(options, - randomActionOption{randomActionSource, 4}, - randomActionOption{randomActionDocument, 4}, - randomActionOption{randomActionFile, 4}, - randomActionOption{randomActionMetadata, 3}, - randomActionOption{randomActionData, 3}, - randomActionOption{randomActionDataTransient, 3}, - ) - case "terminals": - options = append(options, randomActionOption{randomActionStep, 5}) - } - total := 0 - for _, option := range options { - total += option.weight - } - return options, total -} - -func pickWeighted(options []randomActionOption, total int, rng *rand.Rand) string { - if total <= 0 || len(options) == 0 { - return randomActionText - } - pick := rng.Intn(total) - for _, option := range options { - if pick < option.weight { - return option.name - } - pick -= option.weight - } - return randomActionText -} - -func chooseRandomTerminal(cmd randomCommand, rng *rand.Rand) string { - options := []string{"finish"} - if cmd.AllowAbort { - options = append(options, "abort") - } - if cmd.AllowError { - options = append(options, "error") - } - return options[rng.Intn(len(options))] -} - -func randomToolName(rng *rand.Rand) string { - names := []string{"search", "fetch", "summarize", "calendar", "shell", "files", "preview"} - return names[rng.Intn(len(names))] -} - func (r aiRunner) sampleDelay(rng *rand.Rand, minDelay, maxDelay time.Duration) time.Duration { if maxDelay <= minDelay { return minDelay diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 1b8b8e0..7c18cdc 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -35,10 +35,9 @@ func TestParseCommandRejectsConflictingToolTags(t *testing.T) { func TestParseCommandRejectsInvalidProfilesAndOversizedOptions(t *testing.T) { tests := []string{ - "stream-random --profile=unknown", - "stream-lorem 100 --abort --error", - "stream-lorem 100 --finish=length --abort", - "stream-lorem 1000000", + "stream --profile=unknown", + "stream --terminal=unknown", + "stream --chars=1000000", "stream-tools 100 shell --chunk-chars=1:9999", } for _, input := range tests { @@ -51,12 +50,10 @@ func TestParseCommandRejectsInvalidProfilesAndOversizedOptions(t *testing.T) { func TestHelpTextMentionsCommandsOptionsAndToolTags(t *testing.T) { guide := helpText() for _, expected := range []string{ - "stream-lorem", "stream-tools", - "stream-random", - "stream-chaos", - "--data-transient", - "--allow-approval", + "stream", + "--profile=balanced|tools|errors|artifacts", + "--no-approval", "#provider", "#inputerror", } { @@ -67,7 +64,7 @@ func TestHelpTextMentionsCommandsOptionsAndToolTags(t *testing.T) { } func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { - run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-lorem 400 --reasoning=80 --steps=2 --sources=1 --documents=1 --files=1 --meta --data=demo --data-transient=temp --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 400 search --reasoning=80 --steps=2 --sources=1 --documents=1 --files=1 --meta --data=demo --data-transient=temp --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) if err != nil { t.Fatal(err) } @@ -579,14 +576,14 @@ func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { } func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { - errorRun, err := buildAIRun(context.Background(), "run-error", "thread-1", "stream-lorem 80 --error --seed=7", time.Unix(10, 0)) + errorRun, err := buildAIRun(context.Background(), "run-error", "thread-1", "stream 1 --terminal=error --seed=7", time.Unix(10, 0)) if err != nil { t.Fatal(err) } if errorRun.Status.State != "error" { t.Fatalf("expected error status, got %#v", errorRun.Status) } - abortRun, err := buildAIRun(context.Background(), "run-abort", "thread-1", "stream-lorem 80 --abort --seed=7", time.Unix(10, 0)) + abortRun, err := buildAIRun(context.Background(), "run-abort", "thread-1", "stream 1 --terminal=abort --seed=7", time.Unix(10, 0)) if err != nil { t.Fatal(err) } @@ -596,7 +593,7 @@ func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { } func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { - run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-lorem 70000 --seed=7 --chunk-chars=512:512", time.Unix(10, 0)) + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 1 --chars=70000 --actions=1 --seed=7", time.Unix(10, 0)) if err != nil { t.Fatal(err) } @@ -632,7 +629,7 @@ func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { } func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { - plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-1", "stream-chaos 3 1 --max-actions=3 --seed=7 --stagger-ms=1:1", time.Unix(10, 0), "ai", "AI") + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-1", "stream 1 --runs=3 --actions=3 --seed=7 --stagger-ms=1:1", time.Unix(10, 0), "ai", "AI") if err != nil { t.Fatal(err) } @@ -658,7 +655,7 @@ func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { } func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { - run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-random 3 --actions=4 --seed=7 --delay-ms=100:100", time.Unix(10, 0)) + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 3 --actions=4 --seed=7 --delay-ms=100:100", time.Unix(10, 0)) if err != nil { t.Fatal(err) } @@ -685,7 +682,7 @@ func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { func TestRandomModeApprovalPause(t *testing.T) { for seed := int64(1); seed <= 200; seed++ { - run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream-random 1 --profile=tools --allow-approval --seed="+strconv.FormatInt(seed, 10), time.Unix(10, 0)) + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream 1 --profile=tools --seed="+strconv.FormatInt(seed, 10), time.Unix(10, 0)) if err != nil { t.Fatal(err) } @@ -801,6 +798,37 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { if run.Status.State != "streaming" { t.Fatalf("expected continuation with pending approval to remain streaming, got %#v", run.Status) } + + secondCtx := aistream.ApprovalContext{ + ID: run.Prompts[0].ID, + ThreadID: approvalCtx.ThreadID, + RunID: approvalCtx.RunID, + MessageID: approvalCtx.MessageID, + Command: command, + ToolCallID: run.Prompts[0].ToolCallID, + ToolName: run.Prompts[0].ToolName, + TargetEvent: approvalCtx.TargetEvent, + AgentID: approvalCtx.AgentID, + AgentName: approvalCtx.AgentName, + SeqStart: 100, + PriorApprovals: []agui.ToolApprovalResponse{{ + ID: approvalCtx.ID, + Approved: true, + }}, + } + finished, err := buildAIApprovalContinuationRun(context.Background(), secondCtx, agui.ToolApprovalResponse{ + ID: secondCtx.ID, + Approved: true, + }, time.Unix(30, 0)) + if err != nil { + t.Fatal(err) + } + if finished.Status.State != "complete" { + t.Fatalf("second approval continuation should finish, got %#v", finished.Status) + } + if len(finished.Prompts) != 0 { + t.Fatalf("finished continuation should not keep prompts: %#v", finished.Prompts) + } } func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { @@ -810,7 +838,7 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { // seed and lose the original toolCallID. for tick := int64(1); tick <= 500; tick++ { now := time.Unix(tick, 0) - plans, err := buildAIRunPlans(context.Background(), "run-rand", "thread-rand", "stream-random 1 --profile=tools --allow-approval", now, "ai", "AI") + plans, err := buildAIRunPlans(context.Background(), "run-rand", "thread-rand", "stream 1 --profile=tools", now, "ai", "AI") if err != nil { t.Fatal(err) } @@ -855,7 +883,7 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { } func TestChaosSubRunCommandIsParseable(t *testing.T) { - plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-chaos", "stream-chaos 2 1 --allow-approval --seed=11", time.Unix(0, 0), "ai", "AI") + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-chaos", "stream 1 --runs=2 --seed=11", time.Unix(0, 0), "ai", "AI") if err != nil { t.Fatal(err) } @@ -863,8 +891,8 @@ func TestChaosSubRunCommandIsParseable(t *testing.T) { t.Fatalf("expected two chaos sub-runs, got %d", len(plans)) } for i, plan := range plans { - if !strings.HasPrefix(plan.EffectiveCommand, "stream-random ") { - t.Fatalf("chaos plan %d must render as stream-random, got %q", i, plan.EffectiveCommand) + if !strings.HasPrefix(plan.EffectiveCommand, "stream ") { + t.Fatalf("chaos plan %d must render as stream, got %q", i, plan.EffectiveCommand) } if !strings.Contains(plan.EffectiveCommand, "--seed=") { t.Fatalf("chaos sub-run command must include explicit seed: %q", plan.EffectiveCommand) diff --git a/pkg/connector/ai_stream_random.go b/pkg/connector/ai_stream_random.go new file mode 100644 index 0000000..7953fe7 --- /dev/null +++ b/pkg/connector/ai_stream_random.go @@ -0,0 +1,89 @@ +package connector + +import "math/rand" + +func buildRandomActionOptions(cmd randomCommand) ([]randomActionOption, int) { + options := []randomActionOption{ + {randomActionText, 6}, + {randomActionThinking, 4}, + {randomActionStep, 2}, + {randomActionTool, 3}, + {randomActionToolFail, 2}, + {randomActionSource, 2}, + {randomActionDocument, 2}, + {randomActionFile, 2}, + {randomActionMetadata, 2}, + {randomActionData, 1}, + {randomActionDataTransient, 1}, + } + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 2}) + } + switch cmd.Profile { + case "tools": + options = append(options, + randomActionOption{randomActionTool, 6}, + randomActionOption{randomActionToolFail, 4}, + randomActionOption{randomActionToolDeny, 3}, + ) + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 4}) + } + case "artifacts": + options = append(options, + randomActionOption{randomActionSource, 4}, + randomActionOption{randomActionDocument, 4}, + randomActionOption{randomActionFile, 4}, + randomActionOption{randomActionMetadata, 3}, + randomActionOption{randomActionData, 3}, + randomActionOption{randomActionDataTransient, 3}, + ) + case "errors": + options = append(options, + randomActionOption{randomActionToolFail, 7}, + randomActionOption{randomActionToolDeny, 5}, + randomActionOption{randomActionTool, 2}, + ) + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 4}) + } + } + total := 0 + for _, option := range options { + total += option.weight + } + return options, total +} + +func pickWeighted(options []randomActionOption, total int, rng *rand.Rand) string { + if total <= 0 || len(options) == 0 { + return randomActionText + } + pick := rng.Intn(total) + for _, option := range options { + if pick < option.weight { + return option.name + } + pick -= option.weight + } + return randomActionText +} + +func chooseRandomTerminal(cmd randomCommand, rng *rand.Rand) string { + if cmd.Terminal != "" { + return cmd.Terminal + } + options := []string{"finish"} + if cmd.AllowAbort { + options = append(options, "abort") + } + if cmd.AllowError { + options = append(options, "error") + } + return options[rng.Intn(len(options))] +} + +func randomToolName(rng *rand.Rand) string { + names := []string{"search", "fetch", "summarize", "calendar", "shell", "files", "preview"} + return names[rng.Intn(len(names))] +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 9297abd..209ee43 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -444,7 +444,7 @@ func isAIDemoCommandContent(content *event.MessageEventContent) bool { return false } switch strings.ToLower(tokens[0]) { - case "help", "/help", "!help", "stream-lorem", "stream-tools", "stream-random", "stream-chaos": + case "help", "/help", "!help", "stream", "stream-tools": return true case "dummybridge": return len(tokens) > 1 && strings.EqualFold(tokens[1], "help") @@ -846,6 +846,7 @@ func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender net AgentName: run.AgentName, Model: run.Model, SeqStart: prompt.SeqStart, + PriorApprovals: approvalResponsesBeforePrompt(run.Events, prompt.ID), PreviewText: run.Preview.Text, PreviewTruncated: run.Preview.Truncated, } @@ -925,9 +926,14 @@ func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.Ap if response.ID == "" { response.ID = approvalCtx.ID } - run, err := buildAIRunFromCommandWithApprovals(ctx, approvalCtx.RunID, approvalCtx.ThreadID, now, cmd, approvalCtx.AgentID, approvalCtx.AgentName, map[string]agui.ToolApprovalResponse{ - approvalCtx.ID: response, - }) + approvals := make(map[string]agui.ToolApprovalResponse, len(approvalCtx.PriorApprovals)+1) + for _, prior := range approvalCtx.PriorApprovals { + if prior.ID != "" { + approvals[prior.ID] = prior + } + } + approvals[approvalCtx.ID] = response + run, err := buildAIRunFromCommandWithApprovals(ctx, approvalCtx.RunID, approvalCtx.ThreadID, now, cmd, approvalCtx.AgentID, approvalCtx.AgentName, approvals) if err != nil { return aistream.Run{}, err } @@ -952,6 +958,75 @@ func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.Ap return *run, nil } +func approvalResponsesBeforePrompt(events []agui.Event, promptID string) []agui.ToolApprovalResponse { + if promptID == "" { + return nil + } + var responses []agui.ToolApprovalResponse + for _, evt := range events { + if evt["type"] != agui.EventCustom { + continue + } + name, _ := evt["name"].(string) + value, _ := evt["value"].(map[string]any) + if value == nil { + continue + } + if name == agui.ApprovalCustomRequested && aistream.ApprovalIDFromRequestedValue(value) == promptID { + return responses + } + if name != agui.ApprovalCustomResponded { + continue + } + if response, ok := approvalResponseFromAny(value["approval"]); ok && response.ID != "" { + responses = append(responses, response) + } + } + return responses +} + +func approvalResponseFromAny(value any) (agui.ToolApprovalResponse, bool) { + switch typed := value.(type) { + case agui.ToolApprovalResponse: + return typed, typed.ID != "" + case *agui.ToolApprovalResponse: + if typed == nil { + return agui.ToolApprovalResponse{}, false + } + return *typed, typed.ID != "" + case map[string]any: + return approvalResponseFromMap(typed) + default: + raw, err := json.Marshal(value) + if err != nil { + return agui.ToolApprovalResponse{}, false + } + var response agui.ToolApprovalResponse + if err = json.Unmarshal(raw, &response); err != nil { + return agui.ToolApprovalResponse{}, false + } + return response, response.ID != "" + } +} + +func approvalResponseFromMap(value map[string]any) (agui.ToolApprovalResponse, bool) { + idValue, _ := value["id"].(string) + if idValue == "" { + return agui.ToolApprovalResponse{}, false + } + response := agui.ToolApprovalResponse{ID: idValue} + if approved, ok := value["approved"].(bool); ok { + response.Approved = approved + } + if always, ok := value["always"].(bool); ok { + response.Always = always + } + if reason, ok := value["reason"].(string); ok { + response.Reason = reason + } + return response, true +} + func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, events []agui.Event) []aistream.ApprovalPrompt { if len(prompts) == 0 { return nil diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 673bf8f..45b1663 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -52,10 +52,9 @@ func TestAIDemoCommandContentOnlyMatchesExplicitDemoCommands(t *testing.T) { "/help", "!help", "dummybridge help", - "stream-lorem 100", + "stream 20", "stream-tools 100 shell", - "stream-random 1", - "stream-chaos 2 1", + "stream 1 --runs=2", } { if !isAIDemoCommandContent(&event.MessageEventContent{Body: body}) { t.Fatalf("expected AI demo command for %q", body) From ed96d086301cf51d1e55d5840b9a54783a102efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 22:22:16 +0200 Subject: [PATCH 31/46] wip --- pkg/ai-stream/approval.go | 29 ++++--- pkg/connector/ai_runtime_test.go | 14 ++-- pkg/connector/client.go | 139 +++++++++++++------------------ pkg/connector/client_test.go | 14 ++++ 4 files changed, 94 insertions(+), 102 deletions(-) diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index d0dfd28..4b4b0d1 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -34,21 +34,20 @@ type ReactionEvent struct { } type ApprovalContext struct { - ID string `json:"id"` - ThreadID string `json:"threadId"` - RunID string `json:"runId"` - MessageID string `json:"messageId"` - Command string `json:"command"` - ToolCallID string `json:"toolCallId"` - ToolName string `json:"toolName"` - TargetEvent string `json:"target_event"` - AgentID string `json:"agentId,omitempty"` - AgentName string `json:"agentName,omitempty"` - Model string `json:"model,omitempty"` - SeqStart int `json:"seqStart,omitempty"` - PriorApprovals []agui.ToolApprovalResponse `json:"priorApprovals,omitempty"` - PreviewText string `json:"previewText,omitempty"` - PreviewTruncated bool `json:"previewTruncated,omitempty"` + ID string `json:"id"` + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + MessageID string `json:"messageId"` + Command string `json:"command"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + TargetEvent string `json:"target_event"` + AgentID string `json:"agentId,omitempty"` + AgentName string `json:"agentName,omitempty"` + Model string `json:"model,omitempty"` + SeqStart int `json:"seqStart,omitempty"` + PreviewText string `json:"previewText,omitempty"` + PreviewTruncated bool `json:"previewTruncated,omitempty"` } // ApprovalSeqReservation is the size of the sequence-number window reserved diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 7c18cdc..0885fd4 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -811,14 +811,16 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { AgentID: approvalCtx.AgentID, AgentName: approvalCtx.AgentName, SeqStart: 100, - PriorApprovals: []agui.ToolApprovalResponse{{ + } + finished, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), secondCtx, map[string]agui.ToolApprovalResponse{ + approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, - }}, - } - finished, err := buildAIApprovalContinuationRun(context.Background(), secondCtx, agui.ToolApprovalResponse{ - ID: secondCtx.ID, - Approved: true, + }, + secondCtx.ID: { + ID: secondCtx.ID, + Approved: true, + }, }, time.Unix(30, 0)) if err != nil { t.Fatal(err) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 209ee43..1eee3a0 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -37,6 +37,12 @@ type DummyClient struct { approvalSelectionsOnce sync.Once approvalSelections *exsync.Map[string, string] + aiRunSessionsMu sync.Mutex + aiRunSessions map[string]*aiRunSession +} + +type aiRunSession struct { + Decisions map[string]agui.ToolApprovalResponse } var _ bridgev2.NetworkAPI = (*DummyClient)(nil) @@ -625,6 +631,7 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send // known, and finally emits the carriers and (if the run terminated) the // final metadata edit. func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { + dc.ensureAIRunSession(run.RunID) carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") @@ -846,7 +853,6 @@ func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender net AgentName: run.AgentName, Model: run.Model, SeqStart: prompt.SeqStart, - PriorApprovals: approvalResponsesBeforePrompt(run.Events, prompt.ID), PreviewText: run.Preview.Text, PreviewTruncated: run.Preview.Truncated, } @@ -890,7 +896,8 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid response.ID = approvalCtx.ID } now := time.Now() - run, err := buildAIApprovalContinuationRun(ctx, approvalCtx, response, now) + approvals := dc.recordAIApprovalDecision(approvalCtx.RunID, response) + run, err := buildAIApprovalContinuationRunWithApprovals(ctx, approvalCtx, approvals, now) if err != nil { log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to build AI approval continuation") return @@ -919,20 +926,19 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid } func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.ApprovalContext, response agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { - cmd, err := parseCommand(approvalCtx.Command) - if err != nil { - return aistream.Run{}, err - } if response.ID == "" { response.ID = approvalCtx.ID } - approvals := make(map[string]agui.ToolApprovalResponse, len(approvalCtx.PriorApprovals)+1) - for _, prior := range approvalCtx.PriorApprovals { - if prior.ID != "" { - approvals[prior.ID] = prior - } + return buildAIApprovalContinuationRunWithApprovals(ctx, approvalCtx, map[string]agui.ToolApprovalResponse{ + response.ID: response, + }, now) +} + +func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { + cmd, err := parseCommand(approvalCtx.Command) + if err != nil { + return aistream.Run{}, err } - approvals[approvalCtx.ID] = response run, err := buildAIRunFromCommandWithApprovals(ctx, approvalCtx.RunID, approvalCtx.ThreadID, now, cmd, approvalCtx.AgentID, approvalCtx.AgentName, approvals) if err != nil { return aistream.Run{}, err @@ -958,75 +964,6 @@ func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.Ap return *run, nil } -func approvalResponsesBeforePrompt(events []agui.Event, promptID string) []agui.ToolApprovalResponse { - if promptID == "" { - return nil - } - var responses []agui.ToolApprovalResponse - for _, evt := range events { - if evt["type"] != agui.EventCustom { - continue - } - name, _ := evt["name"].(string) - value, _ := evt["value"].(map[string]any) - if value == nil { - continue - } - if name == agui.ApprovalCustomRequested && aistream.ApprovalIDFromRequestedValue(value) == promptID { - return responses - } - if name != agui.ApprovalCustomResponded { - continue - } - if response, ok := approvalResponseFromAny(value["approval"]); ok && response.ID != "" { - responses = append(responses, response) - } - } - return responses -} - -func approvalResponseFromAny(value any) (agui.ToolApprovalResponse, bool) { - switch typed := value.(type) { - case agui.ToolApprovalResponse: - return typed, typed.ID != "" - case *agui.ToolApprovalResponse: - if typed == nil { - return agui.ToolApprovalResponse{}, false - } - return *typed, typed.ID != "" - case map[string]any: - return approvalResponseFromMap(typed) - default: - raw, err := json.Marshal(value) - if err != nil { - return agui.ToolApprovalResponse{}, false - } - var response agui.ToolApprovalResponse - if err = json.Unmarshal(raw, &response); err != nil { - return agui.ToolApprovalResponse{}, false - } - return response, response.ID != "" - } -} - -func approvalResponseFromMap(value map[string]any) (agui.ToolApprovalResponse, bool) { - idValue, _ := value["id"].(string) - if idValue == "" { - return agui.ToolApprovalResponse{}, false - } - response := agui.ToolApprovalResponse{ID: idValue} - if approved, ok := value["approved"].(bool); ok { - response.Approved = approved - } - if always, ok := value["always"].(bool); ok { - response.Always = always - } - if reason, ok := value["reason"].(string); ok { - response.Reason = reason - } - return response, true -} - func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, events []agui.Event) []aistream.ApprovalPrompt { if len(prompts) == 0 { return nil @@ -1176,6 +1113,46 @@ func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, sender n dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, sender, messageID, run, time.Now())) } +func (dc *DummyClient) ensureAIRunSession(runID string) { + if dc == nil || runID == "" { + return + } + dc.aiRunSessionsMu.Lock() + defer dc.aiRunSessionsMu.Unlock() + if dc.aiRunSessions == nil { + dc.aiRunSessions = make(map[string]*aiRunSession) + } + if dc.aiRunSessions[runID] == nil { + dc.aiRunSessions[runID] = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + } +} + +func (dc *DummyClient) recordAIApprovalDecision(runID string, response agui.ToolApprovalResponse) map[string]agui.ToolApprovalResponse { + decisions := make(map[string]agui.ToolApprovalResponse) + if response.ID == "" { + return decisions + } + if dc == nil || runID == "" { + decisions[response.ID] = response + return decisions + } + dc.aiRunSessionsMu.Lock() + defer dc.aiRunSessionsMu.Unlock() + if dc.aiRunSessions == nil { + dc.aiRunSessions = make(map[string]*aiRunSession) + } + session := dc.aiRunSessions[runID] + if session == nil { + session = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + dc.aiRunSessions[runID] = session + } + session.Decisions[response.ID] = response + for id, decision := range session.Decisions { + decisions[id] = decision + } + return decisions +} + func (dc *DummyClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { // bridgev2 will delete the portal + Matrix room after this returns nil. // For dummybridge, there's no separate remote-side deletion to do. diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 45b1663..63e2239 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -230,3 +230,17 @@ func TestAnnotateApprovalEventIDsAddsReactionTargetEventToStreamPrompt(t *testin } t.Fatal("missing approval-requested event") } + +func TestApprovalDecisionsAreStoredInRunSession(t *testing.T) { + client := &DummyClient{} + first := agui.ToolApprovalResponse{ID: "approval-1", Approved: true} + decisions := client.recordAIApprovalDecision("run-1", first) + if len(decisions) != 1 || !decisions["approval-1"].Approved { + t.Fatalf("bad first decisions: %#v", decisions) + } + second := agui.ToolApprovalResponse{ID: "approval-2", Approved: false, Reason: "denied"} + decisions = client.recordAIApprovalDecision("run-1", second) + if len(decisions) != 2 || !decisions["approval-1"].Approved || decisions["approval-2"].Reason != "denied" { + t.Fatalf("bad accumulated decisions: %#v", decisions) + } +} From b995f7ceac7780890124988ab0c3f750249db3b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 22:24:59 +0200 Subject: [PATCH 32/46] wip --- pkg/connector/ai_runtime.go | 6 ++++++ pkg/connector/ai_runtime_test.go | 20 ++++++++++++++++++-- pkg/connector/ai_stream_random.go | 2 +- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 3f2d9aa..4a612e3 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -961,6 +961,12 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC stepOpen = true } case randomActionTool: + if cmd.AllowApproval && cmd.Profile == "balanced" && rng.Intn(24) == 0 { + if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { + return err + } + continue + } if err := handleTool(toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}); err != nil { return err } diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 0885fd4..72695c6 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -703,14 +703,30 @@ func TestRandomModeApprovalPause(t *testing.T) { } func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { + balanced := randomCommand{ + sharedStreamOptions: sharedStreamOptions{ + Profile: "balanced", + AllowApproval: true, + }, + } + seen := map[string]bool{} + rng := rand.New(rand.NewSource(2)) + for range 400 { + options, total := buildRandomActionOptions(balanced) + seen[pickWeighted(options, total, rng)] = true + } + if seen[randomActionToolApproval] { + t.Fatalf("balanced profile should keep approvals rare via tool-call promotion, seen=%#v", seen) + } + cmd := randomCommand{ sharedStreamOptions: sharedStreamOptions{ Profile: "tools", AllowApproval: true, }, } - seen := map[string]bool{} - rng := rand.New(rand.NewSource(4)) + seen = map[string]bool{} + rng = rand.New(rand.NewSource(4)) for range 400 { options, total := buildRandomActionOptions(cmd) seen[pickWeighted(options, total, rng)] = true diff --git a/pkg/connector/ai_stream_random.go b/pkg/connector/ai_stream_random.go index 7953fe7..a55d251 100644 --- a/pkg/connector/ai_stream_random.go +++ b/pkg/connector/ai_stream_random.go @@ -16,7 +16,7 @@ func buildRandomActionOptions(cmd randomCommand) ([]randomActionOption, int) { {randomActionData, 1}, {randomActionDataTransient, 1}, } - if cmd.AllowApproval { + if cmd.AllowApproval && cmd.Profile != "balanced" { options = append(options, randomActionOption{randomActionToolApproval, 2}) } switch cmd.Profile { From 4cc9902d3836007fcbf696ef0abdbb79fb3b5cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sat, 23 May 2026 22:56:27 +0200 Subject: [PATCH 33/46] wip --- pkg/ai-stream/approval.go | 16 ++++++-- pkg/ai-stream/run.go | 14 ++++++- pkg/ai-stream/stream_test.go | 4 +- pkg/connector/ai_runtime.go | 67 +++++++++++++++++--------------- pkg/connector/ai_runtime_test.go | 21 +++++++++- 5 files changed, 81 insertions(+), 41 deletions(-) diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index 4b4b0d1..64d9df6 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -13,10 +13,11 @@ const ( ) type ApprovalChoice struct { - Key string `json:"key"` - Label string `json:"label"` - Alias string `json:"alias"` - Style string `json:"style,omitempty"` + Key string `json:"key"` + Label string `json:"label"` + Alias string `json:"alias"` + Style string `json:"style,omitempty"` + Shortcut string `json:"shortcut,omitempty"` } type ApprovalCleanup struct { @@ -67,6 +68,7 @@ type ApprovalRequestedValue struct { ApprovalMessageID string ApprovalEventID string Choices []ApprovalChoice + Metadata map[string]any } type ApprovalNotice struct { @@ -120,6 +122,9 @@ func (v ApprovalRequestedValue) Map() map[string]any { if v.ApprovalEventID != "" { value["approvalEventId"] = v.ApprovalEventID } + if len(v.Metadata) > 0 { + value["metadata"] = v.Metadata + } return value } @@ -146,6 +151,9 @@ func ApprovalChoicesAsAny(choices []ApprovalChoice) []any { if choice.Style != "" { item["style"] = choice.Style } + if choice.Shortcut != "" { + item["shortcut"] = choice.Shortcut + } out = append(out, item) } return out diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index 50dd379..b927c05 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -234,10 +234,16 @@ func (w *Writer) ToolStartWithMetadata(toolCallID, name string, index int, appro } func (w *Writer) ToolApprovalRequested(toolCallID, name string, input any, approval agui.ToolApproval) { + w.ToolApprovalRequestedWithMetadata(toolCallID, name, input, approval, nil) +} + +func (w *Writer) ToolApprovalRequestedWithMetadata(toolCallID, name string, input any, approval agui.ToolApproval, metadata map[string]any) { w.recordApprovalRequest(toolCallID, name, &approval) + value := NewApprovalRequestedValue(*w.Run, toolCallID, name, input, approval) + value.Metadata = metadata w.Add(w.builder.Custom( agui.ApprovalCustomRequested, - NewApprovalRequestedValue(*w.Run, toolCallID, name, input, approval).Map(), + value.Map(), )) } @@ -302,6 +308,7 @@ func (w *Writer) ToolApprovalResponded(toolCallID, name string, input any, respo } if response.Approved { result["state"] = agui.ToolResultStateComplete + result["status"] = "success" result["approved"] = true } else { reason := response.Reason @@ -309,6 +316,7 @@ func (w *Writer) ToolApprovalResponded(toolCallID, name string, input any, respo reason = "denied" } result["state"] = agui.ToolResultStateError + result["status"] = "denied" result["reason"] = reason } w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(result), agui.ToolStateApprovalResponded)) @@ -321,6 +329,7 @@ func (w *Writer) ToolResult(toolCallID, content, state string) { func (w *Writer) ToolError(toolCallID, name string, input any, reason string) { w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(map[string]any{ "state": agui.ToolResultStateError, + "status": "failed", "reason": reason, }), agui.ToolStateInputComplete)) } @@ -340,7 +349,8 @@ func (w *Writer) ToolDenied(toolCallID, name string, input any, approvalID, reas })) w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(map[string]any{ "state": agui.ToolResultStateError, - "reason": "denied", + "status": "denied", + "reason": reason, }), agui.ToolStateApprovalResponded)) } diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go index 1e39c54..4a04cd1 100644 --- a/pkg/ai-stream/stream_test.go +++ b/pkg/ai-stream/stream_test.go @@ -235,7 +235,7 @@ func TestFinalUIMessageCarriesToolCallMetadata(t *testing.T) { writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) writer.ToolStartWithMetadata("tool-1", "calendar.get_events", 0, nil, map[string]any{ "displayName": "List Calendar Events", - "iconId": "3257-5951", + "iconUrl": "mxc://beeper.com/calendar", }) message := run.FinalUIMessage(0, true) @@ -243,7 +243,7 @@ func TestFinalUIMessageCarriesToolCallMetadata(t *testing.T) { t.Fatalf("expected one part, got %#v", message.Parts) } metadata, ok := message.Parts[0]["metadata"].(map[string]any) - if !ok || metadata["displayName"] != "List Calendar Events" || metadata["iconId"] != "3257-5951" { + if !ok || metadata["displayName"] != "List Calendar Events" || metadata["iconUrl"] != "mxc://beeper.com/calendar" { t.Fatalf("bad tool metadata: %#v", message.Parts[0]) } } diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 4a612e3..8b81833 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -1021,7 +1021,8 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool if spec.Approval { approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} } - w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, toolDisplayMetadata(spec.Name)) + displayMetadata := toolDisplayMetadata(spec.Name) + w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) annotateProviderRawEvent(w, spec, "tool_call_start") if spec.InputError { w.ToolArgs(toolCallID, jsonToolInput(input), nil) @@ -1062,7 +1063,7 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool } w.ToolApprovalInputComplete(toolCallID, spec.Name, input) annotateProviderRawEvent(w, spec, "tool_call_input_complete") - w.ToolApprovalRequested(toolCallID, spec.Name, input, *approval) + w.ToolApprovalRequestedWithMetadata(toolCallID, spec.Name, input, *approval, displayMetadata) annotateProviderRawEvent(w, spec, "approval_requested") return errApprovalRequested case spec.Deny: @@ -1079,48 +1080,50 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool } func toolDisplayMetadata(name string) map[string]any { - displayName := titleToolName(name) - metadata := map[string]any{ - "displayName": displayName, + type ToolProviderMetadata struct { + ID string `json:"id,omitempty"` + DisplayName string `json:"displayName,omitempty"` + IconURL string `json:"iconUrl,omitempty"` } + type ToolDisplayMetadata struct { + DisplayName string `json:"displayName,omitempty"` + Description string `json:"description,omitempty"` + IconURL string `json:"iconUrl,omitempty"` + Provider *ToolProviderMetadata `json:"provider,omitempty"` + } + + metadata := ToolDisplayMetadata{} switch strings.ToLower(name) { case "calendar.get_events", "google_calendar.get_events", "google-calendar.get-events": - metadata["displayName"] = "List Calendar Events" - metadata["iconId"] = "3257-5951" - metadata["provider"] = map[string]any{ - "id": "google-calendar", - "displayName": "Google Calendar", - "iconId": "3257-5951", + metadata.DisplayName = "List Calendar Events" + metadata.Provider = &ToolProviderMetadata{ + ID: "google-calendar", + DisplayName: "Google Calendar", } case "linear.list_issues", "linear.list-issues", "list_issues", "list-issues": - metadata["displayName"] = "List Issues" - metadata["iconId"] = "3257-5945" - metadata["provider"] = map[string]any{ - "id": "linear", - "displayName": "Linear", - "iconId": "3257-5945", + metadata.DisplayName = "List Issues" + metadata.Provider = &ToolProviderMetadata{ + ID: "linear", + DisplayName: "Linear", } case "shell": - metadata["displayName"] = "Run Command" - metadata["iconId"] = "3255-2310" + metadata.DisplayName = "Run Command" case "fetch": - metadata["displayName"] = "Fetch Web" - metadata["iconId"] = "source-placeholder" + metadata.DisplayName = "Fetch Web" } - return metadata + return compactJSONMap(metadata) } -func titleToolName(name string) string { - parts := strings.FieldsFunc(name, func(r rune) bool { - return r == '_' || r == '-' || r == '.' - }) - for i, part := range parts { - if part == "" { - continue - } - parts[i] = strings.ToUpper(part[:1]) + part[1:] +func compactJSONMap(value any) map[string]any { + raw, err := json.Marshal(value) + if err != nil { + return nil } - return strings.Join(parts, " ") + var out map[string]any + if err := json.Unmarshal(raw, &out); err != nil || len(out) == 0 { + return nil + } + return out } func approvalIDForRun(runID, toolCallID string) string { diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 72695c6..3bb5462 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -134,7 +134,7 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { t.Fatalf("bad approval metadata: %#v", approval) } metadata, ok := evt["metadata"].(map[string]any) - if !ok || metadata["displayName"] != "Run Command" || metadata["iconId"] != "3255-2310" { + if !ok || metadata["displayName"] != "Run Command" { t.Fatalf("bad tool display metadata: %#v", evt["metadata"]) } foundToolStart = true @@ -161,6 +161,10 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { if value["approvalMessageId"] != "approval-run-1-dummy-tool-1-shell" { t.Fatalf("approval event should name the Matrix reaction target: %#v", value) } + metadata, ok := value["metadata"].(map[string]any) + if !ok || metadata["displayName"] != "Run Command" { + t.Fatalf("approval event should carry tool display metadata: %#v", value["metadata"]) + } choices, ok := value["choices"].([]aistream.ApprovalChoice) if !ok || len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { t.Fatalf("approval event should duplicate renderer choices: %#v", value["choices"]) @@ -183,6 +187,21 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { } } +func TestToolDisplayMetadataIsOptional(t *testing.T) { + if metadata := toolDisplayMetadata("unknown_tool"); metadata != nil { + t.Fatalf("unknown tools should not invent display metadata: %#v", metadata) + } + + metadata := toolDisplayMetadata("linear.list_issues") + provider, _ := metadata["provider"].(map[string]any) + if metadata["displayName"] != "List Issues" || provider["displayName"] != "Linear" { + t.Fatalf("bad known tool metadata: %#v", metadata) + } + if _, ok := metadata["iconId"]; ok { + t.Fatalf("metadata must not use iconId: %#v", metadata) + } +} + func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) if err != nil { From 6ee13736544848d2e2a3311f5f2498c3a5758305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 02:30:56 +0200 Subject: [PATCH 34/46] wip --- pkg/ai-stream/approval.go | 6 -- pkg/ai-stream/run.go | 77 +++++++++++++- pkg/ai-stream/stream_test.go | 84 +++++++++++++++ pkg/connector/ai_runtime.go | 63 +++++++---- pkg/connector/ai_runtime_test.go | 171 +++++++++++++++++++++++------- pkg/connector/ai_stream_random.go | 2 +- pkg/connector/ai_text.go | 19 +++- pkg/connector/client.go | 41 ++++++- 8 files changed, 389 insertions(+), 74 deletions(-) diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go index 64d9df6..04783bf 100644 --- a/pkg/ai-stream/approval.go +++ b/pkg/ai-stream/approval.go @@ -51,12 +51,6 @@ type ApprovalContext struct { PreviewTruncated bool `json:"previewTruncated,omitempty"` } -// ApprovalSeqReservation is the size of the sequence-number window reserved -// for the continuation of a single approval prompt. Large enough to fit any -// realistic continuation stream without colliding with neighbouring prompts' -// reserved ranges or their (possibly nested) continuations. -const ApprovalSeqReservation = 10000 - type ApprovalRequestedValue struct { ThreadID string RunID string diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go index b927c05..186fc8f 100644 --- a/pkg/ai-stream/run.go +++ b/pkg/ai-stream/run.go @@ -271,6 +271,12 @@ func (w *Writer) ToolArgs(toolCallID, delta string, args any) { } func (w *Writer) ToolEnd(toolCallID, name string, input, result any) { + if result == nil { + result = map[string]any{ + "state": agui.ToolResultStateComplete, + "status": "success", + } + } w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(result), agui.ToolStateInputComplete)) } @@ -368,6 +374,18 @@ func jsonString(value any) any { return string(raw) } +func jsonValue(value any) any { + text, ok := value.(string) + if !ok { + return value + } + var parsed any + if err := json.Unmarshal([]byte(text), &parsed); err != nil { + return value + } + return parsed +} + func (w *Writer) StateSnapshot(state map[string]any) { w.Add(w.builder.StateSnapshot(state)) } @@ -562,7 +580,7 @@ func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage part["input"] = input } if result, ok := evt["result"]; ok { - part["output"] = result + part["output"] = jsonValue(result) } case agui.EventToolCallResult: toolCallID, _ := evt["toolCallId"].(string) @@ -592,11 +610,23 @@ func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage approvalByID[approvalMapID(approval)] = approval } case "com.beeper.source": - message.Parts = append(message.Parts, agui.MessagePart{"type": "source-url", "source": value}) + part := cloneValueMap(value) + part["type"] = "source-url" + if asString(part["sourceId"]) == "" { + part["sourceId"] = firstString(part["url"], part["title"]) + } + message.Parts = append(message.Parts, part) case "com.beeper.document": - message.Parts = append(message.Parts, agui.MessagePart{"type": "file", "file": value}) + part := cloneValueMap(value) + part["type"] = "source-document" + if asString(part["sourceId"]) == "" { + part["sourceId"] = firstString(part["id"], part["title"]) + } + message.Parts = append(message.Parts, part) case "com.beeper.file": - message.Parts = append(message.Parts, agui.MessagePart{"type": "file", "file": value}) + part := cloneValueMap(value) + part["type"] = "file" + message.Parts = append(message.Parts, part) case "com.beeper.data": message.Parts = append(message.Parts, agui.MessagePart{"type": "data-com-beeper-data", "data": value}) } @@ -610,6 +640,11 @@ func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage } } } + if t.Status.State != "" && t.Status.State != "streaming" { + for _, part := range toolParts { + finalizeOpenToolPart(part, t.Status.State) + } + } if textPart != nil { textPart["content"] = textContent.String() } @@ -636,6 +671,32 @@ func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage return message } +func finalizeOpenToolPart(part agui.MessagePart, runState string) { + if part == nil { + return + } + if _, hasOutput := part["output"]; hasOutput { + return + } + state, _ := part["state"].(string) + switch state { + case agui.ToolStateApprovalResponded: + return + } + reason := "run finalized before tool completed" + if runState == "aborted" { + reason = "run aborted before tool completed" + } else if runState == "error" { + reason = "run failed before tool completed" + } + part["state"] = agui.ToolStateInputComplete + part["output"] = map[string]any{ + "state": agui.ToolResultStateError, + "status": "failed", + "reason": reason, + } +} + func (t Run) InitialUIMessage() agui.UIMessage { message := agui.UIMessage{ ID: t.MessageID, @@ -700,6 +761,14 @@ func asString(value any) string { } } +func cloneValueMap(value map[string]any) agui.MessagePart { + cp := make(agui.MessagePart, len(value)+1) + for key, item := range value { + cp[key] = item + } + return cp +} + func firstString(values ...any) string { for _, value := range values { if text, ok := value.(string); ok && text != "" { diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go index 4a04cd1..bfa9ee8 100644 --- a/pkg/ai-stream/stream_test.go +++ b/pkg/ai-stream/stream_test.go @@ -248,6 +248,90 @@ func TestFinalUIMessageCarriesToolCallMetadata(t *testing.T) { } } +func TestFinalUIMessageCarriesParsedToolOutputs(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.ToolStart("tool-1", "shell", 0, nil) + writer.ToolArgs("tool-1", `{"cmd":"pwd"}`, `{"cmd":"pwd"}`) + writer.ToolEnd("tool-1", "shell", map[string]any{"cmd": "pwd"}, nil) + writer.ToolStart("tool-2", "files", 1, nil) + writer.ToolError("tool-2", "files", map[string]any{"path": "/tmp/nope"}, "missing") + + message := run.FinalUIMessage(0, true) + if len(message.Parts) != 2 { + t.Fatalf("expected two tool parts, got %#v", message.Parts) + } + success, ok := message.Parts[0]["output"].(map[string]any) + if !ok || success["state"] != agui.ToolResultStateComplete || success["status"] != "success" { + t.Fatalf("success tool without result should emit terminal success output: %#v", message.Parts[0]) + } + failure, ok := message.Parts[1]["output"].(map[string]any) + if !ok || failure["state"] != agui.ToolResultStateError || failure["status"] != "failed" || failure["reason"] != "missing" { + t.Fatalf("failed tool output should be parsed and terminal: %#v", message.Parts[1]) + } +} + +func TestFinalUIMessageFailsOpenToolsWhenRunFinalized(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.ToolStart("tool-1", "summarize", 0, nil) + writer.ToolStart("tool-2", "calendar", 1, nil) + writer.Finish(agui.FinishReasonStop) + + message := run.FinalUIMessage(0, true) + if len(message.Parts) != 2 { + t.Fatalf("expected two tool parts, got %#v", message.Parts) + } + for _, part := range message.Parts { + if part["state"] != agui.ToolStateInputComplete { + t.Fatalf("open tool should be finalized as input-complete: %#v", part) + } + output, ok := part["output"].(map[string]any) + if !ok || output["state"] != agui.ToolResultStateError || output["status"] != "failed" { + t.Fatalf("open tool should get terminal failed output: %#v", part) + } + } +} + +func TestFinalUIMessageCarriesTopLevelArtifactsWithStableIDs(t *testing.T) { + run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) + writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) + writer.Custom("com.beeper.source", map[string]any{ + "sourceId": "source-1", + "url": "https://example.com/source", + "title": "Example Source", + }) + writer.Custom("com.beeper.document", map[string]any{ + "id": "doc-1", + "title": "Example Doc", + "mediaType": "text/plain", + }) + writer.Custom("com.beeper.file", map[string]any{ + "url": "mxc://example/file", + "mediaType": "application/octet-stream", + }) + + message := run.FinalUIMessage(0, true) + if len(message.Parts) != 3 { + t.Fatalf("expected artifact parts, got %#v", message.Parts) + } + if message.Parts[0]["type"] != "source-url" || message.Parts[0]["sourceId"] != "source-1" || message.Parts[0]["url"] != "https://example.com/source" { + t.Fatalf("bad source part shape: %#v", message.Parts[0]) + } + if _, hasNestedSource := message.Parts[0]["source"]; hasNestedSource { + t.Fatalf("source part should not nest payload: %#v", message.Parts[0]) + } + if message.Parts[1]["type"] != "source-document" || message.Parts[1]["sourceId"] != "doc-1" || message.Parts[1]["id"] != "doc-1" { + t.Fatalf("bad document part shape: %#v", message.Parts[1]) + } + if message.Parts[2]["type"] != "file" || message.Parts[2]["url"] != "mxc://example/file" { + t.Fatalf("bad file part shape: %#v", message.Parts[2]) + } + if _, hasNestedFile := message.Parts[2]["file"]; hasNestedFile { + t.Fatalf("file part should not nest payload: %#v", message.Parts[2]) + } +} + func TestApprovalResolverMatchesEmojiKeysAndAliases(t *testing.T) { choices := DefaultApprovalChoices() for _, key := range []string{"✅", "approve"} { diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 8b81833..0edbe05 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -916,8 +916,12 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC } } } + approvalRequested := false handleTool := func(spec toolSpec) error { if err := r.runToolSpec(ctx, w, spec, rng, defaultCommonOptions()); err != nil { + if spec.Approval { + approvalRequested = true + } if errors.Is(err, errApprovalRequested) && stepOpen { w.StepFinish(stepName) stepOpen = false @@ -925,6 +929,9 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC } return err } + if spec.Approval { + approvalRequested = true + } return nil } for action := range cmd.Actions { @@ -945,7 +952,8 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC } switch pickWeighted(actionOptions, actionWeightTotal, rng) { case randomActionText: - for _, chunk := range chunkText(buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))), rng, defaultChunkMin, defaultChunkMax) { + text := "\n\n" + buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))) + for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { w.Text(chunk) } case randomActionThinking: @@ -961,7 +969,7 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC stepOpen = true } case randomActionTool: - if cmd.AllowApproval && cmd.Profile == "balanced" && rng.Intn(24) == 0 { + if cmd.AllowApproval && cmd.Profile == "balanced" && action >= 10 && !approvalRequested { if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { return err } @@ -983,7 +991,8 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC return err } case randomActionSource: - w.Custom("com.beeper.source", map[string]any{"url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) + sourceID := fmt.Sprintf("random-source-%d", action+1) + w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) case randomActionDocument: w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("random-doc-%d", action+1), "title": fmt.Sprintf("Random Document %d", action+1), "mediaType": "text/plain"}) case randomActionFile: @@ -1015,7 +1024,7 @@ func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomC func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec toolSpec, rng *rand.Rand, opts commonCommandOptions) error { toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) - input := map[string]any{"tool": spec.Name, "sequence": spec.SequenceIndex, "tags": spec.Tags} + input := toolRequestInput(spec) approvalID := approvalIDForRun(w.Run.RunID, toolCallID) var approval *agui.ToolApproval if spec.Approval { @@ -1025,27 +1034,32 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) annotateProviderRawEvent(w, spec, "tool_call_start") if spec.InputError { - w.ToolArgs(toolCallID, jsonToolInput(input), nil) - annotateProviderRawEvent(w, spec, "tool_call_args") + if encodedInput := jsonToolInput(input); encodedInput != "" { + w.ToolArgs(toolCallID, encodedInput, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + } w.ToolError(toolCallID, spec.Name, input, "input-error") annotateProviderRawEvent(w, spec, "tool_call_error") return nil } if spec.Delta { - for _, chunk := range chunkText(fmt.Sprintf("{\"tool\":%q,\"sequence\":%d}", spec.Name, spec.SequenceIndex), rng, opts.ChunkMin, opts.ChunkMax) { - w.ToolArgs(toolCallID, chunk, nil) - annotateProviderRawEvent(w, spec, "tool_call_args") - if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { - return err + if encodedInput := jsonToolInput(input); encodedInput != "" { + for _, chunk := range chunkText(encodedInput, rng, opts.ChunkMin, opts.ChunkMax) { + w.ToolArgs(toolCallID, chunk, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } } } } else { - encodedInput := jsonToolInput(input) - w.ToolArgs(toolCallID, encodedInput, encodedInput) - annotateProviderRawEvent(w, spec, "tool_call_args") + if encodedInput := jsonToolInput(input); encodedInput != "" { + w.ToolArgs(toolCallID, encodedInput, encodedInput) + annotateProviderRawEvent(w, spec, "tool_call_args") + } } if spec.Preliminary { - w.ToolResult(toolCallID, fmt.Sprintf(`{"state":%q,"tool":%q}`, agui.ToolResultStateStreaming, spec.Name), agui.ToolResultStateStreaming) + w.ToolResult(toolCallID, fmt.Sprintf(`{"state":%q}`, agui.ToolResultStateStreaming), agui.ToolResultStateStreaming) annotateProviderRawEvent(w, spec, "tool_call_result") } switch { @@ -1073,12 +1087,16 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool w.ToolError(toolCallID, spec.Name, input, "DummyBridge synthetic tool failure") annotateProviderRawEvent(w, spec, "tool_call_error") default: - w.ToolEnd(toolCallID, spec.Name, input, map[string]any{"status": "ok", "tool": spec.Name, "sequence": spec.SequenceIndex}) + w.ToolEnd(toolCallID, spec.Name, input, nil) annotateProviderRawEvent(w, spec, "tool_call_end") } return nil } +func toolRequestInput(spec toolSpec) any { + return nil +} + func toolDisplayMetadata(name string) map[string]any { type ToolProviderMetadata struct { ID string `json:"id,omitempty"` @@ -1143,10 +1161,16 @@ func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { } } -func jsonToolInput(input map[string]any) string { +func jsonToolInput(input any) string { + if input == nil { + return "" + } + if inputMap, ok := input.(map[string]any); ok && len(inputMap) == 0 { + return "" + } raw, err := json.Marshal(input) if err != nil { - return "{}" + return "" } return string(raw) } @@ -1171,7 +1195,8 @@ func emitDecorations(w *aistream.Writer, opts commonCommandOptions, chars, step, w.StateDelta(statePatch(map[string]any{"command": "demo", "seed": seed, "step": step + 1})) } for i := range splitCount(opts.Sources, steps, step) { - w.Custom("com.beeper.source", map[string]any{"url": fmt.Sprintf("https://dummybridge.local/source/%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Source %d.%d", step+1, i+1)}) + sourceID := fmt.Sprintf("demo-source-%d-%d", step+1, i+1) + w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/source/%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Source %d.%d", step+1, i+1)}) } for i := range splitCount(opts.Documents, steps, step) { w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("demo-doc-%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Document %d.%d", step+1, i+1), "mediaType": "text/plain"}) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 3bb5462..f8e15d7 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "math/rand" + "regexp" "strconv" "strings" "testing" @@ -116,12 +117,8 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { t.Fatalf("approval prompt ID = %q, want run-scoped ID", run.Prompts[0].ID) } foundToolStart := false - seenArgsBeforeApproval := false seenApprovalStateBeforeCustom := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallArgs { - seenArgsBeforeApproval = true - } if evt["type"] == agui.EventToolCallStart { if evt["state"] != agui.ToolStateApprovalRequested { t.Fatalf("expected approval-requested tool state, got %#v", evt) @@ -144,15 +141,15 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { t.Fatalf("approval tool must not downgrade to input-complete: %#v", evt) } if evt["state"] == agui.ToolStateApprovalRequested { - if evt["input"] == nil { - t.Fatalf("approval input-complete event should include final input: %#v", evt) + if evt["input"] != nil { + t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) } seenApprovalStateBeforeCustom = true } } if evt["type"] == agui.EventCustom && evt["name"] == agui.ApprovalCustomRequested { - if !seenArgsBeforeApproval || !seenApprovalStateBeforeCustom { - t.Fatalf("approval custom event should be emitted after tool args and approval state update: %#v", run.Events) + if !seenApprovalStateBeforeCustom { + t.Fatalf("approval custom event should be emitted after approval state update: %#v", run.Events) } value := evt["value"].(map[string]any) if _, hasOptions := value["options"]; hasOptions { @@ -169,8 +166,8 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { if !ok || len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { t.Fatalf("approval event should duplicate renderer choices: %#v", value["choices"]) } - if value["input"] == nil { - t.Fatalf("approval event should include final tool input: %#v", value) + if value["input"] != nil { + t.Fatalf("approval event should omit placeholder tool input: %#v", value) } } } @@ -259,7 +256,9 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if len(run.Prompts) != 1 { t.Fatalf("expected one approval prompt, got %#v", run.Prompts) } - initialCarriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + sizingRun := *run + annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) + initialCarriers, err := aistream.PackRunFromSeq(sizingRun, "$anchor", aistream.CarrierBudgetBytes, 1) if err != nil { t.Fatal(err) } @@ -303,8 +302,15 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { annotatedValue, _ = env.Part["value"].(map[string]any) } } - if annotatedValue == nil || annotatedValue["approvalMessageId"] != prompt.ID || annotatedValue["approvalEventId"] != "$approval" { - t.Fatalf("approval-requested stream event missing Matrix target: %#v", annotatedValue) + if annotatedValue == nil || annotatedValue["approvalMessageId"] != prompt.ID { + t.Fatalf("approval-requested stream event missing approval message id: %#v", annotatedValue) + } + if annotatedValue["approvalEventId"] != "$approval" { + t.Fatalf("approval-requested stream event missing Matrix event target: %#v", annotatedValue) + } + annotatedCarriers = splitCarriersForTimedEmission(annotatedCarriers) + if annotatedNextSeq := aistream.NextSeq(annotatedCarriers); annotatedNextSeq != nextSeq { + t.Fatalf("approval event target changed stream sequence: initial=%d annotated=%d", nextSeq, annotatedNextSeq) } choices, ok := annotatedValue["choices"].([]any) if !ok || len(choices) != len(aistream.DefaultApprovalChoices()) { @@ -460,25 +466,27 @@ func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { t.Fatalf("missing structured denied tool result: %#v", run.Events) } -func TestBuildAIRunToolsArgsAreJSONStrings(t *testing.T) { +func TestBuildAIRunToolsOmitPlaceholderArgsAndEmitTerminalResult(t *testing.T) { run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) if err != nil { t.Fatal(err) } for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallArgs { - continue - } - args, ok := evt["args"].(string) - if !ok { - t.Fatalf("expected args to be JSON string, got %#v", evt["args"]) + if evt["type"] == agui.EventToolCallArgs { + t.Fatalf("plain demo tool should not emit placeholder args: %#v", evt) } - if !strings.Contains(args, `"tool":"shell"`) { - t.Fatalf("expected JSON tool args, got %q", args) + if evt["type"] == agui.EventToolCallEnd { + if evt["input"] != nil { + t.Fatalf("plain demo tool should omit placeholder input: %#v", evt) + } + result := jsonResultMap(t, evt["result"]) + if result["state"] != agui.ToolResultStateComplete || result["status"] != "success" { + t.Fatalf("plain demo tool should emit terminal success result: %#v", evt) + } + return } - return } - t.Fatal("missing TOOL_CALL_ARGS event") + t.Fatal("missing TOOL_CALL_END event") } func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { @@ -544,28 +552,32 @@ func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { t.Fatal(err) } seenFailure := false - seenDelta := false seenInputError := false for _, evt := range run.Events { if evt["type"] != agui.EventToolCallEnd && evt["type"] != agui.EventToolCallArgs { continue } toolCallID, _ := evt["toolCallId"].(string) - if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") && evt["args"] == nil { - seenDelta = true + if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { + t.Fatalf("delta tool without real input should not emit placeholder args: %#v", evt) } if evt["type"] == agui.EventToolCallEnd { - result := jsonResultMap(t, evt["result"]) - if strings.Contains(toolCallID, "shell") && result["state"] == agui.ToolResultStateError { - seenFailure = true + if strings.Contains(toolCallID, "shell") { + result := jsonResultMap(t, evt["result"]) + if result["state"] == agui.ToolResultStateError { + seenFailure = true + } } - if strings.Contains(toolCallID, "parser") && result["reason"] == "input-error" { - seenInputError = true + if strings.Contains(toolCallID, "parser") { + result := jsonResultMap(t, evt["result"]) + if result["reason"] == "input-error" { + seenInputError = true + } } } } - if !seenFailure || !seenDelta || !seenInputError { - t.Fatalf("missing tool tag coverage: failure=%v delta=%v inputError=%v", seenFailure, seenDelta, seenInputError) + if !seenFailure || !seenInputError { + t.Fatalf("missing tool tag coverage: failure=%v inputError=%v", seenFailure, seenInputError) } } @@ -595,14 +607,14 @@ func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { } func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { - errorRun, err := buildAIRun(context.Background(), "run-error", "thread-1", "stream 1 --terminal=error --seed=7", time.Unix(10, 0)) + errorRun, err := buildAIRun(context.Background(), "run-error", "thread-1", "stream 1 --terminal=error --seed=7 --no-approval", time.Unix(10, 0)) if err != nil { t.Fatal(err) } if errorRun.Status.State != "error" { t.Fatalf("expected error status, got %#v", errorRun.Status) } - abortRun, err := buildAIRun(context.Background(), "run-abort", "thread-1", "stream 1 --terminal=abort --seed=7", time.Unix(10, 0)) + abortRun, err := buildAIRun(context.Background(), "run-abort", "thread-1", "stream 1 --terminal=abort --seed=7 --no-approval", time.Unix(10, 0)) if err != nil { t.Fatal(err) } @@ -721,6 +733,32 @@ func TestRandomModeApprovalPause(t *testing.T) { t.Fatal("no approval action selected for tested random seeds") } +func TestBalancedStream50UsuallyPausesForApproval(t *testing.T) { + for seed := int64(1); seed <= 20; seed++ { + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream 50 --seed="+strconv.FormatInt(seed, 10), time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if run.ApprovalID != "" { + return + } + } + t.Fatal("balanced stream 50 did not request approval for any sampled seed") +} + +func TestBalancedStream50DoesNotPauseImmediatelyForApproval(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream 50 --seed=1", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if run.ApprovalID == "approval-run-approval-dummy-tool-1-calendar" { + t.Fatalf("balanced stream paused immediately for approval: %q", run.ApprovalID) + } + if strings.Contains(run.ApprovalID, "dummy-tool-1-") { + t.Fatalf("balanced stream paused on first action approval: %q", run.ApprovalID) + } +} + func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { balanced := randomCommand{ sharedStreamOptions: sharedStreamOptions{ @@ -802,6 +840,67 @@ func TestBuildDemoVisibleTextIsMarkdownRichAndDeterministic(t *testing.T) { t.Fatalf("expected markdown-rich text, got %q", first) } +func TestBuildDemoVisibleTextDoesNotCutMarkdownSyntax(t *testing.T) { + for _, chars := range []int{24, 40, 60, 80, 96, 120, 180, 260, 420} { + for seed := int64(1); seed <= 80; seed++ { + text := buildDemoVisibleText(chars, rand.New(rand.NewSource(seed))) + if strings.Count(text, "[") != strings.Count(text, "]") { + t.Fatalf("unbalanced brackets for chars=%d seed=%d: %q", chars, seed, text) + } + assertCompleteMarkdownLinks(t, chars, seed, text) + if strings.Count(text, "```")%2 != 0 { + t.Fatalf("unbalanced code fence for chars=%d seed=%d: %q", chars, seed, text) + } + if strings.Contains(text, "https://dummybridge.") && !strings.Contains(text, "https://dummybridge.local/") { + t.Fatalf("cut markdown URL for chars=%d seed=%d: %q", chars, seed, text) + } + } + } +} + +func TestRandomStreamTextBlocksKeepMarkdownBoundaries(t *testing.T) { + for seed := int64(1); seed <= 80; seed++ { + run, err := buildAIRun(context.Background(), "run-markdown", "thread-markdown", "stream 40 --seed="+strconv.FormatInt(seed, 10)+" --no-approval", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + text := run.Text() + assertCompleteMarkdownLinks(t, 40, seed, text) + if strings.Count(text, "[") != strings.Count(text, "]") { + t.Fatalf("unbalanced brackets for seed=%d: %q", seed, text) + } + if strings.Count(text, "```")%2 != 0 { + t.Fatalf("unbalanced code fence for seed=%d: %q", seed, text) + } + if joinedMarkdownBlockRE.MatchString(text) { + t.Fatalf("markdown block joined to previous text for seed=%d: %q", seed, text) + } + } +} + +var joinedMarkdownBlockRE = regexp.MustCompile(`[[:lower:]](Use |Review the \[|> )`) + +func assertCompleteMarkdownLinks(t *testing.T, chars int, seed int64, text string) { + t.Helper() + offset := 0 + for { + start := strings.Index(text[offset:], "](") + if start < 0 { + return + } + start += offset + close := strings.IndexByte(text[start+2:], ')') + if close < 0 { + t.Fatalf("unclosed markdown link for chars=%d seed=%d: %q", chars, seed, text) + } + linkTarget := text[start+2 : start+2+close] + if strings.ContainsAny(linkTarget, " \n\t") { + t.Fatalf("cut markdown link for chars=%d seed=%d: %q", chars, seed, text) + } + offset = start + 3 + close + } +} + func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { command := "stream-tools 240 shell#approval fetch#approval --seed=7 --chunk-chars=32:32" approvalCtx := aistream.ApprovalContext{ diff --git a/pkg/connector/ai_stream_random.go b/pkg/connector/ai_stream_random.go index a55d251..606bcc2 100644 --- a/pkg/connector/ai_stream_random.go +++ b/pkg/connector/ai_stream_random.go @@ -8,7 +8,7 @@ func buildRandomActionOptions(cmd randomCommand) ([]randomActionOption, int) { {randomActionThinking, 4}, {randomActionStep, 2}, {randomActionTool, 3}, - {randomActionToolFail, 2}, + {randomActionToolFail, 1}, {randomActionSource, 2}, {randomActionDocument, 2}, {randomActionFile, 2}, diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go index f9b083e..366bfbd 100644 --- a/pkg/connector/ai_text.go +++ b/pkg/connector/ai_text.go @@ -142,17 +142,14 @@ func chooseDemoSegment(specs []demoSegmentSpec, rng *rand.Rand, remaining int) s var candidates []demoSegmentSpec total := 0 for _, spec := range specs { - if remaining > 0 && remaining < spec.minLen/2 { + if remaining > 0 && remaining < spec.minLen { continue } candidates = append(candidates, spec) total += spec.weight } if len(candidates) == 0 { - candidates = specs - for _, spec := range candidates { - total += spec.weight - } + return specs[0].build(rng, remaining) } target := rng.Intn(total) for _, spec := range candidates { @@ -183,6 +180,9 @@ func trimVisibleText(text string, limit int) string { } if next > limit { if len(kept) == 0 { + if isMarkdownSensitiveBlock(block) { + return buildLoremText(limit, rand.New(rand.NewSource(int64(limit)))) + } kept = append(kept, trimText(block, limit)) } break @@ -196,6 +196,15 @@ func trimVisibleText(text string, limit int) string { return trimText(text, limit) } +func isMarkdownSensitiveBlock(block string) bool { + return strings.Contains(block, "](") || + strings.Contains(block, "```") || + strings.Contains(block, "\n|") || + strings.HasPrefix(block, "|") || + strings.HasPrefix(block, ">") || + strings.HasPrefix(block, "-") +} + func trimText(text string, limit int) string { text = strings.TrimSpace(text) if limit <= 0 || len(text) <= limit { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 1eee3a0..4a83a87 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -632,7 +632,9 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send // final metadata edit. func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { dc.ensureAIRunSession(run.RunID) - carriers, err := aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + sizingRun := run + annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) + carriers, err := aistream.PackRunFromSeq(sizingRun, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return @@ -640,8 +642,14 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid carriers = splitCarriersForTimedEmission(carriers) nextSeq := aistream.NextSeq(carriers) approvalEventIDs := make(map[string]id.EventID, len(run.Prompts)) - for i, prompt := range run.Prompts { - prompt.SeqStart = nextSeq + i*aistream.ApprovalSeqReservation + if len(run.Prompts) > 1 { + log.Warn(). + Str("run_id", run.RunID). + Int("approval_prompts", len(run.Prompts)). + Msg("AI run produced multiple simultaneous approval prompts; using the same continuation sequence") + } + for _, prompt := range run.Prompts { + prompt.SeqStart = nextSeq ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { approvalEventIDs[ctx.ID] = approvalEventID @@ -667,6 +675,19 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid return } carriers = splitCarriersForTimedEmission(carriers) + if actualNextSeq := aistream.NextSeq(carriers); actualNextSeq != nextSeq { + log.Warn(). + Str("run_id", run.RunID). + Int("expected_next_seq", nextSeq). + Int("actual_next_seq", actualNextSeq). + Msg("AI approval event ID repack changed stream sequence count") + return + } + } else if len(run.Prompts) > 0 { + log.Info(). + Str("run_id", run.RunID). + Int("approval_prompts", len(run.Prompts)). + Msg("Sending approval stream without approval event IDs") } dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) if len(run.Prompts) > 0 && run.Status.State == "streaming" { @@ -886,6 +907,20 @@ func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) } } +func approvalEventIDPlaceholders(prompts []aistream.ApprovalPrompt) map[string]id.EventID { + if len(prompts) == 0 { + return nil + } + placeholders := make(map[string]id.EventID, len(prompts)) + const placeholderEventID = "$approval_event_id_placeholder_padding_for_stable_ai_stream_sequence_000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000:beeper.local" + for _, prompt := range prompts { + if prompt.ID != "" { + placeholders[prompt.ID] = id.EventID(placeholderEventID) + } + } + return placeholders +} + func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response agui.ToolApprovalResponse) { approvalCtx, ok := dc.approvalContextForMessage(ctx, portal, approvalMessage) if !ok { From 21c24e09bd804df5b20f9aaa9a55808cb14ad86d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:24:21 +0200 Subject: [PATCH 35/46] Handle reaction remove delete errors --- pkg/connector/client.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 4a83a87..284b4b2 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -399,10 +399,11 @@ func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bri } func (dc *DummyClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { - if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && msg != nil && msg.TargetReaction != nil { - if err := dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, msg.TargetReaction); err != nil { - log.Warn().Err(err).Stringer("reaction_mxid", msg.TargetReaction.MXID).Msg("Failed to delete reaction on remove") - } + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || msg == nil || msg.TargetReaction == nil { + return nil + } + if err := dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, msg.TargetReaction); err != nil { + return fmt.Errorf("failed to delete reaction on remove: %w", err) } return nil } From bdf20e7880cf3a4f06af27aff061fe89ccc907bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:24:41 +0200 Subject: [PATCH 36/46] Accumulate AI stream stagger delays --- pkg/connector/ai_runtime.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 0edbe05..bb79b75 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -312,10 +312,10 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t runner := aiRunner{runtime: virtualAIRuntime(now)} actions := max(3, min(cmd.MaxActions, int(cmd.Duration/time.Second))) plans := make([]aiRunPlan, 0, cmd.Runs) + var delay time.Duration for i := range cmd.Runs { - var delay time.Duration if i > 0 { - delay = runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) } runID := fmt.Sprintf("%s-%d", baseRunID, i+1) randomCmd := randomCommand{ @@ -353,10 +353,11 @@ func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now } rng := rand.New(rand.NewSource(seed)) plans := make([]aiRunPlan, 0, cmd.Runs) + runner := aiRunner{runtime: virtualAIRuntime(now)} + var delay time.Duration for i := range cmd.Runs { - var delay time.Duration if i > 0 { - delay = aiRunner{runtime: virtualAIRuntime(now)}.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) } child := cmd child.Runs = 1 From 02620b9d093b46208ac1cca4bfe04f1586ecb477 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:24:57 +0200 Subject: [PATCH 37/46] Preserve oversized markdown text --- pkg/connector/ai_text.go | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go index 366bfbd..40a6945 100644 --- a/pkg/connector/ai_text.go +++ b/pkg/connector/ai_text.go @@ -181,9 +181,10 @@ func trimVisibleText(text string, limit int) string { if next > limit { if len(kept) == 0 { if isMarkdownSensitiveBlock(block) { - return buildLoremText(limit, rand.New(rand.NewSource(int64(limit)))) + kept = append(kept, trimMarkdownBlock(block, limit)) + } else { + kept = append(kept, trimText(block, limit)) } - kept = append(kept, trimText(block, limit)) } break } @@ -205,6 +206,29 @@ func isMarkdownSensitiveBlock(block string) bool { strings.HasPrefix(block, "-") } +func trimMarkdownBlock(block string, limit int) string { + trimmed := trimText(block, limit) + if strings.Count(trimmed, "[") != strings.Count(trimmed, "]") { + if idx := strings.LastIndex(trimmed, "["); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + } + if strings.Contains(trimmed, "](") && strings.Count(trimmed, "(") != strings.Count(trimmed, ")") { + if idx := strings.LastIndex(trimmed, "["); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + } + if strings.Count(trimmed, "```")%2 != 0 { + if idx := strings.LastIndex(trimmed, "```"); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + } + if trimmed == "" { + return trimText(block, limit) + } + return trimmed +} + func trimText(text string, limit int) string { text = strings.TrimSpace(text) if limit <= 0 || len(text) <= limit { From f081cb06ebf1f8213b7b006d0ff4611f22b55f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:25:06 +0200 Subject: [PATCH 38/46] Avoid eager AI run sessions --- pkg/connector/client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 284b4b2..3f1c2fc 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -632,7 +632,6 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send // known, and finally emits the carriers and (if the run terminated) the // final metadata edit. func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { - dc.ensureAIRunSession(run.RunID) sizingRun := run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) carriers, err := aistream.PackRunFromSeq(sizingRun, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) From 3154757f83259927eeb55bb014ac11217876e516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:25:20 +0200 Subject: [PATCH 39/46] Serialize simultaneous AI approvals --- pkg/connector/client.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 3f1c2fc..4b929fd 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -641,14 +641,16 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid } carriers = splitCarriersForTimedEmission(carriers) nextSeq := aistream.NextSeq(carriers) - approvalEventIDs := make(map[string]id.EventID, len(run.Prompts)) - if len(run.Prompts) > 1 { + queuedPrompts := run.Prompts + if len(queuedPrompts) > 1 { log.Warn(). Str("run_id", run.RunID). - Int("approval_prompts", len(run.Prompts)). - Msg("AI run produced multiple simultaneous approval prompts; using the same continuation sequence") + Int("approval_prompts", len(queuedPrompts)). + Msg("AI run produced multiple simultaneous approval prompts; queueing the first prompt only") + queuedPrompts = queuedPrompts[:1] } - for _, prompt := range run.Prompts { + approvalEventIDs := make(map[string]id.EventID, len(queuedPrompts)) + for _, prompt := range queuedPrompts { prompt.SeqStart = nextSeq ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { @@ -683,10 +685,10 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid Msg("AI approval event ID repack changed stream sequence count") return } - } else if len(run.Prompts) > 0 { + } else if len(queuedPrompts) > 0 { log.Info(). Str("run_id", run.RunID). - Int("approval_prompts", len(run.Prompts)). + Int("approval_prompts", len(queuedPrompts)). Msg("Sending approval stream without approval event IDs") } dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) From 3862396788f4f1f94f1310bc9aa4131f095df771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:25:30 +0200 Subject: [PATCH 40/46] Require valid stream demo commands --- pkg/connector/client.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 4b929fd..e6a6ee8 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -446,13 +446,17 @@ func isAIDemoCommandContent(content *event.MessageEventContent) bool { if content == nil { return false } - tokens := strings.Fields(strings.TrimSpace(content.Body)) + body := strings.TrimSpace(content.Body) + tokens := strings.Fields(body) if len(tokens) == 0 { return false } switch strings.ToLower(tokens[0]) { - case "help", "/help", "!help", "stream", "stream-tools": + case "help", "/help", "!help": return true + case "stream", "stream-tools": + _, err := parseCommand(body) + return err == nil case "dummybridge": return len(tokens) > 1 && strings.EqualFold(tokens[1], "help") default: From 143a4c61fc886fb745a273df535c10695360f9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 03:26:00 +0200 Subject: [PATCH 41/46] Use member map in identifier resolution --- pkg/connector/client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index e6a6ee8..5198eaa 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -1235,8 +1235,8 @@ func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, ghostInfo, _ := dc.GetUserInfo(ctx, ghost) portalInfo, _ := dc.GetChatInfo(ctx, portal) portalInfo.Members = &bridgev2.ChatMemberList{ - Members: []bridgev2.ChatMember{ - { + MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(dc.UserLogin.ID): { EventSender: bridgev2.EventSender{ IsFromMe: true, Sender: networkid.UserID(dc.UserLogin.ID), @@ -1244,7 +1244,7 @@ func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, Membership: event.MembershipJoin, PowerLevel: ptr.Ptr(50), }, - { + userID: { EventSender: bridgev2.EventSender{ Sender: userID, }, From c4893b742d41c159e578348954bd2b43ec40ea8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 04:09:49 +0200 Subject: [PATCH 42/46] move the pkgs --- go.mod | 1 + go.sum | 42 +- pkg/ag-ui/events.go | 688 -------------------- pkg/ag-ui/events_test.go | 144 ----- pkg/ai-stream/approval.go | 264 -------- pkg/ai-stream/bridgev2/events.go | 106 --- pkg/ai-stream/bridgev2/events_test.go | 100 --- pkg/ai-stream/matrix/content.go | 77 --- pkg/ai-stream/matrix/content_test.go | 202 ------ pkg/ai-stream/pack.go | 337 ---------- pkg/ai-stream/run.go | 888 -------------------------- pkg/ai-stream/stream_test.go | 460 ------------- pkg/connector/ai_runtime.go | 4 +- pkg/connector/ai_runtime_test.go | 4 +- pkg/connector/client.go | 6 +- pkg/connector/client_test.go | 4 +- 16 files changed, 12 insertions(+), 3315 deletions(-) delete mode 100644 pkg/ag-ui/events.go delete mode 100644 pkg/ag-ui/events_test.go delete mode 100644 pkg/ai-stream/approval.go delete mode 100644 pkg/ai-stream/bridgev2/events.go delete mode 100644 pkg/ai-stream/bridgev2/events_test.go delete mode 100644 pkg/ai-stream/matrix/content.go delete mode 100644 pkg/ai-stream/matrix/content_test.go delete mode 100644 pkg/ai-stream/pack.go delete mode 100644 pkg/ai-stream/run.go delete mode 100644 pkg/ai-stream/stream_test.go diff --git a/go.mod b/go.mod index a2add9c..d84f7a3 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( filippo.io/edwards25519 v1.2.0 // indirect + github.com/beeper/ai-bridge v0.0.0-20260524020001-6f18f21c0e09 github.com/coder/websocket v1.8.14 // indirect github.com/coreos/go-systemd/v22 v22.7.0 // indirect github.com/lib/pq v1.12.3 // indirect diff --git a/go.sum b/go.sum index ecbcb12..5d90657 100644 --- a/go.sum +++ b/go.sum @@ -1,45 +1,29 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/beeper/ai-bridge v0.0.0-20260524020001-6f18f21c0e09 h1:NbH3OUYoEw2gGjg5VzBdPrT27J5HcKGUxj0/nYNFTqE= +github.com/beeper/ai-bridge v0.0.0-20260524020001-6f18f21c0e09/go.mod h1:0K/m+XXVLw1mX5gZ6gIIxDi5RDAgj09W++2eGREM8MI= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= -github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= -github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= -github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= @@ -47,8 +31,6 @@ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDq github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= -github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU= github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -58,41 +40,23 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= -github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= -go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= go.mau.fi/util v0.9.9 h1:ujDeXCo07HBor5oQLyO1tHklupmqVmPgasc53d7q/NE= go.mau.fi/util v0.9.9/go.mod h1:pqt4Vcrt+5gcH/CgrHZg11qSx+b34o6mknGzOEA6waY= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw= golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= @@ -103,7 +67,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= -maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= maunium.net/go/mautrix v0.28.0 h1:vBakLzf8MAdfED3NzAKiMeKQbc3AQ4EAS03NC+TVMXQ= maunium.net/go/mautrix v0.28.0/go.mod h1:/a9A7LGaqb9B3nho4tLd28n0EPcCdwpm2dxkxkLLgh0= diff --git a/pkg/ag-ui/events.go b/pkg/ag-ui/events.go deleted file mode 100644 index dd5ec6b..0000000 --- a/pkg/ag-ui/events.go +++ /dev/null @@ -1,688 +0,0 @@ -package agui - -import ( - "encoding/json" - "fmt" - "strings" - "time" -) - -const ( - EventRunStarted = "RUN_STARTED" - EventRunFinished = "RUN_FINISHED" - EventRunError = "RUN_ERROR" - EventTextMessageStart = "TEXT_MESSAGE_START" - EventTextMessageContent = "TEXT_MESSAGE_CONTENT" - EventTextMessageEnd = "TEXT_MESSAGE_END" - EventToolCallStart = "TOOL_CALL_START" - EventToolCallArgs = "TOOL_CALL_ARGS" - EventToolCallEnd = "TOOL_CALL_END" - EventToolCallResult = "TOOL_CALL_RESULT" - EventStepStarted = "STEP_STARTED" - EventStepFinished = "STEP_FINISHED" - EventStateSnapshot = "STATE_SNAPSHOT" - EventStateDelta = "STATE_DELTA" - EventMessagesSnapshot = "MESSAGES_SNAPSHOT" - EventCustom = "CUSTOM" - EventReasoningStart = "REASONING_START" - EventReasoningEnd = "REASONING_END" - EventReasoningMsgStart = "REASONING_MESSAGE_START" - EventReasoningMsgCont = "REASONING_MESSAGE_CONTENT" - EventReasoningMsgEnd = "REASONING_MESSAGE_END" -) - -const ( - RoleAssistant = "assistant" - RoleUser = "user" - RoleSystem = "system" - RoleTool = "tool" -) - -const ( - ToolStateAwaitingInput = "awaiting-input" - ToolStateInputStreaming = "input-streaming" - ToolStateInputComplete = "input-complete" - ToolStateApprovalRequested = "approval-requested" - ToolStateApprovalResponded = "approval-responded" - ToolResultStateStreaming = "streaming" - ToolResultStateComplete = "complete" - ToolResultStateError = "error" - PartStateStreaming = "streaming" - PartStateDone = "done" - ApprovalCustomRequested = "approval-requested" - ApprovalCustomResponded = "approval-responded" - FinishReasonStop = "stop" - FinishReasonLength = "length" - FinishReasonContentFilter = "content_filter" - FinishReasonToolCalls = "tool_calls" - FinishReasonOther = "other" -) - -type Event map[string]any - -type UIMessage struct { - ID string `json:"id"` - Role string `json:"role"` - Parts []MessagePart `json:"parts"` - CreatedAt *time.Time `json:"createdAt,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` -} - -type MessagePart map[string]any - -type RunAgentInput struct { - ThreadID string `json:"threadId,omitempty"` - RunID string `json:"runId,omitempty"` - State map[string]any `json:"state,omitempty"` - Messages []UIMessage `json:"messages,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Context []ContextItem `json:"context,omitempty"` - ForwardedProps map[string]any `json:"forwardedProps,omitempty"` - Data map[string]any `json:"data,omitempty"` -} - -type Tool struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - InputSchema map[string]any `json:"inputSchema,omitempty"` - OutputSchema map[string]any `json:"outputSchema,omitempty"` - NeedsApproval bool `json:"needsApproval,omitempty"` -} - -type ContextItem struct { - Type string `json:"type"` - Value any `json:"value,omitempty"` - Meta map[string]any `json:"meta,omitempty"` -} - -type ToolApproval struct { - ID string `json:"id"` - NeedsApproval bool `json:"needsApproval"` - Fields map[string]any `json:"fields,omitempty"` -} - -type ToolApprovalResponse struct { - ID string `json:"id"` - Approved bool `json:"approved"` - Always bool `json:"always,omitempty"` - Reason string `json:"reason,omitempty"` - Fields map[string]any `json:"fields,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` -} - -type Usage struct { - PromptTokens int `json:"promptTokens,omitempty"` - CompletionTokens int `json:"completionTokens,omitempty"` - TotalTokens int `json:"totalTokens,omitempty"` -} - -type EventBuilder struct { - now func() time.Time - model string -} - -func NewEventBuilder(model string, now func() time.Time) EventBuilder { - if now == nil { - now = time.Now - } - return EventBuilder{now: now, model: strings.TrimSpace(model)} -} - -func (b EventBuilder) base(eventType string) Event { - evt := Event{ - "type": eventType, - "timestamp": b.now().UnixMilli(), - } - if b.model != "" { - evt["model"] = b.model - } - return evt -} - -func (b EventBuilder) RunStarted(threadID, runID string) Event { - evt := b.base(EventRunStarted) - evt["threadId"] = threadID - evt["runId"] = runID - return evt -} - -func (b EventBuilder) RunFinished(threadID, runID, finishReason string, usage Usage) Event { - evt := b.base(EventRunFinished) - evt["threadId"] = threadID - evt["runId"] = runID - evt["finishReason"] = NormalizeFinishReason(finishReason) - evt["usage"] = usage - return evt -} - -func (b EventBuilder) RunError(threadID, runID, message string) Event { - evt := b.base(EventRunError) - evt["threadId"] = threadID - if strings.TrimSpace(runID) != "" { - evt["runId"] = runID - } - evt["message"] = message - evt["error"] = map[string]any{"message": message} - return evt -} - -func (b EventBuilder) TextMessageStart(messageID, role string) Event { - if role == "" { - role = RoleAssistant - } - evt := b.base(EventTextMessageStart) - evt["messageId"] = messageID - evt["role"] = role - return evt -} - -func (b EventBuilder) TextMessageContent(messageID, delta string) Event { - evt := b.base(EventTextMessageContent) - evt["messageId"] = messageID - evt["delta"] = delta - return evt -} - -func (b EventBuilder) TextMessageEnd(messageID string) Event { - evt := b.base(EventTextMessageEnd) - evt["messageId"] = messageID - return evt -} - -func (b EventBuilder) ReasoningStart(messageID string) Event { - evt := b.base(EventReasoningStart) - evt["messageId"] = messageID - return evt -} - -func (b EventBuilder) ReasoningEnd(messageID string) Event { - evt := b.base(EventReasoningEnd) - evt["messageId"] = messageID - return evt -} - -func (b EventBuilder) ReasoningMessageStart(messageID string) Event { - evt := b.base(EventReasoningMsgStart) - evt["messageId"] = messageID - return evt -} - -func (b EventBuilder) ReasoningMessageContent(messageID, delta string) Event { - evt := b.base(EventReasoningMsgCont) - evt["messageId"] = messageID - evt["delta"] = delta - return evt -} - -func (b EventBuilder) ReasoningMessageEnd(messageID string) Event { - evt := b.base(EventReasoningMsgEnd) - evt["messageId"] = messageID - return evt -} - -func (b EventBuilder) ToolCallStart(messageID, toolCallID, name string, index *int, approval *ToolApproval) Event { - return b.ToolCallStartWithMetadata(messageID, toolCallID, name, index, approval, nil) -} - -func (b EventBuilder) ToolCallStartWithMetadata(messageID, toolCallID, name string, index *int, approval *ToolApproval, metadata map[string]any) Event { - evt := b.base(EventToolCallStart) - if messageID != "" { - evt["parentMessageId"] = messageID - } - evt["toolCallId"] = toolCallID - evt["toolCallName"] = name - evt["toolName"] = name - if len(metadata) > 0 { - evt["metadata"] = metadata - } - if index != nil { - evt["index"] = *index - } - if approval != nil { - evt["approval"] = approval - evt["state"] = ToolStateApprovalRequested - } else { - evt["state"] = ToolStateAwaitingInput - } - return evt -} - -func (b EventBuilder) ToolCallArgs(toolCallID, delta string, args any) Event { - evt := b.base(EventToolCallArgs) - evt["toolCallId"] = toolCallID - evt["delta"] = delta - evt["state"] = ToolStateInputStreaming - if args != nil { - evt["args"] = args - } - return evt -} - -func (b EventBuilder) ToolCallEnd(toolCallID, name string, input, result any, state string) Event { - evt := b.base(EventToolCallEnd) - evt["toolCallId"] = toolCallID - evt["toolCallName"] = name - evt["toolName"] = name - if input != nil { - evt["input"] = input - } - if result != nil { - evt["result"] = result - } - if state == "" { - state = ToolStateInputComplete - } - evt["state"] = state - return evt -} - -func (b EventBuilder) ToolCallResult(messageID, toolCallID, content, state, role string) Event { - if role == "" { - role = RoleTool - } - if state == "" { - state = ToolResultStateComplete - } - evt := b.base(EventToolCallResult) - evt["messageId"] = messageID - evt["toolCallId"] = toolCallID - evt["content"] = content - evt["state"] = state - evt["role"] = role - return evt -} - -func (b EventBuilder) StepStarted(messageID, stepName string) Event { - if stepName == "" { - panic("ag-ui: stepName is required for STEP_STARTED") - } - evt := b.base(EventStepStarted) - if messageID != "" { - evt["messageId"] = messageID - } - evt["stepName"] = stepName - return evt -} - -func (b EventBuilder) StepFinished(messageID, stepName string) Event { - if stepName == "" { - panic("ag-ui: stepName is required for STEP_FINISHED") - } - evt := b.base(EventStepFinished) - if messageID != "" { - evt["messageId"] = messageID - } - evt["stepName"] = stepName - return evt -} - -func (b EventBuilder) StateSnapshot(state map[string]any) Event { - evt := b.base(EventStateSnapshot) - evt["snapshot"] = state - return evt -} - -func (b EventBuilder) StateDelta(delta any) Event { - evt := b.base(EventStateDelta) - evt["delta"] = delta - return evt -} - -func (b EventBuilder) MessagesSnapshot(messages []UIMessage) Event { - evt := b.base(EventMessagesSnapshot) - evt["messages"] = messages - return evt -} - -func (b EventBuilder) Custom(name string, value any) Event { - evt := b.base(EventCustom) - evt["name"] = name - evt["value"] = value - return evt -} - -func TextPart(content string) MessagePart { - return MessagePart{"type": "text", "content": content} -} - -func ThinkingPart(content string) MessagePart { - return MessagePart{"type": "thinking", "content": content} -} - -func ToolCallPart(id, name string, arguments any, state string, approval *ToolApproval, output any) MessagePart { - part := MessagePart{"type": "tool-call", "id": id, "name": name, "arguments": arguments, "state": state} - if approval != nil { - part["approval"] = approval - } - if output != nil { - part["output"] = output - } - return part -} - -func ToolResultPart(toolCallID string, content any, state string, err any) MessagePart { - part := MessagePart{"type": "tool-result", "toolCallId": toolCallID, "content": content, "state": state} - if err != nil { - part["error"] = err - } - return part -} - -func ValidateEvent(evt Event) error { - eventType, _ := evt["type"].(string) - if eventType == "" { - return fmt.Errorf("ag-ui event missing type") - } - if _, ok := evt["timestamp"]; !ok { - return fmt.Errorf("%s missing timestamp", eventType) - } - switch eventType { - case EventRunStarted: - return require(evt, "threadId", "runId") - case EventRunFinished: - return require(evt, "threadId", "runId", "finishReason") - case EventRunError: - return require(evt, "message") - case EventTextMessageStart: - return require(evt, "messageId", "role") - case EventTextMessageContent: - if err := require(evt, "messageId"); err != nil { - return err - } - return requireStringField(evt, "delta") - case EventTextMessageEnd: - return require(evt, "messageId") - case EventReasoningStart, EventReasoningEnd, EventReasoningMsgStart, EventReasoningMsgEnd: - return require(evt, "messageId") - case EventReasoningMsgCont: - if err := require(evt, "messageId"); err != nil { - return err - } - return requireStringField(evt, "delta") - case EventToolCallStart: - if err := require(evt, "toolCallId", "toolCallName"); err != nil { - return err - } - if approval, ok := evt["approval"]; ok { - if err := validateToolApproval(approval); err != nil { - return fmt.Errorf("%s has invalid approval: %w", evt["type"], err) - } - } - return validateStringSet(evt, "state", true, validToolStates) - case EventToolCallArgs: - if err := require(evt, "toolCallId"); err != nil { - return err - } - if err := requireStringField(evt, "delta"); err != nil { - return err - } - if err := validateStringSet(evt, "state", false, validToolStates); err != nil { - return err - } - if args, ok := evt["args"]; ok { - if _, ok := args.(string); !ok { - return fmt.Errorf("%s has invalid args %T", evt["type"], args) - } - } - return nil - case EventToolCallEnd: - if err := require(evt, "toolCallId"); err != nil { - return err - } - if result, ok := evt["result"]; ok { - if _, ok := result.(string); !ok { - return fmt.Errorf("%s has invalid result %T", evt["type"], result) - } - } - return validateStringSet(evt, "state", true, validToolStates) - case EventToolCallResult: - if err := require(evt, "messageId", "toolCallId", "content"); err != nil { - return err - } - return validateStringSet(evt, "state", false, validToolResultStates) - case EventStepStarted, EventStepFinished: - return require(evt, "stepName") - case EventStateSnapshot: - return require(evt, "snapshot") - case EventStateDelta: - return require(evt, "delta") - case EventMessagesSnapshot: - return require(evt, "messages") - case EventCustom: - return require(evt, "name") - default: - return fmt.Errorf("unsupported ag-ui event type %q", eventType) - } -} - -func validateToolApproval(value any) error { - switch approval := value.(type) { - case ToolApproval: - if strings.TrimSpace(approval.ID) == "" { - return fmt.Errorf("missing id") - } - if !approval.NeedsApproval { - return fmt.Errorf("needsApproval must be true") - } - return nil - case *ToolApproval: - if approval == nil { - return fmt.Errorf("missing approval") - } - return validateToolApproval(*approval) - case map[string]any: - id, _ := approval["id"].(string) - if strings.TrimSpace(id) == "" { - return fmt.Errorf("missing id") - } - if approval["needsApproval"] != true { - return fmt.Errorf("needsApproval must be true") - } - return nil - default: - return fmt.Errorf("unexpected %T", value) - } -} - -func ValidateEventSequence(events []Event) error { - seenRunStart := false - terminal := false - textOpen := map[string]bool{} - reasoningOpen := map[string]bool{} - toolStarted := map[string]bool{} - toolEnded := map[string]bool{} - - for i, evt := range events { - if err := ValidateEvent(evt); err != nil { - return fmt.Errorf("event %d: %w", i+1, err) - } - eventType, _ := evt["type"].(string) - if terminal { - return fmt.Errorf("event %d: %s after terminal run event", i+1, eventType) - } - - switch eventType { - case EventRunStarted: - if seenRunStart { - return fmt.Errorf("event %d: duplicate RUN_STARTED", i+1) - } - seenRunStart = true - case EventRunFinished: - if !seenRunStart { - return fmt.Errorf("event %d: RUN_FINISHED before RUN_STARTED", i+1) - } - terminal = true - case EventRunError: - terminal = true - case EventTextMessageStart: - messageID := stringField(evt, "messageId") - if textOpen[messageID] { - return fmt.Errorf("event %d: duplicate TEXT_MESSAGE_START for %s", i+1, messageID) - } - textOpen[messageID] = true - case EventTextMessageContent: - messageID := stringField(evt, "messageId") - if !textOpen[messageID] { - return fmt.Errorf("event %d: TEXT_MESSAGE_CONTENT before TEXT_MESSAGE_START for %s", i+1, messageID) - } - case EventTextMessageEnd: - messageID := stringField(evt, "messageId") - if !textOpen[messageID] { - return fmt.Errorf("event %d: TEXT_MESSAGE_END before TEXT_MESSAGE_START for %s", i+1, messageID) - } - delete(textOpen, messageID) - case EventReasoningMsgStart: - messageID := stringField(evt, "messageId") - if reasoningOpen[messageID] { - return fmt.Errorf("event %d: duplicate REASONING_MESSAGE_START for %s", i+1, messageID) - } - reasoningOpen[messageID] = true - case EventReasoningMsgCont: - messageID := stringField(evt, "messageId") - if !reasoningOpen[messageID] { - return fmt.Errorf("event %d: REASONING_MESSAGE_CONTENT before REASONING_MESSAGE_START for %s", i+1, messageID) - } - case EventReasoningMsgEnd: - messageID := stringField(evt, "messageId") - if !reasoningOpen[messageID] { - return fmt.Errorf("event %d: REASONING_MESSAGE_END before REASONING_MESSAGE_START for %s", i+1, messageID) - } - delete(reasoningOpen, messageID) - case EventToolCallStart: - toolCallID := stringField(evt, "toolCallId") - if toolStarted[toolCallID] { - return fmt.Errorf("event %d: duplicate TOOL_CALL_START for %s", i+1, toolCallID) - } - toolStarted[toolCallID] = true - case EventToolCallArgs: - toolCallID := stringField(evt, "toolCallId") - if !toolStarted[toolCallID] { - return fmt.Errorf("event %d: TOOL_CALL_ARGS before TOOL_CALL_START for %s", i+1, toolCallID) - } - case EventToolCallEnd: - toolCallID := stringField(evt, "toolCallId") - if !toolStarted[toolCallID] { - return fmt.Errorf("event %d: TOOL_CALL_END before TOOL_CALL_START for %s", i+1, toolCallID) - } - if toolEnded[toolCallID] { - return fmt.Errorf("event %d: duplicate TOOL_CALL_END for %s", i+1, toolCallID) - } - toolEnded[toolCallID] = true - case EventToolCallResult: - toolCallID := stringField(evt, "toolCallId") - if !toolStarted[toolCallID] { - return fmt.Errorf("event %d: TOOL_CALL_RESULT before TOOL_CALL_START for %s", i+1, toolCallID) - } - } - } - return nil -} - -var validToolStates = map[string]bool{ - ToolStateAwaitingInput: true, - ToolStateInputStreaming: true, - ToolStateInputComplete: true, - ToolStateApprovalRequested: true, - ToolStateApprovalResponded: true, -} - -func stringField(evt Event, key string) string { - value, _ := evt[key].(string) - return value -} - -var validToolResultStates = map[string]bool{ - ToolResultStateStreaming: true, - ToolResultStateComplete: true, - ToolResultStateError: true, -} - -func validateStringSet(evt Event, key string, required bool, allowed map[string]bool) error { - value, ok := evt[key] - if !ok || value == nil { - if required { - return fmt.Errorf("%s missing %s", evt["type"], key) - } - return nil - } - stringValue, ok := value.(string) - if !ok || !allowed[stringValue] { - return fmt.Errorf("%s has invalid %s %q", evt["type"], key, value) - } - return nil -} - -func NormalizeFinishReason(value string) string { - switch strings.TrimSpace(strings.ToLower(value)) { - case "", FinishReasonStop: - return FinishReasonStop - case FinishReasonLength: - return FinishReasonLength - case "content-filter", "contentfilter", FinishReasonContentFilter: - return FinishReasonContentFilter - case "tool-calls", "toolcalls", FinishReasonToolCalls: - return FinishReasonToolCalls - case FinishReasonOther: - return FinishReasonOther - default: - return FinishReasonStop - } -} - -func CloneEvent(evt Event) Event { - raw, err := json.Marshal(evt) - if err != nil { - cp := make(Event, len(evt)) - for k, v := range evt { - cp[k] = v - } - return cp - } - var cp Event - if err := json.Unmarshal(raw, &cp); err != nil { - cp = make(Event, len(evt)) - for k, v := range evt { - cp[k] = v - } - } - return cp -} - -func require(evt Event, keys ...string) error { - for _, key := range keys { - value, ok := evt[key] - if !ok || emptyValue(value) { - return fmt.Errorf("%s missing %s", evt["type"], key) - } - } - return nil -} - -// requireStringField checks that the field is present and is a string. -// Unlike require, it accepts whitespace-only strings — streaming deltas can -// legitimately consist only of spaces or newlines between tokens. -func requireStringField(evt Event, key string) error { - value, ok := evt[key] - if !ok { - return fmt.Errorf("%s missing %s", evt["type"], key) - } - str, ok := value.(string) - if !ok { - return fmt.Errorf("%s has invalid %s %T", evt["type"], key, value) - } - if str == "" { - return fmt.Errorf("%s missing %s", evt["type"], key) - } - return nil -} - -func emptyValue(value any) bool { - switch v := value.(type) { - case string: - return strings.TrimSpace(v) == "" - case nil: - return true - default: - return false - } -} diff --git a/pkg/ag-ui/events_test.go b/pkg/ag-ui/events_test.go deleted file mode 100644 index f05d019..0000000 --- a/pkg/ag-ui/events_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package agui - -import ( - "testing" - "time" -) - -func TestBuildersCoverLifecycleEventsWithTimestamps(t *testing.T) { - now := func() time.Time { return time.Unix(10, 0) } - builder := NewEventBuilder("dummy/model", now) - idx := 1 - events := []Event{ - builder.RunStarted("thread", "run"), - builder.RunFinished("thread", "run", "tool-calls", Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}), - builder.RunError("thread", "run", "failed"), - builder.TextMessageStart("msg", RoleAssistant), - builder.TextMessageContent("msg", "hello"), - builder.TextMessageEnd("msg"), - builder.ReasoningStart("msg"), - builder.ReasoningMessageStart("msg"), - builder.ReasoningMessageContent("msg", "thinking"), - builder.ReasoningMessageEnd("msg"), - builder.ReasoningEnd("msg"), - builder.ToolCallStart("msg", "tool", "search", &idx, &ToolApproval{ID: "approval", NeedsApproval: true}), - builder.ToolCallArgs("tool", `{"q":"he`, nil), - builder.ToolCallEnd("tool", "search", map[string]any{"q": "hello"}, `{"ok":true}`, ToolStateInputComplete), - builder.ToolCallResult("msg", "tool", `{"ok":true}`, ToolResultStateComplete, RoleTool), - builder.StepStarted("msg", "step"), - builder.StepFinished("msg", "step"), - builder.StateSnapshot(map[string]any{"open": true}), - builder.StateDelta(map[string]any{"path": "/open", "value": false}), - builder.MessagesSnapshot([]UIMessage{{ID: "msg", Role: RoleAssistant, Parts: []MessagePart{TextPart("hello")}}}), - builder.Custom("com.beeper.test", map[string]any{"ok": true}), - } - for _, evt := range events { - if err := ValidateEvent(evt); err != nil { - t.Fatalf("ValidateEvent(%s) returned error: %v", evt["type"], err) - } - if evt["timestamp"] == nil { - t.Fatalf("event missing timestamp: %#v", evt) - } - } - if got := events[1]["finishReason"]; got != FinishReasonToolCalls { - t.Fatalf("finish reason = %q, want %q", got, FinishReasonToolCalls) - } - if got := events[2]["message"]; got != "failed" { - t.Fatalf("run error message = %#v, want failed", got) - } - toolStart := events[11] - if got := toolStart["index"]; got != 1 { - t.Fatalf("tool index = %#v, want 1", got) - } - if got := toolStart["parentMessageId"]; got != "msg" { - t.Fatalf("tool parentMessageId = %#v, want msg", got) - } - if _, hasMessageID := toolStart["messageId"]; hasMessageID { - t.Fatalf("tool start should not emit deprecated messageId: %#v", toolStart) - } - if _, hasSnapshot := events[17]["snapshot"]; !hasSnapshot { - t.Fatalf("state snapshot should emit snapshot field: %#v", events[17]) - } -} - -func TestValidateRejectsBadEvents(t *testing.T) { - tests := []Event{ - {}, - {"type": EventRunStarted, "timestamp": int64(1), "threadId": "thread"}, - {"type": EventRunError, "timestamp": int64(1), "threadId": "thread", "error": map[string]any{"message": "failed"}}, - {"type": EventTextMessageContent, "timestamp": int64(1), "messageId": "msg"}, - {"type": "REASONING_MESSAGE_CONTENT", "timestamp": int64(1)}, - {"type": EventToolCallStart, "timestamp": int64(1), "toolCallId": "tool", "toolCallName": "search", "state": "output-available"}, - {"type": EventToolCallStart, "timestamp": int64(1), "toolCallId": "tool", "toolCallName": "search", "state": ToolStateApprovalRequested, "approval": ToolApproval{ID: "", NeedsApproval: true}}, - {"type": EventToolCallStart, "timestamp": int64(1), "toolCallId": "tool", "toolCallName": "search", "state": ToolStateApprovalRequested, "approval": map[string]any{"id": "approval", "needsApproval": false}}, - {"type": EventToolCallArgs, "timestamp": int64(1), "toolCallId": "tool", "delta": "{}", "args": map[string]any{"bad": true}}, - {"type": EventToolCallEnd, "timestamp": int64(1), "toolCallId": "tool", "result": map[string]any{"bad": true}, "state": ToolStateInputComplete}, - {"type": EventToolCallResult, "timestamp": int64(1), "messageId": "msg", "toolCallId": "tool", "content": "{}", "state": "output-error"}, - {"type": EventStepStarted, "timestamp": int64(1), "stepId": "deprecated-only"}, - {"type": EventStateSnapshot, "timestamp": int64(1), "state": map[string]any{}}, - } - for _, evt := range tests { - if err := ValidateEvent(evt); err == nil { - t.Fatalf("expected validation error for %#v", evt) - } - } -} - -func TestValidateEventSequenceRejectsBadOrdering(t *testing.T) { - now := func() time.Time { return time.Unix(10, 0) } - builder := NewEventBuilder("dummy/model", now) - - valid := []Event{ - builder.RunStarted("thread", "run"), - builder.TextMessageStart("msg", RoleAssistant), - builder.TextMessageContent("msg", "hello"), - builder.TextMessageEnd("msg"), - builder.ToolCallStart("msg", "tool", "search", nil, nil), - builder.ToolCallArgs("tool", `{"q":"hello"}`, `{"q":"hello"}`), - builder.ToolCallEnd("tool", "search", map[string]any{"q": "hello"}, `{"ok":true}`, ToolStateInputComplete), - builder.RunFinished("thread", "run", FinishReasonStop, Usage{}), - } - if err := ValidateEventSequence(valid); err != nil { - t.Fatalf("ValidateEventSequence(valid) returned error: %v", err) - } - - tests := [][]Event{ - {builder.TextMessageContent("msg", "hello")}, - {builder.ReasoningMessageContent("msg", "thinking")}, - {builder.ToolCallArgs("tool", "{}", "{}")}, - {builder.ToolCallResult("msg", "tool", "{}", ToolResultStateComplete, RoleTool)}, - { - builder.RunStarted("thread", "run"), - builder.RunFinished("thread", "run", FinishReasonStop, Usage{}), - builder.TextMessageStart("msg", RoleAssistant), - }, - } - for _, events := range tests { - if err := ValidateEventSequence(events); err == nil { - t.Fatalf("expected ordering error for %#v", events) - } - } -} - -func TestRunAgentInputModelsBidirectionalShape(t *testing.T) { - input := RunAgentInput{ - ThreadID: "thread", - RunID: "run", - State: map[string]any{"open": true}, - Messages: []UIMessage{{ - ID: "msg", - Role: RoleUser, - Parts: []MessagePart{TextPart("hello")}, - }}, - Tools: []Tool{{Name: "send_email", NeedsApproval: true}}, - Context: []ContextItem{{ - Type: "beeper-room", - Value: "room", - }}, - ForwardedProps: map[string]any{"trace": "abc"}, - Data: map[string]any{"legacy": true}, - } - if input.ThreadID != "thread" || !input.Tools[0].NeedsApproval || input.ForwardedProps["trace"] != "abc" { - t.Fatalf("bad RunAgentInput shape: %#v", input) - } -} diff --git a/pkg/ai-stream/approval.go b/pkg/ai-stream/approval.go deleted file mode 100644 index 04783bf..0000000 --- a/pkg/ai-stream/approval.go +++ /dev/null @@ -1,264 +0,0 @@ -package aistream - -import ( - "strings" - - "github.com/beeper/dummybridge/pkg/ag-ui" -) - -const ( - ApprovalChoiceApprove = "approve" - ApprovalChoiceAlwaysApprove = "always_approve" - ApprovalChoiceDeny = "deny" -) - -type ApprovalChoice struct { - Key string `json:"key"` - Label string `json:"label"` - Alias string `json:"alias"` - Style string `json:"style,omitempty"` - Shortcut string `json:"shortcut,omitempty"` -} - -type ApprovalCleanup struct { - Selected ApprovalChoice - SelectedReactionEvent string - RedactReactionEvents []string - Matched bool -} - -type ReactionEvent struct { - EventID string - Sender string - Key string - Bridge bool -} - -type ApprovalContext struct { - ID string `json:"id"` - ThreadID string `json:"threadId"` - RunID string `json:"runId"` - MessageID string `json:"messageId"` - Command string `json:"command"` - ToolCallID string `json:"toolCallId"` - ToolName string `json:"toolName"` - TargetEvent string `json:"target_event"` - AgentID string `json:"agentId,omitempty"` - AgentName string `json:"agentName,omitempty"` - Model string `json:"model,omitempty"` - SeqStart int `json:"seqStart,omitempty"` - PreviewText string `json:"previewText,omitempty"` - PreviewTruncated bool `json:"previewTruncated,omitempty"` -} - -type ApprovalRequestedValue struct { - ThreadID string - RunID string - MessageID string - ToolCallID string - ToolName string - Input any - Approval agui.ToolApproval - ApprovalMessageID string - ApprovalEventID string - Choices []ApprovalChoice - Metadata map[string]any -} - -type ApprovalNotice struct { - Schema string - ID string - MessageID string - ToolCallID string - ToolName string - State string - Choices []ApprovalChoice -} - -func NewApprovalRequestedValue(run Run, toolCallID, toolName string, input any, approval agui.ToolApproval) ApprovalRequestedValue { - return ApprovalRequestedValue{ - ThreadID: run.ThreadID, - RunID: run.RunID, - MessageID: run.MessageID, - ToolCallID: toolCallID, - ToolName: toolName, - Input: input, - Approval: approval, - ApprovalMessageID: approval.ID, - Choices: DefaultApprovalChoices(), - } -} - -func NewApprovalNotice(ctx ApprovalContext, choices []ApprovalChoice) ApprovalNotice { - return ApprovalNotice{ - Schema: "com.beeper.ai.approval.v1", - ID: ctx.ID, - MessageID: ctx.MessageID, - ToolCallID: ctx.ToolCallID, - ToolName: ctx.ToolName, - State: "requested", - Choices: choices, - } -} - -func (v ApprovalRequestedValue) Map() map[string]any { - value := map[string]any{ - "threadId": v.ThreadID, - "runId": v.RunID, - "messageId": v.MessageID, - "toolCallId": v.ToolCallID, - "toolName": v.ToolName, - "input": v.Input, - "approval": v.Approval, - "approvalMessageId": v.ApprovalMessageID, - "choices": v.Choices, - } - if v.ApprovalEventID != "" { - value["approvalEventId"] = v.ApprovalEventID - } - if len(v.Metadata) > 0 { - value["metadata"] = v.Metadata - } - return value -} - -func (n ApprovalNotice) Map() map[string]any { - return map[string]any{ - "schema": n.Schema, - "id": n.ID, - "messageId": n.MessageID, - "toolCallId": n.ToolCallID, - "toolName": n.ToolName, - "state": n.State, - "choices": ApprovalChoicesAsAny(n.Choices), - } -} - -func ApprovalChoicesAsAny(choices []ApprovalChoice) []any { - out := make([]any, 0, len(choices)) - for _, choice := range choices { - item := map[string]any{ - "key": choice.Key, - "label": choice.Label, - "alias": choice.Alias, - } - if choice.Style != "" { - item["style"] = choice.Style - } - if choice.Shortcut != "" { - item["shortcut"] = choice.Shortcut - } - out = append(out, item) - } - return out -} - -func ApprovalIDFromRequestedValue(value map[string]any) string { - approval, _ := value["approval"].(agui.ToolApproval) - if approval.ID != "" { - return approval.ID - } - if raw, ok := value["approval"].(map[string]any); ok { - approvalID, _ := raw["id"].(string) - return approvalID - } - return "" -} - -func SetApprovalRequestedEventID(value map[string]any, eventID string) bool { - if value == nil || eventID == "" { - return false - } - approvalID := ApprovalIDFromRequestedValue(value) - if approvalID == "" { - return false - } - value["approvalMessageId"] = approvalID - value["approvalEventId"] = eventID - return true -} - -func DefaultApprovalChoices() []ApprovalChoice { - return []ApprovalChoice{ - { - Key: ApprovalChoiceApprove, - Label: "Allow once", - Alias: "✅", - }, - { - Key: ApprovalChoiceAlwaysApprove, - Label: "Allow always", - Alias: "☑️", - }, - { - Key: ApprovalChoiceDeny, - Label: "Deny", - Alias: "❌", - Style: "danger", - }, - } -} - -func ResolveApprovalChoice(choices []ApprovalChoice, raw string) (ApprovalChoice, bool) { - key := NormalizeReaction(raw) - for _, choice := range choices { - if NormalizeReaction(choice.Key) == key || NormalizeReaction(choice.Alias) == key { - return choice, true - } - } - var zero ApprovalChoice - return zero, false -} - -func ApprovalResponseForChoice(approvalID string, choice ApprovalChoice) agui.ToolApprovalResponse { - switch choice.Key { - case ApprovalChoiceApprove: - return agui.ToolApprovalResponse{ID: approvalID, Approved: true} - case ApprovalChoiceAlwaysApprove: - return agui.ToolApprovalResponse{ID: approvalID, Approved: true, Always: true} - case ApprovalChoiceDeny: - return agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: "denied"} - default: - return agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: "invalid approval choice"} - } -} - -func CleanupApprovalReactions(choices []ApprovalChoice, selectedKey string, events []ReactionEvent, bridgeSender string) ApprovalCleanup { - selected, ok := ResolveApprovalChoice(choices, selectedKey) - if !ok { - return ApprovalCleanup{} - } - cleanup := ApprovalCleanup{Selected: selected, Matched: true} - for _, evt := range events { - if evt.EventID == "" { - continue - } - choice, matchesChoice := ResolveApprovalChoice(choices, evt.Key) - isSelected := matchesChoice && choice.Key == selected.Key - isBridge := evt.Bridge || (bridgeSender != "" && evt.Sender == bridgeSender) - if isSelected && !isBridge && cleanup.SelectedReactionEvent == "" { - cleanup.SelectedReactionEvent = evt.EventID - continue - } - if isBridge || (matchesChoice && !isSelected) { - cleanup.RedactReactionEvents = append(cleanup.RedactReactionEvents, evt.EventID) - } - } - return cleanup -} - -func NormalizeReaction(reaction string) string { - reaction = strings.TrimSpace(reaction) - reaction = strings.ReplaceAll(reaction, "\ufe0f", "") - return strings.ToLower(reaction) -} - -func approvalSummaryState(response agui.ToolApprovalResponse) string { - if response.Approved { - if response.Always { - return "approved-always" - } - return "approved" - } - return "denied" -} diff --git a/pkg/ai-stream/bridgev2/events.go b/pkg/ai-stream/bridgev2/events.go deleted file mode 100644 index 05cbaa5..0000000 --- a/pkg/ai-stream/bridgev2/events.go +++ /dev/null @@ -1,106 +0,0 @@ -package aibridgev2 - -import ( - "context" - "time" - - aistream "github.com/beeper/dummybridge/pkg/ai-stream" - aimatrix "github.com/beeper/dummybridge/pkg/ai-stream/matrix" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/simplevent" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func eventMeta(eventType bridgev2.RemoteEventType, portalKey networkid.PortalKey, sender networkid.UserID, timestamp time.Time) simplevent.EventMeta { - return simplevent.EventMeta{ - Type: eventType, - PortalKey: portalKey, - Sender: bridgev2.EventSender{Sender: sender}, - Timestamp: timestamp, - StreamOrder: timestamp.UnixNano(), - } -} - -func messagePart(content *event.MessageEventContent, extra map[string]any, dbMetadata map[string]any) *bridgev2.ConvertedMessagePart { - return &bridgev2.ConvertedMessagePart{ - ID: networkid.PartID("0"), - Type: event.EventMessage, - Content: content, - Extra: extra, - DBMetadata: dbMetadata, - } -} - -func Anchor(portalKey networkid.PortalKey, sender networkid.UserID, run aistream.Run, timestamp time.Time) *simplevent.PreConvertedMessage { - content, extra := aimatrix.AnchorContent(run) - return &simplevent.PreConvertedMessage{ - EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), - Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{messagePart(content, extra, nil)}}, - ID: networkid.MessageID(run.MessageID), - } -} - -func Carrier(portalKey networkid.PortalKey, sender networkid.UserID, run aistream.Run, carrier aistream.Carrier, targetEventID id.EventID, index int, timestamp time.Time) *simplevent.PreConvertedMessage { - content, extra := aimatrix.CarrierContent(carrier, targetEventID) - return &simplevent.PreConvertedMessage{ - EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), - Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{messagePart(content, extra, nil)}}, - ID: networkid.MessageID(aistream.StreamTxnID(run.RunID, index)), - } -} - -func ApprovalPrompt(portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, timestamp time.Time) *simplevent.PreConvertedMessage { - content, extra := aimatrix.ApprovalContent(ctx, aistream.DefaultApprovalChoices()) - return &simplevent.PreConvertedMessage{ - EventMeta: eventMeta(bridgev2.RemoteEventMessage, portalKey, sender, timestamp), - Data: &bridgev2.ConvertedMessage{Parts: []*bridgev2.ConvertedMessagePart{ - messagePart(content, extra, map[string]any{"com.beeper.ai.approval": ctx}), - }}, - ID: networkid.MessageID(ctx.ID), - } -} - -func ApprovalOptionReaction(portalKey networkid.PortalKey, sender networkid.UserID, ctx aistream.ApprovalContext, choice aistream.ApprovalChoice, timestamp time.Time) *simplevent.Reaction { - return &simplevent.Reaction{ - EventMeta: eventMeta(bridgev2.RemoteEventReaction, portalKey, sender, timestamp), - TargetMessage: networkid.MessageID(ctx.ID), - EmojiID: networkid.EmojiID(choice.Key), - Emoji: choice.Alias, - ExtraContent: map[string]any{ - "com.beeper.ai.approval_option": map[string]any{ - "approvalId": ctx.ID, - "toolCallId": ctx.ToolCallID, - "choice": choice.Key, - }, - }, - } -} - -func FinalMetadataEdit(portalKey networkid.PortalKey, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, timestamp time.Time) *simplevent.Message[*aistream.Run] { - finalContent, finalExtra := aimatrix.FinalContent(run) - return &simplevent.Message[*aistream.Run]{ - EventMeta: eventMeta(bridgev2.RemoteEventEdit, portalKey, sender, timestamp), - Data: &run, - ID: messageID, - TargetMessage: messageID, - ConvertEditFunc: func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data *aistream.Run) (*bridgev2.ConvertedEdit, error) { - if len(existing) == 0 { - return nil, nil - } - return &bridgev2.ConvertedEdit{ - ModifiedParts: []*bridgev2.ConvertedEditPart{{ - Part: existing[0], - Type: event.EventMessage, - Content: finalContent, - Extra: finalExtra, - TopLevelExtra: map[string]any{ - "com.beeper.dont_render_edited": true, - }, - }}, - }, nil - }, - } -} diff --git a/pkg/ai-stream/bridgev2/events_test.go b/pkg/ai-stream/bridgev2/events_test.go deleted file mode 100644 index 2218aa3..0000000 --- a/pkg/ai-stream/bridgev2/events_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package aibridgev2 - -import ( - "strings" - "testing" - "time" - - aistream "github.com/beeper/dummybridge/pkg/ai-stream" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func TestBridgeV2AIEvents(t *testing.T) { - now := time.Unix(10, 0) - run := aistream.NewRun("run-1", "thread-1", "", "ai", "AI", now) - run.Preview = aistream.Preview{Text: "visible preview"} - - anchor := Anchor( - networkid.PortalKey{ID: "portal-1"}, - networkid.UserID("ai"), - *run, - now, - ) - if anchor.Type != bridgev2.RemoteEventMessage { - t.Fatalf("anchor type = %v", anchor.Type) - } - part := anchor.Data.Parts[0] - if part.Type != event.EventMessage || part.Content.Body != "visible preview" { - t.Fatalf("unexpected anchor part: %#v", part) - } - if part.Extra[aistream.BeeperAIKey] == nil || part.Extra[aistream.BeeperAIMetadataKey] == nil { - t.Fatalf("anchor missing AI metadata: %#v", part.Extra) - } - stream, ok := part.Extra["com.beeper.stream"].(map[string]any) - if !ok || stream["type"] != aistream.BeeperAIStreamDeltas { - t.Fatalf("anchor missing stream descriptor: %#v", part.Extra) - } - - carrier := Carrier( - networkid.PortalKey{ID: "portal-1"}, - networkid.UserID("ai"), - *run, - aistream.Carrier{Envelopes: []aistream.Envelope{{ - Seq: 1, - RunID: run.RunID, - ThreadID: run.ThreadID, - TargetEvent: "$anchor", - }}}, - id.EventID("$anchor"), - 1, - now, - ) - carrierPart := carrier.Data.Parts[0] - if carrierPart.Content.MsgType != event.MsgText || carrierPart.Content.Body != "" { - t.Fatalf("carrier should be hidden text carrier: %#v", carrierPart.Content) - } - if carrierPart.Extra[aistream.BeeperAIStreamDeltas] == nil { - t.Fatalf("carrier missing deltas: %#v", carrierPart.Extra) - } - - approval := ApprovalPrompt(networkid.PortalKey{ID: "portal-1"}, networkid.UserID("ai"), aistream.ApprovalContext{ - ID: "approval-1", - ThreadID: run.ThreadID, - RunID: run.RunID, - MessageID: run.MessageID, - ToolCallID: "tool-1", - ToolName: "dummy_echo", - TargetEvent: "$anchor", - }, now) - approvalPart := approval.Data.Parts[0] - approvalMetadata, ok := approvalPart.DBMetadata.(map[string]any) - if !ok || approvalMetadata["com.beeper.ai.approval"] == nil { - t.Fatalf("approval missing DB metadata: %#v", approvalPart.DBMetadata) - } -} - -func TestFinalMetadataEditUsesCompactAnchorContent(t *testing.T) { - now := time.Unix(10, 0) - run := aistream.NewRun("run-1", "thread-1", "", "ai", "AI", now) - run.Preview = aistream.Preview{Text: strings.Repeat("a", aistream.PreviewBudgetBytes+1), Truncated: true} - - edit := FinalMetadataEdit( - networkid.PortalKey{ID: "portal-1"}, - networkid.UserID("ai"), - networkid.MessageID(run.MessageID), - *run, - now, - ) - if edit.Type != bridgev2.RemoteEventEdit { - t.Fatalf("final metadata event type = %v", edit.Type) - } - if edit.TargetMessage != networkid.MessageID(run.MessageID) { - t.Fatalf("final metadata target = %q", edit.TargetMessage) - } - if edit.Data.Text() != "" { - t.Fatalf("final metadata edit must not expose full accumulated text") - } -} diff --git a/pkg/ai-stream/matrix/content.go b/pkg/ai-stream/matrix/content.go deleted file mode 100644 index 4192a32..0000000 --- a/pkg/ai-stream/matrix/content.go +++ /dev/null @@ -1,77 +0,0 @@ -package matrix - -import ( - "fmt" - - "github.com/beeper/dummybridge/pkg/ai-stream" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -const ApprovalRelationType = event.RelationType("com.beeper.ai.approval") - -func AnchorContent(run aistream.Run) (*event.MessageEventContent, map[string]any) { - content := previewContent(run) - extra := map[string]any{ - aistream.BeeperAIKey: run.InitialUIMessage(), - aistream.BeeperAIMetadataKey: run.Metadata(), - "com.beeper.stream": map[string]any{ - "type": aistream.BeeperAIStreamDeltas, - }, - } - return content, extra -} - -func FinalContent(run aistream.Run) (*event.MessageEventContent, map[string]any) { - content := previewContent(run) - extra := map[string]any{ - aistream.BeeperAIKey: run.FinalUIMessage(0, true), - aistream.BeeperAIMetadataKey: run.Metadata(), - "com.beeper.stream": map[string]any{ - "type": aistream.BeeperAIStreamDeltas, - }, - } - return content, extra -} - -func previewContent(run aistream.Run) *event.MessageEventContent { - body := run.Preview.Text - if body == "" { - body = "..." - } - rendered := format.RenderMarkdown(body, true, false) - content := &rendered - content.EnsureHasHTML() - content.BeeperPerMessageProfile = &event.BeeperPerMessageProfile{ - ID: run.AgentID, - Displayname: run.AgentName, - } - return content -} - -func CarrierContent(carrier aistream.Carrier, targetEventID id.EventID) (*event.MessageEventContent, map[string]any) { - content := format.TextToContent("") - content.SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: targetEventID}) - return &content, aistream.CarrierContent(carrier.Envelopes) -} - -func ApprovalContent(ctx aistream.ApprovalContext, choices []aistream.ApprovalChoice) (*event.MessageEventContent, map[string]any) { - toolName := ctx.ToolName - body := fmt.Sprintf("Approval required for %s", toolName) - if len(choices) > 0 { - body += "\nReact with one of the listed choices." - } - content := format.TextToContent(body) - if ctx.TargetEvent != "" { - content.SetRelatesTo(&event.RelatesTo{Type: ApprovalRelationType, EventID: id.EventID(ctx.TargetEvent)}) - } - extra := map[string]any{ - "com.beeper.ai.approval": aistream.NewApprovalNotice(ctx, choices).Map(), - } - return &content, extra -} - -func ApprovalChoicesAsAny(choices []aistream.ApprovalChoice) []any { - return aistream.ApprovalChoicesAsAny(choices) -} diff --git a/pkg/ai-stream/matrix/content_test.go b/pkg/ai-stream/matrix/content_test.go deleted file mode 100644 index 58ab46f..0000000 --- a/pkg/ai-stream/matrix/content_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package matrix - -import ( - "strings" - "testing" - "time" - - "github.com/beeper/dummybridge/pkg/ag-ui" - "github.com/beeper/dummybridge/pkg/ai-stream" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func TestAnchorContentUsesVisibleTextAndAIProfile(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - run.Preview = aistream.Preview{Text: "visible preview"} - - content, extra := AnchorContent(*run) - if content.MsgType != event.MsgText || content.Body != "visible preview" { - t.Fatalf("bad anchor content: %#v", content) - } - if content.Format != event.FormatHTML || content.FormattedBody == "" { - t.Fatalf("anchor preview should include Matrix HTML: %#v", content) - } - if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.ID != "ai" || content.BeeperPerMessageProfile.Displayname != "AI" { - t.Fatalf("missing AI per-message profile: %#v", content.BeeperPerMessageProfile) - } - uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) - if !ok || uiMessage.ID == "" || uiMessage.Metadata == nil || len(uiMessage.Parts) != 1 { - t.Fatalf("bad compact AI message: %#v", extra[aistream.BeeperAIKey]) - } - if uiMessage.Parts[0]["type"] != "text" || uiMessage.Parts[0]["content"] != "visible preview" { - t.Fatalf("anchor AI message should include preview text part: %#v", uiMessage.Parts) - } - if extra[aistream.BeeperAIMetadataKey] == nil { - t.Fatalf("missing AI metadata: %#v", extra) - } - stream, ok := extra["com.beeper.stream"].(map[string]any) - if !ok || stream["user_id"] != nil || stream["type"] != aistream.BeeperAIStreamDeltas { - t.Fatalf("missing stream descriptor: %#v", extra["com.beeper.stream"]) - } -} - -func TestAnchorContentKeepsLongRunsCompact(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - writer.Text(strings.Repeat("a", 70*1024)) - writer.Finish(agui.FinishReasonStop) - - content, extra := AnchorContent(*run) - if len(content.Body) > aistream.PreviewBudgetBytes { - t.Fatalf("anchor body length = %d, want <= %d", len(content.Body), aistream.PreviewBudgetBytes) - } - metadata := extra[aistream.BeeperAIMetadataKey].(map[string]any) - if _, hasParts := metadata["parts"]; hasParts { - t.Fatalf("metadata must not contain streamed parts: %#v", metadata) - } - if _, hasChunks := metadata["chunks"]; hasChunks { - t.Fatalf("metadata must not contain streamed chunks: %#v", metadata) - } - preview := metadata["preview"].(aistream.Preview) - if !preview.Truncated || len(preview.Text) > aistream.PreviewBudgetBytes { - t.Fatalf("bad bounded preview: %#v", preview) - } -} - -func TestStreamingAnchorDoesNotIncludePreviewPart(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - run.Preview = aistream.Preview{} - - content, extra := AnchorContent(*run) - if content.Body != "..." { - t.Fatalf("empty streaming anchor should use placeholder body, got %q", content.Body) - } - uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) - if !ok || len(uiMessage.Parts) != 0 { - t.Fatalf("streaming anchor should not include an initial text snapshot: %#v", extra[aistream.BeeperAIKey]) - } -} - -func TestAnchorContentRendersFinalPreviewAsMatrixHTML(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - run.Preview = aistream.Preview{Text: "Use **bold** and `code`"} - - content, _ := AnchorContent(*run) - if content.Format != event.FormatHTML { - t.Fatalf("format = %q, want Matrix HTML", content.Format) - } - if !strings.Contains(content.FormattedBody, "bold") || !strings.Contains(content.FormattedBody, "code") { - t.Fatalf("formatted body did not render markdown: %q", content.FormattedBody) - } -} - -func TestFinalContentIncludesFinalUIParts(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - writer.Thinking("hidden reasoning") - writer.Text("final **preview**") - writer.Finish(agui.FinishReasonStop) - - content, extra := FinalContent(*run) - if content.Body != "final **preview**" || content.Format != event.FormatHTML { - t.Fatalf("bad final preview content: %#v", content) - } - uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) - if !ok || len(uiMessage.Parts) != 2 || uiMessage.Parts[0]["type"] != "thinking" || uiMessage.Parts[1]["type"] != "text" { - t.Fatalf("final edit must include concrete UI parts: %#v", extra[aistream.BeeperAIKey]) - } - if uiMessage.Parts[0]["content"] != "hidden reasoning" || uiMessage.Parts[1]["content"] == "" { - t.Fatalf("final edit must preserve reasoning and text parts: %#v", uiMessage.Parts) - } - if extra[aistream.BeeperAIMetadataKey] == nil { - t.Fatalf("missing final metadata: %#v", extra) - } - stream, ok := extra["com.beeper.stream"].(map[string]any) - if !ok || stream["type"] != aistream.BeeperAIStreamDeltas { - t.Fatalf("missing final stream descriptor: %#v", extra["com.beeper.stream"]) - } -} - -func TestFinalContentDoesNotTruncateUIParts(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - full := strings.Repeat("| Artifact | State | Latency |\n| --- | --- | --- |\n| renderer | active | accepts markdown |\n\n", 100) - writer.Text(full) - writer.Finish(agui.FinishReasonStop) - expected := run.Text() - - _, extra := FinalContent(*run) - uiMessage, ok := extra[aistream.BeeperAIKey].(agui.UIMessage) - if !ok || len(uiMessage.Parts) == 0 { - t.Fatalf("missing final UI message: %#v", extra[aistream.BeeperAIKey]) - } - textPart := uiMessage.Parts[len(uiMessage.Parts)-1] - if textPart["content"] != expected { - t.Fatalf("final UI text was truncated: got %d bytes want %d", len(textPart["content"].(string)), len(expected)) - } - if metadata, ok := textPart["providerMetadata"]; ok { - t.Fatalf("final UI text should not be marked truncated: %#v", metadata) - } -} - -func TestCarrierContentIsHiddenTextCarrierWithDeltas(t *testing.T) { - carrier := aistream.Carrier{Envelopes: []aistream.Envelope{{ - ThreadID: "thread-1", - RunID: "run-1", - MessageID: "msg-run-1", - Seq: 1, - TargetEvent: "$anchor", - }}} - - content, extra := CarrierContent(carrier, id.EventID("$anchor")) - if content.MsgType != event.MsgText || content.Body != "" { - t.Fatalf("carrier should be empty m.text, got %#v", content) - } - if content.RelatesTo == nil || content.RelatesTo.EventID != "$anchor" { - t.Fatalf("carrier should reference anchor, got %#v", content.RelatesTo) - } - deltas, ok := extra[aistream.BeeperAIStreamDeltas].([]aistream.Envelope) - if !ok || len(deltas) != 1 || deltas[0].Seq != 1 { - t.Fatalf("missing deltas: %#v", extra) - } -} - -func TestApprovalContentIncludesContextAndChoices(t *testing.T) { - ctx := aistream.ApprovalContext{ - ID: "approval-1", - ThreadID: "thread-1", - RunID: "run-1", - MessageID: "msg-run-1", - ToolCallID: "tool-1", - ToolName: "shell", - TargetEvent: "$anchor", - } - choices := aistream.DefaultApprovalChoices() - - content, extra := ApprovalContent(ctx, choices) - if content.MsgType != event.MsgText || content.RelatesTo == nil || content.RelatesTo.EventID != "$anchor" || content.RelatesTo.Type != ApprovalRelationType { - t.Fatalf("bad approval content: %#v", content) - } - meta, ok := extra["com.beeper.ai.approval"].(map[string]any) - if !ok { - t.Fatalf("missing approval metadata: %#v", extra) - } - if meta["schema"] != "com.beeper.ai.approval.v1" || meta["id"] != ctx.ID || meta["messageId"] != ctx.MessageID || meta["toolCallId"] != ctx.ToolCallID || meta["state"] != "requested" { - t.Fatalf("bad approval metadata: %#v", meta) - } - if _, ok := meta["runId"]; ok { - t.Fatalf("approval event should not duplicate run metadata: %#v", meta) - } - approvalChoices, ok := meta["choices"].([]any) - if !ok || len(approvalChoices) != len(choices) { - t.Fatalf("bad approval choices: %#v", meta["choices"]) - } - first := approvalChoices[0].(map[string]any) - if first["key"] != aistream.ApprovalChoiceApprove || first["alias"] != "✅" { - t.Fatalf("bad first approval choice: %#v", first) - } -} diff --git a/pkg/ai-stream/pack.go b/pkg/ai-stream/pack.go deleted file mode 100644 index 439de44..0000000 --- a/pkg/ai-stream/pack.go +++ /dev/null @@ -1,337 +0,0 @@ -package aistream - -import ( - "encoding/json" - "fmt" - "strings" - "unicode/utf8" - - "github.com/beeper/dummybridge/pkg/ag-ui" -) - -func truncateUTF8(s string, maxBytes int) string { - if maxBytes <= 0 || len(s) <= maxBytes { - return s - } - end := maxBytes - for end > 0 && !utf8.RuneStart(s[end]) { - end-- - } - return s[:end] -} - -type Envelope struct { - ThreadID string `json:"threadId"` - RunID string `json:"runId"` - MessageID string `json:"messageId"` - Seq int `json:"seq"` - Part agui.Event `json:"part"` - TargetEvent string `json:"target_event,omitempty"` - RelatesTo Relation `json:"m.relates_to,omitempty"` - AgentID string `json:"agent_id,omitempty"` -} - -type Relation struct { - Type string `json:"rel_type"` - EventID string `json:"event_id"` -} - -type Carrier struct { - Envelopes []Envelope -} - -func BuildEnvelope(run Run, seq int, part agui.Event, targetEventID string) (Envelope, error) { - if seq <= 0 { - return Envelope{}, fmt.Errorf("stream envelope: seq must be > 0") - } - if err := agui.ValidateEvent(part); err != nil { - return Envelope{}, err - } - targetEventID = strings.TrimSpace(targetEventID) - if targetEventID == "" { - return Envelope{}, fmt.Errorf("stream envelope: missing target event id") - } - return Envelope{ - ThreadID: run.ThreadID, - RunID: run.RunID, - MessageID: run.MessageID, - Seq: seq, - Part: part, - TargetEvent: targetEventID, - RelatesTo: Relation{Type: "m.reference", EventID: targetEventID}, - AgentID: run.AgentID, - }, nil -} - -func PackRun(run Run, targetEventID string, budget int) ([]Carrier, error) { - return PackRunFromSeq(run, targetEventID, budget, 1) -} - -func PackRunFromSeq(run Run, targetEventID string, budget int, startSeq int) ([]Carrier, error) { - if budget <= 0 { - budget = CarrierBudgetBytes - } - if startSeq <= 0 { - startSeq = 1 - } - if err := run.Validate(); err != nil { - return nil, err - } - var carriers []Carrier - var current Carrier - currentSize := 0 - emptyCarrierOverhead := JSONSize(CarrierContent([]Envelope{})) - seq := startSeq - for _, original := range run.Events { - for _, part := range splitEventForBudget(original, budget) { - env, err := BuildEnvelope(run, seq, part, targetEventID) - if err != nil { - return nil, err - } - envSize := JSONSize(env) - if emptyCarrierOverhead+envSize > budget { - return nil, fmt.Errorf("stream envelope %d exceeds %d byte budget", seq, budget) - } - // +1 for the comma separator between envelopes in the JSON array. - addedSize := envSize - if len(current.Envelopes) > 0 { - addedSize++ - } - if len(current.Envelopes) > 0 && currentSize+addedSize > budget { - carriers = append(carriers, current) - current = Carrier{} - currentSize = 0 - addedSize = envSize - } - if len(current.Envelopes) == 0 { - currentSize = emptyCarrierOverhead + envSize - } else { - currentSize += addedSize - } - current.Envelopes = append(current.Envelopes, env) - seq++ - } - } - if len(current.Envelopes) > 0 { - carriers = append(carriers, current) - } - return carriers, nil -} - -func NextSeq(carriers []Carrier) int { - next := 1 - for _, carrier := range carriers { - for _, env := range carrier.Envelopes { - if env.Seq >= next { - next = env.Seq + 1 - } - } - } - return next -} - -func CarrierContent(envelopes []Envelope) map[string]any { - return map[string]any{BeeperAIStreamDeltas: envelopes} -} - -func ReconstructText(carriers []Carrier) string { - var out strings.Builder - for _, carrier := range carriers { - for _, env := range carrier.Envelopes { - if env.Part["type"] == agui.EventTextMessageContent { - delta, _ := env.Part["delta"].(string) - out.WriteString(delta) - } - } - } - return out.String() -} - -func splitEventForBudget(evt agui.Event, budget int) []agui.Event { - if evt["type"] == agui.EventMessagesSnapshot { - return splitMessagesSnapshotForBudget(evt, budget) - } - if JSONSize(evt) <= budget { - return []agui.Event{sanitizeRawEvent(evt, budget)} - } - if evt["type"] != agui.EventTextMessageContent { - return []agui.Event{sanitizeRawEvent(evt, budget)} - } - delta, _ := evt["delta"].(string) - if delta == "" { - return []agui.Event{sanitizeRawEvent(evt, budget)} - } - maxDelta := budget / 2 - if maxDelta < 1024 { - maxDelta = 1024 - } - var out []agui.Event - for _, chunk := range SplitTextUTF8(delta, maxDelta) { - cp := agui.CloneEvent(evt) - cp["delta"] = chunk - out = append(out, sanitizeRawEvent(cp, budget)) - } - return out -} - -func splitMessagesSnapshotForBudget(evt agui.Event, budget int) []agui.Event { - rawMessages, ok := evt["messages"].([]agui.UIMessage) - if !ok || len(rawMessages) == 0 { - return []agui.Event{sanitizeRawEvent(evt, budget)} - } - var out []agui.Event - for _, message := range rawMessages { - out = append(out, splitFinalMessageSnapshot(evt, message, budget)...) - } - if len(out) == 0 { - return []agui.Event{sanitizeRawEvent(evt, budget)} - } - return out -} - -func splitFinalMessageSnapshot(evt agui.Event, message agui.UIMessage, budget int) []agui.Event { - base := agui.CloneEvent(evt) - baseMessage := message - baseMessage.Parts = nil - base["messages"] = []agui.UIMessage{baseMessage} - - var out []agui.Event - baseFlushed := false - flushBase := func() { - if baseFlushed { - return - } - out = append(out, sanitizeRawEvent(base, budget)) - baseFlushed = true - } - appendToBase := func(part agui.MessagePart) bool { - if baseFlushed { - return false - } - nextMessage := baseMessage - nextMessage.Parts = append(append([]agui.MessagePart{}, baseMessage.Parts...), part) - candidate := agui.CloneEvent(base) - candidate["messages"] = []agui.UIMessage{nextMessage} - if JSONSize(candidate) > budget { - return false - } - baseMessage = nextMessage - base["messages"] = []agui.UIMessage{baseMessage} - return true - } - - var continuationParts []agui.MessagePart - continuationOffset := 0 - flushContinuation := func() { - if len(continuationParts) == 0 { - return - } - out = append(out, finalPartsEvent(evt, message.ID, message.Metadata, continuationOffset, continuationParts)) - continuationParts = nil - } - addContinuation := func(partOffset int, part agui.MessagePart) { - if len(continuationParts) > 0 && partOffset != continuationOffset+len(continuationParts) { - flushContinuation() - } - if len(continuationParts) == 0 { - continuationOffset = partOffset - } - candidateParts := append(append([]agui.MessagePart{}, continuationParts...), part) - candidate := finalPartsEvent(evt, message.ID, message.Metadata, continuationOffset, candidateParts) - if len(continuationParts) > 0 && JSONSize(candidate) > budget { - flushContinuation() - continuationOffset = partOffset - } - continuationParts = append(continuationParts, part) - } - - for partOffset, part := range message.Parts { - for pieceIndex, piece := range splitFinalPartForBudget(part, budget) { - if pieceIndex == 0 && appendToBase(piece) { - continue - } - flushBase() - addContinuation(partOffset, piece) - } - } - flushBase() - flushContinuation() - return out -} - -func finalPartsEvent(base agui.Event, messageID string, metadata map[string]any, partOffset int, parts []agui.MessagePart) agui.Event { - evt := agui.CloneEvent(base) - evt["type"] = agui.EventCustom - evt["name"] = FinalPartsCustomName - delete(evt, "messages") - runID, _ := metadata["runId"].(string) - evt["value"] = map[string]any{ - "messageId": messageID, - "runId": runID, - "partOffset": partOffset, - "parts": append([]agui.MessagePart{}, parts...), - } - return evt -} - -func splitFinalPartForBudget(part agui.MessagePart, budget int) []agui.MessagePart { - partType, _ := part["type"].(string) - if partType != "text" && partType != "thinking" { - return []agui.MessagePart{part} - } - content, _ := part["content"].(string) - if content == "" || JSONSize(part) <= budget/2 { - return []agui.MessagePart{part} - } - maxContent := budget / 3 - if maxContent < 1024 { - maxContent = 1024 - } - chunks := SplitTextUTF8(content, maxContent) - out := make([]agui.MessagePart, 0, len(chunks)) - for _, chunk := range chunks { - cp := cloneMessagePart(part) - cp["content"] = chunk - out = append(out, cp) - } - return out -} - -func cloneMessagePart(part agui.MessagePart) agui.MessagePart { - cp := make(agui.MessagePart, len(part)) - for key, value := range part { - cp[key] = value - } - return cp -} - -func sanitizeRawEvent(evt agui.Event, budget int) agui.Event { - cp := agui.CloneEvent(evt) - if _, ok := cp["rawEvent"]; !ok { - return cp - } - if JSONSize(cp) <= budget { - return cp - } - raw, err := json.Marshal(cp["rawEvent"]) - if err != nil { - delete(cp, "rawEvent") - cp["rawEventTruncated"] = true - } else if len(raw) > 2048 { - cp["rawEvent"] = truncateUTF8(string(raw), 2048) - cp["rawEventTruncated"] = true - } - if JSONSize(cp) > budget { - delete(cp, "rawEvent") - cp["rawEventTruncated"] = true - } - return cp -} - -func StreamTxnID(runID string, seq int) string { - runID = strings.TrimSpace(runID) - if runID == "" { - return fmt.Sprintf("ai_stream_%d", seq) - } - return fmt.Sprintf("ai_stream_%s_%d", runID, seq) -} diff --git a/pkg/ai-stream/run.go b/pkg/ai-stream/run.go deleted file mode 100644 index 186fc8f..0000000 --- a/pkg/ai-stream/run.go +++ /dev/null @@ -1,888 +0,0 @@ -package aistream - -import ( - "encoding/json" - "fmt" - "strings" - "time" - "unicode/utf8" - - "github.com/beeper/dummybridge/pkg/ag-ui" -) - -const ( - BeeperAIKey = "com.beeper.ai" - BeeperAIMetadataKey = "com.beeper.ai.metadata" - BeeperAIStreamKey = "com.beeper.llm" - BeeperAIStreamDeltas = BeeperAIStreamKey + ".deltas" - FinalPartsCustomName = "com.beeper.ai.final-parts" - DefaultModel = "dummybridge/ag-ui" - CarrierBudgetBytes = 40 * 1024 - PreviewBudgetBytes = 4096 - SnapshotTextBytes = 4096 -) - -type Run struct { - ThreadID string - RunID string - MessageID string - Model string - AgentID string - AgentName string - Events []agui.Event - Approvals []ApprovalSummary - Artifacts ArtifactSummary - Data map[string]any - Status Status - Usage agui.Usage - Preview Preview - ToolCallID string - ApprovalID string - Prompts []ApprovalPrompt -} - -type Status struct { - State string `json:"state"` - FinishReason string `json:"finishReason,omitempty"` - Terminal any `json:"terminal"` - Error any `json:"error"` -} - -type Preview struct { - Text string `json:"text"` - Truncated bool `json:"truncated"` -} - -type UIMessageMetadata struct { - ThreadID string `json:"threadId"` - RunID string `json:"runId"` - Status Status `json:"status"` - Usage *agui.Usage `json:"usage,omitempty"` -} - -func (m UIMessageMetadata) Map() map[string]any { - out := map[string]any{ - "threadId": m.ThreadID, - "runId": m.RunID, - "status": m.Status, - } - if m.Usage != nil { - out["usage"] = *m.Usage - } - return out -} - -type RunMetadata struct { - Schema string - Protocol string - ThreadID string - RunID string - MessageID string - AgentID string - AgentName string - Model string - Usage agui.Usage - Status Status - Approvals []ApprovalSummary - Artifacts ArtifactSummary - Data map[string]any - Preview Preview -} - -func (m RunMetadata) Map() map[string]any { - return map[string]any{ - "schema": m.Schema, - "protocol": m.Protocol, - "threadId": m.ThreadID, - "runId": m.RunID, - "messageId": m.MessageID, - "agent": map[string]any{ - "id": m.AgentID, - "displayName": m.AgentName, - }, - "model": m.Model, - "usage": map[string]any{ - "promptTokens": m.Usage.PromptTokens, - "completionTokens": m.Usage.CompletionTokens, - "totalTokens": m.Usage.TotalTokens, - }, - "usageDetails": map[string]any{}, - "status": m.Status, - "approvals": m.Approvals, - "artifacts": m.Artifacts, - "data": m.Data, - "preview": m.Preview, - } -} - -type ApprovalSummary struct { - ID string `json:"id"` - ToolCallID string `json:"toolCallId"` - State string `json:"state"` - Always bool `json:"always"` - Reason string `json:"reason,omitempty"` - Fields map[string]any `json:"fields,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` -} - -type ApprovalPrompt struct { - ID string - ToolCallID string - ToolName string - SeqStart int -} - -type ArtifactSummary struct { - Sources []map[string]any `json:"sources"` - Documents []map[string]any `json:"documents"` - Files []map[string]any `json:"files"` -} - -type Writer struct { - Run *Run - builder agui.EventBuilder - reasoningOpen bool -} - -func NewRun(runID, threadID, model, agentID, agentName string, now time.Time) *Run { - runID = strings.TrimSpace(runID) - if runID == "" { - runID = fmt.Sprintf("run-%d", now.UnixNano()) - } - threadID = strings.TrimSpace(threadID) - if threadID == "" { - threadID = runID - } - model = strings.TrimSpace(model) - if model == "" { - model = DefaultModel - } - if agentID == "" { - agentID = "ai" - } - if agentName == "" { - agentName = "AI" - } - run := &Run{ - ThreadID: threadID, - RunID: runID, - MessageID: "msg-" + runID, - Model: model, - AgentID: agentID, - AgentName: agentName, - Data: map[string]any{}, - Status: Status{State: "streaming"}, - } - run.Preview = Preview{Text: BoundedPreview("", PreviewBudgetBytes)} - return run -} - -func NewWriter(run *Run, now func() time.Time) *Writer { - return &Writer{Run: run, builder: agui.NewEventBuilder(run.Model, now)} -} - -func (w *Writer) Add(evt agui.Event) { - if w == nil || w.Run == nil || len(evt) == 0 { - return - } - w.Run.Events = append(w.Run.Events, evt) - w.applySummary(evt) -} - -func (w *Writer) Start() { - w.Add(w.builder.RunStarted(w.Run.ThreadID, w.Run.RunID)) - w.Add(w.builder.TextMessageStart(w.Run.MessageID, agui.RoleAssistant)) -} - -func (w *Writer) Text(delta string) { - if delta == "" { - return - } - w.Add(w.builder.TextMessageContent(w.Run.MessageID, delta)) -} - -func (w *Writer) Thinking(delta string) { - if delta == "" { - return - } - if !w.reasoningOpen { - w.Add(w.builder.ReasoningStart(w.Run.MessageID)) - w.Add(w.builder.ReasoningMessageStart(w.Run.MessageID)) - w.reasoningOpen = true - } - w.Add(w.builder.ReasoningMessageContent(w.Run.MessageID, delta)) -} - -func (w *Writer) StepStart(stepID string) { - w.Add(w.builder.StepStarted(w.Run.MessageID, stepID)) -} - -func (w *Writer) StepFinish(stepID string) { - w.Add(w.builder.StepFinished(w.Run.MessageID, stepID)) -} - -func (w *Writer) ToolStart(toolCallID, name string, index int, approval *agui.ToolApproval) { - w.ToolStartWithMetadata(toolCallID, name, index, approval, nil) -} - -func (w *Writer) ToolStartWithMetadata(toolCallID, name string, index int, approval *agui.ToolApproval, metadata map[string]any) { - idx := index - w.Add(w.builder.ToolCallStartWithMetadata(w.Run.MessageID, toolCallID, name, &idx, approval, metadata)) - if approval != nil { - w.recordApprovalRequest(toolCallID, name, approval) - } -} - -func (w *Writer) ToolApprovalRequested(toolCallID, name string, input any, approval agui.ToolApproval) { - w.ToolApprovalRequestedWithMetadata(toolCallID, name, input, approval, nil) -} - -func (w *Writer) ToolApprovalRequestedWithMetadata(toolCallID, name string, input any, approval agui.ToolApproval, metadata map[string]any) { - w.recordApprovalRequest(toolCallID, name, &approval) - value := NewApprovalRequestedValue(*w.Run, toolCallID, name, input, approval) - value.Metadata = metadata - w.Add(w.builder.Custom( - agui.ApprovalCustomRequested, - value.Map(), - )) -} - -func (w *Writer) recordApprovalRequest(toolCallID, name string, approval *agui.ToolApproval) { - if approval == nil || approval.ID == "" { - return - } - w.Run.ToolCallID = toolCallID - w.Run.ApprovalID = approval.ID - for _, existing := range w.Run.Approvals { - if existing.ID == approval.ID { - return - } - } - w.Run.Approvals = append(w.Run.Approvals, ApprovalSummary{ - ID: approval.ID, - ToolCallID: toolCallID, - State: "requested", - }) - w.Run.Prompts = append(w.Run.Prompts, ApprovalPrompt{ID: approval.ID, ToolCallID: toolCallID, ToolName: name}) -} - -func (w *Writer) ToolArgs(toolCallID, delta string, args any) { - w.Add(w.builder.ToolCallArgs(toolCallID, delta, args)) -} - -func (w *Writer) ToolEnd(toolCallID, name string, input, result any) { - if result == nil { - result = map[string]any{ - "state": agui.ToolResultStateComplete, - "status": "success", - } - } - w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(result), agui.ToolStateInputComplete)) -} - -func (w *Writer) ToolApprovalInputComplete(toolCallID, name string, input any) { - w.Add(w.builder.ToolCallEnd(toolCallID, name, input, nil, agui.ToolStateApprovalRequested)) -} - -func (w *Writer) ToolApprovalResponded(toolCallID, name string, input any, response agui.ToolApprovalResponse) { - for i := range w.Run.Approvals { - if w.Run.Approvals[i].ID == response.ID { - w.Run.Approvals[i].State = approvalSummaryState(response) - w.Run.Approvals[i].Always = response.Always - w.Run.Approvals[i].Reason = response.Reason - w.Run.Approvals[i].Fields = response.Fields - w.Run.Approvals[i].Metadata = response.Metadata - } - } - w.Add(w.builder.Custom(agui.ApprovalCustomResponded, map[string]any{ - "threadId": w.Run.ThreadID, - "runId": w.Run.RunID, - "messageId": w.Run.MessageID, - "toolCallId": toolCallID, - "toolName": name, - "approval": response, - })) - result := map[string]any{ - "approvalId": response.ID, - "always": response.Always, - } - if response.Fields != nil { - result["fields"] = response.Fields - } - if response.Metadata != nil { - result["metadata"] = response.Metadata - } - if response.Approved { - result["state"] = agui.ToolResultStateComplete - result["status"] = "success" - result["approved"] = true - } else { - reason := response.Reason - if reason == "" { - reason = "denied" - } - result["state"] = agui.ToolResultStateError - result["status"] = "denied" - result["reason"] = reason - } - w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(result), agui.ToolStateApprovalResponded)) -} - -func (w *Writer) ToolResult(toolCallID, content, state string) { - w.Add(w.builder.ToolCallResult(w.Run.MessageID, toolCallID, content, state, agui.RoleTool)) -} - -func (w *Writer) ToolError(toolCallID, name string, input any, reason string) { - w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(map[string]any{ - "state": agui.ToolResultStateError, - "status": "failed", - "reason": reason, - }), agui.ToolStateInputComplete)) -} - -func (w *Writer) ToolDenied(toolCallID, name string, input any, approvalID, reason string) { - if reason == "" { - reason = "denied" - } - for i := range w.Run.Approvals { - if w.Run.Approvals[i].ID == approvalID { - w.Run.Approvals[i].State = "denied" - w.Run.Approvals[i].Reason = reason - } - } - w.Add(w.builder.Custom(agui.ApprovalCustomResponded, map[string]any{ - "approval": agui.ToolApprovalResponse{ID: approvalID, Approved: false, Reason: reason}, - })) - w.Add(w.builder.ToolCallEnd(toolCallID, name, input, jsonString(map[string]any{ - "state": agui.ToolResultStateError, - "status": "denied", - "reason": reason, - }), agui.ToolStateApprovalResponded)) -} - -func jsonString(value any) any { - if value == nil { - return nil - } - if text, ok := value.(string); ok { - return text - } - raw, err := json.Marshal(value) - if err != nil { - return fmt.Sprint(value) - } - return string(raw) -} - -func jsonValue(value any) any { - text, ok := value.(string) - if !ok { - return value - } - var parsed any - if err := json.Unmarshal([]byte(text), &parsed); err != nil { - return value - } - return parsed -} - -func (w *Writer) StateSnapshot(state map[string]any) { - w.Add(w.builder.StateSnapshot(state)) -} - -func (w *Writer) StateDelta(delta any) { - w.Add(w.builder.StateDelta(delta)) -} - -func (w *Writer) MessagesSnapshot(messages []agui.UIMessage) { - w.Add(w.builder.MessagesSnapshot(messages)) -} - -func (w *Writer) Custom(name string, value any) { - w.Add(w.builder.Custom(name, value)) -} - -func (w *Writer) Finish(reason string) { - reason = agui.NormalizeFinishReason(reason) - text := w.Run.Text() - w.finishReasoning() - w.Run.Usage = agui.Usage{ - PromptTokens: 1, - CompletionTokens: utf8.RuneCountInString(text), - TotalTokens: utf8.RuneCountInString(text) + 1, - } - w.Run.Status = Status{State: "complete", FinishReason: reason} - w.Add(w.builder.TextMessageEnd(w.Run.MessageID)) - w.addFinalSnapshot() - w.Add(w.builder.RunFinished(w.Run.ThreadID, w.Run.RunID, reason, w.Run.Usage)) -} - -func (w *Writer) Error(message string) { - w.finishReasoning() - w.Run.Status = Status{State: "error", Error: map[string]any{"message": message}} - w.addFinalSnapshot() - w.Add(w.builder.RunError(w.Run.ThreadID, w.Run.RunID, message)) -} - -func (w *Writer) Abort(message string) { - w.finishReasoning() - w.Run.Status = Status{State: "aborted", Error: map[string]any{"message": message}} - w.addFinalSnapshot() - w.Add(w.builder.RunError(w.Run.ThreadID, w.Run.RunID, message)) -} - -func (w *Writer) addFinalSnapshot() { - if w == nil || w.Run == nil { - return - } - w.MessagesSnapshot([]agui.UIMessage{w.Run.FinalUIMessage(0, true)}) -} - -func (w *Writer) finishReasoning() { - if !w.reasoningOpen { - return - } - w.Add(w.builder.ReasoningMessageEnd(w.Run.MessageID)) - w.Add(w.builder.ReasoningEnd(w.Run.MessageID)) - w.reasoningOpen = false -} - -func (w *Writer) applySummary(evt agui.Event) { - switch evt["type"] { - case agui.EventTextMessageContent: - if delta, _ := evt["delta"].(string); delta != "" { - w.Run.Preview = PreviewFromText(w.Run.Text(), PreviewBudgetBytes) - } - case agui.EventCustom: - name, _ := evt["name"].(string) - value, _ := evt["value"].(map[string]any) - switch name { - case "com.beeper.source": - w.Run.Artifacts.Sources = append(w.Run.Artifacts.Sources, value) - case "com.beeper.document": - w.Run.Artifacts.Documents = append(w.Run.Artifacts.Documents, value) - case "com.beeper.file": - w.Run.Artifacts.Files = append(w.Run.Artifacts.Files, value) - case "com.beeper.data": - if key, _ := value["name"].(string); key != "" { - w.Run.Data[key] = value["value"] - } - } - } -} - -func (t Run) Text() string { - var out strings.Builder - for _, evt := range t.Events { - if evt["type"] == agui.EventTextMessageContent { - if delta, _ := evt["delta"].(string); delta != "" { - out.WriteString(delta) - } - } - } - return out.String() -} - -func (t Run) FinalUIMessage(textBudget int, includeThinking bool) agui.UIMessage { - message := agui.UIMessage{ - ID: t.MessageID, - Role: agui.RoleAssistant, - Metadata: t.UIMessageMetadata(true).Map(), - } - var textPart agui.MessagePart - var thinkingPart agui.MessagePart - var textContent, thinkingContent strings.Builder - toolParts := map[string]agui.MessagePart{} - toolResultParts := map[string]agui.MessagePart{} - approvalByID := map[string]any{} - appendPart := func(part agui.MessagePart) agui.MessagePart { - message.Parts = append(message.Parts, part) - return part - } - for _, evt := range t.Events { - switch evt["type"] { - case agui.EventTextMessageContent: - delta, _ := evt["delta"].(string) - if delta == "" { - continue - } - if textPart == nil { - textPart = appendPart(agui.MessagePart{"type": "text", "content": "", "state": agui.PartStateStreaming}) - } - textContent.WriteString(delta) - case agui.EventTextMessageEnd: - if textPart != nil { - textPart["state"] = agui.PartStateDone - } - case agui.EventReasoningMsgCont: - delta, _ := evt["delta"].(string) - if delta == "" { - continue - } - if !includeThinking { - continue - } - if thinkingPart == nil { - thinkingPart = appendPart(agui.MessagePart{"type": "thinking", "content": "", "state": agui.PartStateStreaming}) - } - thinkingContent.WriteString(delta) - case agui.EventReasoningMsgEnd: - if thinkingPart != nil { - thinkingPart["state"] = agui.PartStateDone - } - case agui.EventToolCallStart: - toolCallID, _ := evt["toolCallId"].(string) - if toolCallID == "" { - continue - } - part := agui.MessagePart{ - "type": "tool-call", - "id": toolCallID, - "toolCallId": toolCallID, - "name": firstString(evt["toolName"], evt["toolCallName"]), - "arguments": "", - "state": firstString(evt["state"]), - } - if index, ok := evt["index"]; ok { - part["index"] = index - } - if approval, ok := evt["approval"]; ok { - part["approval"] = approval - } - if metadata, ok := evt["metadata"]; ok { - part["metadata"] = metadata - } - toolParts[toolCallID] = appendPart(part) - case agui.EventToolCallArgs: - toolCallID, _ := evt["toolCallId"].(string) - part := toolParts[toolCallID] - if part == nil { - part = appendPart(agui.MessagePart{"type": "tool-call", "id": toolCallID, "toolCallId": toolCallID, "arguments": ""}) - toolParts[toolCallID] = part - } - part["state"] = firstString(evt["state"]) - if delta, _ := evt["delta"].(string); delta != "" { - part["arguments"] = asString(part["arguments"]) + delta - } - if args, ok := evt["args"]; ok { - part["input"] = args - } - case agui.EventToolCallEnd: - toolCallID, _ := evt["toolCallId"].(string) - part := toolParts[toolCallID] - if part == nil { - part = appendPart(agui.MessagePart{"type": "tool-call", "id": toolCallID, "toolCallId": toolCallID}) - toolParts[toolCallID] = part - } - part["name"] = firstString(part["name"], evt["toolName"], evt["toolCallName"]) - part["state"] = firstString(evt["state"]) - if input, ok := evt["input"]; ok { - part["input"] = input - } - if result, ok := evt["result"]; ok { - part["output"] = jsonValue(result) - } - case agui.EventToolCallResult: - toolCallID, _ := evt["toolCallId"].(string) - if toolCallID == "" { - continue - } - part := toolResultParts[toolCallID] - if part == nil { - part = appendPart(agui.MessagePart{"type": "tool-result", "toolCallId": toolCallID, "content": "", "state": firstString(evt["state"])}) - toolResultParts[toolCallID] = part - } - part["state"] = firstString(evt["state"]) - part["content"] = asString(part["content"]) + asString(evt["content"]) - case agui.EventCustom: - name, _ := evt["name"].(string) - value, _ := evt["value"].(map[string]any) - switch name { - case agui.ApprovalCustomRequested: - if toolCallID, _ := value["toolCallId"].(string); toolCallID != "" { - if part := toolParts[toolCallID]; part != nil { - part["approval"] = value["approval"] - part["state"] = agui.ToolStateApprovalRequested - } - } - case agui.ApprovalCustomResponded: - if approval, ok := value["approval"]; ok { - approvalByID[approvalMapID(approval)] = approval - } - case "com.beeper.source": - part := cloneValueMap(value) - part["type"] = "source-url" - if asString(part["sourceId"]) == "" { - part["sourceId"] = firstString(part["url"], part["title"]) - } - message.Parts = append(message.Parts, part) - case "com.beeper.document": - part := cloneValueMap(value) - part["type"] = "source-document" - if asString(part["sourceId"]) == "" { - part["sourceId"] = firstString(part["id"], part["title"]) - } - message.Parts = append(message.Parts, part) - case "com.beeper.file": - part := cloneValueMap(value) - part["type"] = "file" - message.Parts = append(message.Parts, part) - case "com.beeper.data": - message.Parts = append(message.Parts, agui.MessagePart{"type": "data-com-beeper-data", "data": value}) - } - } - } - for _, part := range toolParts { - if approvalID := approvalMapID(part["approval"]); approvalID != "" { - if response := approvalByID[approvalID]; response != nil { - part["approvalResponse"] = response - part["state"] = agui.ToolStateApprovalResponded - } - } - } - if t.Status.State != "" && t.Status.State != "streaming" { - for _, part := range toolParts { - finalizeOpenToolPart(part, t.Status.State) - } - } - if textPart != nil { - textPart["content"] = textContent.String() - } - if thinkingPart != nil { - thinkingPart["content"] = thinkingContent.String() - } - compactTextPart(textPart, textBudget) - compactTextPart(thinkingPart, textBudget) - if len(message.Parts) > 1 { - visible := make([]agui.MessagePart, 0, len(message.Parts)) - other := make([]agui.MessagePart, 0, len(message.Parts)) - for _, part := range message.Parts { - switch part["type"] { - case "text", "thinking": - visible = append(visible, part) - default: - other = append(other, part) - } - } - if len(visible) > 0 { - message.Parts = append(visible, other...) - } - } - return message -} - -func finalizeOpenToolPart(part agui.MessagePart, runState string) { - if part == nil { - return - } - if _, hasOutput := part["output"]; hasOutput { - return - } - state, _ := part["state"].(string) - switch state { - case agui.ToolStateApprovalResponded: - return - } - reason := "run finalized before tool completed" - if runState == "aborted" { - reason = "run aborted before tool completed" - } else if runState == "error" { - reason = "run failed before tool completed" - } - part["state"] = agui.ToolStateInputComplete - part["output"] = map[string]any{ - "state": agui.ToolResultStateError, - "status": "failed", - "reason": reason, - } -} - -func (t Run) InitialUIMessage() agui.UIMessage { - message := agui.UIMessage{ - ID: t.MessageID, - Role: agui.RoleAssistant, - Metadata: t.UIMessageMetadata(false).Map(), - } - if t.Preview.Text != "" { - message.Parts = []agui.MessagePart{{ - "type": "text", - "content": t.Preview.Text, - "state": agui.PartStateStreaming, - }} - } else { - message.Parts = []agui.MessagePart{} - } - return message -} - -func (t Run) UIMessageMetadata(includeUsage bool) UIMessageMetadata { - metadata := UIMessageMetadata{ - ThreadID: t.ThreadID, - RunID: t.RunID, - Status: t.Status, - } - if includeUsage { - metadata.Usage = &t.Usage - } - return metadata -} - -func compactTextPart(part agui.MessagePart, budget int) { - if part == nil { - return - } - content, _ := part["content"].(string) - if budget <= 0 { - if part["state"] == "" { - part["state"] = agui.PartStateDone - } - return - } - preview := BoundedPreview(content, budget) - part["content"] = preview - if len(preview) < len(content) { - part["providerMetadata"] = map[string]any{"truncated": true} - } - if part["state"] == "" { - part["state"] = agui.PartStateDone - } -} - -func asString(value any) string { - switch typed := value.(type) { - case string: - return typed - case fmt.Stringer: - return typed.String() - case nil: - return "" - default: - return fmt.Sprint(typed) - } -} - -func cloneValueMap(value map[string]any) agui.MessagePart { - cp := make(agui.MessagePart, len(value)+1) - for key, item := range value { - cp[key] = item - } - return cp -} - -func firstString(values ...any) string { - for _, value := range values { - if text, ok := value.(string); ok && text != "" { - return text - } - } - return "" -} - -func approvalMapID(value any) string { - switch typed := value.(type) { - case agui.ToolApproval: - return typed.ID - case *agui.ToolApproval: - if typed != nil { - return typed.ID - } - case agui.ToolApprovalResponse: - return typed.ID - case *agui.ToolApprovalResponse: - if typed != nil { - return typed.ID - } - case map[string]any: - id, _ := typed["id"].(string) - return id - } - return "" -} - -func (t Run) Metadata() map[string]any { - return t.RunMetadata().Map() -} - -func (t Run) RunMetadata() RunMetadata { - return RunMetadata{ - Schema: "com.beeper.ai.run.v1", - Protocol: "ag-ui", - ThreadID: t.ThreadID, - RunID: t.RunID, - MessageID: t.MessageID, - AgentID: t.AgentID, - AgentName: t.AgentName, - Model: t.Model, - Usage: t.Usage, - Status: t.Status, - Approvals: t.Approvals, - Artifacts: t.Artifacts, - Data: t.Data, - Preview: t.Preview, - } -} - -func (t Run) Validate() error { - for i, evt := range t.Events { - if err := agui.ValidateEvent(evt); err != nil { - return fmt.Errorf("event %d: %w", i+1, err) - } - } - return nil -} - -func PreviewFromText(text string, budget int) Preview { - preview := BoundedPreview(text, budget) - return Preview{Text: preview, Truncated: len(preview) < len(text)} -} - -func BoundedPreview(text string, budget int) string { - text = strings.TrimSpace(text) - if budget <= 0 || len(text) <= budget { - return text - } - end := budget - for end > 0 && !utf8.RuneStart(text[end]) { - end-- - } - if end <= 0 { - return "" - } - return strings.TrimSpace(text[:end]) -} - -func SplitTextUTF8(text string, maxBytes int) []string { - if maxBytes <= 0 { - return nil - } - if len(text) <= maxBytes { - return []string{text} - } - var chunks []string - start := 0 - for start < len(text) { - end := start + maxBytes - if end >= len(text) { - chunks = append(chunks, text[start:]) - break - } - for end > start && !utf8.RuneStart(text[end]) { - end-- - } - if end == start { - _, size := utf8.DecodeRuneInString(text[start:]) - end = start + size - } - chunks = append(chunks, text[start:end]) - start = end - } - return chunks -} - -func JSONSize(value any) int { - raw, err := json.Marshal(value) - if err != nil { - return CarrierBudgetBytes + 1 - } - return len(raw) -} diff --git a/pkg/ai-stream/stream_test.go b/pkg/ai-stream/stream_test.go deleted file mode 100644 index bfa9ee8..0000000 --- a/pkg/ai-stream/stream_test.go +++ /dev/null @@ -1,460 +0,0 @@ -package aistream - -import ( - "encoding/json" - "strings" - "testing" - "time" - - "github.com/beeper/dummybridge/pkg/ag-ui" -) - -func TestPackRunSplitsOver64KBAndReconstructs(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - writer.Text(strings.Repeat("a", 70*1024)) - writer.Finish(agui.FinishReasonStop) - - carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) - if err != nil { - t.Fatal(err) - } - if len(carriers) < 2 { - t.Fatalf("expected multiple carriers for over-64KB output, got %d", len(carriers)) - } - for i, carrier := range carriers { - if size := JSONSize(CarrierContent(carrier.Envelopes)); size > CarrierBudgetBytes { - t.Fatalf("carrier %d is %d bytes, budget %d", i, size, CarrierBudgetBytes) - } - } - if got := ReconstructText(carriers); got != strings.Repeat("a", 70*1024) { - t.Fatalf("reconstructed text length = %d", len(got)) - } -} - -func TestPackRunDoesNotPutFinalizationTotalsOnStreamEnvelopes(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - writer.Text("hello") - writer.Finish(agui.FinishReasonStop) - - carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) - if err != nil { - t.Fatal(err) - } - raw, err := json.Marshal(CarrierContent(carriers[0].Envelopes)) - if err != nil { - t.Fatal(err) - } - if strings.Contains(string(raw), "seqTotal") { - t.Fatalf("stream envelopes must not contain finalization totals: %s", raw) - } -} - -func TestFinalSnapshotSplitsIntoBaseAndContinuationParts(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - writer.Thinking(strings.Repeat("t", 12*1024)) - writer.Text(strings.Repeat("a", 70*1024)) - writer.ToolStart("tool-1", "shell", 0, nil) - writer.ToolArgs("tool-1", `{"cmd":"pwd"}`, `{"cmd":"pwd"}`) - writer.ToolEnd("tool-1", "shell", `{"cmd":"pwd"}`, map[string]any{"ok": true}) - writer.Finish(agui.FinishReasonStop) - - carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) - if err != nil { - t.Fatal(err) - } - var baseSnapshots, continuations int - var baseText string - var reconstructedText strings.Builder - var sawMetadata bool - for i, carrier := range carriers { - if size := JSONSize(CarrierContent(carrier.Envelopes)); size > CarrierBudgetBytes { - t.Fatalf("carrier %d is %d bytes, budget %d", i, size, CarrierBudgetBytes) - } - for _, env := range carrier.Envelopes { - switch env.Part["type"] { - case agui.EventMessagesSnapshot: - baseSnapshots++ - messages, ok := env.Part["messages"].([]any) - if !ok || len(messages) != 1 { - t.Fatalf("bad final base snapshot: %#v", env.Part["messages"]) - } - message, ok := messages[0].(map[string]any) - if !ok { - t.Fatalf("bad final base snapshot message: %#v", messages[0]) - } - metadata, ok := message["metadata"].(map[string]any) - if ok && metadata["runId"] == "run-1" { - sawMetadata = true - } - for _, part := range testFinalParts(t, message["parts"]) { - if part["type"] == "text" { - baseText += part["content"].(string) - } - } - case agui.EventCustom: - if env.Part["name"] != FinalPartsCustomName { - continue - } - continuations++ - value := env.Part["value"].(map[string]any) - if value["messageId"] != run.MessageID || value["runId"] != run.RunID { - t.Fatalf("bad continuation relation data: %#v", value) - } - if _, ok := value["metadata"]; ok { - t.Fatalf("continuation must not duplicate message metadata: %#v", value) - } - for _, part := range testFinalParts(t, value["parts"]) { - if part["type"] == "text" { - reconstructedText.WriteString(part["content"].(string)) - } - } - } - } - } - if baseSnapshots != 1 || continuations == 0 || !sawMetadata { - t.Fatalf("expected one metadata base snapshot and continuations, base=%d continuations=%d metadata=%v", baseSnapshots, continuations, sawMetadata) - } - if baseText == "" { - t.Fatal("base final snapshot must keep visible text in the primary event") - } - if !strings.Contains(run.Text(), reconstructedText.String()) { - t.Fatalf("unexpected continuation text reconstruction length=%d", reconstructedText.Len()) - } -} - -func testFinalParts(t *testing.T, value any) []map[string]any { - t.Helper() - switch parts := value.(type) { - case []agui.MessagePart: - out := make([]map[string]any, 0, len(parts)) - for _, part := range parts { - out = append(out, map[string]any(part)) - } - return out - case []any: - out := make([]map[string]any, 0, len(parts)) - for _, rawPart := range parts { - part, ok := rawPart.(map[string]any) - if !ok { - t.Fatalf("bad final part: %#v", rawPart) - } - out = append(out, part) - } - return out - default: - t.Fatalf("bad final parts: %#v", value) - return nil - } -} - -func TestPackRunUsesDeltaEventsInsteadOfAccumulatedText(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - tick := int64(10) - writer := NewWriter(run, func() time.Time { - tick++ - return time.Unix(tick, 0) - }) - writer.Start() - writer.Text("abc") - writer.Text("def") - writer.Finish(agui.FinishReasonStop) - - carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) - if err != nil { - t.Fatal(err) - } - if len(carriers) != 1 { - t.Fatalf("under-budget run should be packed into one carrier, got %d", len(carriers)) - } - var deltas []string - for _, carrier := range carriers { - for _, env := range carrier.Envelopes { - if env.Part["type"] == agui.EventTextMessageContent { - deltas = append(deltas, env.Part["delta"].(string)) - } - } - } - if strings.Join(deltas, "|") != "abc|def" { - t.Fatalf("expected original deltas only, got %#v", deltas) - } -} - -func TestRawEventIsTruncatedBeforePacking(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - builder := agui.NewEventBuilder(DefaultModel, func() time.Time { return time.Unix(10, 0) }) - run.Events = append(run.Events, builder.Custom("com.beeper.debug", map[string]any{"ok": true})) - run.Events[0]["rawEvent"] = strings.Repeat("x", CarrierBudgetBytes) - - carriers, err := PackRun(*run, "$anchor", CarrierBudgetBytes) - if err != nil { - t.Fatal(err) - } - part := carriers[0].Envelopes[0].Part - if part["rawEventTruncated"] != true { - t.Fatalf("expected rawEventTruncated marker, got %#v", part) - } - if size := JSONSize(CarrierContent(carriers[0].Envelopes)); size > CarrierBudgetBytes { - t.Fatalf("carrier size = %d, budget %d", size, CarrierBudgetBytes) - } -} - -func TestPackRunRejectsOversizedNonTextEvent(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - builder := agui.NewEventBuilder(DefaultModel, func() time.Time { return time.Unix(10, 0) }) - run.Events = append(run.Events, builder.Custom("com.beeper.large", map[string]any{ - "value": strings.Repeat("x", CarrierBudgetBytes), - })) - - _, err := PackRun(*run, "$anchor", CarrierBudgetBytes) - if err == nil { - t.Fatal("expected oversized non-text event to fail packing") - } -} - -func TestValidateRejectsLegacyOrInvalidToolResultShape(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - builder := agui.NewEventBuilder(DefaultModel, func() time.Time { return time.Unix(10, 0) }) - run.Events = append(run.Events, - builder.RunStarted("thread-1", "run-1"), - builder.ToolCallStart("msg-run-1", "tool-1", "shell", nil, nil), - builder.ToolCallEnd("tool-1", "shell", nil, map[string]any{"ok": true}, agui.ToolStateInputComplete), - ) - if err := run.Validate(); err == nil { - t.Fatal("expected validation error for non-string TOOL_CALL_END.result") - } -} - -func TestFinalUIMessageCarriesToolCallMetadata(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.ToolStartWithMetadata("tool-1", "calendar.get_events", 0, nil, map[string]any{ - "displayName": "List Calendar Events", - "iconUrl": "mxc://beeper.com/calendar", - }) - - message := run.FinalUIMessage(0, true) - if len(message.Parts) != 1 { - t.Fatalf("expected one part, got %#v", message.Parts) - } - metadata, ok := message.Parts[0]["metadata"].(map[string]any) - if !ok || metadata["displayName"] != "List Calendar Events" || metadata["iconUrl"] != "mxc://beeper.com/calendar" { - t.Fatalf("bad tool metadata: %#v", message.Parts[0]) - } -} - -func TestFinalUIMessageCarriesParsedToolOutputs(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.ToolStart("tool-1", "shell", 0, nil) - writer.ToolArgs("tool-1", `{"cmd":"pwd"}`, `{"cmd":"pwd"}`) - writer.ToolEnd("tool-1", "shell", map[string]any{"cmd": "pwd"}, nil) - writer.ToolStart("tool-2", "files", 1, nil) - writer.ToolError("tool-2", "files", map[string]any{"path": "/tmp/nope"}, "missing") - - message := run.FinalUIMessage(0, true) - if len(message.Parts) != 2 { - t.Fatalf("expected two tool parts, got %#v", message.Parts) - } - success, ok := message.Parts[0]["output"].(map[string]any) - if !ok || success["state"] != agui.ToolResultStateComplete || success["status"] != "success" { - t.Fatalf("success tool without result should emit terminal success output: %#v", message.Parts[0]) - } - failure, ok := message.Parts[1]["output"].(map[string]any) - if !ok || failure["state"] != agui.ToolResultStateError || failure["status"] != "failed" || failure["reason"] != "missing" { - t.Fatalf("failed tool output should be parsed and terminal: %#v", message.Parts[1]) - } -} - -func TestFinalUIMessageFailsOpenToolsWhenRunFinalized(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.ToolStart("tool-1", "summarize", 0, nil) - writer.ToolStart("tool-2", "calendar", 1, nil) - writer.Finish(agui.FinishReasonStop) - - message := run.FinalUIMessage(0, true) - if len(message.Parts) != 2 { - t.Fatalf("expected two tool parts, got %#v", message.Parts) - } - for _, part := range message.Parts { - if part["state"] != agui.ToolStateInputComplete { - t.Fatalf("open tool should be finalized as input-complete: %#v", part) - } - output, ok := part["output"].(map[string]any) - if !ok || output["state"] != agui.ToolResultStateError || output["status"] != "failed" { - t.Fatalf("open tool should get terminal failed output: %#v", part) - } - } -} - -func TestFinalUIMessageCarriesTopLevelArtifactsWithStableIDs(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Custom("com.beeper.source", map[string]any{ - "sourceId": "source-1", - "url": "https://example.com/source", - "title": "Example Source", - }) - writer.Custom("com.beeper.document", map[string]any{ - "id": "doc-1", - "title": "Example Doc", - "mediaType": "text/plain", - }) - writer.Custom("com.beeper.file", map[string]any{ - "url": "mxc://example/file", - "mediaType": "application/octet-stream", - }) - - message := run.FinalUIMessage(0, true) - if len(message.Parts) != 3 { - t.Fatalf("expected artifact parts, got %#v", message.Parts) - } - if message.Parts[0]["type"] != "source-url" || message.Parts[0]["sourceId"] != "source-1" || message.Parts[0]["url"] != "https://example.com/source" { - t.Fatalf("bad source part shape: %#v", message.Parts[0]) - } - if _, hasNestedSource := message.Parts[0]["source"]; hasNestedSource { - t.Fatalf("source part should not nest payload: %#v", message.Parts[0]) - } - if message.Parts[1]["type"] != "source-document" || message.Parts[1]["sourceId"] != "doc-1" || message.Parts[1]["id"] != "doc-1" { - t.Fatalf("bad document part shape: %#v", message.Parts[1]) - } - if message.Parts[2]["type"] != "file" || message.Parts[2]["url"] != "mxc://example/file" { - t.Fatalf("bad file part shape: %#v", message.Parts[2]) - } - if _, hasNestedFile := message.Parts[2]["file"]; hasNestedFile { - t.Fatalf("file part should not nest payload: %#v", message.Parts[2]) - } -} - -func TestApprovalResolverMatchesEmojiKeysAndAliases(t *testing.T) { - choices := DefaultApprovalChoices() - for _, key := range []string{"✅", "approve"} { - choice, ok := ResolveApprovalChoice(choices, key) - response := ApprovalResponseForChoice("approval-1", choice) - if !ok || !response.Approved || response.Always { - t.Fatalf("expected approve for %q, got %#v ok=%v", key, choice, ok) - } - } - choice, ok := ResolveApprovalChoice(choices, "☑️") - response := ApprovalResponseForChoice("approval-1", choice) - if !ok || !response.Approved || !response.Always { - t.Fatalf("expected always-approve, got %#v ok=%v", choice, ok) - } - choice, ok = ResolveApprovalChoice(choices, "deny") - response = ApprovalResponseForChoice("approval-1", choice) - if !ok || response.Approved || response.Reason != "denied" { - t.Fatalf("expected denial, got %#v ok=%v", choice, ok) - } -} - -func TestApprovalRequestedValueOwnsStreamPayloadShape(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "ai", "AI", time.Unix(10, 0)) - run.MessageID = "msg-run-1" - approval := agui.ToolApproval{ID: "approval-1", NeedsApproval: true} - - value := NewApprovalRequestedValue(*run, "tool-1", "shell", map[string]any{"command": "ls"}, approval).Map() - - if value["threadId"] != "thread-1" || value["runId"] != "run-1" || value["messageId"] != "msg-run-1" { - t.Fatalf("bad run identifiers: %#v", value) - } - if value["toolCallId"] != "tool-1" || value["toolName"] != "shell" { - t.Fatalf("bad tool identifiers: %#v", value) - } - if value["approvalMessageId"] != "approval-1" { - t.Fatalf("missing approval message id: %#v", value) - } - if _, ok := value["approvalEventId"]; ok { - t.Fatalf("approval event id should only be added after Matrix send: %#v", value) - } - choices, ok := value["choices"].([]ApprovalChoice) - if !ok || len(choices) != len(DefaultApprovalChoices()) || choices[0].Key != ApprovalChoiceApprove { - t.Fatalf("bad approval choices: %#v", value["choices"]) - } - if ApprovalIDFromRequestedValue(value) != "approval-1" { - t.Fatalf("approval id resolver failed for value: %#v", value) - } - if !SetApprovalRequestedEventID(value, "$approval") || value["approvalEventId"] != "$approval" { - t.Fatalf("failed to annotate approval event id: %#v", value) - } -} - -func TestRunMetadataOwnsMatrixPayloadShape(t *testing.T) { - run := NewRun("run-1", "thread-1", DefaultModel, "agent-1", "Agent", time.Unix(10, 0)) - run.MessageID = "msg-run-1" - run.Usage = agui.Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3} - run.Preview = Preview{Text: "hello", Truncated: false} - - metadata := run.Metadata() - - if metadata["schema"] != "com.beeper.ai.run.v1" || metadata["protocol"] != "ag-ui" { - t.Fatalf("bad protocol metadata: %#v", metadata) - } - if metadata["threadId"] != "thread-1" || metadata["runId"] != "run-1" || metadata["messageId"] != "msg-run-1" { - t.Fatalf("bad run identifiers: %#v", metadata) - } - agent, ok := metadata["agent"].(map[string]any) - if !ok || agent["id"] != "agent-1" || agent["displayName"] != "Agent" { - t.Fatalf("bad agent metadata: %#v", metadata["agent"]) - } - usage, ok := metadata["usage"].(map[string]any) - if !ok || usage["promptTokens"] != 1 || usage["completionTokens"] != 2 || usage["totalTokens"] != 3 { - t.Fatalf("bad usage metadata: %#v", metadata["usage"]) - } - if _, ok := metadata["usageDetails"].(map[string]any); !ok { - t.Fatalf("usage details should always be present: %#v", metadata) - } -} - -func TestApprovalNoticeOwnsHiddenMessagePayloadShape(t *testing.T) { - notice := NewApprovalNotice(ApprovalContext{ - ID: "approval-1", - MessageID: "msg-run-1", - ToolCallID: "tool-1", - ToolName: "shell", - }, DefaultApprovalChoices()).Map() - - if notice["schema"] != "com.beeper.ai.approval.v1" || notice["state"] != "requested" { - t.Fatalf("bad approval notice metadata: %#v", notice) - } - if notice["id"] != "approval-1" || notice["messageId"] != "msg-run-1" || notice["toolCallId"] != "tool-1" || notice["toolName"] != "shell" { - t.Fatalf("bad approval notice identifiers: %#v", notice) - } - choices, ok := notice["choices"].([]any) - if !ok || len(choices) != 3 { - t.Fatalf("bad approval notice choices: %#v", notice["choices"]) - } - first, ok := choices[0].(map[string]any) - if !ok || first["key"] != ApprovalChoiceApprove || first["label"] != "Allow once" || first["alias"] != "✅" { - t.Fatalf("bad first approval choice: %#v", choices[0]) - } - if _, ok := first["style"]; ok { - t.Fatalf("empty style should be omitted from approval choices: %#v", first) - } - deny, ok := choices[2].(map[string]any) - if !ok || deny["style"] != "danger" { - t.Fatalf("deny choice should keep danger style: %#v", choices[2]) - } -} - -func TestCleanupKeepsSelectedUserReactionAndRemovesBridgeOptions(t *testing.T) { - choices := DefaultApprovalChoices() - cleanup := CleanupApprovalReactions(choices, "✅", []ReactionEvent{ - {EventID: "$bridge-allow", Sender: "ai", Key: "✅", Bridge: true}, - {EventID: "$bridge-deny", Sender: "ai", Key: "❌", Bridge: true}, - {EventID: "$user-allow", Sender: "@user:example", Key: "✅"}, - {EventID: "$user-deny", Sender: "@user:example", Key: "❌"}, - }, "ai") - if !cleanup.Matched || cleanup.SelectedReactionEvent != "$user-allow" { - t.Fatalf("bad selected reaction: %#v", cleanup) - } - got := strings.Join(cleanup.RedactReactionEvents, ",") - if !strings.Contains(got, "$bridge-allow") || !strings.Contains(got, "$bridge-deny") || !strings.Contains(got, "$user-deny") { - t.Fatalf("bad cleanup redactions: %#v", cleanup.RedactReactionEvents) - } -} diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index bb79b75..e84aae9 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -11,8 +11,8 @@ import ( "strings" "time" - "github.com/beeper/dummybridge/pkg/ag-ui" - "github.com/beeper/dummybridge/pkg/ai-stream" + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" "go.mau.fi/util/shlex" ) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index f8e15d7..b387866 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/beeper/dummybridge/pkg/ag-ui" - "github.com/beeper/dummybridge/pkg/ai-stream" + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" "maunium.net/go/mautrix/id" ) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 5198eaa..2898b7b 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -10,9 +10,9 @@ import ( "sync" "time" - "github.com/beeper/dummybridge/pkg/ag-ui" - "github.com/beeper/dummybridge/pkg/ai-stream" - aibridgev2 "github.com/beeper/dummybridge/pkg/ai-stream/bridgev2" + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" + aibridgev2 "github.com/beeper/ai-bridge/pkg/ai-stream/bridgev2" "github.com/rs/zerolog/log" "go.mau.fi/util/exsync" "go.mau.fi/util/jsontime" diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 63e2239..3290326 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/beeper/dummybridge/pkg/ag-ui" - "github.com/beeper/dummybridge/pkg/ai-stream" + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" From 538d2f68e2461966b567909fbda08e03ce84e22b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 04:13:20 +0200 Subject: [PATCH 43/46] split --- go.mod | 36 +- go.sum | 76 +- pkg/connector/ai_commands.go | 440 ++++++++++ pkg/connector/ai_parse_helpers.go | 177 ++++ pkg/connector/ai_plans.go | 207 +++++ pkg/connector/ai_runner.go | 415 +++++++++ pkg/connector/ai_runtime.go | 1341 ----------------------------- pkg/connector/ai_types.go | 156 ++++ 8 files changed, 1455 insertions(+), 1393 deletions(-) create mode 100644 pkg/connector/ai_commands.go create mode 100644 pkg/connector/ai_parse_helpers.go create mode 100644 pkg/connector/ai_plans.go create mode 100644 pkg/connector/ai_runner.go create mode 100644 pkg/connector/ai_types.go diff --git a/go.mod b/go.mod index d84f7a3..bddb6ab 100644 --- a/go.mod +++ b/go.mod @@ -1,39 +1,39 @@ module github.com/beeper/dummybridge -go 1.25.0 +go 1.24.0 toolchain go1.25.6 require ( - github.com/rs/zerolog v1.35.1 - go.mau.fi/util v0.9.9 - maunium.net/go/mautrix v0.28.0 + github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 + github.com/rs/zerolog v1.34.0 + go.mau.fi/util v0.9.5 + maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b ) require ( - filippo.io/edwards25519 v1.2.0 // indirect - github.com/beeper/ai-bridge v0.0.0-20260524020001-6f18f21c0e09 + filippo.io/edwards25519 v1.1.0 // indirect github.com/coder/websocket v1.8.14 // indirect - github.com/coreos/go-systemd/v22 v22.7.0 // indirect - github.com/lib/pq v1.12.3 // indirect + github.com/coreos/go-systemd/v22 v22.6.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.44 // indirect - github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 // indirect + github.com/mattn/go-sqlite3 v1.14.33 // indirect + github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/rs/xid v1.6.0 // indirect github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e // indirect - github.com/tidwall/gjson v1.19.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - github.com/yuin/goldmark v1.8.2 // indirect + github.com/yuin/goldmark v1.7.16 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect - golang.org/x/crypto v0.51.0 // indirect - golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect - golang.org/x/net v0.54.0 // indirect - golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.44.0 // indirect - golang.org/x/text v0.37.0 // indirect + golang.org/x/crypto v0.47.0 // indirect + golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index 5d90657..d9630eb 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,44 @@ -filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= -filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/beeper/ai-bridge v0.0.0-20260524020001-6f18f21c0e09 h1:NbH3OUYoEw2gGjg5VzBdPrT27J5HcKGUxj0/nYNFTqE= -github.com/beeper/ai-bridge v0.0.0-20260524020001-6f18f21c0e09/go.mod h1:0K/m+XXVLw1mX5gZ6gIIxDi5RDAgj09W++2eGREM8MI= +github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 h1:Pw2qyz5mizv/UL4JTKiK1sbYfUl6o8dk/KcNyFlSFG0= +github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72/go.mod h1:Uf2M1ogzy7VGB6uUzzHjZL2eaYt79DK0Py8I6xZl3r0= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= -github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= +github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= -github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= -github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= -github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= -github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= +github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= -github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU= -github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= @@ -40,25 +46,27 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= -github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.9 h1:ujDeXCo07HBor5oQLyO1tHklupmqVmPgasc53d7q/NE= -go.mau.fi/util v0.9.9/go.mod h1:pqt4Vcrt+5gcH/CgrHZg11qSx+b34o6mknGzOEA6waY= +github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= +github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= +go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= -golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= -golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw= -golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= -golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= -golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= -golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= -golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= -golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= -golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -67,5 +75,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.28.0 h1:vBakLzf8MAdfED3NzAKiMeKQbc3AQ4EAS03NC+TVMXQ= -maunium.net/go/mautrix v0.28.0/go.mod h1:/a9A7LGaqb9B3nho4tLd28n0EPcCdwpm2dxkxkLLgh0= +maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= +maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= diff --git a/pkg/connector/ai_commands.go b/pkg/connector/ai_commands.go new file mode 100644 index 0000000..1ca9945 --- /dev/null +++ b/pkg/connector/ai_commands.go @@ -0,0 +1,440 @@ +package connector + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "go.mau.fi/util/shlex" +) + +func parseCommand(input string) (*parsedCommand, error) { + tokens, err := shlex.Split(input) + if err != nil { + return nil, fmt.Errorf("invalid command syntax: %w", err) + } + if len(tokens) == 0 { + return &parsedCommand{Name: "help"}, nil + } + switch strings.ToLower(tokens[0]) { + case "help", "/help", "!help", "dummybridge": + return &parsedCommand{Name: "help"}, nil + case "stream-tools": + cmd, err := parseToolsCommand(tokens[1:]) + return &parsedCommand{Name: "stream-tools", Tools: cmd}, err + case "stream": + cmd, err := parseStreamCommand(tokens[1:]) + return &parsedCommand{Name: "stream", Random: cmd}, err + default: + return nil, fmt.Errorf("unknown AI demo command %q", tokens[0]) + } +} + +func helpText() string { + return strings.Join([]string{ + "DummyBridge demo commands:", + "help", + "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", + "stream-tools ... [common options]", + "Notes: stream enables approval requests by default; approval-tagged tools emit a separate Matrix approval event with reaction options.", + }, "\n") +} + +func defaultCommonOptions() commonCommandOptions { + return commonCommandOptions{ + DelayMin: 30 * time.Millisecond, + DelayMax: 150 * time.Millisecond, + ChunkMin: defaultChunkMin, + ChunkMax: defaultChunkMax, + FinishReason: agui.FinishReasonStop, + } +} + +func parseLoremCommand(tokens []string) (*loremCommand, error) { + if len(tokens) == 0 { + return nil, fmt.Errorf("text stream requires a character count") + } + count, err := parsePositiveInt(tokens[0], "character count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { + return nil, err + } + opts, err := parseCommonOptions(tokens[1:]) + if err != nil { + return nil, err + } + return &loremCommand{Chars: count, Options: opts}, nil +} + +func parseToolsCommand(tokens []string) (*toolsCommand, error) { + if len(tokens) < 2 { + return nil, fmt.Errorf("stream-tools requires a character count and at least one tool") + } + count, err := parsePositiveInt(tokens[0], "character count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { + return nil, err + } + var toolTokens, optTokens []string + for _, token := range tokens[1:] { + if strings.HasPrefix(token, "--") { + optTokens = append(optTokens, token) + } else { + toolTokens = append(toolTokens, token) + } + } + if len(toolTokens) == 0 { + return nil, fmt.Errorf("stream-tools requires at least one tool spec") + } + if err := validateMaxIntValue(len(toolTokens), maxDemoToolSpecs, "tool spec count"); err != nil { + return nil, err + } + opts, err := parseCommonOptions(optTokens) + if err != nil { + return nil, err + } + tools := make([]toolSpec, 0, len(toolTokens)) + for idx, token := range toolTokens { + spec, err := parseToolSpec(token, idx) + if err != nil { + return nil, err + } + tools = append(tools, spec) + } + return &toolsCommand{Chars: count, Tools: tools, Options: opts}, nil +} + +func parseRandomCommand(tokens []string) (*randomCommand, error) { + cmd := &randomCommand{ + Duration: 20 * time.Second, + Actions: 20, + DelayMin: 350 * time.Millisecond, + DelayMax: 1150 * time.Millisecond, + Runs: 1, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, + } + return parseStreamLikeCommand(tokens, cmd, false) +} + +func parseStreamCommand(tokens []string) (*randomCommand, error) { + cmd := &randomCommand{ + Duration: 20 * time.Second, + DelayMin: 350 * time.Millisecond, + DelayMax: 1150 * time.Millisecond, + Runs: 1, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced", AllowApproval: true}, + } + return parseStreamLikeCommand(tokens, cmd, true) +} + +func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions bool) (*randomCommand, error) { + rest := tokens + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + seconds, err := parsePositiveInt(rest[0], "duration") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { + return nil, err + } + cmd.Duration = time.Duration(seconds) * time.Second + rest = rest[1:] + } + if deriveActions && cmd.Actions == 0 { + cmd.Actions = max(3, min(maxDemoRandomActions, int(cmd.Duration/time.Second)*2)) + } + for _, token := range rest { + key, value, hasValue := parseOptionToken(token) + switch key { + case "actions": + n, err := parseValidatedInt(value, hasValue, token, "actions", maxDemoRandomActions, false) + if err != nil { + return nil, err + } + cmd.Actions = n + case "chars": + n, err := parseValidatedInt(value, hasValue, token, "character count", maxDemoChars, false) + if err != nil { + return nil, err + } + cmd.Chars = n + case "delay-ms": + minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) + if err != nil { + return nil, err + } + cmd.DelayMin, cmd.DelayMax = minDelay, maxDelay + case "terminal": + if !hasValue { + return nil, fmt.Errorf("%s requires a value", token) + } + switch strings.ToLower(value) { + case "stop", "finish": + cmd.Terminal = "finish" + case "abort", "error": + cmd.Terminal = strings.ToLower(value) + case "length", "tool-calls", "content-filter", "other": + cmd.Terminal = agui.NormalizeFinishReason(value) + default: + return nil, fmt.Errorf("unknown terminal %q", value) + } + case "runs": + n, err := parseValidatedInt(value, hasValue, token, "run count", maxDemoChaosRuns, false) + if err != nil { + return nil, err + } + cmd.Runs = n + case "stagger-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) + if err != nil { + return nil, err + } + cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay + case "no-approval": + cmd.AllowApproval = false + default: + handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) + if err != nil || !handled { + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unknown stream option %q", token) + } + } + } + return cmd, nil +} + +func parseChaosCommand(tokens []string) (*chaosCommand, error) { + cmd := &chaosCommand{ + Runs: 3, + Duration: 10 * time.Second, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + MaxActions: 10, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, + } + rest := tokens + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + n, err := parsePositiveInt(rest[0], "run count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(n, maxDemoChaosRuns, "run count"); err != nil { + return nil, err + } + cmd.Runs = n + rest = rest[1:] + } + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + seconds, err := parsePositiveInt(rest[0], "duration") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { + return nil, err + } + cmd.Duration = time.Duration(seconds) * time.Second + rest = rest[1:] + } + for _, token := range rest { + key, value, hasValue := parseOptionToken(token) + switch key { + case "stagger-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) + if err != nil { + return nil, err + } + cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay + case "max-actions": + n, err := parseValidatedInt(value, hasValue, token, "max-actions", maxDemoChaosActions, false) + if err != nil { + return nil, err + } + cmd.MaxActions = n + default: + handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) + if err != nil || !handled { + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unknown chaos option %q", token) + } + } + } + return cmd, nil +} + +func parseCommonOptions(tokens []string) (commonCommandOptions, error) { + opts := defaultCommonOptions() + for _, token := range tokens { + key, value, hasValue := parseOptionToken(token) + switch key { + case "reasoning": + n, err := parseValidatedInt(value, hasValue, token, "reasoning", maxDemoReasoningChars, true) + if err != nil { + return opts, err + } + opts.ReasoningChars = n + case "steps": + n, err := parseValidatedInt(value, hasValue, token, "steps", maxDemoSteps, false) + if err != nil { + return opts, err + } + opts.Steps = n + case "sources": + n, err := parseValidatedInt(value, hasValue, token, "sources", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Sources = n + case "documents": + n, err := parseValidatedInt(value, hasValue, token, "documents", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Documents = n + case "files": + n, err := parseValidatedInt(value, hasValue, token, "files", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Files = n + case "meta": + opts.Meta = true + case "data": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.DataName = value + case "data-transient": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.DataTransientName = value + case "delay-ms": + minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) + if err != nil { + return opts, err + } + opts.DelayMin, opts.DelayMax = minDelay, maxDelay + case "chunk-chars": + minChunk, maxChunk, err := parseIntRangeOption(value, hasValue, token, "chunk-chars", maxDemoChunkChars) + if err != nil { + return opts, err + } + opts.ChunkMin, opts.ChunkMax = minChunk, maxChunk + case "seed": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + seed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return opts, fmt.Errorf("invalid seed %q", value) + } + opts.Seed, opts.SeedSet = seed, true + case "finish": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.FinishReason = agui.NormalizeFinishReason(value) + case "abort": + opts.Abort = true + case "error": + opts.Error = true + default: + return opts, fmt.Errorf("unknown option %q", token) + } + } + if opts.Abort && opts.Error { + return opts, fmt.Errorf("--abort and --error cannot be combined") + } + if (opts.Abort || opts.Error) && opts.FinishReason != agui.FinishReasonStop { + return opts, fmt.Errorf("--finish cannot be combined with --abort or --error") + } + return opts, nil +} + +func parseSharedStreamOption(key, value string, hasValue bool, token string, opts *sharedStreamOptions) (bool, error) { + switch key { + case "profile": + if !hasValue { + return false, fmt.Errorf("%s requires a value", token) + } + switch strings.ToLower(value) { + case "balanced", "tools", "errors", "artifacts": + opts.Profile = strings.ToLower(value) + default: + return false, fmt.Errorf("unknown profile %q", value) + } + case "seed": + if !hasValue { + return false, fmt.Errorf("%s requires a value", token) + } + seed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid seed %q", value) + } + opts.Seed, opts.SeedSet = seed, true + case "allow-abort": + opts.AllowAbort = true + case "allow-error": + opts.AllowError = true + default: + return false, nil + } + return true, nil +} + +func parseToolSpec(raw string, idx int) (toolSpec, error) { + parts := strings.Split(raw, "#") + spec := toolSpec{Name: strings.TrimSpace(parts[0]), SequenceIndex: idx + 1} + if spec.Name == "" { + return spec, fmt.Errorf("tool spec %q is missing a tool name", raw) + } + for _, tag := range parts[1:] { + tag = strings.TrimSpace(strings.ToLower(tag)) + if tag == "" { + continue + } + spec.Tags = append(spec.Tags, tag) + switch tag { + case "fail": + spec.Fail = true + case "approval": + spec.Approval = true + case "deny": + spec.Deny = true + case "delta": + spec.Delta = true + case "inputerror": + spec.InputError = true + case "prelim": + spec.Preliminary = true + case "provider": + spec.Provider = true + default: + return spec, fmt.Errorf("unknown tool tag %q in %q", tag, raw) + } + } + finalStates := 0 + for _, enabled := range []bool{spec.Fail, spec.Approval, spec.Deny} { + if enabled { + finalStates++ + } + } + if finalStates > 1 { + return spec, fmt.Errorf("tool spec %q has conflicting final state tags", raw) + } + return spec, nil +} diff --git a/pkg/connector/ai_parse_helpers.go b/pkg/connector/ai_parse_helpers.go new file mode 100644 index 0000000..8ce97dd --- /dev/null +++ b/pkg/connector/ai_parse_helpers.go @@ -0,0 +1,177 @@ +package connector + +import ( + "fmt" + "math/rand" + "strconv" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +func parseOptionToken(token string) (string, string, bool) { + trimmed := strings.TrimPrefix(strings.TrimSpace(token), "--") + key, value, ok := strings.Cut(trimmed, "=") + return strings.ToLower(strings.TrimSpace(key)), strings.TrimSpace(value), ok +} + +func parseValidatedInt(value string, hasValue bool, token, label string, maxValue int, allowZero bool) (int, error) { + if !hasValue { + return 0, fmt.Errorf("%s requires a value", token) + } + var n int + var err error + if allowZero { + n, err = parseNonNegativeInt(value, label) + } else { + n, err = parsePositiveInt(value, label) + } + if err != nil { + return 0, err + } + return n, validateMaxIntValue(n, maxValue, label) +} + +func parsePositiveInt(raw, label string) (int, error) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || n <= 0 { + return 0, fmt.Errorf("invalid %s %q", label, raw) + } + return n, nil +} + +func parseNonNegativeInt(raw, label string) (int, error) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || n < 0 { + return 0, fmt.Errorf("invalid %s %q", label, raw) + } + return n, nil +} + +func parseDurationRangeMS(value string, hasValue bool, token string) (time.Duration, time.Duration, error) { + return parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) +} + +func parseDurationRange(value string, hasValue bool, token, label string, maxValue time.Duration) (time.Duration, time.Duration, error) { + minValue, maxRange, err := parseIntRangeOption(value, hasValue, token, label, int(maxValue/time.Millisecond)) + if err != nil { + return 0, 0, err + } + return time.Duration(minValue) * time.Millisecond, time.Duration(maxRange) * time.Millisecond, nil +} + +func parseIntRangeOption(value string, hasValue bool, token, label string, maxValue int) (int, int, error) { + if !hasValue { + return 0, 0, fmt.Errorf("%s requires a value", token) + } + minValue, maxRange, ok := strings.Cut(value, ":") + if !ok { + n, err := parseNonNegativeInt(value, label) + if err != nil { + return 0, 0, err + } + if err := validateMaxIntValue(n, maxValue, label); err != nil { + return 0, 0, err + } + return n, n, nil + } + minInt, err := parseNonNegativeInt(minValue, label) + if err != nil { + return 0, 0, err + } + maxInt, err := parseNonNegativeInt(maxRange, label) + if err != nil { + return 0, 0, err + } + if maxInt < minInt { + return 0, 0, fmt.Errorf("invalid %s range %q", label, value) + } + if err := validateMaxIntValue(maxInt, maxValue, label); err != nil { + return 0, 0, err + } + return minInt, maxInt, nil +} + +func validateMaxIntValue(value, maxValue int, label string) error { + if value > maxValue { + return fmt.Errorf("%s %d exceeds the maximum of %d", label, value, maxValue) + } + return nil +} + +func rngForOptions(seedSet bool, seed, fallback int64) *rand.Rand { + if !seedSet { + seed = fallback + } + return rand.New(rand.NewSource(seed)) +} + +func chunkText(text string, rng *rand.Rand, minChunk, maxChunk int) []string { + if strings.TrimSpace(text) == "" { + return nil + } + if minChunk <= 0 { + minChunk = defaultChunkMin + } + if maxChunk < minChunk { + maxChunk = minChunk + } + var chunks []string + for len(text) > 0 { + size := minChunk + if maxChunk > minChunk { + size += rng.Intn(maxChunk - minChunk + 1) + } + if size > len(text) { + size = len(text) + } + parts := aistream.SplitTextUTF8(text, size) + chunk := parts[0] + chunks = append(chunks, chunk) + text = text[len(chunk):] + } + return chunks +} + +func splitCount(total, parts, index int) int { + if total <= 0 || parts <= 0 || index < 0 || index >= parts { + return 0 + } + base := total / parts + remainder := total % parts + if index < remainder { + return base + 1 + } + return base +} + +func sliceByStep(text string, parts, index int) string { + if parts <= 1 || text == "" { + return text + } + start := 0 + for i := 0; i < index; i++ { + start += splitCount(len(text), parts, i) + } + length := splitCount(len(text), parts, index) + if start >= len(text) || length <= 0 { + return "" + } + end := min(start+length, len(text)) + return text[start:end] +} + +func sanitizeToolName(name string) string { + name = strings.ToLower(strings.TrimSpace(name)) + var out strings.Builder + for _, r := range name { + if r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '_' || r == '-' { + out.WriteRune(r) + } + } + if out.Len() == 0 { + return "tool" + } + return out.String() +} diff --git a/pkg/connector/ai_plans.go b/pkg/connector/ai_plans.go new file mode 100644 index 0000000..ace50f4 --- /dev/null +++ b/pkg/connector/ai_plans.go @@ -0,0 +1,207 @@ +package connector + +import ( + "context" + "errors" + "fmt" + "math/rand" + "strconv" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +// resolveCommandSeed fills in an implicit seed for commands that derive their +// random behavior from the current time, so the continuation can replay the +// exact same sequence. +func resolveCommandSeed(cmd *parsedCommand, now time.Time) { + if cmd == nil { + return + } + switch { + case cmd.Lorem != nil && !cmd.Lorem.Options.SeedSet: + cmd.Lorem.Options.Seed = now.UnixNano() + cmd.Lorem.Options.SeedSet = true + case cmd.Tools != nil && !cmd.Tools.Options.SeedSet: + cmd.Tools.Options.Seed = now.UnixNano() + cmd.Tools.Options.SeedSet = true + case cmd.Random != nil && !cmd.Random.SeedSet: + cmd.Random.Seed = now.UnixNano() + cmd.Random.SeedSet = true + } +} + +// canonicalCommand returns a command string that, when re-parsed, reproduces +// the same run as cmd. If the original input already encoded all randomness +// inputs (e.g. an explicit --seed), it is returned as-is. +func canonicalCommand(input string, cmd *parsedCommand) string { + if cmd == nil { + return input + } + switch { + case cmd.Lorem != nil: + return ensureSeedFlag(input, cmd.Lorem.Options.Seed, cmd.Lorem.Options.SeedSet) + case cmd.Tools != nil: + return ensureSeedFlag(input, cmd.Tools.Options.Seed, cmd.Tools.Options.SeedSet) + case cmd.Random != nil: + return ensureSeedFlag(input, cmd.Random.Seed, cmd.Random.SeedSet) + } + return input +} + +func ensureSeedFlag(input string, seed int64, seedSet bool) string { + if !seedSet || hasSeedFlag(input) { + return input + } + return strings.TrimRight(input, " ") + " --seed=" + strconv.FormatInt(seed, 10) +} + +func hasSeedFlag(input string) bool { + for _, token := range strings.Fields(input) { + if strings.HasPrefix(token, "--seed=") || token == "--seed" { + return true + } + } + return false +} + +func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string) (*aistream.Run, error) { + return buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now, cmd, agentID, agentName, nil) +} + +func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { + runtime := virtualAIRuntime(now) + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) + writer := aistream.NewWriter(run, runtime.now) + writer.Start() + + runner := aiRunner{runtime: runtime, approvals: approvals} + var err error + switch { + case cmd == nil || cmd.Name == "help": + writer.Text(helpText()) + writer.Finish(agui.FinishReasonStop) + case cmd.Lorem != nil: + err = runner.runLorem(ctx, writer, *cmd.Lorem) + case cmd.Tools != nil: + err = runner.runTools(ctx, writer, *cmd.Tools) + case cmd.Random != nil: + err = runner.runRandom(ctx, writer, *cmd.Random) + } + if errors.Is(err, errApprovalRequested) { + err = nil + } + if err != nil { + writer.Error(err.Error()) + } else if err = agui.ValidateEventSequence(run.Events); err != nil { + writer.Error(err.Error()) + } + return run, nil +} + +func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd chaosCommand, agentID, agentName string) ([]aiRunPlan, error) { + seed := cmd.Seed + if !cmd.SeedSet { + seed = now.UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + runner := aiRunner{runtime: virtualAIRuntime(now)} + actions := max(3, min(cmd.MaxActions, int(cmd.Duration/time.Second))) + plans := make([]aiRunPlan, 0, cmd.Runs) + var delay time.Duration + for i := range cmd.Runs { + if i > 0 { + delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + } + runID := fmt.Sprintf("%s-%d", baseRunID, i+1) + randomCmd := randomCommand{ + Duration: cmd.Duration, + Actions: actions, + DelayMin: 180 * time.Millisecond, + DelayMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{ + Profile: cmd.Profile, + Seed: seed + int64(i+1)*97, + SeedSet: true, + AllowAbort: cmd.AllowAbort, + AllowError: cmd.AllowError, + AllowApproval: cmd.AllowApproval, + }, + } + parsed := &parsedCommand{Name: "stream", Random: &randomCmd} + run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName) + if err != nil { + return nil, err + } + plans = append(plans, aiRunPlan{ + Run: run, + Delay: delay, + EffectiveCommand: chaosSubRunCommand(randomCmd), + }) + } + return plans, nil +} + +func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd randomCommand, agentID, agentName string) ([]aiRunPlan, error) { + seed := cmd.Seed + if !cmd.SeedSet { + seed = now.UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + plans := make([]aiRunPlan, 0, cmd.Runs) + runner := aiRunner{runtime: virtualAIRuntime(now)} + var delay time.Duration + for i := range cmd.Runs { + if i > 0 { + delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + } + child := cmd + child.Runs = 1 + child.Seed = seed + int64(i+1)*97 + child.SeedSet = true + parsed := &parsedCommand{Name: "stream", Random: &child} + run, err := buildAIRunFromCommand(ctx, fmt.Sprintf("%s-%d", baseRunID, i+1), threadID, now.Add(delay), parsed, agentID, agentName) + if err != nil { + return nil, err + } + plans = append(plans, aiRunPlan{ + Run: run, + Delay: delay, + EffectiveCommand: streamSubRunCommand(child), + }) + } + return plans, nil +} + +func chaosSubRunCommand(cmd randomCommand) string { + return streamSubRunCommand(cmd) +} + +func streamSubRunCommand(cmd randomCommand) string { + parts := []string{ + "stream", + strconv.Itoa(int(cmd.Duration / time.Second)), + "--actions=" + strconv.Itoa(cmd.Actions), + "--delay-ms=" + strconv.Itoa(int(cmd.DelayMin/time.Millisecond)) + ":" + strconv.Itoa(int(cmd.DelayMax/time.Millisecond)), + "--profile=" + cmd.Profile, + "--seed=" + strconv.FormatInt(cmd.Seed, 10), + } + if cmd.Chars > 0 { + parts = append(parts, "--chars="+strconv.Itoa(cmd.Chars)) + } + if cmd.Terminal != "" { + parts = append(parts, "--terminal="+cmd.Terminal) + } + if !cmd.AllowApproval { + parts = append(parts, "--no-approval") + } + if cmd.AllowAbort { + parts = append(parts, "--allow-abort") + } + if cmd.AllowError { + parts = append(parts, "--allow-error") + } + return strings.Join(parts, " ") +} diff --git a/pkg/connector/ai_runner.go b/pkg/connector/ai_runner.go new file mode 100644 index 0000000..6fc497c --- /dev/null +++ b/pkg/connector/ai_runner.go @@ -0,0 +1,415 @@ +package connector + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "sort" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +func (r aiRunner) runLorem(ctx context.Context, w *aistream.Writer, cmd loremCommand) error { + opts := cmd.Options + rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) + steps := max(opts.Steps, 1) + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) + for step := range steps { + if opts.Steps > 0 { + w.StepStart(fmt.Sprintf("step-%d", step+1)) + } + emitDecorations(w, opts, cmd.Chars, step, steps) + if reasoning != "" { + w.Thinking(sliceByStep(reasoning, steps, step)) + } + for _, chunk := range chunkText(sliceByStep(text, steps, step), rng, opts.ChunkMin, opts.ChunkMax) { + w.Text(chunk) + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } + } + if opts.Steps > 0 { + w.StepFinish(fmt.Sprintf("step-%d", step+1)) + } + } + finishWriter(w, opts) + return nil +} + +func (r aiRunner) runTools(ctx context.Context, w *aistream.Writer, cmd toolsCommand) error { + opts := cmd.Options + rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) + phaseCount := max(len(cmd.Tools)+1, max(opts.Steps, 1)) + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) + for phase := range phaseCount { + w.StepStart(fmt.Sprintf("phase-%d", phase+1)) + emitDecorations(w, opts, cmd.Chars, phase, phaseCount) + if reasoning != "" { + w.Thinking(sliceByStep(reasoning, phaseCount, phase)) + } + for _, chunk := range chunkText(sliceByStep(text, phaseCount, phase), rng, opts.ChunkMin, opts.ChunkMax) { + w.Text(chunk) + } + if phase < len(cmd.Tools) { + if err := r.runToolSpec(ctx, w, cmd.Tools[phase], rng, opts); err != nil { + if errors.Is(err, errApprovalRequested) { + w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) + } + return err + } + } + w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) + } + finishWriter(w, opts) + return nil +} + +func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomCommand) error { + seed := cmd.Seed + if !cmd.SeedSet { + seed = r.runtime.now().UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + started := r.runtime.now() + var deadline time.Time + if cmd.Duration > 0 { + deadline = started.Add(cmd.Duration) + } + stepOpen := false + stepName := "" + actionOptions, actionWeightTotal := buildRandomActionOptions(cmd) + if cmd.Chars > 0 { + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { + w.Text(chunk) + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax)); err != nil { + return err + } + } + } + approvalRequested := false + handleTool := func(spec toolSpec) error { + if err := r.runToolSpec(ctx, w, spec, rng, defaultCommonOptions()); err != nil { + if spec.Approval { + approvalRequested = true + } + if errors.Is(err, errApprovalRequested) && stepOpen { + w.StepFinish(stepName) + stepOpen = false + stepName = "" + } + return err + } + if spec.Approval { + approvalRequested = true + } + return nil + } + for action := range cmd.Actions { + if !deadline.IsZero() && !r.runtime.now().Before(deadline) { + break + } + if action > 0 { + delay := r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax) + if !deadline.IsZero() && r.runtime.now().Add(delay).After(deadline) { + delay = deadline.Sub(r.runtime.now()) + } + if err := r.runtime.sleep(ctx, delay); err != nil { + return err + } + if !deadline.IsZero() && !r.runtime.now().Before(deadline) { + break + } + } + switch pickWeighted(actionOptions, actionWeightTotal, rng) { + case randomActionText: + text := "\n\n" + buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))) + for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { + w.Text(chunk) + } + case randomActionThinking: + w.Thinking(buildLoremText(30+rng.Intn(120), rand.New(rand.NewSource(rng.Int63())))) + case randomActionStep: + if stepOpen { + w.StepFinish(stepName) + stepOpen = false + stepName = "" + } else { + stepName = fmt.Sprintf("random-step-%d", action+1) + w.StepStart(stepName) + stepOpen = true + } + case randomActionTool: + if cmd.AllowApproval && cmd.Profile == "balanced" && action >= 10 && !approvalRequested { + if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { + return err + } + continue + } + if err := handleTool(toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionToolFail: + if err := handleTool(toolSpec{Name: randomToolName(rng), Fail: true, SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionToolDeny: + if err := handleTool(toolSpec{Name: randomToolName(rng), Deny: true, SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionToolApproval: + if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionSource: + sourceID := fmt.Sprintf("random-source-%d", action+1) + w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) + case randomActionDocument: + w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("random-doc-%d", action+1), "title": fmt.Sprintf("Random Document %d", action+1), "mediaType": "text/plain"}) + case randomActionFile: + w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/random-file-%d", action+1), "mediaType": "application/octet-stream"}) + case randomActionMetadata: + w.StateDelta(statePatch(map[string]any{"command": "stream", "seed": seed, "action": action + 1, "profile": cmd.Profile})) + case randomActionData: + w.Custom("com.beeper.data", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) + case randomActionDataTransient: + w.Custom("com.beeper.data.transient", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) + } + } + if stepOpen { + w.StepFinish(stepName) + } + terminal := chooseRandomTerminal(cmd, rng) + switch terminal { + case "abort": + w.Abort("DummyBridge random mode aborted") + case "error": + w.Error("DummyBridge random mode failed") + case agui.FinishReasonLength, agui.FinishReasonToolCalls, agui.FinishReasonContentFilter, agui.FinishReasonOther: + w.Finish(terminal) + default: + w.Finish(agui.FinishReasonStop) + } + return nil +} + +func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec toolSpec, rng *rand.Rand, opts commonCommandOptions) error { + toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) + input := toolRequestInput(spec) + approvalID := approvalIDForRun(w.Run.RunID, toolCallID) + var approval *agui.ToolApproval + if spec.Approval { + approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} + } + displayMetadata := toolDisplayMetadata(spec.Name) + w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) + annotateProviderRawEvent(w, spec, "tool_call_start") + if spec.InputError { + if encodedInput := jsonToolInput(input); encodedInput != "" { + w.ToolArgs(toolCallID, encodedInput, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + } + w.ToolError(toolCallID, spec.Name, input, "input-error") + annotateProviderRawEvent(w, spec, "tool_call_error") + return nil + } + if spec.Delta { + if encodedInput := jsonToolInput(input); encodedInput != "" { + for _, chunk := range chunkText(encodedInput, rng, opts.ChunkMin, opts.ChunkMax) { + w.ToolArgs(toolCallID, chunk, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } + } + } + } else { + if encodedInput := jsonToolInput(input); encodedInput != "" { + w.ToolArgs(toolCallID, encodedInput, encodedInput) + annotateProviderRawEvent(w, spec, "tool_call_args") + } + } + if spec.Preliminary { + w.ToolResult(toolCallID, fmt.Sprintf(`{"state":%q}`, agui.ToolResultStateStreaming), agui.ToolResultStateStreaming) + annotateProviderRawEvent(w, spec, "tool_call_result") + } + switch { + case spec.Approval: + if response, ok := r.approvals[approvalID]; ok { + if response.ID == "" { + response.ID = approvalID + } + w.ToolApprovalResponded(toolCallID, spec.Name, input, response) + annotateProviderRawEvent(w, spec, "approval_responded") + if !response.Approved { + return errApprovalDenied + } + return nil + } + w.ToolApprovalInputComplete(toolCallID, spec.Name, input) + annotateProviderRawEvent(w, spec, "tool_call_input_complete") + w.ToolApprovalRequestedWithMetadata(toolCallID, spec.Name, input, *approval, displayMetadata) + annotateProviderRawEvent(w, spec, "approval_requested") + return errApprovalRequested + case spec.Deny: + w.ToolDenied(toolCallID, spec.Name, input, approvalID, "denied") + annotateProviderRawEvent(w, spec, "tool_call_denied") + case spec.Fail: + w.ToolError(toolCallID, spec.Name, input, "DummyBridge synthetic tool failure") + annotateProviderRawEvent(w, spec, "tool_call_error") + default: + w.ToolEnd(toolCallID, spec.Name, input, nil) + annotateProviderRawEvent(w, spec, "tool_call_end") + } + return nil +} + +func toolRequestInput(spec toolSpec) any { + return nil +} + +func toolDisplayMetadata(name string) map[string]any { + type ToolProviderMetadata struct { + ID string `json:"id,omitempty"` + DisplayName string `json:"displayName,omitempty"` + IconURL string `json:"iconUrl,omitempty"` + } + type ToolDisplayMetadata struct { + DisplayName string `json:"displayName,omitempty"` + Description string `json:"description,omitempty"` + IconURL string `json:"iconUrl,omitempty"` + Provider *ToolProviderMetadata `json:"provider,omitempty"` + } + + metadata := ToolDisplayMetadata{} + switch strings.ToLower(name) { + case "calendar.get_events", "google_calendar.get_events", "google-calendar.get-events": + metadata.DisplayName = "List Calendar Events" + metadata.Provider = &ToolProviderMetadata{ + ID: "google-calendar", + DisplayName: "Google Calendar", + } + case "linear.list_issues", "linear.list-issues", "list_issues", "list-issues": + metadata.DisplayName = "List Issues" + metadata.Provider = &ToolProviderMetadata{ + ID: "linear", + DisplayName: "Linear", + } + case "shell": + metadata.DisplayName = "Run Command" + case "fetch": + metadata.DisplayName = "Fetch Web" + } + return compactJSONMap(metadata) +} + +func compactJSONMap(value any) map[string]any { + raw, err := json.Marshal(value) + if err != nil { + return nil + } + var out map[string]any + if err := json.Unmarshal(raw, &out); err != nil || len(out) == 0 { + return nil + } + return out +} + +func approvalIDForRun(runID, toolCallID string) string { + return "approval-" + runID + "-" + toolCallID +} + +func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { + if !spec.Provider || w == nil || w.Run == nil || len(w.Run.Events) == 0 { + return + } + w.Run.Events[len(w.Run.Events)-1]["rawEvent"] = map[string]any{ + "provider": "dummybridge", + "stage": stage, + "tool": spec.Name, + "sequence": spec.SequenceIndex, + "tags": spec.Tags, + } +} + +func jsonToolInput(input any) string { + if input == nil { + return "" + } + if inputMap, ok := input.(map[string]any); ok && len(inputMap) == 0 { + return "" + } + raw, err := json.Marshal(input) + if err != nil { + return "" + } + return string(raw) +} + +func finishWriter(w *aistream.Writer, opts commonCommandOptions) { + switch { + case opts.Abort: + w.Abort("DummyBridge synthetic abort") + case opts.Error: + w.Error("DummyBridge synthetic error") + default: + w.Finish(opts.FinishReason) + } +} + +func emitDecorations(w *aistream.Writer, opts commonCommandOptions, chars, step, steps int) { + if opts.Meta { + seed := opts.Seed + if !opts.SeedSet { + seed = int64(chars) + } + w.StateDelta(statePatch(map[string]any{"command": "demo", "seed": seed, "step": step + 1})) + } + for i := range splitCount(opts.Sources, steps, step) { + sourceID := fmt.Sprintf("demo-source-%d-%d", step+1, i+1) + w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/source/%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Source %d.%d", step+1, i+1)}) + } + for i := range splitCount(opts.Documents, steps, step) { + w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("demo-doc-%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Document %d.%d", step+1, i+1), "mediaType": "text/plain"}) + } + for i := range splitCount(opts.Files, steps, step) { + w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/demo-file-%d-%d", step+1, i+1), "mediaType": "application/octet-stream"}) + } + if step == 0 && opts.DataName != "" { + w.Custom("com.beeper.data", map[string]any{"name": opts.DataName, "value": map[string]any{"mode": "persistent", "stage": step + 1}}) + } + if step == 0 && opts.DataTransientName != "" { + w.Custom("com.beeper.data.transient", map[string]any{"name": opts.DataTransientName, "value": map[string]any{"mode": "transient", "stage": step + 1}}) + } +} + +func statePatch(values map[string]any) []map[string]any { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + patch := make([]map[string]any, 0, len(keys)) + for _, key := range keys { + patch = append(patch, map[string]any{ + "op": "add", + "path": "/" + key, + "value": values[key], + }) + } + return patch +} + +func (r aiRunner) sampleDelay(rng *rand.Rand, minDelay, maxDelay time.Duration) time.Duration { + if maxDelay <= minDelay { + return minDelay + } + return minDelay + time.Duration(rng.Int63n(int64(maxDelay-minDelay)+1)) +} diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index e84aae9..749667a 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -2,166 +2,13 @@ package connector import ( "context" - "encoding/json" - "errors" "fmt" - "math/rand" - "sort" - "strconv" - "strings" "time" "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" - "go.mau.fi/util/shlex" ) -var ( - errApprovalRequested = errors.New("approval requested") - errApprovalDenied = errors.New("approval denied") -) - -const ( - defaultChunkMin = 24 - defaultChunkMax = 96 - maxDemoChars = 96 * 1024 - maxDemoReasoningChars = 8192 - maxDemoToolSpecs = 16 - maxDemoSteps = 32 - maxDemoCollections = 16 - maxDemoRandomActions = 64 - maxDemoChaosRuns = 16 - maxDemoChaosActions = 64 - maxDemoDuration = 5 * time.Minute - maxDemoDelay = 30 * time.Second - maxDemoChunkChars = 512 - maxDemoStagger = 30 * time.Second -) - -const ( - randomActionText = "text" - randomActionThinking = "thinking" - randomActionStep = "step" - randomActionTool = "tool" - randomActionToolFail = "tool_fail" - randomActionToolDeny = "tool_deny" - randomActionToolApproval = "tool_approval" - randomActionSource = "source" - randomActionDocument = "document" - randomActionFile = "file" - randomActionMetadata = "metadata" - randomActionData = "data" - randomActionDataTransient = "data_transient" -) - -type commonCommandOptions struct { - ReasoningChars int - Steps int - Sources int - Documents int - Files int - Meta bool - DataName string - DataTransientName string - DelayMin time.Duration - DelayMax time.Duration - ChunkMin int - ChunkMax int - FinishReason string - Abort bool - Error bool - Seed int64 - SeedSet bool -} - -type loremCommand struct { - Chars int - Options commonCommandOptions -} - -type toolSpec struct { - Name string - Tags []string - Fail bool - Approval bool - Deny bool - Delta bool - InputError bool - Preliminary bool - Provider bool - SequenceIndex int -} - -type toolsCommand struct { - Chars int - Tools []toolSpec - Options commonCommandOptions -} - -type sharedStreamOptions struct { - Profile string - Seed int64 - SeedSet bool - AllowAbort bool - AllowError bool - AllowApproval bool -} - -type randomCommand struct { - Duration time.Duration - Actions int - Chars int - DelayMin time.Duration - DelayMax time.Duration - Terminal string - Runs int - StaggerMin time.Duration - StaggerMax time.Duration - sharedStreamOptions -} - -type randomActionOption struct { - name string - weight int -} - -type chaosCommand struct { - Runs int - Duration time.Duration - StaggerMin time.Duration - StaggerMax time.Duration - MaxActions int - sharedStreamOptions -} - -type parsedCommand struct { - Name string - Lorem *loremCommand - Tools *toolsCommand - Random *randomCommand - Chaos *chaosCommand -} - -type aiRuntime struct { - now func() time.Time - sleep func(context.Context, time.Duration) error -} - -type aiRunner struct { - runtime aiRuntime - approvals map[string]agui.ToolApprovalResponse -} - -type aiRunPlan struct { - Run *aistream.Run - Delay time.Duration - // EffectiveCommand is the canonical command form used to deterministically - // replay this run during approval continuation. For random/chaos sub-runs - // (where the seed was derived implicitly) this includes the resolved - // --seed=N so the continuation reproduces the same action sequence. - EffectiveCommand string -} - func virtualAIRuntime(now time.Time) aiRuntime { current := now return aiRuntime{ @@ -214,1191 +61,3 @@ func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now tim } return []aiRunPlan{{Run: run, EffectiveCommand: canonicalCommand(input, cmd)}}, nil } - -// resolveCommandSeed fills in an implicit seed for commands that derive their -// random behavior from the current time, so the continuation can replay the -// exact same sequence. -func resolveCommandSeed(cmd *parsedCommand, now time.Time) { - if cmd == nil { - return - } - switch { - case cmd.Lorem != nil && !cmd.Lorem.Options.SeedSet: - cmd.Lorem.Options.Seed = now.UnixNano() - cmd.Lorem.Options.SeedSet = true - case cmd.Tools != nil && !cmd.Tools.Options.SeedSet: - cmd.Tools.Options.Seed = now.UnixNano() - cmd.Tools.Options.SeedSet = true - case cmd.Random != nil && !cmd.Random.SeedSet: - cmd.Random.Seed = now.UnixNano() - cmd.Random.SeedSet = true - } -} - -// canonicalCommand returns a command string that, when re-parsed, reproduces -// the same run as cmd. If the original input already encoded all randomness -// inputs (e.g. an explicit --seed), it is returned as-is. -func canonicalCommand(input string, cmd *parsedCommand) string { - if cmd == nil { - return input - } - switch { - case cmd.Lorem != nil: - return ensureSeedFlag(input, cmd.Lorem.Options.Seed, cmd.Lorem.Options.SeedSet) - case cmd.Tools != nil: - return ensureSeedFlag(input, cmd.Tools.Options.Seed, cmd.Tools.Options.SeedSet) - case cmd.Random != nil: - return ensureSeedFlag(input, cmd.Random.Seed, cmd.Random.SeedSet) - } - return input -} - -func ensureSeedFlag(input string, seed int64, seedSet bool) string { - if !seedSet || hasSeedFlag(input) { - return input - } - return strings.TrimRight(input, " ") + " --seed=" + strconv.FormatInt(seed, 10) -} - -func hasSeedFlag(input string) bool { - for _, token := range strings.Fields(input) { - if strings.HasPrefix(token, "--seed=") || token == "--seed" { - return true - } - } - return false -} - -func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string) (*aistream.Run, error) { - return buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now, cmd, agentID, agentName, nil) -} - -func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { - runtime := virtualAIRuntime(now) - run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) - writer := aistream.NewWriter(run, runtime.now) - writer.Start() - - runner := aiRunner{runtime: runtime, approvals: approvals} - var err error - switch { - case cmd == nil || cmd.Name == "help": - writer.Text(helpText()) - writer.Finish(agui.FinishReasonStop) - case cmd.Lorem != nil: - err = runner.runLorem(ctx, writer, *cmd.Lorem) - case cmd.Tools != nil: - err = runner.runTools(ctx, writer, *cmd.Tools) - case cmd.Random != nil: - err = runner.runRandom(ctx, writer, *cmd.Random) - } - if errors.Is(err, errApprovalRequested) { - err = nil - } - if err != nil { - writer.Error(err.Error()) - } else if err = agui.ValidateEventSequence(run.Events); err != nil { - writer.Error(err.Error()) - } - return run, nil -} - -func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd chaosCommand, agentID, agentName string) ([]aiRunPlan, error) { - seed := cmd.Seed - if !cmd.SeedSet { - seed = now.UnixNano() - } - rng := rand.New(rand.NewSource(seed)) - runner := aiRunner{runtime: virtualAIRuntime(now)} - actions := max(3, min(cmd.MaxActions, int(cmd.Duration/time.Second))) - plans := make([]aiRunPlan, 0, cmd.Runs) - var delay time.Duration - for i := range cmd.Runs { - if i > 0 { - delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) - } - runID := fmt.Sprintf("%s-%d", baseRunID, i+1) - randomCmd := randomCommand{ - Duration: cmd.Duration, - Actions: actions, - DelayMin: 180 * time.Millisecond, - DelayMax: 900 * time.Millisecond, - sharedStreamOptions: sharedStreamOptions{ - Profile: cmd.Profile, - Seed: seed + int64(i+1)*97, - SeedSet: true, - AllowAbort: cmd.AllowAbort, - AllowError: cmd.AllowError, - AllowApproval: cmd.AllowApproval, - }, - } - parsed := &parsedCommand{Name: "stream", Random: &randomCmd} - run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName) - if err != nil { - return nil, err - } - plans = append(plans, aiRunPlan{ - Run: run, - Delay: delay, - EffectiveCommand: chaosSubRunCommand(randomCmd), - }) - } - return plans, nil -} - -func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd randomCommand, agentID, agentName string) ([]aiRunPlan, error) { - seed := cmd.Seed - if !cmd.SeedSet { - seed = now.UnixNano() - } - rng := rand.New(rand.NewSource(seed)) - plans := make([]aiRunPlan, 0, cmd.Runs) - runner := aiRunner{runtime: virtualAIRuntime(now)} - var delay time.Duration - for i := range cmd.Runs { - if i > 0 { - delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) - } - child := cmd - child.Runs = 1 - child.Seed = seed + int64(i+1)*97 - child.SeedSet = true - parsed := &parsedCommand{Name: "stream", Random: &child} - run, err := buildAIRunFromCommand(ctx, fmt.Sprintf("%s-%d", baseRunID, i+1), threadID, now.Add(delay), parsed, agentID, agentName) - if err != nil { - return nil, err - } - plans = append(plans, aiRunPlan{ - Run: run, - Delay: delay, - EffectiveCommand: streamSubRunCommand(child), - }) - } - return plans, nil -} - -func chaosSubRunCommand(cmd randomCommand) string { - return streamSubRunCommand(cmd) -} - -func streamSubRunCommand(cmd randomCommand) string { - parts := []string{ - "stream", - strconv.Itoa(int(cmd.Duration / time.Second)), - "--actions=" + strconv.Itoa(cmd.Actions), - "--delay-ms=" + strconv.Itoa(int(cmd.DelayMin/time.Millisecond)) + ":" + strconv.Itoa(int(cmd.DelayMax/time.Millisecond)), - "--profile=" + cmd.Profile, - "--seed=" + strconv.FormatInt(cmd.Seed, 10), - } - if cmd.Chars > 0 { - parts = append(parts, "--chars="+strconv.Itoa(cmd.Chars)) - } - if cmd.Terminal != "" { - parts = append(parts, "--terminal="+cmd.Terminal) - } - if !cmd.AllowApproval { - parts = append(parts, "--no-approval") - } - if cmd.AllowAbort { - parts = append(parts, "--allow-abort") - } - if cmd.AllowError { - parts = append(parts, "--allow-error") - } - return strings.Join(parts, " ") -} - -func parseCommand(input string) (*parsedCommand, error) { - tokens, err := shlex.Split(input) - if err != nil { - return nil, fmt.Errorf("invalid command syntax: %w", err) - } - if len(tokens) == 0 { - return &parsedCommand{Name: "help"}, nil - } - switch strings.ToLower(tokens[0]) { - case "help", "/help", "!help", "dummybridge": - return &parsedCommand{Name: "help"}, nil - case "stream-tools": - cmd, err := parseToolsCommand(tokens[1:]) - return &parsedCommand{Name: "stream-tools", Tools: cmd}, err - case "stream": - cmd, err := parseStreamCommand(tokens[1:]) - return &parsedCommand{Name: "stream", Random: cmd}, err - default: - return nil, fmt.Errorf("unknown AI demo command %q", tokens[0]) - } -} - -func helpText() string { - return strings.Join([]string{ - "DummyBridge demo commands:", - "help", - "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", - "stream-tools ... [common options]", - "Notes: stream enables approval requests by default; approval-tagged tools emit a separate Matrix approval event with reaction options.", - }, "\n") -} - -func defaultCommonOptions() commonCommandOptions { - return commonCommandOptions{ - DelayMin: 30 * time.Millisecond, - DelayMax: 150 * time.Millisecond, - ChunkMin: defaultChunkMin, - ChunkMax: defaultChunkMax, - FinishReason: agui.FinishReasonStop, - } -} - -func parseLoremCommand(tokens []string) (*loremCommand, error) { - if len(tokens) == 0 { - return nil, fmt.Errorf("text stream requires a character count") - } - count, err := parsePositiveInt(tokens[0], "character count") - if err != nil { - return nil, err - } - if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { - return nil, err - } - opts, err := parseCommonOptions(tokens[1:]) - if err != nil { - return nil, err - } - return &loremCommand{Chars: count, Options: opts}, nil -} - -func parseToolsCommand(tokens []string) (*toolsCommand, error) { - if len(tokens) < 2 { - return nil, fmt.Errorf("stream-tools requires a character count and at least one tool") - } - count, err := parsePositiveInt(tokens[0], "character count") - if err != nil { - return nil, err - } - if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { - return nil, err - } - var toolTokens, optTokens []string - for _, token := range tokens[1:] { - if strings.HasPrefix(token, "--") { - optTokens = append(optTokens, token) - } else { - toolTokens = append(toolTokens, token) - } - } - if len(toolTokens) == 0 { - return nil, fmt.Errorf("stream-tools requires at least one tool spec") - } - if err := validateMaxIntValue(len(toolTokens), maxDemoToolSpecs, "tool spec count"); err != nil { - return nil, err - } - opts, err := parseCommonOptions(optTokens) - if err != nil { - return nil, err - } - tools := make([]toolSpec, 0, len(toolTokens)) - for idx, token := range toolTokens { - spec, err := parseToolSpec(token, idx) - if err != nil { - return nil, err - } - tools = append(tools, spec) - } - return &toolsCommand{Chars: count, Tools: tools, Options: opts}, nil -} - -func parseRandomCommand(tokens []string) (*randomCommand, error) { - cmd := &randomCommand{ - Duration: 20 * time.Second, - Actions: 20, - DelayMin: 350 * time.Millisecond, - DelayMax: 1150 * time.Millisecond, - Runs: 1, - StaggerMin: 150 * time.Millisecond, - StaggerMax: 900 * time.Millisecond, - sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, - } - return parseStreamLikeCommand(tokens, cmd, false) -} - -func parseStreamCommand(tokens []string) (*randomCommand, error) { - cmd := &randomCommand{ - Duration: 20 * time.Second, - DelayMin: 350 * time.Millisecond, - DelayMax: 1150 * time.Millisecond, - Runs: 1, - StaggerMin: 150 * time.Millisecond, - StaggerMax: 900 * time.Millisecond, - sharedStreamOptions: sharedStreamOptions{Profile: "balanced", AllowApproval: true}, - } - return parseStreamLikeCommand(tokens, cmd, true) -} - -func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions bool) (*randomCommand, error) { - rest := tokens - if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { - seconds, err := parsePositiveInt(rest[0], "duration") - if err != nil { - return nil, err - } - if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { - return nil, err - } - cmd.Duration = time.Duration(seconds) * time.Second - rest = rest[1:] - } - if deriveActions && cmd.Actions == 0 { - cmd.Actions = max(3, min(maxDemoRandomActions, int(cmd.Duration/time.Second)*2)) - } - for _, token := range rest { - key, value, hasValue := parseOptionToken(token) - switch key { - case "actions": - n, err := parseValidatedInt(value, hasValue, token, "actions", maxDemoRandomActions, false) - if err != nil { - return nil, err - } - cmd.Actions = n - case "chars": - n, err := parseValidatedInt(value, hasValue, token, "character count", maxDemoChars, false) - if err != nil { - return nil, err - } - cmd.Chars = n - case "delay-ms": - minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) - if err != nil { - return nil, err - } - cmd.DelayMin, cmd.DelayMax = minDelay, maxDelay - case "terminal": - if !hasValue { - return nil, fmt.Errorf("%s requires a value", token) - } - switch strings.ToLower(value) { - case "stop", "finish": - cmd.Terminal = "finish" - case "abort", "error": - cmd.Terminal = strings.ToLower(value) - case "length", "tool-calls", "content-filter", "other": - cmd.Terminal = agui.NormalizeFinishReason(value) - default: - return nil, fmt.Errorf("unknown terminal %q", value) - } - case "runs": - n, err := parseValidatedInt(value, hasValue, token, "run count", maxDemoChaosRuns, false) - if err != nil { - return nil, err - } - cmd.Runs = n - case "stagger-ms": - minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) - if err != nil { - return nil, err - } - cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay - case "no-approval": - cmd.AllowApproval = false - default: - handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) - if err != nil || !handled { - if err != nil { - return nil, err - } - return nil, fmt.Errorf("unknown stream option %q", token) - } - } - } - return cmd, nil -} - -func parseChaosCommand(tokens []string) (*chaosCommand, error) { - cmd := &chaosCommand{ - Runs: 3, - Duration: 10 * time.Second, - StaggerMin: 150 * time.Millisecond, - StaggerMax: 900 * time.Millisecond, - MaxActions: 10, - sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, - } - rest := tokens - if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { - n, err := parsePositiveInt(rest[0], "run count") - if err != nil { - return nil, err - } - if err := validateMaxIntValue(n, maxDemoChaosRuns, "run count"); err != nil { - return nil, err - } - cmd.Runs = n - rest = rest[1:] - } - if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { - seconds, err := parsePositiveInt(rest[0], "duration") - if err != nil { - return nil, err - } - if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { - return nil, err - } - cmd.Duration = time.Duration(seconds) * time.Second - rest = rest[1:] - } - for _, token := range rest { - key, value, hasValue := parseOptionToken(token) - switch key { - case "stagger-ms": - minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) - if err != nil { - return nil, err - } - cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay - case "max-actions": - n, err := parseValidatedInt(value, hasValue, token, "max-actions", maxDemoChaosActions, false) - if err != nil { - return nil, err - } - cmd.MaxActions = n - default: - handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) - if err != nil || !handled { - if err != nil { - return nil, err - } - return nil, fmt.Errorf("unknown chaos option %q", token) - } - } - } - return cmd, nil -} - -func parseCommonOptions(tokens []string) (commonCommandOptions, error) { - opts := defaultCommonOptions() - for _, token := range tokens { - key, value, hasValue := parseOptionToken(token) - switch key { - case "reasoning": - n, err := parseValidatedInt(value, hasValue, token, "reasoning", maxDemoReasoningChars, true) - if err != nil { - return opts, err - } - opts.ReasoningChars = n - case "steps": - n, err := parseValidatedInt(value, hasValue, token, "steps", maxDemoSteps, false) - if err != nil { - return opts, err - } - opts.Steps = n - case "sources": - n, err := parseValidatedInt(value, hasValue, token, "sources", maxDemoCollections, true) - if err != nil { - return opts, err - } - opts.Sources = n - case "documents": - n, err := parseValidatedInt(value, hasValue, token, "documents", maxDemoCollections, true) - if err != nil { - return opts, err - } - opts.Documents = n - case "files": - n, err := parseValidatedInt(value, hasValue, token, "files", maxDemoCollections, true) - if err != nil { - return opts, err - } - opts.Files = n - case "meta": - opts.Meta = true - case "data": - if !hasValue { - return opts, fmt.Errorf("%s requires a value", token) - } - opts.DataName = value - case "data-transient": - if !hasValue { - return opts, fmt.Errorf("%s requires a value", token) - } - opts.DataTransientName = value - case "delay-ms": - minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) - if err != nil { - return opts, err - } - opts.DelayMin, opts.DelayMax = minDelay, maxDelay - case "chunk-chars": - minChunk, maxChunk, err := parseIntRangeOption(value, hasValue, token, "chunk-chars", maxDemoChunkChars) - if err != nil { - return opts, err - } - opts.ChunkMin, opts.ChunkMax = minChunk, maxChunk - case "seed": - if !hasValue { - return opts, fmt.Errorf("%s requires a value", token) - } - seed, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return opts, fmt.Errorf("invalid seed %q", value) - } - opts.Seed, opts.SeedSet = seed, true - case "finish": - if !hasValue { - return opts, fmt.Errorf("%s requires a value", token) - } - opts.FinishReason = agui.NormalizeFinishReason(value) - case "abort": - opts.Abort = true - case "error": - opts.Error = true - default: - return opts, fmt.Errorf("unknown option %q", token) - } - } - if opts.Abort && opts.Error { - return opts, fmt.Errorf("--abort and --error cannot be combined") - } - if (opts.Abort || opts.Error) && opts.FinishReason != agui.FinishReasonStop { - return opts, fmt.Errorf("--finish cannot be combined with --abort or --error") - } - return opts, nil -} - -func parseSharedStreamOption(key, value string, hasValue bool, token string, opts *sharedStreamOptions) (bool, error) { - switch key { - case "profile": - if !hasValue { - return false, fmt.Errorf("%s requires a value", token) - } - switch strings.ToLower(value) { - case "balanced", "tools", "errors", "artifacts": - opts.Profile = strings.ToLower(value) - default: - return false, fmt.Errorf("unknown profile %q", value) - } - case "seed": - if !hasValue { - return false, fmt.Errorf("%s requires a value", token) - } - seed, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return false, fmt.Errorf("invalid seed %q", value) - } - opts.Seed, opts.SeedSet = seed, true - case "allow-abort": - opts.AllowAbort = true - case "allow-error": - opts.AllowError = true - default: - return false, nil - } - return true, nil -} - -func parseToolSpec(raw string, idx int) (toolSpec, error) { - parts := strings.Split(raw, "#") - spec := toolSpec{Name: strings.TrimSpace(parts[0]), SequenceIndex: idx + 1} - if spec.Name == "" { - return spec, fmt.Errorf("tool spec %q is missing a tool name", raw) - } - for _, tag := range parts[1:] { - tag = strings.TrimSpace(strings.ToLower(tag)) - if tag == "" { - continue - } - spec.Tags = append(spec.Tags, tag) - switch tag { - case "fail": - spec.Fail = true - case "approval": - spec.Approval = true - case "deny": - spec.Deny = true - case "delta": - spec.Delta = true - case "inputerror": - spec.InputError = true - case "prelim": - spec.Preliminary = true - case "provider": - spec.Provider = true - default: - return spec, fmt.Errorf("unknown tool tag %q in %q", tag, raw) - } - } - finalStates := 0 - for _, enabled := range []bool{spec.Fail, spec.Approval, spec.Deny} { - if enabled { - finalStates++ - } - } - if finalStates > 1 { - return spec, fmt.Errorf("tool spec %q has conflicting final state tags", raw) - } - return spec, nil -} - -func (r aiRunner) runLorem(ctx context.Context, w *aistream.Writer, cmd loremCommand) error { - opts := cmd.Options - rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) - steps := max(opts.Steps, 1) - text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) - reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) - for step := range steps { - if opts.Steps > 0 { - w.StepStart(fmt.Sprintf("step-%d", step+1)) - } - emitDecorations(w, opts, cmd.Chars, step, steps) - if reasoning != "" { - w.Thinking(sliceByStep(reasoning, steps, step)) - } - for _, chunk := range chunkText(sliceByStep(text, steps, step), rng, opts.ChunkMin, opts.ChunkMax) { - w.Text(chunk) - if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { - return err - } - } - if opts.Steps > 0 { - w.StepFinish(fmt.Sprintf("step-%d", step+1)) - } - } - finishWriter(w, opts) - return nil -} - -func (r aiRunner) runTools(ctx context.Context, w *aistream.Writer, cmd toolsCommand) error { - opts := cmd.Options - rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) - phaseCount := max(len(cmd.Tools)+1, max(opts.Steps, 1)) - text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) - reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) - for phase := range phaseCount { - w.StepStart(fmt.Sprintf("phase-%d", phase+1)) - emitDecorations(w, opts, cmd.Chars, phase, phaseCount) - if reasoning != "" { - w.Thinking(sliceByStep(reasoning, phaseCount, phase)) - } - for _, chunk := range chunkText(sliceByStep(text, phaseCount, phase), rng, opts.ChunkMin, opts.ChunkMax) { - w.Text(chunk) - } - if phase < len(cmd.Tools) { - if err := r.runToolSpec(ctx, w, cmd.Tools[phase], rng, opts); err != nil { - if errors.Is(err, errApprovalRequested) { - w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) - } - return err - } - } - w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) - } - finishWriter(w, opts) - return nil -} - -func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomCommand) error { - seed := cmd.Seed - if !cmd.SeedSet { - seed = r.runtime.now().UnixNano() - } - rng := rand.New(rand.NewSource(seed)) - started := r.runtime.now() - var deadline time.Time - if cmd.Duration > 0 { - deadline = started.Add(cmd.Duration) - } - stepOpen := false - stepName := "" - actionOptions, actionWeightTotal := buildRandomActionOptions(cmd) - if cmd.Chars > 0 { - text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) - for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { - w.Text(chunk) - if err := r.runtime.sleep(ctx, r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax)); err != nil { - return err - } - } - } - approvalRequested := false - handleTool := func(spec toolSpec) error { - if err := r.runToolSpec(ctx, w, spec, rng, defaultCommonOptions()); err != nil { - if spec.Approval { - approvalRequested = true - } - if errors.Is(err, errApprovalRequested) && stepOpen { - w.StepFinish(stepName) - stepOpen = false - stepName = "" - } - return err - } - if spec.Approval { - approvalRequested = true - } - return nil - } - for action := range cmd.Actions { - if !deadline.IsZero() && !r.runtime.now().Before(deadline) { - break - } - if action > 0 { - delay := r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax) - if !deadline.IsZero() && r.runtime.now().Add(delay).After(deadline) { - delay = deadline.Sub(r.runtime.now()) - } - if err := r.runtime.sleep(ctx, delay); err != nil { - return err - } - if !deadline.IsZero() && !r.runtime.now().Before(deadline) { - break - } - } - switch pickWeighted(actionOptions, actionWeightTotal, rng) { - case randomActionText: - text := "\n\n" + buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))) - for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { - w.Text(chunk) - } - case randomActionThinking: - w.Thinking(buildLoremText(30+rng.Intn(120), rand.New(rand.NewSource(rng.Int63())))) - case randomActionStep: - if stepOpen { - w.StepFinish(stepName) - stepOpen = false - stepName = "" - } else { - stepName = fmt.Sprintf("random-step-%d", action+1) - w.StepStart(stepName) - stepOpen = true - } - case randomActionTool: - if cmd.AllowApproval && cmd.Profile == "balanced" && action >= 10 && !approvalRequested { - if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { - return err - } - continue - } - if err := handleTool(toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}); err != nil { - return err - } - case randomActionToolFail: - if err := handleTool(toolSpec{Name: randomToolName(rng), Fail: true, SequenceIndex: action + 1}); err != nil { - return err - } - case randomActionToolDeny: - if err := handleTool(toolSpec{Name: randomToolName(rng), Deny: true, SequenceIndex: action + 1}); err != nil { - return err - } - case randomActionToolApproval: - if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { - return err - } - case randomActionSource: - sourceID := fmt.Sprintf("random-source-%d", action+1) - w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) - case randomActionDocument: - w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("random-doc-%d", action+1), "title": fmt.Sprintf("Random Document %d", action+1), "mediaType": "text/plain"}) - case randomActionFile: - w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/random-file-%d", action+1), "mediaType": "application/octet-stream"}) - case randomActionMetadata: - w.StateDelta(statePatch(map[string]any{"command": "stream", "seed": seed, "action": action + 1, "profile": cmd.Profile})) - case randomActionData: - w.Custom("com.beeper.data", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) - case randomActionDataTransient: - w.Custom("com.beeper.data.transient", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) - } - } - if stepOpen { - w.StepFinish(stepName) - } - terminal := chooseRandomTerminal(cmd, rng) - switch terminal { - case "abort": - w.Abort("DummyBridge random mode aborted") - case "error": - w.Error("DummyBridge random mode failed") - case agui.FinishReasonLength, agui.FinishReasonToolCalls, agui.FinishReasonContentFilter, agui.FinishReasonOther: - w.Finish(terminal) - default: - w.Finish(agui.FinishReasonStop) - } - return nil -} - -func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec toolSpec, rng *rand.Rand, opts commonCommandOptions) error { - toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) - input := toolRequestInput(spec) - approvalID := approvalIDForRun(w.Run.RunID, toolCallID) - var approval *agui.ToolApproval - if spec.Approval { - approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} - } - displayMetadata := toolDisplayMetadata(spec.Name) - w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) - annotateProviderRawEvent(w, spec, "tool_call_start") - if spec.InputError { - if encodedInput := jsonToolInput(input); encodedInput != "" { - w.ToolArgs(toolCallID, encodedInput, nil) - annotateProviderRawEvent(w, spec, "tool_call_args") - } - w.ToolError(toolCallID, spec.Name, input, "input-error") - annotateProviderRawEvent(w, spec, "tool_call_error") - return nil - } - if spec.Delta { - if encodedInput := jsonToolInput(input); encodedInput != "" { - for _, chunk := range chunkText(encodedInput, rng, opts.ChunkMin, opts.ChunkMax) { - w.ToolArgs(toolCallID, chunk, nil) - annotateProviderRawEvent(w, spec, "tool_call_args") - if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { - return err - } - } - } - } else { - if encodedInput := jsonToolInput(input); encodedInput != "" { - w.ToolArgs(toolCallID, encodedInput, encodedInput) - annotateProviderRawEvent(w, spec, "tool_call_args") - } - } - if spec.Preliminary { - w.ToolResult(toolCallID, fmt.Sprintf(`{"state":%q}`, agui.ToolResultStateStreaming), agui.ToolResultStateStreaming) - annotateProviderRawEvent(w, spec, "tool_call_result") - } - switch { - case spec.Approval: - if response, ok := r.approvals[approvalID]; ok { - if response.ID == "" { - response.ID = approvalID - } - w.ToolApprovalResponded(toolCallID, spec.Name, input, response) - annotateProviderRawEvent(w, spec, "approval_responded") - if !response.Approved { - return errApprovalDenied - } - return nil - } - w.ToolApprovalInputComplete(toolCallID, spec.Name, input) - annotateProviderRawEvent(w, spec, "tool_call_input_complete") - w.ToolApprovalRequestedWithMetadata(toolCallID, spec.Name, input, *approval, displayMetadata) - annotateProviderRawEvent(w, spec, "approval_requested") - return errApprovalRequested - case spec.Deny: - w.ToolDenied(toolCallID, spec.Name, input, approvalID, "denied") - annotateProviderRawEvent(w, spec, "tool_call_denied") - case spec.Fail: - w.ToolError(toolCallID, spec.Name, input, "DummyBridge synthetic tool failure") - annotateProviderRawEvent(w, spec, "tool_call_error") - default: - w.ToolEnd(toolCallID, spec.Name, input, nil) - annotateProviderRawEvent(w, spec, "tool_call_end") - } - return nil -} - -func toolRequestInput(spec toolSpec) any { - return nil -} - -func toolDisplayMetadata(name string) map[string]any { - type ToolProviderMetadata struct { - ID string `json:"id,omitempty"` - DisplayName string `json:"displayName,omitempty"` - IconURL string `json:"iconUrl,omitempty"` - } - type ToolDisplayMetadata struct { - DisplayName string `json:"displayName,omitempty"` - Description string `json:"description,omitempty"` - IconURL string `json:"iconUrl,omitempty"` - Provider *ToolProviderMetadata `json:"provider,omitempty"` - } - - metadata := ToolDisplayMetadata{} - switch strings.ToLower(name) { - case "calendar.get_events", "google_calendar.get_events", "google-calendar.get-events": - metadata.DisplayName = "List Calendar Events" - metadata.Provider = &ToolProviderMetadata{ - ID: "google-calendar", - DisplayName: "Google Calendar", - } - case "linear.list_issues", "linear.list-issues", "list_issues", "list-issues": - metadata.DisplayName = "List Issues" - metadata.Provider = &ToolProviderMetadata{ - ID: "linear", - DisplayName: "Linear", - } - case "shell": - metadata.DisplayName = "Run Command" - case "fetch": - metadata.DisplayName = "Fetch Web" - } - return compactJSONMap(metadata) -} - -func compactJSONMap(value any) map[string]any { - raw, err := json.Marshal(value) - if err != nil { - return nil - } - var out map[string]any - if err := json.Unmarshal(raw, &out); err != nil || len(out) == 0 { - return nil - } - return out -} - -func approvalIDForRun(runID, toolCallID string) string { - return "approval-" + runID + "-" + toolCallID -} - -func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { - if !spec.Provider || w == nil || w.Run == nil || len(w.Run.Events) == 0 { - return - } - w.Run.Events[len(w.Run.Events)-1]["rawEvent"] = map[string]any{ - "provider": "dummybridge", - "stage": stage, - "tool": spec.Name, - "sequence": spec.SequenceIndex, - "tags": spec.Tags, - } -} - -func jsonToolInput(input any) string { - if input == nil { - return "" - } - if inputMap, ok := input.(map[string]any); ok && len(inputMap) == 0 { - return "" - } - raw, err := json.Marshal(input) - if err != nil { - return "" - } - return string(raw) -} - -func finishWriter(w *aistream.Writer, opts commonCommandOptions) { - switch { - case opts.Abort: - w.Abort("DummyBridge synthetic abort") - case opts.Error: - w.Error("DummyBridge synthetic error") - default: - w.Finish(opts.FinishReason) - } -} - -func emitDecorations(w *aistream.Writer, opts commonCommandOptions, chars, step, steps int) { - if opts.Meta { - seed := opts.Seed - if !opts.SeedSet { - seed = int64(chars) - } - w.StateDelta(statePatch(map[string]any{"command": "demo", "seed": seed, "step": step + 1})) - } - for i := range splitCount(opts.Sources, steps, step) { - sourceID := fmt.Sprintf("demo-source-%d-%d", step+1, i+1) - w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/source/%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Source %d.%d", step+1, i+1)}) - } - for i := range splitCount(opts.Documents, steps, step) { - w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("demo-doc-%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Document %d.%d", step+1, i+1), "mediaType": "text/plain"}) - } - for i := range splitCount(opts.Files, steps, step) { - w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/demo-file-%d-%d", step+1, i+1), "mediaType": "application/octet-stream"}) - } - if step == 0 && opts.DataName != "" { - w.Custom("com.beeper.data", map[string]any{"name": opts.DataName, "value": map[string]any{"mode": "persistent", "stage": step + 1}}) - } - if step == 0 && opts.DataTransientName != "" { - w.Custom("com.beeper.data.transient", map[string]any{"name": opts.DataTransientName, "value": map[string]any{"mode": "transient", "stage": step + 1}}) - } -} - -func statePatch(values map[string]any) []map[string]any { - keys := make([]string, 0, len(values)) - for key := range values { - keys = append(keys, key) - } - sort.Strings(keys) - patch := make([]map[string]any, 0, len(keys)) - for _, key := range keys { - patch = append(patch, map[string]any{ - "op": "add", - "path": "/" + key, - "value": values[key], - }) - } - return patch -} - -func (r aiRunner) sampleDelay(rng *rand.Rand, minDelay, maxDelay time.Duration) time.Duration { - if maxDelay <= minDelay { - return minDelay - } - return minDelay + time.Duration(rng.Int63n(int64(maxDelay-minDelay)+1)) -} - -func parseOptionToken(token string) (string, string, bool) { - trimmed := strings.TrimPrefix(strings.TrimSpace(token), "--") - key, value, ok := strings.Cut(trimmed, "=") - return strings.ToLower(strings.TrimSpace(key)), strings.TrimSpace(value), ok -} - -func parseValidatedInt(value string, hasValue bool, token, label string, maxValue int, allowZero bool) (int, error) { - if !hasValue { - return 0, fmt.Errorf("%s requires a value", token) - } - var n int - var err error - if allowZero { - n, err = parseNonNegativeInt(value, label) - } else { - n, err = parsePositiveInt(value, label) - } - if err != nil { - return 0, err - } - return n, validateMaxIntValue(n, maxValue, label) -} - -func parsePositiveInt(raw, label string) (int, error) { - n, err := strconv.Atoi(strings.TrimSpace(raw)) - if err != nil || n <= 0 { - return 0, fmt.Errorf("invalid %s %q", label, raw) - } - return n, nil -} - -func parseNonNegativeInt(raw, label string) (int, error) { - n, err := strconv.Atoi(strings.TrimSpace(raw)) - if err != nil || n < 0 { - return 0, fmt.Errorf("invalid %s %q", label, raw) - } - return n, nil -} - -func parseDurationRangeMS(value string, hasValue bool, token string) (time.Duration, time.Duration, error) { - return parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) -} - -func parseDurationRange(value string, hasValue bool, token, label string, maxValue time.Duration) (time.Duration, time.Duration, error) { - minValue, maxRange, err := parseIntRangeOption(value, hasValue, token, label, int(maxValue/time.Millisecond)) - if err != nil { - return 0, 0, err - } - return time.Duration(minValue) * time.Millisecond, time.Duration(maxRange) * time.Millisecond, nil -} - -func parseIntRangeOption(value string, hasValue bool, token, label string, maxValue int) (int, int, error) { - if !hasValue { - return 0, 0, fmt.Errorf("%s requires a value", token) - } - minValue, maxRange, ok := strings.Cut(value, ":") - if !ok { - n, err := parseNonNegativeInt(value, label) - if err != nil { - return 0, 0, err - } - if err := validateMaxIntValue(n, maxValue, label); err != nil { - return 0, 0, err - } - return n, n, nil - } - minInt, err := parseNonNegativeInt(minValue, label) - if err != nil { - return 0, 0, err - } - maxInt, err := parseNonNegativeInt(maxRange, label) - if err != nil { - return 0, 0, err - } - if maxInt < minInt { - return 0, 0, fmt.Errorf("invalid %s range %q", label, value) - } - if err := validateMaxIntValue(maxInt, maxValue, label); err != nil { - return 0, 0, err - } - return minInt, maxInt, nil -} - -func validateMaxIntValue(value, maxValue int, label string) error { - if value > maxValue { - return fmt.Errorf("%s %d exceeds the maximum of %d", label, value, maxValue) - } - return nil -} - -func rngForOptions(seedSet bool, seed, fallback int64) *rand.Rand { - if !seedSet { - seed = fallback - } - return rand.New(rand.NewSource(seed)) -} - -func chunkText(text string, rng *rand.Rand, minChunk, maxChunk int) []string { - if strings.TrimSpace(text) == "" { - return nil - } - if minChunk <= 0 { - minChunk = defaultChunkMin - } - if maxChunk < minChunk { - maxChunk = minChunk - } - var chunks []string - for len(text) > 0 { - size := minChunk - if maxChunk > minChunk { - size += rng.Intn(maxChunk - minChunk + 1) - } - if size > len(text) { - size = len(text) - } - parts := aistream.SplitTextUTF8(text, size) - chunk := parts[0] - chunks = append(chunks, chunk) - text = text[len(chunk):] - } - return chunks -} - -func splitCount(total, parts, index int) int { - if total <= 0 || parts <= 0 || index < 0 || index >= parts { - return 0 - } - base := total / parts - remainder := total % parts - if index < remainder { - return base + 1 - } - return base -} - -func sliceByStep(text string, parts, index int) string { - if parts <= 1 || text == "" { - return text - } - start := 0 - for i := 0; i < index; i++ { - start += splitCount(len(text), parts, i) - } - length := splitCount(len(text), parts, index) - if start >= len(text) || length <= 0 { - return "" - } - end := min(start+length, len(text)) - return text[start:end] -} - -func sanitizeToolName(name string) string { - name = strings.ToLower(strings.TrimSpace(name)) - var out strings.Builder - for _, r := range name { - if r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '_' || r == '-' { - out.WriteRune(r) - } - } - if out.Len() == 0 { - return "tool" - } - return out.String() -} diff --git a/pkg/connector/ai_types.go b/pkg/connector/ai_types.go new file mode 100644 index 0000000..80c3ee2 --- /dev/null +++ b/pkg/connector/ai_types.go @@ -0,0 +1,156 @@ +package connector + +import ( + "context" + "errors" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +var ( + errApprovalRequested = errors.New("approval requested") + errApprovalDenied = errors.New("approval denied") +) + +const ( + defaultChunkMin = 24 + defaultChunkMax = 96 + maxDemoChars = 96 * 1024 + maxDemoReasoningChars = 8192 + maxDemoToolSpecs = 16 + maxDemoSteps = 32 + maxDemoCollections = 16 + maxDemoRandomActions = 64 + maxDemoChaosRuns = 16 + maxDemoChaosActions = 64 + maxDemoDuration = 5 * time.Minute + maxDemoDelay = 30 * time.Second + maxDemoChunkChars = 512 + maxDemoStagger = 30 * time.Second +) + +const ( + randomActionText = "text" + randomActionThinking = "thinking" + randomActionStep = "step" + randomActionTool = "tool" + randomActionToolFail = "tool_fail" + randomActionToolDeny = "tool_deny" + randomActionToolApproval = "tool_approval" + randomActionSource = "source" + randomActionDocument = "document" + randomActionFile = "file" + randomActionMetadata = "metadata" + randomActionData = "data" + randomActionDataTransient = "data_transient" +) + +type commonCommandOptions struct { + ReasoningChars int + Steps int + Sources int + Documents int + Files int + Meta bool + DataName string + DataTransientName string + DelayMin time.Duration + DelayMax time.Duration + ChunkMin int + ChunkMax int + FinishReason string + Abort bool + Error bool + Seed int64 + SeedSet bool +} + +type loremCommand struct { + Chars int + Options commonCommandOptions +} + +type toolSpec struct { + Name string + Tags []string + Fail bool + Approval bool + Deny bool + Delta bool + InputError bool + Preliminary bool + Provider bool + SequenceIndex int +} + +type toolsCommand struct { + Chars int + Tools []toolSpec + Options commonCommandOptions +} + +type sharedStreamOptions struct { + Profile string + Seed int64 + SeedSet bool + AllowAbort bool + AllowError bool + AllowApproval bool +} + +type randomCommand struct { + Duration time.Duration + Actions int + Chars int + DelayMin time.Duration + DelayMax time.Duration + Terminal string + Runs int + StaggerMin time.Duration + StaggerMax time.Duration + sharedStreamOptions +} + +type randomActionOption struct { + name string + weight int +} + +type chaosCommand struct { + Runs int + Duration time.Duration + StaggerMin time.Duration + StaggerMax time.Duration + MaxActions int + sharedStreamOptions +} + +type parsedCommand struct { + Name string + Lorem *loremCommand + Tools *toolsCommand + Random *randomCommand + Chaos *chaosCommand +} + +type aiRuntime struct { + now func() time.Time + sleep func(context.Context, time.Duration) error +} + +type aiRunner struct { + runtime aiRuntime + approvals map[string]agui.ToolApprovalResponse +} + +type aiRunPlan struct { + Run *aistream.Run + Delay time.Duration + // EffectiveCommand is the canonical command form used to deterministically + // replay this run during approval continuation. For random/chaos sub-runs + // (where the seed was derived implicitly) this includes the resolved + // --seed=N so the continuation reproduces the same action sequence. + EffectiveCommand string +} From 084b8ef5d53abae2117e6032b91d8d733403f4b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 04:24:12 +0200 Subject: [PATCH 44/46] wip --- pkg/connector/ai_commands.go | 4 ++-- pkg/connector/ai_parse_helpers.go | 4 ---- pkg/connector/ai_plans.go | 14 +++----------- pkg/connector/ai_runtime.go | 2 +- pkg/connector/ai_runtime_test.go | 24 ++++++++++++------------ pkg/connector/client.go | 9 --------- 6 files changed, 18 insertions(+), 39 deletions(-) diff --git a/pkg/connector/ai_commands.go b/pkg/connector/ai_commands.go index 1ca9945..26fd54f 100644 --- a/pkg/connector/ai_commands.go +++ b/pkg/connector/ai_commands.go @@ -169,7 +169,7 @@ func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions b } cmd.Chars = n case "delay-ms": - minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) if err != nil { return nil, err } @@ -323,7 +323,7 @@ func parseCommonOptions(tokens []string) (commonCommandOptions, error) { } opts.DataTransientName = value case "delay-ms": - minDelay, maxDelay, err := parseDurationRangeMS(value, hasValue, token) + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) if err != nil { return opts, err } diff --git a/pkg/connector/ai_parse_helpers.go b/pkg/connector/ai_parse_helpers.go index 8ce97dd..9b06d7d 100644 --- a/pkg/connector/ai_parse_helpers.go +++ b/pkg/connector/ai_parse_helpers.go @@ -49,10 +49,6 @@ func parseNonNegativeInt(raw, label string) (int, error) { return n, nil } -func parseDurationRangeMS(value string, hasValue bool, token string) (time.Duration, time.Duration, error) { - return parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) -} - func parseDurationRange(value string, hasValue bool, token, label string, maxValue time.Duration) (time.Duration, time.Duration, error) { minValue, maxRange, err := parseIntRangeOption(value, hasValue, token, label, int(maxValue/time.Millisecond)) if err != nil { diff --git a/pkg/connector/ai_plans.go b/pkg/connector/ai_plans.go index ace50f4..b96987b 100644 --- a/pkg/connector/ai_plans.go +++ b/pkg/connector/ai_plans.go @@ -67,10 +67,6 @@ func hasSeedFlag(input string) bool { return false } -func buildAIRunFromCommand(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string) (*aistream.Run, error) { - return buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now, cmd, agentID, agentName, nil) -} - func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { runtime := virtualAIRuntime(now) run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) @@ -131,14 +127,14 @@ func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now t }, } parsed := &parsedCommand{Name: "stream", Random: &randomCmd} - run, err := buildAIRunFromCommand(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName) + run, err := buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName, nil) if err != nil { return nil, err } plans = append(plans, aiRunPlan{ Run: run, Delay: delay, - EffectiveCommand: chaosSubRunCommand(randomCmd), + EffectiveCommand: streamSubRunCommand(randomCmd), }) } return plans, nil @@ -162,7 +158,7 @@ func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now child.Seed = seed + int64(i+1)*97 child.SeedSet = true parsed := &parsedCommand{Name: "stream", Random: &child} - run, err := buildAIRunFromCommand(ctx, fmt.Sprintf("%s-%d", baseRunID, i+1), threadID, now.Add(delay), parsed, agentID, agentName) + run, err := buildAIRunFromCommandWithApprovals(ctx, fmt.Sprintf("%s-%d", baseRunID, i+1), threadID, now.Add(delay), parsed, agentID, agentName, nil) if err != nil { return nil, err } @@ -175,10 +171,6 @@ func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now return plans, nil } -func chaosSubRunCommand(cmd randomCommand) string { - return streamSubRunCommand(cmd) -} - func streamSubRunCommand(cmd randomCommand) string { parts := []string{ "stream", diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go index 749667a..36f3ddb 100644 --- a/pkg/connector/ai_runtime.go +++ b/pkg/connector/ai_runtime.go @@ -55,7 +55,7 @@ func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now tim if cmd != nil && cmd.Random != nil && cmd.Random.Runs > 1 { return buildAIStreamRunPlans(ctx, runID, threadID, now, *cmd.Random, agentID, agentName) } - run, err := buildAIRunFromCommand(ctx, runID, threadID, now, cmd, agentID, agentName) + run, err := buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now, cmd, agentID, agentName, nil) if err != nil { return nil, err } diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index b387866..a94d520 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -228,10 +228,10 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { AgentName: run.AgentName, SeqStart: prompt.SeqStart, } - continuation, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { ID: prompt.ID, Approved: true, - }, time.Unix(20, 0)) + }}, time.Unix(20, 0)) if err != nil { t.Fatal(err) } @@ -321,10 +321,10 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { t.Fatalf("approval-requested stream event has bad choice shape: %#v", choices[0]) } - continuation, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { ID: prompt.ID, Approved: true, - }, time.Unix(20, 0)) + }}, time.Unix(20, 0)) if err != nil { t.Fatal(err) } @@ -361,10 +361,10 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, - }, time.Unix(20, 0)) + }}, time.Unix(20, 0)) if err != nil { t.Fatal(err) } @@ -420,11 +420,11 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: false, Reason: "denied", - }, time.Unix(20, 0)) + }}, time.Unix(20, 0)) if err != nil { t.Fatal(err) } @@ -916,10 +916,10 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, - }, time.Unix(20, 0)) + }}, time.Unix(20, 0)) if err != nil { t.Fatal(err) } @@ -1000,10 +1000,10 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { AgentName: "AI", SeqStart: 50, } - continuation, err := buildAIApprovalContinuationRun(context.Background(), approvalCtx, agui.ToolApprovalResponse{ + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, - }, now.Add(time.Hour)) + }}, now.Add(time.Hour)) if err != nil { t.Fatalf("continuation failed: %v", err) } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 2898b7b..92c9b16 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -966,15 +966,6 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid Msg("Queued AI approval continuation") } -func buildAIApprovalContinuationRun(ctx context.Context, approvalCtx aistream.ApprovalContext, response agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { - if response.ID == "" { - response.ID = approvalCtx.ID - } - return buildAIApprovalContinuationRunWithApprovals(ctx, approvalCtx, map[string]agui.ToolApprovalResponse{ - response.ID: response, - }, now) -} - func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { cmd, err := parseCommand(approvalCtx.Command) if err != nil { From e51341e001ac78dae72fe6c875b8955a1fd5feaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 18:30:20 +0200 Subject: [PATCH 45/46] wip --- build.sh | 4 ++++ go.mod | 32 ++++++++++++++++---------------- go.sum | 30 ++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 16 deletions(-) create mode 100755 build.sh diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..8c13e8b --- /dev/null +++ b/build.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env sh +set -eu + +go build -o dummybridge ./cmd/dummybridge diff --git a/go.mod b/go.mod index bddb6ab..02c9550 100644 --- a/go.mod +++ b/go.mod @@ -1,39 +1,39 @@ module github.com/beeper/dummybridge -go 1.24.0 +go 1.25.0 toolchain go1.25.6 require ( github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 - github.com/rs/zerolog v1.34.0 - go.mau.fi/util v0.9.5 - maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b + github.com/rs/zerolog v1.35.1 + go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 + maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 ) require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.2.0 // indirect github.com/coder/websocket v1.8.14 // indirect - github.com/coreos/go-systemd/v22 v22.6.0 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/coreos/go-systemd/v22 v22.7.0 // indirect + github.com/lib/pq v1.12.3 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.33 // indirect - github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect + github.com/mattn/go-sqlite3 v1.14.44 // indirect + github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 // indirect github.com/rs/xid v1.6.0 // indirect github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - github.com/yuin/goldmark v1.7.16 // indirect + github.com/yuin/goldmark v1.8.2 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect - golang.org/x/crypto v0.47.0 // indirect - golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect - golang.org/x/net v0.49.0 // indirect - golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect + golang.org/x/net v0.53.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index d9630eb..7ec830b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 h1:Pw2qyz5mizv/UL4JTKiK1sbYfUl6o8dk/KcNyFlSFG0= @@ -9,11 +11,15 @@ github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6p github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= +github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= +github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -23,8 +29,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -32,6 +42,8 @@ github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= +github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -48,25 +60,41 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= +github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= +go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 h1:YPEmc+li7TF6C9AdRTcSLMb6yCHdF27/wNT7kFLIVNg= +go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25/go.mod h1:jE9FfhbgEgAwxei6lomO9v8zdCIATcquONUu4vjRwSs= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= +golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -77,3 +105,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= +maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 h1:zNC9eVAhw8FhKpM3AxNAh/iy75UEYX91uJUvqqAYlvo= +maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4/go.mod h1:3sOGhXi3P1V6/NruTA0gujkvTypXVUraWktCuTGyDuM= From 903f484bf810d82aa89df12bd04e5785af13569d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 24 May 2026 19:56:16 +0200 Subject: [PATCH 46/46] wip --- pkg/connector/client.go | 142 +----------------------- pkg/connector/client_test.go | 207 ----------------------------------- pkg/connector/commands.go | 64 +++++++++++ pkg/connector/connector.go | 1 - 4 files changed, 65 insertions(+), 349 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 92c9b16..6ff8c0b 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -47,17 +47,13 @@ type aiRunSession struct { var _ bridgev2.NetworkAPI = (*DummyClient)(nil) var _ bridgev2.IdentifierResolvingNetworkAPI = (*DummyClient)(nil) -var _ bridgev2.ContactListingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.BackfillingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.DeleteChatHandlingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) const ( - aiGhostID networkid.UserID = "ai" - aiGhostName string = "AI" - aiPortalIDPrefix string = "ai-" - dummyAIAgentName string = "Dummy" + dummyAIAgentName string = "Dummy" ) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -149,14 +145,6 @@ func (dc *DummyClient) IsThisUser(ctx context.Context, userID networkid.UserID) } func (dc *DummyClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { - if isAIPortalID(portal.ID) { - roomType := database.RoomTypeDM - return &bridgev2.ChatInfo{ - Name: ptr.Ptr(aiGhostName), - Type: ptr.Ptr(roomType), - }, nil - } - portalIDPrefix := string(portal.ID) if len(portalIDPrefix) > 6 { portalIDPrefix = portalIDPrefix[:6] @@ -183,17 +171,6 @@ func (dc *DummyClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) } func (tc *DummyClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { - if ghost.ID == aiGhostID { - name := aiGhostName - isBot := true - ghost.UpdateName(ctx, name) - return &bridgev2.UserInfo{ - Identifiers: []string{string(aiGhostID), "AI"}, - Name: &name, - IsBot: &isBot, - }, nil - } - name := ghost.Name if name == "" { name = string(ghost.ID) @@ -250,10 +227,6 @@ func (dc *DummyClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma StreamOrder: time.Now().UnixNano(), } - if msg.Portal != nil && (isAIPortalID(msg.Portal.ID) || isAIDemoCommandContent(msg.Content)) { - dc.queueAIResponse(ctx, msg.Portal, msg.Content) - } - return resp, nil } @@ -442,28 +415,6 @@ func getRemoteEchoBehavior(content *event.MessageEventContent) remoteEchoBehavio return remoteEchoBehavior{pending: true, delay: delay} } -func isAIDemoCommandContent(content *event.MessageEventContent) bool { - if content == nil { - return false - } - body := strings.TrimSpace(content.Body) - tokens := strings.Fields(body) - if len(tokens) == 0 { - return false - } - switch strings.ToLower(tokens[0]) { - case "help", "/help", "!help": - return true - case "stream", "stream-tools": - _, err := parseCommand(body) - return err == nil - case "dummybridge": - return len(tokens) > 1 && strings.EqualFold(tokens[1], "help") - default: - return false - } -} - // ensureAISenderInvited queues a ChatInfoChange that adds the AI sender ghost // to the given portal. The bridge's default portal generator can create // portals with members=0, in which case the per-portal AI sender chosen by @@ -474,9 +425,6 @@ func (dc *DummyClient) ensureAISenderInvited(portal *bridgev2.Portal, sender net if dc == nil || dc.UserLogin == nil || portal == nil || sender == "" { return } - if isAIPortalID(portal.ID) { - return - } changes := &bridgev2.ChatMemberList{MemberMap: bridgev2.ChatMemberMap{}} changes.MemberMap.Set(bridgev2.ChatMember{ EventSender: bridgev2.EventSender{Sender: sender}, @@ -502,16 +450,10 @@ func dummyAISenderForPortal(portal *bridgev2.Portal) networkid.UserID { if portal == nil { return networkid.UserID(dummyAIAgentName) } - if isAIPortalID(portal.ID) { - return aiGhostID - } return stablePortalUserIDByIndex(portal.ID, 0) } func dummyAIAgentNameForPortal(portal *bridgev2.Portal) string { - if portal != nil && isAIPortalID(portal.ID) { - return aiGhostName - } return dummyAIAgentName } @@ -1204,10 +1146,6 @@ func (dc *DummyClient) HandleMatrixAcceptMessageRequest(ctx context.Context, msg } func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { - if isAIIdentifier(identifier) { - return dc.resolveAIIdentifier(ctx, createChat) - } - userID := networkid.UserID(identifier) portalID := randomPortalID() portalKey := networkid.PortalKey{ @@ -1256,81 +1194,3 @@ func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, }, nil } - -func (dc *DummyClient) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { - contact, err := dc.resolveAIIdentifier(ctx, false) - if err != nil { - return nil, err - } - return []*bridgev2.ResolveIdentifierResponse{contact}, nil -} - -func (dc *DummyClient) resolveAIIdentifier(ctx context.Context, createChat bool) (*bridgev2.ResolveIdentifierResponse, error) { - ghost, err := dc.UserLogin.Bridge.GetGhostByID(ctx, aiGhostID) - if err != nil { - return nil, fmt.Errorf("failed to get AI ghost: %w", err) - } - userInfo, _ := dc.GetUserInfo(ctx, ghost) - response := &bridgev2.ResolveIdentifierResponse{ - Ghost: ghost, - UserID: aiGhostID, - UserInfo: userInfo, - } - if !createChat { - return response, nil - } - - portalID := networkid.PortalID(aiPortalIDPrefix + string(randomPortalID())) - portalKey := networkid.PortalKey{ID: portalID, Receiver: dc.UserLogin.ID} - portal, err := dc.UserLogin.Bridge.GetPortalByKey(ctx, portalKey) - if err != nil { - return nil, fmt.Errorf("failed to get AI portal: %w", err) - } - roomType := database.RoomTypeDM - response.Chat = &bridgev2.CreateChatResponse{ - Portal: portal, - PortalKey: portalKey, - PortalInfo: &bridgev2.ChatInfo{ - Name: ptr.Ptr(aiGhostName), - Topic: ptr.Ptr("DummyBridge AI chat"), - Type: ptr.Ptr(roomType), - CanBackfill: true, - Members: &bridgev2.ChatMemberList{ - MemberMap: bridgev2.ChatMemberMap{ - networkid.UserID(dc.UserLogin.ID): { - EventSender: bridgev2.EventSender{ - IsFromMe: true, - Sender: networkid.UserID(dc.UserLogin.ID), - }, - Membership: event.MembershipJoin, - PowerLevel: ptr.Ptr(100), - }, - aiGhostID: { - EventSender: bridgev2.EventSender{ - Sender: aiGhostID, - }, - Membership: event.MembershipJoin, - PowerLevel: ptr.Ptr(50), - MemberEventExtra: map[string]any{ - "displayname": aiGhostName, - "com.beeper.ai.agent": string(aiGhostID), - "com.beeper.ai.model_id": aistream.DefaultModel, - "com.beeper.ai.protocol": "ag-ui", - "com.beeper.ai.static_ai": true, - }, - }, - }, - }, - }, - } - return response, nil -} - -func isAIIdentifier(identifier string) bool { - identifier = strings.TrimSpace(identifier) - return strings.EqualFold(identifier, string(aiGhostID)) || strings.EqualFold(identifier, aiGhostName) -} - -func isAIPortalID(portalID networkid.PortalID) bool { - return strings.HasPrefix(string(portalID), aiPortalIDPrefix) -} diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 3290326..ed2a445 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -1,18 +1,10 @@ package connector import ( - "context" - "encoding/json" "testing" "time" - "github.com/beeper/ai-bridge/pkg/ag-ui" - "github.com/beeper/ai-bridge/pkg/ai-stream" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) func TestGetRemoteEchoBehavior(t *testing.T) { @@ -45,202 +37,3 @@ func TestGetRemoteEchoBehavior(t *testing.T) { }) } } - -func TestAIDemoCommandContentOnlyMatchesExplicitDemoCommands(t *testing.T) { - for _, body := range []string{ - "help", - "/help", - "!help", - "dummybridge help", - "stream 20", - "stream-tools 100 shell", - "stream 1 --runs=2", - } { - if !isAIDemoCommandContent(&event.MessageEventContent{Body: body}) { - t.Fatalf("expected AI demo command for %q", body) - } - } - for _, body := range []string{ - "", - "hello", - "dummybridge", - "remote-echo delay 1s", - } { - if isAIDemoCommandContent(&event.MessageEventContent{Body: body}) { - t.Fatalf("did not expect AI demo command for %q", body) - } - } -} - -func TestDummyAISenderForPortalSupportsDedicatedAndNormalRooms(t *testing.T) { - if got := dummyAISenderForPortal(&bridgev2.Portal{Portal: &database.Portal{PortalKey: networkid.PortalKey{ID: "ai-room"}}}); got != aiGhostID { - t.Fatalf("AI portal sender = %q, want %q", got, aiGhostID) - } - if got := dummyAISenderForPortal(&bridgev2.Portal{Portal: &database.Portal{PortalKey: networkid.PortalKey{ID: "normal-room"}}}); got != stablePortalUserIDByIndex("normal-room", 0) { - t.Fatalf("normal portal sender = %q", got) - } -} - -func TestResolveApprovalOnceKeepsFirstSelection(t *testing.T) { - client := &DummyClient{} - selected, first := client.resolveApprovalOnce("approval-1", "allow") - if !first || selected != "allow" { - t.Fatalf("first selection = %q first=%v", selected, first) - } - selected, first = client.resolveApprovalOnce("approval-1", "deny") - if first || selected != "allow" { - t.Fatalf("second selection = %q first=%v", selected, first) - } -} - -func TestInitialAIAnchorRunOmitsPreviewAndTerminalMetadata(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.Start() - writer.Text("visible preview") - writer.Finish(agui.FinishReasonStop) - - anchor := initialAIAnchorRun(*run) - if anchor.Preview.Text != "" { - t.Fatalf("anchor should not include initial preview text: %#v", anchor.Preview) - } - uiMessage := anchor.InitialUIMessage() - if len(uiMessage.Parts) != 0 { - t.Fatalf("anchor UI message should wait for stream deltas: %#v", uiMessage.Parts) - } - if uiMessage.Metadata["runId"] != run.RunID { - t.Fatalf("anchor UI metadata missing run id: %#v", uiMessage.Metadata) - } - if anchor.Status.State != "streaming" { - t.Fatalf("anchor status = %#v, want streaming", anchor.Status) - } - if anchor.Usage.TotalTokens != 0 || anchor.Usage.CompletionTokens != 0 || anchor.Usage.PromptTokens != 0 { - t.Fatalf("anchor leaked terminal usage: %#v", anchor.Usage) - } - if run.Status.State != "complete" || run.Usage.TotalTokens == 0 { - t.Fatalf("final run should keep terminal metadata: status=%#v usage=%#v", run.Status, run.Usage) - } -} - -func TestCarrierTimestampUsesEventOffsetFromRunStart(t *testing.T) { - run := aistream.Run{ - Events: []agui.Event{ - {"timestamp": int64(10_000), "type": agui.EventRunStarted, "threadId": "thread-1"}, - {"timestamp": int64(13_500), "type": agui.EventTextMessageContent, "messageId": "msg-1", "delta": "later"}, - }, - } - streamStart := time.Unix(100, 0) - target := carrierTimestamp(run, aistream.Carrier{Envelopes: []aistream.Envelope{{ - Part: run.Events[1], - }}}, streamStart) - if want := streamStart.Add(3500 * time.Millisecond); !target.Equal(want) { - t.Fatalf("target = %s, want %s", target, want) - } -} - -func TestSplitCarriersForTimedEmissionKeepsOneEnvelopePerCarrier(t *testing.T) { - carriers := splitCarriersForTimedEmission([]aistream.Carrier{{ - Envelopes: []aistream.Envelope{ - {Seq: 1}, - {Seq: 2}, - }, - }}) - if len(carriers) != 2 { - t.Fatalf("carrier count = %d, want 2", len(carriers)) - } - if carriers[0].Envelopes[0].Seq != 1 || carriers[1].Envelopes[0].Seq != 2 { - t.Fatalf("bad split carriers: %#v", carriers) - } -} - -func TestApprovalContextForMessageFallsBackToStoredMessage(t *testing.T) { - want := aistream.ApprovalContext{ - ID: "approval-1", - ThreadID: "thread-1", - RunID: "run-1", - MessageID: "msg-1", - Command: "stream-tools 120 shell#approval", - ToolCallID: "tool-1", - TargetEvent: "$event", - SeqStart: 12, - } - stub := &database.Message{ID: "approval-1"} - rawMetadata, err := json.Marshal(map[string]any{"com.beeper.ai.approval": want}) - if err != nil { - t.Fatal(err) - } - fetched := &database.Message{ID: "approval-1", Metadata: rawMetadata} - called := false - - got, ok := approvalContextForMessage(context.Background(), stub, func(_ context.Context, messageID networkid.MessageID) (*database.Message, error) { - called = true - if messageID != stub.ID { - t.Fatalf("fetch message ID = %q, want %q", messageID, stub.ID) - } - return fetched, nil - }) - if !ok { - t.Fatal("expected approval context") - } - if !called { - t.Fatal("expected fallback fetch") - } - if got.ID != want.ID || got.RunID != want.RunID || got.TargetEvent != want.TargetEvent || got.SeqStart != want.SeqStart { - t.Fatalf("approval context = %#v, want %#v", got, want) - } -} - -func TestApprovalOptionReactionIsBridgeManagedFallback(t *testing.T) { - msg := &bridgev2.MatrixReaction{ - MatrixEventBase: bridgev2.MatrixEventBase[*event.ReactionEventContent]{ - Event: &event.Event{Content: event.Content{Raw: map[string]any{ - "com.beeper.ai.approval_option": map[string]any{"choice": "approve"}, - }}}, - }, - } - if !isApprovalOptionReaction(msg) { - t.Fatal("expected managed approval option reaction") - } - if isApprovalOptionReaction(&bridgev2.MatrixReaction{MatrixEventBase: bridgev2.MatrixEventBase[*event.ReactionEventContent]{Event: &event.Event{Content: event.Content{Raw: map[string]any{}}}}}) { - t.Fatal("plain user reaction must not be treated as a managed approval option") - } -} - -func TestAnnotateApprovalEventIDsAddsReactionTargetEventToStreamPrompt(t *testing.T) { - run := aistream.NewRun("run-1", "thread-1", aistream.DefaultModel, "ai", "AI", time.Unix(10, 0)) - writer := aistream.NewWriter(run, func() time.Time { return time.Unix(10, 0) }) - writer.ToolApprovalRequested("tool-1", "shell", map[string]any{"command": "ls"}, agui.ToolApproval{ - ID: "approval-1", - NeedsApproval: true, - }) - - annotateApprovalEventIDs(run, map[string]id.EventID{ - "approval-1": "$approval", - }) - - for _, evt := range run.Events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { - continue - } - value, _ := evt["value"].(map[string]any) - if value["approvalMessageId"] != "approval-1" || value["approvalEventId"] != "$approval" { - t.Fatalf("approval stream event missing target ids: %#v", value) - } - return - } - t.Fatal("missing approval-requested event") -} - -func TestApprovalDecisionsAreStoredInRunSession(t *testing.T) { - client := &DummyClient{} - first := agui.ToolApprovalResponse{ID: "approval-1", Approved: true} - decisions := client.recordAIApprovalDecision("run-1", first) - if len(decisions) != 1 || !decisions["approval-1"].Approved { - t.Fatalf("bad first decisions: %#v", decisions) - } - second := agui.ToolApprovalResponse{ID: "approval-2", Approved: false, Reason: "denied"} - decisions = client.recordAIApprovalDecision("run-1", second) - if len(decisions) != 2 || !decisions["approval-1"].Approved || decisions["approval-2"].Reason != "denied" { - t.Fatalf("bad accumulated decisions: %#v", decisions) - } -} diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index 630a004..e3934e1 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -29,6 +29,9 @@ var AllCommands = []commands.CommandHandler{ MessagesCommand, KickMeCommand, FileCommand, + StreamCommand, + StreamToolsCommand, + StreamHelpCommand, CatCommand, CatAvatarCommand, } @@ -264,6 +267,67 @@ var FileCommand = &commands.FullHandler{ }, } +func runStreamCommand(e *commands.Event, name string) { + if e.Portal == nil { + e.Reply("Can only stream within a portal") + return + } + login := e.User.GetDefaultLogin() + if login == nil { + e.Reply("No login") + return + } + client, ok := login.Client.(*DummyClient) + if !ok || client == nil { + e.Reply("Default login is not a dummybridge login") + return + } + body := strings.TrimSpace(name + " " + e.RawArgs) + if _, err := parseCommand(body); err != nil { + e.Reply(err.Error()) + return + } + client.queueAIResponse(e.Ctx, e.Portal, &event.MessageEventContent{Body: body}) + e.Reply("Started %s", name) +} + +var StreamCommand = &commands.FullHandler{ + Func: func(e *commands.Event) { + runStreamCommand(e, "stream") + }, + Name: "stream", + Help: commands.HelpMeta{ + Description: "Generate a random streamed AI event sequence", + Args: "[seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", + Section: DummyHelpsection, + }, + RequiresLogin: true, +} + +var StreamToolsCommand = &commands.FullHandler{ + Func: func(e *commands.Event) { + runStreamCommand(e, "stream-tools") + }, + Name: "stream-tools", + Help: commands.HelpMeta{ + Description: "Generate a streamed AI event sequence with explicit tool calls", + Args: " ... [common options]", + Section: DummyHelpsection, + }, + RequiresLogin: true, +} + +var StreamHelpCommand = &commands.FullHandler{ + Func: func(e *commands.Event) { + e.Reply(helpText()) + }, + Name: "stream-help", + Help: commands.HelpMeta{ + Description: "Show stream command examples", + Section: DummyHelpsection, + }, +} + var catpions []string = []string{ "You’ve cat to be kitten me!", "I’m feline fine!", diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 3adf8fa..b65c6ec 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -49,7 +49,6 @@ func (dc *DummyConnector) GetCapabilities() *bridgev2.NetworkGeneralCapabilities ResolveIdentifier: bridgev2.ResolveIdentifierCapabilities{ CreateDM: true, LookupUsername: true, - ContactList: true, }, }, }