Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion commands/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ func commandQuery(c *cli.Context) error {
// Print newline after streaming
fmt.Println()

newCommand := strings.TrimSpace(result.String())
newCommand := sanitizeSuggestedCommand(result.String())
if newCommand == "" {
color.Red.Println("AI returned an empty response. Try rephrasing your query.")
return fmt.Errorf("empty AI response")
}

// Check auto-run configuration
if cfg.AI != nil && (cfg.AI.Agent.View || cfg.AI.Agent.Edit || cfg.AI.Agent.Delete) {
Expand Down Expand Up @@ -188,6 +192,39 @@ func executeCommand(ctx context.Context, command string) error {
return nil
}

// sanitizeSuggestedCommand normalizes raw AI output into an executable command.
// It strips triple-backtick fences (with optional language tag like bash, sh,
// zsh, fish, pwsh, powershell), strips surrounding single backticks when the
// result is a single line, and trims whitespace. Responses that start with `#`
// are treated as refusal comments and preserved verbatim so the caller can
// surface them to the user without attempting execution.
func sanitizeSuggestedCommand(raw string) string {
s := strings.TrimSpace(raw)
if s == "" {
return ""
}

if strings.HasPrefix(s, "```") {
s = strings.TrimPrefix(s, "```")
if nl := strings.IndexByte(s, '\n'); nl >= 0 {
switch strings.ToLower(strings.TrimSpace(s[:nl])) {
case "", "bash", "sh", "shell", "zsh", "fish", "pwsh", "powershell":
s = s[nl+1:]
}
}
s = strings.TrimRight(s, " \t\n")
s = strings.TrimSuffix(s, "```")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Truncate fenced output at first closing backticks

When the AI returns a fenced command followed by explanatory prose (for example bash\nls -la\n```\nThis lists files), sanitizeSuggestedCommand only removes a closing fence if it is the final suffix, so it leaves \n```\nThis lists files in the command. Because classification uses the first token, this can now be auto-run as a view/edit command while still containing non-command prose, causing unintended shell execution and failures whenever auto-run is enabled. The sanitizer should extract content up to the first closing fence and ignore trailing text before classification/execution.

Useful? React with 👍 / 👎.

s = strings.TrimSpace(s)
}

if !strings.ContainsRune(s, '\n') && len(s) >= 2 &&
strings.HasPrefix(s, "`") && strings.HasSuffix(s, "`") {
s = strings.TrimSpace(s[1 : len(s)-1])
}

return s
}

func getSystemContext(query string) (model.CommandSuggestVariables, error) {
// Get shell information
shell := os.Getenv("SHELL")
Expand Down
37 changes: 36 additions & 1 deletion commands/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,42 @@ func (s *queryTestSuite) TestQueryCommandEmptyAIResponse() {
}

err := s.app.Run(command)
assert.Nil(s.T(), err)
assert.NotNil(s.T(), err)
assert.Contains(s.T(), err.Error(), "empty AI response")
}

func (s *queryTestSuite) TestSanitizeSuggestedCommand() {
tests := []struct {
name string
in string
want string
}{
{"plain", "ls -la", "ls -la"},
{"trims whitespace", " ls -la \n\t", "ls -la"},
{"fence with bash tag", "```bash\necho hi\n```", "echo hi"},
{"fence with sh tag", "```sh\necho hi\n```", "echo hi"},
{"fence with zsh tag", "```zsh\necho hi\n```", "echo hi"},
{"fence with shell tag", "```shell\necho hi\n```", "echo hi"},
{"fence with fish tag", "```fish\nset -x FOO bar\n```", "set -x FOO bar"},
{"fence with powershell tag", "```powershell\nGet-Process\n```", "Get-Process"},
{"fence with pwsh tag", "```pwsh\nGet-Process\n```", "Get-Process"},
{"fence no language tag", "```\necho hi\n```", "echo hi"},
{"fence with trailing newline before closing", "```bash\nls -la\n\n```", "ls -la"},
{"single backticks around single-line", "`ls -la`", "ls -la"},
{"single backticks with surrounding space", " `ls -la` ", "ls -la"},
{"only whitespace", " \n\t ", ""},
{"empty", "", ""},
{"comment passthrough preserved", "# refusing: unsafe request", "# refusing: unsafe request"},
{"multiline without fences kept", "ls\ncat foo", "ls\ncat foo"},
}
for _, tt := range tests {
s.T().Run(tt.name, func(t *testing.T) {
got := sanitizeSuggestedCommand(tt.in)
if got != tt.want {
t.Errorf("sanitizeSuggestedCommand(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}

func (s *queryTestSuite) TestQueryCommandDescription() {
Expand Down
36 changes: 26 additions & 10 deletions model/ai_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,26 @@ func (s *sseAIService) QueryCommandStream(
for scanner.Scan() {
line := scanner.Text()

if line == "event: error" {
isError = true
if line == "" {
isError = false
continue
}

if strings.HasPrefix(line, "data:") {
data := line[len("data:"):]
if v, ok := stripSSEField(line, "event:"); ok {
if v == "error" {
isError = true
}
Comment on lines +81 to +83
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The isError state should be explicitly set based on whether the event type is exactly "error". Currently, if an event is marked as "error" and then a subsequent event in the same stream (before a blank line) has a different event type, isError will remain true. While SSE events are typically separated by blank lines, it's safer to update the state for every event: field encountered.

Suggested change
if v == "error" {
isError = true
}
if v, ok := stripSSEField(line, "event:"); ok {
isError = (v == "error")
continue
}

continue
}

if v, ok := stripSSEField(line, "data:"); ok {
if isError {
return fmt.Errorf("server error: %s", data)
return fmt.Errorf("server error: %s", v)
}
Comment on lines 88 to 90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the server sends a multi-line error message (multiple data: lines for a single event: error), this implementation will return an error containing only the first line. Consider buffering the error message until the event is fully received (at the blank line) or concatenating it if multi-line errors are expected.


if data == "[DONE]" {
if v == "[DONE]" {
return nil
}

onToken(data)
isError = false
onToken(v)
}
}

Expand All @@ -99,3 +101,17 @@ func (s *sseAIService) QueryCommandStream(

return nil
}

// stripSSEField returns the value after prefix, stripping one optional leading
// space per the SSE specification (§9.2 "If value starts with a U+0020 SPACE
// character, remove it from value").
func stripSSEField(line, prefix string) (string, bool) {
if !strings.HasPrefix(line, prefix) {
return "", false
}
v := line[len(prefix):]
if strings.HasPrefix(v, " ") {
v = v[1:]
}
return v, true
}
122 changes: 122 additions & 0 deletions model/ai_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand Down Expand Up @@ -90,3 +91,124 @@ func TestQueryCommandStream_ErrorResponseBody(t *testing.T) {
})
}
}

func TestQueryCommandStream_SSEParsing(t *testing.T) {
tests := []struct {
name string
body string
wantErr bool
wantErrSubstr string
wantTokens []string
}{
{
name: "data with space and [DONE] terminates cleanly",
body: "data: [DONE]\n\n",
wantTokens: nil,
},
{
name: "data without space and [DONE] terminates cleanly",
body: "data:[DONE]\n\n",
wantTokens: nil,
},
{
name: "single data token with leading space is stripped",
body: "data: hello\n\ndata: [DONE]\n\n",
wantTokens: []string{"hello"},
},
{
name: "single data token without leading space passes through",
body: "data:hello\n\ndata:[DONE]\n\n",
wantTokens: []string{"hello"},
},
{
name: "multi-token stream concatenates without spurious spaces",
body: "data: ls\n\ndata: -la\n\ndata: [DONE]\n\n",
wantTokens: []string{"ls", " -la"},
},
{
name: "event error with space",
body: "event: error\ndata: boom\n\n",
wantErr: true,
wantErrSubstr: "boom",
},
{
name: "event error without space",
body: "event:error\ndata:boom\n\n",
wantErr: true,
wantErrSubstr: "boom",
},
{
name: "blank line resets error state between events",
body: "event: error\n\ndata: hello\n\ndata: [DONE]\n\n",
wantTokens: []string{"hello"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()

var got []string
svc := NewAIService()
err := svc.QueryCommandStream(
context.Background(),
CommandSuggestVariables{Shell: "bash", Os: "linux", Query: "test"},
Endpoint{APIEndpoint: server.URL, Token: "test-token"},
func(token string) { got = append(got, token) },
)

if tt.wantErr {
if err == nil {
t.Fatalf("expected error, got nil (tokens=%v)", got)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("expected error to contain %q, got %q", tt.wantErrSubstr, err.Error())
}
return
}

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(got) != len(tt.wantTokens) {
t.Fatalf("token count mismatch: want %d %v, got %d %v", len(tt.wantTokens), tt.wantTokens, len(got), got)
}
for i, tok := range tt.wantTokens {
if got[i] != tok {
t.Errorf("token[%d] = %q, want %q", i, got[i], tok)
}
}
})
}
}

func TestStripSSEField(t *testing.T) {
tests := []struct {
name string
line string
prefix string
wantVal string
wantOk bool
}{
{"no match", "foo:bar", "data:", "", false},
{"match no space", "data:hello", "data:", "hello", true},
{"match one space stripped", "data: hello", "data:", "hello", true},
{"match two spaces preserves second", "data: hello", "data:", " hello", true},
{"empty value no space", "data:", "data:", "", true},
{"empty value one space", "data: ", "data:", "", true},
{"event error with space", "event: error", "event:", "error", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v, ok := stripSSEField(tt.line, tt.prefix)
if ok != tt.wantOk || v != tt.wantVal {
t.Errorf("stripSSEField(%q, %q) = (%q, %v), want (%q, %v)", tt.line, tt.prefix, v, ok, tt.wantVal, tt.wantOk)
}
})
}
}
Loading