From 086339a8a8b839d95759ad24ce39d059084b8ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Thu, 28 May 2026 19:55:30 +0200 Subject: [PATCH 1/7] Use aistream approval types and adapt tests Switch AG-UI approval tooling to use aistream types and interrupt-based outcomes. Replaced agui.ToolApproval/ToolApprovalResponse usages with aistream equivalents across runner, plans, types, and client code, and call writer.Interrupt() when approval is requested. Updated tests to expect TOOL_CALL_RESULT and RUN_FINISHED interrupts (and to check interrupt metadata/choices) instead of custom approval events; added helpers (firstInterrupts, eventInterrupts, approvalChoicesFromMetadata, stringFromAny) to normalize interrupt/choice payloads. Adjusted annotateApprovalEventIDs to annotate interrupt outcomes, and updated buildAIApprovalContinuationRunWithApprovals call sites. Also updated go.mod/go.sum (dependency bumps and local replace for ai-bridge). --- go.mod | 14 +- go.sum | 64 ++----- pkg/connector/ai_plans.go | 3 +- pkg/connector/ai_runner.go | 4 +- pkg/connector/ai_runtime_test.go | 289 +++++++++++++++++++++---------- pkg/connector/ai_types.go | 3 +- pkg/connector/client.go | 139 +++++++++++---- 7 files changed, 323 insertions(+), 193 deletions(-) diff --git a/go.mod b/go.mod index 02c9550..0a5a442 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,12 @@ toolchain go1.25.6 require ( github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 github.com/rs/zerolog v1.35.1 - go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 + go.mau.fi/util v0.9.9 maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 ) +replace github.com/beeper/ai-bridge => ../ai-bridge + require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/coder/websocket v1.8.14 // indirect @@ -28,12 +30,12 @@ require ( github.com/tidwall/sjson v1.2.5 // indirect github.com/yuin/goldmark v1.8.2 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect - golang.org/x/crypto v0.50.0 // indirect - golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect - golang.org/x/net v0.53.0 // indirect + golang.org/x/crypto v0.51.0 // indirect + golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect + golang.org/x/net v0.54.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index 7ec830b..22abad4 100644 --- a/go.sum +++ b/go.sum @@ -1,47 +1,27 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 h1:Pw2qyz5mizv/UL4JTKiK1sbYfUl6o8dk/KcNyFlSFG0= -github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72/go.mod h1:Uf2M1ogzy7VGB6uUzzHjZL2eaYt79DK0Py8I6xZl3r0= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= -github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= -github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= -github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= @@ -58,43 +38,25 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= -github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= -go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= -go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 h1:YPEmc+li7TF6C9AdRTcSLMb6yCHdF27/wNT7kFLIVNg= -go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25/go.mod h1:jE9FfhbgEgAwxei6lomO9v8zdCIATcquONUu4vjRwSs= +go.mau.fi/util v0.9.9 h1:ujDeXCo07HBor5oQLyO1tHklupmqVmPgasc53d7q/NE= +go.mau.fi/util v0.9.9/go.mod h1:pqt4Vcrt+5gcH/CgrHZg11qSx+b34o6mknGzOEA6waY= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= -golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= -golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -103,7 +65,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= -maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 h1:zNC9eVAhw8FhKpM3AxNAh/iy75UEYX91uJUvqqAYlvo= maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4/go.mod h1:3sOGhXi3P1V6/NruTA0gujkvTypXVUraWktCuTGyDuM= diff --git a/pkg/connector/ai_plans.go b/pkg/connector/ai_plans.go index b96987b..4d7d736 100644 --- a/pkg/connector/ai_plans.go +++ b/pkg/connector/ai_plans.go @@ -67,7 +67,7 @@ func hasSeedFlag(input string) bool { return false } -func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { +func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]aistream.ToolApprovalResponse) (*aistream.Run, error) { runtime := virtualAIRuntime(now) run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) writer := aistream.NewWriter(run, runtime.now) @@ -87,6 +87,7 @@ func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID str err = runner.runRandom(ctx, writer, *cmd.Random) } if errors.Is(err, errApprovalRequested) { + writer.Interrupt() err = nil } if err != nil { diff --git a/pkg/connector/ai_runner.go b/pkg/connector/ai_runner.go index 6fc497c..0c0e2e7 100644 --- a/pkg/connector/ai_runner.go +++ b/pkg/connector/ai_runner.go @@ -204,9 +204,9 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) input := toolRequestInput(spec) approvalID := approvalIDForRun(w.Run.RunID, toolCallID) - var approval *agui.ToolApproval + var approval *aistream.ToolApproval if spec.Approval { - approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} + approval = &aistream.ToolApproval{ID: approvalID, NeedsApproval: true} } displayMetadata := toolDisplayMetadata(spec.Name) w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index a94d520..690e1b2 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -117,18 +117,15 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { t.Fatalf("approval prompt ID = %q, want run-scoped ID", run.Prompts[0].ID) } foundToolStart := false - seenApprovalStateBeforeCustom := false + seenToolCallEndBeforeInterrupt := false + seenInterrupt := false for _, evt := range run.Events { if evt["type"] == agui.EventToolCallStart { - if evt["state"] != agui.ToolStateApprovalRequested { - t.Fatalf("expected approval-requested tool state, got %#v", evt) + if evt["state"] != agui.ToolStateAwaitingInput { + t.Fatalf("tool start should stay a normal AG-UI tool call, got %#v", evt) } - approval, ok := evt["approval"].(*agui.ToolApproval) - if !ok { - t.Fatalf("expected tool start approval metadata, got %#v", evt["approval"]) - } - if approval.ID != "approval-run-1-dummy-tool-1-shell" || !approval.NeedsApproval { - t.Fatalf("bad approval metadata: %#v", approval) + if _, ok := evt["approval"]; ok { + t.Fatalf("tool start must not carry Beeper approval metadata: %#v", evt) } metadata, ok := evt["metadata"].(map[string]any) if !ok || metadata["displayName"] != "Run Command" { @@ -137,50 +134,51 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { foundToolStart = true } if evt["type"] == agui.EventToolCallEnd { - if evt["state"] == agui.ToolStateInputComplete { - t.Fatalf("approval tool must not downgrade to input-complete: %#v", evt) + if evt["state"] != agui.ToolStateInputComplete { + t.Fatalf("tool call should finish normally before AG-UI interrupt: %#v", evt) } - if evt["state"] == agui.ToolStateApprovalRequested { - if evt["input"] != nil { - t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) - } - seenApprovalStateBeforeCustom = true + if evt["input"] != nil { + t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) } + seenToolCallEndBeforeInterrupt = true + } + if evt["type"] == agui.EventCustom { + t.Fatalf("approval must use AG-UI interrupt outcome, not custom event: %#v", evt) } - if evt["type"] == agui.EventCustom && evt["name"] == agui.ApprovalCustomRequested { - if !seenApprovalStateBeforeCustom { - t.Fatalf("approval custom event should be emitted after approval state update: %#v", run.Events) + if evt["type"] == agui.EventRunFinished { + if !seenToolCallEndBeforeInterrupt { + t.Fatalf("approval interrupt should be emitted after approval state update: %#v", run.Events) } - value := evt["value"].(map[string]any) - if _, hasOptions := value["options"]; hasOptions { - t.Fatalf("AG-UI approval event must not embed Matrix reaction options: %#v", value) + interrupts := eventInterrupts(t, evt) + if len(interrupts) != 1 { + t.Fatalf("approval run should finish with one interrupt: %#v", evt) } - if value["approvalMessageId"] != "approval-run-1-dummy-tool-1-shell" { - t.Fatalf("approval event should name the Matrix reaction target: %#v", value) + interrupt := interrupts[0] + if interrupt.ID != "approval-run-1-dummy-tool-1-shell" || interrupt.Reason != agui.InterruptReasonToolCall || interrupt.ToolCallID != "dummy-tool-1-shell" { + t.Fatalf("bad approval interrupt: %#v", interrupt) } - metadata, ok := value["metadata"].(map[string]any) + metadata, ok := interrupt.Metadata["metadata"].(map[string]any) if !ok || metadata["displayName"] != "Run Command" { - t.Fatalf("approval event should carry tool display metadata: %#v", value["metadata"]) + t.Fatalf("approval interrupt should carry tool display metadata: %#v", interrupt.Metadata) } - choices, ok := value["choices"].([]aistream.ApprovalChoice) - if !ok || len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { - t.Fatalf("approval event should duplicate renderer choices: %#v", value["choices"]) + choices := approvalChoicesFromMetadata(t, interrupt.Metadata) + if len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { + t.Fatalf("approval interrupt should duplicate renderer choices: %#v", interrupt.Metadata["choices"]) } - if value["input"] != nil { - t.Fatalf("approval event should omit placeholder tool input: %#v", value) + if interrupt.Metadata["input"] != nil { + t.Fatalf("approval interrupt should omit placeholder tool input: %#v", interrupt.Metadata) } + seenInterrupt = true } } if !foundToolStart { t.Fatal("missing tool start event") } - if run.Status.State != "streaming" { - t.Fatalf("approval request should pause the run without terminal status, got %#v", run.Status) + if run.Status.State != "interrupted" { + t.Fatalf("approval request should interrupt the run, got %#v", run.Status) } - for _, evt := range run.Events { - if evt["type"] == agui.EventRunFinished { - t.Fatalf("approval request should not finish the run before response: %#v", run.Events) - } + if !seenInterrupt { + t.Fatalf("approval request missing AG-UI interrupt outcome: %#v", run.Events) } } @@ -228,7 +226,7 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { AgentName: run.AgentName, SeqStart: prompt.SeqStart, } - continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: prompt.ID, Approved: true, }}, time.Unix(20, 0)) @@ -293,35 +291,39 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if err != nil { t.Fatal(err) } - var annotatedValue map[string]any + var annotatedInterrupt *agui.Interrupt for _, carrier := range annotatedCarriers { for _, env := range carrier.Envelopes { - if env.Part["type"] != agui.EventCustom || env.Part["name"] != agui.ApprovalCustomRequested { + if env.Part["type"] != agui.EventRunFinished { continue } - annotatedValue, _ = env.Part["value"].(map[string]any) + interrupts := eventInterrupts(t, env.Part) + if len(interrupts) > 0 { + interrupt := interrupts[0] + annotatedInterrupt = &interrupt + } } } - if annotatedValue == nil || annotatedValue["approvalMessageId"] != prompt.ID { - t.Fatalf("approval-requested stream event missing approval message id: %#v", annotatedValue) + if annotatedInterrupt == nil || annotatedInterrupt.Metadata["approvalMessageId"] != prompt.ID { + t.Fatalf("approval interrupt missing approval message id: %#v", annotatedInterrupt) } - if annotatedValue["approvalEventId"] != "$approval" { - t.Fatalf("approval-requested stream event missing Matrix event target: %#v", annotatedValue) + if annotatedInterrupt.Metadata["approvalEventId"] != "$approval" { + t.Fatalf("approval interrupt missing Matrix event target: %#v", annotatedInterrupt) } annotatedCarriers = splitCarriersForTimedEmission(annotatedCarriers) if annotatedNextSeq := aistream.NextSeq(annotatedCarriers); annotatedNextSeq != nextSeq { t.Fatalf("approval event target changed stream sequence: initial=%d annotated=%d", nextSeq, annotatedNextSeq) } - choices, ok := annotatedValue["choices"].([]any) - if !ok || len(choices) != len(aistream.DefaultApprovalChoices()) { - t.Fatalf("approval-requested stream event missing choices: %#v", annotatedValue["choices"]) + choices := approvalChoicesFromMetadata(t, annotatedInterrupt.Metadata) + if len(choices) != len(aistream.DefaultApprovalChoices()) { + t.Fatalf("approval interrupt missing choices: %#v", annotatedInterrupt.Metadata["choices"]) } - firstChoice, ok := choices[0].(map[string]any) - if !ok || firstChoice["key"] != aistream.ApprovalChoiceApprove || firstChoice["label"] != "Allow once" { + firstChoice := choices[0] + if firstChoice.Key != aistream.ApprovalChoiceApprove || firstChoice.Label != "Allow once" { t.Fatalf("approval-requested stream event has bad choice shape: %#v", choices[0]) } - continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: prompt.ID, Approved: true, }}, time.Unix(20, 0)) @@ -341,8 +343,8 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { t.Fatalf("continuation should resume at seq %d, got %#v", nextSeq, continuationCarriers) } - if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("continuation must start by acknowledging approval: %#v", continuation.Events) + if continuation.Events[0]["type"] != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != prompt.ID { + t.Fatalf("continuation must start with approval tool result: %#v", continuation.Events) } } @@ -361,7 +363,7 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, }}, time.Unix(20, 0)) @@ -371,19 +373,16 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { if len(run.Events) == 0 { t.Fatal("expected continuation events") } - if run.Events[0]["type"] != agui.EventCustom || run.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("first continuation event should acknowledge approval, got %#v", run.Events[0]) + if run.Events[0]["type"] != agui.EventToolCallResult || toolResultApprovalID(run.Events[0]) != approvalCtx.ID { + t.Fatalf("first continuation event should be approval tool result, got %#v", run.Events[0]) } seenApprovedTool := false seenLaterTool := false seenFinished := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID { - if evt["state"] == agui.ToolStateApprovalResponded { - result := jsonResultMap(t, evt["result"]) - if result["approved"] != true { - t.Fatalf("approved result missing approval state: %#v", result) - } + if evt["type"] == agui.EventToolCallResult && evt["toolCallId"] == approvalCtx.ToolCallID { + result := jsonResultMap(t, evt["content"]) + if result["approved"] == true { seenApprovedTool = true } } @@ -420,7 +419,7 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: false, Reason: "denied", @@ -433,8 +432,8 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { t.Fatalf("denied approval must not continue later tools: %#v", run.Events) } - if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID && evt["state"] == agui.ToolStateApprovalResponded { - result := jsonResultMap(t, evt["result"]) + if evt["type"] == agui.EventToolCallResult && evt["toolCallId"] == approvalCtx.ToolCallID { + result := jsonResultMap(t, evt["content"]) if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { t.Fatalf("bad denied result: %#v", result) } @@ -455,10 +454,10 @@ func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallEnd { + if evt["type"] != agui.EventToolCallResult { continue } - result := jsonResultMap(t, evt["result"]) + result := jsonResultMap(t, evt["content"]) if result["state"] == agui.ToolResultStateError && result["reason"] == "denied" { return } @@ -471,6 +470,8 @@ func TestBuildAIRunToolsOmitPlaceholderArgsAndEmitTerminalResult(t *testing.T) { if err != nil { t.Fatal(err) } + seenEnd := false + seenResult := false for _, evt := range run.Events { if evt["type"] == agui.EventToolCallArgs { t.Fatalf("plain demo tool should not emit placeholder args: %#v", evt) @@ -479,14 +480,22 @@ func TestBuildAIRunToolsOmitPlaceholderArgsAndEmitTerminalResult(t *testing.T) { if evt["input"] != nil { t.Fatalf("plain demo tool should omit placeholder input: %#v", evt) } - result := jsonResultMap(t, evt["result"]) + if _, hasResult := evt["result"]; hasResult { + t.Fatalf("TOOL_CALL_END must not carry result: %#v", evt) + } + seenEnd = true + } + if evt["type"] == agui.EventToolCallResult { + result := jsonResultMap(t, evt["content"]) if result["state"] != agui.ToolResultStateComplete || result["status"] != "success" { t.Fatalf("plain demo tool should emit terminal success result: %#v", evt) } - return + seenResult = true } } - t.Fatal("missing TOOL_CALL_END event") + if !seenEnd || !seenResult { + t.Fatalf("missing TOOL_CALL_END/TOOL_CALL_RESULT events: %#v", run.Events) + } } func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { @@ -511,7 +520,7 @@ func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { if err != nil { t.Fatal(err) } - var snapshot []agui.UIMessage + var snapshot []agui.Message seenRunFinished := false for _, evt := range run.Events { switch evt["type"] { @@ -520,7 +529,7 @@ func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { t.Fatal("final snapshot must be emitted before RUN_FINISHED") } var ok bool - snapshot, ok = evt["messages"].([]agui.UIMessage) + snapshot, ok = evt["messages"].([]agui.Message) if !ok { t.Fatalf("bad snapshot payload: %#v", evt["messages"]) } @@ -528,21 +537,21 @@ func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { seenRunFinished = true } } - if len(snapshot) != 1 { - t.Fatalf("expected one final UI message snapshot, got %#v", snapshot) + if len(snapshot) == 0 { + t.Fatalf("expected final message snapshot, got %#v", snapshot) } seenToolCall := false seenToolResult := false - for _, part := range snapshot[0].Parts { - switch part["type"] { - case "tool-call": + for _, message := range snapshot { + if message.Role == agui.RoleAssistant && len(message.ToolCalls) > 0 { seenToolCall = true - case "tool-result": + } + if message.Role == agui.RoleTool && message.ToolCallID != "" { seenToolResult = true } } if !seenToolCall || !seenToolResult { - t.Fatalf("final snapshot lost tool parts: %#v", snapshot[0].Parts) + t.Fatalf("final snapshot lost tool messages: %#v", snapshot) } } @@ -554,22 +563,22 @@ func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { seenFailure := false seenInputError := false for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallEnd && evt["type"] != agui.EventToolCallArgs { + if evt["type"] != agui.EventToolCallResult && evt["type"] != agui.EventToolCallArgs { continue } toolCallID, _ := evt["toolCallId"].(string) if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { t.Fatalf("delta tool without real input should not emit placeholder args: %#v", evt) } - if evt["type"] == agui.EventToolCallEnd { + if evt["type"] == agui.EventToolCallResult { if strings.Contains(toolCallID, "shell") { - result := jsonResultMap(t, evt["result"]) + result := jsonResultMap(t, evt["content"]) if result["state"] == agui.ToolResultStateError { seenFailure = true } } if strings.Contains(toolCallID, "parser") { - result := jsonResultMap(t, evt["result"]) + result := jsonResultMap(t, evt["content"]) if result["reason"] == "input-error" { seenInputError = true } @@ -720,13 +729,11 @@ func TestRandomModeApprovalPause(t *testing.T) { if run.ApprovalID == "" { continue } - for _, evt := range run.Events { - if evt["type"] == agui.EventRunFinished { - t.Fatalf("approval run emitted RUN_FINISHED with seed %d", seed) - } + if run.Status.State != "interrupted" { + t.Fatalf("expected approval run to interrupt, got %q", run.Status.State) } - if run.Status.State != "streaming" { - t.Fatalf("expected approval run to remain streaming, got %q", run.Status.State) + if len(firstInterrupts(t, run.Events)) == 0 { + t.Fatalf("approval run missing interrupt outcome with seed %d", seed) } return } @@ -916,7 +923,7 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, }}, time.Unix(20, 0)) @@ -929,8 +936,11 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { if run.Prompts[0].ToolName != "fetch" { t.Fatalf("expected preserved prompt to belong to fetch, got %#v", run.Prompts[0]) } - if run.Status.State != "streaming" { - t.Fatalf("expected continuation with pending approval to remain streaming, got %#v", run.Status) + if run.Status.State != "interrupted" { + t.Fatalf("expected continuation with pending approval to interrupt, got %#v", run.Status) + } + if len(firstInterrupts(t, run.Events)) != 1 { + t.Fatalf("expected continuation with pending approval to finish with one interrupt: %#v", run.Events) } secondCtx := aistream.ApprovalContext{ @@ -946,7 +956,7 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { AgentName: approvalCtx.AgentName, SeqStart: 100, } - finished, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), secondCtx, map[string]agui.ToolApprovalResponse{ + finished, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), secondCtx, map[string]aistream.ToolApprovalResponse{ approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, @@ -1000,7 +1010,7 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { AgentName: "AI", SeqStart: 50, } - continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, }}, now.Add(time.Hour)) @@ -1010,8 +1020,8 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { if len(continuation.Events) == 0 { t.Fatalf("expected continuation events for random run, got none") } - if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("first continuation event should acknowledge approval, got %#v", continuation.Events[0]) + if continuation.Events[0]["type"] != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != approvalCtx.ID { + t.Fatalf("first continuation event should be approval tool result, got %#v", continuation.Events[0]) } return } @@ -1055,3 +1065,90 @@ func jsonResultMap(t *testing.T, value any) map[string]any { } return out } + +func firstInterrupts(t *testing.T, events []agui.Event) []agui.Interrupt { + t.Helper() + for _, evt := range events { + if evt["type"] != agui.EventRunFinished { + continue + } + interrupts := eventInterrupts(t, evt) + if len(interrupts) > 0 { + return interrupts + } + } + return nil +} + +func eventInterrupts(t *testing.T, evt agui.Event) []agui.Interrupt { + t.Helper() + switch outcome := evt["outcome"].(type) { + case agui.RunFinishedOutcome: + if outcome.Type != agui.OutcomeInterrupt { + return nil + } + return outcome.Interrupts + case map[string]any: + if outcome["type"] != agui.OutcomeInterrupt { + return nil + } + rawInterrupts, ok := outcome["interrupts"].([]any) + if !ok { + t.Fatalf("bad interrupt payload: %#v", outcome["interrupts"]) + } + interrupts := make([]agui.Interrupt, 0, len(rawInterrupts)) + for _, raw := range rawInterrupts { + value, ok := raw.(map[string]any) + if !ok { + t.Fatalf("bad interrupt value: %#v", raw) + } + metadata, _ := value["metadata"].(map[string]any) + responseSchema, _ := value["responseSchema"].(map[string]any) + interrupts = append(interrupts, agui.Interrupt{ + ID: stringFromAny(value["id"]), + Reason: stringFromAny(value["reason"]), + Message: stringFromAny(value["message"]), + ToolCallID: stringFromAny(value["toolCallId"]), + ExpiresAt: stringFromAny(value["expiresAt"]), + ResponseSchema: responseSchema, + Metadata: metadata, + }) + } + return interrupts + default: + t.Fatalf("bad outcome payload: %#v", evt["outcome"]) + return nil + } +} + +func approvalChoicesFromMetadata(t *testing.T, metadata map[string]any) []aistream.ApprovalChoice { + t.Helper() + switch raw := metadata["choices"].(type) { + case []aistream.ApprovalChoice: + return raw + case []any: + choices := make([]aistream.ApprovalChoice, 0, len(raw)) + for _, item := range raw { + value, ok := item.(map[string]any) + if !ok { + t.Fatalf("bad approval choice: %#v", item) + } + choices = append(choices, aistream.ApprovalChoice{ + Key: stringFromAny(value["key"]), + Label: stringFromAny(value["label"]), + Alias: stringFromAny(value["alias"]), + Style: stringFromAny(value["style"]), + Shortcut: stringFromAny(value["shortcut"]), + }) + } + return choices + default: + t.Fatalf("bad approval choices: %#v", raw) + return nil + } +} + +func stringFromAny(value any) string { + text, _ := value.(string) + return text +} diff --git a/pkg/connector/ai_types.go b/pkg/connector/ai_types.go index 80c3ee2..c6dbaaa 100644 --- a/pkg/connector/ai_types.go +++ b/pkg/connector/ai_types.go @@ -5,7 +5,6 @@ import ( "errors" "time" - "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" ) @@ -142,7 +141,7 @@ type aiRuntime struct { type aiRunner struct { runtime aiRuntime - approvals map[string]agui.ToolApprovalResponse + approvals map[string]aistream.ToolApprovalResponse } type aiRunPlan struct { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 378b3d4..7a1940d 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -42,7 +42,7 @@ type DummyClient struct { } type aiRunSession struct { - Decisions map[string]agui.ToolApprovalResponse + Decisions map[string]aistream.ToolApprovalResponse } var _ bridgev2.NetworkAPI = (*DummyClient)(nil) @@ -846,20 +846,60 @@ func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) if run == nil || len(eventIDs) == 0 { return } - for _, evt := range run.Events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + for i := range run.Interrupts { + eventID := eventIDs[run.Interrupts[i].ID] + if eventID == "" { continue } - value, _ := evt["value"].(map[string]any) - if value == nil { + aistream.SetApprovalInterruptEventID(&run.Interrupts[i], string(eventID)) + } + for _, evt := range run.Events { + if evt["type"] != agui.EventRunFinished { continue } - approvalID := aistream.ApprovalIDFromRequestedValue(value) - eventID := eventIDs[approvalID] - if eventID == "" { - continue + annotateApprovalOutcomeEventIDs(evt, eventIDs) + } +} + +func annotateApprovalOutcomeEventIDs(evt agui.Event, eventIDs map[string]id.EventID) { + switch outcome := evt["outcome"].(type) { + case agui.RunFinishedOutcome: + for i := range outcome.Interrupts { + eventID := eventIDs[outcome.Interrupts[i].ID] + if eventID == "" { + continue + } + aistream.SetApprovalInterruptEventID(&outcome.Interrupts[i], string(eventID)) + } + evt["outcome"] = outcome + case *agui.RunFinishedOutcome: + if outcome == nil { + return + } + for i := range outcome.Interrupts { + eventID := eventIDs[outcome.Interrupts[i].ID] + if eventID == "" { + continue + } + aistream.SetApprovalInterruptEventID(&outcome.Interrupts[i], string(eventID)) + } + case map[string]any: + interrupts, _ := outcome["interrupts"].([]any) + for _, raw := range interrupts { + interrupt, _ := raw.(map[string]any) + approvalID, _ := interrupt["id"].(string) + eventID := eventIDs[approvalID] + if interrupt == nil || eventID == "" { + continue + } + metadata, _ := interrupt["metadata"].(map[string]any) + if metadata == nil { + metadata = map[string]any{} + interrupt["metadata"] = metadata + } + metadata["approvalMessageId"] = approvalID + metadata["approvalEventId"] = string(eventID) } - aistream.SetApprovalRequestedEventID(value, string(eventID)) } } @@ -877,7 +917,7 @@ func approvalEventIDPlaceholders(prompts []aistream.ApprovalPrompt) map[string]i return placeholders } -func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response agui.ToolApprovalResponse) { +func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response aistream.ToolApprovalResponse) { approvalCtx, ok := dc.approvalContextForMessage(ctx, portal, approvalMessage) if !ok { log.Warn().Str("approval_id", messageIDString(approvalMessage)).Msg("Missing AI approval metadata") @@ -916,7 +956,7 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid Msg("Queued AI approval continuation") } -func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { +func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]aistream.ToolApprovalResponse, now time.Time) (aistream.Run, error) { cmd, err := parseCommand(approvalCtx.Command) if err != nil { return aistream.Run{}, err @@ -950,15 +990,9 @@ func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, if len(prompts) == 0 { return nil } - requested := make(map[string]bool, len(events)) - for _, evt := range events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { - continue - } - value, _ := evt["value"].(map[string]any) - if id := aistream.ApprovalIDFromRequestedValue(value); id != "" { - requested[id] = true - } + requested := approvalInterruptIDsFromEvents(events) + if len(requested) == 0 { + return nil } out := prompts[:0] for _, prompt := range prompts { @@ -973,25 +1007,62 @@ func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, return out } +func approvalInterruptIDsFromEvents(events []agui.Event) map[string]bool { + requested := map[string]bool{} + for _, evt := range events { + if evt["type"] != agui.EventRunFinished { + continue + } + switch outcome := evt["outcome"].(type) { + case agui.RunFinishedOutcome: + if outcome.Type != agui.OutcomeInterrupt { + continue + } + for _, interrupt := range outcome.Interrupts { + if interrupt.ID != "" { + requested[interrupt.ID] = true + } + } + case map[string]any: + if outcome["type"] != agui.OutcomeInterrupt { + continue + } + interrupts, _ := outcome["interrupts"].([]any) + for _, raw := range interrupts { + interrupt, _ := raw.(map[string]any) + if id, _ := interrupt["id"].(string); id != "" { + requested[id] = true + } + } + } + } + return requested +} + func approvalContinuationStart(events []agui.Event, approvalID string) int { for i, evt := range events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomResponded { + if evt["type"] != agui.EventToolCallResult { continue } - value, _ := evt["value"].(map[string]any) - approval, _ := value["approval"].(agui.ToolApprovalResponse) - if approval.ID == approvalID { + if toolResultApprovalID(evt) == approvalID { return i } - if raw, ok := value["approval"].(map[string]any); ok { - if idValue, _ := raw["id"].(string); idValue == approvalID { - return i - } - } } return -1 } +func toolResultApprovalID(evt agui.Event) string { + content, _ := evt["content"].(string) + if content == "" { + return "" + } + result, ok := aistream.ParseApprovalToolResult(content) + if !ok { + return "" + } + return result.ApprovalID +} + func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { var fetch func(context.Context, networkid.MessageID) (*database.Message, error) if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && dc.UserLogin.Bridge.DB != nil && portal != nil { @@ -1105,12 +1176,12 @@ func (dc *DummyClient) ensureAIRunSession(runID string) { dc.aiRunSessions = make(map[string]*aiRunSession) } if dc.aiRunSessions[runID] == nil { - dc.aiRunSessions[runID] = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + dc.aiRunSessions[runID] = &aiRunSession{Decisions: make(map[string]aistream.ToolApprovalResponse)} } } -func (dc *DummyClient) recordAIApprovalDecision(runID string, response agui.ToolApprovalResponse) map[string]agui.ToolApprovalResponse { - decisions := make(map[string]agui.ToolApprovalResponse) +func (dc *DummyClient) recordAIApprovalDecision(runID string, response aistream.ToolApprovalResponse) map[string]aistream.ToolApprovalResponse { + decisions := make(map[string]aistream.ToolApprovalResponse) if response.ID == "" { return decisions } @@ -1125,7 +1196,7 @@ func (dc *DummyClient) recordAIApprovalDecision(runID string, response agui.Tool } session := dc.aiRunSessions[runID] if session == nil { - session = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + session = &aiRunSession{Decisions: make(map[string]aistream.ToolApprovalResponse)} dc.aiRunSessions[runID] = session } session.Decisions[response.ID] = response From 1fa6509f5526c3a2aee17a7b26f5d44570c599fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Thu, 28 May 2026 23:50:41 +0200 Subject: [PATCH 2/7] wip --- pkg/connector/ai_commands.go | 7 +- pkg/connector/ai_runner.go | 4 +- pkg/connector/ai_runtime_test.go | 152 ++++++++++++----------- pkg/connector/client.go | 201 +++++++++++++++++++++---------- pkg/connector/client_test.go | 14 +-- 5 files changed, 235 insertions(+), 143 deletions(-) diff --git a/pkg/connector/ai_commands.go b/pkg/connector/ai_commands.go index 26fd54f..5ce5c8b 100644 --- a/pkg/connector/ai_commands.go +++ b/pkg/connector/ai_commands.go @@ -36,7 +36,7 @@ func helpText() string { return strings.Join([]string{ "DummyBridge demo commands:", "help", - "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", + "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|tool_calls|content_filter|other|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", "stream-tools ... [common options]", "Notes: stream enables approval requests by default; approval-tagged tools emit a separate Matrix approval event with reaction options.", }, "\n") @@ -183,7 +183,7 @@ func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions b cmd.Terminal = "finish" case "abort", "error": cmd.Terminal = strings.ToLower(value) - case "length", "tool-calls", "content-filter", "other": + case "length", "tool_calls", "content_filter", "other": cmd.Terminal = agui.NormalizeFinishReason(value) default: return nil, fmt.Errorf("unknown terminal %q", value) @@ -348,6 +348,9 @@ func parseCommonOptions(tokens []string) (commonCommandOptions, error) { return opts, fmt.Errorf("%s requires a value", token) } opts.FinishReason = agui.NormalizeFinishReason(value) + if !agui.ValidFinishReason(opts.FinishReason) { + return opts, fmt.Errorf("unknown finish reason %q", value) + } case "abort": opts.Abort = true case "error": diff --git a/pkg/connector/ai_runner.go b/pkg/connector/ai_runner.go index 0c0e2e7..6f8b38c 100644 --- a/pkg/connector/ai_runner.go +++ b/pkg/connector/ai_runner.go @@ -330,13 +330,13 @@ func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { if !spec.Provider || w == nil || w.Run == nil || len(w.Run.Events) == 0 { return } - w.Run.Events[len(w.Run.Events)-1]["rawEvent"] = map[string]any{ + w.Run.Events[len(w.Run.Events)-1].Set("rawEvent", map[string]any{ "provider": "dummybridge", "stage": stage, "tool": spec.Name, "sequence": spec.SequenceIndex, "tags": spec.Tags, - } + }) } func jsonToolInput(input any) string { diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 690e1b2..4a5038a 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -71,19 +71,19 @@ func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { } seen := map[string]bool{} for _, evt := range run.Events { - switch evt["type"] { + switch evt.Type() { case agui.EventTextMessageContent, agui.EventStepStarted, agui.EventStepFinished: - seen[evt["type"].(string)] = true + seen[evt.Type()] = true case agui.EventStateDelta: - seen[evt["type"].(string)] = true - if _, ok := evt["delta"].([]map[string]any); !ok { - t.Fatalf("STATE_DELTA should use JSON Patch array, got %#v", evt["delta"]) + seen[evt.Type()] = true + if _, ok := evt.Get("delta").([]map[string]any); !ok { + t.Fatalf("STATE_DELTA should use JSON Patch array, got %#v", evt.Get("delta")) } case agui.EventCustom: - name, _ := evt["name"].(string) + name, _ := evt.Get("name").(string) seen[name] = true if name == "com.beeper.data" { - value := evt["value"].(map[string]any) + value := evt.Get("value").(map[string]any) if value["name"] == "temp" { t.Fatal("transient data must not persist as metadata") } @@ -120,32 +120,32 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { seenToolCallEndBeforeInterrupt := false seenInterrupt := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallStart { - if evt["state"] != agui.ToolStateAwaitingInput { + if evt.Type() == agui.EventToolCallStart { + if evt.Get("state") != agui.ToolStateAwaitingInput { t.Fatalf("tool start should stay a normal AG-UI tool call, got %#v", evt) } - if _, ok := evt["approval"]; ok { + if evt.Has("approval") { t.Fatalf("tool start must not carry Beeper approval metadata: %#v", evt) } - metadata, ok := evt["metadata"].(map[string]any) + metadata, ok := evt.Get("metadata").(map[string]any) if !ok || metadata["displayName"] != "Run Command" { - t.Fatalf("bad tool display metadata: %#v", evt["metadata"]) + t.Fatalf("bad tool display metadata: %#v", evt.Get("metadata")) } foundToolStart = true } - if evt["type"] == agui.EventToolCallEnd { - if evt["state"] != agui.ToolStateInputComplete { + if evt.Type() == agui.EventToolCallEnd { + if evt.Get("state") != agui.ToolStateInputComplete { t.Fatalf("tool call should finish normally before AG-UI interrupt: %#v", evt) } - if evt["input"] != nil { + if evt.Get("input") != nil { t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) } seenToolCallEndBeforeInterrupt = true } - if evt["type"] == agui.EventCustom { + if evt.Type() == agui.EventCustom { t.Fatalf("approval must use AG-UI interrupt outcome, not custom event: %#v", evt) } - if evt["type"] == agui.EventRunFinished { + if evt.Type() == agui.EventRunFinished { if !seenToolCallEndBeforeInterrupt { t.Fatalf("approval interrupt should be emitted after approval state update: %#v", run.Events) } @@ -202,11 +202,11 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { if err != nil { t.Fatal(err) } - carriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + carriers, err := aistream.PackRunFromSeq(*run, aistream.CarrierBudgetBytes, 1) if err != nil { t.Fatal(err) } - nextSeq := aistream.NextSeq(splitCarriersForTimedEmission(carriers)) + nextSeq := aistream.NextSeq(carriers) if nextSeq <= 1 { t.Fatalf("expected initial stream to consume carrier sequence numbers, got %d", nextSeq) } @@ -233,7 +233,7 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { if err != nil { t.Fatal(err) } - continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + continuationCarriers, err := aistream.PackRunFromSeq(continuation, aistream.CarrierBudgetBytes, approvalCtx.SeqStart) if err != nil { t.Fatal(err) } @@ -256,11 +256,11 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { } sizingRun := *run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) - initialCarriers, err := aistream.PackRunFromSeq(sizingRun, "$anchor", aistream.CarrierBudgetBytes, 1) + initialCarriers, err := aistream.PackRunFromSeq(sizingRun, aistream.CarrierBudgetBytes, 1) if err != nil { t.Fatal(err) } - initialCarriers = splitCarriersForTimedEmission(initialCarriers) + initialCarriers = initialCarriers nextSeq := aistream.NextSeq(initialCarriers) if nextSeq <= 1 { t.Fatalf("expected initial carriers to advance sequence, got %d", nextSeq) @@ -287,17 +287,17 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { } annotateApprovalEventIDs(run, map[string]id.EventID{prompt.ID: "$approval"}) - annotatedCarriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + annotatedCarriers, err := aistream.PackRunFromSeq(*run, aistream.CarrierBudgetBytes, 1) if err != nil { t.Fatal(err) } var annotatedInterrupt *agui.Interrupt for _, carrier := range annotatedCarriers { for _, env := range carrier.Envelopes { - if env.Part["type"] != agui.EventRunFinished { + if env.Event.Type() != agui.EventRunFinished { continue } - interrupts := eventInterrupts(t, env.Part) + interrupts := eventInterrupts(t, env.Event) if len(interrupts) > 0 { interrupt := interrupts[0] annotatedInterrupt = &interrupt @@ -310,7 +310,7 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if annotatedInterrupt.Metadata["approvalEventId"] != "$approval" { t.Fatalf("approval interrupt missing Matrix event target: %#v", annotatedInterrupt) } - annotatedCarriers = splitCarriersForTimedEmission(annotatedCarriers) + annotatedCarriers = annotatedCarriers if annotatedNextSeq := aistream.NextSeq(annotatedCarriers); annotatedNextSeq != nextSeq { t.Fatalf("approval event target changed stream sequence: initial=%d annotated=%d", nextSeq, annotatedNextSeq) } @@ -333,17 +333,20 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if len(continuation.Prompts) != 0 { t.Fatalf("continuation must not request approval again: %#v", continuation.Prompts) } + if len(continuation.Interrupts) != 0 || continuation.ApprovalID != "" || continuation.ToolCallID != "" { + t.Fatalf("finished continuation kept pending approval state: interrupts=%#v approval=%q tool=%q", continuation.Interrupts, continuation.ApprovalID, continuation.ToolCallID) + } if continuation.Status.State != "complete" { t.Fatalf("approved continuation should finish the run, got %#v", continuation.Status) } - continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + continuationCarriers, err := aistream.PackRunFromSeq(continuation, aistream.CarrierBudgetBytes, approvalCtx.SeqStart) if err != nil { t.Fatal(err) } if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { t.Fatalf("continuation should resume at seq %d, got %#v", nextSeq, continuationCarriers) } - if continuation.Events[0]["type"] != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != prompt.ID { + if continuation.Events[0].Type() != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != prompt.ID { t.Fatalf("continuation must start with approval tool result: %#v", continuation.Events) } } @@ -373,23 +376,23 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { if len(run.Events) == 0 { t.Fatal("expected continuation events") } - if run.Events[0]["type"] != agui.EventToolCallResult || toolResultApprovalID(run.Events[0]) != approvalCtx.ID { + if run.Events[0].Type() != agui.EventToolCallResult || toolResultApprovalID(run.Events[0]) != approvalCtx.ID { t.Fatalf("first continuation event should be approval tool result, got %#v", run.Events[0]) } seenApprovedTool := false seenLaterTool := false seenFinished := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallResult && evt["toolCallId"] == approvalCtx.ToolCallID { - result := jsonResultMap(t, evt["content"]) + if evt.Type() == agui.EventToolCallResult && evt.Get("toolCallId") == approvalCtx.ToolCallID { + result := jsonResultMap(t, evt.Get("content")) if result["approved"] == true { seenApprovedTool = true } } - if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + if evt.Type() == agui.EventToolCallStart && evt.Get("toolCallId") == "dummy-tool-2-fetch" { seenLaterTool = true } - if evt["type"] == agui.EventRunFinished { + if evt.Type() == agui.EventRunFinished { seenFinished = true } } @@ -402,6 +405,9 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { if len(run.Prompts) != 0 { t.Fatalf("finished continuation should not keep pending prompts: %#v", run.Prompts) } + if len(run.Interrupts) != 0 || run.ApprovalID != "" || run.ToolCallID != "" { + t.Fatalf("finished continuation kept pending approval state: interrupts=%#v approval=%q tool=%q", run.Interrupts, run.ApprovalID, run.ToolCallID) + } } func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { @@ -429,11 +435,11 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { } seenDeniedTool := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + if evt.Type() == agui.EventToolCallStart && evt.Get("toolCallId") == "dummy-tool-2-fetch" { t.Fatalf("denied approval must not continue later tools: %#v", run.Events) } - if evt["type"] == agui.EventToolCallResult && evt["toolCallId"] == approvalCtx.ToolCallID { - result := jsonResultMap(t, evt["content"]) + if evt.Type() == agui.EventToolCallResult && evt.Get("toolCallId") == approvalCtx.ToolCallID { + result := jsonResultMap(t, evt.Get("content")) if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { t.Fatalf("bad denied result: %#v", result) } @@ -446,6 +452,9 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { if run.Status.State != "error" { t.Fatalf("denied continuation status = %#v", run.Status) } + if len(run.Prompts) != 0 || len(run.Interrupts) != 0 || run.ApprovalID != "" || run.ToolCallID != "" { + t.Fatalf("denied continuation kept pending approval state: prompts=%#v interrupts=%#v approval=%q tool=%q", run.Prompts, run.Interrupts, run.ApprovalID, run.ToolCallID) + } } func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { @@ -454,10 +463,10 @@ func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallResult { + if evt.Type() != agui.EventToolCallResult { continue } - result := jsonResultMap(t, evt["content"]) + result := jsonResultMap(t, evt.Get("content")) if result["state"] == agui.ToolResultStateError && result["reason"] == "denied" { return } @@ -473,20 +482,20 @@ func TestBuildAIRunToolsOmitPlaceholderArgsAndEmitTerminalResult(t *testing.T) { seenEnd := false seenResult := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallArgs { + if evt.Type() == agui.EventToolCallArgs { t.Fatalf("plain demo tool should not emit placeholder args: %#v", evt) } - if evt["type"] == agui.EventToolCallEnd { - if evt["input"] != nil { + if evt.Type() == agui.EventToolCallEnd { + if evt.Get("input") != nil { t.Fatalf("plain demo tool should omit placeholder input: %#v", evt) } - if _, hasResult := evt["result"]; hasResult { + if evt.Has("result") { t.Fatalf("TOOL_CALL_END must not carry result: %#v", evt) } seenEnd = true } - if evt["type"] == agui.EventToolCallResult { - result := jsonResultMap(t, evt["content"]) + if evt.Type() == agui.EventToolCallResult { + result := jsonResultMap(t, evt.Get("content")) if result["state"] != agui.ToolResultStateComplete || result["status"] != "success" { t.Fatalf("plain demo tool should emit terminal success result: %#v", evt) } @@ -504,10 +513,10 @@ func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallResult { + if evt.Type() != agui.EventToolCallResult { continue } - if evt["state"] != agui.ToolResultStateStreaming || evt["toolCallId"] == "" || evt["content"] == "" { + if evt.Get("state") != agui.ToolResultStateStreaming || evt.Get("toolCallId") == "" || evt.Get("content") == "" { t.Fatalf("bad TOOL_CALL_RESULT event: %#v", evt) } return @@ -523,15 +532,15 @@ func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { var snapshot []agui.Message seenRunFinished := false for _, evt := range run.Events { - switch evt["type"] { + switch evt.Type() { case agui.EventMessagesSnapshot: if seenRunFinished { t.Fatal("final snapshot must be emitted before RUN_FINISHED") } var ok bool - snapshot, ok = evt["messages"].([]agui.Message) + snapshot, ok = evt.Get("messages").([]agui.Message) if !ok { - t.Fatalf("bad snapshot payload: %#v", evt["messages"]) + t.Fatalf("bad snapshot payload: %#v", evt.Get("messages")) } case agui.EventRunFinished: seenRunFinished = true @@ -563,22 +572,22 @@ func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { seenFailure := false seenInputError := false for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallResult && evt["type"] != agui.EventToolCallArgs { + if evt.Type() != agui.EventToolCallResult && evt.Type() != agui.EventToolCallArgs { continue } - toolCallID, _ := evt["toolCallId"].(string) - if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { + toolCallID, _ := evt.Get("toolCallId").(string) + if evt.Type() == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { t.Fatalf("delta tool without real input should not emit placeholder args: %#v", evt) } - if evt["type"] == agui.EventToolCallResult { + if evt.Type() == agui.EventToolCallResult { if strings.Contains(toolCallID, "shell") { - result := jsonResultMap(t, evt["content"]) + result := jsonResultMap(t, evt.Get("content")) if result["state"] == agui.ToolResultStateError { seenFailure = true } } if strings.Contains(toolCallID, "parser") { - result := jsonResultMap(t, evt["content"]) + result := jsonResultMap(t, evt.Get("content")) if result["reason"] == "input-error" { seenInputError = true } @@ -596,14 +605,14 @@ func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - raw, ok := evt["rawEvent"].(map[string]any) + raw, ok := evt.Get("rawEvent").(map[string]any) if !ok { continue } if raw["provider"] != "dummybridge" || raw["tool"] != "shell" { t.Fatalf("bad raw provider event: %#v", raw) } - carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + carriers, err := aistream.PackRun(*run, aistream.CarrierBudgetBytes) if err != nil { t.Fatal(err) } @@ -637,7 +646,7 @@ func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { if err != nil { t.Fatal(err) } - carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + carriers, err := aistream.PackRun(*run, aistream.CarrierBudgetBytes) if err != nil { t.Fatal(err) } @@ -645,16 +654,16 @@ func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { t.Fatalf("expected split carriers, got %d", len(carriers)) } for i, carrier := range carriers { - if size := aistream.JSONSize(aistream.CarrierContent(carrier.Envelopes)); size > aistream.CarrierBudgetBytes { + if size := aistream.JSONSize(aistream.CarrierContent(*run, carrier.Envelopes)); size > aistream.CarrierBudgetBytes { t.Fatalf("carrier %d size = %d", i, size) } } for _, carrier := range carriers { for _, envelope := range carrier.Envelopes { - if envelope.Part["type"] != agui.EventMessagesSnapshot { + if envelope.Event.Type() != agui.EventMessagesSnapshot { continue } - raw, err := json.Marshal(envelope.Part) + raw, err := json.Marshal(envelope.Event) if err != nil { t.Fatal(err) } @@ -701,9 +710,9 @@ func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { } var first, last int64 for _, evt := range run.Events { - ts, _ := evt["timestamp"].(int64) + ts, _ := evt.Get("timestamp").(int64) if ts == 0 { - if n, ok := evt["timestamp"].(int); ok { + if n, ok := evt.Get("timestamp").(int); ok { ts = int64(n) } } @@ -942,6 +951,12 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { if len(firstInterrupts(t, run.Events)) != 1 { t.Fatalf("expected continuation with pending approval to finish with one interrupt: %#v", run.Events) } + if len(run.Interrupts) != 1 || run.Interrupts[0].ID != run.Prompts[0].ID { + t.Fatalf("pending continuation should expose only the new interrupt: prompts=%#v interrupts=%#v", run.Prompts, run.Interrupts) + } + if run.ApprovalID != run.Prompts[0].ID || run.ToolCallID != run.Prompts[0].ToolCallID { + t.Fatalf("pending continuation should target the new approval: prompts=%#v approval=%q tool=%q", run.Prompts, run.ApprovalID, run.ToolCallID) + } secondCtx := aistream.ApprovalContext{ ID: run.Prompts[0].ID, @@ -975,6 +990,9 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { if len(finished.Prompts) != 0 { t.Fatalf("finished continuation should not keep prompts: %#v", finished.Prompts) } + if len(finished.Interrupts) != 0 || finished.ApprovalID != "" || finished.ToolCallID != "" { + t.Fatalf("finished continuation kept pending approval state: interrupts=%#v approval=%q tool=%q", finished.Interrupts, finished.ApprovalID, finished.ToolCallID) + } } func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { @@ -1020,7 +1038,7 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { if len(continuation.Events) == 0 { t.Fatalf("expected continuation events for random run, got none") } - if continuation.Events[0]["type"] != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != approvalCtx.ID { + if continuation.Events[0].Type() != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != approvalCtx.ID { t.Fatalf("first continuation event should be approval tool result, got %#v", continuation.Events[0]) } return @@ -1069,7 +1087,7 @@ func jsonResultMap(t *testing.T, value any) map[string]any { func firstInterrupts(t *testing.T, events []agui.Event) []agui.Interrupt { t.Helper() for _, evt := range events { - if evt["type"] != agui.EventRunFinished { + if evt.Type() != agui.EventRunFinished { continue } interrupts := eventInterrupts(t, evt) @@ -1082,7 +1100,7 @@ func firstInterrupts(t *testing.T, events []agui.Event) []agui.Interrupt { func eventInterrupts(t *testing.T, evt agui.Event) []agui.Interrupt { t.Helper() - switch outcome := evt["outcome"].(type) { + switch outcome := evt.Get("outcome").(type) { case agui.RunFinishedOutcome: if outcome.Type != agui.OutcomeInterrupt { return nil @@ -1116,7 +1134,7 @@ func eventInterrupts(t *testing.T, evt agui.Event) []agui.Interrupt { } return interrupts default: - t.Fatalf("bad outcome payload: %#v", evt["outcome"]) + t.Fatalf("bad outcome payload: %#v", evt.Get("outcome")) return nil } } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 7a1940d..2a9fb66 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -53,7 +53,8 @@ var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) const ( - dummyAIAgentName string = "Dummy" + dummyAIAgentName string = "Dummy" + defaultAIApprovalTimeout = 5 * time.Minute ) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -262,9 +263,6 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M if dc == nil || dc.UserLogin == nil || msg == nil || msg.TargetMessage == nil || msg.Content == nil || msg.Portal == nil { return &database.Reaction{}, nil } - if isApprovalOptionReaction(msg) { - return &database.Reaction{}, nil - } approvalID := string(msg.TargetMessage.ID) if !strings.HasPrefix(approvalID, "approval-") { return &database.Reaction{}, nil @@ -307,14 +305,6 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M return &database.Reaction{}, nil } -func isApprovalOptionReaction(msg *bridgev2.MatrixReaction) bool { - if msg == nil || msg.Event == nil { - return false - } - _, ok := msg.Event.Content.Raw["com.beeper.ai.approval_option"] - return ok -} - func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (string, bool) { dc.approvalSelectionsOnce.Do(func() { dc.approvalSelections = exsync.NewMap[string, string]() @@ -591,25 +581,27 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { sizingRun := run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) - carriers, err := aistream.PackRunFromSeq(sizingRun, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + carriers, err := aistream.PackRunFromSeq(sizingRun, aistream.CarrierBudgetBytes, startSeq) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return } - carriers = splitCarriersForTimedEmission(carriers) nextSeq := aistream.NextSeq(carriers) - queuedPrompts := run.Prompts - if len(queuedPrompts) > 1 { + approvalQueue := aistream.NewApprovalQueue(aistream.ApprovalTimeout{After: defaultAIApprovalTimeout}) + approvalQueue.AddAll(run.Prompts) + activePrompt, hasActivePrompt := approvalQueue.Active() + if pending := approvalQueue.Pending(); len(pending) > 0 { log.Warn(). Str("run_id", run.RunID). - Int("approval_prompts", len(queuedPrompts)). - Msg("AI run produced multiple simultaneous approval prompts; queueing the first prompt only") - queuedPrompts = queuedPrompts[:1] + Int("pending_approval_prompts", len(pending)). + Msg("AI run produced multiple approval prompts; keeping one active interrupt and queueing the rest") } - approvalEventIDs := make(map[string]id.EventID, len(queuedPrompts)) - for _, prompt := range queuedPrompts { + approvalEventIDs := make(map[string]id.EventID, 1) + if hasActivePrompt { + prompt := activePrompt prompt.SeqStart = nextSeq ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) + dc.scheduleAIApprovalTimeout(portal, networkid.MessageID(ctx.ID), approvalQueue.Timeout()) if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { approvalEventIDs[ctx.ID] = approvalEventID log.Info(). @@ -628,12 +620,11 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid } if len(approvalEventIDs) > 0 { annotateApprovalEventIDs(&run, approvalEventIDs) - carriers, err = aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + carriers, err = aistream.PackRunFromSeq(run, aistream.CarrierBudgetBytes, startSeq) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to repack AI stream with approval event IDs") return } - carriers = splitCarriersForTimedEmission(carriers) if actualNextSeq := aistream.NextSeq(carriers); actualNextSeq != nextSeq { log.Warn(). Str("run_id", run.RunID). @@ -642,10 +633,9 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid Msg("AI approval event ID repack changed stream sequence count") return } - } else if len(queuedPrompts) > 0 { + } else if hasActivePrompt { log.Info(). Str("run_id", run.RunID). - Int("approval_prompts", len(queuedPrompts)). Msg("Sending approval stream without approval event IDs") } dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) @@ -661,6 +651,35 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid } } +func (dc *DummyClient) scheduleAIApprovalTimeout(portal *bridgev2.Portal, approvalMessageID networkid.MessageID, timeout aistream.ApprovalTimeout) { + if dc == nil || portal == nil || approvalMessageID == "" || timeout.After <= 0 { + return + } + dc.wg.Add(1) + go func() { + defer dc.wg.Done() + timer := time.NewTimer(timeout.After) + defer timer.Stop() + select { + case <-dc.clientContext().Done(): + return + case <-timer.C: + } + approvalID := string(approvalMessageID) + if _, firstResolution := dc.resolveApprovalOnce(approvalID, timeout.Reason); !firstResolution { + return + } + ctx := dc.clientContext() + approvalMessage, err := dc.lookupMessage(ctx, portal.Receiver, approvalMessageID) + if err != nil || approvalMessage == nil { + log.Warn().Err(err).Str("approval_id", approvalID).Msg("Timed-out AI approval message was not found") + return + } + response := aistream.TimedOutApprovalResponse(approvalID) + dc.queueAIApprovalResponse(ctx, portal, approvalMessage, response) + }() +} + func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { streamStart := time.Now() // minCarrierTimestamp guarantees every carrier lands strictly after the @@ -681,20 +700,6 @@ func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender net } } -func splitCarriersForTimedEmission(carriers []aistream.Carrier) []aistream.Carrier { - out := make([]aistream.Carrier, 0, len(carriers)) - for _, carrier := range carriers { - if len(carrier.Envelopes) <= 1 { - out = append(out, carrier) - continue - } - for _, env := range carrier.Envelopes { - out = append(out, aistream.Carrier{Envelopes: []aistream.Envelope{env}}) - } - } - return out -} - func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { target := carrierTimestamp(run, carrier, streamStart) if target.IsZero() { @@ -719,7 +724,7 @@ func carrierTimestamp(run aistream.Run, carrier aistream.Carrier, streamStart ti } var latest time.Time for _, env := range carrier.Envelopes { - eventTime := eventTimestamp(env.Part) + eventTime := eventTimestamp(env.Event) if eventTime.IsZero() { continue } @@ -743,8 +748,8 @@ func runStartTimestamp(run aistream.Run) time.Time { } func eventTimestamp(evt agui.Event) time.Time { - raw, ok := evt["timestamp"] - if !ok { + raw := evt.Get("timestamp") + if !evt.Has("timestamp") { return time.Time{} } var millis int64 @@ -808,13 +813,20 @@ func (dc *DummyClient) waitForMessageMXID( } func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) id.EventID { - message, err := dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, receiver, messageID) + message, err := dc.lookupMessage(ctx, receiver, messageID) if err != nil || message == nil { return "" } return message.MXID } +func (dc *DummyClient) lookupMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) (*database.Message, error) { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil { + return nil, nil + } + return dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, receiver, messageID) +} + func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) aistream.ApprovalContext { choices := aistream.DefaultApprovalChoices() approvalCtx := aistream.ApprovalContext{ @@ -854,7 +866,7 @@ func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) aistream.SetApprovalInterruptEventID(&run.Interrupts[i], string(eventID)) } for _, evt := range run.Events { - if evt["type"] != agui.EventRunFinished { + if evt.Type() != agui.EventRunFinished { continue } annotateApprovalOutcomeEventIDs(evt, eventIDs) @@ -862,7 +874,7 @@ func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) } func annotateApprovalOutcomeEventIDs(evt agui.Event, eventIDs map[string]id.EventID) { - switch outcome := evt["outcome"].(type) { + switch outcome := evt.Get("outcome").(type) { case agui.RunFinishedOutcome: for i := range outcome.Interrupts { eventID := eventIDs[outcome.Interrupts[i].ID] @@ -871,7 +883,7 @@ func annotateApprovalOutcomeEventIDs(evt agui.Event, eventIDs map[string]id.Even } aistream.SetApprovalInterruptEventID(&outcome.Interrupts[i], string(eventID)) } - evt["outcome"] = outcome + evt.Set("outcome", outcome) case *agui.RunFinishedOutcome: if outcome == nil { return @@ -976,13 +988,19 @@ func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCt run.RunID = approvalCtx.RunID run.ThreadID = approvalCtx.ThreadID run.MessageID = approvalCtx.MessageID - run.ToolCallID = approvalCtx.ToolCallID - run.ApprovalID = approvalCtx.ID // Keep only prompts that the continuation segment newly emitted (i.e. // approvals raised by tools that ran AFTER the resolved one). The // already-resolved approval has been removed from the event range above // and must not be queued again. run.Prompts = filterPendingPrompts(run.Prompts, approvalCtx.ID, run.Events) + run.Interrupts = filterPendingInterrupts(run.Interrupts, run.Prompts, run.Events) + if len(run.Prompts) > 0 { + run.ApprovalID = run.Prompts[0].ID + run.ToolCallID = run.Prompts[0].ToolCallID + } else { + run.ApprovalID = "" + run.ToolCallID = "" + } return *run, nil } @@ -1007,41 +1025,96 @@ func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, return out } +func filterPendingInterrupts(interrupts []agui.Interrupt, prompts []aistream.ApprovalPrompt, events []agui.Event) []agui.Interrupt { + if len(prompts) == 0 { + return nil + } + pending := make(map[string]bool, len(prompts)) + for _, prompt := range prompts { + pending[prompt.ID] = true + } + var out []agui.Interrupt + for _, interrupt := range approvalInterruptsFromEvents(events) { + if pending[interrupt.ID] { + out = append(out, interrupt) + } + } + if len(out) > 0 { + return out + } + for _, interrupt := range interrupts { + if pending[interrupt.ID] { + out = append(out, interrupt) + } + } + return out +} + func approvalInterruptIDsFromEvents(events []agui.Event) map[string]bool { requested := map[string]bool{} + for _, interrupt := range approvalInterruptsFromEvents(events) { + if interrupt.ID != "" { + requested[interrupt.ID] = true + } + } + return requested +} + +func approvalInterruptsFromEvents(events []agui.Event) []agui.Interrupt { + var interrupts []agui.Interrupt for _, evt := range events { - if evt["type"] != agui.EventRunFinished { + if evt.Type() != agui.EventRunFinished { continue } - switch outcome := evt["outcome"].(type) { + switch outcome := evt.Get("outcome").(type) { case agui.RunFinishedOutcome: if outcome.Type != agui.OutcomeInterrupt { continue } - for _, interrupt := range outcome.Interrupts { - if interrupt.ID != "" { - requested[interrupt.ID] = true - } - } + interrupts = append(interrupts, outcome.Interrupts...) case map[string]any: if outcome["type"] != agui.OutcomeInterrupt { continue } - interrupts, _ := outcome["interrupts"].([]any) - for _, raw := range interrupts { + rawInterrupts, _ := outcome["interrupts"].([]any) + for _, raw := range rawInterrupts { interrupt, _ := raw.(map[string]any) - if id, _ := interrupt["id"].(string); id != "" { - requested[id] = true + if len(interrupt) == 0 { + continue } + metadata, _ := interrupt["metadata"].(map[string]any) + responseSchema, _ := interrupt["responseSchema"].(map[string]any) + interrupts = append(interrupts, agui.Interrupt{ + ID: eventStringValue(interrupt["id"]), + Reason: eventStringValue(interrupt["reason"]), + Message: eventStringValue(interrupt["message"]), + ToolCallID: eventStringValue(interrupt["toolCallId"]), + ResponseSchema: responseSchema, + ExpiresAt: eventStringValue(interrupt["expiresAt"]), + Metadata: metadata, + }) } } } - return requested + return interrupts +} + +func eventStringValue(value any) string { + switch typed := value.(type) { + case string: + return typed + case fmt.Stringer: + return typed.String() + case nil: + return "" + default: + return fmt.Sprint(typed) + } } func approvalContinuationStart(events []agui.Event, approvalID string) int { for i, evt := range events { - if evt["type"] != agui.EventToolCallResult { + if evt.Type() != agui.EventToolCallResult { continue } if toolResultApprovalID(evt) == approvalID { @@ -1052,7 +1125,7 @@ func approvalContinuationStart(events []agui.Event, approvalID string) int { } func toolResultApprovalID(evt agui.Event) string { - content, _ := evt["content"].(string) + content, _ := evt.Get("content").(string) if content == "" { return "" } @@ -1163,6 +1236,10 @@ func validApprovalContext(ctx aistream.ApprovalContext) (aistream.ApprovalContex } func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { + targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) + for _, segment := range aibridgev2.FinalSegments(portal.PortalKey, sender, run, targetEventID, time.Now()) { + dc.UserLogin.QueueRemoteEvent(segment) + } dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, sender, messageID, run, time.Now())) } diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 1c12db6..07747ca 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -42,20 +42,14 @@ func TestGetRemoteEchoBehavior(t *testing.T) { func TestSleepUntilCarrierTimeWithoutConnectedContext(t *testing.T) { base := time.Now() + builder := agui.NewEventBuilder("dummybridge/test", func() time.Time { return base }) run := aistream.Run{ - Events: []agui.Event{{ - "type": agui.EventRunStarted, - "timestamp": base.UnixMilli(), - "threadId": "thread-1", - }}, + Events: []agui.Event{builder.RunStarted("thread-1", "run-1")}, } + builder = agui.NewEventBuilder("dummybridge/test", func() time.Time { return base.Add(time.Millisecond) }) carrier := aistream.Carrier{ Envelopes: []aistream.Envelope{{ - Part: agui.Event{ - "type": agui.EventTextMessageContent, - "timestamp": base.Add(time.Millisecond).UnixMilli(), - "messageId": "message-1", - }, + Event: builder.TextMessageContent("message-1", "hello"), }}, } From e097699dc6179210fb6bea80bb66c944e9d1d970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Thu, 28 May 2026 23:58:12 +0200 Subject: [PATCH 3/7] wip --- pkg/connector/ai_runtime_test.go | 74 ++++++------ pkg/connector/client.go | 198 +++++++++---------------------- 2 files changed, 96 insertions(+), 176 deletions(-) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 4a5038a..2b85ff1 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -95,13 +95,12 @@ func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { t.Fatalf("missing %s in events", key) } } - metadata := run.Metadata() - if metadata["model"] == "" || metadata["threadId"] != "thread-1" || metadata["runId"] != "run-1" { - t.Fatalf("bad metadata: %#v", metadata) + payload := run.AI(aistream.AIKindFinal) + if payload.Model == "" || payload.ThreadID != "thread-1" || payload.RunID != "run-1" { + t.Fatalf("bad AI payload: %#v", payload) } - data := metadata["data"].(map[string]any) - if _, ok := data["temp"]; ok { - t.Fatalf("transient data leaked into final metadata: %#v", data) + if _, ok := payload.Data["temp"]; ok { + t.Fatalf("transient data leaked into final AI payload: %#v", payload.Data) } } @@ -202,7 +201,7 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { if err != nil { t.Fatal(err) } - carriers, err := aistream.PackRunFromSeq(*run, aistream.CarrierBudgetBytes, 1) + carriers, err := aistream.PackRunByTimeFromSeq(*run, 1, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } @@ -233,7 +232,7 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { if err != nil { t.Fatal(err) } - continuationCarriers, err := aistream.PackRunFromSeq(continuation, aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + continuationCarriers, err := aistream.PackRunByTimeFromSeq(continuation, approvalCtx.SeqStart, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } @@ -256,7 +255,7 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { } sizingRun := *run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) - initialCarriers, err := aistream.PackRunFromSeq(sizingRun, aistream.CarrierBudgetBytes, 1) + initialCarriers, err := aistream.PackRunByTimeFromSeq(sizingRun, 1, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } @@ -287,7 +286,7 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { } annotateApprovalEventIDs(run, map[string]id.EventID{prompt.ID: "$approval"}) - annotatedCarriers, err := aistream.PackRunFromSeq(*run, aistream.CarrierBudgetBytes, 1) + annotatedCarriers, err := aistream.PackRunByTimeFromSeq(*run, 1, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } @@ -339,7 +338,7 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if continuation.Status.State != "complete" { t.Fatalf("approved continuation should finish the run, got %#v", continuation.Status) } - continuationCarriers, err := aistream.PackRunFromSeq(continuation, aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + continuationCarriers, err := aistream.PackRunByTimeFromSeq(continuation, approvalCtx.SeqStart, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } @@ -612,7 +611,7 @@ func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { if raw["provider"] != "dummybridge" || raw["tool"] != "shell" { t.Fatalf("bad raw provider event: %#v", raw) } - carriers, err := aistream.PackRun(*run, aistream.CarrierBudgetBytes) + carriers, err := aistream.PackRun(*run) if err != nil { t.Fatal(err) } @@ -641,40 +640,45 @@ func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { } } -func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { +func TestBuildAIRunOver64KBStreamsWithoutCarrierSizeSplit(t *testing.T) { run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 1 --chars=70000 --actions=1 --seed=7", time.Unix(10, 0)) if err != nil { t.Fatal(err) } - carriers, err := aistream.PackRun(*run, aistream.CarrierBudgetBytes) + carriers, err := aistream.PackRun(*run) if err != nil { t.Fatal(err) } - if len(carriers) < 2 { - t.Fatalf("expected split carriers, got %d", len(carriers)) - } - for i, carrier := range carriers { - if size := aistream.JSONSize(aistream.CarrierContent(*run, carrier.Envelopes)); size > aistream.CarrierBudgetBytes { - t.Fatalf("carrier %d size = %d", i, size) - } - } - for _, carrier := range carriers { - for _, envelope := range carrier.Envelopes { - if envelope.Event.Type() != agui.EventMessagesSnapshot { - continue - } - raw, err := json.Marshal(envelope.Event) - if err != nil { - t.Fatal(err) - } - if strings.Contains(string(raw), strings.Repeat("a", 60*1024)) { - t.Fatal("final snapshot should not repeat full streamed text") - } - } + if len(carriers) != 1 { + t.Fatalf("stream packing must not split by size, got %d carriers", len(carriers)) } if len(aistream.ReconstructText(carriers)) < 60*1024 { t.Fatalf("expected large reconstructed output, got %d", len(aistream.ReconstructText(carriers))) } + _, segments := aistream.FinalUIMessageContent(*run, aistream.FinalMessageBudgetBytes) + if len(segments) == 0 { + t.Fatal("large final UIMessage should be segmented during finalization") + } +} + +func TestBuildAIRunStream50PacksByCadence(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 50 --seed=7 --no-approval", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + carriers, err := aistream.PackRunByTimeFromSeq(*run, 1, demoStreamCarrierMaxSpan) + if err != nil { + t.Fatal(err) + } + if len(carriers) < 10 { + t.Fatalf("stream 50 should produce incremental carriers, got %d", len(carriers)) + } + start := time.Unix(100, 0) + first := aistream.CarrierTimestamp(*run, carriers[0], start) + last := aistream.CarrierTimestamp(*run, carriers[len(carriers)-1], start) + if first.IsZero() || last.IsZero() || !last.After(first) { + t.Fatalf("carrier timestamps should preserve stream cadence, first=%s last=%s", first, last) + } } func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 2a9fb66..996eac9 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -55,6 +55,7 @@ var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) const ( dummyAIAgentName string = "Dummy" defaultAIApprovalTimeout = 5 * time.Minute + demoStreamCarrierMaxSpan = 750 * time.Millisecond ) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -581,7 +582,7 @@ func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, send func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { sizingRun := run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) - carriers, err := aistream.PackRunFromSeq(sizingRun, aistream.CarrierBudgetBytes, startSeq) + carriers, err := aistream.PackRunByTimeFromSeq(sizingRun, startSeq, demoStreamCarrierMaxSpan) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return @@ -620,7 +621,7 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid } if len(approvalEventIDs) > 0 { annotateApprovalEventIDs(&run, approvalEventIDs) - carriers, err = aistream.PackRunFromSeq(run, aistream.CarrierBudgetBytes, startSeq) + carriers, err = aistream.PackRunByTimeFromSeq(run, startSeq, demoStreamCarrierMaxSpan) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to repack AI stream with approval event IDs") return @@ -701,7 +702,7 @@ func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender net } func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { - target := carrierTimestamp(run, carrier, streamStart) + target := aistream.CarrierTimestamp(run, carrier, streamStart) if target.IsZero() { return } @@ -717,66 +718,6 @@ func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream. } } -func carrierTimestamp(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) time.Time { - base := runStartTimestamp(run) - if base.IsZero() { - return time.Time{} - } - var latest time.Time - for _, env := range carrier.Envelopes { - eventTime := eventTimestamp(env.Event) - if eventTime.IsZero() { - continue - } - if latest.IsZero() || eventTime.After(latest) { - latest = eventTime - } - } - if latest.IsZero() { - return time.Time{} - } - return streamStart.Add(latest.Sub(base)) -} - -func runStartTimestamp(run aistream.Run) time.Time { - for _, evt := range run.Events { - if ts := eventTimestamp(evt); !ts.IsZero() { - return ts - } - } - return time.Time{} -} - -func eventTimestamp(evt agui.Event) time.Time { - raw := evt.Get("timestamp") - if !evt.Has("timestamp") { - return time.Time{} - } - var millis int64 - switch value := raw.(type) { - case int64: - millis = value - case int: - millis = int64(value) - case int32: - millis = int64(value) - case float64: - millis = int64(value) - case json.Number: - parsed, err := value.Int64() - if err != nil { - return time.Time{} - } - millis = parsed - default: - return time.Time{} - } - if millis <= 0 { - return time.Time{} - } - return time.UnixMilli(millis) -} - func (dc *DummyClient) waitForMessageMXID( portal *bridgev2.Portal, messageID networkid.MessageID, @@ -828,7 +769,6 @@ func (dc *DummyClient) lookupMessage(ctx context.Context, receiver networkid.Use } func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) aistream.ApprovalContext { - choices := aistream.DefaultApprovalChoices() approvalCtx := aistream.ApprovalContext{ ID: prompt.ID, ThreadID: run.ThreadID, @@ -846,11 +786,6 @@ func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender net PreviewTruncated: run.Preview.Truncated, } dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, sender, approvalCtx, timestamp)) - - for i, choice := range choices { - choice := choice - dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, sender, approvalCtx, choice, timestamp.Add(time.Duration(i+1)*time.Millisecond))) - } return approvalCtx } @@ -895,23 +830,6 @@ func annotateApprovalOutcomeEventIDs(evt agui.Event, eventIDs map[string]id.Even } aistream.SetApprovalInterruptEventID(&outcome.Interrupts[i], string(eventID)) } - case map[string]any: - interrupts, _ := outcome["interrupts"].([]any) - for _, raw := range interrupts { - interrupt, _ := raw.(map[string]any) - approvalID, _ := interrupt["id"].(string) - eventID := eventIDs[approvalID] - if interrupt == nil || eventID == "" { - continue - } - metadata, _ := interrupt["metadata"].(map[string]any) - if metadata == nil { - metadata = map[string]any{} - interrupt["metadata"] = metadata - } - metadata["approvalMessageId"] = approvalID - metadata["approvalEventId"] = string(eventID) - } } } @@ -1072,46 +990,16 @@ func approvalInterruptsFromEvents(events []agui.Event) []agui.Interrupt { continue } interrupts = append(interrupts, outcome.Interrupts...) - case map[string]any: - if outcome["type"] != agui.OutcomeInterrupt { + case *agui.RunFinishedOutcome: + if outcome == nil || outcome.Type != agui.OutcomeInterrupt { continue } - rawInterrupts, _ := outcome["interrupts"].([]any) - for _, raw := range rawInterrupts { - interrupt, _ := raw.(map[string]any) - if len(interrupt) == 0 { - continue - } - metadata, _ := interrupt["metadata"].(map[string]any) - responseSchema, _ := interrupt["responseSchema"].(map[string]any) - interrupts = append(interrupts, agui.Interrupt{ - ID: eventStringValue(interrupt["id"]), - Reason: eventStringValue(interrupt["reason"]), - Message: eventStringValue(interrupt["message"]), - ToolCallID: eventStringValue(interrupt["toolCallId"]), - ResponseSchema: responseSchema, - ExpiresAt: eventStringValue(interrupt["expiresAt"]), - Metadata: metadata, - }) - } + interrupts = append(interrupts, outcome.Interrupts...) } } return interrupts } -func eventStringValue(value any) string { - switch typed := value.(type) { - case string: - return typed - case fmt.Stringer: - return typed.String() - case nil: - return "" - default: - return fmt.Sprint(typed) - } -} - func approvalContinuationStart(events []agui.Event, approvalID string) int { for i, evt := range events { if evt.Type() != agui.EventToolCallResult { @@ -1178,44 +1066,72 @@ func approvalContextFromAny(value any) (aistream.ApprovalContext, bool) { } return validApprovalContext(*typed) case map[string]any: - if nested, ok := typed["com.beeper.ai.approval"]; ok { + if nested, ok := typed[aistream.BeeperAIApprovalKey]; ok { return approvalContextFromAny(nested) } - case *map[string]any: - if typed == nil { - return aistream.ApprovalContext{}, false - } - return approvalContextFromAny(*typed) + return validApprovalContext(approvalContextFromMap(typed)) case json.RawMessage: return approvalContextFromJSON(typed) case []byte: return approvalContextFromJSON(typed) - case string: - return approvalContextFromJSON([]byte(typed)) - } - var ctx aistream.ApprovalContext - raw, err := json.Marshal(value) - if err != nil { - return aistream.ApprovalContext{}, false } - if err = json.Unmarshal(raw, &ctx); err != nil { - return aistream.ApprovalContext{}, false + return aistream.ApprovalContext{}, false +} + +func approvalContextFromMap(raw map[string]any) aistream.ApprovalContext { + return aistream.ApprovalContext{ + ID: stringField(raw, "id"), + ThreadID: stringField(raw, "threadId"), + RunID: stringField(raw, "runId"), + MessageID: stringField(raw, "messageId"), + Command: stringField(raw, "command"), + ToolCallID: stringField(raw, "toolCallId"), + ToolName: stringField(raw, "toolName"), + TargetEvent: stringField(raw, "targetEvent"), + AgentID: stringField(raw, "agentId"), + AgentName: stringField(raw, "agentName"), + Model: stringField(raw, "model"), + SeqStart: intField(raw, "seqStart"), + PreviewText: stringField(raw, "previewText"), + PreviewTruncated: boolField(raw, "previewTruncated"), } - return validApprovalContext(ctx) } func approvalContextFromJSON(raw []byte) (aistream.ApprovalContext, bool) { - var decoded any - if err := json.Unmarshal(raw, &decoded); err == nil { - if approvalCtx, ok := approvalContextFromAny(decoded); ok { + var ctx aistream.ApprovalContext + if err := json.Unmarshal(raw, &ctx); err == nil { + if approvalCtx, ok := validApprovalContext(ctx); ok { return approvalCtx, true } } - var ctx aistream.ApprovalContext - if err := json.Unmarshal(raw, &ctx); err != nil { + var wrapper map[string]any + if err := json.Unmarshal(raw, &wrapper); err != nil { return aistream.ApprovalContext{}, false } - return validApprovalContext(ctx) + return approvalContextFromAny(wrapper) +} + +func stringField(raw map[string]any, key string) string { + value, _ := raw[key].(string) + return value +} + +func intField(raw map[string]any, key string) int { + switch value := raw[key].(type) { + case int: + return value + case int64: + return int(value) + case float64: + return int(value) + default: + return 0 + } +} + +func boolField(raw map[string]any, key string) bool { + value, _ := raw[key].(bool) + return value } func messageIDString(message *database.Message) string { From da36e6601b79cbba931939aeaf0c159e6303d0d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Fri, 29 May 2026 00:16:29 +0200 Subject: [PATCH 4/7] wip --- pkg/connector/ai_parse_helpers.go | 87 ++++++++++++++++++++++++++++--- pkg/connector/ai_runtime_test.go | 54 +++++++++++++++++++ pkg/connector/ai_text.go | 27 ++++++---- 3 files changed, 152 insertions(+), 16 deletions(-) diff --git a/pkg/connector/ai_parse_helpers.go b/pkg/connector/ai_parse_helpers.go index 9b06d7d..a942bef 100644 --- a/pkg/connector/ai_parse_helpers.go +++ b/pkg/connector/ai_parse_helpers.go @@ -146,16 +146,89 @@ func sliceByStep(text string, parts, index int) string { if parts <= 1 || text == "" { return text } - start := 0 - for i := 0; i < index; i++ { - start += splitCount(len(text), parts, i) + units := naturalTextUnits(text) + if len(units) == 0 { + return "" + } + if parts >= len(units) { + if index >= 0 && index < len(units) { + return units[index] + } + return "" + } + + cumulative := make([]int, len(units)+1) + for i, unit := range units { + cumulative[i+1] = cumulative[i] + len(unit) + 2 + } + boundary := func(step int) int { + if step <= 0 { + return 0 + } + if step >= parts { + return len(units) + } + target := cumulative[len(units)] * step / parts + out := 0 + for out < len(units) && cumulative[out] < target { + out++ + } + if out < step { + out = step + } + maxBoundary := len(units) - (parts - step) + if out > maxBoundary { + out = maxBoundary + } + return out } - length := splitCount(len(text), parts, index) - if start >= len(text) || length <= 0 { + start := boundary(index) + end := boundary(index + 1) + if start >= end || start < 0 || end > len(units) { return "" } - end := min(start+length, len(text)) - return text[start:end] + return strings.Join(units[start:end], "\n\n") +} + +func naturalTextUnits(text string) []string { + var units []string + for _, block := range strings.Split(text, "\n\n") { + block = strings.TrimSpace(block) + if block == "" { + continue + } + if isMarkdownSensitiveBlock(block) { + units = append(units, block) + continue + } + units = append(units, splitSentences(block)...) + } + return units +} + +func splitSentences(text string) []string { + var sentences []string + start := 0 + for i := 0; i < len(text); i++ { + switch text[i] { + case '.', '!', '?': + if i+1 < len(text) && text[i+1] != ' ' && text[i+1] != '\n' { + continue + } + sentence := strings.TrimSpace(text[start : i+1]) + if sentence != "" { + sentences = append(sentences, sentence) + } + start = i + 1 + for start < len(text) && (text[start] == ' ' || text[start] == '\n') { + start++ + } + } + } + if tail := strings.TrimSpace(text[start:]); tail != "" { + sentences = append(sentences, tail) + } + return sentences } func sanitizeToolName(name string) string { diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 2b85ff1..3c23520 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -563,6 +563,26 @@ func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { } } +func TestBuildAIRunFinalUIMessagePreservesTextToolTextOrder(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 420 fetch search --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + message := run.FinalBeeperAIMessage(0, true) + var order []string + for _, part := range message.Parts { + switch part["type"] { + case "text": + order = append(order, "text") + case "tool-call": + order = append(order, "tool-call") + } + } + if strings.Join(order, "|") != "text|tool-call|text|tool-call|text" { + t.Fatalf("final UIMessage did not preserve text/tool order: %v\nparts: %#v", order, message.Parts) + } +} + func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#fail fetch#delta parser#inputerror --seed=7 --chunk-chars=8:8", time.Unix(10, 0)) if err != nil { @@ -874,6 +894,40 @@ func TestBuildDemoVisibleTextDoesNotCutMarkdownSyntax(t *testing.T) { if strings.Contains(text, "https://dummybridge.") && !strings.Contains(text, "https://dummybridge.local/") { t.Fatalf("cut markdown URL for chars=%d seed=%d: %q", chars, seed, text) } + if joinedMarkdownBlockRE.MatchString(text) { + t.Fatalf("markdown block joined to incomplete text for chars=%d seed=%d: %q", chars, seed, text) + } + } + } +} + +func TestSliceByStepKeepsNaturalTextUnits(t *testing.T) { + text := strings.Join([]string{ + "First complete sentence. Second complete sentence.", + "Review the [release notes](https://dummybridge.local/docs/streaming) entry for **review-ready** output.", + "Third complete sentence. Fourth complete sentence.", + }, "\n\n") + + parts := []string{ + sliceByStep(text, 3, 0), + sliceByStep(text, 3, 1), + sliceByStep(text, 3, 2), + } + joined := strings.Join(parts, "\n\n") + for _, expected := range []string{ + "First complete sentence.", + "Second complete sentence.", + "Review the [release notes](https://dummybridge.local/docs/streaming) entry for **review-ready** output.", + "Third complete sentence.", + "Fourth complete sentence.", + } { + if !strings.Contains(joined, expected) { + t.Fatalf("sliced text lost %q:\n%#v", expected, parts) + } + } + for _, part := range parts { + if strings.HasSuffix(part, "complete") || strings.HasSuffix(part, "Review the") { + t.Fatalf("slice ended with a cut-off unit: %#v", parts) } } } diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go index 40a6945..689a2f7 100644 --- a/pkg/connector/ai_text.go +++ b/pkg/connector/ai_text.go @@ -84,6 +84,10 @@ func buildLoremText(chars int, rng *rand.Rand) string { return trimText(sb.String(), chars) } +func buildCompleteLoremText(chars int, rng *rand.Rand) string { + return trimCompleteText(buildLoremText(chars+128, rng), chars) +} + func buildDemoVisibleText(chars int, rng *rand.Rand) string { if chars <= 0 { return "" @@ -96,8 +100,8 @@ func buildDemoVisibleText(chars int, rng *rand.Rand) string { return buildLoremText(max(48, min(168, remaining+48)), rand.New(rand.NewSource(rng.Int63()))) }}, {weight: 4, minLen: 96, build: func(rng *rand.Rand, _ int) string { - return fmt.Sprintf("%s Review the [%s](%s) entry for **%s** output and _staged_ formatting transitions.", - buildLoremText(72+rng.Intn(48), rand.New(rand.NewSource(rng.Int63()))), + return fmt.Sprintf("%s\n\nReview the [%s](%s) entry for **%s** output and _staged_ formatting transitions.", + buildCompleteLoremText(72+rng.Intn(48), rand.New(rand.NewSource(rng.Int63()))), demoMarkdownLabels[rng.Intn(len(demoMarkdownLabels))], demoMarkdownURLs[rng.Intn(len(demoMarkdownURLs))], demoMarkdownEmphasis[rng.Intn(len(demoMarkdownEmphasis))]) @@ -183,7 +187,7 @@ func trimVisibleText(text string, limit int) string { if isMarkdownSensitiveBlock(block) { kept = append(kept, trimMarkdownBlock(block, limit)) } else { - kept = append(kept, trimText(block, limit)) + kept = append(kept, trimCompleteText(block, limit)) } } break @@ -194,7 +198,7 @@ func trimVisibleText(text string, limit int) string { if len(kept) > 0 { return strings.Join(kept, "\n\n") } - return trimText(text, limit) + return trimCompleteText(text, limit) } func isMarkdownSensitiveBlock(block string) bool { @@ -224,12 +228,16 @@ func trimMarkdownBlock(block string, limit int) string { } } if trimmed == "" { - return trimText(block, limit) + return trimCompleteText(block, limit) } return trimmed } func trimText(text string, limit int) string { + return trimCompleteText(text, limit) +} + +func trimCompleteText(text string, limit int) string { text = strings.TrimSpace(text) if limit <= 0 || len(text) <= limit { return text @@ -241,10 +249,11 @@ func trimText(text string, limit int) string { return strings.TrimSpace(text[:i]) } } - for i := min(limit, len(text)); i >= minCutoff; i-- { - if text[i-1] == ' ' { - return strings.Trim(strings.TrimSpace(text[:i]), ".,;:") + for i := min(limit+128, len(text)); i > limit; i++ { + switch text[i-1] { + case '.', '!', '?': + return strings.TrimSpace(text[:i]) } } - return strings.Trim(strings.TrimSpace(text[:limit]), ".,;:") + return text } From a17771e2c4391818dec5bd4c1b120877cb4ec41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 31 May 2026 22:19:12 +0200 Subject: [PATCH 5/7] wip --- pkg/connector/ai_runtime_test.go | 10 +- pkg/connector/client.go | 242 ++++++++++++++++++++++++++----- pkg/connector/commands.go | 10 +- 3 files changed, 217 insertions(+), 45 deletions(-) diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index 3c23520..51b0752 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -12,6 +12,7 @@ import ( "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" + aimatrix "github.com/beeper/ai-bridge/pkg/ai-stream/matrix" "maunium.net/go/mautrix/id" ) @@ -675,9 +676,12 @@ func TestBuildAIRunOver64KBStreamsWithoutCarrierSizeSplit(t *testing.T) { if len(aistream.ReconstructText(carriers)) < 60*1024 { t.Fatalf("expected large reconstructed output, got %d", len(aistream.ReconstructText(carriers))) } - _, segments := aistream.FinalUIMessageContent(*run, aistream.FinalMessageBudgetBytes) - if len(segments) == 0 { - t.Fatal("large final UIMessage should be segmented during finalization") + projection := aimatrix.ProjectFinal(*run, nil) + if !projection.NeedsAttachment { + t.Fatal("large final UIMessage should use final-parts attachment projection") + } + if len(projection.Message.Parts) == 0 { + t.Fatal("large final projection should preserve full UIMessage parts for attachment upload") } } diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 996eac9..19d21e9 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -13,6 +13,7 @@ import ( "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" aibridgev2 "github.com/beeper/ai-bridge/pkg/ai-stream/bridgev2" + aimatrix "github.com/beeper/ai-bridge/pkg/ai-stream/matrix" "github.com/rs/zerolog/log" "go.mau.fi/util/exsync" "go.mau.fi/util/jsontime" @@ -507,20 +508,48 @@ func cloneMessageContent(content *event.MessageEventContent) *event.MessageEvent return &cloned } +type aiRunTarget struct { + portal *bridgev2.Portal + bot bridgev2.MatrixAPI + roomID id.RoomID + threadID string + sender networkid.UserID + agentName string +} + func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Portal, inbound *event.MessageEventContent) { if portal == nil { return } + dc.queueAIResponseToTarget(ctx, aiRunTarget{ + portal: portal, + threadID: string(portal.ID), + sender: dummyAISenderForPortal(portal), + agentName: dummyAIAgentNameForPortal(portal), + }, inbound) +} + +func (dc *DummyClient) queueAIResponseInRoom(ctx context.Context, bot bridgev2.MatrixAPI, roomID id.RoomID, inbound *event.MessageEventContent) { + if bot == nil || roomID == "" { + return + } + dc.queueAIResponseToTarget(ctx, aiRunTarget{ + bot: bot, + roomID: roomID, + threadID: string(roomID), + sender: networkid.UserID(dummyAIAgentName), + agentName: dummyAIAgentName, + }, inbound) +} +func (dc *DummyClient) queueAIResponseToTarget(ctx context.Context, target aiRunTarget, inbound *event.MessageEventContent) { now := time.Now() runID := "run-" + string(randomMessageID()) - sender := dummyAISenderForPortal(portal) - agentName := dummyAIAgentNameForPortal(portal) var body string if inbound != nil { body = inbound.Body } - plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), body, now, string(sender), agentName) + plans, err := buildAIRunPlans(ctx, runID, target.threadID, body, now, string(target.sender), target.agentName) if err != nil { log.Warn().Err(err).Msg("Failed to build AI runs") return @@ -536,7 +565,7 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por } dc.wg.Add(1) - go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { + go func(target aiRunTarget, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { defer dc.wg.Done() if delay > 0 { timer := time.NewTimer(delay) @@ -547,11 +576,14 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por return } } - dc.ensureAISenderInvited(portal, sender) anchorAt := time.Now() - dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(run), anchorAt)) - dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run, command, anchorAt) - }(portal, sender, placeholderID, *plan.Run, effectiveCommand, plan.Delay) + targetEventID, err := target.sendAnchor(dc, initialAIAnchorRun(run), messageID, anchorAt) + if err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to send AI anchor") + return + } + dc.emitAIRunStream(target, messageID, targetEventID, run, command, 1, anchorAt) + }(target, placeholderID, *plan.Run, effectiveCommand, plan.Delay) } } @@ -562,24 +594,12 @@ func initialAIAnchorRun(run aistream.Run) aistream.Run { return run } -func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, anchorAt time.Time) { - targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) - if targetEventID == "" { - log.Warn(). - Str("run_id", run.RunID). - Str("message_id", string(messageID)). - Msg("Timed out waiting for AI anchor Matrix event") - return - } - dc.emitAIRunStream(portal, sender, messageID, targetEventID, run, command, 1, anchorAt) -} - // emitAIRunStream packs and emits one segment of an AI run — used both for // the initial run and for any approval continuation. It queues approval // prompts produced by the segment, repacks once approval event IDs are // known, and finally emits the carriers and (if the run terminated) the // final metadata edit. -func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { +func (dc *DummyClient) emitAIRunStream(target aiRunTarget, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { sizingRun := run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) carriers, err := aistream.PackRunByTimeFromSeq(sizingRun, startSeq, demoStreamCarrierMaxSpan) @@ -601,21 +621,22 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid if hasActivePrompt { prompt := activePrompt prompt.SeqStart = nextSeq - ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) - dc.scheduleAIApprovalTimeout(portal, networkid.MessageID(ctx.ID), approvalQueue.Timeout()) - if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { - approvalEventIDs[ctx.ID] = approvalEventID + approvalCtx, approvalEventID, err := target.sendApprovalPrompt(dc, run, prompt, targetEventID, command, time.Now()) + if err == nil && approvalEventID != "" { + target.scheduleApprovalTimeout(dc, approvalCtx, approvalQueue.Timeout()) + approvalEventIDs[approvalCtx.ID] = approvalEventID log.Info(). Str("run_id", run.RunID). - Str("approval_id", ctx.ID). + Str("approval_id", approvalCtx.ID). Stringer("approval_event_id", approvalEventID). - Int("approval_seq_start", ctx.SeqStart). + Int("approval_seq_start", approvalCtx.SeqStart). Msg("AI approval notice ready for reaction") } else { log.Warn(). + Err(err). Str("run_id", run.RunID). - Str("approval_id", ctx.ID). - Int("approval_seq_start", ctx.SeqStart). + Str("approval_id", approvalCtx.ID). + Int("approval_seq_start", approvalCtx.SeqStart). Msg("Timed out waiting for AI approval notice Matrix event") } } @@ -639,7 +660,7 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid Str("run_id", run.RunID). Msg("Sending approval stream without approval event IDs") } - dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) + target.sendCarriers(dc, targetEventID, run, carriers, startSeq, anchorAt) if len(run.Prompts) > 0 && run.Status.State == "streaming" { log.Info(). Str("run_id", run.RunID). @@ -648,7 +669,7 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid Msg("AI run paused for approval") } if run.Status.State != "streaming" { - dc.queueAIRunFinalMetadata(portal, sender, messageID, run) + target.sendFinal(dc, messageID, targetEventID, run, time.Now()) } } @@ -701,6 +722,153 @@ func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender net } } +func (target aiRunTarget) sendAnchor(dc *DummyClient, run aistream.Run, messageID networkid.MessageID, timestamp time.Time) (id.EventID, error) { + if target.portal != nil { + dc.ensureAISenderInvited(target.portal, target.sender) + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(target.portal.PortalKey, target.sender, run, timestamp)) + eventID := dc.waitForMessageMXID(target.portal, messageID, 30*time.Second) + if eventID == "" { + return "", fmt.Errorf("timed out waiting for AI anchor Matrix event") + } + return eventID, nil + } + content, extra := aimatrix.AnchorContent(run) + return dc.sendAIMessageToRoom(target.bot, target.roomID, content, extra, timestamp) +} + +func (target aiRunTarget) sendApprovalPrompt(dc *DummyClient, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) (aistream.ApprovalContext, id.EventID, error) { + approvalCtx := approvalContextForPrompt(run, prompt, targetEventID, command) + if target.portal != nil { + approvalCtx = dc.queueAIApprovalPrompt(target.portal, target.sender, run, prompt, targetEventID, command, timestamp) + eventID := dc.waitForMessageMXID(target.portal, networkid.MessageID(approvalCtx.ID), 10*time.Second) + if eventID == "" { + return approvalCtx, "", fmt.Errorf("timed out waiting for AI approval notice Matrix event") + } + return approvalCtx, eventID, nil + } + eventID, err := dc.sendAIApprovalPromptToRoom(target.bot, target.roomID, approvalCtx, timestamp) + return approvalCtx, eventID, err +} + +func (target aiRunTarget) scheduleApprovalTimeout(dc *DummyClient, approvalCtx aistream.ApprovalContext, timeout aistream.ApprovalTimeout) { + if target.portal != nil { + dc.scheduleAIApprovalTimeout(target.portal, networkid.MessageID(approvalCtx.ID), timeout) + return + } + if dc == nil || target.bot == nil || target.roomID == "" || approvalCtx.ID == "" || timeout.After <= 0 { + return + } + dc.wg.Add(1) + go func() { + defer dc.wg.Done() + timer := time.NewTimer(timeout.After) + defer timer.Stop() + select { + case <-dc.clientContext().Done(): + return + case <-timer.C: + } + if _, firstResolution := dc.resolveApprovalOnce(approvalCtx.ID, timeout.Reason); !firstResolution { + return + } + response := aistream.TimedOutApprovalResponse(approvalCtx.ID) + approvals := dc.recordAIApprovalDecision(approvalCtx.RunID, response) + run, err := buildAIApprovalContinuationRunWithApprovals(dc.clientContext(), approvalCtx, approvals, time.Now()) + if err != nil { + log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to build timed-out AI approval continuation") + return + } + dc.emitAIRunStream(target, networkid.MessageID(approvalCtx.MessageID), id.EventID(approvalCtx.TargetEvent), run, approvalCtx.Command, approvalCtx.SeqStart, time.Now()) + }() +} + +func (target aiRunTarget) sendCarriers(dc *DummyClient, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { + if target.portal != nil { + dc.queuePackedAICarriers(target.portal, target.sender, targetEventID, run, carriers, startSeq, anchorAt) + return + } + dc.sendPackedAICarriersToRoom(target.bot, target.roomID, targetEventID, run, carriers, startSeq, anchorAt) +} + +func (target aiRunTarget) sendFinal(dc *DummyClient, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, timestamp time.Time) { + if target.portal != nil { + dc.queueAIRunFinalMetadata(target.portal, target.sender, messageID, run) + return + } + dc.sendAIRunFinalToRoom(target.bot, target.roomID, targetEventID, run, timestamp) +} + +func approvalContextForPrompt(run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string) aistream.ApprovalContext { + return aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: command, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: string(targetEventID), + AgentID: run.AgentID, + AgentName: run.AgentName, + Model: run.Model, + SeqStart: prompt.SeqStart, + PreviewText: run.Preview.Text, + PreviewTruncated: run.Preview.Truncated, + } +} + +func (dc *DummyClient) sendAIApprovalPromptToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, approvalCtx aistream.ApprovalContext, timestamp time.Time) (id.EventID, error) { + content, extra := aimatrix.ApprovalContent(approvalCtx, aistream.DefaultApprovalChoices()) + return dc.sendAIMessageToRoom(bot, roomID, content, extra, timestamp) +} + +func (dc *DummyClient) sendPackedAICarriersToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { + streamStart := time.Now() + minCarrierTimestamp := anchorAt.Add(time.Millisecond) + if streamStart.Before(minCarrierTimestamp) { + streamStart = minCarrierTimestamp + } + for i, carrier := range carriers { + dc.sleepUntilCarrierTime(run, carrier, streamStart) + now := time.Now() + if now.Before(minCarrierTimestamp) { + now = minCarrierTimestamp + } + minCarrierTimestamp = now.Add(time.Nanosecond) + content, extra := aimatrix.CarrierContent(run, carrier, targetEventID) + if _, err := dc.sendAIMessageToRoom(bot, roomID, content, extra, now); err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Int("carrier_index", startSeq+i).Msg("Failed to send AI stream carrier to Matrix room") + return + } + } +} + +func (dc *DummyClient) sendAIRunFinalToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, targetEventID id.EventID, run aistream.Run, timestamp time.Time) { + content, extra := aimatrix.FinalContent(run) + content.SetEdit(targetEventID) + raw := map[string]any{ + "m.new_content": extra, + "com.beeper.dont_render_edited": true, + } + if _, err := dc.sendAIMessageToRoom(bot, roomID, content, raw, timestamp); err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to send AI final edit to Matrix room") + } +} + +func (dc *DummyClient) sendAIMessageToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, content *event.MessageEventContent, extra map[string]any, timestamp time.Time) (id.EventID, error) { + resp, err := bot.SendMessage(dc.clientContext(), roomID, event.EventMessage, &event.Content{ + Parsed: content, + Raw: extra, + }, &bridgev2.MatrixSendExtra{Timestamp: timestamp}) + if err != nil { + return "", err + } + if resp == nil { + return "", nil + } + return resp.EventID, nil +} + func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { target := aistream.CarrierTimestamp(run, carrier, streamStart) if target.IsZero() { @@ -872,8 +1040,12 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid if sender == "" { sender = dummyAISenderForPortal(portal) } - dc.ensureAISenderInvited(portal, sender) - dc.emitAIRunStream(portal, sender, networkid.MessageID(approvalCtx.MessageID), targetEventID, run, approvalCtx.Command, approvalCtx.SeqStart, now) + dc.emitAIRunStream(aiRunTarget{ + portal: portal, + threadID: approvalCtx.ThreadID, + sender: sender, + agentName: approvalCtx.AgentName, + }, networkid.MessageID(approvalCtx.MessageID), targetEventID, run, approvalCtx.Command, approvalCtx.SeqStart, now) log.Info(). Str("run_id", approvalCtx.RunID). Str("approval_id", approvalCtx.ID). @@ -1152,10 +1324,6 @@ func validApprovalContext(ctx aistream.ApprovalContext) (aistream.ApprovalContex } func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { - targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) - for _, segment := range aibridgev2.FinalSegments(portal.PortalKey, sender, run, targetEventID, time.Now()) { - dc.UserLogin.QueueRemoteEvent(segment) - } dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, sender, messageID, run, time.Now())) } diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index e3934e1..1212972 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -268,10 +268,6 @@ var FileCommand = &commands.FullHandler{ } func runStreamCommand(e *commands.Event, name string) { - if e.Portal == nil { - e.Reply("Can only stream within a portal") - return - } login := e.User.GetDefaultLogin() if login == nil { e.Reply("No login") @@ -287,7 +283,11 @@ func runStreamCommand(e *commands.Event, name string) { e.Reply(err.Error()) return } - client.queueAIResponse(e.Ctx, e.Portal, &event.MessageEventContent{Body: body}) + if e.Portal != nil { + client.queueAIResponse(e.Ctx, e.Portal, &event.MessageEventContent{Body: body}) + } else { + client.queueAIResponseInRoom(e.Ctx, e.Bot, e.RoomID, &event.MessageEventContent{Body: body}) + } e.Reply("Started %s", name) } From 7e81e074c3299e7d14f100e89bd598f3a843553c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 31 May 2026 22:27:55 +0200 Subject: [PATCH 6/7] Update ai-bridge usage and approval context Bump github.com/beeper/ai-bridge and adapt code to its updated APIs. Replace ToolApprovalRequestedWithMetadata with ToolApprovalRequestedWithRequest in ai_runner, constructing an ApprovalRequest. Expand approval context to include title/description/plan/expires/choices/metadata and add parsing helpers (mapField, approvalChoicesField). Refactor DummyClient message waiting to use aibridgev2.WaitForMessageEventID with per-receiver timeouts, remove the old lookupMessageMXID usage, and centralize approval context creation via approvalContextForPrompt. These changes align the connector with the newer ai-bridge contract and provide richer approval metadata and more robust message lookup. --- go.mod | 2 +- pkg/connector/ai_runner.go | 9 +++- pkg/connector/client.go | 107 +++++++++++++++++++++++-------------- 3 files changed, 76 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index 0a5a442..c355634 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.25.0 toolchain go1.25.6 require ( - github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 + github.com/beeper/ai-bridge v0.0.0-20260531201429-3d0bf92ccf00 github.com/rs/zerolog v1.35.1 go.mau.fi/util v0.9.9 maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 diff --git a/pkg/connector/ai_runner.go b/pkg/connector/ai_runner.go index 6f8b38c..aa202e7 100644 --- a/pkg/connector/ai_runner.go +++ b/pkg/connector/ai_runner.go @@ -255,7 +255,14 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool } w.ToolApprovalInputComplete(toolCallID, spec.Name, input) annotateProviderRawEvent(w, spec, "tool_call_input_complete") - w.ToolApprovalRequestedWithMetadata(toolCallID, spec.Name, input, *approval, displayMetadata) + w.ToolApprovalRequestedWithRequest(aistream.ApprovalRequest{ + ID: approvalID, + ToolCallID: toolCallID, + ToolName: spec.Name, + Input: input, + Approval: *approval, + Metadata: displayMetadata, + }) annotateProviderRawEvent(w, spec, "approval_requested") return errApprovalRequested case spec.Deny: diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 19d21e9..a196907 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -807,6 +807,11 @@ func approvalContextForPrompt(run aistream.Run, prompt aistream.ApprovalPrompt, Command: command, ToolCallID: prompt.ToolCallID, ToolName: prompt.ToolName, + Title: prompt.Title, + Description: prompt.Description, + PlanText: prompt.PlanText, + ExpiresAt: prompt.ExpiresAt, + Choices: aistream.DefaultApprovalChoices(), TargetEvent: string(targetEventID), AgentID: run.AgentID, AgentName: run.AgentName, @@ -814,6 +819,7 @@ func approvalContextForPrompt(run aistream.Run, prompt aistream.ApprovalPrompt, SeqStart: prompt.SeqStart, PreviewText: run.Preview.Text, PreviewTruncated: run.Preview.Truncated, + Metadata: prompt.Metadata, } } @@ -894,41 +900,37 @@ func (dc *DummyClient) waitForMessageMXID( if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { return "" } - parent := dc.clientContext() - ctx, cancel := context.WithTimeout(parent, timeout) - defer cancel() receivers := []networkid.UserLoginID{portal.Receiver} if dc.UserLogin.ID != "" && dc.UserLogin.ID != portal.Receiver { receivers = append(receivers, dc.UserLogin.ID) } - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for ctx.Err() == nil { - select { - case <-ctx.Done(): - return "" - case <-ticker.C: + perReceiverTimeout := timeout + if perReceiverTimeout <= 0 { + perReceiverTimeout = 5 * time.Second + } + if len(receivers) > 1 { + perReceiverTimeout /= time.Duration(len(receivers)) + if perReceiverTimeout < time.Second { + perReceiverTimeout = time.Second } - for _, receiver := range receivers { - mxid := dc.lookupMessageMXID(ctx, receiver, messageID) - if mxid != "" { - return mxid - } + } + for _, receiver := range receivers { + eventID, err := aibridgev2.WaitForMessageEventID( + dc.clientContext(), + dc.UserLogin.Bridge, + receiver, + messageID, + networkid.PartID("0"), + perReceiverTimeout, + ) + if err == nil && eventID != "" { + return eventID } } return "" } -func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) id.EventID { - message, err := dc.lookupMessage(ctx, receiver, messageID) - if err != nil || message == nil { - return "" - } - return message.MXID -} - func (dc *DummyClient) lookupMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) (*database.Message, error) { if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil { return nil, nil @@ -937,22 +939,7 @@ func (dc *DummyClient) lookupMessage(ctx context.Context, receiver networkid.Use } func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) aistream.ApprovalContext { - approvalCtx := aistream.ApprovalContext{ - ID: prompt.ID, - ThreadID: run.ThreadID, - RunID: run.RunID, - MessageID: run.MessageID, - Command: command, - ToolCallID: prompt.ToolCallID, - ToolName: prompt.ToolName, - TargetEvent: string(targetEventID), - AgentID: run.AgentID, - AgentName: run.AgentName, - Model: run.Model, - SeqStart: prompt.SeqStart, - PreviewText: run.Preview.Text, - PreviewTruncated: run.Preview.Truncated, - } + approvalCtx := approvalContextForPrompt(run, prompt, targetEventID, command) dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, sender, approvalCtx, timestamp)) return approvalCtx } @@ -1259,6 +1246,11 @@ func approvalContextFromMap(raw map[string]any) aistream.ApprovalContext { Command: stringField(raw, "command"), ToolCallID: stringField(raw, "toolCallId"), ToolName: stringField(raw, "toolName"), + Title: stringField(raw, "title"), + Description: stringField(raw, "description"), + PlanText: stringField(raw, "planText"), + ExpiresAt: stringField(raw, "expiresAt"), + Choices: approvalChoicesField(raw, "choices"), TargetEvent: stringField(raw, "targetEvent"), AgentID: stringField(raw, "agentId"), AgentName: stringField(raw, "agentName"), @@ -1266,6 +1258,7 @@ func approvalContextFromMap(raw map[string]any) aistream.ApprovalContext { SeqStart: intField(raw, "seqStart"), PreviewText: stringField(raw, "previewText"), PreviewTruncated: boolField(raw, "previewTruncated"), + Metadata: mapField(raw, "metadata"), } } @@ -1306,6 +1299,40 @@ func boolField(raw map[string]any, key string) bool { return value } +func mapField(raw map[string]any, key string) map[string]any { + switch value := raw[key].(type) { + case map[string]any: + return value + default: + return nil + } +} + +func approvalChoicesField(raw map[string]any, key string) []aistream.ApprovalChoice { + switch value := raw[key].(type) { + case []aistream.ApprovalChoice: + return value + case []any: + choices := make([]aistream.ApprovalChoice, 0, len(value)) + for _, item := range value { + rawChoice, ok := item.(map[string]any) + if !ok { + return nil + } + choices = append(choices, aistream.ApprovalChoice{ + Key: stringField(rawChoice, "key"), + Label: stringField(rawChoice, "label"), + Alias: stringField(rawChoice, "alias"), + Style: stringField(rawChoice, "style"), + Shortcut: stringField(rawChoice, "shortcut"), + }) + } + return choices + default: + return nil + } +} + func messageIDString(message *database.Message) string { if message == nil { return "" From ae266ba8d8c8ac872d5c6758736ec58195d6d7c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 31 May 2026 22:32:12 +0200 Subject: [PATCH 7/7] Remove local replace for ai-bridge Remove the replace directive that pointed github.com/beeper/ai-bridge to a local ../ai-bridge path in go.mod so the module is resolved from its published version. go.sum is updated with checksum entries for the resolved ai-bridge pseudo-version. --- go.mod | 2 -- go.sum | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index c355634..7455170 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,6 @@ require ( maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 ) -replace github.com/beeper/ai-bridge => ../ai-bridge - require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/coder/websocket v1.8.14 // indirect diff --git a/go.sum b/go.sum index 22abad4..6463929 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/beeper/ai-bridge v0.0.0-20260531201429-3d0bf92ccf00 h1:RIdSWhnzWxhNpt9evjb5kmCNjfgj6Hrl+Kd75yut43c= +github.com/beeper/ai-bridge v0.0.0-20260531201429-3d0bf92ccf00/go.mod h1:+icZV4D9wnp0NTP8bsfS/WXrf/8plzmnp/3bhQEnL3E= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA=