From fa98c37313741f10f0c69e96e7100b66d33da1e4 Mon Sep 17 00:00:00 2001 From: Harshaneel Gokhale Date: Mon, 18 May 2026 10:41:42 -0700 Subject: [PATCH 1/3] feat: Add Models list/get, CountTokens, and legacy completions routes Closes the most-used SDK surface gaps for client init and tokenizer checks. Six new routes, all backed by existing llama.cpp endpoints: - POST /v1/completions (OpenAI legacy passthrough) - GET /v1/models (OpenAI passthrough) - GET /v1/models/{id} (OpenAI passthrough) - GET /v1beta/models (Gemini, translated from /v1/models) - GET /v1beta/models/{name} (Gemini, translated) - POST /v1beta/models/{m}:countTokens (Gemini, calls llama.cpp /tokenize) Adds SDK contract tests exercising genai's Models.List / Models.Get / Models.CountTokens and openai-go's Models.List / Models.Get / Completions.New against the new routes. Also includes an unrelated gofmt drive-by on examples/go/gemini-structured/main.go which was breaking make lint on main. ComputeTokens is intentionally not implemented: the genai SDK gates it behind BackendVertexAI and never reaches a Gemini-Developer-API proxy. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 13 +- examples/go/gemini-structured/main.go | 4 +- integration/sdk_meta_test.go | 262 ++++++++++++++++++++++++++ internal/protocol/gemini/types.go | 23 +++ internal/protocol/openai/types.go | 12 ++ internal/server/gemini_meta.go | 101 ++++++++++ internal/server/meta_test.go | 206 ++++++++++++++++++++ internal/server/openai.go | 51 ++++- internal/server/server.go | 40 ++-- internal/translate/models.go | 50 +++++ internal/translate/models_test.go | 88 +++++++++ 11 files changed, 832 insertions(+), 18 deletions(-) create mode 100644 integration/sdk_meta_test.go create mode 100644 internal/server/gemini_meta.go create mode 100644 internal/server/meta_test.go create mode 100644 internal/translate/models.go create mode 100644 internal/translate/models_test.go diff --git a/README.md b/README.md index 84a9595..e38a87d 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,13 @@ services: | --------------------------------------------------- | ------------------------------ | --------------------------------------- | | `POST /v1beta/models/{model}:generateContent` | Gemini `GenerateContent` | Translated to upstream chat completions | | `POST /v1beta/models/{model}:streamGenerateContent` | Gemini `GenerateContentStream` | Gemini-style SSE (typically `?alt=sse`) | +| `POST /v1beta/models/{model}:countTokens` | Gemini `CountTokens` | Translated to upstream `/tokenize` | +| `GET /v1beta/models` | Gemini `Models.List` | Translated from upstream `/v1/models` | +| `GET /v1beta/models/{model}` | Gemini `Models.Get` | Translated from upstream `/v1/models` | | `POST /v1/chat/completions` | OpenAI chat completions | Forwarded to upstream | +| `POST /v1/completions` | OpenAI legacy completions | Forwarded to upstream | +| `GET /v1/models` | OpenAI `Models.List` | Forwarded to upstream | +| `GET /v1/models/{model}` | OpenAI `Models.Retrieve` | Forwarded to upstream | | `GET /health` | Health checks | Custom route | @@ -227,13 +233,14 @@ jobs: **Not supported:** -- SDK methods outside `GenerateContent` / `GenerateContentStream` +- SDK methods outside `GenerateContent` / `GenerateContentStream` / `CountTokens` / `Models.List` / `Models.Get` - Non-function tools (Google Search, Maps, URL context, code execution) -- Embeddings, token counting, cached content, live/bidi sessions, uploads +- Embeddings, cached content, live/bidi sessions, uploads +- `ComputeTokens` is Vertex-only in the SDK and not exposed on `BackendGeminiAPI` ## OpenAI compatibility -**Supported:** text chat completions, structured output, vision inputs, tool-related fields (all passed through to upstream). +**Supported:** text chat completions, legacy `/v1/completions`, `Models.List` / `Models.Retrieve`, structured output, vision inputs, tool-related fields (all passed through to upstream). **Not supported:** Responses API, Assistants, Embeddings, Images, Audio, Files, Vector stores. diff --git a/examples/go/gemini-structured/main.go b/examples/go/gemini-structured/main.go index 4f55950..df6cd92 100644 --- a/examples/go/gemini-structured/main.go +++ b/examples/go/gemini-structured/main.go @@ -32,8 +32,8 @@ func main() { Items: &genai.Schema{ Type: genai.TypeObject, Properties: map[string]*genai.Schema{ - "name": {Type: genai.TypeString}, - "year": {Type: genai.TypeInteger}, + "name": {Type: genai.TypeString}, + "year": {Type: genai.TypeInteger}, "use_case": {Type: genai.TypeString}, }, Required: []string{"name", "year", "use_case"}, diff --git a/integration/sdk_meta_test.go b/integration/sdk_meta_test.go new file mode 100644 index 0000000..f6a7c64 --- /dev/null +++ b/integration/sdk_meta_test.go @@ -0,0 +1,262 @@ +package integration + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/harshaneel/localaik/internal/pdf" + openaip "github.com/harshaneel/localaik/internal/protocol/openai" + "github.com/harshaneel/localaik/internal/server" + + openaisdk "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + genaisdk "google.golang.org/genai" +) + +func TestSDKGenAIModelsList(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Fatalf("upstream path = %q, want /v1/models", r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.ModelList{ + Object: "list", + Data: []openaip.Model{ + {ID: "gemma-3-4b", Object: "model"}, + {ID: "gemma-3-12b", Object: "model"}, + }, + }) + }) + + proxy := newCapturedProxyHandlerForUpstream(t, upstream) + + client, err := genaisdk.NewClient(context.Background(), &genaisdk.ClientConfig{ + APIKey: "test", + Backend: genaisdk.BackendGeminiAPI, + HTTPClient: &http.Client{ + Transport: newHandlerTransport(proxy), + }, + HTTPOptions: genaisdk.HTTPOptions{ + BaseURL: "http://localaik.test", + }, + }) + if err != nil { + t.Fatalf("genai.NewClient returned error: %v", err) + } + + page, err := client.Models.List(context.Background(), nil) + if err != nil { + t.Fatalf("Models.List returned error: %v", err) + } + + if len(page.Items) != 2 { + t.Fatalf("page.Items = %#v, want 2 models", page.Items) + } + if page.Items[0].Name != "models/gemma-3-4b" { + t.Fatalf("page.Items[0].Name = %q, want models/gemma-3-4b", page.Items[0].Name) + } +} + +func TestSDKGenAIModelsGet(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/gemma-3-4b" { + t.Fatalf("upstream path = %q, want /v1/models/gemma-3-4b", r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.Model{ID: "gemma-3-4b", Object: "model"}) + }) + + proxy := newCapturedProxyHandlerForUpstream(t, upstream) + + client, err := genaisdk.NewClient(context.Background(), &genaisdk.ClientConfig{ + APIKey: "test", + Backend: genaisdk.BackendGeminiAPI, + HTTPClient: &http.Client{ + Transport: newHandlerTransport(proxy), + }, + HTTPOptions: genaisdk.HTTPOptions{ + BaseURL: "http://localaik.test", + }, + }) + if err != nil { + t.Fatalf("genai.NewClient returned error: %v", err) + } + + model, err := client.Models.Get(context.Background(), "gemma-3-4b", nil) + if err != nil { + t.Fatalf("Models.Get returned error: %v", err) + } + + if model.Name != "models/gemma-3-4b" { + t.Fatalf("model.Name = %q, want models/gemma-3-4b", model.Name) + } +} + +func TestSDKGenAICountTokens(t *testing.T) { + var upstreamBody map[string]any + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tokenize" { + t.Fatalf("upstream path = %q, want /tokenize", r.URL.Path) + } + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream body: %v", err) + } + writeJSON(w, http.StatusOK, map[string]any{"tokens": []int{1, 2, 3, 4, 5, 6, 7}}) + }) + + proxy := newCapturedProxyHandlerForUpstream(t, upstream) + + client, err := genaisdk.NewClient(context.Background(), &genaisdk.ClientConfig{ + APIKey: "test", + Backend: genaisdk.BackendGeminiAPI, + HTTPClient: &http.Client{ + Transport: newHandlerTransport(proxy), + }, + HTTPOptions: genaisdk.HTTPOptions{ + BaseURL: "http://localaik.test", + }, + }) + if err != nil { + t.Fatalf("genai.NewClient returned error: %v", err) + } + + resp, err := client.Models.CountTokens(context.Background(), "gemma-3-4b", genaisdk.Text("hello world"), nil) + if err != nil { + t.Fatalf("Models.CountTokens returned error: %v", err) + } + + if resp.TotalTokens != 7 { + t.Fatalf("TotalTokens = %d, want 7", resp.TotalTokens) + } + if upstreamBody["content"] != "hello world" { + t.Fatalf("upstream content = %#v, want hello world", upstreamBody["content"]) + } +} + +func TestSDKOpenAIModelsList(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Fatalf("upstream path = %q, want /v1/models", r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.ModelList{ + Object: "list", + Data: []openaip.Model{ + {ID: "gemma-3-4b", Object: "model"}, + }, + }) + }) + + proxy := newCapturedProxyHandlerForUpstream(t, upstream) + + client := openaisdk.NewClient( + option.WithBaseURL("http://localaik.test/v1/"), + option.WithAPIKey("test"), + option.WithHTTPClient(&http.Client{Transport: newHandlerTransport(proxy)}), + ) + + page, err := client.Models.List(context.Background()) + if err != nil { + t.Fatalf("Models.List returned error: %v", err) + } + if len(page.Data) != 1 || page.Data[0].ID != "gemma-3-4b" { + t.Fatalf("page.Data = %#v", page.Data) + } +} + +func TestSDKOpenAIModelsGet(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/gemma-3-4b" { + t.Fatalf("upstream path = %q, want /v1/models/gemma-3-4b", r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.Model{ID: "gemma-3-4b", Object: "model"}) + }) + + proxy := newCapturedProxyHandlerForUpstream(t, upstream) + + client := openaisdk.NewClient( + option.WithBaseURL("http://localaik.test/v1/"), + option.WithAPIKey("test"), + option.WithHTTPClient(&http.Client{Transport: newHandlerTransport(proxy)}), + ) + + model, err := client.Models.Get(context.Background(), "gemma-3-4b") + if err != nil { + t.Fatalf("Models.Get returned error: %v", err) + } + if model.ID != "gemma-3-4b" { + t.Fatalf("model.ID = %q, want gemma-3-4b", model.ID) + } +} + +func TestSDKOpenAILegacyCompletions(t *testing.T) { + var upstreamBody map[string]any + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Fatalf("upstream path = %q, want /v1/completions", r.URL.Path) + } + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream body: %v", err) + } + writeJSON(w, http.StatusOK, map[string]any{ + "id": "cmpl_1", + "object": "text_completion", + "created": 1, + "model": "localaik", + "choices": []map[string]any{ + {"index": 0, "text": "pong", "finish_reason": "stop"}, + }, + }) + }) + + proxy := newCapturedProxyHandlerForUpstream(t, upstream) + + client := openaisdk.NewClient( + option.WithBaseURL("http://localaik.test/v1/"), + option.WithAPIKey("test"), + option.WithHTTPClient(&http.Client{Transport: newHandlerTransport(proxy)}), + ) + + resp, err := client.Completions.New(context.Background(), openaisdk.CompletionNewParams{ + Model: "localaik", + Prompt: openaisdk.CompletionNewParamsPromptUnion{OfString: openaisdk.String("ping")}, + }) + if err != nil { + t.Fatalf("Completions.New returned error: %v", err) + } + if resp.Choices[0].Text != "pong" { + t.Fatalf("response = %#v", resp.Choices) + } + if upstreamBody["prompt"] != "ping" { + t.Fatalf("upstream prompt = %#v, want ping", upstreamBody["prompt"]) + } +} + +// newCapturedProxyHandlerForUpstream wires up the localaik proxy with an +// arbitrary upstream handler so individual tests can stub /v1/models, +// /tokenize, /v1/completions, etc. without sharing routing logic with the +// chat-completions tests in sdk_test.go. +func newCapturedProxyHandlerForUpstream(t *testing.T, upstream http.Handler) http.Handler { + t.Helper() + + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + // Route everything else to the test-supplied upstream. + mux.Handle("/", upstream) + + proxyServer, err := server.New(server.Config{ + UpstreamBaseURL: "http://upstream.test/v1", + HTTPClient: &http.Client{ + Transport: newHandlerTransport(mux), + }, + PDFRenderer: pdf.RendererFunc(func(_ context.Context, _ []byte) ([][]byte, error) { return nil, nil }), + }) + if err != nil { + t.Fatalf("server.New returned error: %v", err) + } + + return proxyServer +} diff --git a/internal/protocol/gemini/types.go b/internal/protocol/gemini/types.go index a4f1b1a..c8b2867 100644 --- a/internal/protocol/gemini/types.go +++ b/internal/protocol/gemini/types.go @@ -133,6 +133,29 @@ type UsageMetadata struct { TotalTokenCount int `json:"totalTokenCount,omitempty"` } +type Model struct { + Name string `json:"name,omitempty"` + DisplayName string `json:"displayName,omitempty"` + Description string `json:"description,omitempty"` + Version string `json:"version,omitempty"` + InputTokenLimit int `json:"inputTokenLimit,omitempty"` + OutputTokenLimit int `json:"outputTokenLimit,omitempty"` + SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"` +} + +type ListModelsResponse struct { + Models []Model `json:"models"` + NextPageToken string `json:"nextPageToken,omitempty"` +} + +type CountTokensRequest struct { + Contents []Content `json:"contents,omitempty"` +} + +type CountTokensResponse struct { + TotalTokens int `json:"totalTokens"` +} + type ErrorResponse struct { Error Error `json:"error"` } diff --git a/internal/protocol/openai/types.go b/internal/protocol/openai/types.go index f7dcbe4..3c6f1f4 100644 --- a/internal/protocol/openai/types.go +++ b/internal/protocol/openai/types.go @@ -119,6 +119,18 @@ type ToolCallDelta struct { Function *ToolCallFunction `json:"function,omitempty"` } +type Model struct { + ID string `json:"id"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` +} + +type ModelList struct { + Object string `json:"object,omitempty"` + Data []Model `json:"data"` +} + type ErrorResponse struct { Error Error `json:"error"` } diff --git a/internal/server/gemini_meta.go b/internal/server/gemini_meta.go new file mode 100644 index 0000000..b38bf24 --- /dev/null +++ b/internal/server/gemini_meta.go @@ -0,0 +1,101 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/harshaneel/localaik/internal/protocol/gemini" + openaip "github.com/harshaneel/localaik/internal/protocol/openai" + "github.com/harshaneel/localaik/internal/translate" +) + +func (s *Server) handleGeminiModelsList(w http.ResponseWriter, r *http.Request) { + var upstream openaip.ModelList + status, body, err := s.fetchUpstreamJSON(r, s.upstreamModelsURL, &upstream) + if err != nil { + gemini.WriteError(w, http.StatusBadGateway, fmt.Sprintf("failed to reach upstream: %v", err)) + return + } + if status >= http.StatusBadRequest { + writeJSON(w, status, translate.OpenAIErrorToGemini(status, body)) + return + } + writeJSON(w, http.StatusOK, translate.OpenAIModelListToGemini(upstream)) +} + +func (s *Server) handleGeminiModelGet(w http.ResponseWriter, r *http.Request) { + modelName := strings.TrimPrefix(r.URL.Path, "/v1beta/models/") + if modelName == "" || strings.ContainsAny(modelName, "/:") { + gemini.WriteError(w, http.StatusNotFound, "route not found") + return + } + + var upstream openaip.Model + status, body, err := s.fetchUpstreamJSON(r, s.upstreamModelsURL+"/"+modelName, &upstream) + if err != nil { + gemini.WriteError(w, http.StatusBadGateway, fmt.Sprintf("failed to reach upstream: %v", err)) + return + } + if status >= http.StatusBadRequest { + writeJSON(w, status, translate.OpenAIErrorToGemini(status, body)) + return + } + writeJSON(w, http.StatusOK, translate.OpenAIModelToGemini(upstream)) +} + +func (s *Server) handleGeminiCountTokens(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + var req gemini.CountTokensRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + gemini.WriteError(w, http.StatusBadRequest, fmt.Sprintf("invalid countTokens request: %v", err)) + return + } + + upstreamPayload, err := json.Marshal(map[string]any{ + "content": translate.CountTokensTextFromGemini(req.Contents), + "add_special": false, + }) + if err != nil { + gemini.WriteError(w, http.StatusInternalServerError, "failed to serialize upstream request") + return + } + + upstreamReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, s.upstreamTokenizeURL, bytes.NewReader(upstreamPayload)) + if err != nil { + gemini.WriteError(w, http.StatusInternalServerError, "failed to create upstream request") + return + } + upstreamReq.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(upstreamReq) + if err != nil { + gemini.WriteError(w, http.StatusBadGateway, fmt.Sprintf("failed to reach upstream: %v", err)) + return + } + defer resp.Body.Close() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + gemini.WriteError(w, http.StatusBadGateway, fmt.Sprintf("failed to read upstream response: %v", readErr)) + return + } + if resp.StatusCode >= http.StatusBadRequest { + writeJSON(w, resp.StatusCode, translate.OpenAIErrorToGemini(resp.StatusCode, body)) + return + } + + var upstreamResp struct { + Tokens []any `json:"tokens"` + } + if err := json.Unmarshal(body, &upstreamResp); err != nil { + gemini.WriteError(w, http.StatusBadGateway, fmt.Sprintf("failed to parse upstream response: %v", err)) + return + } + + writeJSON(w, http.StatusOK, gemini.CountTokensResponse{TotalTokens: len(upstreamResp.Tokens)}) +} diff --git a/internal/server/meta_test.go b/internal/server/meta_test.go new file mode 100644 index 0000000..723194a --- /dev/null +++ b/internal/server/meta_test.go @@ -0,0 +1,206 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/harshaneel/localaik/internal/pdf" + "github.com/harshaneel/localaik/internal/protocol/gemini" + openaip "github.com/harshaneel/localaik/internal/protocol/openai" +) + +func newTestServer(t *testing.T, upstream http.Handler) *Server { + t.Helper() + srv, err := New(Config{ + UpstreamBaseURL: "http://upstream.test/v1", + HTTPClient: &http.Client{ + Transport: roundTripHandler{handler: upstream}, + }, + PDFRenderer: pdf.RendererFunc(func(_ context.Context, _ []byte) ([][]byte, error) { return nil, nil }), + }) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + return srv +} + +func TestServerOpenAILegacyCompletionsPassthrough(t *testing.T) { + var seenPath, seenMethod string + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenPath, seenMethod = r.URL.Path, r.Method + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"cmpl_1","object":"text_completion"}`)) + }) + + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(http.MethodPost, "/v1/completions", bytes.NewBufferString(`{"prompt":"hi"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + if seenPath != "/v1/completions" || seenMethod != http.MethodPost { + t.Fatalf("upstream got %s %s, want POST /v1/completions", seenMethod, seenPath) + } +} + +func TestServerOpenAIModelsList(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" || r.Method != http.MethodGet { + t.Fatalf("upstream got %s %s, want GET /v1/models", r.Method, r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.ModelList{ + Object: "list", + Data: []openaip.Model{{ID: "gemma-3-4b", Object: "model"}}, + }) + }) + + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + var got openaip.ModelList + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(got.Data) != 1 || got.Data[0].ID != "gemma-3-4b" { + t.Fatalf("response = %#v", got) + } +} + +func TestServerOpenAIModelGet(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/gemma-3-4b" || r.Method != http.MethodGet { + t.Fatalf("upstream got %s %s, want GET /v1/models/gemma-3-4b", r.Method, r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.Model{ID: "gemma-3-4b", Object: "model"}) + }) + + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(http.MethodGet, "/v1/models/gemma-3-4b", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } +} + +func TestServerGeminiModelsList(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" || r.Method != http.MethodGet { + t.Fatalf("upstream got %s %s, want GET /v1/models", r.Method, r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.ModelList{ + Object: "list", + Data: []openaip.Model{{ID: "gemma-3-4b", Object: "model"}}, + }) + }) + + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/models", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var got gemini.ListModelsResponse + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(got.Models) != 1 || got.Models[0].Name != "models/gemma-3-4b" { + t.Fatalf("response = %#v", got) + } +} + +func TestServerGeminiModelGet(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/gemma-3-4b" { + t.Fatalf("upstream path = %q, want /v1/models/gemma-3-4b", r.URL.Path) + } + writeJSON(w, http.StatusOK, openaip.Model{ID: "gemma-3-4b"}) + }) + + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/models/gemma-3-4b", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + var got gemini.Model + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if got.Name != "models/gemma-3-4b" { + t.Fatalf("response = %#v", got) + } +} + +func TestServerGeminiCountTokens(t *testing.T) { + var upstreamBody map[string]any + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/tokenize" || r.Method != http.MethodPost { + t.Fatalf("upstream got %s %s, want POST /tokenize", r.Method, r.URL.Path) + } + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream body: %v", err) + } + writeJSON(w, http.StatusOK, map[string]any{"tokens": []int{1, 2, 3, 4, 5}}) + }) + + srv := newTestServer(t, upstream) + + body := `{"contents":[{"role":"user","parts":[{"text":"hello"},{"text":"world"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemma-3-4b:countTokens", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200; body=%s", rec.Code, rec.Body.String()) + } + if upstreamBody["content"] != "hello\nworld" { + t.Fatalf("upstream content = %#v, want hello\\nworld", upstreamBody["content"]) + } + var got gemini.CountTokensResponse + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode response: %v", err) + } + if got.TotalTokens != 5 { + t.Fatalf("totalTokens = %d, want 5", got.TotalTokens) + } +} + +func TestServerGeminiModelGetMalformedPath(t *testing.T) { + srv := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("upstream should not be reached, got %s", r.URL.Path) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/models/foo/bar", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", rec.Code) + } +} diff --git a/internal/server/openai.go b/internal/server/openai.go index cbc7d1d..ed386ca 100644 --- a/internal/server/openai.go +++ b/internal/server/openai.go @@ -1,15 +1,17 @@ package server import ( + "encoding/json" "fmt" "io" "net/http" + "strings" openaip "github.com/harshaneel/localaik/internal/protocol/openai" ) -func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Request) { - req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, s.upstreamChatURL, r.Body) +func (s *Server) handleOpenAIPassthrough(w http.ResponseWriter, r *http.Request, upstreamURL string) { + req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, r.Body) if err != nil { openaip.WriteError(w, http.StatusInternalServerError, "failed to create upstream request", "server_error") return @@ -34,6 +36,19 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ _, _ = io.Copy(w, resp.Body) } +func (s *Server) handleOpenAIModelsList(w http.ResponseWriter, r *http.Request) { + s.handleOpenAIPassthrough(w, r, s.upstreamModelsURL) +} + +func (s *Server) handleOpenAIModelGet(w http.ResponseWriter, r *http.Request) { + modelID := strings.TrimPrefix(r.URL.Path, "/v1/models/") + if modelID == "" || strings.Contains(modelID, "/") { + openaip.WriteError(w, http.StatusNotFound, "route not found", "invalid_request_error") + return + } + s.handleOpenAIPassthrough(w, r, s.upstreamModelsURL+"/"+modelID) +} + type flushWriter struct { Writer io.Writer Flusher http.Flusher @@ -46,3 +61,35 @@ func (w flushWriter) Write(p []byte) (int, error) { } return n, err } + +// fetchUpstreamJSON issues a GET to upstreamURL and decodes the JSON response +// into dst. On non-2xx status it returns (status, body, nil) so the caller can +// translate the upstream error. On decode failure it returns (status, body, err) +// — body is the raw upstream payload but is malformed JSON, so callers should +// surface the error rather than the body. No request headers are forwarded; +// this is intentional so that client-side auth (Authorization, X-Goog-Api-Key) +// does not leak to upstream. +func (s *Server) fetchUpstreamJSON(r *http.Request, upstreamURL string, dst any) (int, []byte, error) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) + if err != nil { + return 0, nil, fmt.Errorf("create upstream request: %w", err) + } + + resp, err := s.client.Do(req) + if err != nil { + return 0, nil, fmt.Errorf("reach upstream: %w", err) + } + defer resp.Body.Close() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, nil, fmt.Errorf("read upstream body: %w", readErr) + } + if resp.StatusCode >= http.StatusBadRequest { + return resp.StatusCode, body, nil + } + if err := json.Unmarshal(body, dst); err != nil { + return resp.StatusCode, body, fmt.Errorf("decode upstream body: %w", err) + } + return resp.StatusCode, body, nil +} diff --git a/internal/server/server.go b/internal/server/server.go index f6339a8..0407570 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,10 +21,13 @@ type Config struct { } type Server struct { - client *http.Client - pdfRenderer pdf.Renderer - upstreamChatURL string - upstreamHealthURL string + client *http.Client + pdfRenderer pdf.Renderer + upstreamChatURL string + upstreamCompletions string + upstreamModelsURL string + upstreamTokenizeURL string + upstreamHealthURL string } func New(cfg Config) (*Server, error) { @@ -51,10 +54,13 @@ func New(cfg Config) (*Server, error) { } return &Server{ - client: client, - pdfRenderer: renderer, - upstreamChatURL: resolveURLPath(parsed, "chat/completions"), - upstreamHealthURL: deriveHealthURL(parsed), + client: client, + pdfRenderer: renderer, + upstreamChatURL: resolveURLPath(parsed, "chat/completions"), + upstreamCompletions: resolveURLPath(parsed, "completions"), + upstreamModelsURL: resolveURLPath(parsed, "models"), + upstreamTokenizeURL: deriveRootURL(parsed, "tokenize"), + upstreamHealthURL: deriveRootURL(parsed, "health"), }, nil } @@ -63,11 +69,23 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case r.Method == http.MethodGet && r.URL.Path == "/health": s.handleHealth(w, r) case r.Method == http.MethodPost && r.URL.Path == "/v1/chat/completions": - s.handleOpenAIChatCompletions(w, r) + s.handleOpenAIPassthrough(w, r, s.upstreamChatURL) + case r.Method == http.MethodPost && r.URL.Path == "/v1/completions": + s.handleOpenAIPassthrough(w, r, s.upstreamCompletions) + case r.Method == http.MethodGet && r.URL.Path == "/v1/models": + s.handleOpenAIModelsList(w, r) + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/models/"): + s.handleOpenAIModelGet(w, r) + case r.Method == http.MethodGet && r.URL.Path == "/v1beta/models": + s.handleGeminiModelsList(w, r) case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1beta/models/") && strings.HasSuffix(r.URL.Path, ":generateContent"): s.handleGeminiGenerateContent(w, r, false) case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1beta/models/") && strings.HasSuffix(r.URL.Path, ":streamGenerateContent"): s.handleGeminiGenerateContent(w, r, true) + case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/v1beta/models/") && strings.HasSuffix(r.URL.Path, ":countTokens"): + s.handleGeminiCountTokens(w, r) + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1beta/models/"): + s.handleGeminiModelGet(w, r) default: s.handleNotFound(w, r) } @@ -113,11 +131,11 @@ func resolveURLPath(base *url.URL, extra string) string { return clone.String() } -func deriveHealthURL(base *url.URL) string { +func deriveRootURL(base *url.URL, extra string) string { clone := *base basePath := strings.TrimSuffix(clone.Path, "/") basePath = strings.TrimSuffix(basePath, "/v1") - clone.Path = joinURLPath(basePath, "health") + clone.Path = joinURLPath(basePath, extra) return clone.String() } diff --git a/internal/translate/models.go b/internal/translate/models.go new file mode 100644 index 0000000..7eac682 --- /dev/null +++ b/internal/translate/models.go @@ -0,0 +1,50 @@ +package translate + +import ( + "strings" + + "github.com/harshaneel/localaik/internal/protocol/gemini" + openaip "github.com/harshaneel/localaik/internal/protocol/openai" +) + +var defaultGenerationMethods = []string{ + "generateContent", + "streamGenerateContent", + "countTokens", +} + +func OpenAIModelListToGemini(list openaip.ModelList) gemini.ListModelsResponse { + out := gemini.ListModelsResponse{Models: make([]gemini.Model, 0, len(list.Data))} + for _, m := range list.Data { + out.Models = append(out.Models, OpenAIModelToGemini(m)) + } + return out +} + +func OpenAIModelToGemini(m openaip.Model) gemini.Model { + return gemini.Model{ + Name: "models/" + m.ID, + DisplayName: m.ID, + SupportedGenerationMethods: defaultGenerationMethods, + } +} + +// CountTokensTextFromGemini flattens a Gemini countTokens request body into a +// single text payload suitable for llama.cpp's /tokenize endpoint. Non-text +// parts (inline blobs, file refs, function calls/responses) are skipped — they +// are not tokenizable by the text tokenizer. +func CountTokensTextFromGemini(contents []gemini.Content) string { + var b strings.Builder + for _, content := range contents { + for _, part := range content.Parts { + if part.Text == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(part.Text) + } + } + return b.String() +} diff --git a/internal/translate/models_test.go b/internal/translate/models_test.go new file mode 100644 index 0000000..3b1140e --- /dev/null +++ b/internal/translate/models_test.go @@ -0,0 +1,88 @@ +package translate + +import ( + "testing" + + "github.com/harshaneel/localaik/internal/protocol/gemini" + openaip "github.com/harshaneel/localaik/internal/protocol/openai" +) + +func TestOpenAIModelListToGemini(t *testing.T) { + in := openaip.ModelList{ + Object: "list", + Data: []openaip.Model{ + {ID: "gemma-3-4b", Object: "model", OwnedBy: "llama.cpp"}, + {ID: "gemma-3-12b", Object: "model"}, + }, + } + + got := OpenAIModelListToGemini(in) + + if len(got.Models) != 2 { + t.Fatalf("models = %#v, want 2 entries", got.Models) + } + if got.Models[0].Name != "models/gemma-3-4b" { + t.Fatalf("model[0].Name = %q, want models/gemma-3-4b", got.Models[0].Name) + } + if got.Models[0].DisplayName != "gemma-3-4b" { + t.Fatalf("model[0].DisplayName = %q, want gemma-3-4b", got.Models[0].DisplayName) + } + if len(got.Models[0].SupportedGenerationMethods) != 3 { + t.Fatalf("model[0].SupportedGenerationMethods = %#v", got.Models[0].SupportedGenerationMethods) + } +} + +func TestOpenAIModelToGeminiSingle(t *testing.T) { + got := OpenAIModelToGemini(openaip.Model{ID: "gemma-3-4b"}) + if got.Name != "models/gemma-3-4b" || got.DisplayName != "gemma-3-4b" { + t.Fatalf("unexpected mapping: %#v", got) + } +} + +func TestCountTokensTextFromGemini(t *testing.T) { + tests := []struct { + name string + in []gemini.Content + want string + }{ + { + name: "single text part", + in: []gemini.Content{ + {Parts: []gemini.Part{{Text: "hello"}}}, + }, + want: "hello", + }, + { + name: "multiple parts joined with newline", + in: []gemini.Content{ + {Parts: []gemini.Part{{Text: "hello"}}}, + {Parts: []gemini.Part{{Text: "world"}}}, + }, + want: "hello\nworld", + }, + { + name: "skips non-text parts", + in: []gemini.Content{ + {Parts: []gemini.Part{ + {Text: "describe"}, + {InlineData: &gemini.Blob{MimeType: "image/png", Data: "AAA"}}, + {Text: "this image"}, + }}, + }, + want: "describe\nthis image", + }, + { + name: "empty input", + in: nil, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CountTokensTextFromGemini(tt.in); got != tt.want { + t.Fatalf("got %q, want %q", got, tt.want) + } + }) + } +} From be9abfd9dd739fbaa73a0bfcb850b174d5c80d0b Mon Sep 17 00:00:00 2001 From: Harshaneel Gokhale Date: Mon, 18 May 2026 11:49:58 -0700 Subject: [PATCH 2/3] test: Cover error paths, method confusion, and auth-header stripping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes five test gaps identified in review: - Upstream-error → Gemini-error translation for /v1beta/models and :countTokens (TestServer*UpstreamError) - Method confusion: POST /v1/models, PUT /v1beta/models, GET on POST routes all return 404 instead of being silently misrouted - OpenAI model-get path-traversal guard (was Gemini-only) - Gemini action-verb collision: GET /v1beta/models/foo:bar must 404, not route to handleGeminiModelGet - Auth header (Authorization, X-Goog-Api-Key) stripping on the new passthrough routes (legacy completions, models list, model get) Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/server/meta_test.go | 144 ++++++++++++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 1 deletion(-) diff --git a/internal/server/meta_test.go b/internal/server/meta_test.go index 723194a..6598691 100644 --- a/internal/server/meta_test.go +++ b/internal/server/meta_test.go @@ -196,7 +196,29 @@ func TestServerGeminiModelGetMalformedPath(t *testing.T) { t.Fatalf("upstream should not be reached, got %s", r.URL.Path) })) - req := httptest.NewRequest(http.MethodGet, "/v1beta/models/foo/bar", nil) + cases := []string{ + "/v1beta/models/foo/bar", // path traversal + "/v1beta/models/foo:bar", // action-verb collision + "/v1beta/models/foo:list", // action-verb collision + } + for _, path := range cases { + t.Run(path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, path, nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", rec.Code) + } + }) + } +} + +func TestServerOpenAIModelGetMalformedPath(t *testing.T) { + srv := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("upstream should not be reached, got %s", r.URL.Path) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/models/foo/bar", nil) rec := httptest.NewRecorder() srv.ServeHTTP(rec, req) @@ -204,3 +226,123 @@ func TestServerGeminiModelGetMalformedPath(t *testing.T) { t.Fatalf("status = %d, want 404", rec.Code) } } + +func TestServerMethodConfusion(t *testing.T) { + srv := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("upstream should not be reached, got %s %s", r.Method, r.URL.Path) + })) + + cases := []struct { + method string + path string + }{ + {http.MethodPost, "/v1/models"}, + {http.MethodPost, "/v1/models/gemma-3-4b"}, + {http.MethodPut, "/v1beta/models"}, + {http.MethodPut, "/v1beta/models/gemma-3-4b"}, + {http.MethodGet, "/v1/chat/completions"}, + {http.MethodGet, "/v1/completions"}, + {http.MethodGet, "/v1beta/models/gemma-3-4b:countTokens"}, + } + for _, tc := range cases { + t.Run(tc.method+" "+tc.path, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.path, nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", rec.Code) + } + }) + } +} + +func TestServerGeminiModelsListUpstreamError(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":{"message":"upstream blew up","type":"server_error"}}`)) + }) + + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/models", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", rec.Code) + } + var errResp gemini.ErrorResponse + if err := json.Unmarshal(rec.Body.Bytes(), &errResp); err != nil { + t.Fatalf("decode response: %v; body=%s", err, rec.Body.String()) + } + if errResp.Error.Message == "" { + t.Fatalf("error message missing; body=%s", rec.Body.String()) + } +} + +func TestServerGeminiCountTokensUpstreamError(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"bad tokens","type":"invalid_request_error"}}`)) + }) + + srv := newTestServer(t, upstream) + + body := `{"contents":[{"parts":[{"text":"hi"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemma-3-4b:countTokens", bytes.NewBufferString(body)) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body=%s", rec.Code, rec.Body.String()) + } + var errResp gemini.ErrorResponse + if err := json.Unmarshal(rec.Body.Bytes(), &errResp); err != nil { + t.Fatalf("decode response: %v", err) + } + if errResp.Error.Message == "" { + t.Fatalf("error message missing; body=%s", rec.Body.String()) + } +} + +func TestServerPassthroughStripsAuthHeaders(t *testing.T) { + var seenAuth, seenGoogKey string + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = r.Header.Get("Authorization") + seenGoogKey = r.Header.Get("X-Goog-Api-Key") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{}`)) + }) + + cases := []struct { + name string + method string + path string + }{ + {"legacy_completions", http.MethodPost, "/v1/completions"}, + {"openai_models_list", http.MethodGet, "/v1/models"}, + {"openai_model_get", http.MethodGet, "/v1/models/gemma-3-4b"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + seenAuth, seenGoogKey = "", "" + srv := newTestServer(t, upstream) + + req := httptest.NewRequest(tc.method, tc.path, bytes.NewBufferString(`{}`)) + req.Header.Set("Authorization", "Bearer secret") + req.Header.Set("X-Goog-Api-Key", "secret-key") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if seenAuth != "" { + t.Fatalf("Authorization leaked upstream: %q", seenAuth) + } + if seenGoogKey != "" { + t.Fatalf("X-Goog-Api-Key leaked upstream: %q", seenGoogKey) + } + }) + } +} From 337971068b18daa27cc4dda3d27ba5f3da153766 Mon Sep 17 00:00:00 2001 From: Harshaneel Gokhale Date: Mon, 18 May 2026 11:52:52 -0700 Subject: [PATCH 3/3] test: Address review findings on test-gap commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move closure vars (seenAuth, seenGoogKey, upstream) inside the TestServerPassthroughStripsAuthHeaders sub-test so each case is fully isolated and the test is safe under future t.Parallel() - Use nil body on GET sub-cases (was bytes.NewBufferString("{}")) to match the convention used elsewhere in the file - Add POST /v1beta/models to TestServerMethodConfusion — covers the prefix-route regression target alongside the existing PUT case Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/server/meta_test.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/internal/server/meta_test.go b/internal/server/meta_test.go index 6598691..060c453 100644 --- a/internal/server/meta_test.go +++ b/internal/server/meta_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "io" "net/http" "net/http/httptest" "testing" @@ -238,6 +239,7 @@ func TestServerMethodConfusion(t *testing.T) { }{ {http.MethodPost, "/v1/models"}, {http.MethodPost, "/v1/models/gemma-3-4b"}, + {http.MethodPost, "/v1beta/models"}, {http.MethodPut, "/v1beta/models"}, {http.MethodPut, "/v1beta/models/gemma-3-4b"}, {http.MethodGet, "/v1/chat/completions"}, @@ -308,15 +310,6 @@ func TestServerGeminiCountTokensUpstreamError(t *testing.T) { } func TestServerPassthroughStripsAuthHeaders(t *testing.T) { - var seenAuth, seenGoogKey string - - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seenAuth = r.Header.Get("Authorization") - seenGoogKey = r.Header.Get("X-Goog-Api-Key") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{}`)) - }) - cases := []struct { name string method string @@ -328,10 +321,20 @@ func TestServerPassthroughStripsAuthHeaders(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - seenAuth, seenGoogKey = "", "" + var seenAuth, seenGoogKey string + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = r.Header.Get("Authorization") + seenGoogKey = r.Header.Get("X-Goog-Api-Key") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{}`)) + }) srv := newTestServer(t, upstream) - req := httptest.NewRequest(tc.method, tc.path, bytes.NewBufferString(`{}`)) + var body io.Reader + if tc.method == http.MethodPost { + body = bytes.NewBufferString(`{}`) + } + req := httptest.NewRequest(tc.method, tc.path, body) req.Header.Set("Authorization", "Bearer secret") req.Header.Set("X-Goog-Api-Key", "secret-key") rec := httptest.NewRecorder()