Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions model/ai_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -56,6 +57,13 @@ func (s *sseAIService) QueryCommandStream(
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(resp.Body)
if readErr == nil {
var errResp errorResponse
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.ErrorMessage != "" {
return fmt.Errorf("%s", errResp.ErrorMessage)
}
}
Comment on lines +60 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The errorResponse struct is used here but not defined in ai_service.go. It should be defined in model/ai_service.go or imported from model/api.base.go if it's a common type. Defining it directly in ai_service.go would make this file self-contained for its error handling logic.

body, readErr := io.ReadAll(resp.Body)
if readErr == nil {
	type errorResponse struct {
		ErrorCode    int    `json:"errorCode"`
		ErrorMessage string `json:"errorMessage"`
	}
	var errResp errorResponse
	if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.ErrorMessage != "" {
		return fmt.Errorf("%s", errResp.ErrorMessage)
	}
}

return fmt.Errorf("server returned status %d", resp.StatusCode)
}

Expand Down
92 changes: 92 additions & 0 deletions model/ai_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package model

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)

func TestQueryCommandStream_ErrorResponseBody(t *testing.T) {
tests := []struct {
name string
statusCode int
responseBody interface{}
expectedErrMsg string
}{
{
name: "quota exceeded returns error message from body",
statusCode: http.StatusTooManyRequests,
responseBody: errorResponse{
ErrorCode: http.StatusTooManyRequests,
ErrorMessage: "monthly AI credit quota exceeded",
},
expectedErrMsg: "monthly AI credit quota exceeded",
},
{
name: "unauthorized returns error message from body",
statusCode: http.StatusUnauthorized,
responseBody: errorResponse{
ErrorCode: http.StatusUnauthorized,
ErrorMessage: "unauthorized",
},
expectedErrMsg: "unauthorized",
},
{
name: "service unavailable returns error message from body",
statusCode: http.StatusServiceUnavailable,
responseBody: errorResponse{
ErrorCode: http.StatusServiceUnavailable,
ErrorMessage: "AI service is not available",
},
expectedErrMsg: "AI service is not available",
},
{
name: "non-JSON response falls back to status code",
statusCode: http.StatusInternalServerError,
responseBody: "not json",
expectedErrMsg: fmt.Sprintf("server returned status %d", http.StatusInternalServerError),
},
{
name: "empty error message falls back to status code",
statusCode: http.StatusBadRequest,
responseBody: errorResponse{
ErrorCode: http.StatusBadRequest,
ErrorMessage: "",
},
expectedErrMsg: fmt.Sprintf("server returned status %d", http.StatusBadRequest),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.statusCode)
if s, ok := tt.responseBody.(string); ok {
w.Write([]byte(s))
} else {
json.NewEncoder(w).Encode(tt.responseBody)
}
}))
defer server.Close()

svc := NewAIService()
err := svc.QueryCommandStream(
context.Background(),
CommandSuggestVariables{Shell: "bash", Os: "linux", Query: "test"},
Endpoint{APIEndpoint: server.URL, Token: "test-token"},
func(token string) {},
)

if err == nil {
t.Fatal("expected error, got nil")
}
if err.Error() != tt.expectedErrMsg {
t.Errorf("expected error %q, got %q", tt.expectedErrMsg, err.Error())
}
})
}
}