diff --git a/go.mod b/go.mod index e53016d0..0e916bca 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ module github.com/1024XEngineer/anyclaw -go 1.25.0 +go 1.25.1 require ( + cloud.google.com/go/speech v1.33.0 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 @@ -10,15 +11,25 @@ require ( github.com/chromedp/cdproto v0.0.0-20240801214329-3f85d328b335 github.com/chromedp/chromedp v0.10.0 github.com/clipperhouse/uax29/v2 v2.5.0 + github.com/godeps/webrtcvad-go v0.1.0 github.com/gorilla/websocket v1.5.3 github.com/philippgille/chromem-go v0.7.0 golang.org/x/sys v0.42.0 - golang.org/x/text v0.22.0 + golang.org/x/text v0.35.0 + google.golang.org/api v0.275.0 + google.golang.org/grpc v1.80.0 + google.golang.org/protobuf v1.36.11 modernc.org/sqlite v1.48.1 ) require ( + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.20.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/longrunning v0.9.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect github.com/charmbracelet/x/ansi v0.11.6 // indirect github.com/charmbracelet/x/cellbuf v0.0.15 // indirect @@ -27,10 +38,16 @@ require ( github.com/clipperhouse/stringish v0.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/gobwas/ws v1.4.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.21.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect @@ -44,6 +61,19 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect + go.opentelemetry.io/otel v1.43.0 // indirect + go.opentelemetry.io/otel/metric v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/time v0.15.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/go.sum b/go.sum index 336affad..bd4ff23b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,19 @@ +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA= +cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/longrunning v0.9.0 h1:0EzbDEGsAvOZNbqXopgniY0w0a1phvu5IdUFq8grmqY= +cloud.google.com/go/longrunning v0.9.0/go.mod h1:pkTz846W7bF4o2SzdWJ40Hu0Re+UoNT6Q5t+igIcb8E= +cloud.google.com/go/speech v1.33.0 h1:555yroj4HCS7SPgfHuDU8zX+E5KrhccVWG96HNyBUAk= +cloud.google.com/go/speech v1.33.0/go.mod h1:shnf33sZbGnQQZyek1fdLOR5rRKV6D3jsNqpqyijvj8= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= @@ -26,20 +40,48 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= +github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= +github.com/godeps/webrtcvad-go v0.1.0 h1:JpVfJHSzND9p/iuO7xqko1UlB/UJjKxskEWEbzKKjrQ= +github.com/godeps/webrtcvad-go v0.1.0/go.mod h1:487THSHEZrYU29LRm4AKYCm/Y8PPq3pIJSuz1KX3MwU= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= +github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= +github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= +github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -70,27 +112,73 @@ github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde h1:x0TT0RDC7UhA github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0= github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI= +google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= diff --git a/pkg/speech/_migration_backup/gateway_speech_stt.go b/pkg/speech/_migration_backup/gateway_speech_stt.go new file mode 100644 index 00000000..79bb630f --- /dev/null +++ b/pkg/speech/_migration_backup/gateway_speech_stt.go @@ -0,0 +1,87 @@ +package gateway + +import ( + "time" + + "github.com/1024XEngineer/anyclaw/pkg/speech" +) + +func (s *Server) initSTT() { + sttCfg := s.mainRuntime.Config.Speech.STT + if !sttCfg.Enabled { + return + } + + s.sttManager = speech.NewSTTManager() + + if sttCfg.Provider != "" && sttCfg.APIKey != "" { + providerType := speech.STTProviderType(sttCfg.Provider) + sttProviderCfg := speech.STTConfig{ + Type: providerType, + APIKey: sttCfg.APIKey, + BaseURL: sttCfg.BaseURL, + Model: sttCfg.Model, + Language: sttCfg.DefaultLang, + Timeout: time.Duration(sttCfg.TimeoutSec) * time.Second, + } + if sttCfg.TimeoutSec <= 0 { + sttProviderCfg.Timeout = 120 * time.Second + } + + provider, err := speech.NewSTTProvider(sttProviderCfg) + if err != nil { + s.appendEvent("stt.init.error", "", map[string]any{"error": err.Error(), "provider": sttCfg.Provider}) + return + } + + if err := s.sttManager.Register(sttCfg.Provider, provider); err != nil { + s.appendEvent("stt.init.error", "", map[string]any{"error": err.Error(), "provider": sttCfg.Provider}) + return + } + } + + pipelineCfg := speech.STTPipelineConfig{ + Provider: sttCfg.Provider, + DefaultLang: sttCfg.DefaultLang, + AutoDetect: sttCfg.DefaultLang == "auto", + MaxDuration: time.Duration(sttCfg.MaxDurationSec) * time.Second, + MinConfidence: sttCfg.MinConfidence, + Timeout: time.Duration(sttCfg.TimeoutSec) * time.Second, + } + if sttCfg.MaxDurationSec <= 0 { + pipelineCfg.MaxDuration = 10 * time.Minute + } + if sttCfg.TimeoutSec <= 0 { + pipelineCfg.Timeout = 120 * time.Second + } + + s.sttPipeline = speech.NewSTTPipeline(s.sttManager, pipelineCfg) + + integrationCfg := speech.STTIntegrationConfig{ + Enabled: sttCfg.Enabled, + AutoSTT: sttCfg.AutoSTT, + TriggerPrefix: sttCfg.TriggerPrefix, + Provider: sttCfg.Provider, + DefaultLang: sttCfg.DefaultLang, + MaxDuration: pipelineCfg.MaxDuration, + MinConfidence: sttCfg.MinConfidence, + Timeout: pipelineCfg.Timeout, + Channels: sttCfg.Channels, + ExcludeChannels: sttCfg.ExcludeChannels, + FallbackToVoice: sttCfg.FallbackToVoice, + AppendTranscript: sttCfg.AppendTranscript, + } + if integrationCfg.TriggerPrefix == "" { + integrationCfg.TriggerPrefix = "/transcribe" + } + + s.sttIntegration = speech.NewSTTIntegration(s.sttPipeline, integrationCfg) + + s.appendEvent("stt.init.ok", "", map[string]any{ + "provider": sttCfg.Provider, + "auto_stt": sttCfg.AutoSTT, + "language": sttCfg.DefaultLang, + "channels": len(sttCfg.Channels), + "excluded": len(sttCfg.ExcludeChannels), + }) +} diff --git a/pkg/speech/_migration_backup/stt_provider.go b/pkg/speech/_migration_backup/stt_provider.go new file mode 100644 index 00000000..5f6c9ba6 --- /dev/null +++ b/pkg/speech/_migration_backup/stt_provider.go @@ -0,0 +1,245 @@ +package speech + +import ( + "context" + "fmt" + "time" +) + +type STTProviderType string + +const ( + STTProviderOpenAI STTProviderType = "openai" + STTProviderAzure STTProviderType = "azure" + STTProviderGoogle STTProviderType = "google" + STTProviderDeepgram STTProviderType = "deepgram" + STTProviderAssemblyAI STTProviderType = "assemblyai" + STTProviderWhisperCPP STTProviderType = "whisper.cpp" + STTProviderVosk STTProviderType = "vosk" + STTProviderFasterWhisper STTProviderType = "faster-whisper" + STTProviderCustom STTProviderType = "custom" +) + +type AudioInputFormat string + +const ( + InputMP3 AudioInputFormat = "mp3" + InputWAV AudioInputFormat = "wav" + InputOGG AudioInputFormat = "ogg" + InputFLAC AudioInputFormat = "flac" + InputPCM AudioInputFormat = "pcm" + InputM4A AudioInputFormat = "m4a" + InputMP4 AudioInputFormat = "mp4" + InputMPEG AudioInputFormat = "mpeg" + InputMPGA AudioInputFormat = "mpga" + InputWEBM AudioInputFormat = "webm" +) + +type STTProvider interface { + Name() string + Type() STTProviderType + Transcribe(ctx context.Context, audio []byte, opts ...TranscribeOption) (*TranscriptResult, error) + ListLanguages(ctx context.Context) ([]string, error) +} + +type STTConfig struct { + Type STTProviderType + APIKey string + BaseURL string + Model string + Language string + SampleRate int + Timeout time.Duration +} + +func NewSTTProvider(cfg STTConfig) (STTProvider, error) { + switch cfg.Type { + case STTProviderOpenAI: + opts := []WhisperOption{} + if cfg.BaseURL != "" { + opts = append(opts, WithWhisperBaseURL(cfg.BaseURL)) + } + if cfg.Model != "" { + opts = append(opts, WithWhisperModel(WhisperModel(cfg.Model))) + } + if cfg.Language != "" { + opts = append(opts, WithWhisperLanguage(cfg.Language)) + } + if cfg.Timeout > 0 { + opts = append(opts, WithWhisperTimeout(cfg.Timeout)) + } + return NewWhisperProvider(cfg.APIKey, opts...) + case STTProviderGoogle: + opts := []GoogleOption{} + if cfg.BaseURL != "" { + opts = append(opts, WithGoogleBaseURL(cfg.BaseURL)) + } + if cfg.Language != "" { + opts = append(opts, WithGoogleLanguageCode(cfg.Language)) + } + if cfg.Timeout > 0 { + opts = append(opts, WithGoogleTimeout(cfg.Timeout)) + } + return NewGoogleProvider(cfg.APIKey, opts...) + case STTProviderWhisperCPP: + opts := []WhisperCPPOption{} + if cfg.Model != "" { + opts = append(opts, WithWhisperCPPModelPath(cfg.Model)) + } + if cfg.Language != "" { + opts = append(opts, WithWhisperCPPLanguage(cfg.Language)) + } + if cfg.Timeout > 0 { + opts = append(opts, WithWhisperCPPTimeout(cfg.Timeout)) + } + return NewWhisperCPPProvider(opts...) + default: + return nil, NewSTTError(ErrProviderNotSupported, "unknown STT provider: "+string(cfg.Type)) + } +} + +type TranscribeMode string + +const ( + ModeTranscription TranscribeMode = "transcription" + ModeTranslation TranscribeMode = "translation" +) + +type TranscribeOptions struct { + Language string + Model string + Prompt string + Temperature float64 + Mode TranscribeMode + InputFormat AudioInputFormat + SampleRate int + WordTimestamps bool + SpeakerLabels bool + MaxAlternatives int +} + +type TranscribeOption func(*TranscribeOptions) + +func WithSTTLanguage(lang string) TranscribeOption { + return func(o *TranscribeOptions) { + o.Language = lang + } +} + +func WithSTTModel(model string) TranscribeOption { + return func(o *TranscribeOptions) { + o.Model = model + } +} + +func WithSTTPrompt(prompt string) TranscribeOption { + return func(o *TranscribeOptions) { + o.Prompt = prompt + } +} + +func WithSTTTemperature(temp float64) TranscribeOption { + return func(o *TranscribeOptions) { + o.Temperature = temp + } +} + +func WithSTTMode(mode TranscribeMode) TranscribeOption { + return func(o *TranscribeOptions) { + o.Mode = mode + } +} + +func WithSTTInputFormat(format AudioInputFormat) TranscribeOption { + return func(o *TranscribeOptions) { + o.InputFormat = format + } +} + +func WithSTTSampleRate(rate int) TranscribeOption { + return func(o *TranscribeOptions) { + o.SampleRate = rate + } +} + +func WithSTTWordTimestamps(enabled bool) TranscribeOption { + return func(o *TranscribeOptions) { + o.WordTimestamps = enabled + } +} + +func WithSTTSpeakerLabels(enabled bool) TranscribeOption { + return func(o *TranscribeOptions) { + o.SpeakerLabels = enabled + } +} + +func WithSTTMaxAlternatives(n int) TranscribeOption { + return func(o *TranscribeOptions) { + o.MaxAlternatives = n + } +} + +type WordInfo struct { + Word string + StartTime time.Duration + EndTime time.Duration + Confidence float64 +} + +type SegmentInfo struct { + ID int + Text string + StartTime time.Duration + EndTime time.Duration + Confidence float64 + Speaker string + Words []WordInfo +} + +type TranscriptResult struct { + Text string + Language string + Duration time.Duration + Confidence float64 + Segments []SegmentInfo + Words []WordInfo + Alternatives []string +} + +type STTErrorCode string + +const ( + ErrProviderNotSupported STTErrorCode = "provider_not_supported" + ErrAudioFormatInvalid STTErrorCode = "audio_format_invalid" + ErrTranscriptionFailed STTErrorCode = "transcription_failed" + ErrAudioTooLong STTErrorCode = "audio_too_long" + ErrAudioTooLarge STTErrorCode = "audio_too_large" + ErrRateLimited STTErrorCode = "rate_limited" + ErrAuthentication STTErrorCode = "authentication_failed" +) + +type STTError struct { + Code STTErrorCode + Message string + Err error +} + +func NewSTTError(code STTErrorCode, message string) *STTError { + return &STTError{Code: code, Message: message} +} + +func NewSTTErrorf(code STTErrorCode, format string, args ...interface{}) *STTError { + return &STTError{Code: code, Message: fmt.Sprintf(format, args...)} +} + +func (e *STTError) Error() string { + if e.Err != nil { + return string(e.Code) + ": " + e.Message + ": " + e.Err.Error() + } + return string(e.Code) + ": " + e.Message +} + +func (e *STTError) Unwrap() error { + return e.Err +} diff --git a/pkg/speech/_migration_backup/stt_whisper.go b/pkg/speech/_migration_backup/stt_whisper.go new file mode 100644 index 00000000..9ac04a86 --- /dev/null +++ b/pkg/speech/_migration_backup/stt_whisper.go @@ -0,0 +1,644 @@ +package speech + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +type WhisperModel string + +const ( + WhisperModelV1 WhisperModel = "whisper-1" +) + +var validWhisperModels = map[WhisperModel]bool{ + WhisperModelV1: true, +} + +var validInputFormats = map[AudioInputFormat]bool{ + InputMP3: true, + InputWAV: true, + InputOGG: true, + InputFLAC: true, + InputM4A: true, + InputMP4: true, + InputMPEG: true, + InputMPGA: true, + InputWEBM: true, +} + +type WhisperProvider struct { + apiKey string + baseURL string + model WhisperModel + language string + timeout time.Duration + retries int + client *http.Client + httpTransport *http.Transport +} + +type WhisperOption func(*WhisperProvider) + +func WithWhisperBaseURL(url string) WhisperOption { + return func(p *WhisperProvider) { + p.baseURL = strings.TrimRight(url, "/") + } +} + +func WithWhisperModel(model WhisperModel) WhisperOption { + return func(p *WhisperProvider) { + p.model = model + } +} + +func WithWhisperLanguage(lang string) WhisperOption { + return func(p *WhisperProvider) { + p.language = lang + } +} + +func WithWhisperTimeout(timeout time.Duration) WhisperOption { + return func(p *WhisperProvider) { + p.timeout = timeout + } +} + +func WithWhisperRetries(retries int) WhisperOption { + return func(p *WhisperProvider) { + p.retries = retries + } +} + +func WithWhisperHTTPTransport(transport *http.Transport) WhisperOption { + return func(p *WhisperProvider) { + p.httpTransport = transport + } +} + +func NewWhisperProvider(apiKey string, opts ...WhisperOption) (*WhisperProvider, error) { + if apiKey == "" { + return nil, NewSTTError(ErrAuthentication, "openai: API key is required") + } + + p := &WhisperProvider{ + apiKey: apiKey, + baseURL: "https://api.openai.com", + model: WhisperModelV1, + timeout: 120 * time.Second, + retries: 2, + client: &http.Client{Timeout: 120 * time.Second}, + } + + for _, opt := range opts { + opt(p) + } + + if p.httpTransport != nil { + p.client.Transport = p.httpTransport + } + p.client.Timeout = p.timeout + + if !validWhisperModels[p.model] { + return nil, NewSTTErrorf(ErrProviderNotSupported, "openai: invalid whisper model: %s", p.model) + } + + return p, nil +} + +func (p *WhisperProvider) Name() string { + return "openai-whisper" +} + +func (p *WhisperProvider) Type() STTProviderType { + return STTProviderOpenAI +} + +func (p *WhisperProvider) Transcribe(ctx context.Context, audio []byte, opts ...TranscribeOption) (*TranscriptResult, error) { + options := TranscribeOptions{ + Model: string(p.model), + Language: p.language, + Temperature: 0, + Mode: ModeTranscription, + InputFormat: InputMP3, + } + for _, opt := range opts { + opt(&options) + } + + if err := p.validateTranscribeOptions(options); err != nil { + return nil, err + } + + if len(audio) == 0 { + return nil, NewSTTError(ErrAudioFormatInvalid, "openai-whisper: audio data is empty") + } + + const maxAudioSize = 25 * 1024 * 1024 + if len(audio) > maxAudioSize { + return nil, NewSTTErrorf(ErrAudioTooLarge, "openai-whisper: audio exceeds 25MB limit (%d bytes)", len(audio)) + } + + if !validInputFormats[options.InputFormat] { + return nil, NewSTTErrorf(ErrAudioFormatInvalid, "openai-whisper: unsupported input format: %s", options.InputFormat) + } + + var lastErr error + for attempt := 0; attempt <= p.retries; attempt++ { + if attempt > 0 { + backoff := time.Duration(attempt) * time.Second + select { + case <-ctx.Done(): + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: context cancelled during retry: %v", ctx.Err()) + case <-time.After(backoff): + } + } + + result, err := p.doTranscribe(ctx, audio, options) + if err == nil { + return result, nil + } + + lastErr = err + + if sttErr, ok := err.(*STTError); ok { + if sttErr.Code == ErrAuthentication || sttErr.Code == ErrAudioFormatInvalid || sttErr.Code == ErrAudioTooLarge || sttErr.Code == ErrRateLimited { + return nil, err + } + } + } + + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: all %d retries failed: %v", p.retries, lastErr) +} + +func (p *WhisperProvider) TranscribeFile(ctx context.Context, filePath string, opts ...TranscribeOption) (*TranscriptResult, error) { + if filePath == "" { + return nil, NewSTTError(ErrAudioFormatInvalid, "openai-whisper: file path is empty") + } + + info, err := os.Stat(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, NewSTTErrorf(ErrAudioFormatInvalid, "openai-whisper: file not found: %s", filePath) + } + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to stat file: %v", err) + } + + const maxAudioSize = 25 * 1024 * 1024 + if info.Size() > maxAudioSize { + return nil, NewSTTErrorf(ErrAudioTooLarge, "openai-whisper: file exceeds 25MB limit (%d bytes)", info.Size()) + } + + audio, err := os.ReadFile(filePath) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to read file: %v", err) + } + + if len(opts) == 0 || anyInputFormatNotSet(opts) { + ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filePath)), ".") + if ext != "" { + formatOpts := append([]TranscribeOption{WithSTTInputFormat(AudioInputFormat(ext))}, opts...) + return p.Transcribe(ctx, audio, formatOpts...) + } + } + + return p.Transcribe(ctx, audio, opts...) +} + +func anyInputFormatNotSet(opts []TranscribeOption) bool { + for _, opt := range opts { + o := &TranscribeOptions{} + opt(o) + if o.InputFormat != "" { + return false + } + } + return true +} + +func (p *WhisperProvider) TranscribeStream(ctx context.Context, reader io.Reader, opts ...TranscribeOption) (*TranscriptResult, error) { + if reader == nil { + return nil, NewSTTError(ErrAudioFormatInvalid, "openai-whisper: reader is nil") + } + + audio, err := io.ReadAll(reader) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to read stream: %v", err) + } + + return p.Transcribe(ctx, audio, opts...) +} + +func (p *WhisperProvider) validateTranscribeOptions(options TranscribeOptions) error { + if options.Temperature < 0 || options.Temperature > 1 { + return NewSTTErrorf(ErrAudioFormatInvalid, "openai-whisper: temperature must be between 0 and 1, got: %f", options.Temperature) + } + + if options.MaxAlternatives < 0 { + return NewSTTErrorf(ErrAudioFormatInvalid, "openai-whisper: maxAlternatives cannot be negative") + } + + if options.Model == "" { + return NewSTTError(ErrAudioFormatInvalid, "openai-whisper: model is required") + } + + return nil +} + +func (p *WhisperProvider) doTranscribe(ctx context.Context, audio []byte, options TranscribeOptions) (*TranscriptResult, error) { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + filename := "audio." + string(options.InputFormat) + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create form file: %v", err) + } + + if _, err := part.Write(audio); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write audio data: %v", err) + } + + if err := writer.WriteField("model", options.Model); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write model field: %v", err) + } + + if options.Language != "" { + if err := writer.WriteField("language", options.Language); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write language field: %v", err) + } + } + + if options.Prompt != "" { + if err := writer.WriteField("prompt", options.Prompt); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write prompt field: %v", err) + } + } + + if options.Temperature > 0 { + if err := writer.WriteField("temperature", fmt.Sprintf("%.2f", options.Temperature)); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write temperature field: %v", err) + } + } + + if options.MaxAlternatives > 0 { + if err := writer.WriteField("max_alternatives", fmt.Sprintf("%d", options.MaxAlternatives)); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write max_alternatives field: %v", err) + } + } + + if options.WordTimestamps || options.SpeakerLabels { + if options.WordTimestamps { + if err := writer.WriteField("timestamp_granularities[]", "word"); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write word timestamp_granularities: %v", err) + } + } + if options.SpeakerLabels { + if err := writer.WriteField("timestamp_granularities[]", "segment"); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write segment timestamp_granularities: %v", err) + } + } + } + + responseType := "verbose_json" + if options.WordTimestamps || options.SpeakerLabels { + responseType = "verbose_json" + } + if err := writer.WriteField("response_format", responseType); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write response_format field: %v", err) + } + + if err := writer.Close(); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to close multipart writer: %v", err) + } + + var endpoint string + switch options.Mode { + case ModeTranslation: + endpoint = "/v1/audio/translations" + default: + endpoint = "/v1/audio/transcriptions" + } + + url := p.baseURL + endpoint + + req, err := http.NewRequestWithContext(ctx, "POST", url, &body) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+p.apiKey) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("User-Agent", "anyclaw-stt/1.0") + + resp, err := p.client.Do(req) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, p.handleErrorResponse(resp.StatusCode, respBody) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to read response: %v", err) + } + + return p.parseResponse(respBody, options) +} + +func (p *WhisperProvider) handleErrorResponse(statusCode int, body []byte) error { + var errResp whisperErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error.Message != "" { + msg := fmt.Sprintf("openai-whisper: API error: %s (type: %s, code: %s)", + errResp.Error.Message, errResp.Error.Type, errResp.Error.Code) + switch statusCode { + case http.StatusUnauthorized: + return NewSTTError(ErrAuthentication, msg) + case http.StatusTooManyRequests: + return NewSTTError(ErrRateLimited, msg) + case http.StatusBadRequest: + return NewSTTError(ErrAudioFormatInvalid, msg) + default: + return NewSTTError(ErrTranscriptionFailed, msg) + } + } + + switch statusCode { + case http.StatusUnauthorized: + return NewSTTError(ErrAuthentication, fmt.Sprintf("openai-whisper: authentication failed: %s", string(body))) + case http.StatusTooManyRequests: + return NewSTTError(ErrRateLimited, fmt.Sprintf("openai-whisper: rate limited: %s", string(body))) + case http.StatusBadRequest: + return NewSTTErrorf(ErrAudioFormatInvalid, "openai-whisper: invalid request: %s", string(body)) + case http.StatusServiceUnavailable: + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: service unavailable: %s", string(body)) + default: + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: unexpected status %d: %s", statusCode, string(body)) + } +} + +type whisperResponse struct { + Text string `json:"text"` + Language string `json:"language"` + Duration float64 `json:"duration,omitempty"` + Segments []struct { + ID int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogProb float64 `json:"avg_logprob"` + Compression float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + Confidence float64 `json:"probability"` + } `json:"words,omitempty"` + } `json:"segments,omitempty"` + LanguageProbability float64 `json:"language_probability,omitempty"` +} + +type whisperErrorResponse struct { + Error struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + } `json:"error"` +} + +func (p *WhisperProvider) parseResponse(body []byte, options TranscribeOptions) (*TranscriptResult, error) { + var resp whisperResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to parse JSON response: %v", err) + } + + result := &TranscriptResult{ + Text: strings.TrimSpace(resp.Text), + Language: resp.Language, + Duration: time.Duration(resp.Duration * float64(time.Second)), + Confidence: resp.LanguageProbability, + } + + if len(resp.Segments) > 0 { + result.Segments = make([]SegmentInfo, 0, len(resp.Segments)) + for _, seg := range resp.Segments { + segment := SegmentInfo{ + ID: seg.ID, + Text: seg.Text, + StartTime: time.Duration(seg.Start * float64(time.Second)), + EndTime: time.Duration(seg.End * float64(time.Second)), + } + + if seg.AvgLogProb != 0 { + segment.Confidence = normalizeLogProb(seg.AvgLogProb) + } + + if len(seg.Words) > 0 { + segment.Words = make([]WordInfo, 0, len(seg.Words)) + for _, w := range seg.Words { + segment.Words = append(segment.Words, WordInfo{ + Word: w.Word, + StartTime: time.Duration(w.Start * float64(time.Second)), + EndTime: time.Duration(w.End * float64(time.Second)), + Confidence: w.Confidence, + }) + } + } + + result.Segments = append(result.Segments, segment) + } + + if len(result.Segments) > 0 && result.Confidence == 0 { + totalConfidence := 0.0 + for _, seg := range result.Segments { + totalConfidence += seg.Confidence + } + result.Confidence = totalConfidence / float64(len(result.Segments)) + } + } + + if options.WordTimestamps && len(result.Segments) > 0 { + words := make([]WordInfo, 0) + for _, seg := range result.Segments { + words = append(words, seg.Words...) + } + result.Words = words + } + + return result, nil +} + +func normalizeLogProb(logProb float64) float64 { + if logProb > 0 { + return 1.0 + } + prob := 1.0 / (1.0 + logProb*-1) + if prob < 0 { + return 0 + } + if prob > 1 { + return 1 + } + return prob +} + +func (p *WhisperProvider) TranscribeSSE(ctx context.Context, audio []byte, onChunk func(chunk *TranscriptResult), opts ...TranscribeOption) error { + options := TranscribeOptions{ + Model: string(p.model), + Language: p.language, + Temperature: 0, + Mode: ModeTranscription, + InputFormat: InputMP3, + } + for _, opt := range opts { + opt(&options) + } + + if err := p.validateTranscribeOptions(options); err != nil { + return err + } + + if len(audio) == 0 { + return NewSTTError(ErrAudioFormatInvalid, "openai-whisper: audio data is empty") + } + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + filename := "audio." + string(options.InputFormat) + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create form file: %v", err) + } + + if _, err := part.Write(audio); err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write audio data: %v", err) + } + + if err := writer.WriteField("model", options.Model); err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write model field: %v", err) + } + + if options.Language != "" { + if err := writer.WriteField("language", options.Language); err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write language field: %v", err) + } + } + + if err := writer.WriteField("response_format", "json"); err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write response_format field: %v", err) + } + + if err := writer.WriteField("stream", "true"); err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write stream field: %v", err) + } + + if err := writer.Close(); err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to close multipart writer: %v", err) + } + + endpoint := "/v1/audio/transcriptions" + if options.Mode == ModeTranslation { + endpoint = "/v1/audio/translations" + } + + url := p.baseURL + endpoint + + req, err := http.NewRequestWithContext(ctx, "POST", url, &body) + if err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+p.apiKey) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Accept", "text/event-stream") + + resp, err := p.client.Do(req) + if err != nil { + return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: streaming request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return p.handleErrorResponse(resp.StatusCode, respBody) + } + + return p.readSSEStream(resp.Body, onChunk) +} + +func (p *WhisperProvider) readSSEStream(reader io.Reader, onChunk func(chunk *TranscriptResult)) error { + scanner := bufio.NewScanner(reader) + scanner.Split(bufio.ScanLines) + + var currentText strings.Builder + var detectedLanguage string + + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk struct { + Text string `json:"text"` + Language string `json:"language"` + Done bool `json:"done"` + } + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + if chunk.Text != "" { + currentText.WriteString(chunk.Text) + } + if chunk.Language != "" { + detectedLanguage = chunk.Language + } + + onChunk(&TranscriptResult{ + Text: currentText.String(), + Language: detectedLanguage, + }) + + if chunk.Done { + break + } + } + } + + return scanner.Err() +} + +func (p *WhisperProvider) ListLanguages(ctx context.Context) ([]string, error) { + return []string{ + "af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs", "ca", "cs", "cy", "da", + "de", "el", "en", "es", "et", "eu", "fa", "fi", "fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", + "hr", "ht", "hu", "hy", "id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb", + "ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt", "my", "ne", "nl", "nn", + "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru", "sa", "sd", "si", "sk", "sl", "sn", "so", "sq", + "sr", "su", "sv", "sw", "ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi", + "yi", "yo", "zh", + }, nil +} diff --git a/pkg/speech/_migration_backup/vad.go b/pkg/speech/_migration_backup/vad.go new file mode 100644 index 00000000..fc36af30 --- /dev/null +++ b/pkg/speech/_migration_backup/vad.go @@ -0,0 +1,306 @@ +package speech + +import ( + "math" + "sync" +) + +type VADState string + +const ( + VADStateSilence VADState = "silence" + VADStateSpeech VADState = "speech" +) + +type VADConfig struct { + SampleRate int + FrameSize int + EnergyThreshold float64 + ZeroCrossThreshold int + SpeechMinFrames int + SilenceFrames int + HangoverFrames int +} + +func DefaultVADConfig() VADConfig { + return VADConfig{ + SampleRate: 16000, + FrameSize: 320, + EnergyThreshold: 0.01, + ZeroCrossThreshold: 50, + SpeechMinFrames: 3, + SilenceFrames: 30, + HangoverFrames: 10, + } +} + +type VAD struct { + mu sync.Mutex + cfg VADConfig + state VADState + consecutiveSpeech int + consecutiveSilence int + listeners []VADStateListener +} + +type VADStateListener func(state VADState, energy float64, zcr float64) + +func NewVAD(cfg VADConfig) *VAD { + if cfg.SampleRate == 0 { + cfg.SampleRate = 16000 + } + if cfg.FrameSize == 0 { + cfg.FrameSize = 320 + } + if cfg.EnergyThreshold == 0 { + cfg.EnergyThreshold = 0.01 + } + if cfg.ZeroCrossThreshold == 0 { + cfg.ZeroCrossThreshold = 50 + } + if cfg.SpeechMinFrames == 0 { + cfg.SpeechMinFrames = 3 + } + if cfg.SilenceFrames == 0 { + cfg.SilenceFrames = 30 + } + if cfg.HangoverFrames == 0 { + cfg.HangoverFrames = 10 + } + + return &VAD{ + cfg: cfg, + state: VADStateSilence, + } +} + +func (v *VAD) RegisterListener(listener VADStateListener) { + v.mu.Lock() + defer v.mu.Unlock() + v.listeners = append(v.listeners, listener) +} + +func (v *VAD) ProcessFrame(samples []int16) VADState { + v.mu.Lock() + defer v.mu.Unlock() + + energy := v.calculateRMS(samples) + zcr := v.calculateZeroCrossingRate(samples) + + isSpeech := v.isSpeechFrame(energy, zcr) + + if isSpeech { + v.consecutiveSpeech++ + v.consecutiveSilence = 0 + } else { + v.consecutiveSilence++ + v.consecutiveSpeech = 0 + } + + switch v.state { + case VADStateSilence: + if isSpeech { + if v.consecutiveSpeech >= v.cfg.SpeechMinFrames { + v.state = VADStateSpeech + v.notifyListeners(VADStateSpeech, energy, zcr) + } + } else { + v.consecutiveSpeech = 0 + } + + case VADStateSpeech: + if isSpeech { + v.consecutiveSilence = 0 + } else { + if v.consecutiveSilence >= v.cfg.HangoverFrames { + v.state = VADStateSilence + v.consecutiveSpeech = 0 + v.consecutiveSilence = 0 + v.notifyListeners(VADStateSilence, energy, zcr) + } + } + } + + return v.state +} + +func (v *VAD) ProcessFloatFrame(samples []float32) VADState { + intSamples := make([]int16, len(samples)) + for i, s := range samples { + clamped := s + if clamped > 1.0 { + clamped = 1.0 + } + if clamped < -1.0 { + clamped = -1.0 + } + intSamples[i] = int16(clamped * 32767.0) + } + return v.ProcessFrame(intSamples) +} + +func (v *VAD) isSpeechFrame(energy float64, zcr float64) bool { + return energy > v.cfg.EnergyThreshold || zcr > float64(v.cfg.ZeroCrossThreshold) +} + +func (v *VAD) calculateRMS(samples []int16) float64 { + if len(samples) == 0 { + return 0 + } + + var sumSquares float64 + for _, s := range samples { + normalized := float64(s) / 32768.0 + sumSquares += normalized * normalized + } + + return math.Sqrt(sumSquares / float64(len(samples))) +} + +func (v *VAD) calculateZeroCrossingRate(samples []int16) float64 { + if len(samples) < 2 { + return 0 + } + + var crossings int + for i := 1; i < len(samples); i++ { + if (samples[i] >= 0 && samples[i-1] < 0) || (samples[i] < 0 && samples[i-1] >= 0) { + crossings++ + } + } + + return float64(crossings) +} + +func (v *VAD) State() VADState { + v.mu.Lock() + defer v.mu.Unlock() + return v.state +} + +func (v *VAD) Reset() { + v.mu.Lock() + defer v.mu.Unlock() + v.state = VADStateSilence + v.consecutiveSpeech = 0 + v.consecutiveSilence = 0 +} + +func (v *VAD) notifyListeners(state VADState, energy float64, zcr float64) { + for _, listener := range v.listeners { + listener(state, energy, zcr) + } +} + +func (v *VAD) UpdateConfig(cfg VADConfig) { + v.mu.Lock() + defer v.mu.Unlock() + if cfg.EnergyThreshold > 0 { + v.cfg.EnergyThreshold = cfg.EnergyThreshold + } + if cfg.ZeroCrossThreshold > 0 { + v.cfg.ZeroCrossThreshold = cfg.ZeroCrossThreshold + } + if cfg.SpeechMinFrames > 0 { + v.cfg.SpeechMinFrames = cfg.SpeechMinFrames + } + if cfg.SilenceFrames > 0 { + v.cfg.SilenceFrames = cfg.SilenceFrames + } + if cfg.HangoverFrames > 0 { + v.cfg.HangoverFrames = cfg.HangoverFrames + } +} + +func (v *VAD) Config() VADConfig { + v.mu.Lock() + defer v.mu.Unlock() + return v.cfg +} + +func NormalizeAudio(samples []int16) []float64 { + result := make([]float64, len(samples)) + for i, s := range samples { + result[i] = float64(s) / 32768.0 + } + return result +} + +func Float32ToInt16(samples []float32) []int16 { + result := make([]int16, len(samples)) + for i, s := range samples { + clamped := s + if clamped > 1.0 { + clamped = 1.0 + } + if clamped < -1.0 { + clamped = -1.0 + } + result[i] = int16(clamped * 32767.0) + } + return result +} + +func Int16ToWAV(samples []int16, sampleRate int, channels int) []byte { + if len(samples) == 0 { + return nil + } + + bitsPerSample := 16 + byteRate := sampleRate * channels * bitsPerSample / 8 + blockAlign := channels * bitsPerSample / 8 + dataSize := len(samples) * 2 + fileSize := 36 + dataSize + + buf := make([]byte, 44+dataSize) + + copy(buf[0:4], []byte("RIFF")) + buf[4] = byte(fileSize) + buf[5] = byte(fileSize >> 8) + buf[6] = byte(fileSize >> 16) + buf[7] = byte(fileSize >> 24) + + copy(buf[8:12], []byte("WAVE")) + + copy(buf[12:16], []byte("fmt ")) + buf[16] = 16 + buf[17] = 0 + buf[18] = 0 + buf[19] = 0 + + buf[20] = 1 + buf[21] = 0 + + buf[22] = byte(channels) + buf[23] = 0 + + buf[24] = byte(sampleRate) + buf[25] = byte(sampleRate >> 8) + buf[26] = byte(sampleRate >> 16) + buf[27] = byte(sampleRate >> 24) + + buf[28] = byte(byteRate) + buf[29] = byte(byteRate >> 8) + buf[30] = byte(byteRate >> 16) + buf[31] = byte(byteRate >> 24) + + buf[32] = byte(blockAlign) + buf[33] = 0 + + buf[34] = byte(bitsPerSample) + buf[35] = 0 + + copy(buf[36:40], []byte("data")) + buf[40] = byte(dataSize) + buf[41] = byte(dataSize >> 8) + buf[42] = byte(dataSize >> 16) + buf[43] = byte(dataSize >> 24) + + for i, s := range samples { + offset := 44 + i*2 + buf[offset] = byte(s) + buf[offset+1] = byte(s >> 8) + } + + return buf +} diff --git a/pkg/speech/_migration_backup/voicewake.go b/pkg/speech/_migration_backup/voicewake.go new file mode 100644 index 00000000..5807bbf9 --- /dev/null +++ b/pkg/speech/_migration_backup/voicewake.go @@ -0,0 +1,620 @@ +package speech + +import ( + "context" + "fmt" + "log" + "sync" + "time" +) + +type VoiceWakeState string + +const ( + VoiceWakeStateIdle VoiceWakeState = "idle" + VoiceWakeStateListening VoiceWakeState = "listening" + VoiceWakeStateRecording VoiceWakeState = "recording" + VoiceWakeStateProcessing VoiceWakeState = "processing" + VoiceWakeStateTriggered VoiceWakeState = "triggered" +) + +type VoiceWakeEventType string + +const ( + VoiceWakeEventStateChanged VoiceWakeEventType = "state_changed" + VoiceWakeEventWakeDetected VoiceWakeEventType = "wake_detected" + VoiceWakeEventSpeechStart VoiceWakeEventType = "speech_start" + VoiceWakeEventSpeechEnd VoiceWakeEventType = "speech_end" + VoiceWakeEventError VoiceWakeEventType = "error" +) + +type VoiceWakeEvent struct { + Type VoiceWakeEventType + State VoiceWakeState + Timestamp time.Time + Data map[string]any +} + +type VoiceWakeListener func(event VoiceWakeEvent) + +type AudioSource interface { + Start(ctx context.Context) error + Stop() error + Read(samples []int16) (int, error) + SampleRate() int + Channels() int +} + +type VoiceWakeConfig struct { + VADConfig VADConfig + WakeWordConfig WakeWordConfig + EngineConfig WakeWordEngineConfig + SampleRate int + Channels int + FrameSize int + MaxRecordingTime time.Duration + CooldownTime time.Duration + AudioSource AudioSource + STTPipeline *STTPipeline + AutoTranscribe bool + WakeWordEngine WakeWordEngineType +} + +func DefaultVoiceWakeConfig() VoiceWakeConfig { + return VoiceWakeConfig{ + VADConfig: DefaultVADConfig(), + WakeWordConfig: DefaultWakeWordConfig(), + SampleRate: 16000, + Channels: 1, + FrameSize: 320, + MaxRecordingTime: 30 * time.Second, + CooldownTime: 2 * time.Second, + AutoTranscribe: true, + } +} + +type VoiceWake struct { + mu sync.Mutex + cfg VoiceWakeConfig + state VoiceWakeState + vad *VAD + wakeDetector *WakeWordDetector + engineRouter *WakeWordEngineRouter + engineAdapter *WakeWordEngineAdapter + listeners []VoiceWakeListener + audioBuffer []int16 + recordingBuffer []int16 + isRecording bool + recordingStart time.Time + cooldownUntil time.Time + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + transcriber *STTPipeline + lastTranscript string + lastWakeMatch string + lastConfidence float64 + lastEnergy float64 +} + +func NewVoiceWake(cfg VoiceWakeConfig) *VoiceWake { + if cfg.SampleRate == 0 { + cfg.SampleRate = 16000 + } + if cfg.Channels == 0 { + cfg.Channels = 1 + } + if cfg.FrameSize == 0 { + cfg.FrameSize = 320 + } + if cfg.MaxRecordingTime == 0 { + cfg.MaxRecordingTime = 30 * time.Second + } + if cfg.CooldownTime == 0 { + cfg.CooldownTime = 2 * time.Second + } + + cfg.VADConfig.SampleRate = cfg.SampleRate + cfg.VADConfig.FrameSize = cfg.FrameSize + + cfg.EngineConfig.SampleRate = cfg.SampleRate + cfg.EngineConfig.FrameSize = cfg.FrameSize + + vad := NewVAD(cfg.VADConfig) + wakeDetector := NewWakeWordDetector(cfg.WakeWordConfig) + + router := NewWakeWordEngineRouter(cfg.EngineConfig) + adapter := NewWakeWordEngineAdapter(router, wakeDetector) + + vw := &VoiceWake{ + cfg: cfg, + state: VoiceWakeStateIdle, + vad: vad, + wakeDetector: wakeDetector, + engineRouter: router, + engineAdapter: adapter, + transcriber: cfg.STTPipeline, + } + + vad.RegisterListener(vw.onVADStateChanged) + + return vw +} + +func (vw *VoiceWake) RegisterListener(listener VoiceWakeListener) { + vw.mu.Lock() + defer vw.mu.Unlock() + vw.listeners = append(vw.listeners, listener) +} + +func (vw *VoiceWake) Start(ctx context.Context) error { + vw.mu.Lock() + if vw.state != VoiceWakeStateIdle { + vw.mu.Unlock() + return fmt.Errorf("voicewake: already in state %s", vw.state) + } + vw.state = VoiceWakeStateListening + vw.mu.Unlock() + + vw.ctx, vw.cancel = context.WithCancel(ctx) + + if vw.cfg.AudioSource != nil { + if err := vw.cfg.AudioSource.Start(vw.ctx); err != nil { + vw.mu.Lock() + vw.state = VoiceWakeStateIdle + vw.mu.Unlock() + return fmt.Errorf("voicewake: failed to start audio source: %w", err) + } + } + + vw.wg.Add(1) + go vw.listenLoop() + + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventStateChanged, + State: VoiceWakeStateListening, + Timestamp: time.Now(), + Data: map[string]any{"message": "Voice wake listener started"}, + }) + + return nil +} + +func (vw *VoiceWake) Stop() error { + vw.mu.Lock() + if vw.state == VoiceWakeStateIdle { + vw.mu.Unlock() + return nil + } + + if vw.cancel != nil { + vw.cancel() + } + vw.state = VoiceWakeStateIdle + vw.mu.Unlock() + + if vw.cfg.AudioSource != nil { + _ = vw.cfg.AudioSource.Stop() + } + + vw.wg.Wait() + + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventStateChanged, + State: VoiceWakeStateIdle, + Timestamp: time.Now(), + Data: map[string]any{"message": "Voice wake listener stopped"}, + }) + + return nil +} + +func (vw *VoiceWake) listenLoop() { + defer vw.wg.Done() + + samples := make([]int16, vw.cfg.FrameSize) + + for { + select { + case <-vw.ctx.Done(): + return + default: + } + + var n int + var err error + + if vw.cfg.AudioSource != nil { + n, err = vw.cfg.AudioSource.Read(samples) + if err != nil { + log.Printf("voicewake: error reading audio: %v", err) + time.Sleep(10 * time.Millisecond) + continue + } + } else { + time.Sleep(time.Duration(vw.cfg.FrameSize) * time.Second / time.Duration(vw.cfg.SampleRate)) + continue + } + + if n == 0 { + continue + } + + vw.mu.Lock() + inCooldown := time.Now().Before(vw.cooldownUntil) + vw.mu.Unlock() + + if inCooldown { + continue + } + + if vw.engineAdapter != nil && vw.engineAdapter.UseEngine() { + result, detected := vw.engineAdapter.ProcessFrame(samples[:n]) + if detected && result != nil { + vw.mu.Lock() + vw.lastWakeMatch = result.Keyword + vw.lastConfidence = result.Confidence + vw.cooldownUntil = time.Now().Add(vw.cfg.CooldownTime) + vw.mu.Unlock() + + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventWakeDetected, + State: VoiceWakeStateTriggered, + Timestamp: time.Now(), + Data: map[string]any{ + "phrase": result.Keyword, + "confidence": result.Confidence, + "engine": string(result.Engine), + "energy": 0.0, + }, + }) + + vw.mu.Lock() + vw.setState(VoiceWakeStateTriggered) + vw.mu.Unlock() + + time.Sleep(vw.cfg.CooldownTime) + + vw.mu.Lock() + vw.setState(VoiceWakeStateListening) + vw.mu.Unlock() + + continue + } + } + + vw.processAudio(samples[:n]) + } +} + +func (vw *VoiceWake) processAudio(samples []int16) { + vw.mu.Lock() + vw.audioBuffer = append(vw.audioBuffer, samples...) + vw.mu.Unlock() + + state := vw.vad.ProcessFrame(samples) + + switch state { + case VADStateSpeech: + vw.mu.Lock() + if !vw.isRecording { + vw.isRecording = true + vw.recordingStart = time.Now() + vw.recordingBuffer = make([]int16, 0, vw.cfg.SampleRate*int(vw.cfg.MaxRecordingTime.Seconds())) + vw.setState(VoiceWakeStateRecording) + } + vw.recordingBuffer = append(vw.recordingBuffer, samples...) + vw.mu.Unlock() + + case VADStateSilence: + vw.mu.Lock() + if vw.isRecording { + vw.isRecording = false + recording := make([]int16, len(vw.recordingBuffer)) + copy(recording, vw.recordingBuffer) + vw.recordingBuffer = nil + vw.mu.Unlock() + + vw.processRecording(recording) + } else { + vw.mu.Unlock() + } + } +} + +func (vw *VoiceWake) processRecording(samples []int16) { + if len(samples) == 0 { + return + } + + vw.mu.Lock() + vw.setState(VoiceWakeStateProcessing) + vw.mu.Unlock() + + if vw.cfg.AutoTranscribe && vw.transcriber != nil { + audioData := Int16ToWAV(samples, vw.cfg.SampleRate, vw.cfg.Channels) + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + result, err := vw.transcriber.TranscribeDirect(ctx, audioData, WithSTTInputFormat(InputWAV)) + if err != nil { + log.Printf("voicewake: transcription error: %v", err) + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventError, + State: VoiceWakeStateProcessing, + Timestamp: time.Now(), + Data: map[string]any{"error": err.Error()}, + }) + vw.mu.Lock() + vw.setState(VoiceWakeStateListening) + vw.mu.Unlock() + return + } + + vw.mu.Lock() + vw.lastTranscript = result.Text + vw.mu.Unlock() + + vw.checkWakeWord(result.Text) + }() + } else { + vw.mu.Lock() + vw.setState(VoiceWakeStateListening) + vw.mu.Unlock() + } +} + +func (vw *VoiceWake) checkWakeWord(transcript string) { + if transcript == "" { + vw.mu.Lock() + vw.setState(VoiceWakeStateListening) + vw.mu.Unlock() + return + } + + phrase, confidence, matched := vw.wakeDetector.Detect(transcript) + + vw.mu.Lock() + vw.lastTranscript = transcript + vw.lastWakeMatch = phrase + vw.lastConfidence = confidence + vw.mu.Unlock() + + if matched { + vw.mu.Lock() + vw.setState(VoiceWakeStateTriggered) + vw.cooldownUntil = time.Now().Add(vw.cfg.CooldownTime) + energy := vw.lastEnergy + vw.mu.Unlock() + + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventWakeDetected, + State: VoiceWakeStateTriggered, + Timestamp: time.Now(), + Data: map[string]any{ + "phrase": phrase, + "confidence": confidence, + "transcript": transcript, + "energy": energy, + }, + }) + + time.Sleep(vw.cfg.CooldownTime) + + vw.mu.Lock() + vw.setState(VoiceWakeStateListening) + vw.mu.Unlock() + } else { + vw.mu.Lock() + vw.setState(VoiceWakeStateListening) + vw.mu.Unlock() + } +} + +func (vw *VoiceWake) onVADStateChanged(state VADState, energy float64, zcr float64) { + vw.mu.Lock() + vw.lastEnergy = energy + vw.mu.Unlock() + + switch state { + case VADStateSpeech: + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventSpeechStart, + State: vw.State(), + Timestamp: time.Now(), + Data: map[string]any{ + "energy": energy, + "zcr": zcr, + }, + }) + + case VADStateSilence: + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventSpeechEnd, + State: vw.State(), + Timestamp: time.Now(), + Data: map[string]any{ + "energy": energy, + "zcr": zcr, + }, + }) + } +} + +func (vw *VoiceWake) setState(state VoiceWakeState) { + oldState := vw.state + vw.state = state + + if oldState != state { + vw.notifyListeners(VoiceWakeEvent{ + Type: VoiceWakeEventStateChanged, + State: state, + Timestamp: time.Now(), + Data: map[string]any{ + "previous_state": oldState, + "new_state": state, + }, + }) + } +} + +func (vw *VoiceWake) State() VoiceWakeState { + vw.mu.Lock() + defer vw.mu.Unlock() + return vw.state +} + +func (vw *VoiceWake) notifyListeners(event VoiceWakeEvent) { + vw.mu.Lock() + listeners := make([]VoiceWakeListener, len(vw.listeners)) + copy(listeners, vw.listeners) + vw.mu.Unlock() + + for _, listener := range listeners { + listener(event) + } +} + +func (vw *VoiceWake) LastTranscript() string { + vw.mu.Lock() + defer vw.mu.Unlock() + return vw.lastTranscript +} + +func (vw *VoiceWake) LastWakeMatch() (string, float64) { + vw.mu.Lock() + defer vw.mu.Unlock() + return vw.lastWakeMatch, vw.lastConfidence +} + +func (vw *VoiceWake) VAD() *VAD { + return vw.vad +} + +func (vw *VoiceWake) WakeDetector() *WakeWordDetector { + return vw.wakeDetector +} + +func (vw *VoiceWake) UpdateConfig(cfg VoiceWakeConfig) { + vw.mu.Lock() + defer vw.mu.Unlock() + + if cfg.SampleRate > 0 { + vw.cfg.SampleRate = cfg.SampleRate + } + if cfg.Channels > 0 { + vw.cfg.Channels = cfg.Channels + } + if cfg.FrameSize > 0 { + vw.cfg.FrameSize = cfg.FrameSize + } + if cfg.MaxRecordingTime > 0 { + vw.cfg.MaxRecordingTime = cfg.MaxRecordingTime + } + if cfg.CooldownTime > 0 { + vw.cfg.CooldownTime = cfg.CooldownTime + } + + vw.cfg.AutoTranscribe = cfg.AutoTranscribe +} + +func (vw *VoiceWake) Config() VoiceWakeConfig { + vw.mu.Lock() + defer vw.mu.Unlock() + return vw.cfg +} + +func (vw *VoiceWake) SetTranscriber(pipeline *STTPipeline) { + vw.mu.Lock() + defer vw.mu.Unlock() + vw.transcriber = pipeline +} + +func (vw *VoiceWake) RegisterEngine(engineType WakeWordEngineType, cfg WakeWordEngineConfig) error { + vw.mu.Lock() + router := vw.engineRouter + vw.mu.Unlock() + + if router == nil { + return fmt.Errorf("voicewake: no engine router available") + } + + if err := router.CreateEngine(engineType, cfg); err != nil { + return err + } + + vw.engineAdapter.SetUseEngine(true) + return nil +} + +func (vw *VoiceWake) SetActiveEngine(name string) error { + vw.mu.Lock() + router := vw.engineRouter + vw.mu.Unlock() + + if router == nil { + return fmt.Errorf("voicewake: no engine router available") + } + + return router.SetActive(name) +} + +func (vw *VoiceWake) UseWakeWordEngine(use bool) { + vw.mu.Lock() + defer vw.mu.Unlock() + if vw.engineAdapter != nil { + vw.engineAdapter.SetUseEngine(use) + } +} + +func (vw *VoiceWake) IsUsingWakeWordEngine() bool { + vw.mu.Lock() + defer vw.mu.Unlock() + if vw.engineAdapter == nil { + return false + } + return vw.engineAdapter.UseEngine() +} + +func (vw *VoiceWake) AvailableEngines() []string { + vw.mu.Lock() + defer vw.mu.Unlock() + if vw.engineRouter == nil { + return nil + } + return vw.engineRouter.Engines() +} + +func (vw *VoiceWake) ActiveEngine() string { + vw.mu.Lock() + defer vw.mu.Unlock() + if vw.engineRouter == nil { + return "" + } + return vw.engineRouter.ActiveEngine() +} + +func (vw *VoiceWake) EngineRouter() *WakeWordEngineRouter { + return vw.engineRouter +} + +func (vw *VoiceWake) EngineAdapter() *WakeWordEngineAdapter { + return vw.engineAdapter +} + +func (vw *VoiceWake) Close() error { + if err := vw.Stop(); err != nil { + return err + } + + vw.mu.Lock() + router := vw.engineRouter + vw.mu.Unlock() + + if router != nil { + return router.Close() + } + return nil +} diff --git a/pkg/speech/stt_google.go b/pkg/speech/stt_google.go index be0d14a7..ebb7145b 100644 --- a/pkg/speech/stt_google.go +++ b/pkg/speech/stt_google.go @@ -1,15 +1,19 @@ package speech import ( - "bytes" "context" - "encoding/base64" - "encoding/json" + "errors" "fmt" "io" - "net/http" "strings" "time" + + speechpb "cloud.google.com/go/speech/apiv1/speechpb" + "google.golang.org/api/googleapi" + "google.golang.org/api/option" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" ) type GoogleModel string @@ -71,7 +75,7 @@ type GoogleProvider struct { useEnhanced bool timeout time.Duration retries int - client *http.Client + client googleRecognizeAPI } type GoogleOption func(*GoogleProvider) @@ -118,11 +122,13 @@ func WithGoogleCredentialsJSON(credentialsJSON string) GoogleOption { } } -func NewGoogleProvider(apiKey string, opts ...GoogleOption) (*GoogleProvider, error) { - if apiKey == "" { - return nil, NewSTTError(ErrAuthentication, "google: API key is required") +func withGoogleRecognizeClient(client googleRecognizeAPI) GoogleOption { + return func(p *GoogleProvider) { + p.client = client } +} +func NewGoogleProvider(apiKey string, opts ...GoogleOption) (*GoogleProvider, error) { p := &GoogleProvider{ apiKey: apiKey, baseURL: "https://speech.googleapis.com", @@ -130,18 +136,40 @@ func NewGoogleProvider(apiKey string, opts ...GoogleOption) (*GoogleProvider, er model: GoogleModelDefault, timeout: 120 * time.Second, retries: 2, - client: &http.Client{Timeout: 120 * time.Second}, } for _, opt := range opts { opt(p) } - p.client.Timeout = p.timeout + if p.apiKey == "" && p.credentialsJSON == "" { + return nil, NewSTTError(ErrAuthentication, "google: API key or credentials JSON is required") + } + + if p.client == nil { + client, err := newGoogleRecognizeClient(context.Background(), p.clientOptions()...) + if err != nil { + return nil, NewSTTErrorf(ErrAuthentication, "google-speech: failed to initialize official client: %v", err) + } + p.client = client + } return p, nil } +func (p *GoogleProvider) clientOptions() []option.ClientOption { + opts := make([]option.ClientOption, 0, 2) + if p.credentialsJSON != "" { + opts = append(opts, option.WithCredentialsJSON([]byte(p.credentialsJSON))) + } else { + opts = append(opts, option.WithAPIKey(p.apiKey)) + } + if p.baseURL != "" && p.baseURL != "https://speech.googleapis.com" { + opts = append(opts, option.WithEndpoint(p.baseURL)) + } + return opts +} + func (p *GoogleProvider) Name() string { return "google-speech" } @@ -219,6 +247,24 @@ func (p *GoogleProvider) TranscribeStream(ctx context.Context, reader io.Reader, } func (p *GoogleProvider) doTranscribe(ctx context.Context, audio []byte, options TranscribeOptions) (*TranscriptResult, error) { + req := p.buildRecognizeRequest(audio, options) + + requestCtx := ctx + var cancel context.CancelFunc + if _, hasDeadline := ctx.Deadline(); !hasDeadline && p.timeout > 0 { + requestCtx, cancel = context.WithTimeout(ctx, p.timeout) + defer cancel() + } + + resp, err := p.client.Recognize(requestCtx, req) + if err != nil { + return nil, p.handleClientError(err) + } + + return p.parseRecognizeResponse(resp, options) +} + +func (p *GoogleProvider) buildRecognizeRequest(audio []byte, options TranscribeOptions) *speechpb.RecognizeRequest { encoding := p.mapInputFormatToEncoding(options.InputFormat) sampleRate := int32(options.SampleRate) @@ -226,9 +272,9 @@ func (p *GoogleProvider) doTranscribe(ctx context.Context, audio []byte, options sampleRate = p.guessSampleRate(options.InputFormat) } - reqBody := googleRecognizeRequest{ - Config: googleRecognitionConfigRequest{ - Encoding: string(encoding), + return &speechpb.RecognizeRequest{ + Config: &speechpb.RecognitionConfig{ + Encoding: p.toProtoRecognitionEncoding(encoding), SampleRateHertz: sampleRate, LanguageCode: options.Language, Model: string(p.model), @@ -238,43 +284,10 @@ func (p *GoogleProvider) doTranscribe(ctx context.Context, audio []byte, options EnableWordConfidence: true, EnableAutomaticPunctuation: true, }, - Audio: googleAudioRequest{ - Content: base64.StdEncoding.EncodeToString(audio), + Audio: &speechpb.RecognitionAudio{ + AudioSource: &speechpb.RecognitionAudio_Content{Content: audio}, }, } - - body, err := json.Marshal(reqBody) - if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "google-speech: failed to marshal request: %v", err) - } - - url := fmt.Sprintf("%s/v1/speech:recognize?key=%s", p.baseURL, p.apiKey) - - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) - if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "google-speech: failed to create request: %v", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "anyclaw-stt/1.0") - - resp, err := p.client.Do(req) - if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "google-speech: request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, p.handleErrorResponse(resp.StatusCode, respBody) - } - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "google-speech: failed to read response: %v", err) - } - - return p.parseResponse(respBody, options) } func (p *GoogleProvider) mapInputFormatToEncoding(format AudioInputFormat) RecognitionEncoding { @@ -285,7 +298,9 @@ func (p *GoogleProvider) mapInputFormatToEncoding(format AudioInputFormat) Recog return EncodingFLAC case InputMP3: return EncodingMP3 - case InputOGG, InputWEBM: + case InputOGG: + return EncodingOGGOpus + case InputWEBM: return EncodingWEBMOpus case InputM4A, InputMP4: return EncodingWEBMOpus @@ -296,6 +311,31 @@ func (p *GoogleProvider) mapInputFormatToEncoding(format AudioInputFormat) Recog } } +func (p *GoogleProvider) toProtoRecognitionEncoding(encoding RecognitionEncoding) speechpb.RecognitionConfig_AudioEncoding { + switch encoding { + case EncodingLinear16: + return speechpb.RecognitionConfig_LINEAR16 + case EncodingFLAC: + return speechpb.RecognitionConfig_FLAC + case EncodingMULAW: + return speechpb.RecognitionConfig_MULAW + case EncodingAMR: + return speechpb.RecognitionConfig_AMR + case EncodingAMRWB: + return speechpb.RecognitionConfig_AMR_WB + case EncodingOGGOpus: + return speechpb.RecognitionConfig_OGG_OPUS + case EncodingSpeexWithHeaderByte: + return speechpb.RecognitionConfig_SPEEX_WITH_HEADER_BYTE + case EncodingWEBMOpus: + return speechpb.RecognitionConfig_WEBM_OPUS + case EncodingMP3: + return speechpb.RecognitionConfig_MP3 + default: + return speechpb.RecognitionConfig_ENCODING_UNSPECIFIED + } +} + func (p *GoogleProvider) guessSampleRate(format AudioInputFormat) int32 { switch format { case InputWAV, InputPCM: @@ -313,104 +353,41 @@ func (p *GoogleProvider) guessSampleRate(format AudioInputFormat) int32 { } } -func (p *GoogleProvider) handleErrorResponse(statusCode int, body []byte) error { - var errResp googleErrorResponse - if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error.Message != "" { - msg := fmt.Sprintf("google-speech: API error: %s (status: %s)", errResp.Error.Message, errResp.Error.Status) - switch statusCode { - case http.StatusUnauthorized, http.StatusForbidden: +func (p *GoogleProvider) handleClientError(err error) error { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return NewSTTErrorf(ErrTranscriptionFailed, "google-speech: request context error: %v", err) + } + + var apiErr *googleapi.Error + if errors.As(err, &apiErr) { + msg := fmt.Sprintf("google-speech: API error: %s", apiErr.Message) + switch apiErr.Code { + case 400: + return NewSTTError(ErrAudioFormatInvalid, msg) + case 401, 403: return NewSTTError(ErrAuthentication, msg) - case http.StatusTooManyRequests: + case 429: return NewSTTError(ErrRateLimited, msg) - case http.StatusBadRequest: - return NewSTTError(ErrAudioFormatInvalid, msg) default: return NewSTTError(ErrTranscriptionFailed, msg) } } - switch statusCode { - case http.StatusUnauthorized, http.StatusForbidden: - return NewSTTError(ErrAuthentication, fmt.Sprintf("google-speech: authentication failed: %s", string(body))) - case http.StatusTooManyRequests: - return NewSTTError(ErrRateLimited, fmt.Sprintf("google-speech: rate limited: %s", string(body))) - case http.StatusBadRequest: - return NewSTTError(ErrAudioFormatInvalid, fmt.Sprintf("google-speech: invalid request: %s", string(body))) - case http.StatusServiceUnavailable: - return NewSTTError(ErrTranscriptionFailed, fmt.Sprintf("google-speech: service unavailable: %s", string(body))) + switch status.Code(err) { + case codes.InvalidArgument: + return NewSTTError(ErrAudioFormatInvalid, "google-speech: invalid recognition request") + case codes.Unauthenticated, codes.PermissionDenied: + return NewSTTError(ErrAuthentication, "google-speech: authentication failed") + case codes.ResourceExhausted: + return NewSTTError(ErrRateLimited, "google-speech: rate limited") default: - return NewSTTErrorf(ErrTranscriptionFailed, "google-speech: unexpected status %d: %s", statusCode, string(body)) + return NewSTTErrorf(ErrTranscriptionFailed, "google-speech: request failed: %v", err) } } -type googleRecognizeRequest struct { - Config googleRecognitionConfigRequest `json:"config"` - Audio googleAudioRequest `json:"audio"` -} - -type googleRecognitionConfigRequest struct { - Encoding string `json:"encoding"` - SampleRateHertz int32 `json:"sampleRateHertz"` - LanguageCode string `json:"languageCode"` - Model string `json:"model,omitempty"` - UseEnhanced bool `json:"useEnhanced,omitempty"` - MaxAlternatives int32 `json:"maxAlternatives,omitempty"` - EnableWordTimeOffsets bool `json:"enableWordTimeOffsets,omitempty"` - EnableWordConfidence bool `json:"enableWordConfidence,omitempty"` - EnableAutomaticPunctuation bool `json:"enableAutomaticPunctuation,omitempty"` - EnableSpokenPunctuation bool `json:"enableSpokenPunctuation,omitempty"` -} - -type googleAudioRequest struct { - Content string `json:"content"` -} - -type googleResponse struct { - Results []googleResult `json:"results"` -} - -type googleResult struct { - Alternatives []googleAlternative `json:"alternatives"` - LanguageCode string `json:"languageCode"` - ResultEndTime struct { - Seconds string `json:"seconds"` - Nanos int `json:"nanos"` - } `json:"resultEndTime"` -} - -type googleAlternative struct { - Transcript string `json:"transcript"` - Confidence float64 `json:"confidence"` - Words []googleWordInfo `json:"words"` -} - -type googleWordInfo struct { - StartTime googleDuration `json:"startTime"` - EndTime googleDuration `json:"endTime"` - Word string `json:"word"` - Confidence float64 `json:"confidence"` -} - -type googleDuration struct { - Seconds string `json:"seconds"` - Nanos int `json:"nanos"` -} - -type googleErrorResponse struct { - Error struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` - } `json:"error"` -} - -func (p *GoogleProvider) parseResponse(body []byte, options TranscribeOptions) (*TranscriptResult, error) { - var resp googleResponse - if err := json.Unmarshal(body, &resp); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "google-speech: failed to parse JSON response: %v", err) - } - - if len(resp.Results) == 0 { +func (p *GoogleProvider) parseRecognizeResponse(resp *speechpb.RecognizeResponse, options TranscribeOptions) (*TranscriptResult, error) { + results := resp.GetResults() + if len(results) == 0 { return &TranscriptResult{ Text: "", Language: options.Language, @@ -420,35 +397,40 @@ func (p *GoogleProvider) parseResponse(body []byte, options TranscribeOptions) ( result := &TranscriptResult{} var totalConfidence float64 var confidenceCount int + var lastEnd time.Duration - for i, res := range resp.Results { - if len(res.Alternatives) == 0 { + for i, res := range results { + if len(res.GetAlternatives()) == 0 { continue } - primary := res.Alternatives[0] - + primary := res.GetAlternatives()[0] segment := SegmentInfo{ ID: i, - Text: primary.Transcript, + Text: primary.GetTranscript(), } - if primary.Confidence > 0 { - segment.Confidence = primary.Confidence - totalConfidence += primary.Confidence + if confidence := primary.GetConfidence(); confidence > 0 { + segment.Confidence = float64(confidence) + totalConfidence += segment.Confidence confidenceCount++ } - if len(primary.Words) > 0 { - segment.Words = make([]WordInfo, 0, len(primary.Words)) - for _, w := range primary.Words { - segment.Words = append(segment.Words, WordInfo{ - Word: w.Word, - StartTime: parseGoogleDuration(w.StartTime), - EndTime: parseGoogleDuration(w.EndTime), - Confidence: w.Confidence, - }) + if len(primary.GetWords()) > 0 { + segment.Words = make([]WordInfo, 0, len(primary.GetWords())) + for _, word := range primary.GetWords() { + wordInfo := WordInfo{ + Word: word.GetWord(), + StartTime: parseProtoDuration(word.GetStartTime()), + EndTime: parseProtoDuration(word.GetEndTime()), + Confidence: float64(word.GetConfidence()), + } + segment.Words = append(segment.Words, wordInfo) } + segment.StartTime = segment.Words[0].StartTime + segment.EndTime = segment.Words[len(segment.Words)-1].EndTime + } else { + segment.EndTime = parseProtoDuration(res.GetResultEndTime()) } if options.WordTimestamps && len(segment.Words) > 0 { @@ -458,40 +440,47 @@ func (p *GoogleProvider) parseResponse(body []byte, options TranscribeOptions) ( result.Segments = append(result.Segments, segment) if i == 0 { - result.Text = primary.Transcript - result.Language = res.LanguageCode + result.Text = primary.GetTranscript() + if lang := res.GetLanguageCode(); lang != "" { + result.Language = lang + } } else { - result.Text += " " + primary.Transcript + result.Text += " " + primary.GetTranscript() } - if options.MaxAlternatives > 1 && len(res.Alternatives) > 1 { - for _, alt := range res.Alternatives[1:] { - result.Alternatives = append(result.Alternatives, alt.Transcript) + if options.MaxAlternatives > 1 && len(res.GetAlternatives()) > 1 { + for _, alt := range res.GetAlternatives()[1:] { + result.Alternatives = append(result.Alternatives, alt.GetTranscript()) } } + + if segment.EndTime > lastEnd { + lastEnd = segment.EndTime + } + } + + if result.Language == "" { + result.Language = options.Language } if confidenceCount > 0 { result.Confidence = totalConfidence / float64(confidenceCount) } - if len(resp.Results) > 0 { - endTime := resp.Results[len(resp.Results)-1].ResultEndTime - result.Duration = parseGoogleDuration(googleDuration{ - Seconds: endTime.Seconds, - Nanos: endTime.Nanos, - }) + if lastEnd > 0 { + result.Duration = lastEnd + } else { + result.Duration = parseProtoDuration(resp.GetTotalBilledTime()) } return result, nil } -func parseGoogleDuration(d googleDuration) time.Duration { - var seconds int64 - if d.Seconds != "" { - fmt.Sscanf(d.Seconds, "%d", &seconds) +func parseProtoDuration(d *durationpb.Duration) time.Duration { + if d == nil { + return 0 } - return time.Duration(seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond + return d.AsDuration() } func (p *GoogleProvider) ListLanguages(ctx context.Context) ([]string, error) { diff --git a/pkg/speech/stt_google_client.go b/pkg/speech/stt_google_client.go new file mode 100644 index 00000000..aed01cb6 --- /dev/null +++ b/pkg/speech/stt_google_client.go @@ -0,0 +1,34 @@ +package speech + +import ( + "context" + + speechapi "cloud.google.com/go/speech/apiv1" + speechpb "cloud.google.com/go/speech/apiv1/speechpb" + "google.golang.org/api/option" +) + +type googleRecognizeAPI interface { + Recognize(context.Context, *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) + Close() error +} + +type googleRecognizeClient struct { + client *speechapi.Client +} + +func newGoogleRecognizeClient(ctx context.Context, clientOpts ...option.ClientOption) (googleRecognizeAPI, error) { + client, err := speechapi.NewRESTClient(ctx, clientOpts...) + if err != nil { + return nil, err + } + return &googleRecognizeClient{client: client}, nil +} + +func (c *googleRecognizeClient) Recognize(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + return c.client.Recognize(ctx, req) +} + +func (c *googleRecognizeClient) Close() error { + return c.client.Close() +} diff --git a/pkg/speech/stt_google_test.go b/pkg/speech/stt_google_test.go index b9fea4a2..cb7fbdac 100644 --- a/pkg/speech/stt_google_test.go +++ b/pkg/speech/stt_google_test.go @@ -2,19 +2,41 @@ package speech import ( "context" - "encoding/json" - "net/http" - "net/http/httptest" + "errors" + "math" "strings" "testing" "time" + + speechpb "cloud.google.com/go/speech/apiv1/speechpb" + "google.golang.org/api/googleapi" + "google.golang.org/protobuf/types/known/durationpb" ) +type fakeGoogleRecognizeClient struct { + calls int + lastRequest *speechpb.RecognizeRequest + recognizeFn func(context.Context, *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) +} + +func (f *fakeGoogleRecognizeClient) Recognize(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + f.calls++ + f.lastRequest = req + if f.recognizeFn != nil { + return f.recognizeFn(ctx, req) + } + return &speechpb.RecognizeResponse{}, nil +} + +func (f *fakeGoogleRecognizeClient) Close() error { + return nil +} + func TestNewGoogleProvider(t *testing.T) { - t.Run("requires API key", func(t *testing.T) { - _, err := NewGoogleProvider("") + t.Run("requires API key or credentials JSON", func(t *testing.T) { + _, err := NewGoogleProvider("", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) if err == nil { - t.Fatal("expected error when API key is empty") + t.Fatal("expected error when auth config is empty") } sttErr, ok := err.(*STTError) if !ok { @@ -26,7 +48,8 @@ func TestNewGoogleProvider(t *testing.T) { }) t.Run("creates provider with defaults", func(t *testing.T) { - p, err := NewGoogleProvider("test-key") + fake := &fakeGoogleRecognizeClient{} + p, err := NewGoogleProvider("test-key", withGoogleRecognizeClient(fake)) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -48,16 +71,21 @@ func TestNewGoogleProvider(t *testing.T) { if p.retries != 2 { t.Errorf("expected 2 retries, got %d", p.retries) } + if p.client != fake { + t.Fatal("expected injected fake client to be used") + } }) t.Run("applies options", func(t *testing.T) { p, err := NewGoogleProvider("test-key", - WithGoogleBaseURL("https://custom.speech.api.com"), + withGoogleRecognizeClient(&fakeGoogleRecognizeClient{}), + WithGoogleBaseURL("https://custom.speech.api.com/"), WithGoogleLanguageCode("zh-CN"), WithGoogleModel(GoogleModelLatestLong), WithGoogleEnhanced(true), WithGoogleTimeout(30*time.Second), WithGoogleRetries(5), + WithGoogleCredentialsJSON(`{"type":"service_account"}`), ) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -80,25 +108,15 @@ func TestNewGoogleProvider(t *testing.T) { if p.retries != 5 { t.Errorf("expected 5 retries, got %d", p.retries) } - }) - - t.Run("trims trailing slash from baseURL", func(t *testing.T) { - p, err := NewGoogleProvider("test-key", WithGoogleBaseURL("https://api.example.com/")) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if strings.HasSuffix(p.baseURL, "/") { - t.Errorf("baseURL should not have trailing slash: %s", p.baseURL) + if p.credentialsJSON == "" { + t.Error("expected credentials JSON to be stored") } }) } func TestGoogleProviderTranscribe(t *testing.T) { t.Run("rejects empty audio", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) _, err := p.Transcribe(context.Background(), nil) if err == nil { t.Fatal("expected error for empty audio") @@ -106,10 +124,7 @@ func TestGoogleProviderTranscribe(t *testing.T) { }) t.Run("rejects audio too large", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) largeAudio := make([]byte, 101*1024*1024) _, err := p.Transcribe(context.Background(), largeAudio) if err == nil { @@ -124,49 +139,46 @@ func TestGoogleProviderTranscribe(t *testing.T) { } }) - t.Run("successful transcription", func(t *testing.T) { - response := googleResponse{ - Results: []googleResult{ - { - Alternatives: []googleAlternative{ + t.Run("successful transcription and request mapping", func(t *testing.T) { + fake := &fakeGoogleRecognizeClient{ + recognizeFn: func(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + return &speechpb.RecognizeResponse{ + Results: []*speechpb.SpeechRecognitionResult{ { - Transcript: "Hello world", - Confidence: 0.95, - Words: []googleWordInfo{ - {Word: "Hello", Confidence: 0.96, StartTime: googleDuration{Seconds: "0", Nanos: 0}, EndTime: googleDuration{Seconds: "0", Nanos: 500000000}}, - {Word: "world", Confidence: 0.94, StartTime: googleDuration{Seconds: "0", Nanos: 600000000}, EndTime: googleDuration{Seconds: "1", Nanos: 0}}, + Alternatives: []*speechpb.SpeechRecognitionAlternative{ + { + Transcript: "Hello world", + Confidence: 0.95, + Words: []*speechpb.WordInfo{ + { + Word: "Hello", + Confidence: 0.96, + StartTime: durationpb.New(0), + EndTime: durationpb.New(500 * time.Millisecond), + }, + { + Word: "world", + Confidence: 0.94, + StartTime: durationpb.New(600 * time.Millisecond), + EndTime: durationpb.New(time.Second), + }, + }, + }, }, + LanguageCode: "en-US", + ResultEndTime: durationpb.New(2500 * time.Millisecond), }, }, - LanguageCode: "en-US", - ResultEndTime: struct { - Seconds string `json:"seconds"` - Nanos int `json:"nanos"` - }{Seconds: "2", Nanos: 500000000}, - }, + }, nil }, } - respBody, _ := json.Marshal(response) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - t.Errorf("expected POST, got %s", r.Method) - } - if !strings.Contains(r.URL.Query().Get("key"), "test-key") { - t.Error("missing API key in query") - } - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("expected application/json, got %s", r.Header.Get("Content-Type")) - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write(respBody) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) - result, err := p.Transcribe(context.Background(), []byte("fake-audio-data")) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(fake)) + result, err := p.Transcribe(context.Background(), []byte("fake-audio-data"), + WithSTTLanguage("zh-CN"), + WithSTTWordTimestamps(true), + WithSTTMaxAlternatives(3), + ) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -177,8 +189,8 @@ func TestGoogleProviderTranscribe(t *testing.T) { if result.Language != "en-US" { t.Errorf("expected language 'en-US', got '%s'", result.Language) } - if result.Duration != 2500*time.Millisecond { - t.Errorf("expected duration 2.5s, got %v", result.Duration) + if result.Duration != time.Second { + t.Errorf("expected duration 1s from word timestamps, got %v", result.Duration) } if len(result.Segments) != 1 { t.Fatalf("expected 1 segment, got %d", len(result.Segments)) @@ -186,243 +198,151 @@ func TestGoogleProviderTranscribe(t *testing.T) { if len(result.Segments[0].Words) != 2 { t.Fatalf("expected 2 words, got %d", len(result.Segments[0].Words)) } - if result.Confidence != 0.95 { + if math.Abs(result.Confidence-0.95) > 0.0001 { t.Errorf("expected confidence 0.95, got %f", result.Confidence) } - }) - t.Run("multiple segments", func(t *testing.T) { - response := googleResponse{ - Results: []googleResult{ - {Alternatives: []googleAlternative{{Transcript: "First segment"}}, LanguageCode: "en-US"}, - {Alternatives: []googleAlternative{{Transcript: "Second segment"}}, LanguageCode: "en-US"}, - }, + req := fake.lastRequest + if req == nil { + t.Fatal("expected request to be captured") } - - respBody, _ := json.Marshal(response) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write(respBody) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) - result, err := p.Transcribe(context.Background(), []byte("fake-audio")) - if err != nil { - t.Fatalf("unexpected error: %v", err) + if req.GetConfig().GetLanguageCode() != "zh-CN" { + t.Errorf("expected request language zh-CN, got %s", req.GetConfig().GetLanguageCode()) } - - expectedText := "First segment Second segment" - if result.Text != expectedText { - t.Errorf("expected '%s', got '%s'", expectedText, result.Text) + if !req.GetConfig().GetEnableWordTimeOffsets() { + t.Error("expected EnableWordTimeOffsets to be true") } - if len(result.Segments) != 2 { - t.Fatalf("expected 2 segments, got %d", len(result.Segments)) + if req.GetConfig().GetMaxAlternatives() != 3 { + t.Errorf("expected max alternatives 3, got %d", req.GetConfig().GetMaxAlternatives()) + } + if req.GetConfig().GetEncoding() != speechpb.RecognitionConfig_MP3 { + t.Errorf("expected MP3 encoding, got %v", req.GetConfig().GetEncoding()) + } + if len(req.GetAudio().GetContent()) == 0 { + t.Error("expected inline audio content to be populated") } }) - t.Run("alternatives", func(t *testing.T) { - response := googleResponse{ - Results: []googleResult{ - { - Alternatives: []googleAlternative{ - {Transcript: "Hello world", Confidence: 0.95}, - {Transcript: "Hello word", Confidence: 0.80}, - {Transcript: "Halo world", Confidence: 0.70}, + t.Run("multiple segments and alternatives", func(t *testing.T) { + fake := &fakeGoogleRecognizeClient{ + recognizeFn: func(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + return &speechpb.RecognizeResponse{ + Results: []*speechpb.SpeechRecognitionResult{ + { + Alternatives: []*speechpb.SpeechRecognitionAlternative{ + {Transcript: "First segment", Confidence: 0.9}, + {Transcript: "First segments", Confidence: 0.7}, + }, + LanguageCode: "en-US", + ResultEndTime: durationpb.New(time.Second), + }, + { + Alternatives: []*speechpb.SpeechRecognitionAlternative{ + {Transcript: "Second segment", Confidence: 0.8}, + {Transcript: "Second segments", Confidence: 0.6}, + }, + LanguageCode: "en-US", + ResultEndTime: durationpb.New(2 * time.Second), + }, }, - LanguageCode: "en-US", - }, + }, nil }, } - respBody, _ := json.Marshal(response) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write(respBody) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) - result, err := p.Transcribe(context.Background(), []byte("fake-audio"), - WithSTTMaxAlternatives(3)) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(fake)) + result, err := p.Transcribe(context.Background(), []byte("fake-audio"), WithSTTMaxAlternatives(3)) if err != nil { t.Fatalf("unexpected error: %v", err) } + if result.Text != "First segment Second segment" { + t.Errorf("unexpected combined text: %s", result.Text) + } + if len(result.Segments) != 2 { + t.Fatalf("expected 2 segments, got %d", len(result.Segments)) + } if len(result.Alternatives) != 2 { t.Fatalf("expected 2 alternatives, got %d", len(result.Alternatives)) } - if result.Alternatives[0] != "Hello word" { - t.Errorf("expected first alternative 'Hello word', got '%s'", result.Alternatives[0]) + if result.Duration != 2*time.Second { + t.Errorf("expected duration 2s, got %v", result.Duration) } }) t.Run("empty results", func(t *testing.T) { - response := googleResponse{Results: []googleResult{}} - respBody, _ := json.Marshal(response) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write(respBody) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) result, err := p.Transcribe(context.Background(), []byte("fake-audio")) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Text != "" { - t.Errorf("expected empty text, got '%s'", result.Text) + t.Errorf("expected empty text, got %q", result.Text) } }) - t.Run("handles authentication error", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{"error":{"code":401,"message":"API key not valid. Please pass a valid API key.","status":"UNAUTHENTICATED"}}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("bad-key", WithGoogleBaseURL(server.URL), WithGoogleRetries(0)) - _, err := p.Transcribe(context.Background(), []byte("fake-audio")) - if err == nil { - t.Fatal("expected authentication error") - } - sttErr, ok := err.(*STTError) - if !ok { - t.Fatalf("expected *STTError, got %T", err) - } - if sttErr.Code != ErrAuthentication { - t.Errorf("expected ErrAuthentication, got %s", sttErr.Code) + t.Run("does not retry auth errors", func(t *testing.T) { + fake := &fakeGoogleRecognizeClient{ + recognizeFn: func(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + return nil, &googleapi.Error{Code: 401, Message: "invalid API key"} + }, } - }) - - t.Run("handles forbidden error", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - w.Write([]byte(`{"error":{"code":403,"message":"API key expired.","status":"PERMISSION_DENIED"}}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("expired-key", WithGoogleBaseURL(server.URL), WithGoogleRetries(0)) + p, _ := NewGoogleProvider("bad-key", withGoogleRecognizeClient(fake), WithGoogleRetries(3)) _, err := p.Transcribe(context.Background(), []byte("fake-audio")) if err == nil { t.Fatal("expected authentication error") } - sttErr, ok := err.(*STTError) - if !ok { - t.Fatalf("expected *STTError, got %T", err) - } - if sttErr.Code != ErrAuthentication { - t.Errorf("expected ErrAuthentication, got %s", sttErr.Code) - } - }) - - t.Run("handles rate limit error", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - w.Write([]byte(`{"error":{"code":429,"message":"Quota exceeded.","status":"RESOURCE_EXHAUSTED"}}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL), WithGoogleRetries(0)) - _, err := p.Transcribe(context.Background(), []byte("fake-audio")) - if err == nil { - t.Fatal("expected rate limit error") - } - sttErr, ok := err.(*STTError) - if !ok { - t.Fatalf("expected *STTError, got %T", err) - } - if sttErr.Code != ErrRateLimited { - t.Errorf("expected ErrRateLimited, got %s", sttErr.Code) + if fake.calls != 1 { + t.Errorf("expected 1 call, got %d", fake.calls) } }) - t.Run("context cancellation", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL), WithGoogleRetries(1)) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := p.Transcribe(ctx, []byte("fake-audio")) - if err == nil { - t.Fatal("expected context cancellation error") + t.Run("retries transient errors then succeeds", func(t *testing.T) { + fake := &fakeGoogleRecognizeClient{} + fake.recognizeFn = func(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + if fake.calls == 1 { + return nil, &googleapi.Error{Code: 503, Message: "service unavailable"} + } + return &speechpb.RecognizeResponse{ + Results: []*speechpb.SpeechRecognitionResult{ + { + Alternatives: []*speechpb.SpeechRecognitionAlternative{{Transcript: "Success after retry"}}, + LanguageCode: "en-US", + ResultEndTime: durationpb.New(time.Second), + }, + }, + }, nil } - }) - t.Run("uses correct URL with API key", func(t *testing.T) { - var receivedURL string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedURL = r.URL.String() - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"results":[{"alternatives":[{"transcript":"test"}],"languageCode":"en-US"}]}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("my-api-key", WithGoogleBaseURL(server.URL)) - _, err := p.Transcribe(context.Background(), []byte("fake-audio")) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(fake), WithGoogleRetries(2)) + result, err := p.Transcribe(context.Background(), []byte("fake-audio")) if err != nil { t.Fatalf("unexpected error: %v", err) } - if !strings.Contains(receivedURL, "key=my-api-key") { - t.Errorf("expected URL to contain 'key=my-api-key', got %s", receivedURL) + if result.Text != "Success after retry" { + t.Errorf("expected 'Success after retry', got '%s'", result.Text) } - if !strings.Contains(receivedURL, "/v1/speech:recognize") { - t.Errorf("expected URL to contain '/v1/speech:recognize', got %s", receivedURL) + if fake.calls != 2 { + t.Errorf("expected 2 calls, got %d", fake.calls) } }) - t.Run("sends correct request body", func(t *testing.T) { - var receivedBody googleRecognizeRequest - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - json.NewDecoder(r.Body).Decode(&receivedBody) - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"results":[{"alternatives":[{"transcript":"test"}],"languageCode":"en-US"}]}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) - _, err := p.Transcribe(context.Background(), []byte("fake-audio"), - WithSTTLanguage("zh-CN"), - WithSTTWordTimestamps(true), - WithSTTMaxAlternatives(3)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if receivedBody.Config.LanguageCode != "zh-CN" { - t.Errorf("expected language zh-CN, got %s", receivedBody.Config.LanguageCode) - } - if !receivedBody.Config.EnableWordTimeOffsets { - t.Error("expected EnableWordTimeOffsets to be true") - } - if receivedBody.Config.MaxAlternatives != 3 { - t.Errorf("expected maxAlternatives 3, got %d", receivedBody.Config.MaxAlternatives) + t.Run("context cancellation", func(t *testing.T) { + fake := &fakeGoogleRecognizeClient{ + recognizeFn: func(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + return nil, context.Canceled + }, } - if receivedBody.Config.EnableAutomaticPunctuation != true { - t.Error("expected EnableAutomaticPunctuation to be true") + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(fake), WithGoogleRetries(0)) + _, err := p.Transcribe(context.Background(), []byte("fake-audio")) + if err == nil { + t.Fatal("expected context cancellation error") } }) } func TestGoogleProviderTranscribeStream(t *testing.T) { t.Run("rejects nil reader", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) _, err := p.TranscribeStream(context.Background(), nil) if err == nil { t.Fatal("expected error for nil reader") @@ -430,13 +350,19 @@ func TestGoogleProviderTranscribeStream(t *testing.T) { }) t.Run("successful stream transcription", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"results":[{"alternatives":[{"transcript":"Stream content","confidence":0.9}],"languageCode":"en-US"}]}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) + fake := &fakeGoogleRecognizeClient{ + recognizeFn: func(ctx context.Context, req *speechpb.RecognizeRequest) (*speechpb.RecognizeResponse, error) { + return &speechpb.RecognizeResponse{ + Results: []*speechpb.SpeechRecognitionResult{ + { + Alternatives: []*speechpb.SpeechRecognitionAlternative{{Transcript: "Stream content", Confidence: 0.9}}, + LanguageCode: "en-US", + }, + }, + }, nil + }, + } + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(fake)) reader := strings.NewReader("stream-audio-data") result, err := p.TranscribeStream(context.Background(), reader) if err != nil { @@ -449,27 +375,22 @@ func TestGoogleProviderTranscribeStream(t *testing.T) { } func TestGoogleProviderTranscribeFile(t *testing.T) { - t.Run("returns not supported error", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL)) - _, err := p.TranscribeFile(context.Background(), "/some/file.mp3") - if err == nil { - t.Fatal("expected error for file transcription") - } - sttErr, ok := err.(*STTError) - if !ok { - t.Fatalf("expected *STTError, got %T", err) - } - if sttErr.Code != ErrProviderNotSupported { - t.Errorf("expected ErrProviderNotSupported, got %s", sttErr.Code) - } - }) + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) + _, err := p.TranscribeFile(context.Background(), "/some/file.mp3") + if err == nil { + t.Fatal("expected error for file transcription") + } + sttErr, ok := err.(*STTError) + if !ok { + t.Fatalf("expected *STTError, got %T", err) + } + if sttErr.Code != ErrProviderNotSupported { + t.Errorf("expected ErrProviderNotSupported, got %s", sttErr.Code) + } } func TestGoogleProviderListLanguages(t *testing.T) { - p, _ := NewGoogleProvider("test-key") + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) langs, err := p.ListLanguages(context.Background()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -477,32 +398,10 @@ func TestGoogleProviderListLanguages(t *testing.T) { if len(langs) == 0 { t.Fatal("expected non-empty language list") } - - found := false - for _, lang := range langs { - if lang == "en-US" { - found = true - break - } - } - if !found { - t.Error("expected 'en-US' in language list") - } - - found = false - for _, lang := range langs { - if lang == "zh-CN" { - found = true - break - } - } - if !found { - t.Error("expected 'zh-CN' in language list") - } } func TestGoogleProviderEncodingMapping(t *testing.T) { - p, _ := NewGoogleProvider("test-key") + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) tests := []struct { format AudioInputFormat @@ -512,7 +411,7 @@ func TestGoogleProviderEncodingMapping(t *testing.T) { {InputPCM, EncodingLinear16}, {InputFLAC, EncodingFLAC}, {InputMP3, EncodingMP3}, - {InputOGG, EncodingWEBMOpus}, + {InputOGG, EncodingOGGOpus}, {InputWEBM, EncodingWEBMOpus}, {InputM4A, EncodingWEBMOpus}, {InputMP4, EncodingWEBMOpus}, @@ -531,7 +430,7 @@ func TestGoogleProviderEncodingMapping(t *testing.T) { } func TestGoogleProviderSampleRateGuessing(t *testing.T) { - p, _ := NewGoogleProvider("test-key") + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) tests := []struct { format AudioInputFormat @@ -557,137 +456,89 @@ func TestGoogleProviderSampleRateGuessing(t *testing.T) { } } -func TestGoogleProviderRetries(t *testing.T) { - t.Run("retries on server error then succeeds", func(t *testing.T) { - callCount := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - if callCount < 2 { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(`{"error":{"code":503,"message":"Service unavailable","status":"UNAVAILABLE"}}`)) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"results":[{"alternatives":[{"transcript":"Success after retry"}],"languageCode":"en-US"}]}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("test-key", WithGoogleBaseURL(server.URL), WithGoogleRetries(2)) - result, err := p.Transcribe(context.Background(), []byte("fake-audio")) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result.Text != "Success after retry" { - t.Errorf("expected 'Success after retry', got '%s'", result.Text) - } - if callCount != 2 { - t.Errorf("expected 2 calls, got %d", callCount) - } - }) - - t.Run("does not retry on auth error", func(t *testing.T) { - callCount := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte(`{"error":{"code":401,"message":"Invalid API key","status":"UNAUTHENTICATED"}}`)) - })) - defer server.Close() - - p, _ := NewGoogleProvider("bad-key", WithGoogleBaseURL(server.URL), WithGoogleRetries(3)) - _, err := p.Transcribe(context.Background(), []byte("fake-audio")) - if err == nil { - t.Fatal("expected error") - } - if callCount != 1 { - t.Errorf("expected 1 call (no retry on auth error), got %d", callCount) - } - }) -} - -func TestParseGoogleDuration(t *testing.T) { +func TestParseProtoDuration(t *testing.T) { tests := []struct { name string - d googleDuration + d *durationpb.Duration want time.Duration }{ - {"zero", googleDuration{Seconds: "0", Nanos: 0}, 0}, - {"one second", googleDuration{Seconds: "1", Nanos: 0}, time.Second}, - {"500ms", googleDuration{Seconds: "0", Nanos: 500000000}, 500 * time.Millisecond}, - {"2.5s", googleDuration{Seconds: "2", Nanos: 500000000}, 2500 * time.Millisecond}, - {"1.234s", googleDuration{Seconds: "1", Nanos: 234000000}, time.Second + 234*time.Millisecond}, + {"nil", nil, 0}, + {"zero", durationpb.New(0), 0}, + {"one second", durationpb.New(time.Second), time.Second}, + {"500ms", durationpb.New(500 * time.Millisecond), 500 * time.Millisecond}, + {"2.5s", durationpb.New(2500 * time.Millisecond), 2500 * time.Millisecond}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := parseGoogleDuration(tt.d) + got := parseProtoDuration(tt.d) if got != tt.want { - t.Errorf("parseGoogleDuration(%v) = %v, want %v", tt.d, got, tt.want) + t.Errorf("parseProtoDuration(%v) = %v, want %v", tt.d, got, tt.want) } }) } } func TestNewSTTProviderGoogle(t *testing.T) { - t.Run("creates Google provider", func(t *testing.T) { - p, err := NewSTTProvider(STTConfig{ - Type: STTProviderGoogle, - APIKey: "test-key", - Timeout: 30 * time.Second, - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if p.Type() != STTProviderGoogle { - t.Errorf("expected STTProviderGoogle, got %s", p.Type()) - } - if p.Name() != "google-speech" { - t.Errorf("expected name 'google-speech', got %s", p.Name()) - } - }) - - t.Run("creates Google provider with language", func(t *testing.T) { - p, err := NewSTTProvider(STTConfig{ - Type: STTProviderGoogle, - APIKey: "test-key", - Language: "zh-CN", - }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - gp, ok := p.(*GoogleProvider) - if !ok { - t.Fatalf("expected *GoogleProvider, got %T", p) - } - if gp.languageCode != "zh-CN" { - t.Errorf("expected language zh-CN, got %s", gp.languageCode) - } + p, err := NewSTTProvider(STTConfig{ + Type: STTProviderGoogle, + APIKey: "test-key", + Timeout: 30 * time.Second, }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p.Type() != STTProviderGoogle { + t.Errorf("expected STTProviderGoogle, got %s", p.Type()) + } + if p.Name() != "google-speech" { + t.Errorf("expected name 'google-speech', got %s", p.Name()) + } } func TestGoogleSTTManager(t *testing.T) { - t.Run("register and use Google provider", func(t *testing.T) { - m := NewSTTManager() - p, _ := NewGoogleProvider("test-key") + m := NewSTTManager() + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) - err := m.Register("google", p) - if err != nil { - t.Fatalf("failed to register provider: %v", err) - } + err := m.Register("google", p) + if err != nil { + t.Fatalf("failed to register provider: %v", err) + } - providers := m.ListProviders() - if len(providers) != 1 { - t.Fatalf("expected 1 provider, got %d", len(providers)) - } + got, err := m.Get("google") + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + if got.Type() != STTProviderGoogle { + t.Errorf("expected STTProviderGoogle, got %s", got.Type()) + } +} - got, err := m.Get("google") - if err != nil { - t.Fatalf("failed to get provider: %v", err) - } - if got.Type() != STTProviderGoogle { - t.Errorf("expected STTProviderGoogle, got %s", got.Type()) - } - }) +func TestGoogleProviderHandleClientError(t *testing.T) { + p, _ := NewGoogleProvider("test-key", withGoogleRecognizeClient(&fakeGoogleRecognizeClient{})) + + tests := []struct { + name string + err error + want STTErrorCode + }{ + {"bad request", &googleapi.Error{Code: 400, Message: "bad request"}, ErrAudioFormatInvalid}, + {"unauthorized", &googleapi.Error{Code: 401, Message: "unauthorized"}, ErrAuthentication}, + {"forbidden", &googleapi.Error{Code: 403, Message: "forbidden"}, ErrAuthentication}, + {"rate limited", &googleapi.Error{Code: 429, Message: "quota"}, ErrRateLimited}, + {"generic", errors.New("boom"), ErrTranscriptionFailed}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := p.handleClientError(tt.err) + sttErr, ok := err.(*STTError) + if !ok { + t.Fatalf("expected *STTError, got %T", err) + } + if sttErr.Code != tt.want { + t.Errorf("expected %s, got %s", tt.want, sttErr.Code) + } + }) + } } diff --git a/pkg/speech/stt_openai_client.go b/pkg/speech/stt_openai_client.go new file mode 100644 index 00000000..a03cef4c --- /dev/null +++ b/pkg/speech/stt_openai_client.go @@ -0,0 +1,128 @@ +package speech + +import ( + "bytes" + "context" + "fmt" + "mime/multipart" + "net/http" +) + +type openAIAudioAPIClient struct { + apiKey string + baseURL string + client *http.Client +} + +func newOpenAIAudioAPIClient(apiKey, baseURL string, client *http.Client) *openAIAudioAPIClient { + return &openAIAudioAPIClient{ + apiKey: apiKey, + baseURL: baseURL, + client: client, + } +} + +func (c *openAIAudioAPIClient) DoTranscriptionRequest(ctx context.Context, endpoint string, audio []byte, options TranscribeOptions, stream bool) (*http.Response, error) { + body, contentType, err := c.buildMultipartBody(audio, options, stream) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+endpoint, body) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", contentType) + req.Header.Set("User-Agent", "anyclaw-stt/1.0") + if stream { + req.Header.Set("Accept", "text/event-stream") + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: request failed: %v", err) + } + + return resp, nil +} + +func (c *openAIAudioAPIClient) buildMultipartBody(audio []byte, options TranscribeOptions, stream bool) (*bytes.Buffer, string, error) { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + filename := "audio." + string(options.InputFormat) + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create form file: %v", err) + } + + if _, err := part.Write(audio); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write audio data: %v", err) + } + + if err := writer.WriteField("model", options.Model); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write model field: %v", err) + } + + if options.Language != "" { + if err := writer.WriteField("language", options.Language); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write language field: %v", err) + } + } + + if options.Prompt != "" { + if err := writer.WriteField("prompt", options.Prompt); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write prompt field: %v", err) + } + } + + if options.Temperature > 0 { + if err := writer.WriteField("temperature", fmt.Sprintf("%.2f", options.Temperature)); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write temperature field: %v", err) + } + } + + if options.MaxAlternatives > 0 { + if err := writer.WriteField("max_alternatives", fmt.Sprintf("%d", options.MaxAlternatives)); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write max_alternatives field: %v", err) + } + } + + // Streaming requests use response_format=json and should not send + // verbose-only timestamp granularities. + if !stream && (options.WordTimestamps || options.SpeakerLabels) { + if options.WordTimestamps { + if err := writer.WriteField("timestamp_granularities[]", "word"); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write word timestamp_granularities: %v", err) + } + } + if options.SpeakerLabels { + if err := writer.WriteField("timestamp_granularities[]", "segment"); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write segment timestamp_granularities: %v", err) + } + } + } + + responseType := "verbose_json" + if stream { + responseType = "json" + } + if err := writer.WriteField("response_format", responseType); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write response_format field: %v", err) + } + + if stream { + if err := writer.WriteField("stream", "true"); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write stream field: %v", err) + } + } + + if err := writer.Close(); err != nil { + return nil, "", NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to close multipart writer: %v", err) + } + + return &body, writer.FormDataContentType(), nil +} + diff --git a/pkg/speech/stt_provider.go b/pkg/speech/stt_provider.go index 5f6c9ba6..10f6bd5a 100644 --- a/pkg/speech/stt_provider.go +++ b/pkg/speech/stt_provider.go @@ -43,13 +43,14 @@ type STTProvider interface { } type STTConfig struct { - Type STTProviderType - APIKey string - BaseURL string - Model string - Language string - SampleRate int - Timeout time.Duration + Type STTProviderType + APIKey string + CredentialsJSON string + BaseURL string + Model string + Language string + SampleRate int + Timeout time.Duration } func NewSTTProvider(cfg STTConfig) (STTProvider, error) { @@ -71,6 +72,9 @@ func NewSTTProvider(cfg STTConfig) (STTProvider, error) { return NewWhisperProvider(cfg.APIKey, opts...) case STTProviderGoogle: opts := []GoogleOption{} + if cfg.CredentialsJSON != "" { + opts = append(opts, WithGoogleCredentialsJSON(cfg.CredentialsJSON)) + } if cfg.BaseURL != "" { opts = append(opts, WithGoogleBaseURL(cfg.BaseURL)) } diff --git a/pkg/speech/stt_whisper.go b/pkg/speech/stt_whisper.go index 9ac04a86..e94b451d 100644 --- a/pkg/speech/stt_whisper.go +++ b/pkg/speech/stt_whisper.go @@ -2,12 +2,10 @@ package speech import ( "bufio" - "bytes" "context" "encoding/json" "fmt" "io" - "mime/multipart" "net/http" "os" "path/filepath" @@ -45,6 +43,7 @@ type WhisperProvider struct { timeout time.Duration retries int client *http.Client + apiClient *openAIAudioAPIClient httpTransport *http.Transport } @@ -108,6 +107,7 @@ func NewWhisperProvider(apiKey string, opts ...WhisperOption) (*WhisperProvider, p.client.Transport = p.httpTransport } p.client.Timeout = p.timeout + p.apiClient = newOpenAIAudioAPIClient(p.apiKey, p.baseURL, p.client) if !validWhisperModels[p.model] { return nil, NewSTTErrorf(ErrProviderNotSupported, "openai: invalid whisper model: %s", p.model) @@ -256,72 +256,6 @@ func (p *WhisperProvider) validateTranscribeOptions(options TranscribeOptions) e } func (p *WhisperProvider) doTranscribe(ctx context.Context, audio []byte, options TranscribeOptions) (*TranscriptResult, error) { - var body bytes.Buffer - writer := multipart.NewWriter(&body) - - filename := "audio." + string(options.InputFormat) - part, err := writer.CreateFormFile("file", filename) - if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create form file: %v", err) - } - - if _, err := part.Write(audio); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write audio data: %v", err) - } - - if err := writer.WriteField("model", options.Model); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write model field: %v", err) - } - - if options.Language != "" { - if err := writer.WriteField("language", options.Language); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write language field: %v", err) - } - } - - if options.Prompt != "" { - if err := writer.WriteField("prompt", options.Prompt); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write prompt field: %v", err) - } - } - - if options.Temperature > 0 { - if err := writer.WriteField("temperature", fmt.Sprintf("%.2f", options.Temperature)); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write temperature field: %v", err) - } - } - - if options.MaxAlternatives > 0 { - if err := writer.WriteField("max_alternatives", fmt.Sprintf("%d", options.MaxAlternatives)); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write max_alternatives field: %v", err) - } - } - - if options.WordTimestamps || options.SpeakerLabels { - if options.WordTimestamps { - if err := writer.WriteField("timestamp_granularities[]", "word"); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write word timestamp_granularities: %v", err) - } - } - if options.SpeakerLabels { - if err := writer.WriteField("timestamp_granularities[]", "segment"); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write segment timestamp_granularities: %v", err) - } - } - } - - responseType := "verbose_json" - if options.WordTimestamps || options.SpeakerLabels { - responseType = "verbose_json" - } - if err := writer.WriteField("response_format", responseType); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write response_format field: %v", err) - } - - if err := writer.Close(); err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to close multipart writer: %v", err) - } - var endpoint string switch options.Mode { case ModeTranslation: @@ -330,20 +264,9 @@ func (p *WhisperProvider) doTranscribe(ctx context.Context, audio []byte, option endpoint = "/v1/audio/transcriptions" } - url := p.baseURL + endpoint - - req, err := http.NewRequestWithContext(ctx, "POST", url, &body) - if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create request: %v", err) - } - - req.Header.Set("Authorization", "Bearer "+p.apiKey) - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("User-Agent", "anyclaw-stt/1.0") - - resp, err := p.client.Do(req) + resp, err := p.apiClient.DoTranscriptionRequest(ctx, endpoint, audio, options, false) if err != nil { - return nil, NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: request failed: %v", err) + return nil, err } defer resp.Body.Close() @@ -520,60 +443,14 @@ func (p *WhisperProvider) TranscribeSSE(ctx context.Context, audio []byte, onChu return NewSTTError(ErrAudioFormatInvalid, "openai-whisper: audio data is empty") } - var body bytes.Buffer - writer := multipart.NewWriter(&body) - - filename := "audio." + string(options.InputFormat) - part, err := writer.CreateFormFile("file", filename) - if err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create form file: %v", err) - } - - if _, err := part.Write(audio); err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write audio data: %v", err) - } - - if err := writer.WriteField("model", options.Model); err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write model field: %v", err) - } - - if options.Language != "" { - if err := writer.WriteField("language", options.Language); err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write language field: %v", err) - } - } - - if err := writer.WriteField("response_format", "json"); err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write response_format field: %v", err) - } - - if err := writer.WriteField("stream", "true"); err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to write stream field: %v", err) - } - - if err := writer.Close(); err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to close multipart writer: %v", err) - } - endpoint := "/v1/audio/transcriptions" if options.Mode == ModeTranslation { endpoint = "/v1/audio/translations" } - url := p.baseURL + endpoint - - req, err := http.NewRequestWithContext(ctx, "POST", url, &body) + resp, err := p.apiClient.DoTranscriptionRequest(ctx, endpoint, audio, options, true) if err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: failed to create request: %v", err) - } - - req.Header.Set("Authorization", "Bearer "+p.apiKey) - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Accept", "text/event-stream") - - resp, err := p.client.Do(req) - if err != nil { - return NewSTTErrorf(ErrTranscriptionFailed, "openai-whisper: streaming request failed: %v", err) + return err } defer resp.Body.Close() diff --git a/pkg/speech/vad.go b/pkg/speech/vad.go index fc36af30..045f0f9f 100644 --- a/pkg/speech/vad.go +++ b/pkg/speech/vad.go @@ -15,6 +15,7 @@ const ( type VADConfig struct { SampleRate int FrameSize int + Aggressiveness int EnergyThreshold float64 ZeroCrossThreshold int SpeechMinFrames int @@ -26,6 +27,7 @@ func DefaultVADConfig() VADConfig { return VADConfig{ SampleRate: 16000, FrameSize: 320, + Aggressiveness: 2, EnergyThreshold: 0.01, ZeroCrossThreshold: 50, SpeechMinFrames: 3, @@ -45,6 +47,14 @@ type VAD struct { type VADStateListener func(state VADState, energy float64, zcr float64) +func (v *VAD) Name() string { + return "heuristic-vad" +} + +func (v *VAD) Type() VADProviderType { + return VADProviderHeuristic +} + func NewVAD(cfg VADConfig) *VAD { if cfg.SampleRate == 0 { cfg.SampleRate = 16000 @@ -52,6 +62,9 @@ func NewVAD(cfg VADConfig) *VAD { if cfg.FrameSize == 0 { cfg.FrameSize = 320 } + if cfg.Aggressiveness < 0 || cfg.Aggressiveness > 3 { + cfg.Aggressiveness = 2 + } if cfg.EnergyThreshold == 0 { cfg.EnergyThreshold = 0.01 } diff --git a/pkg/speech/vad_provider.go b/pkg/speech/vad_provider.go new file mode 100644 index 00000000..ce0de4e2 --- /dev/null +++ b/pkg/speech/vad_provider.go @@ -0,0 +1,58 @@ +package speech + +import "fmt" + +type VADProviderType string + +const ( + VADProviderHeuristic VADProviderType = "heuristic" + VADProviderWebRTC VADProviderType = "webrtc" +) + +type VADProcessor interface { + Name() string + Type() VADProviderType + ProcessFrame(samples []int16) VADState + ProcessFloatFrame(samples []float32) VADState + RegisterListener(listener VADStateListener) + State() VADState + Reset() + UpdateConfig(cfg VADConfig) + Config() VADConfig +} + +type VADProviderFactory func(cfg VADConfig) (VADProcessor, error) + +type VADManager struct { + factories map[VADProviderType]VADProviderFactory +} + +func NewVADManager() *VADManager { + m := &VADManager{ + factories: map[VADProviderType]VADProviderFactory{}, + } + m.Register(VADProviderHeuristic, func(cfg VADConfig) (VADProcessor, error) { + return NewVAD(cfg), nil + }) + m.Register(VADProviderWebRTC, func(cfg VADConfig) (VADProcessor, error) { + return NewWebRTCVAD(cfg) + }) + return m +} + +func (m *VADManager) Register(providerType VADProviderType, factory VADProviderFactory) { + m.factories[providerType] = factory +} + +func (m *VADManager) New(cfg VADConfig, providerType VADProviderType) (VADProcessor, error) { + if providerType == "" { + providerType = VADProviderHeuristic + } + + factory, ok := m.factories[providerType] + if !ok { + return nil, fmt.Errorf("vad: unsupported provider %q", providerType) + } + + return factory(cfg) +} diff --git a/pkg/speech/vad_webrtc.go b/pkg/speech/vad_webrtc.go new file mode 100644 index 00000000..9c5b41f2 --- /dev/null +++ b/pkg/speech/vad_webrtc.go @@ -0,0 +1,151 @@ +package speech + +import ( + "encoding/binary" + "fmt" + + webrtcvad "github.com/godeps/webrtcvad-go" +) + +type WebRTCVAD struct { + inner *VAD + detector *webrtcvad.VAD + mode int + sampleRate int + frameSize int + scratch []byte +} + +func NewWebRTCVAD(cfg VADConfig) (*WebRTCVAD, error) { + if cfg.SampleRate == 0 { + cfg.SampleRate = 16000 + } + if cfg.FrameSize == 0 { + cfg.FrameSize = 320 + } + if cfg.Aggressiveness < 0 || cfg.Aggressiveness > 3 { + cfg.Aggressiveness = 2 + } + + if !webrtcvad.ValidRateAndFrameLength(cfg.SampleRate, cfg.FrameSize) { + return nil, fmt.Errorf("vad: invalid WebRTC sampleRate/frameSize combination: %d/%d", cfg.SampleRate, cfg.FrameSize) + } + + detector, err := webrtcvad.New(cfg.Aggressiveness) + if err != nil { + return nil, fmt.Errorf("vad: failed to create WebRTC VAD: %w", err) + } + + return &WebRTCVAD{ + inner: NewVAD(cfg), + detector: detector, + mode: cfg.Aggressiveness, + sampleRate: cfg.SampleRate, + frameSize: cfg.FrameSize, + scratch: make([]byte, cfg.FrameSize*2), + }, nil +} + +func (v *WebRTCVAD) Name() string { + return "webrtc-vad" +} + +func (v *WebRTCVAD) Type() VADProviderType { + return VADProviderWebRTC +} + +func (v *WebRTCVAD) ProcessFrame(samples []int16) VADState { + if len(samples) == 0 { + return v.inner.ProcessFrame(samples) + } + + v.inner.mu.Lock() + audio := v.frameBytes(samples) + isSpeech, err := v.detector.IsSpeech(audio, v.sampleRate) + if err != nil { + v.inner.mu.Unlock() + return v.inner.ProcessFrame(samples) + } + + defer v.inner.mu.Unlock() + + energy := v.inner.calculateRMS(samples) + zcr := v.inner.calculateZeroCrossingRate(samples) + + if isSpeech { + v.inner.consecutiveSpeech++ + v.inner.consecutiveSilence = 0 + } else { + v.inner.consecutiveSilence++ + v.inner.consecutiveSpeech = 0 + } + + switch v.inner.state { + case VADStateSilence: + if isSpeech { + if v.inner.consecutiveSpeech >= v.inner.cfg.SpeechMinFrames { + v.inner.state = VADStateSpeech + v.inner.notifyListeners(VADStateSpeech, energy, zcr) + } + } else { + v.inner.consecutiveSpeech = 0 + } + + case VADStateSpeech: + if isSpeech { + v.inner.consecutiveSilence = 0 + } else { + if v.inner.consecutiveSilence >= v.inner.cfg.HangoverFrames { + v.inner.state = VADStateSilence + v.inner.consecutiveSpeech = 0 + v.inner.consecutiveSilence = 0 + v.inner.notifyListeners(VADStateSilence, energy, zcr) + } + } + } + + return v.inner.state +} + +func (v *WebRTCVAD) ProcessFloatFrame(samples []float32) VADState { + return v.ProcessFrame(Float32ToInt16(samples)) +} + +func (v *WebRTCVAD) RegisterListener(listener VADStateListener) { + v.inner.RegisterListener(listener) +} + +func (v *WebRTCVAD) State() VADState { + return v.inner.State() +} + +func (v *WebRTCVAD) Reset() { + v.inner.Reset() +} + +func (v *WebRTCVAD) UpdateConfig(cfg VADConfig) { + v.inner.UpdateConfig(cfg) + if cfg.Aggressiveness >= 0 && cfg.Aggressiveness <= 3 { + _ = v.detector.SetMode(cfg.Aggressiveness) + v.mode = cfg.Aggressiveness + } +} + +func (v *WebRTCVAD) Config() VADConfig { + cfg := v.inner.Config() + cfg.Aggressiveness = v.mode + return cfg +} + +func (v *WebRTCVAD) frameBytes(samples []int16) []byte { + size := len(samples) * 2 + if cap(v.scratch) < size { + v.scratch = make([]byte, size) + } + out := v.scratch[:size] + for i, s := range samples { + binary.LittleEndian.PutUint16(out[i*2:], uint16(s)) + } + return out +} + diff --git a/pkg/speech/voicewake.go b/pkg/speech/voicewake.go index 5807bbf9..61ae30a9 100644 --- a/pkg/speech/voicewake.go +++ b/pkg/speech/voicewake.go @@ -47,6 +47,7 @@ type AudioSource interface { type VoiceWakeConfig struct { VADConfig VADConfig + VADProvider VADProviderType WakeWordConfig WakeWordConfig EngineConfig WakeWordEngineConfig SampleRate int @@ -63,6 +64,7 @@ type VoiceWakeConfig struct { func DefaultVoiceWakeConfig() VoiceWakeConfig { return VoiceWakeConfig{ VADConfig: DefaultVADConfig(), + VADProvider: VADProviderWebRTC, WakeWordConfig: DefaultWakeWordConfig(), SampleRate: 16000, Channels: 1, @@ -77,7 +79,7 @@ type VoiceWake struct { mu sync.Mutex cfg VoiceWakeConfig state VoiceWakeState - vad *VAD + vad VADProcessor wakeDetector *WakeWordDetector engineRouter *WakeWordEngineRouter engineAdapter *WakeWordEngineAdapter @@ -120,7 +122,17 @@ func NewVoiceWake(cfg VoiceWakeConfig) *VoiceWake { cfg.EngineConfig.SampleRate = cfg.SampleRate cfg.EngineConfig.FrameSize = cfg.FrameSize - vad := NewVAD(cfg.VADConfig) + if cfg.VADProvider == "" { + cfg.VADProvider = VADProviderHeuristic + } + + vadManager := NewVADManager() + vad, err := vadManager.New(cfg.VADConfig, cfg.VADProvider) + if err != nil { + log.Printf("voicewake: failed to create VAD provider %q, fallback to heuristic: %v", cfg.VADProvider, err) + cfg.VADProvider = VADProviderHeuristic + vad = NewVAD(cfg.VADConfig) + } wakeDetector := NewWakeWordDetector(cfg.WakeWordConfig) router := NewWakeWordEngineRouter(cfg.EngineConfig) @@ -489,7 +501,7 @@ func (vw *VoiceWake) LastWakeMatch() (string, float64) { return vw.lastWakeMatch, vw.lastConfidence } -func (vw *VoiceWake) VAD() *VAD { +func (vw *VoiceWake) VAD() VADProcessor { return vw.vad }