diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..8c13e8b --- /dev/null +++ b/build.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env sh +set -eu + +go build -o dummybridge ./cmd/dummybridge diff --git a/cmd/dummybridge/Dockerfile b/cmd/dummybridge/Dockerfile index 65bcc74..ce4ab2d 100644 --- a/cmd/dummybridge/Dockerfile +++ b/cmd/dummybridge/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1-alpine3.23 AS builder +FROM golang:1.25-alpine3.23 AS builder RUN apk add --no-cache build-base olm-dev @@ -7,8 +7,11 @@ WORKDIR /build ENV CGO_ENABLED=1 RUN go build -o /usr/bin/dummybridge ./cmd/dummybridge -FROM alpine:3.20 +FROM alpine:3.23 -RUN apk add --no-cache ca-certificates olm su-exec bash jq yq curl -COPY --from=builder /usr/bin/dummybridge /usr/bin/dummybridge +RUN apk add --no-cache ca-certificates olm su-exec bash jq yq curl \ + && addgroup -S dummybridge \ + && adduser -S -G dummybridge dummybridge +COPY --from=builder --chown=dummybridge:dummybridge /usr/bin/dummybridge /usr/bin/dummybridge +USER dummybridge:dummybridge CMD ["/usr/bin/dummybridge"] diff --git a/cmd/loginhelper/Dockerfile b/cmd/loginhelper/Dockerfile index 856ff03..76cbcde 100644 --- a/cmd/loginhelper/Dockerfile +++ b/cmd/loginhelper/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1-alpine3.20 AS builder +FROM golang:1.25-alpine3.23 AS builder RUN apk add --no-cache build-base @@ -7,8 +7,11 @@ WORKDIR /build ENV CGO_ENABLED=1 RUN go build -o /usr/bin/loginhelper ./cmd/loginhelper -FROM alpine:3.20 +FROM alpine:3.23 -RUN apk add --no-cache ca-certificates -COPY --from=builder /usr/bin/loginhelper /usr/bin/loginhelper +RUN apk add --no-cache ca-certificates \ + && addgroup -S loginhelper \ + && adduser -S -G loginhelper loginhelper +COPY --from=builder --chown=loginhelper:loginhelper /usr/bin/loginhelper /usr/bin/loginhelper +USER loginhelper:loginhelper CMD ["/usr/bin/loginhelper"] diff --git a/go.mod b/go.mod index 88ab339..02c9550 100644 --- a/go.mod +++ b/go.mod @@ -1,38 +1,39 @@ module github.com/beeper/dummybridge -go 1.24.0 +go 1.25.0 toolchain go1.25.6 require ( - github.com/rs/zerolog v1.34.0 - go.mau.fi/util v0.9.5 - maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b + github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 + github.com/rs/zerolog v1.35.1 + go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 + maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 ) require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.2.0 // indirect github.com/coder/websocket v1.8.14 // indirect - github.com/coreos/go-systemd/v22 v22.6.0 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/coreos/go-systemd/v22 v22.7.0 // indirect + github.com/lib/pq v1.12.3 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.33 // indirect - github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect + github.com/mattn/go-sqlite3 v1.14.44 // indirect + github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 // indirect github.com/rs/xid v1.6.0 // indirect github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - github.com/yuin/goldmark v1.7.16 // indirect + github.com/yuin/goldmark v1.8.2 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect - golang.org/x/crypto v0.47.0 // indirect - golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect - golang.org/x/net v0.49.0 // indirect - golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect + golang.org/x/net v0.53.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index ad7799f..7ec830b 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,25 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 h1:Pw2qyz5mizv/UL4JTKiK1sbYfUl6o8dk/KcNyFlSFG0= +github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72/go.mod h1:Uf2M1ogzy7VGB6uUzzHjZL2eaYt79DK0Py8I6xZl3r0= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= +github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= +github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -21,8 +29,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= +github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -30,6 +42,8 @@ github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= +github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -46,25 +60,41 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= +github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= +go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 h1:YPEmc+li7TF6C9AdRTcSLMb6yCHdF27/wNT7kFLIVNg= +go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25/go.mod h1:jE9FfhbgEgAwxei6lomO9v8zdCIATcquONUu4vjRwSs= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= +golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -75,3 +105,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= +maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 h1:zNC9eVAhw8FhKpM3AxNAh/iy75UEYX91uJUvqqAYlvo= +maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4/go.mod h1:3sOGhXi3P1V6/NruTA0gujkvTypXVUraWktCuTGyDuM= diff --git a/pkg/connector/ai_commands.go b/pkg/connector/ai_commands.go new file mode 100644 index 0000000..26fd54f --- /dev/null +++ b/pkg/connector/ai_commands.go @@ -0,0 +1,440 @@ +package connector + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "go.mau.fi/util/shlex" +) + +func parseCommand(input string) (*parsedCommand, error) { + tokens, err := shlex.Split(input) + if err != nil { + return nil, fmt.Errorf("invalid command syntax: %w", err) + } + if len(tokens) == 0 { + return &parsedCommand{Name: "help"}, nil + } + switch strings.ToLower(tokens[0]) { + case "help", "/help", "!help", "dummybridge": + return &parsedCommand{Name: "help"}, nil + case "stream-tools": + cmd, err := parseToolsCommand(tokens[1:]) + return &parsedCommand{Name: "stream-tools", Tools: cmd}, err + case "stream": + cmd, err := parseStreamCommand(tokens[1:]) + return &parsedCommand{Name: "stream", Random: cmd}, err + default: + return nil, fmt.Errorf("unknown AI demo command %q", tokens[0]) + } +} + +func helpText() string { + return strings.Join([]string{ + "DummyBridge demo commands:", + "help", + "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", + "stream-tools ... [common options]", + "Notes: stream enables approval requests by default; approval-tagged tools emit a separate Matrix approval event with reaction options.", + }, "\n") +} + +func defaultCommonOptions() commonCommandOptions { + return commonCommandOptions{ + DelayMin: 30 * time.Millisecond, + DelayMax: 150 * time.Millisecond, + ChunkMin: defaultChunkMin, + ChunkMax: defaultChunkMax, + FinishReason: agui.FinishReasonStop, + } +} + +func parseLoremCommand(tokens []string) (*loremCommand, error) { + if len(tokens) == 0 { + return nil, fmt.Errorf("text stream requires a character count") + } + count, err := parsePositiveInt(tokens[0], "character count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { + return nil, err + } + opts, err := parseCommonOptions(tokens[1:]) + if err != nil { + return nil, err + } + return &loremCommand{Chars: count, Options: opts}, nil +} + +func parseToolsCommand(tokens []string) (*toolsCommand, error) { + if len(tokens) < 2 { + return nil, fmt.Errorf("stream-tools requires a character count and at least one tool") + } + count, err := parsePositiveInt(tokens[0], "character count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(count, maxDemoChars, "character count"); err != nil { + return nil, err + } + var toolTokens, optTokens []string + for _, token := range tokens[1:] { + if strings.HasPrefix(token, "--") { + optTokens = append(optTokens, token) + } else { + toolTokens = append(toolTokens, token) + } + } + if len(toolTokens) == 0 { + return nil, fmt.Errorf("stream-tools requires at least one tool spec") + } + if err := validateMaxIntValue(len(toolTokens), maxDemoToolSpecs, "tool spec count"); err != nil { + return nil, err + } + opts, err := parseCommonOptions(optTokens) + if err != nil { + return nil, err + } + tools := make([]toolSpec, 0, len(toolTokens)) + for idx, token := range toolTokens { + spec, err := parseToolSpec(token, idx) + if err != nil { + return nil, err + } + tools = append(tools, spec) + } + return &toolsCommand{Chars: count, Tools: tools, Options: opts}, nil +} + +func parseRandomCommand(tokens []string) (*randomCommand, error) { + cmd := &randomCommand{ + Duration: 20 * time.Second, + Actions: 20, + DelayMin: 350 * time.Millisecond, + DelayMax: 1150 * time.Millisecond, + Runs: 1, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, + } + return parseStreamLikeCommand(tokens, cmd, false) +} + +func parseStreamCommand(tokens []string) (*randomCommand, error) { + cmd := &randomCommand{ + Duration: 20 * time.Second, + DelayMin: 350 * time.Millisecond, + DelayMax: 1150 * time.Millisecond, + Runs: 1, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced", AllowApproval: true}, + } + return parseStreamLikeCommand(tokens, cmd, true) +} + +func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions bool) (*randomCommand, error) { + rest := tokens + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + seconds, err := parsePositiveInt(rest[0], "duration") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { + return nil, err + } + cmd.Duration = time.Duration(seconds) * time.Second + rest = rest[1:] + } + if deriveActions && cmd.Actions == 0 { + cmd.Actions = max(3, min(maxDemoRandomActions, int(cmd.Duration/time.Second)*2)) + } + for _, token := range rest { + key, value, hasValue := parseOptionToken(token) + switch key { + case "actions": + n, err := parseValidatedInt(value, hasValue, token, "actions", maxDemoRandomActions, false) + if err != nil { + return nil, err + } + cmd.Actions = n + case "chars": + n, err := parseValidatedInt(value, hasValue, token, "character count", maxDemoChars, false) + if err != nil { + return nil, err + } + cmd.Chars = n + case "delay-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) + if err != nil { + return nil, err + } + cmd.DelayMin, cmd.DelayMax = minDelay, maxDelay + case "terminal": + if !hasValue { + return nil, fmt.Errorf("%s requires a value", token) + } + switch strings.ToLower(value) { + case "stop", "finish": + cmd.Terminal = "finish" + case "abort", "error": + cmd.Terminal = strings.ToLower(value) + case "length", "tool-calls", "content-filter", "other": + cmd.Terminal = agui.NormalizeFinishReason(value) + default: + return nil, fmt.Errorf("unknown terminal %q", value) + } + case "runs": + n, err := parseValidatedInt(value, hasValue, token, "run count", maxDemoChaosRuns, false) + if err != nil { + return nil, err + } + cmd.Runs = n + case "stagger-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) + if err != nil { + return nil, err + } + cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay + case "no-approval": + cmd.AllowApproval = false + default: + handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) + if err != nil || !handled { + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unknown stream option %q", token) + } + } + } + return cmd, nil +} + +func parseChaosCommand(tokens []string) (*chaosCommand, error) { + cmd := &chaosCommand{ + Runs: 3, + Duration: 10 * time.Second, + StaggerMin: 150 * time.Millisecond, + StaggerMax: 900 * time.Millisecond, + MaxActions: 10, + sharedStreamOptions: sharedStreamOptions{Profile: "balanced"}, + } + rest := tokens + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + n, err := parsePositiveInt(rest[0], "run count") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(n, maxDemoChaosRuns, "run count"); err != nil { + return nil, err + } + cmd.Runs = n + rest = rest[1:] + } + if len(rest) > 0 && !strings.HasPrefix(rest[0], "--") { + seconds, err := parsePositiveInt(rest[0], "duration") + if err != nil { + return nil, err + } + if err := validateMaxIntValue(seconds, int(maxDemoDuration/time.Second), "duration seconds"); err != nil { + return nil, err + } + cmd.Duration = time.Duration(seconds) * time.Second + rest = rest[1:] + } + for _, token := range rest { + key, value, hasValue := parseOptionToken(token) + switch key { + case "stagger-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "stagger-ms", maxDemoStagger) + if err != nil { + return nil, err + } + cmd.StaggerMin, cmd.StaggerMax = minDelay, maxDelay + case "max-actions": + n, err := parseValidatedInt(value, hasValue, token, "max-actions", maxDemoChaosActions, false) + if err != nil { + return nil, err + } + cmd.MaxActions = n + default: + handled, err := parseSharedStreamOption(key, value, hasValue, token, &cmd.sharedStreamOptions) + if err != nil || !handled { + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unknown chaos option %q", token) + } + } + } + return cmd, nil +} + +func parseCommonOptions(tokens []string) (commonCommandOptions, error) { + opts := defaultCommonOptions() + for _, token := range tokens { + key, value, hasValue := parseOptionToken(token) + switch key { + case "reasoning": + n, err := parseValidatedInt(value, hasValue, token, "reasoning", maxDemoReasoningChars, true) + if err != nil { + return opts, err + } + opts.ReasoningChars = n + case "steps": + n, err := parseValidatedInt(value, hasValue, token, "steps", maxDemoSteps, false) + if err != nil { + return opts, err + } + opts.Steps = n + case "sources": + n, err := parseValidatedInt(value, hasValue, token, "sources", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Sources = n + case "documents": + n, err := parseValidatedInt(value, hasValue, token, "documents", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Documents = n + case "files": + n, err := parseValidatedInt(value, hasValue, token, "files", maxDemoCollections, true) + if err != nil { + return opts, err + } + opts.Files = n + case "meta": + opts.Meta = true + case "data": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.DataName = value + case "data-transient": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.DataTransientName = value + case "delay-ms": + minDelay, maxDelay, err := parseDurationRange(value, hasValue, token, "delay-ms", maxDemoDelay) + if err != nil { + return opts, err + } + opts.DelayMin, opts.DelayMax = minDelay, maxDelay + case "chunk-chars": + minChunk, maxChunk, err := parseIntRangeOption(value, hasValue, token, "chunk-chars", maxDemoChunkChars) + if err != nil { + return opts, err + } + opts.ChunkMin, opts.ChunkMax = minChunk, maxChunk + case "seed": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + seed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return opts, fmt.Errorf("invalid seed %q", value) + } + opts.Seed, opts.SeedSet = seed, true + case "finish": + if !hasValue { + return opts, fmt.Errorf("%s requires a value", token) + } + opts.FinishReason = agui.NormalizeFinishReason(value) + case "abort": + opts.Abort = true + case "error": + opts.Error = true + default: + return opts, fmt.Errorf("unknown option %q", token) + } + } + if opts.Abort && opts.Error { + return opts, fmt.Errorf("--abort and --error cannot be combined") + } + if (opts.Abort || opts.Error) && opts.FinishReason != agui.FinishReasonStop { + return opts, fmt.Errorf("--finish cannot be combined with --abort or --error") + } + return opts, nil +} + +func parseSharedStreamOption(key, value string, hasValue bool, token string, opts *sharedStreamOptions) (bool, error) { + switch key { + case "profile": + if !hasValue { + return false, fmt.Errorf("%s requires a value", token) + } + switch strings.ToLower(value) { + case "balanced", "tools", "errors", "artifacts": + opts.Profile = strings.ToLower(value) + default: + return false, fmt.Errorf("unknown profile %q", value) + } + case "seed": + if !hasValue { + return false, fmt.Errorf("%s requires a value", token) + } + seed, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid seed %q", value) + } + opts.Seed, opts.SeedSet = seed, true + case "allow-abort": + opts.AllowAbort = true + case "allow-error": + opts.AllowError = true + default: + return false, nil + } + return true, nil +} + +func parseToolSpec(raw string, idx int) (toolSpec, error) { + parts := strings.Split(raw, "#") + spec := toolSpec{Name: strings.TrimSpace(parts[0]), SequenceIndex: idx + 1} + if spec.Name == "" { + return spec, fmt.Errorf("tool spec %q is missing a tool name", raw) + } + for _, tag := range parts[1:] { + tag = strings.TrimSpace(strings.ToLower(tag)) + if tag == "" { + continue + } + spec.Tags = append(spec.Tags, tag) + switch tag { + case "fail": + spec.Fail = true + case "approval": + spec.Approval = true + case "deny": + spec.Deny = true + case "delta": + spec.Delta = true + case "inputerror": + spec.InputError = true + case "prelim": + spec.Preliminary = true + case "provider": + spec.Provider = true + default: + return spec, fmt.Errorf("unknown tool tag %q in %q", tag, raw) + } + } + finalStates := 0 + for _, enabled := range []bool{spec.Fail, spec.Approval, spec.Deny} { + if enabled { + finalStates++ + } + } + if finalStates > 1 { + return spec, fmt.Errorf("tool spec %q has conflicting final state tags", raw) + } + return spec, nil +} diff --git a/pkg/connector/ai_parse_helpers.go b/pkg/connector/ai_parse_helpers.go new file mode 100644 index 0000000..9b06d7d --- /dev/null +++ b/pkg/connector/ai_parse_helpers.go @@ -0,0 +1,173 @@ +package connector + +import ( + "fmt" + "math/rand" + "strconv" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +func parseOptionToken(token string) (string, string, bool) { + trimmed := strings.TrimPrefix(strings.TrimSpace(token), "--") + key, value, ok := strings.Cut(trimmed, "=") + return strings.ToLower(strings.TrimSpace(key)), strings.TrimSpace(value), ok +} + +func parseValidatedInt(value string, hasValue bool, token, label string, maxValue int, allowZero bool) (int, error) { + if !hasValue { + return 0, fmt.Errorf("%s requires a value", token) + } + var n int + var err error + if allowZero { + n, err = parseNonNegativeInt(value, label) + } else { + n, err = parsePositiveInt(value, label) + } + if err != nil { + return 0, err + } + return n, validateMaxIntValue(n, maxValue, label) +} + +func parsePositiveInt(raw, label string) (int, error) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || n <= 0 { + return 0, fmt.Errorf("invalid %s %q", label, raw) + } + return n, nil +} + +func parseNonNegativeInt(raw, label string) (int, error) { + n, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || n < 0 { + return 0, fmt.Errorf("invalid %s %q", label, raw) + } + return n, nil +} + +func parseDurationRange(value string, hasValue bool, token, label string, maxValue time.Duration) (time.Duration, time.Duration, error) { + minValue, maxRange, err := parseIntRangeOption(value, hasValue, token, label, int(maxValue/time.Millisecond)) + if err != nil { + return 0, 0, err + } + return time.Duration(minValue) * time.Millisecond, time.Duration(maxRange) * time.Millisecond, nil +} + +func parseIntRangeOption(value string, hasValue bool, token, label string, maxValue int) (int, int, error) { + if !hasValue { + return 0, 0, fmt.Errorf("%s requires a value", token) + } + minValue, maxRange, ok := strings.Cut(value, ":") + if !ok { + n, err := parseNonNegativeInt(value, label) + if err != nil { + return 0, 0, err + } + if err := validateMaxIntValue(n, maxValue, label); err != nil { + return 0, 0, err + } + return n, n, nil + } + minInt, err := parseNonNegativeInt(minValue, label) + if err != nil { + return 0, 0, err + } + maxInt, err := parseNonNegativeInt(maxRange, label) + if err != nil { + return 0, 0, err + } + if maxInt < minInt { + return 0, 0, fmt.Errorf("invalid %s range %q", label, value) + } + if err := validateMaxIntValue(maxInt, maxValue, label); err != nil { + return 0, 0, err + } + return minInt, maxInt, nil +} + +func validateMaxIntValue(value, maxValue int, label string) error { + if value > maxValue { + return fmt.Errorf("%s %d exceeds the maximum of %d", label, value, maxValue) + } + return nil +} + +func rngForOptions(seedSet bool, seed, fallback int64) *rand.Rand { + if !seedSet { + seed = fallback + } + return rand.New(rand.NewSource(seed)) +} + +func chunkText(text string, rng *rand.Rand, minChunk, maxChunk int) []string { + if strings.TrimSpace(text) == "" { + return nil + } + if minChunk <= 0 { + minChunk = defaultChunkMin + } + if maxChunk < minChunk { + maxChunk = minChunk + } + var chunks []string + for len(text) > 0 { + size := minChunk + if maxChunk > minChunk { + size += rng.Intn(maxChunk - minChunk + 1) + } + if size > len(text) { + size = len(text) + } + parts := aistream.SplitTextUTF8(text, size) + chunk := parts[0] + chunks = append(chunks, chunk) + text = text[len(chunk):] + } + return chunks +} + +func splitCount(total, parts, index int) int { + if total <= 0 || parts <= 0 || index < 0 || index >= parts { + return 0 + } + base := total / parts + remainder := total % parts + if index < remainder { + return base + 1 + } + return base +} + +func sliceByStep(text string, parts, index int) string { + if parts <= 1 || text == "" { + return text + } + start := 0 + for i := 0; i < index; i++ { + start += splitCount(len(text), parts, i) + } + length := splitCount(len(text), parts, index) + if start >= len(text) || length <= 0 { + return "" + } + end := min(start+length, len(text)) + return text[start:end] +} + +func sanitizeToolName(name string) string { + name = strings.ToLower(strings.TrimSpace(name)) + var out strings.Builder + for _, r := range name { + if r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '_' || r == '-' { + out.WriteRune(r) + } + } + if out.Len() == 0 { + return "tool" + } + return out.String() +} diff --git a/pkg/connector/ai_plans.go b/pkg/connector/ai_plans.go new file mode 100644 index 0000000..b96987b --- /dev/null +++ b/pkg/connector/ai_plans.go @@ -0,0 +1,199 @@ +package connector + +import ( + "context" + "errors" + "fmt" + "math/rand" + "strconv" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +// resolveCommandSeed fills in an implicit seed for commands that derive their +// random behavior from the current time, so the continuation can replay the +// exact same sequence. +func resolveCommandSeed(cmd *parsedCommand, now time.Time) { + if cmd == nil { + return + } + switch { + case cmd.Lorem != nil && !cmd.Lorem.Options.SeedSet: + cmd.Lorem.Options.Seed = now.UnixNano() + cmd.Lorem.Options.SeedSet = true + case cmd.Tools != nil && !cmd.Tools.Options.SeedSet: + cmd.Tools.Options.Seed = now.UnixNano() + cmd.Tools.Options.SeedSet = true + case cmd.Random != nil && !cmd.Random.SeedSet: + cmd.Random.Seed = now.UnixNano() + cmd.Random.SeedSet = true + } +} + +// canonicalCommand returns a command string that, when re-parsed, reproduces +// the same run as cmd. If the original input already encoded all randomness +// inputs (e.g. an explicit --seed), it is returned as-is. +func canonicalCommand(input string, cmd *parsedCommand) string { + if cmd == nil { + return input + } + switch { + case cmd.Lorem != nil: + return ensureSeedFlag(input, cmd.Lorem.Options.Seed, cmd.Lorem.Options.SeedSet) + case cmd.Tools != nil: + return ensureSeedFlag(input, cmd.Tools.Options.Seed, cmd.Tools.Options.SeedSet) + case cmd.Random != nil: + return ensureSeedFlag(input, cmd.Random.Seed, cmd.Random.SeedSet) + } + return input +} + +func ensureSeedFlag(input string, seed int64, seedSet bool) string { + if !seedSet || hasSeedFlag(input) { + return input + } + return strings.TrimRight(input, " ") + " --seed=" + strconv.FormatInt(seed, 10) +} + +func hasSeedFlag(input string) bool { + for _, token := range strings.Fields(input) { + if strings.HasPrefix(token, "--seed=") || token == "--seed" { + return true + } + } + return false +} + +func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { + runtime := virtualAIRuntime(now) + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) + writer := aistream.NewWriter(run, runtime.now) + writer.Start() + + runner := aiRunner{runtime: runtime, approvals: approvals} + var err error + switch { + case cmd == nil || cmd.Name == "help": + writer.Text(helpText()) + writer.Finish(agui.FinishReasonStop) + case cmd.Lorem != nil: + err = runner.runLorem(ctx, writer, *cmd.Lorem) + case cmd.Tools != nil: + err = runner.runTools(ctx, writer, *cmd.Tools) + case cmd.Random != nil: + err = runner.runRandom(ctx, writer, *cmd.Random) + } + if errors.Is(err, errApprovalRequested) { + err = nil + } + if err != nil { + writer.Error(err.Error()) + } else if err = agui.ValidateEventSequence(run.Events); err != nil { + writer.Error(err.Error()) + } + return run, nil +} + +func buildAIChaosRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd chaosCommand, agentID, agentName string) ([]aiRunPlan, error) { + seed := cmd.Seed + if !cmd.SeedSet { + seed = now.UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + runner := aiRunner{runtime: virtualAIRuntime(now)} + actions := max(3, min(cmd.MaxActions, int(cmd.Duration/time.Second))) + plans := make([]aiRunPlan, 0, cmd.Runs) + var delay time.Duration + for i := range cmd.Runs { + if i > 0 { + delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + } + runID := fmt.Sprintf("%s-%d", baseRunID, i+1) + randomCmd := randomCommand{ + Duration: cmd.Duration, + Actions: actions, + DelayMin: 180 * time.Millisecond, + DelayMax: 900 * time.Millisecond, + sharedStreamOptions: sharedStreamOptions{ + Profile: cmd.Profile, + Seed: seed + int64(i+1)*97, + SeedSet: true, + AllowAbort: cmd.AllowAbort, + AllowError: cmd.AllowError, + AllowApproval: cmd.AllowApproval, + }, + } + parsed := &parsedCommand{Name: "stream", Random: &randomCmd} + run, err := buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now.Add(delay), parsed, agentID, agentName, nil) + if err != nil { + return nil, err + } + plans = append(plans, aiRunPlan{ + Run: run, + Delay: delay, + EffectiveCommand: streamSubRunCommand(randomCmd), + }) + } + return plans, nil +} + +func buildAIStreamRunPlans(ctx context.Context, baseRunID, threadID string, now time.Time, cmd randomCommand, agentID, agentName string) ([]aiRunPlan, error) { + seed := cmd.Seed + if !cmd.SeedSet { + seed = now.UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + plans := make([]aiRunPlan, 0, cmd.Runs) + runner := aiRunner{runtime: virtualAIRuntime(now)} + var delay time.Duration + for i := range cmd.Runs { + if i > 0 { + delay += runner.sampleDelay(rng, cmd.StaggerMin, cmd.StaggerMax) + } + child := cmd + child.Runs = 1 + child.Seed = seed + int64(i+1)*97 + child.SeedSet = true + parsed := &parsedCommand{Name: "stream", Random: &child} + run, err := buildAIRunFromCommandWithApprovals(ctx, fmt.Sprintf("%s-%d", baseRunID, i+1), threadID, now.Add(delay), parsed, agentID, agentName, nil) + if err != nil { + return nil, err + } + plans = append(plans, aiRunPlan{ + Run: run, + Delay: delay, + EffectiveCommand: streamSubRunCommand(child), + }) + } + return plans, nil +} + +func streamSubRunCommand(cmd randomCommand) string { + parts := []string{ + "stream", + strconv.Itoa(int(cmd.Duration / time.Second)), + "--actions=" + strconv.Itoa(cmd.Actions), + "--delay-ms=" + strconv.Itoa(int(cmd.DelayMin/time.Millisecond)) + ":" + strconv.Itoa(int(cmd.DelayMax/time.Millisecond)), + "--profile=" + cmd.Profile, + "--seed=" + strconv.FormatInt(cmd.Seed, 10), + } + if cmd.Chars > 0 { + parts = append(parts, "--chars="+strconv.Itoa(cmd.Chars)) + } + if cmd.Terminal != "" { + parts = append(parts, "--terminal="+cmd.Terminal) + } + if !cmd.AllowApproval { + parts = append(parts, "--no-approval") + } + if cmd.AllowAbort { + parts = append(parts, "--allow-abort") + } + if cmd.AllowError { + parts = append(parts, "--allow-error") + } + return strings.Join(parts, " ") +} diff --git a/pkg/connector/ai_runner.go b/pkg/connector/ai_runner.go new file mode 100644 index 0000000..6fc497c --- /dev/null +++ b/pkg/connector/ai_runner.go @@ -0,0 +1,415 @@ +package connector + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "sort" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +func (r aiRunner) runLorem(ctx context.Context, w *aistream.Writer, cmd loremCommand) error { + opts := cmd.Options + rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) + steps := max(opts.Steps, 1) + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) + for step := range steps { + if opts.Steps > 0 { + w.StepStart(fmt.Sprintf("step-%d", step+1)) + } + emitDecorations(w, opts, cmd.Chars, step, steps) + if reasoning != "" { + w.Thinking(sliceByStep(reasoning, steps, step)) + } + for _, chunk := range chunkText(sliceByStep(text, steps, step), rng, opts.ChunkMin, opts.ChunkMax) { + w.Text(chunk) + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } + } + if opts.Steps > 0 { + w.StepFinish(fmt.Sprintf("step-%d", step+1)) + } + } + finishWriter(w, opts) + return nil +} + +func (r aiRunner) runTools(ctx context.Context, w *aistream.Writer, cmd toolsCommand) error { + opts := cmd.Options + rng := rngForOptions(opts.SeedSet, opts.Seed, r.runtime.now().UnixNano()) + phaseCount := max(len(cmd.Tools)+1, max(opts.Steps, 1)) + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + reasoning := buildLoremText(opts.ReasoningChars, rand.New(rand.NewSource(rng.Int63()))) + for phase := range phaseCount { + w.StepStart(fmt.Sprintf("phase-%d", phase+1)) + emitDecorations(w, opts, cmd.Chars, phase, phaseCount) + if reasoning != "" { + w.Thinking(sliceByStep(reasoning, phaseCount, phase)) + } + for _, chunk := range chunkText(sliceByStep(text, phaseCount, phase), rng, opts.ChunkMin, opts.ChunkMax) { + w.Text(chunk) + } + if phase < len(cmd.Tools) { + if err := r.runToolSpec(ctx, w, cmd.Tools[phase], rng, opts); err != nil { + if errors.Is(err, errApprovalRequested) { + w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) + } + return err + } + } + w.StepFinish(fmt.Sprintf("phase-%d", phase+1)) + } + finishWriter(w, opts) + return nil +} + +func (r aiRunner) runRandom(ctx context.Context, w *aistream.Writer, cmd randomCommand) error { + seed := cmd.Seed + if !cmd.SeedSet { + seed = r.runtime.now().UnixNano() + } + rng := rand.New(rand.NewSource(seed)) + started := r.runtime.now() + var deadline time.Time + if cmd.Duration > 0 { + deadline = started.Add(cmd.Duration) + } + stepOpen := false + stepName := "" + actionOptions, actionWeightTotal := buildRandomActionOptions(cmd) + if cmd.Chars > 0 { + text := buildDemoVisibleText(cmd.Chars, rand.New(rand.NewSource(rng.Int63()))) + for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { + w.Text(chunk) + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax)); err != nil { + return err + } + } + } + approvalRequested := false + handleTool := func(spec toolSpec) error { + if err := r.runToolSpec(ctx, w, spec, rng, defaultCommonOptions()); err != nil { + if spec.Approval { + approvalRequested = true + } + if errors.Is(err, errApprovalRequested) && stepOpen { + w.StepFinish(stepName) + stepOpen = false + stepName = "" + } + return err + } + if spec.Approval { + approvalRequested = true + } + return nil + } + for action := range cmd.Actions { + if !deadline.IsZero() && !r.runtime.now().Before(deadline) { + break + } + if action > 0 { + delay := r.sampleDelay(rng, cmd.DelayMin, cmd.DelayMax) + if !deadline.IsZero() && r.runtime.now().Add(delay).After(deadline) { + delay = deadline.Sub(r.runtime.now()) + } + if err := r.runtime.sleep(ctx, delay); err != nil { + return err + } + if !deadline.IsZero() && !r.runtime.now().Before(deadline) { + break + } + } + switch pickWeighted(actionOptions, actionWeightTotal, rng) { + case randomActionText: + text := "\n\n" + buildDemoVisibleText(40+rng.Intn(160), rand.New(rand.NewSource(rng.Int63()))) + for _, chunk := range chunkText(text, rng, defaultChunkMin, defaultChunkMax) { + w.Text(chunk) + } + case randomActionThinking: + w.Thinking(buildLoremText(30+rng.Intn(120), rand.New(rand.NewSource(rng.Int63())))) + case randomActionStep: + if stepOpen { + w.StepFinish(stepName) + stepOpen = false + stepName = "" + } else { + stepName = fmt.Sprintf("random-step-%d", action+1) + w.StepStart(stepName) + stepOpen = true + } + case randomActionTool: + if cmd.AllowApproval && cmd.Profile == "balanced" && action >= 10 && !approvalRequested { + if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { + return err + } + continue + } + if err := handleTool(toolSpec{Name: randomToolName(rng), SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionToolFail: + if err := handleTool(toolSpec{Name: randomToolName(rng), Fail: true, SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionToolDeny: + if err := handleTool(toolSpec{Name: randomToolName(rng), Deny: true, SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionToolApproval: + if err := handleTool(toolSpec{Name: randomToolName(rng), Approval: true, SequenceIndex: action + 1}); err != nil { + return err + } + case randomActionSource: + sourceID := fmt.Sprintf("random-source-%d", action+1) + w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/random/source/%d", action+1), "title": fmt.Sprintf("Random Source %d", action+1)}) + case randomActionDocument: + w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("random-doc-%d", action+1), "title": fmt.Sprintf("Random Document %d", action+1), "mediaType": "text/plain"}) + case randomActionFile: + w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/random-file-%d", action+1), "mediaType": "application/octet-stream"}) + case randomActionMetadata: + w.StateDelta(statePatch(map[string]any{"command": "stream", "seed": seed, "action": action + 1, "profile": cmd.Profile})) + case randomActionData: + w.Custom("com.beeper.data", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) + case randomActionDataTransient: + w.Custom("com.beeper.data.transient", map[string]any{"name": "random", "value": map[string]any{"action": action + 1, "seed": seed}}) + } + } + if stepOpen { + w.StepFinish(stepName) + } + terminal := chooseRandomTerminal(cmd, rng) + switch terminal { + case "abort": + w.Abort("DummyBridge random mode aborted") + case "error": + w.Error("DummyBridge random mode failed") + case agui.FinishReasonLength, agui.FinishReasonToolCalls, agui.FinishReasonContentFilter, agui.FinishReasonOther: + w.Finish(terminal) + default: + w.Finish(agui.FinishReasonStop) + } + return nil +} + +func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec toolSpec, rng *rand.Rand, opts commonCommandOptions) error { + toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) + input := toolRequestInput(spec) + approvalID := approvalIDForRun(w.Run.RunID, toolCallID) + var approval *agui.ToolApproval + if spec.Approval { + approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} + } + displayMetadata := toolDisplayMetadata(spec.Name) + w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) + annotateProviderRawEvent(w, spec, "tool_call_start") + if spec.InputError { + if encodedInput := jsonToolInput(input); encodedInput != "" { + w.ToolArgs(toolCallID, encodedInput, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + } + w.ToolError(toolCallID, spec.Name, input, "input-error") + annotateProviderRawEvent(w, spec, "tool_call_error") + return nil + } + if spec.Delta { + if encodedInput := jsonToolInput(input); encodedInput != "" { + for _, chunk := range chunkText(encodedInput, rng, opts.ChunkMin, opts.ChunkMax) { + w.ToolArgs(toolCallID, chunk, nil) + annotateProviderRawEvent(w, spec, "tool_call_args") + if err := r.runtime.sleep(ctx, r.sampleDelay(rng, opts.DelayMin, opts.DelayMax)); err != nil { + return err + } + } + } + } else { + if encodedInput := jsonToolInput(input); encodedInput != "" { + w.ToolArgs(toolCallID, encodedInput, encodedInput) + annotateProviderRawEvent(w, spec, "tool_call_args") + } + } + if spec.Preliminary { + w.ToolResult(toolCallID, fmt.Sprintf(`{"state":%q}`, agui.ToolResultStateStreaming), agui.ToolResultStateStreaming) + annotateProviderRawEvent(w, spec, "tool_call_result") + } + switch { + case spec.Approval: + if response, ok := r.approvals[approvalID]; ok { + if response.ID == "" { + response.ID = approvalID + } + w.ToolApprovalResponded(toolCallID, spec.Name, input, response) + annotateProviderRawEvent(w, spec, "approval_responded") + if !response.Approved { + return errApprovalDenied + } + return nil + } + w.ToolApprovalInputComplete(toolCallID, spec.Name, input) + annotateProviderRawEvent(w, spec, "tool_call_input_complete") + w.ToolApprovalRequestedWithMetadata(toolCallID, spec.Name, input, *approval, displayMetadata) + annotateProviderRawEvent(w, spec, "approval_requested") + return errApprovalRequested + case spec.Deny: + w.ToolDenied(toolCallID, spec.Name, input, approvalID, "denied") + annotateProviderRawEvent(w, spec, "tool_call_denied") + case spec.Fail: + w.ToolError(toolCallID, spec.Name, input, "DummyBridge synthetic tool failure") + annotateProviderRawEvent(w, spec, "tool_call_error") + default: + w.ToolEnd(toolCallID, spec.Name, input, nil) + annotateProviderRawEvent(w, spec, "tool_call_end") + } + return nil +} + +func toolRequestInput(spec toolSpec) any { + return nil +} + +func toolDisplayMetadata(name string) map[string]any { + type ToolProviderMetadata struct { + ID string `json:"id,omitempty"` + DisplayName string `json:"displayName,omitempty"` + IconURL string `json:"iconUrl,omitempty"` + } + type ToolDisplayMetadata struct { + DisplayName string `json:"displayName,omitempty"` + Description string `json:"description,omitempty"` + IconURL string `json:"iconUrl,omitempty"` + Provider *ToolProviderMetadata `json:"provider,omitempty"` + } + + metadata := ToolDisplayMetadata{} + switch strings.ToLower(name) { + case "calendar.get_events", "google_calendar.get_events", "google-calendar.get-events": + metadata.DisplayName = "List Calendar Events" + metadata.Provider = &ToolProviderMetadata{ + ID: "google-calendar", + DisplayName: "Google Calendar", + } + case "linear.list_issues", "linear.list-issues", "list_issues", "list-issues": + metadata.DisplayName = "List Issues" + metadata.Provider = &ToolProviderMetadata{ + ID: "linear", + DisplayName: "Linear", + } + case "shell": + metadata.DisplayName = "Run Command" + case "fetch": + metadata.DisplayName = "Fetch Web" + } + return compactJSONMap(metadata) +} + +func compactJSONMap(value any) map[string]any { + raw, err := json.Marshal(value) + if err != nil { + return nil + } + var out map[string]any + if err := json.Unmarshal(raw, &out); err != nil || len(out) == 0 { + return nil + } + return out +} + +func approvalIDForRun(runID, toolCallID string) string { + return "approval-" + runID + "-" + toolCallID +} + +func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { + if !spec.Provider || w == nil || w.Run == nil || len(w.Run.Events) == 0 { + return + } + w.Run.Events[len(w.Run.Events)-1]["rawEvent"] = map[string]any{ + "provider": "dummybridge", + "stage": stage, + "tool": spec.Name, + "sequence": spec.SequenceIndex, + "tags": spec.Tags, + } +} + +func jsonToolInput(input any) string { + if input == nil { + return "" + } + if inputMap, ok := input.(map[string]any); ok && len(inputMap) == 0 { + return "" + } + raw, err := json.Marshal(input) + if err != nil { + return "" + } + return string(raw) +} + +func finishWriter(w *aistream.Writer, opts commonCommandOptions) { + switch { + case opts.Abort: + w.Abort("DummyBridge synthetic abort") + case opts.Error: + w.Error("DummyBridge synthetic error") + default: + w.Finish(opts.FinishReason) + } +} + +func emitDecorations(w *aistream.Writer, opts commonCommandOptions, chars, step, steps int) { + if opts.Meta { + seed := opts.Seed + if !opts.SeedSet { + seed = int64(chars) + } + w.StateDelta(statePatch(map[string]any{"command": "demo", "seed": seed, "step": step + 1})) + } + for i := range splitCount(opts.Sources, steps, step) { + sourceID := fmt.Sprintf("demo-source-%d-%d", step+1, i+1) + w.Custom("com.beeper.source", map[string]any{"sourceId": sourceID, "url": fmt.Sprintf("https://dummybridge.local/source/%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Source %d.%d", step+1, i+1)}) + } + for i := range splitCount(opts.Documents, steps, step) { + w.Custom("com.beeper.document", map[string]any{"id": fmt.Sprintf("demo-doc-%d-%d", step+1, i+1), "title": fmt.Sprintf("Demo Document %d.%d", step+1, i+1), "mediaType": "text/plain"}) + } + for i := range splitCount(opts.Files, steps, step) { + w.Custom("com.beeper.file", map[string]any{"url": fmt.Sprintf("mxc://dummybridge/demo-file-%d-%d", step+1, i+1), "mediaType": "application/octet-stream"}) + } + if step == 0 && opts.DataName != "" { + w.Custom("com.beeper.data", map[string]any{"name": opts.DataName, "value": map[string]any{"mode": "persistent", "stage": step + 1}}) + } + if step == 0 && opts.DataTransientName != "" { + w.Custom("com.beeper.data.transient", map[string]any{"name": opts.DataTransientName, "value": map[string]any{"mode": "transient", "stage": step + 1}}) + } +} + +func statePatch(values map[string]any) []map[string]any { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + patch := make([]map[string]any, 0, len(keys)) + for _, key := range keys { + patch = append(patch, map[string]any{ + "op": "add", + "path": "/" + key, + "value": values[key], + }) + } + return patch +} + +func (r aiRunner) sampleDelay(rng *rand.Rand, minDelay, maxDelay time.Duration) time.Duration { + if maxDelay <= minDelay { + return minDelay + } + return minDelay + time.Duration(rng.Int63n(int64(maxDelay-minDelay)+1)) +} diff --git a/pkg/connector/ai_runtime.go b/pkg/connector/ai_runtime.go new file mode 100644 index 0000000..36f3ddb --- /dev/null +++ b/pkg/connector/ai_runtime.go @@ -0,0 +1,63 @@ +package connector + +import ( + "context" + "fmt" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +func virtualAIRuntime(now time.Time) aiRuntime { + current := now + return aiRuntime{ + now: func() time.Time { + return current + }, + sleep: func(ctx context.Context, delay time.Duration) error { + if err := ctx.Err(); err != nil { + return err + } + if delay > 0 { + current = current.Add(delay) + } + return nil + }, + } +} + +func buildAIRun(ctx context.Context, runID, threadID, input string, now time.Time) (*aistream.Run, error) { + plans, err := buildAIRunPlans(ctx, runID, threadID, input, now, "ai", "AI") + if err != nil { + return nil, err + } + if len(plans) == 0 { + return nil, fmt.Errorf("no AI runs built") + } + return plans[0].Run, nil +} + +func buildAIRunPlans(ctx context.Context, runID, threadID, input string, now time.Time, agentID, agentName string) ([]aiRunPlan, error) { + cmd, err := parseCommand(input) + if err != nil { + run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) + writer := aistream.NewWriter(run, func() time.Time { return now }) + writer.Start() + writer.Text(err.Error() + "\n\n" + helpText()) + writer.Finish(agui.FinishReasonStop) + return []aiRunPlan{{Run: run, EffectiveCommand: input}}, nil + } + if cmd != nil && cmd.Chaos != nil { + return buildAIChaosRunPlans(ctx, runID, threadID, now, *cmd.Chaos, agentID, agentName) + } + resolveCommandSeed(cmd, now) + if cmd != nil && cmd.Random != nil && cmd.Random.Runs > 1 { + return buildAIStreamRunPlans(ctx, runID, threadID, now, *cmd.Random, agentID, agentName) + } + run, err := buildAIRunFromCommandWithApprovals(ctx, runID, threadID, now, cmd, agentID, agentName, nil) + if err != nil { + return nil, err + } + return []aiRunPlan{{Run: run, EffectiveCommand: canonicalCommand(input, cmd)}}, nil +} diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go new file mode 100644 index 0000000..a94d520 --- /dev/null +++ b/pkg/connector/ai_runtime_test.go @@ -0,0 +1,1057 @@ +package connector + +import ( + "context" + "encoding/json" + "math/rand" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" + "maunium.net/go/mautrix/id" +) + +func TestParseCommandRecognizesHelpAliases(t *testing.T) { + for _, input := range []string{"help", "/help", "!help", "dummybridge help"} { + cmd, err := parseCommand(input) + if err != nil { + t.Fatalf("parseCommand(%q) returned error: %v", input, err) + } + if cmd == nil || cmd.Name != "help" { + t.Fatalf("expected help command for %q, got %#v", input, cmd) + } + } +} + +func TestParseCommandRejectsConflictingToolTags(t *testing.T) { + _, err := parseCommand("stream-tools 100 shell#fail#approval") + if err == nil { + t.Fatal("expected parse error for conflicting tool tags") + } +} + +func TestParseCommandRejectsInvalidProfilesAndOversizedOptions(t *testing.T) { + tests := []string{ + "stream --profile=unknown", + "stream --terminal=unknown", + "stream --chars=1000000", + "stream-tools 100 shell --chunk-chars=1:9999", + } + for _, input := range tests { + if _, err := parseCommand(input); err == nil { + t.Fatalf("expected parse error for %q", input) + } + } +} + +func TestHelpTextMentionsCommandsOptionsAndToolTags(t *testing.T) { + guide := helpText() + for _, expected := range []string{ + "stream-tools", + "stream", + "--profile=balanced|tools|errors|artifacts", + "--no-approval", + "#provider", + "#inputerror", + } { + if !strings.Contains(guide, expected) { + t.Fatalf("help text missing %q:\n%s", expected, guide) + } + } +} + +func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 400 search --reasoning=80 --steps=2 --sources=1 --documents=1 --files=1 --meta --data=demo --data-transient=temp --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + seen := map[string]bool{} + for _, evt := range run.Events { + switch evt["type"] { + case agui.EventTextMessageContent, agui.EventStepStarted, agui.EventStepFinished: + seen[evt["type"].(string)] = true + case agui.EventStateDelta: + seen[evt["type"].(string)] = true + if _, ok := evt["delta"].([]map[string]any); !ok { + t.Fatalf("STATE_DELTA should use JSON Patch array, got %#v", evt["delta"]) + } + case agui.EventCustom: + name, _ := evt["name"].(string) + seen[name] = true + if name == "com.beeper.data" { + value := evt["value"].(map[string]any) + if value["name"] == "temp" { + t.Fatal("transient data must not persist as metadata") + } + } + } + } + for _, key := range []string{agui.EventTextMessageContent, agui.EventStepStarted, agui.EventStepFinished, agui.EventStateDelta, "com.beeper.source", "com.beeper.document", "com.beeper.file", "com.beeper.data", "com.beeper.data.transient"} { + if !seen[key] { + t.Fatalf("missing %s in events", key) + } + } + metadata := run.Metadata() + if metadata["model"] == "" || metadata["threadId"] != "thread-1" || metadata["runId"] != "run-1" { + t.Fatalf("bad metadata: %#v", metadata) + } + data := metadata["data"].(map[string]any) + if _, ok := data["temp"]; ok { + t.Fatalf("transient data leaked into final metadata: %#v", data) + } +} + +func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Prompts) != 1 { + t.Fatalf("expected one approval prompt, got %#v", run.Prompts) + } + if run.Prompts[0].ID != "approval-run-1-dummy-tool-1-shell" { + t.Fatalf("approval prompt ID = %q, want run-scoped ID", run.Prompts[0].ID) + } + foundToolStart := false + seenApprovalStateBeforeCustom := false + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallStart { + if evt["state"] != agui.ToolStateApprovalRequested { + t.Fatalf("expected approval-requested tool state, got %#v", evt) + } + approval, ok := evt["approval"].(*agui.ToolApproval) + if !ok { + t.Fatalf("expected tool start approval metadata, got %#v", evt["approval"]) + } + if approval.ID != "approval-run-1-dummy-tool-1-shell" || !approval.NeedsApproval { + t.Fatalf("bad approval metadata: %#v", approval) + } + metadata, ok := evt["metadata"].(map[string]any) + if !ok || metadata["displayName"] != "Run Command" { + t.Fatalf("bad tool display metadata: %#v", evt["metadata"]) + } + foundToolStart = true + } + if evt["type"] == agui.EventToolCallEnd { + if evt["state"] == agui.ToolStateInputComplete { + t.Fatalf("approval tool must not downgrade to input-complete: %#v", evt) + } + if evt["state"] == agui.ToolStateApprovalRequested { + if evt["input"] != nil { + t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) + } + seenApprovalStateBeforeCustom = true + } + } + if evt["type"] == agui.EventCustom && evt["name"] == agui.ApprovalCustomRequested { + if !seenApprovalStateBeforeCustom { + t.Fatalf("approval custom event should be emitted after approval state update: %#v", run.Events) + } + value := evt["value"].(map[string]any) + if _, hasOptions := value["options"]; hasOptions { + t.Fatalf("AG-UI approval event must not embed Matrix reaction options: %#v", value) + } + if value["approvalMessageId"] != "approval-run-1-dummy-tool-1-shell" { + t.Fatalf("approval event should name the Matrix reaction target: %#v", value) + } + metadata, ok := value["metadata"].(map[string]any) + if !ok || metadata["displayName"] != "Run Command" { + t.Fatalf("approval event should carry tool display metadata: %#v", value["metadata"]) + } + choices, ok := value["choices"].([]aistream.ApprovalChoice) + if !ok || len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { + t.Fatalf("approval event should duplicate renderer choices: %#v", value["choices"]) + } + if value["input"] != nil { + t.Fatalf("approval event should omit placeholder tool input: %#v", value) + } + } + } + if !foundToolStart { + t.Fatal("missing tool start event") + } + if run.Status.State != "streaming" { + t.Fatalf("approval request should pause the run without terminal status, got %#v", run.Status) + } + for _, evt := range run.Events { + if evt["type"] == agui.EventRunFinished { + t.Fatalf("approval request should not finish the run before response: %#v", run.Events) + } + } +} + +func TestToolDisplayMetadataIsOptional(t *testing.T) { + if metadata := toolDisplayMetadata("unknown_tool"); metadata != nil { + t.Fatalf("unknown tools should not invent display metadata: %#v", metadata) + } + + metadata := toolDisplayMetadata("linear.list_issues") + provider, _ := metadata["provider"].(map[string]any) + if metadata["displayName"] != "List Issues" || provider["displayName"] != "Linear" { + t.Fatalf("bad known tool metadata: %#v", metadata) + } + if _, ok := metadata["iconId"]; ok { + t.Fatalf("metadata must not use iconId: %#v", metadata) + } +} + +func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + carriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + if err != nil { + t.Fatal(err) + } + nextSeq := aistream.NextSeq(splitCarriersForTimedEmission(carriers)) + if nextSeq <= 1 { + t.Fatalf("expected initial stream to consume carrier sequence numbers, got %d", nextSeq) + } + + prompt := run.Prompts[0] + prompt.SeqStart = nextSeq + approvalCtx := aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: "stream-tools 120 shell#approval --seed=7 --chunk-chars=32:32", + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: "$anchor", + AgentID: run.AgentID, + AgentName: run.AgentName, + SeqStart: prompt.SeqStart, + } + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + ID: prompt.ID, + Approved: true, + }}, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + if err != nil { + t.Fatal(err) + } + if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { + t.Fatalf("continuation should start at next carrier seq %d, got %#v", nextSeq, continuationCarriers) + } + if continuationCarriers[0].Envelopes[0].Seq >= 100000 { + t.Fatalf("continuation sequence has legacy large gap: %#v", continuationCarriers[0]) + } +} + +func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { + command := "stream-tools 120 shell#approval fetch --seed=7 --chunk-chars=32:32" + run, err := buildAIRun(context.Background(), "run-1", "thread-1", command, time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Prompts) != 1 { + t.Fatalf("expected one approval prompt, got %#v", run.Prompts) + } + sizingRun := *run + annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) + initialCarriers, err := aistream.PackRunFromSeq(sizingRun, "$anchor", aistream.CarrierBudgetBytes, 1) + if err != nil { + t.Fatal(err) + } + initialCarriers = splitCarriersForTimedEmission(initialCarriers) + nextSeq := aistream.NextSeq(initialCarriers) + if nextSeq <= 1 { + t.Fatalf("expected initial carriers to advance sequence, got %d", nextSeq) + } + + prompt := run.Prompts[0] + prompt.SeqStart = nextSeq + approvalCtx := aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: command, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: "$anchor", + AgentID: run.AgentID, + AgentName: run.AgentName, + SeqStart: prompt.SeqStart, + } + notice := aistream.NewApprovalNotice(approvalCtx, aistream.DefaultApprovalChoices()).Map() + if notice["id"] != prompt.ID || notice["messageId"] != run.MessageID || notice["state"] != "requested" { + t.Fatalf("approval notice does not target the paused run: %#v", notice) + } + + annotateApprovalEventIDs(run, map[string]id.EventID{prompt.ID: "$approval"}) + annotatedCarriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + if err != nil { + t.Fatal(err) + } + var annotatedValue map[string]any + for _, carrier := range annotatedCarriers { + for _, env := range carrier.Envelopes { + if env.Part["type"] != agui.EventCustom || env.Part["name"] != agui.ApprovalCustomRequested { + continue + } + annotatedValue, _ = env.Part["value"].(map[string]any) + } + } + if annotatedValue == nil || annotatedValue["approvalMessageId"] != prompt.ID { + t.Fatalf("approval-requested stream event missing approval message id: %#v", annotatedValue) + } + if annotatedValue["approvalEventId"] != "$approval" { + t.Fatalf("approval-requested stream event missing Matrix event target: %#v", annotatedValue) + } + annotatedCarriers = splitCarriersForTimedEmission(annotatedCarriers) + if annotatedNextSeq := aistream.NextSeq(annotatedCarriers); annotatedNextSeq != nextSeq { + t.Fatalf("approval event target changed stream sequence: initial=%d annotated=%d", nextSeq, annotatedNextSeq) + } + choices, ok := annotatedValue["choices"].([]any) + if !ok || len(choices) != len(aistream.DefaultApprovalChoices()) { + t.Fatalf("approval-requested stream event missing choices: %#v", annotatedValue["choices"]) + } + firstChoice, ok := choices[0].(map[string]any) + if !ok || firstChoice["key"] != aistream.ApprovalChoiceApprove || firstChoice["label"] != "Allow once" { + t.Fatalf("approval-requested stream event has bad choice shape: %#v", choices[0]) + } + + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + ID: prompt.ID, + Approved: true, + }}, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + if len(continuation.Prompts) != 0 { + t.Fatalf("continuation must not request approval again: %#v", continuation.Prompts) + } + if continuation.Status.State != "complete" { + t.Fatalf("approved continuation should finish the run, got %#v", continuation.Status) + } + continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + if err != nil { + t.Fatal(err) + } + if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { + t.Fatalf("continuation should resume at seq %d, got %#v", nextSeq, continuationCarriers) + } + if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("continuation must start by acknowledging approval: %#v", continuation.Events) + } +} + +func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { + command := "stream-tools 240 shell#approval fetch --seed=7 --chunk-chars=32:32" + approvalCtx := aistream.ApprovalContext{ + ID: "approval-run-1-dummy-tool-1-shell", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Command: command, + ToolCallID: "dummy-tool-1-shell", + ToolName: "shell", + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 12, + } + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + ID: approvalCtx.ID, + Approved: true, + }}, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Events) == 0 { + t.Fatal("expected continuation events") + } + if run.Events[0]["type"] != agui.EventCustom || run.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("first continuation event should acknowledge approval, got %#v", run.Events[0]) + } + seenApprovedTool := false + seenLaterTool := false + seenFinished := false + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID { + if evt["state"] == agui.ToolStateApprovalResponded { + result := jsonResultMap(t, evt["result"]) + if result["approved"] != true { + t.Fatalf("approved result missing approval state: %#v", result) + } + seenApprovedTool = true + } + } + if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + seenLaterTool = true + } + if evt["type"] == agui.EventRunFinished { + seenFinished = true + } + } + if !seenApprovedTool || !seenLaterTool || !seenFinished { + t.Fatalf("continuation did not resume fully: approved=%v laterTool=%v finished=%v events=%#v", seenApprovedTool, seenLaterTool, seenFinished, run.Events) + } + if run.Status.State != "complete" { + t.Fatalf("approved continuation status = %#v", run.Status) + } + if len(run.Prompts) != 0 { + t.Fatalf("finished continuation should not keep pending prompts: %#v", run.Prompts) + } +} + +func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { + command := "stream-tools 240 shell#approval fetch --seed=7 --chunk-chars=32:32" + approvalCtx := aistream.ApprovalContext{ + ID: "approval-run-1-dummy-tool-1-shell", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Command: command, + ToolCallID: "dummy-tool-1-shell", + ToolName: "shell", + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 12, + } + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + ID: approvalCtx.ID, + Approved: false, + Reason: "denied", + }}, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + seenDeniedTool := false + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + t.Fatalf("denied approval must not continue later tools: %#v", run.Events) + } + if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID && evt["state"] == agui.ToolStateApprovalResponded { + result := jsonResultMap(t, evt["result"]) + if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { + t.Fatalf("bad denied result: %#v", result) + } + seenDeniedTool = true + } + } + if !seenDeniedTool { + t.Fatalf("missing denied approval result: %#v", run.Events) + } + if run.Status.State != "error" { + t.Fatalf("denied continuation status = %#v", run.Status) + } +} + +func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#deny --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallEnd { + continue + } + result := jsonResultMap(t, evt["result"]) + if result["state"] == agui.ToolResultStateError && result["reason"] == "denied" { + return + } + } + t.Fatalf("missing structured denied tool result: %#v", run.Events) +} + +func TestBuildAIRunToolsOmitPlaceholderArgsAndEmitTerminalResult(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + if evt["type"] == agui.EventToolCallArgs { + t.Fatalf("plain demo tool should not emit placeholder args: %#v", evt) + } + if evt["type"] == agui.EventToolCallEnd { + if evt["input"] != nil { + t.Fatalf("plain demo tool should omit placeholder input: %#v", evt) + } + result := jsonResultMap(t, evt["result"]) + if result["state"] != agui.ToolResultStateComplete || result["status"] != "success" { + t.Fatalf("plain demo tool should emit terminal success result: %#v", evt) + } + return + } + } + t.Fatal("missing TOOL_CALL_END event") +} + +func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#prelim --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallResult { + continue + } + if evt["state"] != agui.ToolResultStateStreaming || evt["toolCallId"] == "" || evt["content"] == "" { + t.Fatalf("bad TOOL_CALL_RESULT event: %#v", evt) + } + return + } + t.Fatal("missing TOOL_CALL_RESULT event") +} + +func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#prelim fetch#fail --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + var snapshot []agui.UIMessage + seenRunFinished := false + for _, evt := range run.Events { + switch evt["type"] { + case agui.EventMessagesSnapshot: + if seenRunFinished { + t.Fatal("final snapshot must be emitted before RUN_FINISHED") + } + var ok bool + snapshot, ok = evt["messages"].([]agui.UIMessage) + if !ok { + t.Fatalf("bad snapshot payload: %#v", evt["messages"]) + } + case agui.EventRunFinished: + seenRunFinished = true + } + } + if len(snapshot) != 1 { + t.Fatalf("expected one final UI message snapshot, got %#v", snapshot) + } + seenToolCall := false + seenToolResult := false + for _, part := range snapshot[0].Parts { + switch part["type"] { + case "tool-call": + seenToolCall = true + case "tool-result": + seenToolResult = true + } + } + if !seenToolCall || !seenToolResult { + t.Fatalf("final snapshot lost tool parts: %#v", snapshot[0].Parts) + } +} + +func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#fail fetch#delta parser#inputerror --seed=7 --chunk-chars=8:8", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + seenFailure := false + seenInputError := false + for _, evt := range run.Events { + if evt["type"] != agui.EventToolCallEnd && evt["type"] != agui.EventToolCallArgs { + continue + } + toolCallID, _ := evt["toolCallId"].(string) + if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { + t.Fatalf("delta tool without real input should not emit placeholder args: %#v", evt) + } + if evt["type"] == agui.EventToolCallEnd { + if strings.Contains(toolCallID, "shell") { + result := jsonResultMap(t, evt["result"]) + if result["state"] == agui.ToolResultStateError { + seenFailure = true + } + } + if strings.Contains(toolCallID, "parser") { + result := jsonResultMap(t, evt["result"]) + if result["reason"] == "input-error" { + seenInputError = true + } + } + } + } + if !seenFailure || !seenInputError { + t.Fatalf("missing tool tag coverage: failure=%v inputError=%v", seenFailure, seenInputError) + } +} + +func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 120 shell#provider --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + for _, evt := range run.Events { + raw, ok := evt["rawEvent"].(map[string]any) + if !ok { + continue + } + if raw["provider"] != "dummybridge" || raw["tool"] != "shell" { + t.Fatalf("bad raw provider event: %#v", raw) + } + carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + if len(carriers) == 0 { + t.Fatal("expected packed carriers") + } + return + } + t.Fatal("missing rawEvent for provider-tagged tool") +} + +func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { + errorRun, err := buildAIRun(context.Background(), "run-error", "thread-1", "stream 1 --terminal=error --seed=7 --no-approval", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if errorRun.Status.State != "error" { + t.Fatalf("expected error status, got %#v", errorRun.Status) + } + abortRun, err := buildAIRun(context.Background(), "run-abort", "thread-1", "stream 1 --terminal=abort --seed=7 --no-approval", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if abortRun.Status.State != "aborted" { + t.Fatalf("expected aborted status, got %#v", abortRun.Status) + } +} + +func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 1 --chars=70000 --actions=1 --seed=7", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + if err != nil { + t.Fatal(err) + } + if len(carriers) < 2 { + t.Fatalf("expected split carriers, got %d", len(carriers)) + } + for i, carrier := range carriers { + if size := aistream.JSONSize(aistream.CarrierContent(carrier.Envelopes)); size > aistream.CarrierBudgetBytes { + t.Fatalf("carrier %d size = %d", i, size) + } + } + for _, carrier := range carriers { + for _, envelope := range carrier.Envelopes { + if envelope.Part["type"] != agui.EventMessagesSnapshot { + continue + } + raw, err := json.Marshal(envelope.Part) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), strings.Repeat("a", 60*1024)) { + t.Fatal("final snapshot should not repeat full streamed text") + } + } + } + if len(aistream.ReconstructText(carriers)) < 60*1024 { + t.Fatalf("expected large reconstructed output, got %d", len(aistream.ReconstructText(carriers))) + } +} + +func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-1", "stream 1 --runs=3 --actions=3 --seed=7 --stagger-ms=1:1", time.Unix(10, 0), "ai", "AI") + if err != nil { + t.Fatal(err) + } + if len(plans) != 3 { + t.Fatalf("expected three chaos runs, got %d", len(plans)) + } + seen := map[string]bool{} + for i, plan := range plans { + if plan.Run == nil { + t.Fatalf("nil run at %d", i) + } + if seen[plan.Run.RunID] { + t.Fatalf("duplicate run ID %q", plan.Run.RunID) + } + seen[plan.Run.RunID] = true + if plan.Run.ThreadID != "thread-1" { + t.Fatalf("bad thread ID: %q", plan.Run.ThreadID) + } + if i > 0 && plan.Delay <= 0 { + t.Fatalf("expected nonzero child stagger delay, got %#v", plans) + } + } +} + +func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 3 --actions=4 --seed=7 --delay-ms=100:100", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + var first, last int64 + for _, evt := range run.Events { + ts, _ := evt["timestamp"].(int64) + if ts == 0 { + if n, ok := evt["timestamp"].(int); ok { + ts = int64(n) + } + } + if ts == 0 { + continue + } + if first == 0 { + first = ts + } + last = ts + } + if first == 0 || last-first < 300 { + t.Fatalf("expected random run timestamps to reflect action delays, first=%d last=%d", first, last) + } +} + +func TestRandomModeApprovalPause(t *testing.T) { + for seed := int64(1); seed <= 200; seed++ { + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream 1 --profile=tools --seed="+strconv.FormatInt(seed, 10), time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if run.ApprovalID == "" { + continue + } + for _, evt := range run.Events { + if evt["type"] == agui.EventRunFinished { + t.Fatalf("approval run emitted RUN_FINISHED with seed %d", seed) + } + } + if run.Status.State != "streaming" { + t.Fatalf("expected approval run to remain streaming, got %q", run.Status.State) + } + return + } + t.Fatal("no approval action selected for tested random seeds") +} + +func TestBalancedStream50UsuallyPausesForApproval(t *testing.T) { + for seed := int64(1); seed <= 20; seed++ { + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream 50 --seed="+strconv.FormatInt(seed, 10), time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if run.ApprovalID != "" { + return + } + } + t.Fatal("balanced stream 50 did not request approval for any sampled seed") +} + +func TestBalancedStream50DoesNotPauseImmediatelyForApproval(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-approval", "thread-approval", "stream 50 --seed=1", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + if run.ApprovalID == "approval-run-approval-dummy-tool-1-calendar" { + t.Fatalf("balanced stream paused immediately for approval: %q", run.ApprovalID) + } + if strings.Contains(run.ApprovalID, "dummy-tool-1-") { + t.Fatalf("balanced stream paused on first action approval: %q", run.ApprovalID) + } +} + +func TestRandomProfilesCoverToolsArtifactsAndTransientData(t *testing.T) { + balanced := randomCommand{ + sharedStreamOptions: sharedStreamOptions{ + Profile: "balanced", + AllowApproval: true, + }, + } + seen := map[string]bool{} + rng := rand.New(rand.NewSource(2)) + for range 400 { + options, total := buildRandomActionOptions(balanced) + seen[pickWeighted(options, total, rng)] = true + } + if seen[randomActionToolApproval] { + t.Fatalf("balanced profile should keep approvals rare via tool-call promotion, seen=%#v", seen) + } + + cmd := randomCommand{ + sharedStreamOptions: sharedStreamOptions{ + Profile: "tools", + AllowApproval: true, + }, + } + seen = map[string]bool{} + rng = rand.New(rand.NewSource(4)) + for range 400 { + options, total := buildRandomActionOptions(cmd) + seen[pickWeighted(options, total, rng)] = true + } + for _, action := range []string{randomActionTool, randomActionToolFail, randomActionToolDeny, randomActionToolApproval} { + if !seen[action] { + t.Fatalf("tools profile never selected %s; seen=%#v", action, seen) + } + } + + cmd.Profile = "artifacts" + seen = map[string]bool{} + rng = rand.New(rand.NewSource(8)) + for range 400 { + options, total := buildRandomActionOptions(cmd) + seen[pickWeighted(options, total, rng)] = true + } + for _, action := range []string{randomActionSource, randomActionDocument, randomActionFile, randomActionMetadata, randomActionData, randomActionDataTransient} { + if !seen[action] { + t.Fatalf("artifacts profile never selected %s; seen=%#v", action, seen) + } + } +} + +func TestRandomTerminalUsesAllowedOutcomes(t *testing.T) { + cmd := randomCommand{sharedStreamOptions: sharedStreamOptions{AllowAbort: true, AllowError: true}} + seen := map[string]bool{} + rng := rand.New(rand.NewSource(10)) + for range 80 { + seen[chooseRandomTerminal(cmd, rng)] = true + } + for _, terminal := range []string{"finish", "abort", "error"} { + if !seen[terminal] { + t.Fatalf("terminal %s was never selected; seen=%#v", terminal, seen) + } + } + + if terminal := chooseRandomTerminal(randomCommand{}, rand.New(rand.NewSource(1))); terminal != "finish" { + t.Fatalf("unexpected terminal without flags: %q", terminal) + } +} + +func TestBuildDemoVisibleTextIsMarkdownRichAndDeterministic(t *testing.T) { + first := buildDemoVisibleText(420, rand.New(rand.NewSource(7))) + second := buildDemoVisibleText(420, rand.New(rand.NewSource(7))) + if first != second { + t.Fatalf("expected deterministic output for seed") + } + for _, signal := range []string{"[", "](", "**", "\n- ", "\n> ", "```", "\n| "} { + if strings.Contains(first, signal) { + return + } + } + t.Fatalf("expected markdown-rich text, got %q", first) +} + +func TestBuildDemoVisibleTextDoesNotCutMarkdownSyntax(t *testing.T) { + for _, chars := range []int{24, 40, 60, 80, 96, 120, 180, 260, 420} { + for seed := int64(1); seed <= 80; seed++ { + text := buildDemoVisibleText(chars, rand.New(rand.NewSource(seed))) + if strings.Count(text, "[") != strings.Count(text, "]") { + t.Fatalf("unbalanced brackets for chars=%d seed=%d: %q", chars, seed, text) + } + assertCompleteMarkdownLinks(t, chars, seed, text) + if strings.Count(text, "```")%2 != 0 { + t.Fatalf("unbalanced code fence for chars=%d seed=%d: %q", chars, seed, text) + } + if strings.Contains(text, "https://dummybridge.") && !strings.Contains(text, "https://dummybridge.local/") { + t.Fatalf("cut markdown URL for chars=%d seed=%d: %q", chars, seed, text) + } + } + } +} + +func TestRandomStreamTextBlocksKeepMarkdownBoundaries(t *testing.T) { + for seed := int64(1); seed <= 80; seed++ { + run, err := buildAIRun(context.Background(), "run-markdown", "thread-markdown", "stream 40 --seed="+strconv.FormatInt(seed, 10)+" --no-approval", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + text := run.Text() + assertCompleteMarkdownLinks(t, 40, seed, text) + if strings.Count(text, "[") != strings.Count(text, "]") { + t.Fatalf("unbalanced brackets for seed=%d: %q", seed, text) + } + if strings.Count(text, "```")%2 != 0 { + t.Fatalf("unbalanced code fence for seed=%d: %q", seed, text) + } + if joinedMarkdownBlockRE.MatchString(text) { + t.Fatalf("markdown block joined to previous text for seed=%d: %q", seed, text) + } + } +} + +var joinedMarkdownBlockRE = regexp.MustCompile(`[[:lower:]](Use |Review the \[|> )`) + +func assertCompleteMarkdownLinks(t *testing.T, chars int, seed int64, text string) { + t.Helper() + offset := 0 + for { + start := strings.Index(text[offset:], "](") + if start < 0 { + return + } + start += offset + close := strings.IndexByte(text[start+2:], ')') + if close < 0 { + t.Fatalf("unclosed markdown link for chars=%d seed=%d: %q", chars, seed, text) + } + linkTarget := text[start+2 : start+2+close] + if strings.ContainsAny(linkTarget, " \n\t") { + t.Fatalf("cut markdown link for chars=%d seed=%d: %q", chars, seed, text) + } + offset = start + 3 + close + } +} + +func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { + command := "stream-tools 240 shell#approval fetch#approval --seed=7 --chunk-chars=32:32" + approvalCtx := aistream.ApprovalContext{ + ID: "approval-run-1-dummy-tool-1-shell", + ThreadID: "thread-1", + RunID: "run-1", + MessageID: "msg-run-1", + Command: command, + ToolCallID: "dummy-tool-1-shell", + ToolName: "shell", + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 12, + } + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + ID: approvalCtx.ID, + Approved: true, + }}, time.Unix(20, 0)) + if err != nil { + t.Fatal(err) + } + if len(run.Prompts) != 1 { + t.Fatalf("expected second approval prompt to be preserved, got %#v", run.Prompts) + } + if run.Prompts[0].ToolName != "fetch" { + t.Fatalf("expected preserved prompt to belong to fetch, got %#v", run.Prompts[0]) + } + if run.Status.State != "streaming" { + t.Fatalf("expected continuation with pending approval to remain streaming, got %#v", run.Status) + } + + secondCtx := aistream.ApprovalContext{ + ID: run.Prompts[0].ID, + ThreadID: approvalCtx.ThreadID, + RunID: approvalCtx.RunID, + MessageID: approvalCtx.MessageID, + Command: command, + ToolCallID: run.Prompts[0].ToolCallID, + ToolName: run.Prompts[0].ToolName, + TargetEvent: approvalCtx.TargetEvent, + AgentID: approvalCtx.AgentID, + AgentName: approvalCtx.AgentName, + SeqStart: 100, + } + finished, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), secondCtx, map[string]agui.ToolApprovalResponse{ + approvalCtx.ID: { + ID: approvalCtx.ID, + Approved: true, + }, + secondCtx.ID: { + ID: secondCtx.ID, + Approved: true, + }, + }, time.Unix(30, 0)) + if err != nil { + t.Fatal(err) + } + if finished.Status.State != "complete" { + t.Fatalf("second approval continuation should finish, got %#v", finished.Status) + } + if len(finished.Prompts) != 0 { + t.Fatalf("finished continuation should not keep prompts: %#v", finished.Prompts) + } +} + +func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { + // Iterate clocks until the random-action profile produces an approval + // request — the seed is implicit (resolved from now()), and the bug being + // guarded against is that the continuation would otherwise pick a fresh + // seed and lose the original toolCallID. + for tick := int64(1); tick <= 500; tick++ { + now := time.Unix(tick, 0) + plans, err := buildAIRunPlans(context.Background(), "run-rand", "thread-rand", "stream 1 --profile=tools", now, "ai", "AI") + if err != nil { + t.Fatal(err) + } + if len(plans) != 1 || plans[0].Run == nil { + t.Fatalf("expected one random plan, got %#v", plans) + } + originalRun := plans[0].Run + if originalRun.ApprovalID == "" { + continue + } + if !strings.Contains(plans[0].EffectiveCommand, "--seed=") { + t.Fatalf("effective command must include resolved seed: %q", plans[0].EffectiveCommand) + } + approvalCtx := aistream.ApprovalContext{ + ID: originalRun.ApprovalID, + ThreadID: originalRun.ThreadID, + RunID: originalRun.RunID, + MessageID: originalRun.MessageID, + Command: plans[0].EffectiveCommand, + ToolCallID: originalRun.ToolCallID, + TargetEvent: "$anchor", + AgentID: "ai", + AgentName: "AI", + SeqStart: 50, + } + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + ID: approvalCtx.ID, + Approved: true, + }}, now.Add(time.Hour)) + if err != nil { + t.Fatalf("continuation failed: %v", err) + } + if len(continuation.Events) == 0 { + t.Fatalf("expected continuation events for random run, got none") + } + if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { + t.Fatalf("first continuation event should acknowledge approval, got %#v", continuation.Events[0]) + } + return + } + t.Fatal("no implicit-seed random run produced an approval prompt in the tested range") +} + +func TestChaosSubRunCommandIsParseable(t *testing.T) { + plans, err := buildAIRunPlans(context.Background(), "run-chaos", "thread-chaos", "stream 1 --runs=2 --seed=11", time.Unix(0, 0), "ai", "AI") + if err != nil { + t.Fatal(err) + } + if len(plans) != 2 { + t.Fatalf("expected two chaos sub-runs, got %d", len(plans)) + } + for i, plan := range plans { + if !strings.HasPrefix(plan.EffectiveCommand, "stream ") { + t.Fatalf("chaos plan %d must render as stream, got %q", i, plan.EffectiveCommand) + } + if !strings.Contains(plan.EffectiveCommand, "--seed=") { + t.Fatalf("chaos sub-run command must include explicit seed: %q", plan.EffectiveCommand) + } + cmd, err := parseCommand(plan.EffectiveCommand) + if err != nil { + t.Fatalf("chaos sub-run command did not re-parse: %v (%q)", err, plan.EffectiveCommand) + } + if cmd == nil || cmd.Random == nil || !cmd.Random.SeedSet { + t.Fatalf("re-parsed chaos sub-run lost seed: %#v", cmd) + } + } +} + +func jsonResultMap(t *testing.T, value any) map[string]any { + t.Helper() + text, ok := value.(string) + if !ok { + t.Fatalf("expected JSON string result, got %#v", value) + } + var out map[string]any + if err := json.Unmarshal([]byte(text), &out); err != nil { + t.Fatalf("failed to parse result %q: %v", text, err) + } + return out +} diff --git a/pkg/connector/ai_stream_random.go b/pkg/connector/ai_stream_random.go new file mode 100644 index 0000000..606bcc2 --- /dev/null +++ b/pkg/connector/ai_stream_random.go @@ -0,0 +1,89 @@ +package connector + +import "math/rand" + +func buildRandomActionOptions(cmd randomCommand) ([]randomActionOption, int) { + options := []randomActionOption{ + {randomActionText, 6}, + {randomActionThinking, 4}, + {randomActionStep, 2}, + {randomActionTool, 3}, + {randomActionToolFail, 1}, + {randomActionSource, 2}, + {randomActionDocument, 2}, + {randomActionFile, 2}, + {randomActionMetadata, 2}, + {randomActionData, 1}, + {randomActionDataTransient, 1}, + } + if cmd.AllowApproval && cmd.Profile != "balanced" { + options = append(options, randomActionOption{randomActionToolApproval, 2}) + } + switch cmd.Profile { + case "tools": + options = append(options, + randomActionOption{randomActionTool, 6}, + randomActionOption{randomActionToolFail, 4}, + randomActionOption{randomActionToolDeny, 3}, + ) + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 4}) + } + case "artifacts": + options = append(options, + randomActionOption{randomActionSource, 4}, + randomActionOption{randomActionDocument, 4}, + randomActionOption{randomActionFile, 4}, + randomActionOption{randomActionMetadata, 3}, + randomActionOption{randomActionData, 3}, + randomActionOption{randomActionDataTransient, 3}, + ) + case "errors": + options = append(options, + randomActionOption{randomActionToolFail, 7}, + randomActionOption{randomActionToolDeny, 5}, + randomActionOption{randomActionTool, 2}, + ) + if cmd.AllowApproval { + options = append(options, randomActionOption{randomActionToolApproval, 4}) + } + } + total := 0 + for _, option := range options { + total += option.weight + } + return options, total +} + +func pickWeighted(options []randomActionOption, total int, rng *rand.Rand) string { + if total <= 0 || len(options) == 0 { + return randomActionText + } + pick := rng.Intn(total) + for _, option := range options { + if pick < option.weight { + return option.name + } + pick -= option.weight + } + return randomActionText +} + +func chooseRandomTerminal(cmd randomCommand, rng *rand.Rand) string { + if cmd.Terminal != "" { + return cmd.Terminal + } + options := []string{"finish"} + if cmd.AllowAbort { + options = append(options, "abort") + } + if cmd.AllowError { + options = append(options, "error") + } + return options[rng.Intn(len(options))] +} + +func randomToolName(rng *rand.Rand) string { + names := []string{"search", "fetch", "summarize", "calendar", "shell", "files", "preview"} + return names[rng.Intn(len(names))] +} diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go new file mode 100644 index 0000000..40a6945 --- /dev/null +++ b/pkg/connector/ai_text.go @@ -0,0 +1,250 @@ +package connector + +import ( + "fmt" + "math/rand" + "strings" +) + +var loremSentenceCorpus = []string{ + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.", + "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.", + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.", + "Integer nec odio praesent libero sed cursus ante dapibus diam.", + "Nulla quis sem at nibh elementum imperdiet duis sagittis ipsum.", + "Praesent mauris fusce nec tellus sed augue semper porta.", + "Mauris massa vestibulum lacinia arcu eget nulla.", + "Class aptent taciti sociosqu ad litora torquent per conubia nostra.", + "In consectetur orci eu erat varius, vitae facilisis lorem blandit.", + "Curabitur ullamcorper ultricies nisi nam eget dui etiam rhoncus.", +} + +var demoMarkdownLabels = []string{"release notes", "ops runbook", "incident log", "design memo", "qa checklist", "support brief"} +var demoMarkdownURLs = []string{ + "https://dummybridge.local/docs/streaming", + "https://dummybridge.local/docs/markdown", + "https://dummybridge.local/runbooks/runs", + "https://dummybridge.local/notes/demo-output", +} +var demoMarkdownEmphasis = []string{"high-signal", "operator-visible", "tool-safe", "incremental", "review-ready"} +var demoMarkdownListItems = []string{ + "Confirm the seeded output changes shape between runs.", + "Surface enough formatting to stress the renderer.", + "Keep deltas readable while chunks arrive out of phase.", + "Preserve stable output for deterministic test fixtures.", + "Expose links, tables, and code blocks without extra flags.", +} +var demoMarkdownQuoteCorpus = []string{ + "Streaming output should feel alive, not like the same paragraph repeated forever.", + "Richer markdown gives the client something realistic to render while the run is still open.", +} +var demoMarkdownCodeSnippets = []string{ + "const preview = chunks.filter(Boolean).join(\"\");", + "writer.textDelta(\"| status | value |\\n| --- | --- |\\n\");", + "if (seeded) { return renderMarkdownBlocks(); }", +} +var demoMarkdownTableHeaders = [][]string{{"Metric", "Value", "Notes"}, {"Phase", "Owner", "Status"}, {"Artifact", "State", "Latency"}} +var demoMarkdownTableRows = [][]string{ + {"stream", "warming", "steady deltas"}, + {"renderer", "active", "accepts markdown"}, + {"tool call", "complete", "output persisted"}, + {"search step", "queued", "awaiting sources"}, + {"summary", "ready", "links attached"}, +} + +type demoSegmentSpec struct { + weight int + minLen int + build func(*rand.Rand, int) string +} + +func buildLoremText(chars int, rng *rand.Rand) string { + if chars <= 0 { + return "" + } + if rng == nil { + rng = rand.New(rand.NewSource(int64(chars))) + } + var sb strings.Builder + sb.Grow(chars + 128) + lastIndex := -1 + for sb.Len() < chars+64 { + index := rng.Intn(len(loremSentenceCorpus)) + if len(loremSentenceCorpus) > 1 && index == lastIndex { + index = (index + 1 + rng.Intn(len(loremSentenceCorpus)-1)) % len(loremSentenceCorpus) + } + if sb.Len() > 0 { + sb.WriteByte(' ') + } + sb.WriteString(loremSentenceCorpus[index]) + lastIndex = index + } + return trimText(sb.String(), chars) +} + +func buildDemoVisibleText(chars int, rng *rand.Rand) string { + if chars <= 0 { + return "" + } + if rng == nil { + rng = rand.New(rand.NewSource(int64(chars))) + } + segments := []demoSegmentSpec{ + {weight: 5, minLen: 48, build: func(rng *rand.Rand, remaining int) string { + return buildLoremText(max(48, min(168, remaining+48)), rand.New(rand.NewSource(rng.Int63()))) + }}, + {weight: 4, minLen: 96, build: func(rng *rand.Rand, _ int) string { + return fmt.Sprintf("%s Review the [%s](%s) entry for **%s** output and _staged_ formatting transitions.", + buildLoremText(72+rng.Intn(48), rand.New(rand.NewSource(rng.Int63()))), + demoMarkdownLabels[rng.Intn(len(demoMarkdownLabels))], + demoMarkdownURLs[rng.Intn(len(demoMarkdownURLs))], + demoMarkdownEmphasis[rng.Intn(len(demoMarkdownEmphasis))]) + }}, + {weight: 3, minLen: 96, build: func(rng *rand.Rand, _ int) string { + var lines []string + for i := 0; i < 2+rng.Intn(3); i++ { + prefix := "-" + if rng.Intn(4) == 0 { + prefix = "- [x]" + } + lines = append(lines, fmt.Sprintf("%s %s", prefix, demoMarkdownListItems[(rng.Intn(len(demoMarkdownListItems))+i)%len(demoMarkdownListItems)])) + } + return strings.Join(lines, "\n") + }}, + {weight: 2, minLen: 72, build: func(rng *rand.Rand, _ int) string { + return fmt.Sprintf("> %s\n>\n> %s", demoMarkdownQuoteCorpus[rng.Intn(len(demoMarkdownQuoteCorpus))], buildLoremText(48+rng.Intn(36), rand.New(rand.NewSource(rng.Int63())))) + }}, + {weight: 2, minLen: 72, build: func(rng *rand.Rand, _ int) string { + return fmt.Sprintf("Use `%s` for incremental patches.\n\n```js\n%s\n```", sanitizeToolName(demoMarkdownLabels[rng.Intn(len(demoMarkdownLabels))]), demoMarkdownCodeSnippets[rng.Intn(len(demoMarkdownCodeSnippets))]) + }}, + {weight: 2, minLen: 180, build: func(rng *rand.Rand, _ int) string { + header := demoMarkdownTableHeaders[rng.Intn(len(demoMarkdownTableHeaders))] + lines := []string{fmt.Sprintf("| %s |", strings.Join(header, " | ")), "| --- | --- | --- |"} + for i := 0; i < 2+rng.Intn(2); i++ { + lines = append(lines, fmt.Sprintf("| %s |", strings.Join(demoMarkdownTableRows[(rng.Intn(len(demoMarkdownTableRows))+i)%len(demoMarkdownTableRows)], " | "))) + } + return strings.Join(lines, "\n") + }}, + } + var blocks []string + total := 0 + for total < chars { + block := chooseDemoSegment(segments, rng, chars-total) + blocks = append(blocks, block) + total += len(block) + 2 + } + return trimVisibleText(strings.Join(blocks, "\n\n"), chars) +} + +func chooseDemoSegment(specs []demoSegmentSpec, rng *rand.Rand, remaining int) string { + var candidates []demoSegmentSpec + total := 0 + for _, spec := range specs { + if remaining > 0 && remaining < spec.minLen { + continue + } + candidates = append(candidates, spec) + total += spec.weight + } + if len(candidates) == 0 { + return specs[0].build(rng, remaining) + } + target := rng.Intn(total) + for _, spec := range candidates { + target -= spec.weight + if target < 0 { + return spec.build(rng, remaining) + } + } + return candidates[0].build(rng, remaining) +} + +func trimVisibleText(text string, limit int) string { + text = strings.TrimSpace(text) + if len(text) <= limit { + return text + } + blocks := strings.Split(text, "\n\n") + var kept []string + total := 0 + for _, block := range blocks { + block = strings.TrimSpace(block) + if block == "" { + continue + } + next := total + len(block) + if len(kept) > 0 { + next += 2 + } + if next > limit { + if len(kept) == 0 { + if isMarkdownSensitiveBlock(block) { + kept = append(kept, trimMarkdownBlock(block, limit)) + } else { + kept = append(kept, trimText(block, limit)) + } + } + break + } + kept = append(kept, block) + total = next + } + if len(kept) > 0 { + return strings.Join(kept, "\n\n") + } + return trimText(text, limit) +} + +func isMarkdownSensitiveBlock(block string) bool { + return strings.Contains(block, "](") || + strings.Contains(block, "```") || + strings.Contains(block, "\n|") || + strings.HasPrefix(block, "|") || + strings.HasPrefix(block, ">") || + strings.HasPrefix(block, "-") +} + +func trimMarkdownBlock(block string, limit int) string { + trimmed := trimText(block, limit) + if strings.Count(trimmed, "[") != strings.Count(trimmed, "]") { + if idx := strings.LastIndex(trimmed, "["); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + } + if strings.Contains(trimmed, "](") && strings.Count(trimmed, "(") != strings.Count(trimmed, ")") { + if idx := strings.LastIndex(trimmed, "["); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + } + if strings.Count(trimmed, "```")%2 != 0 { + if idx := strings.LastIndex(trimmed, "```"); idx >= 0 { + trimmed = strings.TrimSpace(trimmed[:idx]) + } + } + if trimmed == "" { + return trimText(block, limit) + } + return trimmed +} + +func trimText(text string, limit int) string { + text = strings.TrimSpace(text) + if limit <= 0 || len(text) <= limit { + return text + } + minCutoff := max(1, (limit*3)/4) + for i := min(limit, len(text)); i >= minCutoff; i-- { + switch text[i-1] { + case '.', '!', '?': + return strings.TrimSpace(text[:i]) + } + } + for i := min(limit, len(text)); i >= minCutoff; i-- { + if text[i-1] == ' ' { + return strings.Trim(strings.TrimSpace(text[:i]), ".,;:") + } + } + return strings.Trim(strings.TrimSpace(text[:limit]), ".,;:") +} diff --git a/pkg/connector/ai_types.go b/pkg/connector/ai_types.go new file mode 100644 index 0000000..80c3ee2 --- /dev/null +++ b/pkg/connector/ai_types.go @@ -0,0 +1,156 @@ +package connector + +import ( + "context" + "errors" + "time" + + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" +) + +var ( + errApprovalRequested = errors.New("approval requested") + errApprovalDenied = errors.New("approval denied") +) + +const ( + defaultChunkMin = 24 + defaultChunkMax = 96 + maxDemoChars = 96 * 1024 + maxDemoReasoningChars = 8192 + maxDemoToolSpecs = 16 + maxDemoSteps = 32 + maxDemoCollections = 16 + maxDemoRandomActions = 64 + maxDemoChaosRuns = 16 + maxDemoChaosActions = 64 + maxDemoDuration = 5 * time.Minute + maxDemoDelay = 30 * time.Second + maxDemoChunkChars = 512 + maxDemoStagger = 30 * time.Second +) + +const ( + randomActionText = "text" + randomActionThinking = "thinking" + randomActionStep = "step" + randomActionTool = "tool" + randomActionToolFail = "tool_fail" + randomActionToolDeny = "tool_deny" + randomActionToolApproval = "tool_approval" + randomActionSource = "source" + randomActionDocument = "document" + randomActionFile = "file" + randomActionMetadata = "metadata" + randomActionData = "data" + randomActionDataTransient = "data_transient" +) + +type commonCommandOptions struct { + ReasoningChars int + Steps int + Sources int + Documents int + Files int + Meta bool + DataName string + DataTransientName string + DelayMin time.Duration + DelayMax time.Duration + ChunkMin int + ChunkMax int + FinishReason string + Abort bool + Error bool + Seed int64 + SeedSet bool +} + +type loremCommand struct { + Chars int + Options commonCommandOptions +} + +type toolSpec struct { + Name string + Tags []string + Fail bool + Approval bool + Deny bool + Delta bool + InputError bool + Preliminary bool + Provider bool + SequenceIndex int +} + +type toolsCommand struct { + Chars int + Tools []toolSpec + Options commonCommandOptions +} + +type sharedStreamOptions struct { + Profile string + Seed int64 + SeedSet bool + AllowAbort bool + AllowError bool + AllowApproval bool +} + +type randomCommand struct { + Duration time.Duration + Actions int + Chars int + DelayMin time.Duration + DelayMax time.Duration + Terminal string + Runs int + StaggerMin time.Duration + StaggerMax time.Duration + sharedStreamOptions +} + +type randomActionOption struct { + name string + weight int +} + +type chaosCommand struct { + Runs int + Duration time.Duration + StaggerMin time.Duration + StaggerMax time.Duration + MaxActions int + sharedStreamOptions +} + +type parsedCommand struct { + Name string + Lorem *loremCommand + Tools *toolsCommand + Random *randomCommand + Chaos *chaosCommand +} + +type aiRuntime struct { + now func() time.Time + sleep func(context.Context, time.Duration) error +} + +type aiRunner struct { + runtime aiRuntime + approvals map[string]agui.ToolApprovalResponse +} + +type aiRunPlan struct { + Run *aistream.Run + Delay time.Duration + // EffectiveCommand is the canonical command form used to deterministically + // replay this run during approval continuation. For random/chaos sub-runs + // (where the seed was derived implicitly) this includes the resolved + // --seed=N so the continuation reproduces the same action sequence. + EffectiveCommand string +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 38f3d05..6ff8c0b 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -2,6 +2,7 @@ package connector import ( "context" + "encoding/json" "errors" "fmt" "regexp" @@ -9,7 +10,11 @@ import ( "sync" "time" + "github.com/beeper/ai-bridge/pkg/ag-ui" + "github.com/beeper/ai-bridge/pkg/ai-stream" + aibridgev2 "github.com/beeper/ai-bridge/pkg/ai-stream/bridgev2" "github.com/rs/zerolog/log" + "go.mau.fi/util/exsync" "go.mau.fi/util/jsontime" "go.mau.fi/util/ptr" @@ -19,6 +24,7 @@ import ( "maunium.net/go/mautrix/bridgev2/simplevent" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) type DummyClient struct { @@ -28,6 +34,15 @@ type DummyClient struct { UserLogin *bridgev2.UserLogin Connector *DummyConnector + + approvalSelectionsOnce sync.Once + approvalSelections *exsync.Map[string, string] + aiRunSessionsMu sync.Mutex + aiRunSessions map[string]*aiRunSession +} + +type aiRunSession struct { + Decisions map[string]agui.ToolApprovalResponse } var _ bridgev2.NetworkAPI = (*DummyClient)(nil) @@ -35,6 +50,11 @@ var _ bridgev2.IdentifierResolvingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.BackfillingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.DeleteChatHandlingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*DummyClient)(nil) +var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) + +const ( + dummyAIAgentName string = "Dummy" +) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -198,16 +218,169 @@ func (dc *DummyClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma messageID = networkid.MessageID(msg.Event.Unsigned.TransactionID) } - return &bridgev2.MatrixMessageResponse{ + resp := &bridgev2.MatrixMessageResponse{ DB: &database.Message{ ID: messageID, SenderID: networkid.UserID(dc.UserLogin.ID), Timestamp: timestamp, }, StreamOrder: time.Now().UnixNano(), + } + + return resp, nil +} + +func (dc *DummyClient) PreHandleMatrixReaction(_ context.Context, msg *bridgev2.MatrixReaction) (bridgev2.MatrixReactionPreResponse, error) { + if msg == nil || msg.Content == nil { + return bridgev2.MatrixReactionPreResponse{}, nil + } + senderID := networkid.UserID("") + if dc != nil && dc.UserLogin != nil { + senderID = networkid.UserID(dc.UserLogin.ID) + } + key := aistream.NormalizeReaction(msg.Content.RelatesTo.Key) + return bridgev2.MatrixReactionPreResponse{ + SenderID: senderID, + EmojiID: networkid.EmojiID(key), + Emoji: key, + MaxReactions: 1, }, nil } +func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (*database.Reaction, error) { + if dc == nil || dc.UserLogin == nil || msg == nil || msg.TargetMessage == nil || msg.Content == nil || msg.Portal == nil { + return &database.Reaction{}, nil + } + if isApprovalOptionReaction(msg) { + return &database.Reaction{}, nil + } + approvalID := string(msg.TargetMessage.ID) + if !strings.HasPrefix(approvalID, "approval-") { + return &database.Reaction{}, nil + } + reaction := aistream.NormalizeReaction(msg.Content.RelatesTo.Key) + selected, ok := aistream.ResolveApprovalChoice(aistream.DefaultApprovalChoices(), reaction) + if !ok { + return &database.Reaction{}, nil + } + response := aistream.ApprovalResponseForChoice(approvalID, selected) + + selectedKey, firstResolution := dc.resolveApprovalOnce(approvalID, reaction) + dc.cleanupApprovalReactions(ctx, msg.Portal, networkid.MessageID(approvalID), selectedKey, reaction, msg) + if !firstResolution { + log.Info(). + Str("approval_id", approvalID). + Str("reaction", reaction). + Str("selected_reaction", selectedKey). + Msg("Ignoring duplicate dummy AI approval reaction") + return &database.Reaction{}, nil + } + portal := msg.Portal + target := msg.TargetMessage + dc.wg.Add(1) + go func() { + defer dc.wg.Done() + dc.queueAIApprovalResponse(dc.ctx, portal, target, response) + }() + + logger := log.Info(). + Str("approval_id", approvalID). + Str("reaction", reaction). + Str("choice", selected.Key). + Bool("approved", response.Approved) + if msg.Event != nil { + logger = logger.Stringer("sender", msg.Event.Sender) + } + logger.Msg("Resolved dummy AI approval from Matrix reaction") + + return &database.Reaction{}, nil +} + +func isApprovalOptionReaction(msg *bridgev2.MatrixReaction) bool { + if msg == nil || msg.Event == nil { + return false + } + _, ok := msg.Event.Content.Raw["com.beeper.ai.approval_option"] + return ok +} + +func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (string, bool) { + dc.approvalSelectionsOnce.Do(func() { + dc.approvalSelections = exsync.NewMap[string, string]() + }) + selected, alreadyResolved := dc.approvalSelections.GetOrSet(approvalID, selectedKey) + return selected, !alreadyResolved +} + +func (dc *DummyClient) cleanupApprovalReactions(ctx context.Context, portal *bridgev2.Portal, approvalMessageID networkid.MessageID, selectedKey, reactionKey string, msg *bridgev2.MatrixReaction) { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { + return + } + reactions, err := dc.UserLogin.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, approvalMessageID) + if err != nil { + log.Warn().Err(err).Str("approval_id", string(approvalMessageID)).Msg("Failed to load approval reactions") + return + } + events := make([]aistream.ReactionEvent, 0, len(reactions)+1) + reactionByMXID := make(map[string]*database.Reaction, len(reactions)) + for _, reaction := range reactions { + if reaction == nil || reaction.MXID == "" { + continue + } + eventID := string(reaction.MXID) + reactionByMXID[eventID] = reaction + events = append(events, aistream.ReactionEvent{ + EventID: eventID, + Sender: string(reaction.SenderID), + Key: reaction.Emoji, + Bridge: reaction.SenderID == dummyAISenderForPortal(portal), + }) + } + if msg != nil && msg.Event != nil && msg.Event.ID != "" { + senderID := string(msg.Event.Sender) + if msg.PreHandleResp != nil && msg.PreHandleResp.SenderID != "" { + senderID = string(msg.PreHandleResp.SenderID) + } + events = append(events, aistream.ReactionEvent{ + EventID: string(msg.Event.ID), + Sender: senderID, + Key: reactionKey, + }) + } + sender := dummyAISenderForPortal(portal) + cleanup := aistream.CleanupApprovalReactions(aistream.DefaultApprovalChoices(), selectedKey, events, string(sender)) + intent, ok := portal.GetIntentFor(ctx, bridgev2.EventSender{Sender: sender}, dc.UserLogin, bridgev2.RemoteEventMessageRemove) + if !ok || intent == nil { + log.Warn().Str("approval_id", string(approvalMessageID)).Msg("Failed to resolve AI sender intent for approval reaction cleanup") + return + } + for _, reactionEventID := range cleanup.RedactReactionEvents { + reactionMXID := id.EventID(reactionEventID) + _, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{Redacts: reactionMXID}, + }, nil) + if err != nil { + log.Warn().Err(err).Stringer("reaction_mxid", reactionMXID).Msg("Failed to redact approval reaction") + continue + } + if reaction := reactionByMXID[reactionEventID]; reaction != nil { + if err := dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, reaction); err != nil { + log.Warn().Err(err).Stringer("reaction_mxid", reaction.MXID).Msg("Failed to delete approval reaction") + } + } + } +} + +func (dc *DummyClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || msg == nil || msg.TargetReaction == nil { + return nil + } + if err := dc.UserLogin.Bridge.DB.Reaction.Delete(ctx, msg.TargetReaction); err != nil { + return fmt.Errorf("failed to delete reaction on remove: %w", err) + } + return nil +} + func getTransactionID(msg *bridgev2.MatrixMessage) networkid.TransactionID { if msg.Event != nil && msg.Event.Unsigned.TransactionID != "" { return networkid.TransactionID(msg.Event.Unsigned.TransactionID) @@ -242,6 +415,48 @@ func getRemoteEchoBehavior(content *event.MessageEventContent) remoteEchoBehavio return remoteEchoBehavior{pending: true, delay: delay} } +// ensureAISenderInvited queues a ChatInfoChange that adds the AI sender ghost +// to the given portal. The bridge's default portal generator can create +// portals with members=0, in which case the per-portal AI sender chosen by +// dummyAISenderForPortal is not actually a room member — sending the anchor +// from a non-member ghost would fail. Re-asserting an existing membership is +// a no-op for bridgev2, so it is safe to call for every AI run. +func (dc *DummyClient) ensureAISenderInvited(portal *bridgev2.Portal, sender networkid.UserID) { + if dc == nil || dc.UserLogin == nil || portal == nil || sender == "" { + return + } + changes := &bridgev2.ChatMemberList{MemberMap: bridgev2.ChatMemberMap{}} + changes.MemberMap.Set(bridgev2.ChatMember{ + EventSender: bridgev2.EventSender{Sender: sender}, + Membership: event.MembershipJoin, + MemberEventExtra: map[string]any{ + "displayname": dummyAIAgentNameForPortal(portal), + }, + }) + now := time.Now() + dc.UserLogin.QueueRemoteEvent(&simplevent.ChatInfoChange{ + EventMeta: simplevent.EventMeta{ + Type: bridgev2.RemoteEventChatInfoChange, + PortalKey: portal.PortalKey, + Sender: bridgev2.EventSender{Sender: sender}, + Timestamp: now, + StreamOrder: now.UnixNano(), + }, + ChatInfoChange: &bridgev2.ChatInfoChange{MemberChanges: changes}, + }) +} + +func dummyAISenderForPortal(portal *bridgev2.Portal) networkid.UserID { + if portal == nil { + return networkid.UserID(dummyAIAgentName) + } + return stablePortalUserIDByIndex(portal.ID, 0) +} + +func dummyAIAgentNameForPortal(portal *bridgev2.Portal) string { + return dummyAIAgentName +} + func (dc *DummyClient) queueRemoteEcho(msg *bridgev2.MatrixMessage, transactionID networkid.TransactionID, timestamp time.Time, delay time.Duration) { if delay <= 0 || msg.Portal == nil { return @@ -290,6 +505,628 @@ func cloneMessageContent(content *event.MessageEventContent) *event.MessageEvent return &cloned } +func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Portal, inbound *event.MessageEventContent) { + if portal == nil { + return + } + + now := time.Now() + runID := "run-" + string(randomMessageID()) + sender := dummyAISenderForPortal(portal) + agentName := dummyAIAgentNameForPortal(portal) + var body string + if inbound != nil { + body = inbound.Body + } + plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), body, now, string(sender), agentName) + if err != nil { + log.Warn().Err(err).Msg("Failed to build AI runs") + return + } + for _, plan := range plans { + if plan.Run == nil { + continue + } + placeholderID := networkid.MessageID(plan.Run.MessageID) + effectiveCommand := plan.EffectiveCommand + if effectiveCommand == "" { + effectiveCommand = body + } + + dc.wg.Add(1) + go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { + defer dc.wg.Done() + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-dc.ctx.Done(): + timer.Stop() + return + } + } + dc.ensureAISenderInvited(portal, sender) + anchorAt := time.Now() + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(run), anchorAt)) + dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run, command, anchorAt) + }(portal, sender, placeholderID, *plan.Run, effectiveCommand, plan.Delay) + } +} + +func initialAIAnchorRun(run aistream.Run) aistream.Run { + run.Status = aistream.Status{State: "streaming"} + run.Usage = agui.Usage{} + run.Preview = aistream.Preview{} + return run +} + +func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, anchorAt time.Time) { + targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) + if targetEventID == "" { + log.Warn(). + Str("run_id", run.RunID). + Str("message_id", string(messageID)). + Msg("Timed out waiting for AI anchor Matrix event") + return + } + dc.emitAIRunStream(portal, sender, messageID, targetEventID, run, command, 1, anchorAt) +} + +// emitAIRunStream packs and emits one segment of an AI run — used both for +// the initial run and for any approval continuation. It queues approval +// prompts produced by the segment, repacks once approval event IDs are +// known, and finally emits the carriers and (if the run terminated) the +// final metadata edit. +func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { + sizingRun := run + annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) + carriers, err := aistream.PackRunFromSeq(sizingRun, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + if err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") + return + } + carriers = splitCarriersForTimedEmission(carriers) + nextSeq := aistream.NextSeq(carriers) + queuedPrompts := run.Prompts + if len(queuedPrompts) > 1 { + log.Warn(). + Str("run_id", run.RunID). + Int("approval_prompts", len(queuedPrompts)). + Msg("AI run produced multiple simultaneous approval prompts; queueing the first prompt only") + queuedPrompts = queuedPrompts[:1] + } + approvalEventIDs := make(map[string]id.EventID, len(queuedPrompts)) + for _, prompt := range queuedPrompts { + prompt.SeqStart = nextSeq + ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) + if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { + approvalEventIDs[ctx.ID] = approvalEventID + log.Info(). + Str("run_id", run.RunID). + Str("approval_id", ctx.ID). + Stringer("approval_event_id", approvalEventID). + Int("approval_seq_start", ctx.SeqStart). + Msg("AI approval notice ready for reaction") + } else { + log.Warn(). + Str("run_id", run.RunID). + Str("approval_id", ctx.ID). + Int("approval_seq_start", ctx.SeqStart). + Msg("Timed out waiting for AI approval notice Matrix event") + } + } + if len(approvalEventIDs) > 0 { + annotateApprovalEventIDs(&run, approvalEventIDs) + carriers, err = aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + if err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to repack AI stream with approval event IDs") + return + } + carriers = splitCarriersForTimedEmission(carriers) + if actualNextSeq := aistream.NextSeq(carriers); actualNextSeq != nextSeq { + log.Warn(). + Str("run_id", run.RunID). + Int("expected_next_seq", nextSeq). + Int("actual_next_seq", actualNextSeq). + Msg("AI approval event ID repack changed stream sequence count") + return + } + } else if len(queuedPrompts) > 0 { + log.Info(). + Str("run_id", run.RunID). + Int("approval_prompts", len(queuedPrompts)). + Msg("Sending approval stream without approval event IDs") + } + dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) + if len(run.Prompts) > 0 && run.Status.State == "streaming" { + log.Info(). + Str("run_id", run.RunID). + Str("message_id", string(messageID)). + Int("approval_prompts", len(run.Prompts)). + Msg("AI run paused for approval") + } + if run.Status.State != "streaming" { + dc.queueAIRunFinalMetadata(portal, sender, messageID, run) + } +} + +func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { + streamStart := time.Now() + // minCarrierTimestamp guarantees every carrier lands strictly after the + // anchor message timestamp so Matrix room ordering keeps the anchor first + // and downstream RelatesTo resolution can always find the parent event. + minCarrierTimestamp := anchorAt.Add(time.Millisecond) + if streamStart.Before(minCarrierTimestamp) { + streamStart = minCarrierTimestamp + } + for i, carrier := range carriers { + dc.sleepUntilCarrierTime(run, carrier, streamStart) + now := time.Now() + if now.Before(minCarrierTimestamp) { + now = minCarrierTimestamp + } + minCarrierTimestamp = now.Add(time.Nanosecond) + dc.UserLogin.QueueRemoteEvent(aibridgev2.Carrier(portal.PortalKey, sender, run, carrier, targetEventID, startSeq+i, now)) + } +} + +func splitCarriersForTimedEmission(carriers []aistream.Carrier) []aistream.Carrier { + out := make([]aistream.Carrier, 0, len(carriers)) + for _, carrier := range carriers { + if len(carrier.Envelopes) <= 1 { + out = append(out, carrier) + continue + } + for _, env := range carrier.Envelopes { + out = append(out, aistream.Carrier{Envelopes: []aistream.Envelope{env}}) + } + } + return out +} + +func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { + target := carrierTimestamp(run, carrier, streamStart) + if target.IsZero() { + return + } + delay := time.Until(target) + if delay <= 0 { + return + } + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-dc.ctx.Done(): + timer.Stop() + } +} + +func carrierTimestamp(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) time.Time { + base := runStartTimestamp(run) + if base.IsZero() { + return time.Time{} + } + var latest time.Time + for _, env := range carrier.Envelopes { + eventTime := eventTimestamp(env.Part) + if eventTime.IsZero() { + continue + } + if latest.IsZero() || eventTime.After(latest) { + latest = eventTime + } + } + if latest.IsZero() { + return time.Time{} + } + return streamStart.Add(latest.Sub(base)) +} + +func runStartTimestamp(run aistream.Run) time.Time { + for _, evt := range run.Events { + if ts := eventTimestamp(evt); !ts.IsZero() { + return ts + } + } + return time.Time{} +} + +func eventTimestamp(evt agui.Event) time.Time { + raw, ok := evt["timestamp"] + if !ok { + return time.Time{} + } + var millis int64 + switch value := raw.(type) { + case int64: + millis = value + case int: + millis = int64(value) + case int32: + millis = int64(value) + case float64: + millis = int64(value) + case json.Number: + parsed, err := value.Int64() + if err != nil { + return time.Time{} + } + millis = parsed + default: + return time.Time{} + } + if millis <= 0 { + return time.Time{} + } + return time.UnixMilli(millis) +} + +func (dc *DummyClient) waitForMessageMXID( + portal *bridgev2.Portal, + messageID networkid.MessageID, + timeout time.Duration, +) id.EventID { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { + return "" + } + parent := dc.ctx + if parent == nil { + parent = context.Background() + } + ctx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + + receivers := []networkid.UserLoginID{portal.Receiver} + if dc.UserLogin.ID != "" && dc.UserLogin.ID != portal.Receiver { + receivers = append(receivers, dc.UserLogin.ID) + } + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return "" + case <-ticker.C: + } + for _, receiver := range receivers { + mxid := dc.lookupMessageMXID(ctx, receiver, messageID) + if mxid != "" { + return mxid + } + } + } + return "" +} + +func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) id.EventID { + message, err := dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, receiver, messageID) + if err != nil || message == nil { + return "" + } + return message.MXID +} + +func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) aistream.ApprovalContext { + choices := aistream.DefaultApprovalChoices() + approvalCtx := aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: command, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + TargetEvent: string(targetEventID), + AgentID: run.AgentID, + AgentName: run.AgentName, + Model: run.Model, + SeqStart: prompt.SeqStart, + PreviewText: run.Preview.Text, + PreviewTruncated: run.Preview.Truncated, + } + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, sender, approvalCtx, timestamp)) + + for i, choice := range choices { + choice := choice + dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, sender, approvalCtx, choice, timestamp.Add(time.Duration(i+1)*time.Millisecond))) + } + return approvalCtx +} + +func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) { + if run == nil || len(eventIDs) == 0 { + return + } + for _, evt := range run.Events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + continue + } + value, _ := evt["value"].(map[string]any) + if value == nil { + continue + } + approvalID := aistream.ApprovalIDFromRequestedValue(value) + eventID := eventIDs[approvalID] + if eventID == "" { + continue + } + aistream.SetApprovalRequestedEventID(value, string(eventID)) + } +} + +func approvalEventIDPlaceholders(prompts []aistream.ApprovalPrompt) map[string]id.EventID { + if len(prompts) == 0 { + return nil + } + placeholders := make(map[string]id.EventID, len(prompts)) + const placeholderEventID = "$approval_event_id_placeholder_padding_for_stable_ai_stream_sequence_000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000:beeper.local" + for _, prompt := range prompts { + if prompt.ID != "" { + placeholders[prompt.ID] = id.EventID(placeholderEventID) + } + } + return placeholders +} + +func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response agui.ToolApprovalResponse) { + approvalCtx, ok := dc.approvalContextForMessage(ctx, portal, approvalMessage) + if !ok { + log.Warn().Str("approval_id", messageIDString(approvalMessage)).Msg("Missing AI approval metadata") + return + } + if response.ID == "" { + response.ID = approvalCtx.ID + } + now := time.Now() + approvals := dc.recordAIApprovalDecision(approvalCtx.RunID, response) + run, err := buildAIApprovalContinuationRunWithApprovals(ctx, approvalCtx, approvals, now) + if err != nil { + log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to build AI approval continuation") + return + } + targetEventID := id.EventID(approvalCtx.TargetEvent) + if targetEventID == "" { + log.Warn().Str("approval_id", approvalCtx.ID).Msg("Missing AI approval target event") + return + } + sender := networkid.UserID(approvalCtx.AgentID) + if sender == "" { + sender = dummyAISenderForPortal(portal) + } + dc.ensureAISenderInvited(portal, sender) + dc.emitAIRunStream(portal, sender, networkid.MessageID(approvalCtx.MessageID), targetEventID, run, approvalCtx.Command, approvalCtx.SeqStart, now) + log.Info(). + Str("run_id", approvalCtx.RunID). + Str("approval_id", approvalCtx.ID). + Str("tool_call_id", approvalCtx.ToolCallID). + Bool("approved", response.Approved). + Bool("always", response.Always). + Int("seq_start", approvalCtx.SeqStart). + Str("state", run.Status.State). + Int("pending_prompts", len(run.Prompts)). + Msg("Queued AI approval continuation") +} + +func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { + cmd, err := parseCommand(approvalCtx.Command) + if err != nil { + return aistream.Run{}, err + } + run, err := buildAIRunFromCommandWithApprovals(ctx, approvalCtx.RunID, approvalCtx.ThreadID, now, cmd, approvalCtx.AgentID, approvalCtx.AgentName, approvals) + if err != nil { + return aistream.Run{}, err + } + if run == nil { + return aistream.Run{}, fmt.Errorf("approval continuation produced no run") + } + start := approvalContinuationStart(run.Events, approvalCtx.ID) + if start < 0 { + return aistream.Run{}, fmt.Errorf("approval response event %q not found", approvalCtx.ID) + } + run.Events = append([]agui.Event(nil), run.Events[start:]...) + run.RunID = approvalCtx.RunID + run.ThreadID = approvalCtx.ThreadID + run.MessageID = approvalCtx.MessageID + run.ToolCallID = approvalCtx.ToolCallID + run.ApprovalID = approvalCtx.ID + // Keep only prompts that the continuation segment newly emitted (i.e. + // approvals raised by tools that ran AFTER the resolved one). The + // already-resolved approval has been removed from the event range above + // and must not be queued again. + run.Prompts = filterPendingPrompts(run.Prompts, approvalCtx.ID, run.Events) + return *run, nil +} + +func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, events []agui.Event) []aistream.ApprovalPrompt { + if len(prompts) == 0 { + return nil + } + requested := make(map[string]bool, len(events)) + for _, evt := range events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + continue + } + value, _ := evt["value"].(map[string]any) + if id := aistream.ApprovalIDFromRequestedValue(value); id != "" { + requested[id] = true + } + } + out := prompts[:0] + for _, prompt := range prompts { + if prompt.ID == resolvedID { + continue + } + if !requested[prompt.ID] { + continue + } + out = append(out, prompt) + } + return out +} + +func approvalContinuationStart(events []agui.Event, approvalID string) int { + for i, evt := range events { + if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomResponded { + continue + } + value, _ := evt["value"].(map[string]any) + approval, _ := value["approval"].(agui.ToolApprovalResponse) + if approval.ID == approvalID { + return i + } + if raw, ok := value["approval"].(map[string]any); ok { + if idValue, _ := raw["id"].(string); idValue == approvalID { + return i + } + } + } + return -1 +} + +func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { + var fetch func(context.Context, networkid.MessageID) (*database.Message, error) + if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && dc.UserLogin.Bridge.DB != nil && portal != nil { + fetch = func(ctx context.Context, messageID networkid.MessageID) (*database.Message, error) { + return dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, messageID) + } + } + return approvalContextForMessage(ctx, message, fetch) +} + +func approvalContextForMessage(ctx context.Context, message *database.Message, fetch func(context.Context, networkid.MessageID) (*database.Message, error)) (aistream.ApprovalContext, bool) { + if approvalCtx, ok := approvalContextFromMetadata(message); ok { + return approvalCtx, true + } + if message == nil || message.ID == "" || fetch == nil { + return aistream.ApprovalContext{}, false + } + fetched, err := fetch(ctx, message.ID) + if err != nil { + log.Warn().Err(err).Str("approval_id", string(message.ID)).Msg("Failed to reload AI approval message") + return aistream.ApprovalContext{}, false + } + return approvalContextFromMetadata(fetched) +} + +func approvalContextFromMetadata(message *database.Message) (aistream.ApprovalContext, bool) { + if message == nil { + return aistream.ApprovalContext{}, false + } + return approvalContextFromAny(message.Metadata) +} + +func approvalContextFromAny(value any) (aistream.ApprovalContext, bool) { + switch typed := value.(type) { + case aistream.ApprovalContext: + return validApprovalContext(typed) + case *aistream.ApprovalContext: + if typed == nil { + return aistream.ApprovalContext{}, false + } + return validApprovalContext(*typed) + case map[string]any: + if nested, ok := typed["com.beeper.ai.approval"]; ok { + return approvalContextFromAny(nested) + } + case *map[string]any: + if typed == nil { + return aistream.ApprovalContext{}, false + } + return approvalContextFromAny(*typed) + case json.RawMessage: + return approvalContextFromJSON(typed) + case []byte: + return approvalContextFromJSON(typed) + case string: + return approvalContextFromJSON([]byte(typed)) + } + var ctx aistream.ApprovalContext + raw, err := json.Marshal(value) + if err != nil { + return aistream.ApprovalContext{}, false + } + if err = json.Unmarshal(raw, &ctx); err != nil { + return aistream.ApprovalContext{}, false + } + return validApprovalContext(ctx) +} + +func approvalContextFromJSON(raw []byte) (aistream.ApprovalContext, bool) { + var decoded any + if err := json.Unmarshal(raw, &decoded); err == nil { + if approvalCtx, ok := approvalContextFromAny(decoded); ok { + return approvalCtx, true + } + } + var ctx aistream.ApprovalContext + if err := json.Unmarshal(raw, &ctx); err != nil { + return aistream.ApprovalContext{}, false + } + return validApprovalContext(ctx) +} + +func messageIDString(message *database.Message) string { + if message == nil { + return "" + } + return string(message.ID) +} + +func validApprovalContext(ctx aistream.ApprovalContext) (aistream.ApprovalContext, bool) { + if ctx.ID == "" || ctx.ThreadID == "" || ctx.RunID == "" || ctx.MessageID == "" || ctx.Command == "" || ctx.ToolCallID == "" || ctx.TargetEvent == "" { + return aistream.ApprovalContext{}, false + } + if ctx.SeqStart <= 0 { + ctx.SeqStart = 1 + } + return ctx, true +} + +func (dc *DummyClient) queueAIRunFinalMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run) { + dc.UserLogin.QueueRemoteEvent(aibridgev2.FinalMetadataEdit(portal.PortalKey, sender, messageID, run, time.Now())) +} + +func (dc *DummyClient) ensureAIRunSession(runID string) { + if dc == nil || runID == "" { + return + } + dc.aiRunSessionsMu.Lock() + defer dc.aiRunSessionsMu.Unlock() + if dc.aiRunSessions == nil { + dc.aiRunSessions = make(map[string]*aiRunSession) + } + if dc.aiRunSessions[runID] == nil { + dc.aiRunSessions[runID] = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + } +} + +func (dc *DummyClient) recordAIApprovalDecision(runID string, response agui.ToolApprovalResponse) map[string]agui.ToolApprovalResponse { + decisions := make(map[string]agui.ToolApprovalResponse) + if response.ID == "" { + return decisions + } + if dc == nil || runID == "" { + decisions[response.ID] = response + return decisions + } + dc.aiRunSessionsMu.Lock() + defer dc.aiRunSessionsMu.Unlock() + if dc.aiRunSessions == nil { + dc.aiRunSessions = make(map[string]*aiRunSession) + } + session := dc.aiRunSessions[runID] + if session == nil { + session = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + dc.aiRunSessions[runID] = session + } + session.Decisions[response.ID] = response + for id, decision := range session.Decisions { + decisions[id] = decision + } + return decisions +} + func (dc *DummyClient) HandleMatrixDeleteChat(ctx context.Context, msg *bridgev2.MatrixDeleteChat) error { // bridgev2 will delete the portal + Matrix room after this returns nil. // For dummybridge, there's no separate remote-side deletion to do. @@ -327,8 +1164,8 @@ func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, ghostInfo, _ := dc.GetUserInfo(ctx, ghost) portalInfo, _ := dc.GetChatInfo(ctx, portal) portalInfo.Members = &bridgev2.ChatMemberList{ - Members: []bridgev2.ChatMember{ - { + MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(dc.UserLogin.ID): { EventSender: bridgev2.EventSender{ IsFromMe: true, Sender: networkid.UserID(dc.UserLogin.ID), @@ -336,7 +1173,7 @@ func (dc *DummyClient) ResolveIdentifier(ctx context.Context, identifier string, Membership: event.MembershipJoin, PowerLevel: ptr.Ptr(50), }, - { + userID: { EventSender: bridgev2.EventSender{ Sender: userID, }, diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index 9722482..e3934e1 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -29,6 +29,9 @@ var AllCommands = []commands.CommandHandler{ MessagesCommand, KickMeCommand, FileCommand, + StreamCommand, + StreamToolsCommand, + StreamHelpCommand, CatCommand, CatAvatarCommand, } @@ -226,8 +229,7 @@ var FileCommand = &commands.FullHandler{ Func: func(e *commands.Event) { e.Reply("Generating file event in this room") - var mediaData []byte - mediaData = []byte("Test text file") + mediaData := []byte("Test text file") mediaName := "test.txt" mediaMime := "text/plain" @@ -265,6 +267,67 @@ var FileCommand = &commands.FullHandler{ }, } +func runStreamCommand(e *commands.Event, name string) { + if e.Portal == nil { + e.Reply("Can only stream within a portal") + return + } + login := e.User.GetDefaultLogin() + if login == nil { + e.Reply("No login") + return + } + client, ok := login.Client.(*DummyClient) + if !ok || client == nil { + e.Reply("Default login is not a dummybridge login") + return + } + body := strings.TrimSpace(name + " " + e.RawArgs) + if _, err := parseCommand(body); err != nil { + e.Reply(err.Error()) + return + } + client.queueAIResponse(e.Ctx, e.Portal, &event.MessageEventContent{Body: body}) + e.Reply("Started %s", name) +} + +var StreamCommand = &commands.FullHandler{ + Func: func(e *commands.Event) { + runStreamCommand(e, "stream") + }, + Name: "stream", + Help: commands.HelpMeta{ + Description: "Generate a random streamed AI event sequence", + Args: "[seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", + Section: DummyHelpsection, + }, + RequiresLogin: true, +} + +var StreamToolsCommand = &commands.FullHandler{ + Func: func(e *commands.Event) { + runStreamCommand(e, "stream-tools") + }, + Name: "stream-tools", + Help: commands.HelpMeta{ + Description: "Generate a streamed AI event sequence with explicit tool calls", + Args: " ... [common options]", + Section: DummyHelpsection, + }, + RequiresLogin: true, +} + +var StreamHelpCommand = &commands.FullHandler{ + Func: func(e *commands.Event) { + e.Reply(helpText()) + }, + Name: "stream-help", + Help: commands.HelpMeta{ + Description: "Show stream command examples", + Section: DummyHelpsection, + }, +} + var catpions []string = []string{ "You’ve cat to be kitten me!", "I’m feline fine!", diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 91ea6b1..b65c6ec 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -44,11 +44,18 @@ func (dc *DummyConnector) Start(ctx context.Context) error { } func (dc *DummyConnector) GetCapabilities() *bridgev2.NetworkGeneralCapabilities { - return &bridgev2.NetworkGeneralCapabilities{} + return &bridgev2.NetworkGeneralCapabilities{ + Provisioning: bridgev2.ProvisioningCapabilities{ + ResolveIdentifier: bridgev2.ResolveIdentifierCapabilities{ + CreateDM: true, + LookupUsername: true, + }, + }, + } } func (dc *DummyConnector) GetBridgeInfoVersion() (info, caps int) { - return 0, 0 + return 0, 1 } func (dc *DummyConnector) GetName() bridgev2.BridgeName { @@ -62,7 +69,10 @@ func (dc *DummyConnector) GetName() bridgev2.BridgeName { } func (dc *DummyConnector) GetDBMetaTypes() database.MetaTypes { - return database.MetaTypes{} + return database.MetaTypes{ + Message: func() any { return &map[string]any{} }, + Reaction: func() any { return &map[string]any{} }, + } } //go:embed example-config.yaml diff --git a/pkg/connector/generators.go b/pkg/connector/generators.go index 5c812c2..a3ebf56 100644 --- a/pkg/connector/generators.go +++ b/pkg/connector/generators.go @@ -58,8 +58,8 @@ func generatePortal(ctx context.Context, br *bridgev2.Bridge, login *bridgev2.Us Type: ptr.Ptr(roomType), CanBackfill: true, Members: &bridgev2.ChatMemberList{ - Members: []bridgev2.ChatMember{ - { + MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(login.ID): { EventSender: bridgev2.EventSender{ IsFromMe: true, Sender: networkid.UserID(login.ID), @@ -78,7 +78,7 @@ func generatePortal(ctx context.Context, br *bridgev2.Bridge, login *bridgev2.Us return nil, fmt.Errorf("failed to get ghost by id: %w", err) } - chatInfo.Members.Members = append(chatInfo.Members.Members, bridgev2.ChatMember{ + chatInfo.Members.MemberMap.Set(bridgev2.ChatMember{ EventSender: bridgev2.EventSender{ Sender: userID, }, diff --git a/pkg/connector/message_requests.go b/pkg/connector/message_requests.go index a1b3b21..512964c 100644 --- a/pkg/connector/message_requests.go +++ b/pkg/connector/message_requests.go @@ -113,11 +113,13 @@ func createMessageRequestPortal( Type: ptr.Ptr(roomType), MessageRequest: &isMessageRequest, CanBackfill: true, - Members: &bridgev2.ChatMemberList{Members: []bridgev2.ChatMember{{ - EventSender: bridgev2.EventSender{IsFromMe: true, Sender: networkid.UserID(login.ID)}, - Membership: event.MembershipJoin, - PowerLevel: ptr.Ptr(100), - }}}, + Members: &bridgev2.ChatMemberList{MemberMap: bridgev2.ChatMemberMap{ + networkid.UserID(login.ID): { + EventSender: bridgev2.EventSender{IsFromMe: true, Sender: networkid.UserID(login.ID)}, + Membership: event.MembershipJoin, + PowerLevel: ptr.Ptr(100), + }, + }}, } firstGhost := stablePortalUserIDByIndex(portalID, 0) @@ -128,7 +130,7 @@ func createMessageRequestPortal( return nil, fmt.Errorf("failed to get ghost by id: %w", err) } ghost.UpdateName(ctx, fmt.Sprintf("Dummy User %d", i+1)) - chatInfo.Members.Members = append(chatInfo.Members.Members, bridgev2.ChatMember{ + chatInfo.Members.MemberMap.Set(bridgev2.ChatMember{ EventSender: bridgev2.EventSender{Sender: userID}, Membership: event.MembershipJoin, })