diff --git a/README.md b/README.md index 53acdd5178..77b8667b2f 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,12 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB PoixeAI Thanks to Poixe AI for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive CLIProxyAPI referral link and receive a bonus of $5 USD on your first top-up. + +VisionCoder +Thanks to VisionCoder for supporting this project. VisionCoder Developer Platform is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity. +

+VisionCoder is also offering our users a limited-time Token Plan promotion: buy 1 month and get 1 month free. + diff --git a/README_CN.md b/README_CN.md index 86ea954209..75d50e7ac1 100644 --- a/README_CN.md +++ b/README_CN.md @@ -24,23 +24,29 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元 PackyCode -感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用此链接注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。 +感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用此链接注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。 AICodeMirror -感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! +感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! BmoPlus -感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! +感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! LingtrueAPI -感谢 LingtrueAPI 对本项目的赞助!LingtrueAPI 是一家全球大模型API中转服务平台,提供Claude Code、Codex、Gemini 等多种顶级模型API调用服务,致力于让用户以低成本、高稳定性链接全球AI能力。LingtrueAPI为本软件用户提供了特别优惠:使用此链接注册,并在首次充值时输入 "LingtrueAPI" 优惠码即可享受9折优惠。 +感谢 LingtrueAPI 对本项目的赞助!LingtrueAPI 是一家全球大模型API中转服务平台,提供Claude Code、Codex、Gemini 等多种顶级模型API调用服务,致力于让用户以低成本、高稳定性链接全球AI能力。LingtrueAPI为本软件用户提供了特别优惠:使用此链接注册,并在首次充值时输入 "LingtrueAPI" 优惠码即可享受9折优惠。 PoixeAI -感谢 Poixe AI 对本项目的赞助!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 CLIProxyAPI 专属链接注册,充值额外赠送 $5 美金 +感谢 Poixe AI 对本项目的赞助!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 CLIProxyAPI 专属链接注册,充值额外赠送 $5 美金 + + +VisionCoder +感谢 VisionCoder 对本项目的支持。VisionCoder 开发平台 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。 +

+VisionCoder 还为我们的用户提供 Token Plan 限时活动:购买 1 个月,赠送 1 个月。 diff --git a/README_JA.md b/README_JA.md index 8c34325b49..cf8a0f77d8 100644 --- a/README_JA.md +++ b/README_JA.md @@ -42,6 +42,10 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB PoixeAI Poixe AIのスポンサーシップに感謝します!Poixe AIは信頼できるAIモデルAPIサービスを提供しており、プラットフォームが提供するLLM APIを使って簡単にAI製品を構築できます。また、サプライヤーとしてプラットフォームに大規模モデルのリソースを提供し、収益を得ることも可能です。CLIProxyAPIの専用リンクから登録すると、チャージ時に追加で$5が付与されます。 + +VisionCoder +VisionCoderのご支援に感謝します!VisionCoder 開発プラットフォーム は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに Token Plan の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。 + diff --git a/assets/visioncoder.png b/assets/visioncoder.png new file mode 100644 index 0000000000..24b1760ce5 Binary files /dev/null and b/assets/visioncoder.png differ diff --git a/config.example.yaml b/config.example.yaml index 734dd7d522..13042b78d3 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -98,7 +98,7 @@ disable-cooling: false quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded - antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"] + antigravity-credits: true # Whether to use credits as last-resort fallback when all free-tier auths are exhausted for Claude models # Routing strategy for selecting credentials when multiple match. routing: diff --git a/internal/api/handlers/management/config_auth_index.go b/internal/api/handlers/management/config_auth_index.go new file mode 100644 index 0000000000..ed0b3ec42d --- /dev/null +++ b/internal/api/handlers/management/config_auth_index.go @@ -0,0 +1,241 @@ +package management + +import ( + "fmt" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer" +) + +type geminiKeyWithAuthIndex struct { + config.GeminiKey + AuthIndex string `json:"auth-index,omitempty"` +} + +type claudeKeyWithAuthIndex struct { + config.ClaudeKey + AuthIndex string `json:"auth-index,omitempty"` +} + +type codexKeyWithAuthIndex struct { + config.CodexKey + AuthIndex string `json:"auth-index,omitempty"` +} + +type vertexCompatKeyWithAuthIndex struct { + config.VertexCompatKey + AuthIndex string `json:"auth-index,omitempty"` +} + +type openAICompatibilityAPIKeyWithAuthIndex struct { + config.OpenAICompatibilityAPIKey + AuthIndex string `json:"auth-index,omitempty"` +} + +type openAICompatibilityWithAuthIndex struct { + Name string `json:"name"` + Priority int `json:"priority,omitempty"` + Prefix string `json:"prefix,omitempty"` + BaseURL string `json:"base-url"` + APIKeyEntries []openAICompatibilityAPIKeyWithAuthIndex `json:"api-key-entries,omitempty"` + Models []config.OpenAICompatibilityModel `json:"models,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + AuthIndex string `json:"auth-index,omitempty"` +} + +func (h *Handler) liveAuthIndexByID() map[string]string { + out := map[string]string{} + if h == nil { + return out + } + h.mu.Lock() + manager := h.authManager + h.mu.Unlock() + if manager == nil { + return out + } + // authManager.List() returns clones, so EnsureIndex only affects these copies. + for _, auth := range manager.List() { + if auth == nil { + continue + } + id := strings.TrimSpace(auth.ID) + if id == "" { + continue + } + idx := strings.TrimSpace(auth.Index) + if idx == "" { + idx = auth.EnsureIndex() + } + if idx == "" { + continue + } + out[id] = idx + } + return out +} + +func (h *Handler) geminiKeysWithAuthIndex() []geminiKeyWithAuthIndex { + if h == nil { + return nil + } + liveIndexByID := h.liveAuthIndexByID() + + h.mu.Lock() + defer h.mu.Unlock() + if h.cfg == nil { + return nil + } + + idGen := synthesizer.NewStableIDGenerator() + out := make([]geminiKeyWithAuthIndex, len(h.cfg.GeminiKey)) + for i := range h.cfg.GeminiKey { + entry := h.cfg.GeminiKey[i] + authIndex := "" + if key := strings.TrimSpace(entry.APIKey); key != "" { + id, _ := idGen.Next("gemini:apikey", key, entry.BaseURL) + authIndex = liveIndexByID[id] + } + out[i] = geminiKeyWithAuthIndex{ + GeminiKey: entry, + AuthIndex: authIndex, + } + } + return out +} + +func (h *Handler) claudeKeysWithAuthIndex() []claudeKeyWithAuthIndex { + if h == nil { + return nil + } + liveIndexByID := h.liveAuthIndexByID() + + h.mu.Lock() + defer h.mu.Unlock() + if h.cfg == nil { + return nil + } + + idGen := synthesizer.NewStableIDGenerator() + out := make([]claudeKeyWithAuthIndex, len(h.cfg.ClaudeKey)) + for i := range h.cfg.ClaudeKey { + entry := h.cfg.ClaudeKey[i] + authIndex := "" + if key := strings.TrimSpace(entry.APIKey); key != "" { + id, _ := idGen.Next("claude:apikey", key, entry.BaseURL) + authIndex = liveIndexByID[id] + } + out[i] = claudeKeyWithAuthIndex{ + ClaudeKey: entry, + AuthIndex: authIndex, + } + } + return out +} + +func (h *Handler) codexKeysWithAuthIndex() []codexKeyWithAuthIndex { + if h == nil { + return nil + } + liveIndexByID := h.liveAuthIndexByID() + + h.mu.Lock() + defer h.mu.Unlock() + if h.cfg == nil { + return nil + } + + idGen := synthesizer.NewStableIDGenerator() + out := make([]codexKeyWithAuthIndex, len(h.cfg.CodexKey)) + for i := range h.cfg.CodexKey { + entry := h.cfg.CodexKey[i] + authIndex := "" + if key := strings.TrimSpace(entry.APIKey); key != "" { + id, _ := idGen.Next("codex:apikey", key, entry.BaseURL) + authIndex = liveIndexByID[id] + } + out[i] = codexKeyWithAuthIndex{ + CodexKey: entry, + AuthIndex: authIndex, + } + } + return out +} + +func (h *Handler) vertexCompatKeysWithAuthIndex() []vertexCompatKeyWithAuthIndex { + if h == nil { + return nil + } + liveIndexByID := h.liveAuthIndexByID() + + h.mu.Lock() + defer h.mu.Unlock() + if h.cfg == nil { + return nil + } + + idGen := synthesizer.NewStableIDGenerator() + out := make([]vertexCompatKeyWithAuthIndex, len(h.cfg.VertexCompatAPIKey)) + for i := range h.cfg.VertexCompatAPIKey { + entry := h.cfg.VertexCompatAPIKey[i] + id, _ := idGen.Next("vertex:apikey", entry.APIKey, entry.BaseURL, entry.ProxyURL) + authIndex := liveIndexByID[id] + out[i] = vertexCompatKeyWithAuthIndex{ + VertexCompatKey: entry, + AuthIndex: authIndex, + } + } + return out +} + +func (h *Handler) openAICompatibilityWithAuthIndex() []openAICompatibilityWithAuthIndex { + if h == nil { + return nil + } + liveIndexByID := h.liveAuthIndexByID() + + h.mu.Lock() + defer h.mu.Unlock() + if h.cfg == nil { + return nil + } + + normalized := normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility) + out := make([]openAICompatibilityWithAuthIndex, len(normalized)) + idGen := synthesizer.NewStableIDGenerator() + for i := range normalized { + entry := normalized[i] + providerName := strings.ToLower(strings.TrimSpace(entry.Name)) + if providerName == "" { + providerName = "openai-compatibility" + } + idKind := fmt.Sprintf("openai-compatibility:%s", providerName) + + response := openAICompatibilityWithAuthIndex{ + Name: entry.Name, + Priority: entry.Priority, + Prefix: entry.Prefix, + BaseURL: entry.BaseURL, + Models: entry.Models, + Headers: entry.Headers, + AuthIndex: "", + } + if len(entry.APIKeyEntries) == 0 { + id, _ := idGen.Next(idKind, entry.BaseURL) + response.AuthIndex = liveIndexByID[id] + } else { + response.APIKeyEntries = make([]openAICompatibilityAPIKeyWithAuthIndex, len(entry.APIKeyEntries)) + for j := range entry.APIKeyEntries { + apiKeyEntry := entry.APIKeyEntries[j] + id, _ := idGen.Next(idKind, apiKeyEntry.APIKey, entry.BaseURL, apiKeyEntry.ProxyURL) + response.APIKeyEntries[j] = openAICompatibilityAPIKeyWithAuthIndex{ + OpenAICompatibilityAPIKey: apiKeyEntry, + AuthIndex: liveIndexByID[id], + } + } + } + out[i] = response + } + return out +} diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index fbaad956e0..ee3a4714b8 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -120,7 +120,7 @@ func (h *Handler) DeleteAPIKeys(c *gin.Context) { // gemini-api-key: []GeminiKey func (h *Handler) GetGeminiKeys(c *gin.Context) { - c.JSON(200, gin.H{"gemini-api-key": h.cfg.GeminiKey}) + c.JSON(200, gin.H{"gemini-api-key": h.geminiKeysWithAuthIndex()}) } func (h *Handler) PutGeminiKeys(c *gin.Context) { data, err := c.GetRawData() @@ -139,9 +139,11 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) { } arr = obj.Items } + h.mu.Lock() + defer h.mu.Unlock() h.cfg.GeminiKey = append([]config.GeminiKey(nil), arr...) h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) PatchGeminiKey(c *gin.Context) { type geminiKeyPatch struct { @@ -161,6 +163,9 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } + + h.mu.Lock() + defer h.mu.Unlock() targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { targetIndex = *body.Index @@ -187,7 +192,7 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { if trimmed == "" { h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...) h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.persistLocked(c) return } entry.APIKey = trimmed @@ -209,10 +214,12 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { } h.cfg.GeminiKey[targetIndex] = entry h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) DeleteGeminiKey(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() if val := strings.TrimSpace(c.Query("api-key")); val != "" { if baseRaw, okBase := c.GetQuery("base-url"); okBase { base := strings.TrimSpace(baseRaw) @@ -226,7 +233,7 @@ func (h *Handler) DeleteGeminiKey(c *gin.Context) { if len(out) != len(h.cfg.GeminiKey) { h.cfg.GeminiKey = out h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.persistLocked(c) } else { c.JSON(404, gin.H{"error": "item not found"}) } @@ -253,7 +260,7 @@ func (h *Handler) DeleteGeminiKey(c *gin.Context) { } h.cfg.GeminiKey = append(h.cfg.GeminiKey[:matchIndex], h.cfg.GeminiKey[matchIndex+1:]...) h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.persistLocked(c) return } if idxStr := c.Query("index"); idxStr != "" { @@ -261,7 +268,7 @@ func (h *Handler) DeleteGeminiKey(c *gin.Context) { if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil && idx >= 0 && idx < len(h.cfg.GeminiKey) { h.cfg.GeminiKey = append(h.cfg.GeminiKey[:idx], h.cfg.GeminiKey[idx+1:]...) h.cfg.SanitizeGeminiKeys() - h.persist(c) + h.persistLocked(c) return } } @@ -270,7 +277,7 @@ func (h *Handler) DeleteGeminiKey(c *gin.Context) { // claude-api-key: []ClaudeKey func (h *Handler) GetClaudeKeys(c *gin.Context) { - c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey}) + c.JSON(200, gin.H{"claude-api-key": h.claudeKeysWithAuthIndex()}) } func (h *Handler) PutClaudeKeys(c *gin.Context) { data, err := c.GetRawData() @@ -292,9 +299,11 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) { for i := range arr { normalizeClaudeKey(&arr[i]) } + h.mu.Lock() + defer h.mu.Unlock() h.cfg.ClaudeKey = arr h.cfg.SanitizeClaudeKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) PatchClaudeKey(c *gin.Context) { type claudeKeyPatch struct { @@ -315,6 +324,9 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } + + h.mu.Lock() + defer h.mu.Unlock() targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { targetIndex = *body.Index @@ -358,10 +370,12 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) { normalizeClaudeKey(&entry) h.cfg.ClaudeKey[targetIndex] = entry h.cfg.SanitizeClaudeKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) DeleteClaudeKey(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() if val := strings.TrimSpace(c.Query("api-key")); val != "" { if baseRaw, okBase := c.GetQuery("base-url"); okBase { base := strings.TrimSpace(baseRaw) @@ -374,7 +388,7 @@ func (h *Handler) DeleteClaudeKey(c *gin.Context) { } h.cfg.ClaudeKey = out h.cfg.SanitizeClaudeKeys() - h.persist(c) + h.persistLocked(c) return } @@ -396,7 +410,7 @@ func (h *Handler) DeleteClaudeKey(c *gin.Context) { h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:matchIndex], h.cfg.ClaudeKey[matchIndex+1:]...) } h.cfg.SanitizeClaudeKeys() - h.persist(c) + h.persistLocked(c) return } if idxStr := c.Query("index"); idxStr != "" { @@ -405,7 +419,7 @@ func (h *Handler) DeleteClaudeKey(c *gin.Context) { if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) { h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...) h.cfg.SanitizeClaudeKeys() - h.persist(c) + h.persistLocked(c) return } } @@ -414,7 +428,7 @@ func (h *Handler) DeleteClaudeKey(c *gin.Context) { // openai-compatibility: []OpenAICompatibility func (h *Handler) GetOpenAICompat(c *gin.Context) { - c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility)}) + c.JSON(200, gin.H{"openai-compatibility": h.openAICompatibilityWithAuthIndex()}) } func (h *Handler) PutOpenAICompat(c *gin.Context) { data, err := c.GetRawData() @@ -440,9 +454,11 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) { filtered = append(filtered, arr[i]) } } + h.mu.Lock() + defer h.mu.Unlock() h.cfg.OpenAICompatibility = filtered h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + h.persistLocked(c) } func (h *Handler) PatchOpenAICompat(c *gin.Context) { type openAICompatPatch struct { @@ -462,6 +478,9 @@ func (h *Handler) PatchOpenAICompat(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } + + h.mu.Lock() + defer h.mu.Unlock() targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { targetIndex = *body.Index @@ -492,7 +511,7 @@ func (h *Handler) PatchOpenAICompat(c *gin.Context) { if trimmed == "" { h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...) h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + h.persistLocked(c) return } entry.BaseURL = trimmed @@ -509,10 +528,12 @@ func (h *Handler) PatchOpenAICompat(c *gin.Context) { normalizeOpenAICompatibilityEntry(&entry) h.cfg.OpenAICompatibility[targetIndex] = entry h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + h.persistLocked(c) } func (h *Handler) DeleteOpenAICompat(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() if name := c.Query("name"); name != "" { out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) for _, v := range h.cfg.OpenAICompatibility { @@ -522,7 +543,7 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { } h.cfg.OpenAICompatibility = out h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + h.persistLocked(c) return } if idxStr := c.Query("index"); idxStr != "" { @@ -531,7 +552,7 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) { h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...) h.cfg.SanitizeOpenAICompatibility() - h.persist(c) + h.persistLocked(c) return } } @@ -540,7 +561,7 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { // vertex-api-key: []VertexCompatKey func (h *Handler) GetVertexCompatKeys(c *gin.Context) { - c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey}) + c.JSON(200, gin.H{"vertex-api-key": h.vertexCompatKeysWithAuthIndex()}) } func (h *Handler) PutVertexCompatKeys(c *gin.Context) { data, err := c.GetRawData() @@ -566,9 +587,11 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) { return } } + h.mu.Lock() + defer h.mu.Unlock() h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...) h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) PatchVertexCompatKey(c *gin.Context) { type vertexCompatPatch struct { @@ -589,6 +612,9 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } + + h.mu.Lock() + defer h.mu.Unlock() targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) { targetIndex = *body.Index @@ -615,7 +641,7 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { if trimmed == "" { h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...) h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) return } entry.APIKey = trimmed @@ -628,7 +654,7 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { if trimmed == "" { h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...) h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) return } entry.BaseURL = trimmed @@ -648,10 +674,12 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { normalizeVertexCompatKey(&entry) h.cfg.VertexCompatAPIKey[targetIndex] = entry h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() if val := strings.TrimSpace(c.Query("api-key")); val != "" { if baseRaw, okBase := c.GetQuery("base-url"); okBase { base := strings.TrimSpace(baseRaw) @@ -664,7 +692,7 @@ func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { } h.cfg.VertexCompatAPIKey = out h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) return } @@ -686,7 +714,7 @@ func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:matchIndex], h.cfg.VertexCompatAPIKey[matchIndex+1:]...) } h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) return } if idxStr := c.Query("index"); idxStr != "" { @@ -695,7 +723,7 @@ func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) { h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...) h.cfg.SanitizeVertexCompatKeys() - h.persist(c) + h.persistLocked(c) return } } @@ -886,7 +914,7 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) { // codex-api-key: []CodexKey func (h *Handler) GetCodexKeys(c *gin.Context) { - c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey}) + c.JSON(200, gin.H{"codex-api-key": h.codexKeysWithAuthIndex()}) } func (h *Handler) PutCodexKeys(c *gin.Context) { data, err := c.GetRawData() @@ -915,9 +943,11 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { } filtered = append(filtered, entry) } + h.mu.Lock() + defer h.mu.Unlock() h.cfg.CodexKey = filtered h.cfg.SanitizeCodexKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) PatchCodexKey(c *gin.Context) { type codexKeyPatch struct { @@ -938,6 +968,9 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid body"}) return } + + h.mu.Lock() + defer h.mu.Unlock() targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { targetIndex = *body.Index @@ -968,7 +1001,7 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { if trimmed == "" { h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...) h.cfg.SanitizeCodexKeys() - h.persist(c) + h.persistLocked(c) return } entry.BaseURL = trimmed @@ -988,10 +1021,12 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { normalizeCodexKey(&entry) h.cfg.CodexKey[targetIndex] = entry h.cfg.SanitizeCodexKeys() - h.persist(c) + h.persistLocked(c) } func (h *Handler) DeleteCodexKey(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() if val := strings.TrimSpace(c.Query("api-key")); val != "" { if baseRaw, okBase := c.GetQuery("base-url"); okBase { base := strings.TrimSpace(baseRaw) @@ -1004,7 +1039,7 @@ func (h *Handler) DeleteCodexKey(c *gin.Context) { } h.cfg.CodexKey = out h.cfg.SanitizeCodexKeys() - h.persist(c) + h.persistLocked(c) return } @@ -1026,7 +1061,7 @@ func (h *Handler) DeleteCodexKey(c *gin.Context) { h.cfg.CodexKey = append(h.cfg.CodexKey[:matchIndex], h.cfg.CodexKey[matchIndex+1:]...) } h.cfg.SanitizeCodexKeys() - h.persist(c) + h.persistLocked(c) return } if idxStr := c.Query("index"); idxStr != "" { @@ -1035,7 +1070,7 @@ func (h *Handler) DeleteCodexKey(c *gin.Context) { if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) { h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...) h.cfg.SanitizeCodexKeys() - h.persist(c) + h.persistLocked(c) return } } diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 45786b9d3e..30cc973817 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -105,10 +105,24 @@ func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manag } // SetConfig updates the in-memory config reference when the server hot-reloads. -func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg } +func (h *Handler) SetConfig(cfg *config.Config) { + if h == nil { + return + } + h.mu.Lock() + h.cfg = cfg + h.mu.Unlock() +} // SetAuthManager updates the auth manager reference used by management endpoints. -func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager } +func (h *Handler) SetAuthManager(manager *coreauth.Manager) { + if h == nil { + return + } + h.mu.Lock() + h.authManager = manager + h.mu.Unlock() +} // SetUsageStatistics allows replacing the usage statistics reference. func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats } @@ -276,6 +290,12 @@ func (h *Handler) Middleware() gin.HandlerFunc { func (h *Handler) persist(c *gin.Context) bool { h.mu.Lock() defer h.mu.Unlock() + return h.persistLocked(c) +} + +// persistLocked saves the current in-memory config to disk. +// It expects the caller to hold h.mu. +func (h *Handler) persistLocked(c *gin.Context) bool { // Preserve comments when writing if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)}) diff --git a/internal/api/server.go b/internal/api/server.go index 075455ba83..32ae3164fd 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -319,9 +319,16 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // setupRoutes configures the API routes for the server. // It defines the endpoints and associates them with their respective handlers. func (s *Server) setupRoutes() { - s.engine.GET("/healthz", func(c *gin.Context) { + healthzHandler := func(c *gin.Context) { + if c.Request.Method == http.MethodHead { + c.Status(http.StatusOK) + return + } + c.JSON(http.StatusOK, gin.H{"status": "ok"}) - }) + } + s.engine.GET("/healthz", healthzHandler) + s.engine.HEAD("/healthz", healthzHandler) s.engine.GET("/management.html", s.serveManagementControlPanel) openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers) @@ -337,6 +344,8 @@ func (s *Server) setupRoutes() { v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) v1.POST("/chat/completions", openaiHandlers.ChatCompletions) v1.POST("/completions", openaiHandlers.Completions) + v1.POST("/images/generations", openaiHandlers.ImagesGenerations) + v1.POST("/images/edits", openaiHandlers.ImagesEdits) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket) @@ -344,6 +353,15 @@ func (s *Server) setupRoutes() { v1.POST("/responses/compact", openaiResponsesHandlers.Compact) } + // Codex CLI direct route aliases (chatgpt_base_url compatible) + codexDirect := s.engine.Group("/backend-api/codex") + codexDirect.Use(AuthMiddleware(s.accessManager)) + { + codexDirect.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket) + codexDirect.POST("/responses", openaiResponsesHandlers.Responses) + codexDirect.POST("/responses/compact", openaiResponsesHandlers.Compact) + } + // Gemini compatible API routes v1beta := s.engine.Group("/v1beta") v1beta.Use(AuthMiddleware(s.accessManager)) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index dbc2cd5a83..db1ef27d17 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -50,23 +50,38 @@ func newTestServer(t *testing.T) *Server { func TestHealthz(t *testing.T) { server := newTestServer(t) - req := httptest.NewRequest(http.MethodGet, "/healthz", nil) - rr := httptest.NewRecorder() - server.engine.ServeHTTP(rr, req) + t.Run("GET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) - if rr.Code != http.StatusOK { - t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) - } + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } - var resp struct { - Status string `json:"status"` - } - if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String()) - } - if resp.Status != "ok" { - t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok") - } + var resp struct { + Status string `json:"status"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String()) + } + if resp.Status != "ok" { + t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok") + } + }) + + t.Run("HEAD", func(t *testing.T) { + req := httptest.NewRequest(http.MethodHead, "/healthz", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + if rr.Body.Len() != 0 { + t.Fatalf("expected empty body for HEAD request, got %q", rr.Body.String()) + } + }) } func TestAmpProviderModelRoutes(t *testing.T) { diff --git a/internal/config/config.go b/internal/config/config.go index 760d43ec4a..1ebbb460c0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -206,8 +206,9 @@ type QuotaExceeded struct { // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` - // AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once - // on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"]. + // AntigravityCredits enables credits-based last-resort fallback for Claude models. + // When all free-tier auths are exhausted (429/503), the conductor retries with + // an auth that has available Google One AI credits. AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"` } diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index b94d7afe6d..d92ae985e5 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -20,6 +20,7 @@ import ( var aiAPIPrefixes = []string{ "/v1/chat/completions", "/v1/completions", + "/v1/images", "/v1/messages", "/v1/responses", "/v1beta/models/", diff --git a/internal/logging/gin_logger_test.go b/internal/logging/gin_logger_test.go index 7de1833865..9bd3ddfba6 100644 --- a/internal/logging/gin_logger_test.go +++ b/internal/logging/gin_logger_test.go @@ -58,3 +58,12 @@ func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) { t.Fatalf("expected 500, got %d", recorder.Code) } } + +func TestIsAIAPIPathIncludesImages(t *testing.T) { + if !isAIAPIPath("/v1/images/generations") { + t.Fatalf("expected /v1/images/generations to be treated as AI API path") + } + if !isAIAPIPath("/v1/images/edits") { + t.Fatalf("expected /v1/images/edits to be treated as AI API path") + } +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index ab7258f845..7ac6b469ac 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -6,6 +6,8 @@ import ( "strings" ) +const codexBuiltinImageModelID = "gpt-image-2" + // staticModelsJSON mirrors the top-level structure of models.json. type staticModelsJSON struct { Claude []*ModelInfo `json:"claude"` @@ -48,22 +50,22 @@ func GetAIStudioModels() []*ModelInfo { // GetCodexFreeModels returns model definitions for the Codex free plan tier. func GetCodexFreeModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexFree) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexFree)) } // GetCodexTeamModels returns model definitions for the Codex team plan tier. func GetCodexTeamModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexTeam) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexTeam)) } // GetCodexPlusModels returns model definitions for the Codex plus plan tier. func GetCodexPlusModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexPlus) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexPlus)) } // GetCodexProModels returns model definitions for the Codex pro plan tier. func GetCodexProModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexPro) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexPro)) } // GetKimiModels returns the standard Kimi (Moonshot AI) model definitions. @@ -76,6 +78,71 @@ func GetAntigravityModels() []*ModelInfo { return cloneModelInfos(getModels().Antigravity) } +// WithCodexBuiltins injects hard-coded Codex-only model definitions that should +// not depend on remote models.json updates. Built-ins replace any matching IDs +// already present in the provided slice. +func WithCodexBuiltins(models []*ModelInfo) []*ModelInfo { + return upsertModelInfos(models, codexBuiltinImageModelInfo()) +} + +func codexBuiltinImageModelInfo() *ModelInfo { + return &ModelInfo{ + ID: codexBuiltinImageModelID, + Object: "model", + Created: 1704067200, // 2024-01-01 + OwnedBy: "openai", + Type: "openai", + DisplayName: "GPT Image 2", + Version: codexBuiltinImageModelID, + } +} + +func upsertModelInfos(models []*ModelInfo, extras ...*ModelInfo) []*ModelInfo { + if len(extras) == 0 { + return models + } + + extraIDs := make(map[string]struct{}, len(extras)) + extraList := make([]*ModelInfo, 0, len(extras)) + for _, extra := range extras { + if extra == nil { + continue + } + id := strings.TrimSpace(extra.ID) + if id == "" { + continue + } + key := strings.ToLower(id) + if _, exists := extraIDs[key]; exists { + continue + } + extraIDs[key] = struct{}{} + extraList = append(extraList, cloneModelInfo(extra)) + } + + if len(extraList) == 0 { + return models + } + + filtered := make([]*ModelInfo, 0, len(models)+len(extraList)) + for _, model := range models { + if model == nil { + continue + } + id := strings.TrimSpace(model.ID) + if id == "" { + continue + } + if _, exists := extraIDs[strings.ToLower(id)]; exists { + continue + } + filtered = append(filtered, model) + } + + filtered = append(filtered, extraList...) + return filtered +} + // cloneModelInfos returns a shallow copy of the slice with each element deep-cloned. func cloneModelInfos(models []*ModelInfo) []*ModelInfo { if len(models) == 0 { diff --git a/internal/registry/model_definitions_test.go b/internal/registry/model_definitions_test.go new file mode 100644 index 0000000000..7a0630c28d --- /dev/null +++ b/internal/registry/model_definitions_test.go @@ -0,0 +1,88 @@ +package registry + +import "testing" + +func TestCodexStaticModelsIncludeGPT55(t *testing.T) { + tierModels := map[string][]*ModelInfo{ + "free": GetCodexFreeModels(), + "team": GetCodexTeamModels(), + "plus": GetCodexPlusModels(), + "pro": GetCodexProModels(), + } + + for tier, models := range tierModels { + t.Run(tier, func(t *testing.T) { + model := findModelInfo(models, "gpt-5.5") + if model == nil { + t.Fatalf("expected codex %s tier to include gpt-5.5", tier) + } + assertGPT55ModelInfo(t, tier, model) + }) + } + + model := LookupStaticModelInfo("gpt-5.5") + if model == nil { + t.Fatal("expected LookupStaticModelInfo to find gpt-5.5") + } + assertGPT55ModelInfo(t, "lookup", model) +} + +func findModelInfo(models []*ModelInfo, id string) *ModelInfo { + for _, model := range models { + if model != nil && model.ID == id { + return model + } + } + return nil +} + +func assertGPT55ModelInfo(t *testing.T, source string, model *ModelInfo) { + t.Helper() + + if model.ID != "gpt-5.5" { + t.Fatalf("%s id mismatch: got %q", source, model.ID) + } + if model.Object != "model" { + t.Fatalf("%s object mismatch: got %q", source, model.Object) + } + if model.Created != 1776902400 { + t.Fatalf("%s created timestamp mismatch: got %d", source, model.Created) + } + if model.OwnedBy != "openai" { + t.Fatalf("%s owned_by mismatch: got %q", source, model.OwnedBy) + } + if model.Type != "openai" { + t.Fatalf("%s type mismatch: got %q", source, model.Type) + } + if model.DisplayName != "GPT 5.5" { + t.Fatalf("%s display name mismatch: got %q", source, model.DisplayName) + } + if model.Version != "gpt-5.5" { + t.Fatalf("%s version mismatch: got %q", source, model.Version) + } + if model.Description != "Frontier model for complex coding, research, and real-world work." { + t.Fatalf("%s description mismatch: got %q", source, model.Description) + } + if model.ContextLength != 272000 { + t.Fatalf("%s context length mismatch: got %d", source, model.ContextLength) + } + if model.MaxCompletionTokens != 128000 { + t.Fatalf("%s max completion tokens mismatch: got %d", source, model.MaxCompletionTokens) + } + if len(model.SupportedParameters) != 1 || model.SupportedParameters[0] != "tools" { + t.Fatalf("%s supported parameters mismatch: got %v", source, model.SupportedParameters) + } + if model.Thinking == nil { + t.Fatalf("%s missing thinking support", source) + } + + want := []string{"low", "medium", "high", "xhigh"} + if len(model.Thinking.Levels) != len(want) { + t.Fatalf("%s thinking level count mismatch: got %d, want %d", source, len(model.Thinking.Levels), len(want)) + } + for i, level := range want { + if model.Thinking.Levels[i] != level { + t.Fatalf("%s thinking level %d mismatch: got %q, want %q", source, i, model.Thinking.Levels[i], level) + } + } +} diff --git a/internal/registry/model_registry_safety_test.go b/internal/registry/model_registry_safety_test.go index 5f4f65d298..be5bf7908c 100644 --- a/internal/registry/model_registry_safety_test.go +++ b/internal/registry/model_registry_safety_test.go @@ -136,13 +136,13 @@ func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) { } func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) { - first := LookupModelInfo("glm-4.6") + first := LookupModelInfo("claude-sonnet-4-6") if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 { t.Fatalf("expected static model with thinking levels, got %+v", first) } first.Thinking.Levels[0] = "mutated" - second := LookupModelInfo("glm-4.6") + second := LookupModelInfo("claude-sonnet-4-6") if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" { t.Fatalf("expected static lookup clone, got %+v", second) } diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 65d8325169..a1abb5a381 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -1387,6 +1387,29 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-plus": [ @@ -1505,6 +1528,29 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-pro": [ @@ -1623,6 +1669,29 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "kimi": [ @@ -1670,6 +1739,23 @@ "zero_allowed": true, "dynamic_allowed": true } + }, + { + "id": "kimi-k2.6", + "object": "model", + "created": 1776729600, + "owned_by": "moonshot", + "type": "kimi", + "display_name": "Kimi K2.6", + "description": "Kimi K2.6 - Latest Moonshot AI coding model with improved capabilities", + "context_length": 262144, + "max_completion_tokens": 65536, + "thinking": { + "min": 1024, + "max": 32000, + "zero_allowed": true, + "dynamic_allowed": true + } } ], "antigravity": [ diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 163b2d9279..6983bface5 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -52,8 +52,8 @@ const ( defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent() antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second - antigravityCreditsRetryTTL = 5 * time.Hour - antigravityCreditsAutoDisableDuration = 5 * time.Hour + antigravityCreditsHintRefreshInterval = 10 * time.Minute + antigravityCreditsHintRefreshTimeout = 5 * time.Second antigravityShortQuotaCooldownThreshold = 5 * time.Minute antigravityInstantRetryThreshold = 3 * time.Second // systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**" @@ -62,8 +62,6 @@ const ( type antigravity429Category string type antigravityCreditsFailureState struct { - Count int - DisabledUntil time.Time PermanentlyDisabled bool ExplicitBalanceExhausted bool } @@ -91,28 +89,85 @@ var ( randSource = rand.New(rand.NewSource(time.Now().UnixNano())) randSourceMutex sync.Mutex antigravityCreditsFailureByAuth sync.Map - antigravityPreferCreditsByModel sync.Map antigravityShortCooldownByAuth sync.Map + antigravityCreditsBalanceByAuth sync.Map // auth.ID → antigravityCreditsBalance + antigravityCreditsHintRefreshByID sync.Map // auth.ID → *antigravityCreditsHintRefreshState antigravityQuotaExhaustedKeywords = []string{ "quota_exhausted", "quota exhausted", } - antigravityCreditsExhaustedKeywords = []string{ - "google_one_ai", - "insufficient credit", - "insufficient credits", - "not enough credit", - "not enough credits", - "credit exhausted", - "credits exhausted", - "credit balance", - "minimumcreditamountforusage", - "minimum credit amount for usage", - "minimum credit", - "resource has been exhausted", - } ) +type antigravityCreditsBalance struct { + CreditAmount float64 + MinCreditAmount float64 + PaidTierID string + Known bool +} + +type antigravityCreditsHintRefreshState struct { + mu sync.Mutex + lastAttempt time.Time +} + +func antigravityAuthHasCredits(auth *cliproxyauth.Auth) bool { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return false + } + if hint, ok := cliproxyauth.GetAntigravityCreditsHint(auth.ID); ok && hint.Known { + return hint.Available + } + val, ok := antigravityCreditsBalanceByAuth.Load(strings.TrimSpace(auth.ID)) + if !ok { + return true // optimistic: assume credits available when balance unknown + } + bal, valid := val.(antigravityCreditsBalance) + if !valid { + antigravityCreditsBalanceByAuth.Delete(strings.TrimSpace(auth.ID)) + return false + } + if !bal.Known { + return false + } + available := bal.CreditAmount >= bal.MinCreditAmount + cliproxyauth.SetAntigravityCreditsHint(strings.TrimSpace(auth.ID), cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: available, + CreditAmount: bal.CreditAmount, + MinCreditAmount: bal.MinCreditAmount, + PaidTierID: bal.PaidTierID, + UpdatedAt: time.Now(), + }) + return available +} + +// parseMetaFloat extracts a float64 from auth.Metadata (handles string and numeric types). +func parseMetaFloat(metadata map[string]any, key string) (float64, bool) { + v, ok := metadata[key] + if !ok { + return 0, false + } + switch typed := v.(type) { + case float64: + return typed, true + case int: + return float64(typed), true + case int64: + return float64(typed), true + case uint64: + return float64(typed), true + case json.Number: + if f, err := typed.Float64(); err == nil { + return f, true + } + case string: + if f, err := strconv.ParseFloat(strings.TrimSpace(typed), 64); err == nil { + return f, true + } + } + return 0, false +} + // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { cfg *config.Config @@ -189,7 +244,7 @@ func validateAntigravityRequestSignatures(from sdktranslator.Format, rawJSON []b if from.String() != "claude" { return rawJSON, nil } - // Always strip thinking blocks with empty signatures (proxy-generated). + // Always strip thinking blocks with invalid signatures (empty or non-Claude-format). rawJSON = antigravityclaude.StripEmptySignatureThinkingBlocks(rawJSON) if cache.SignatureCacheEnabled() { return rawJSON, nil @@ -298,49 +353,46 @@ func decideAntigravity429(body []byte) antigravity429Decision { decision.retryAfter = retryAfter } - lowerBody := strings.ToLower(string(body)) - for _, keyword := range antigravityQuotaExhaustedKeywords { - if strings.Contains(lowerBody, keyword) { - decision.kind = antigravity429DecisionFullQuotaExhausted - decision.reason = "quota_exhausted" - return decision - } - } - status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String()) if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") { return decision } details := gjson.GetBytes(body, "error.details") - if !details.Exists() || !details.IsArray() { - decision.kind = antigravity429DecisionSoftRetry - return decision - } - - for _, detail := range details.Array() { - if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { - continue - } - reason := strings.TrimSpace(detail.Get("reason").String()) - decision.reason = reason - switch { - case strings.EqualFold(reason, "QUOTA_EXHAUSTED"): - decision.kind = antigravity429DecisionFullQuotaExhausted - return decision - case strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED"): - if decision.retryAfter == nil { - decision.kind = antigravity429DecisionSoftRetry - return decision + if details.Exists() && details.IsArray() { + for _, detail := range details.Array() { + if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { + continue } + reason := strings.TrimSpace(detail.Get("reason").String()) + decision.reason = reason switch { - case *decision.retryAfter < antigravityInstantRetryThreshold: - decision.kind = antigravity429DecisionInstantRetrySameAuth - case *decision.retryAfter < antigravityShortQuotaCooldownThreshold: - decision.kind = antigravity429DecisionShortCooldownSwitchAuth - default: + case strings.EqualFold(reason, "QUOTA_EXHAUSTED"): decision.kind = antigravity429DecisionFullQuotaExhausted + return decision + case strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED"): + if decision.retryAfter == nil { + decision.kind = antigravity429DecisionSoftRetry + return decision + } + switch { + case *decision.retryAfter < antigravityInstantRetryThreshold: + decision.kind = antigravity429DecisionInstantRetrySameAuth + case *decision.retryAfter < antigravityShortQuotaCooldownThreshold: + decision.kind = antigravity429DecisionShortCooldownSwitchAuth + default: + decision.kind = antigravity429DecisionFullQuotaExhausted + } + return decision } + } + } + + lowerBody := strings.ToLower(string(body)) + for _, keyword := range antigravityQuotaExhaustedKeywords { + if strings.Contains(lowerBody, keyword) { + decision.kind = antigravity429DecisionFullQuotaExhausted + decision.reason = "quota_exhausted" return decision } } @@ -349,81 +401,10 @@ func decideAntigravity429(body []byte) antigravity429Decision { return decision } -func antigravityHasQuotaResetDelayOrModelInfo(body []byte) bool { - if len(body) == 0 { - return false - } - details := gjson.GetBytes(body, "error.details") - if !details.Exists() || !details.IsArray() { - return false - } - for _, detail := range details.Array() { - if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { - continue - } - if strings.TrimSpace(detail.Get("metadata.quotaResetDelay").String()) != "" { - return true - } - if strings.TrimSpace(detail.Get("metadata.model").String()) != "" { - return true - } - } - return false -} - func antigravityCreditsRetryEnabled(cfg *config.Config) bool { return cfg != nil && cfg.QuotaExceeded.AntigravityCredits } -func antigravityCreditsFailureStateForAuth(auth *cliproxyauth.Auth) (string, antigravityCreditsFailureState, bool) { - if auth == nil || strings.TrimSpace(auth.ID) == "" { - return "", antigravityCreditsFailureState{}, false - } - authID := strings.TrimSpace(auth.ID) - value, ok := antigravityCreditsFailureByAuth.Load(authID) - if !ok { - return authID, antigravityCreditsFailureState{}, true - } - state, ok := value.(antigravityCreditsFailureState) - if !ok { - antigravityCreditsFailureByAuth.Delete(authID) - return authID, antigravityCreditsFailureState{}, true - } - return authID, state, true -} - -func antigravityCreditsDisabled(auth *cliproxyauth.Auth, now time.Time) bool { - authID, state, ok := antigravityCreditsFailureStateForAuth(auth) - if !ok { - return false - } - if state.PermanentlyDisabled { - return true - } - if state.DisabledUntil.IsZero() { - return false - } - if state.DisabledUntil.After(now) { - return true - } - antigravityCreditsFailureByAuth.Delete(authID) - return false -} - -func recordAntigravityCreditsFailure(auth *cliproxyauth.Auth, now time.Time) { - authID, state, ok := antigravityCreditsFailureStateForAuth(auth) - if !ok { - return - } - if state.PermanentlyDisabled { - antigravityCreditsFailureByAuth.Store(authID, state) - return - } - state.Count++ - state.DisabledUntil = now.Add(antigravityCreditsAutoDisableDuration) - antigravityCreditsFailureByAuth.Store(authID, state) -} - func clearAntigravityCreditsFailureState(auth *cliproxyauth.Auth) { if auth == nil || strings.TrimSpace(auth.ID) == "" { return @@ -440,6 +421,25 @@ func markAntigravityCreditsPermanentlyDisabled(auth *cliproxyauth.Auth) { ExplicitBalanceExhausted: true, } antigravityCreditsFailureByAuth.Store(authID, state) + antigravityCreditsBalanceByAuth.Store(authID, antigravityCreditsBalance{ + CreditAmount: 0, + MinCreditAmount: 1, + Known: true, + }) + cliproxyauth.SetAntigravityCreditsHint(authID, cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: false, + CreditAmount: 0, + MinCreditAmount: 1, + UpdatedAt: time.Now(), + }) +} + +func clearAntigravityCreditsPermanentlyDisabled(auth *cliproxyauth.Auth) { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return + } + antigravityCreditsFailureByAuth.Delete(strings.TrimSpace(auth.ID)) } func antigravityHasExplicitCreditsBalanceExhaustedReason(body []byte) bool { @@ -462,81 +462,6 @@ func antigravityHasExplicitCreditsBalanceExhaustedReason(body []byte) bool { return false } -func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string { - if auth == nil { - return "" - } - authID := strings.TrimSpace(auth.ID) - modelName = strings.TrimSpace(modelName) - if authID == "" || modelName == "" { - return "" - } - return authID + "|" + modelName -} - -func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool { - key := antigravityPreferCreditsKey(auth, modelName) - if key == "" { - return false - } - value, ok := antigravityPreferCreditsByModel.Load(key) - if !ok { - return false - } - until, ok := value.(time.Time) - if !ok || until.IsZero() { - antigravityPreferCreditsByModel.Delete(key) - return false - } - if !until.After(now) { - antigravityPreferCreditsByModel.Delete(key) - return false - } - return true -} - -func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) { - key := antigravityPreferCreditsKey(auth, modelName) - if key == "" { - return - } - until := now.Add(antigravityCreditsRetryTTL) - if retryAfter != nil && *retryAfter > 0 { - until = now.Add(*retryAfter) - } - antigravityPreferCreditsByModel.Store(key, until) -} - -func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) { - key := antigravityPreferCreditsKey(auth, modelName) - if key == "" { - return - } - antigravityPreferCreditsByModel.Delete(key) -} - -func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool { - if reqErr != nil || statusCode == 0 { - return false - } - if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout { - return false - } - lowerBody := strings.ToLower(string(body)) - for _, keyword := range antigravityCreditsExhaustedKeywords { - if strings.Contains(lowerBody, keyword) { - if keyword == "resource has been exhausted" && - statusCode == http.StatusTooManyRequests && - decideAntigravity429(body).kind == antigravity429DecisionSoftRetry && - !antigravityHasQuotaResetDelayOrModelInfo(body) { - return false - } - return true - } - } - return false -} - func newAntigravityStatusErr(statusCode int, body []byte) statusErr { err := statusErr{code: statusCode, msg: string(body)} if statusCode == http.StatusTooManyRequests { @@ -547,129 +472,6 @@ func newAntigravityStatusErr(statusCode int, body []byte) statusErr { return err } -func (e *AntigravityExecutor) attemptCreditsFallback( - ctx context.Context, - auth *cliproxyauth.Auth, - httpClient *http.Client, - token string, - modelName string, - payload []byte, - stream bool, - alt string, - baseURL string, - originalBody []byte, -) (*http.Response, bool) { - if !antigravityCreditsRetryEnabled(e.cfg) { - return nil, false - } - if decideAntigravity429(originalBody).kind != antigravity429DecisionFullQuotaExhausted { - return nil, false - } - now := time.Now() - if shouldForcePermanentDisableCredits(originalBody) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, false - } - - if antigravityHasExplicitCreditsBalanceExhaustedReason(originalBody) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, false - } - - if antigravityCreditsDisabled(auth, now) { - return nil, false - } - creditsPayload := injectEnabledCreditTypes(payload) - if len(creditsPayload) == 0 { - return nil, false - } - - httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL) - if errReq != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errReq) - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true - } - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true - } - if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { - retryAfter, _ := parseRetryDelay(originalBody) - markAntigravityPreferCredits(auth, modelName, now, retryAfter) - clearAntigravityCreditsFailureState(auth) - return httpResp, true - } - - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - bodyBytes, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose) - } - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true - } - helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes) - if shouldForcePermanentDisableCredits(bodyBytes) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, true - } - - if antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, true - } - - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true -} - -func (e *AntigravityExecutor) handleDirectCreditsFailure(ctx context.Context, auth *cliproxyauth.Auth, modelName string, reqErr error) { - if reqErr != nil { - if shouldForcePermanentDisableCredits(reqErrBody(reqErr)) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return - } - - if antigravityHasExplicitCreditsBalanceExhaustedReason(reqErrBody(reqErr)) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return - } - - helps.RecordAPIResponseError(ctx, e.cfg, reqErr) - } - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, time.Now()) -} -func reqErrBody(reqErr error) []byte { - if reqErr == nil { - return nil - } - msg := reqErr.Error() - if strings.TrimSpace(msg) == "" { - return nil - } - return []byte(msg) -} - -func shouldForcePermanentDisableCredits(body []byte) bool { - return antigravityHasExplicitCreditsBalanceExhaustedReason(body) -} - // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { if opts.Alt == "responses/compact" { @@ -721,6 +523,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au requestedModel := helps.PayloadRequestedModel(opts, req.Model) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) + useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg) + baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) attempts := antigravityRetryAttempts(auth, e.cfg) @@ -733,11 +537,10 @@ attemptLoop: for idx, baseURL := range baseURLs { requestPayload := translated - usedCreditsDirect := false - if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) { - if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 { - requestPayload = creditsPayload - usedCreditsDirect = true + if useCredits { + if cp := injectEnabledCreditTypes(translated); len(cp) > 0 { + requestPayload = cp + helps.MarkCreditsUsed(ctx) } } @@ -785,7 +588,6 @@ attemptLoop: wait := antigravityInstantRetryDelay(*decision.retryAfter) log.Debugf("antigravity executor: instant retry for model %s, waiting %s", baseModel, wait) if errWait := antigravityWait(ctx, wait); errWait != nil { - return resp, errWait } } @@ -794,34 +596,13 @@ attemptLoop: case antigravity429DecisionShortCooldownSwitchAuth: if decision.retryAfter != nil && *decision.retryAfter > 0 { markAntigravityShortCooldown(auth, baseModel, time.Now(), *decision.retryAfter) - log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown and skipping credits fallback", *decision.retryAfter, baseModel) + log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown", *decision.retryAfter, baseModel) } case antigravity429DecisionFullQuotaExhausted: - if usedCreditsDirect { - clearAntigravityPreferCredits(auth, baseModel) - recordAntigravityCreditsFailure(auth, time.Now()) - } else { - creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes) - if creditsResp != nil { - helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone()) - creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body) - if errClose := creditsResp.Body.Close(); errClose != nil { - log.Errorf("antigravity executor: close credits success response body error: %v", errClose) - } - if errCreditsRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead) - err = errCreditsRead - return resp, err - } - helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody) - reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody)) - var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, ¶m) - resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()} - reporter.EnsurePublished(ctx) - return resp, nil - } + if useCredits && antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { + markAntigravityCreditsPermanentlyDisabled(auth) } + // No credits logic - just fall through to error return below } } @@ -870,6 +651,10 @@ attemptLoop: return resp, err } + // Success + if useCredits { + clearAntigravityCreditsFailureState(auth) + } reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes)) var param any converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m) @@ -935,6 +720,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * requestedModel := helps.PayloadRequestedModel(opts, req.Model) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) + useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg) + baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) @@ -948,11 +735,10 @@ attemptLoop: for idx, baseURL := range baseURLs { requestPayload := translated - usedCreditsDirect := false - if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) { - if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 { - requestPayload = creditsPayload - usedCreditsDirect = true + if useCredits { + if cp := injectEnabledCreditTypes(translated); len(cp) > 0 { + requestPayload = cp + helps.MarkCreditsUsed(ctx) } } httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL) @@ -1014,7 +800,6 @@ attemptLoop: wait := antigravityInstantRetryDelay(*decision.retryAfter) log.Debugf("antigravity executor: instant retry for model %s, waiting %s", baseModel, wait) if errWait := antigravityWait(ctx, wait); errWait != nil { - return resp, errWait } } @@ -1023,25 +808,16 @@ attemptLoop: case antigravity429DecisionShortCooldownSwitchAuth: if decision.retryAfter != nil && *decision.retryAfter > 0 { markAntigravityShortCooldown(auth, baseModel, time.Now(), *decision.retryAfter) - log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown and skipping credits fallback", *decision.retryAfter, baseModel) + log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown", *decision.retryAfter, baseModel) } case antigravity429DecisionFullQuotaExhausted: - if usedCreditsDirect { - clearAntigravityPreferCredits(auth, baseModel) - recordAntigravityCreditsFailure(auth, time.Now()) - } else { - creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes) - if creditsResp != nil { - httpResp = creditsResp - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - } + if useCredits && antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { + markAntigravityCreditsPermanentlyDisabled(auth) } + // No credits logic - just fall through to error return below } } - if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { - goto streamSuccessClaudeNonStream - } lastStatus = httpResp.StatusCode lastBody = append([]byte(nil), bodyBytes...) lastErr = nil @@ -1085,7 +861,10 @@ attemptLoop: return resp, err } - streamSuccessClaudeNonStream: + // Stream success + if useCredits { + clearAntigravityCreditsFailureState(auth) + } out := make(chan cliproxyexecutor.StreamChunk) go func(resp *http.Response) { defer close(out) @@ -1389,6 +1168,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya if updatedAuth != nil { auth = updatedAuth } + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) @@ -1400,6 +1180,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya requestedModel := helps.PayloadRequestedModel(opts, req.Model) translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) + useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg) + baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) @@ -1413,11 +1195,10 @@ attemptLoop: for idx, baseURL := range baseURLs { requestPayload := translated - usedCreditsDirect := false - if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) { - if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 { - requestPayload = creditsPayload - usedCreditsDirect = true + if useCredits { + if cp := injectEnabledCreditTypes(translated); len(cp) > 0 { + requestPayload = cp + helps.MarkCreditsUsed(ctx) } } httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL) @@ -1478,7 +1259,6 @@ attemptLoop: wait := antigravityInstantRetryDelay(*decision.retryAfter) log.Debugf("antigravity executor: instant retry for model %s, waiting %s", baseModel, wait) if errWait := antigravityWait(ctx, wait); errWait != nil { - return nil, errWait } } @@ -1487,25 +1267,16 @@ attemptLoop: case antigravity429DecisionShortCooldownSwitchAuth: if decision.retryAfter != nil && *decision.retryAfter > 0 { markAntigravityShortCooldown(auth, baseModel, time.Now(), *decision.retryAfter) - log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown and skipping credits fallback", *decision.retryAfter, baseModel) + log.Debugf("antigravity executor: short quota cooldown (%s) for model %s recorded", *decision.retryAfter, baseModel) } case antigravity429DecisionFullQuotaExhausted: - if usedCreditsDirect { - clearAntigravityPreferCredits(auth, baseModel) - recordAntigravityCreditsFailure(auth, time.Now()) - } else { - creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes) - if creditsResp != nil { - httpResp = creditsResp - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - } + if useCredits && antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { + markAntigravityCreditsPermanentlyDisabled(auth) } + // No credits logic - just fall through to error return below } } - if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { - goto streamSuccessExecuteStream - } lastStatus = httpResp.StatusCode lastBody = append([]byte(nil), bodyBytes...) lastErr = nil @@ -1549,7 +1320,10 @@ attemptLoop: return nil, err } - streamSuccessExecuteStream: + // Stream success + if useCredits { + clearAntigravityCreditsFailureState(auth) + } out := make(chan cliproxyexecutor.StreamChunk) go func(resp *http.Response) { defer close(out) @@ -1792,6 +1566,7 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr accessToken := metaStringValue(auth.Metadata, "access_token") expiry := tokenExpiry(auth.Metadata) if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) { + e.maybeRefreshAntigravityCreditsHint(ctx, auth, accessToken) return accessToken, nil, nil } refreshCtx := context.Background() @@ -1807,6 +1582,63 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr return metaStringValue(updated.Metadata, "access_token"), updated, nil } +func (e *AntigravityExecutor) maybeRefreshAntigravityCreditsHint(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) { + if e == nil || auth == nil || !antigravityCreditsRetryEnabled(e.cfg) { + return + } + if ctx != nil && ctx.Err() != nil { + return + } + authID := strings.TrimSpace(auth.ID) + if authID == "" { + return + } + if hint, ok := cliproxyauth.GetAntigravityCreditsHint(authID); ok && hint.Known { + return + } + if strings.TrimSpace(accessToken) == "" { + accessToken = metaStringValue(auth.Metadata, "access_token") + } + if strings.TrimSpace(accessToken) == "" { + return + } + + state := &antigravityCreditsHintRefreshState{} + if existing, loaded := antigravityCreditsHintRefreshByID.LoadOrStore(authID, state); loaded { + if cast, ok := existing.(*antigravityCreditsHintRefreshState); ok && cast != nil { + state = cast + } else { + antigravityCreditsHintRefreshByID.Delete(authID) + antigravityCreditsHintRefreshByID.Store(authID, state) + } + } + + now := time.Now() + if !state.mu.TryLock() { + return + } + if !state.lastAttempt.IsZero() && now.Sub(state.lastAttempt) < antigravityCreditsHintRefreshInterval { + state.mu.Unlock() + return + } + state.lastAttempt = now + + refreshCtx := context.Background() + if ctx != nil { + if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil { + refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt) + } + } + refreshCtx, cancel := context.WithTimeout(refreshCtx, antigravityCreditsHintRefreshTimeout) + authCopy := auth.Clone() + + go func(state *antigravityCreditsHintRefreshState, auth *cliproxyauth.Auth, token string) { + defer cancel() + defer state.mu.Unlock() + e.updateAntigravityCreditsBalance(refreshCtx, auth, token) + }(state, authCopy, accessToken) +} + func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { if auth == nil { return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} @@ -1882,6 +1714,7 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil { log.Warnf("antigravity executor: ensure project id failed: %v", errProject) } + e.updateAntigravityCreditsBalance(ctx, auth, tokenResp.AccessToken) return auth, nil } @@ -1918,6 +1751,94 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au return nil } +func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return + } + token := strings.TrimSpace(accessToken) + if token == "" { + token = metaStringValue(auth.Metadata, "access_token") + } + if token == "" { + return + } + + loadReqBody := `{"metadata":{"ideType":"ANTIGRAVITY","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}}` + endpointURL := "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(loadReqBody)) + if errReq != nil { + log.Debugf("antigravity executor: create loadCodeAssist request error: %v", errReq) + return + } + httpReq.Header.Set("Authorization", "Bearer "+token) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "google-api-nodejs-client/9.15.1") + + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + log.Debugf("antigravity executor: loadCodeAssist request error: %v", errDo) + return + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("antigravity executor: close loadCodeAssist response body error: %v", errClose) + } + }() + + bodyBytes, errRead := io.ReadAll(httpResp.Body) + if errRead != nil || httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + log.Debugf("antigravity executor: loadCodeAssist returned status %d, err=%v", httpResp.StatusCode, errRead) + return + } + + authID := strings.TrimSpace(auth.ID) + paidTierID := strings.TrimSpace(gjson.GetBytes(bodyBytes, "paidTier.id").String()) + + credits := gjson.GetBytes(bodyBytes, "paidTier.availableCredits") + if !credits.IsArray() { + cliproxyauth.SetAntigravityCreditsHint(authID, cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: false, + PaidTierID: paidTierID, + UpdatedAt: time.Now(), + }) + return + } + for _, credit := range credits.Array() { + if !strings.EqualFold(credit.Get("creditType").String(), "GOOGLE_ONE_AI") { + continue + } + creditAmount, errCA := strconv.ParseFloat(strings.TrimSpace(credit.Get("creditAmount").String()), 64) + if errCA != nil { + continue + } + minAmount, errMA := strconv.ParseFloat(strings.TrimSpace(credit.Get("minimumCreditAmountForUsage").String()), 64) + if errMA != nil { + continue + } + bal := antigravityCreditsBalance{ + CreditAmount: creditAmount, + MinCreditAmount: minAmount, + PaidTierID: paidTierID, + Known: true, + } + antigravityCreditsBalanceByAuth.Store(authID, bal) + cliproxyauth.SetAntigravityCreditsHint(authID, cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: creditAmount >= minAmount, + CreditAmount: creditAmount, + MinCreditAmount: minAmount, + PaidTierID: paidTierID, + UpdatedAt: time.Now(), + }) + if creditAmount >= minAmount { + clearAntigravityCreditsPermanentlyDisabled(auth) + } + return + } +} + func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) { if token == "" { return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"} diff --git a/internal/runtime/executor/antigravity_executor_credits_test.go b/internal/runtime/executor/antigravity_executor_credits_test.go index cf968ac794..6e38223e50 100644 --- a/internal/runtime/executor/antigravity_executor_credits_test.go +++ b/internal/runtime/executor/antigravity_executor_credits_test.go @@ -18,8 +18,9 @@ import ( func resetAntigravityCreditsRetryState() { antigravityCreditsFailureByAuth = sync.Map{} - antigravityPreferCreditsByModel = sync.Map{} antigravityShortCooldownByAuth = sync.Map{} + antigravityCreditsBalanceByAuth = sync.Map{} + antigravityCreditsHintRefreshByID = sync.Map{} } func TestClassifyAntigravity429(t *testing.T) { @@ -30,6 +31,43 @@ func TestClassifyAntigravity429(t *testing.T) { } }) + t.Run("standard antigravity rate limit with ui message stays rate limited", func(t *testing.T) { + body := []byte(`{ + "error": { + "code": 429, + "message": "You have exhausted your capacity on this model. Your quota will reset after 0s.", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "RATE_LIMIT_EXCEEDED", + "domain": "cloudcode-pa.googleapis.com", + "metadata": { + "model": "claude-opus-4-6-thinking", + "quotaResetDelay": "479.417207ms", + "quotaResetTimeStamp": "2026-04-20T09:19:49Z", + "uiMessage": "true" + } + }, + { + "@type": "type.googleapis.com/google.rpc.RetryInfo", + "retryDelay": "0.479417207s" + } + ] + } + }`) + if got := classifyAntigravity429(body); got != antigravity429RateLimited { + t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited) + } + decision := decideAntigravity429(body) + if decision.kind != antigravity429DecisionInstantRetrySameAuth { + t.Fatalf("decideAntigravity429().kind = %q, want %q", decision.kind, antigravity429DecisionInstantRetrySameAuth) + } + if decision.retryAfter == nil { + t.Fatal("decideAntigravity429().retryAfter = nil") + } + }) + t.Run("structured rate limit", func(t *testing.T) { body := []byte(`{ "error": { @@ -67,8 +105,31 @@ func TestClassifyAntigravity429(t *testing.T) { }) } +func TestAntigravityShouldRetryNoCapacity_Standard503(t *testing.T) { + body := []byte(`{ + "error": { + "code": 503, + "message": "No capacity available for model gemini-3.1-flash-image on the server", + "status": "UNAVAILABLE", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "MODEL_CAPACITY_EXHAUSTED", + "domain": "cloudcode-pa.googleapis.com", + "metadata": { + "model": "gemini-3.1-flash-image" + } + } + ] + } + }`) + if !antigravityShouldRetryNoCapacity(http.StatusServiceUnavailable, body) { + t.Fatal("antigravityShouldRetryNoCapacity() = false, want true") + } +} + func TestInjectEnabledCreditTypes(t *testing.T) { - body := []byte(`{"model":"gemini-2.5-flash","request":{}}`) + body := []byte(`{"model":"claude-sonnet-4-6","request":{}}`) got := injectEnabledCreditTypes(body) if got == nil { t.Fatal("injectEnabledCreditTypes() returned nil") @@ -82,34 +143,18 @@ func TestInjectEnabledCreditTypes(t *testing.T) { } } -func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) { - t.Run("credit errors are marked", func(t *testing.T) { - for _, body := range [][]byte{ - []byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`), - []byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`), - } { - if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) { - t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body)) - } - } - }) - - t.Run("transient 429 resource exhausted is not marked", func(t *testing.T) { - body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`) - if shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) { - t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = true, want false", string(body)) - } - }) - - t.Run("resource exhausted with quota metadata is still marked", func(t *testing.T) { - body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted","status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"1h","model":"claude-sonnet-4-6"}}]}}`) - if !shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) { - t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body)) - } - }) - - if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) { - t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false") +func TestParseRetryDelay_HumanReadableDuration(t *testing.T) { + body := []byte(`{"error":{"message":"You have exhausted your capacity on this model. Your quota will reset after 1h43m56s."}}`) + retryAfter, err := parseRetryDelay(body) + if err != nil { + t.Fatalf("parseRetryDelay() error = %v", err) + } + if retryAfter == nil { + t.Fatal("parseRetryDelay() returned nil") + } + want := time.Hour + 43*time.Minute + 56*time.Second + if *retryAfter != want { + t.Fatalf("parseRetryDelay() = %v, want %v", *retryAfter, want) } } @@ -147,7 +192,7 @@ func TestAntigravityExecute_RetriesTransient429ResourceExhausted(t *testing.T) { } resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", + Model: "claude-sonnet-4-6", Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatAntigravity, @@ -163,32 +208,18 @@ func TestAntigravityExecute_RetriesTransient429ResourceExhausted(t *testing.T) { } } -func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { +func TestAntigravityExecute_CreditsInjectedWhenConductorRequests(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var ( - mu sync.Mutex - requestBodies []string - ) - + var requestBodies []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) _ = r.Body.Close() - - mu.Lock() requestBodies = append(requestBodies, string(body)) - reqNum := len(requestBodies) - mu.Unlock() - - if reqNum == 1 { - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)) - return - } if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("second request body missing enabledCreditTypes: %s", string(body)) + t.Fatalf("request body missing enabledCreditTypes: %s", string(body)) } w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`)) @@ -199,7 +230,7 @@ func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, }) auth := &cliproxyauth.Auth{ - ID: "auth-credits-ok", + ID: "auth-credits-conductor", Attributes: map[string]string{ "base_url": server.URL, }, @@ -210,8 +241,11 @@ func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { }, } - resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", + // Simulate conductor setting credits requested flag in context + ctx := cliproxyauth.WithAntigravityCredits(context.Background()) + + resp, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{ + Model: "claude-sonnet-4-6", Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatAntigravity, @@ -222,21 +256,20 @@ func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { if len(resp.Payload) == 0 { t.Fatal("Execute() returned empty payload") } - - mu.Lock() - defer mu.Unlock() - if len(requestBodies) != 2 { - t.Fatalf("request count = %d, want 2", len(requestBodies)) + if len(requestBodies) != 1 { + t.Fatalf("request count = %d, want 1", len(requestBodies)) } } -func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) { +func TestAntigravityExecute_NoCreditsWithoutConductorFlag(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var requestCount int + var requestBodies []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ + body, _ := io.ReadAll(r.Body) + _ = r.Body.Close() + requestBodies = append(requestBodies, string(body)) w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)) })) @@ -246,7 +279,7 @@ func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, }) auth := &cliproxyauth.Auth{ - ID: "auth-credits-exhausted", + ID: "auth-no-conductor-flag", Attributes: map[string]string{ "base_url": server.URL, }, @@ -256,10 +289,10 @@ func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), }, } - recordAntigravityCreditsFailure(auth, time.Now()) + // No conductor credits flag set in context _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", + Model: "claude-sonnet-4-6", Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatAntigravity, @@ -267,224 +300,159 @@ func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) if err == nil { t.Fatal("Execute() error = nil, want 429") } - sErr, ok := err.(statusErr) - if !ok { - t.Fatalf("Execute() error type = %T, want statusErr", err) - } - if got := sErr.StatusCode(); got != http.StatusTooManyRequests { - t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests) + if len(requestBodies) != 1 { + t.Fatalf("request count = %d, want 1", len(requestBodies)) } - if requestCount != 1 { - t.Fatalf("request count = %d, want 1", requestCount) + // Should NOT contain credits since conductor didn't request them + if strings.Contains(requestBodies[0], `"enabledCreditTypes"`) { + t.Fatalf("request should not contain enabledCreditTypes without conductor flag: %s", requestBodies[0]) } } -func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) { - resetAntigravityCreditsRetryState() - t.Cleanup(resetAntigravityCreditsRetryState) - - var ( - mu sync.Mutex - requestBodies []string - ) +func TestAntigravityAuthHasCredits(t *testing.T) { + t.Run("sufficient balance", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-sufficient"} + antigravityCreditsBalanceByAuth.Store("test-sufficient", antigravityCreditsBalance{ + CreditAmount: 25000, + MinCreditAmount: 50, + Known: true, + }) + if !antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = false, want true") + } + }) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - _ = r.Body.Close() + t.Run("insufficient balance", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-insufficient"} + antigravityCreditsBalanceByAuth.Store("test-insufficient", antigravityCreditsBalance{ + CreditAmount: 30, + MinCreditAmount: 50, + Known: true, + }) + if antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = true, want false") + } + }) - mu.Lock() - requestBodies = append(requestBodies, string(body)) - reqNum := len(requestBodies) - mu.Unlock() + t.Run("no balance stored returns true (optimistic)", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-no-balance"} + if !antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = false with no balance stored, want true (optimistic default)") + } + }) - switch reqNum { - case 1: - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`)) - case 2, 3: - if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body)) - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`)) - default: - t.Fatalf("unexpected request count %d", reqNum) + t.Run("nil auth returns false", func(t *testing.T) { + if antigravityAuthHasCredits(nil) { + t.Fatal("antigravityAuthHasCredits(nil) = true, want false") } - })) - defer server.Close() + }) - exec := NewAntigravityExecutor(&config.Config{ - QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, + t.Run("empty ID returns false", func(t *testing.T) { + auth := &cliproxyauth.Auth{} + if antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits(empty ID) = true, want false") + } }) - auth := &cliproxyauth.Auth{ - ID: "auth-prefer-credits", - Attributes: map[string]string{ - "base_url": server.URL, - }, - Metadata: map[string]any{ - "access_token": "token", - "project_id": "project-1", - "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), - }, - } - request := cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", - Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), - } - opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity} + t.Run("unknown balance returns false", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-unknown"} + antigravityCreditsBalanceByAuth.Store("test-unknown", antigravityCreditsBalance{ + Known: false, + }) + if antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = true for unknown balance, want false") + } + }) +} - if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil { - t.Fatalf("first Execute() error = %v", err) - } - if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil { - t.Fatalf("second Execute() error = %v", err) - } +type roundTripperFunc func(*http.Request) (*http.Response, error) - mu.Lock() - defer mu.Unlock() - if len(requestBodies) != 3 { - t.Fatalf("request count = %d, want 3", len(requestBodies)) - } - if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0]) - } - if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("fallback request missing credits: %s", requestBodies[1]) - } - if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("preferred request missing credits: %s", requestBodies[2]) - } +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) } -func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) { +func TestEnsureAccessToken_WarmTokenLoadsCreditsHint(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var ( - mu sync.Mutex - firstCount int - secondCount int - ) - - firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - - mu.Lock() - firstCount++ - reqNum := firstCount - mu.Unlock() - - switch reqNum { - case 1: - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`)) - case 2: - if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body)) - } - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`)) - default: - t.Fatalf("unexpected first server request count %d", reqNum) - } - })) - defer firstServer.Close() - - secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - secondCount++ - mu.Unlock() - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`)) - })) - defer secondServer.Close() - exec := NewAntigravityExecutor(&config.Config{ QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, }) auth := &cliproxyauth.Auth{ - ID: "auth-baseurl-fallback", - Attributes: map[string]string{ - "base_url": firstServer.URL, - }, + ID: "auth-warm-token-credits", Metadata: map[string]any{ "access_token": "token", - "project_id": "project-1", "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), }, } + ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" { + t.Fatalf("unexpected request url %s", req.URL.String()) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"paidTier":{"id":"tier-1","availableCredits":[{"creditType":"GOOGLE_ONE_AI","creditAmount":"25000","minimumCreditAmountForUsage":"50"}]}}`)), + }, nil + })) - originalOrder := antigravityBaseURLFallbackOrder - defer func() { antigravityBaseURLFallbackOrder = originalOrder }() - antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string { - return []string{firstServer.URL, secondServer.URL} - } - - resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", - Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FormatAntigravity, - }) + token, updatedAuth, err := exec.ensureAccessToken(ctx, auth) if err != nil { - t.Fatalf("Execute() error = %v", err) + t.Fatalf("ensureAccessToken() error = %v", err) } - if len(resp.Payload) == 0 { - t.Fatal("Execute() returned empty payload") + if token != "token" { + t.Fatalf("ensureAccessToken() token = %q, want %q", token, "token") } - if firstCount != 2 { - t.Fatalf("first server request count = %d, want 2", firstCount) + if updatedAuth != nil { + t.Fatalf("ensureAccessToken() updatedAuth = %v, want nil", updatedAuth) } - if secondCount != 1 { - t.Fatalf("second server request count = %d, want 1", secondCount) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) && !cliproxyauth.HasKnownAntigravityCreditsHint(auth.ID) { + time.Sleep(10 * time.Millisecond) } -} - -func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) { - resetAntigravityCreditsRetryState() - t.Cleanup(resetAntigravityCreditsRetryState) - - var requestBodies []string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - requestBodies = append(requestBodies, string(body)) - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)) - })) - defer server.Close() - - exec := NewAntigravityExecutor(&config.Config{ - QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false}, - }) - auth := &cliproxyauth.Auth{ - ID: "auth-flag-disabled", - Attributes: map[string]string{ - "base_url": server.URL, - }, - Metadata: map[string]any{ - "access_token": "token", - "project_id": "project-1", - "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), - }, + if !cliproxyauth.HasKnownAntigravityCreditsHint(auth.ID) { + t.Fatal("expected credits hint to be populated for warm token auth") } - markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil) - - _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", - Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FormatAntigravity, - }) - if err == nil { - t.Fatal("Execute() error = nil, want 429") + hint, ok := cliproxyauth.GetAntigravityCreditsHint(auth.ID) + if !ok { + t.Fatal("expected credits hint lookup to succeed") } - if len(requestBodies) != 1 { - t.Fatalf("request count = %d, want 1", len(requestBodies)) + if !hint.Available { + t.Fatalf("hint.Available = %v, want true", hint.Available) } - if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0]) + if hint.CreditAmount != 25000 || hint.MinCreditAmount != 50 { + t.Fatalf("hint amounts = (%v, %v), want (25000, 50)", hint.CreditAmount, hint.MinCreditAmount) + } +} + +func TestParseMetaFloat(t *testing.T) { + tests := []struct { + name string + value any + wantVal float64 + wantOK bool + }{ + {"string", "25000", 25000, true}, + {"float64", float64(100), 100, true}, + {"int", int(50), 50, true}, + {"int64", int64(75), 75, true}, + {"empty string", "", 0, false}, + {"invalid string", "abc", 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + meta := map[string]any{"key": tt.value} + got, ok := parseMetaFloat(meta, "key") + if ok != tt.wantOK { + t.Fatalf("parseMetaFloat() ok = %v, want %v", ok, tt.wantOK) + } + if ok && got != tt.wantVal { + t.Fatalf("parseMetaFloat() = %f, want %f", got, tt.wantVal) + } + }) } } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 0311827bae..235db1f3b2 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "net/http" - "net/textproto" "strings" "time" @@ -911,15 +910,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, baseBetas += ",interleaved-thinking-2025-05-14" } - hasClaude1MHeader := false - if ginHeaders != nil { - if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok { - hasClaude1MHeader = true - } - } - // Merge extra betas from request body and request flags. - if len(extraBetas) > 0 || hasClaude1MHeader { + if len(extraBetas) > 0 { existingSet := make(map[string]bool) for _, b := range strings.Split(baseBetas, ",") { betaName := strings.TrimSpace(b) @@ -934,9 +926,6 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, existingSet[beta] = true } } - if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] { - baseBetas += ",context-1m-2025-08-07" - } } r.Header.Set("Anthropic-Beta", baseBetas) diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index f456064dc6..c1ce8fc088 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -1714,7 +1714,27 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity } } -// Test case 1: String system prompt is preserved and converted to a content block +func expectedClaudeCodeStaticPrompt() string { + return strings.Join([]string{ + helps.ClaudeCodeIntro, + helps.ClaudeCodeSystem, + helps.ClaudeCodeDoingTasks, + helps.ClaudeCodeToneAndStyle, + helps.ClaudeCodeOutputEfficiency, + }, "\n\n") +} + +func expectedForwardedSystemReminder(text string) string { + return fmt.Sprintf(` +As you answer the user's questions, you can use the following context from the system: +%s + +IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task. + +`, text) +} + +// Test case 1: String system prompt is preserved by forwarding it to the first user message func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) { payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`) @@ -1733,42 +1753,52 @@ func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) { if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") { t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String()) } - if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." { + if blocks[1].Get("text").String() != "You are Claude Code, Anthropic's official CLI for Claude." { t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String()) } - if blocks[2].Get("text").String() != "You are a helpful assistant." { - t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String()) + if blocks[2].Get("text").String() != expectedClaudeCodeStaticPrompt() { + t.Fatalf("blocks[2] should be static Claude Code prompt, got %q", blocks[2].Get("text").String()) + } + if blocks[2].Get("cache_control").Exists() { + t.Fatalf("blocks[2] should not have cache_control, got %s", blocks[2].Get("cache_control").Raw) } - if blocks[2].Get("cache_control.type").String() != "ephemeral" { - t.Fatalf("blocks[2] should have cache_control.type=ephemeral") + + if got := gjson.GetBytes(out, "messages.0.content").String(); got != expectedForwardedSystemReminder("You are a helpful assistant.")+"hi" { + t.Fatalf("messages[0].content should include forwarded system prompt, got %q", got) } } -// Test case 2: Strict mode drops the string system prompt +// Test case 2: Strict mode keeps only the injected Claude Code system blocks func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) { payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`) out := checkSystemInstructionsWithMode(payload, true) blocks := gjson.GetBytes(out, "system").Array() - if len(blocks) != 2 { - t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks)) + if len(blocks) != 3 { + t.Fatalf("strict mode should produce 3 injected blocks, got %d", len(blocks)) + } + if got := gjson.GetBytes(out, "messages.0.content").String(); got != "hi" { + t.Fatalf("strict mode should not forward system prompt into messages, got %q", got) } } -// Test case 3: Empty string system prompt does not produce a spurious block +// Test case 3: Empty string system prompt does not alter the first user message func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) { payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`) out := checkSystemInstructionsWithMode(payload, false) blocks := gjson.GetBytes(out, "system").Array() - if len(blocks) != 2 { - t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks)) + if len(blocks) != 3 { + t.Fatalf("empty string system should still produce 3 injected blocks, got %d", len(blocks)) + } + if got := gjson.GetBytes(out, "messages.0.content").String(); got != "hi" { + t.Fatalf("empty string system should not alter messages, got %q", got) } } -// Test case 4: Array system prompt is unaffected by the string handling +// Test case 4: Array system prompt is forwarded to the first user message func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) { payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`) @@ -1778,12 +1808,15 @@ func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) { if len(blocks) != 3 { t.Fatalf("expected 3 system blocks, got %d", len(blocks)) } - if blocks[2].Get("text").String() != "Be concise." { - t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String()) + if blocks[2].Get("text").String() != expectedClaudeCodeStaticPrompt() { + t.Fatalf("blocks[2] should be static Claude Code prompt, got %q", blocks[2].Get("text").String()) + } + if got := gjson.GetBytes(out, "messages.0.content").String(); got != expectedForwardedSystemReminder("Be concise.")+"hi" { + t.Fatalf("messages[0].content should include forwarded array system prompt, got %q", got) } } -// Test case 5: Special characters in string system prompt survive conversion +// Test case 5: Special characters in string system prompt survive forwarding func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) { payload := []byte(`{"system":"Use tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`) @@ -1793,8 +1826,8 @@ func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) { if len(blocks) != 3 { t.Fatalf("expected 3 system blocks, got %d", len(blocks)) } - if blocks[2].Get("text").String() != `Use tags & "quotes" in output.` { - t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String()) + if got := gjson.GetBytes(out, "messages.0.content").String(); got != expectedForwardedSystemReminder(`Use tags & "quotes" in output.`)+"hi" { + t.Fatalf("forwarded system prompt text mangled, got %q", got) } } @@ -1902,8 +1935,11 @@ func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmi out := applyCloaking(context.Background(), cfg, auth, payload, "claude-3-5-sonnet-20241022", "key-123") blocks := gjson.GetBytes(out, "system").Array() - if len(blocks) != 2 { - t.Fatalf("expected strict mode to keep only injected system blocks, got %d", len(blocks)) + if len(blocks) != 3 { + t.Fatalf("expected strict mode to keep the 3 injected Claude Code system blocks, got %d", len(blocks)) + } + if got := gjson.GetBytes(out, "messages.0.content.#").Int(); got != 1 { + t.Fatalf("strict mode should not prepend a forwarded system reminder block, got %d content blocks", got) } if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); !strings.Contains(got, "\u200B") { t.Fatalf("expected configured sensitive word obfuscation to apply, got %q", got) diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 41b1c32527..38667231aa 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -36,6 +36,69 @@ const ( var dataTag = []byte("data:") +// Streamed Codex responses may emit response.output_item.done events while leaving +// response.completed.response.output empty. Keep the stream path aligned with the +// already-patched non-stream path by reconstructing response.output from those items. +func collectCodexOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) { + itemResult := gjson.GetBytes(eventData, "item") + if !itemResult.Exists() || itemResult.Type != gjson.JSON { + return + } + outputIndexResult := gjson.GetBytes(eventData, "output_index") + if outputIndexResult.Exists() { + outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw) + return + } + *outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw)) +} + +func patchCodexCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte { + outputResult := gjson.GetBytes(eventData, "response.output") + shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0) + if !shouldPatchOutput { + return eventData + } + + indexes := make([]int64, 0, len(outputItemsByIndex)) + for idx := range outputItemsByIndex { + indexes = append(indexes, idx) + } + sort.Slice(indexes, func(i, j int) bool { + return indexes[i] < indexes[j] + }) + + items := make([][]byte, 0, len(outputItemsByIndex)+len(outputItemsFallback)) + for _, idx := range indexes { + items = append(items, outputItemsByIndex[idx]) + } + items = append(items, outputItemsFallback...) + + outputArray := []byte("[]") + if len(items) > 0 { + var buf bytes.Buffer + totalLen := 2 + for _, item := range items { + totalLen += len(item) + } + if len(items) > 1 { + totalLen += len(items) - 1 + } + buf.Grow(totalLen) + buf.WriteByte('[') + for i, item := range items { + if i > 0 { + buf.WriteByte(',') + } + buf.Write(item) + } + buf.WriteByte(']') + outputArray = buf.Bytes() + } + + completedDataPatched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray) + return completedDataPatched +} + // CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint). // If api_key is unavailable on auth, it falls back to legacy via ClientAdapter. type CodexExecutor struct { @@ -117,6 +180,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body, _ = sjson.DeleteBytes(body, "safety_identifier") body, _ = sjson.DeleteBytes(body, "stream_options") body = normalizeCodexInstructions(body) + body = ensureImageGenerationTool(body, baseModel) url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -263,6 +327,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.DeleteBytes(body, "stream") body = normalizeCodexInstructions(body) + body = ensureImageGenerationTool(body, baseModel) url := strings.TrimSuffix(baseURL, "/") + "/responses/compact" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -357,6 +422,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body, _ = sjson.DeleteBytes(body, "stream_options") body, _ = sjson.SetBytes(body, "model", baseModel) body = normalizeCodexInstructions(body) + body = ensureImageGenerationTool(body, baseModel) url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -414,20 +480,28 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au scanner := bufio.NewScanner(httpResp.Body) scanner.Buffer(nil, 52_428_800) // 50MB var param any + outputItemsByIndex := make(map[int64][]byte) + var outputItemsFallback [][]byte for scanner.Scan() { line := scanner.Bytes() helps.AppendAPIResponseChunk(ctx, e.cfg, line) + translatedLine := bytes.Clone(line) if bytes.HasPrefix(line, dataTag) { data := bytes.TrimSpace(line[5:]) - if gjson.GetBytes(data, "type").String() == "response.completed" { + switch gjson.GetBytes(data, "type").String() { + case "response.output_item.done": + collectCodexOutputItemDone(data, outputItemsByIndex, &outputItemsFallback) + case "response.completed": if detail, ok := helps.ParseCodexUsage(data); ok { reporter.Publish(ctx, detail) } + data = patchCodexCompletedOutput(data, outputItemsByIndex, outputItemsFallback) + translatedLine = append([]byte("data: "), data...) } } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m) + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, translatedLine, ¶m) for i := range chunks { out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} } @@ -750,6 +824,28 @@ func normalizeCodexInstructions(body []byte) []byte { return body } +var imageGenToolJSON = []byte(`{"type":"image_generation","output_format":"png"}`) +var imageGenToolArrayJSON = []byte(`[{"type":"image_generation","output_format":"png"}]`) + +func ensureImageGenerationTool(body []byte, baseModel string) []byte { + if strings.HasSuffix(baseModel, "spark") { + return body + } + + tools := gjson.GetBytes(body, "tools") + if !tools.Exists() || !tools.IsArray() { + body, _ = sjson.SetRawBytes(body, "tools", imageGenToolArrayJSON) + return body + } + for _, t := range tools.Array() { + if t.Get("type").String() == "image_generation" { + return body + } + } + body, _ = sjson.SetRawBytes(body, "tools.-1", imageGenToolJSON) + return body +} + func isCodexModelCapacityError(errorBody []byte) bool { if len(errorBody) == 0 { return false diff --git a/internal/runtime/executor/codex_executor_imagegen_test.go b/internal/runtime/executor/codex_executor_imagegen_test.go new file mode 100644 index 0000000000..5e67c598a4 --- /dev/null +++ b/internal/runtime/executor/codex_executor_imagegen_test.go @@ -0,0 +1,101 @@ +package executor + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestEnsureImageGenerationTool_NoTools(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","input":"draw a cat"}`) + result := ensureImageGenerationTool(body, "gpt-5.4") + + tools := gjson.GetBytes(result, "tools") + if !tools.IsArray() { + t.Fatalf("expected tools array, got %v", tools.Type) + } + arr := tools.Array() + if len(arr) != 1 { + t.Fatalf("expected 1 tool, got %d", len(arr)) + } + if arr[0].Get("type").String() != "image_generation" { + t.Fatalf("expected type=image_generation, got %s", arr[0].Get("type").String()) + } + if arr[0].Get("output_format").String() != "png" { + t.Fatalf("expected output_format=png, got %s", arr[0].Get("output_format").String()) + } +} + +func TestEnsureImageGenerationTool_ExistingToolsWithoutImageGen(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","name":"get_weather","parameters":{}}]}`) + result := ensureImageGenerationTool(body, "gpt-5.4") + + tools := gjson.GetBytes(result, "tools") + arr := tools.Array() + if len(arr) != 2 { + t.Fatalf("expected 2 tools, got %d", len(arr)) + } + if arr[0].Get("type").String() != "function" { + t.Fatalf("expected first tool type=function, got %s", arr[0].Get("type").String()) + } + if arr[1].Get("type").String() != "image_generation" { + t.Fatalf("expected second tool type=image_generation, got %s", arr[1].Get("type").String()) + } +} + +func TestEnsureImageGenerationTool_AlreadyPresent(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","output_format":"webp"},{"type":"function","name":"f1"}]}`) + result := ensureImageGenerationTool(body, "gpt-5.4") + + tools := gjson.GetBytes(result, "tools") + arr := tools.Array() + if len(arr) != 2 { + t.Fatalf("expected 2 tools (no duplicate), got %d", len(arr)) + } + if arr[0].Get("output_format").String() != "webp" { + t.Fatalf("expected original output_format=webp preserved, got %s", arr[0].Get("output_format").String()) + } +} + +func TestEnsureImageGenerationTool_EmptyToolsArray(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[]}`) + result := ensureImageGenerationTool(body, "gpt-5.4") + + tools := gjson.GetBytes(result, "tools") + arr := tools.Array() + if len(arr) != 1 { + t.Fatalf("expected 1 tool, got %d", len(arr)) + } + if arr[0].Get("type").String() != "image_generation" { + t.Fatalf("expected type=image_generation, got %s", arr[0].Get("type").String()) + } +} + +func TestEnsureImageGenerationTool_WebSearchAndImageGen(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[{"type":"web_search"}]}`) + result := ensureImageGenerationTool(body, "gpt-5.4") + + tools := gjson.GetBytes(result, "tools") + arr := tools.Array() + if len(arr) != 2 { + t.Fatalf("expected 2 tools, got %d", len(arr)) + } + if arr[0].Get("type").String() != "web_search" { + t.Fatalf("expected first tool type=web_search, got %s", arr[0].Get("type").String()) + } + if arr[1].Get("type").String() != "image_generation" { + t.Fatalf("expected second tool type=image_generation, got %s", arr[1].Get("type").String()) + } +} + +func TestEnsureImageGenerationTool_GPT53CodexSparkDoesNotInjectTool(t *testing.T) { + body := []byte(`{"model":"gpt-5.3-codex-spark","input":"draw a cat"}`) + result := ensureImageGenerationTool(body, "gpt-5.3-codex-spark") + + if string(result) != string(body) { + t.Fatalf("expected body to be unchanged, got %s", string(result)) + } + if gjson.GetBytes(result, "tools").Exists() { + t.Fatalf("expected no tools for gpt-5.3-codex-spark, got %s", gjson.GetBytes(result, "tools").Raw) + } +} diff --git a/internal/runtime/executor/codex_executor_stream_output_test.go b/internal/runtime/executor/codex_executor_stream_output_test.go index 91d9b0761c..a2da45e199 100644 --- a/internal/runtime/executor/codex_executor_stream_output_test.go +++ b/internal/runtime/executor/codex_executor_stream_output_test.go @@ -1,6 +1,7 @@ package executor import ( + "bytes" "context" "net/http" "net/http/httptest" @@ -44,3 +45,53 @@ func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *t t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload)) } } + +func TestCodexExecutorExecuteStream_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n")) + _, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n")) + })) + defer server.Close() + + executor := NewCodexExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL, + "api_key": "test", + }} + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-5.4-mini", + Payload: []byte(`{"model":"gpt-5.4-mini","input":"Say ok"}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai-response"), + Stream: true, + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + var completed []byte + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("stream chunk error: %v", chunk.Err) + } + payload := bytes.TrimSpace(chunk.Payload) + if !bytes.HasPrefix(payload, []byte("data:")) { + continue + } + data := bytes.TrimSpace(payload[5:]) + if gjson.GetBytes(data, "type").String() == "response.completed" { + completed = append([]byte(nil), data...) + } + } + + if len(completed) == 0 { + t.Fatal("missing response.completed chunk") + } + + gotContent := gjson.GetBytes(completed, "response.output.0.content.0.text").String() + if gotContent != "ok" { + t.Fatalf("response.output[0].content[0].text = %q, want %q; completed=%s", gotContent, "ok", string(completed)) + } +} diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index d2df610966..a18f824a62 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -898,7 +898,14 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) { if matches := re.FindStringSubmatch(message); len(matches) > 1 { seconds, err := strconv.Atoi(matches[1]) if err == nil { - return new(time.Duration(seconds) * time.Second), nil + duration := time.Duration(seconds) * time.Second + return &duration, nil + } + } + reHuman := regexp.MustCompile(`after\s+((?:\d+h)?(?:\d+m)?(?:\d+s)?)\.?`) + if matches := reHuman.FindStringSubmatch(strings.ToLower(message)); len(matches) > 1 { + if duration, err := time.ParseDuration(matches[1]); err == nil && duration > 0 { + return &duration, nil } } } diff --git a/internal/runtime/executor/helps/logging_helpers.go b/internal/runtime/executor/helps/logging_helpers.go index 767c882016..a0b30f7099 100644 --- a/internal/runtime/executor/helps/logging_helpers.go +++ b/internal/runtime/executor/helps/logging_helpers.go @@ -24,6 +24,7 @@ const ( apiRequestKey = "API_REQUEST" apiResponseKey = "API_RESPONSE" apiWebsocketTimelineKey = "API_WEBSOCKET_TIMELINE" + creditsUsedKey = "__antigravity_credits_used__" ) // UpstreamRequestLog captures the outbound upstream request details for logging. @@ -568,3 +569,24 @@ func LogWithRequestID(ctx context.Context) *log.Entry { } return log.WithField("request_id", requestID) } + +// MarkCreditsUsed flags the request as having used AI credits for billing. +func MarkCreditsUsed(ctx context.Context) { + ginCtx := ginContextFrom(ctx) + if ginCtx != nil { + ginCtx.Set(creditsUsedKey, true) + } +} + +// CreditsUsed returns true if the request used AI credits. +func CreditsUsed(ctx context.Context) bool { + ginCtx := ginContextFrom(ctx) + if ginCtx != nil { + if val, exists := ginCtx.Get(creditsUsedKey); exists { + if b, ok := val.(bool); ok { + return b + } + } + } + return false +} diff --git a/internal/translator/codex/gemini/codex_gemini_response.go b/internal/translator/codex/gemini/codex_gemini_response.go index f6ef87710a..a2e4e20ea2 100644 --- a/internal/translator/codex/gemini/codex_gemini_response.go +++ b/internal/translator/codex/gemini/codex_gemini_response.go @@ -7,6 +7,8 @@ package gemini import ( "bytes" "context" + "crypto/sha256" + "strings" "time" translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common" @@ -25,6 +27,7 @@ type ConvertCodexResponseToGeminiParams struct { ResponseID string LastStorageOutput []byte HasOutputTextDelta bool + LastImageHashByID map[string][32]byte } // ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format. @@ -48,6 +51,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR ResponseID: "", LastStorageOutput: nil, HasOutputTextDelta: false, + LastImageHashByID: make(map[string][32]byte), } } @@ -74,10 +78,63 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR template, _ = sjson.SetBytes(template, "responseId", params.ResponseID) } + if typeStr == "response.image_generation_call.partial_image" { + itemID := rootResult.Get("item_id").String() + b64 := rootResult.Get("partial_image_b64").String() + if b64 == "" { + return [][]byte{} + } + if itemID != "" { + if params.LastImageHashByID == nil { + params.LastImageHashByID = make(map[string][32]byte) + } + hash := sha256.Sum256([]byte(b64)) + if last, ok := params.LastImageHashByID[itemID]; ok && last == hash { + return [][]byte{} + } + params.LastImageHashByID[itemID] = hash + } + + outputFormat := rootResult.Get("output_format").String() + mimeType := mimeTypeFromCodexOutputFormat(outputFormat) + + part := []byte(`{"inlineData":{"data":"","mimeType":""}}`) + part, _ = sjson.SetBytes(part, "inlineData.data", b64) + part, _ = sjson.SetBytes(part, "inlineData.mimeType", mimeType) + template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part) + return [][]byte{template} + } + // Handle function call completion if typeStr == "response.output_item.done" { itemResult := rootResult.Get("item") itemType := itemResult.Get("type").String() + if itemType == "image_generation_call" { + itemID := itemResult.Get("id").String() + b64 := itemResult.Get("result").String() + if b64 == "" { + return [][]byte{} + } + if itemID != "" { + if params.LastImageHashByID == nil { + params.LastImageHashByID = make(map[string][32]byte) + } + hash := sha256.Sum256([]byte(b64)) + if last, ok := params.LastImageHashByID[itemID]; ok && last == hash { + return [][]byte{} + } + params.LastImageHashByID[itemID] = hash + } + + outputFormat := itemResult.Get("output_format").String() + mimeType := mimeTypeFromCodexOutputFormat(outputFormat) + + part := []byte(`{"inlineData":{"data":"","mimeType":""}}`) + part, _ = sjson.SetBytes(part, "inlineData.data", b64) + part, _ = sjson.SetBytes(part, "inlineData.mimeType", mimeType) + template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part) + return [][]byte{template} + } if itemType == "function_call" { // Create function call part functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`) @@ -270,6 +327,20 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, }) } + case "image_generation_call": + flushPendingFunctionCalls() + b64 := value.Get("result").String() + if b64 == "" { + break + } + outputFormat := value.Get("output_format").String() + mimeType := mimeTypeFromCodexOutputFormat(outputFormat) + + part := []byte(`{"inlineData":{"data":"","mimeType":""}}`) + part, _ = sjson.SetBytes(part, "inlineData.data", b64) + part, _ = sjson.SetBytes(part, "inlineData.mimeType", mimeType) + template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part) + case "function_call": // Collect function call for potential merging with consecutive ones hasToolCall = true @@ -342,3 +413,24 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string { func GeminiTokenCount(ctx context.Context, count int64) []byte { return translatorcommon.GeminiTokenCountJSON(count) } + +func mimeTypeFromCodexOutputFormat(outputFormat string) string { + if outputFormat == "" { + return "image/png" + } + if strings.Contains(outputFormat, "/") { + return outputFormat + } + switch strings.ToLower(outputFormat) { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "webp": + return "image/webp" + case "gif": + return "image/gif" + default: + return "image/png" + } +} diff --git a/internal/translator/codex/gemini/codex_gemini_response_test.go b/internal/translator/codex/gemini/codex_gemini_response_test.go index b8f227beb5..547ee84715 100644 --- a/internal/translator/codex/gemini/codex_gemini_response_test.go +++ b/internal/translator/codex/gemini/codex_gemini_response_test.go @@ -33,3 +33,79 @@ func TestConvertCodexResponseToGemini_StreamEmptyOutputUsesOutputItemDoneMessage t.Fatalf("expected fallback content from response.output_item.done message; outputs=%q", outputs) } } + +func TestConvertCodexResponseToGemini_StreamPartialImageEmitsInlineData(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"tools":[]}`) + var param any + + chunk := []byte(`data: {"type":"response.image_generation_call.partial_image","item_id":"ig_123","output_format":"png","partial_image_b64":"aGVsbG8=","partial_image_index":0}`) + out := ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, chunk, ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + got := gjson.GetBytes(out[0], "candidates.0.content.parts.0.inlineData.data").String() + if got != "aGVsbG8=" { + t.Fatalf("expected inlineData.data %q, got %q; chunk=%s", "aGVsbG8=", got, string(out[0])) + } + + gotMime := gjson.GetBytes(out[0], "candidates.0.content.parts.0.inlineData.mimeType").String() + if gotMime != "image/png" { + t.Fatalf("expected inlineData.mimeType %q, got %q; chunk=%s", "image/png", gotMime, string(out[0])) + } + + out = ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, chunk, ¶m) + if len(out) != 0 { + t.Fatalf("expected duplicate image chunk to be suppressed, got %d", len(out)) + } +} + +func TestConvertCodexResponseToGemini_StreamImageGenerationCallDoneEmitsInlineData(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"tools":[]}`) + var param any + + out := ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, []byte(`data: {"type":"response.image_generation_call.partial_image","item_id":"ig_123","output_format":"png","partial_image_b64":"aGVsbG8=","partial_image_index":0}`), ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + out = ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, []byte(`data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","output_format":"png","result":"aGVsbG8="}}`), ¶m) + if len(out) != 0 { + t.Fatalf("expected output_item.done to be suppressed when identical to last partial image, got %d", len(out)) + } + + out = ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, []byte(`data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","output_format":"jpeg","result":"Ymll"}}`), ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + got := gjson.GetBytes(out[0], "candidates.0.content.parts.0.inlineData.data").String() + if got != "Ymll" { + t.Fatalf("expected inlineData.data %q, got %q; chunk=%s", "Ymll", got, string(out[0])) + } + + gotMime := gjson.GetBytes(out[0], "candidates.0.content.parts.0.inlineData.mimeType").String() + if gotMime != "image/jpeg" { + t.Fatalf("expected inlineData.mimeType %q, got %q; chunk=%s", "image/jpeg", gotMime, string(out[0])) + } +} + +func TestConvertCodexResponseToGemini_NonStreamImageGenerationCallAddsInlineDataPart(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"tools":[]}`) + + raw := []byte(`{"type":"response.completed","response":{"id":"resp_123","created_at":1700000000,"usage":{"input_tokens":1,"output_tokens":1},"output":[{"type":"message","content":[{"type":"output_text","text":"ok"}]},{"type":"image_generation_call","output_format":"png","result":"aGVsbG8="}]}}`) + out := ConvertCodexResponseToGeminiNonStream(ctx, "gemini-2.5-pro", originalRequest, nil, raw, nil) + + got := gjson.GetBytes(out, "candidates.0.content.parts.1.inlineData.data").String() + if got != "aGVsbG8=" { + t.Fatalf("expected inlineData.data %q, got %q; chunk=%s", "aGVsbG8=", got, string(out)) + } + + gotMime := gjson.GetBytes(out, "candidates.0.content.parts.1.inlineData.mimeType").String() + if gotMime != "image/png" { + t.Fatalf("expected inlineData.mimeType %q, got %q; chunk=%s", "image/png", gotMime, string(out)) + } +} diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_response.go b/internal/translator/codex/openai/chat-completions/codex_openai_response.go index afae35d48d..75b5b848b3 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_response.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_response.go @@ -8,6 +8,8 @@ package chat_completions import ( "bytes" "context" + "crypto/sha256" + "strings" "time" "github.com/tidwall/gjson" @@ -26,6 +28,7 @@ type ConvertCliToOpenAIParams struct { FunctionCallIndex int HasReceivedArgumentsDelta bool HasToolCallAnnounced bool + LastImageHashByItemID map[string][32]byte } // ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the @@ -51,6 +54,7 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR FunctionCallIndex: -1, HasReceivedArgumentsDelta: false, HasToolCallAnnounced: false, + LastImageHashByItemID: make(map[string][32]byte), } } @@ -70,6 +74,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR (*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String() (*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int() (*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String() + if (*param).(*ConvertCliToOpenAIParams).LastImageHashByItemID == nil { + (*param).(*ConvertCliToOpenAIParams).LastImageHashByItemID = make(map[string][32]byte) + } return [][]byte{} } @@ -120,6 +127,39 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.content", deltaResult.String()) } + } else if dataType == "response.image_generation_call.partial_image" { + itemID := rootResult.Get("item_id").String() + b64 := rootResult.Get("partial_image_b64").String() + if b64 == "" { + return [][]byte{} + } + if itemID != "" { + p := (*param).(*ConvertCliToOpenAIParams) + if p.LastImageHashByItemID == nil { + p.LastImageHashByItemID = make(map[string][32]byte) + } + hash := sha256.Sum256([]byte(b64)) + if last, ok := p.LastImageHashByItemID[itemID]; ok && last == hash { + return [][]byte{} + } + p.LastImageHashByItemID[itemID] = hash + } + + outputFormat := rootResult.Get("output_format").String() + mimeType := mimeTypeFromCodexOutputFormat(outputFormat) + imageURL := "data:" + mimeType + ";base64," + b64 + + imagesResult := gjson.GetBytes(template, "choices.0.delta.images") + if !imagesResult.Exists() || !imagesResult.IsArray() { + template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`)) + } + imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array()) + imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`) + imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex) + imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL) + + template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant") + template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload) } else if dataType == "response.completed" { finishReason := "stop" if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 { @@ -183,7 +223,46 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR } else if dataType == "response.output_item.done" { itemResult := rootResult.Get("item") - if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" { + if !itemResult.Exists() { + return [][]byte{} + } + itemType := itemResult.Get("type").String() + if itemType == "image_generation_call" { + itemID := itemResult.Get("id").String() + b64 := itemResult.Get("result").String() + if b64 == "" { + return [][]byte{} + } + if itemID != "" { + p := (*param).(*ConvertCliToOpenAIParams) + if p.LastImageHashByItemID == nil { + p.LastImageHashByItemID = make(map[string][32]byte) + } + hash := sha256.Sum256([]byte(b64)) + if last, ok := p.LastImageHashByItemID[itemID]; ok && last == hash { + return [][]byte{} + } + p.LastImageHashByItemID[itemID] = hash + } + + outputFormat := itemResult.Get("output_format").String() + mimeType := mimeTypeFromCodexOutputFormat(outputFormat) + imageURL := "data:" + mimeType + ";base64," + b64 + + imagesResult := gjson.GetBytes(template, "choices.0.delta.images") + if !imagesResult.Exists() || !imagesResult.IsArray() { + template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`)) + } + imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array()) + imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`) + imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex) + imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL) + + template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant") + template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload) + return [][]byte{template} + } + if itemType != "function_call" { return [][]byte{} } @@ -285,6 +364,7 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original // Process the output array for content and function calls var toolCalls [][]byte + var images [][]byte outputResult := responseResult.Get("output") if outputResult.IsArray() { outputArray := outputResult.Array() @@ -339,6 +419,19 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original } toolCalls = append(toolCalls, functionCallTemplate) + case "image_generation_call": + b64 := outputItem.Get("result").String() + if b64 == "" { + break + } + outputFormat := outputItem.Get("output_format").String() + mimeType := mimeTypeFromCodexOutputFormat(outputFormat) + imageURL := "data:" + mimeType + ";base64," + b64 + + imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`) + imagePayload, _ = sjson.SetBytes(imagePayload, "index", len(images)) + imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL) + images = append(images, imagePayload) } } @@ -361,6 +454,15 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original } template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant") } + + // Add images if any + if len(images) > 0 { + template, _ = sjson.SetRawBytes(template, "choices.0.message.images", []byte(`[]`)) + for _, image := range images { + template, _ = sjson.SetRawBytes(template, "choices.0.message.images.-1", image) + } + template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant") + } } // Extract and set the finish reason based on status @@ -409,3 +511,24 @@ func buildReverseMapFromOriginalOpenAI(original []byte) map[string]string { } return rev } + +func mimeTypeFromCodexOutputFormat(outputFormat string) string { + if outputFormat == "" { + return "image/png" + } + if strings.Contains(outputFormat, "/") { + return outputFormat + } + switch strings.ToLower(outputFormat) { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "webp": + return "image/webp" + case "gif": + return "image/gif" + default: + return "image/png" + } +} diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go b/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go index 534884c229..a6bb486fdf 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go @@ -90,3 +90,62 @@ func TestConvertCodexResponseToOpenAI_ToolCallArgumentsDeltaOmitsNullContentFiel t.Fatalf("expected tool call arguments delta to exist, got %s", string(out[0])) } } + +func TestConvertCodexResponseToOpenAI_StreamPartialImageEmitsDeltaImages(t *testing.T) { + ctx := context.Background() + var param any + + chunk := []byte(`data: {"type":"response.image_generation_call.partial_image","item_id":"ig_123","output_format":"png","partial_image_b64":"aGVsbG8=","partial_image_index":0}`) + + out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, chunk, ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + gotURL := gjson.GetBytes(out[0], "choices.0.delta.images.0.image_url.url").String() + if gotURL != "data:image/png;base64,aGVsbG8=" { + t.Fatalf("expected image url %q, got %q; chunk=%s", "data:image/png;base64,aGVsbG8=", gotURL, string(out[0])) + } + + out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, chunk, ¶m) + if len(out) != 0 { + t.Fatalf("expected duplicate image chunk to be suppressed, got %d", len(out)) + } +} + +func TestConvertCodexResponseToOpenAI_StreamImageGenerationCallDoneEmitsDeltaImages(t *testing.T) { + ctx := context.Background() + var param any + + out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.image_generation_call.partial_image","item_id":"ig_123","output_format":"png","partial_image_b64":"aGVsbG8=","partial_image_index":0}`), ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","output_format":"png","result":"aGVsbG8="}}`), ¶m) + if len(out) != 0 { + t.Fatalf("expected output_item.done to be suppressed when identical to last partial image, got %d", len(out)) + } + + out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","output_format":"jpeg","result":"Ymll"}}`), ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + gotURL := gjson.GetBytes(out[0], "choices.0.delta.images.0.image_url.url").String() + if gotURL != "data:image/jpeg;base64,Ymll" { + t.Fatalf("expected image url %q, got %q; chunk=%s", "data:image/jpeg;base64,Ymll", gotURL, string(out[0])) + } +} + +func TestConvertCodexResponseToOpenAI_NonStreamImageGenerationCallAddsMessageImages(t *testing.T) { + ctx := context.Background() + + raw := []byte(`{"type":"response.completed","response":{"id":"resp_123","created_at":1700000000,"model":"gpt-5.4","status":"completed","usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2},"output":[{"type":"message","content":[{"type":"output_text","text":"ok"}]},{"type":"image_generation_call","output_format":"png","result":"aGVsbG8="}]}}`) + out := ConvertCodexResponseToOpenAINonStream(ctx, "gpt-5.4", nil, nil, raw, nil) + + gotURL := gjson.GetBytes(out, "choices.0.message.images.0.image_url.url").String() + if gotURL != "data:image/png;base64,aGVsbG8=" { + t.Fatalf("expected image url %q, got %q; chunk=%s", "data:image/png;base64,aGVsbG8=", gotURL, string(out)) + } +} diff --git a/internal/util/header_helpers.go b/internal/util/header_helpers.go index c53c291f10..0b8d72bcb4 100644 --- a/internal/util/header_helpers.go +++ b/internal/util/header_helpers.go @@ -47,6 +47,14 @@ func applyCustomHeaders(r *http.Request, headers map[string]string) { if k == "" || v == "" { continue } + // net/http reads Host from req.Host (not req.Header) when writing + // a real request, so we must mirror it there. Some callers pass + // synthetic requests (e.g. &http.Request{Header: ...}) and only + // consume r.Header afterwards, so keep the value in the header + // map too. + if http.CanonicalHeaderKey(k) == "Host" { + r.Host = v + } r.Header.Set(k, v) } } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 49e73d4637..1fda8f49f0 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -795,6 +795,13 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string parsed := thinking.ParseSuffix(resolvedModelName) baseModel := strings.TrimSpace(parsed.ModelName) + if strings.EqualFold(baseModel, "gpt-image-2") { + return nil, "", &interfaces.ErrorMessage{ + StatusCode: http.StatusServiceUnavailable, + Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", baseModel), + } + } + providers = util.GetProviderName(baseModel) // Fallback: if baseModel has no provider but differs from resolvedModelName, // try using the full model name. This handles edge cases where custom models diff --git a/sdk/api/handlers/handlers_request_details_test.go b/sdk/api/handlers/handlers_request_details_test.go index b0f6b13262..c98580f224 100644 --- a/sdk/api/handlers/handlers_request_details_test.go +++ b/sdk/api/handlers/handlers_request_details_test.go @@ -1,7 +1,9 @@ package handlers import ( + "net/http" "reflect" + "strings" "testing" "time" @@ -116,3 +118,22 @@ func TestGetRequestDetails_PreservesSuffix(t *testing.T) { }) } } + +func TestGetRequestDetails_ImageModelReturns503(t *testing.T) { + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, coreauth.NewManager(nil, nil, nil)) + + _, _, errMsg := handler.getRequestDetails("gpt-image-2") + if errMsg == nil { + t.Fatalf("expected error for gpt-image-2, got nil") + } + if errMsg.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("unexpected status code: got %d want %d", errMsg.StatusCode, http.StatusServiceUnavailable) + } + if errMsg.Error == nil { + t.Fatalf("expected error message, got nil") + } + msg := errMsg.Error.Error() + if !strings.Contains(msg, "/v1/images/generations") || !strings.Contains(msg, "/v1/images/edits") { + t.Fatalf("unexpected error message: %q", msg) + } +} diff --git a/sdk/api/handlers/openai/openai_images_handlers.go b/sdk/api/handlers/openai/openai_images_handlers.go new file mode 100644 index 0000000000..93d45460d0 --- /dev/null +++ b/sdk/api/handlers/openai/openai_images_handlers.go @@ -0,0 +1,896 @@ +package openai + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + defaultImagesMainModel = "gpt-5.4-mini" + defaultImagesToolModel = "gpt-image-2" +) + +type imageCallResult struct { + Result string + RevisedPrompt string + OutputFormat string + Size string + Background string + Quality string +} + +type sseFrameAccumulator struct { + pending []byte +} + +func (a *sseFrameAccumulator) AddChunk(chunk []byte) [][]byte { + if len(chunk) == 0 { + return nil + } + + if responsesSSENeedsLineBreak(a.pending, chunk) { + a.pending = append(a.pending, '\n') + } + a.pending = append(a.pending, chunk...) + + var frames [][]byte + for { + frameLen := responsesSSEFrameLen(a.pending) + if frameLen == 0 { + break + } + frames = append(frames, a.pending[:frameLen]) + copy(a.pending, a.pending[frameLen:]) + a.pending = a.pending[:len(a.pending)-frameLen] + } + + if len(bytes.TrimSpace(a.pending)) == 0 { + a.pending = a.pending[:0] + return frames + } + if len(a.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(a.pending) { + return frames + } + frames = append(frames, a.pending) + a.pending = a.pending[:0] + return frames +} + +func (a *sseFrameAccumulator) Flush() [][]byte { + if len(a.pending) == 0 { + return nil + } + + var frames [][]byte + for { + frameLen := responsesSSEFrameLen(a.pending) + if frameLen == 0 { + break + } + frames = append(frames, a.pending[:frameLen]) + copy(a.pending, a.pending[frameLen:]) + a.pending = a.pending[:len(a.pending)-frameLen] + } + + if len(bytes.TrimSpace(a.pending)) == 0 { + a.pending = nil + return frames + } + if responsesSSECanEmitWithoutDelimiter(a.pending) { + frames = append(frames, a.pending) + } + a.pending = nil + return frames +} + +func mimeTypeFromOutputFormat(outputFormat string) string { + if outputFormat == "" { + return "image/png" + } + if strings.Contains(outputFormat, "/") { + return outputFormat + } + switch strings.ToLower(strings.TrimSpace(outputFormat)) { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "webp": + return "image/webp" + default: + return "image/png" + } +} + +func multipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) { + if fileHeader == nil { + return "", fmt.Errorf("upload file is nil") + } + f, err := fileHeader.Open() + if err != nil { + return "", fmt.Errorf("open upload file failed: %w", err) + } + defer func() { + if errClose := f.Close(); errClose != nil { + log.Errorf("openai images: close upload file error: %v", errClose) + } + }() + + data, err := io.ReadAll(f) + if err != nil { + return "", fmt.Errorf("read upload file failed: %w", err) + } + + mediaType := strings.TrimSpace(fileHeader.Header.Get("Content-Type")) + if mediaType == "" { + mediaType = http.DetectContentType(data) + } + + b64 := base64.StdEncoding.EncodeToString(data) + return "data:" + mediaType + ";base64," + b64, nil +} + +func parseIntField(raw string, fallback int64) int64 { + raw = strings.TrimSpace(raw) + if raw == "" { + return fallback + } + v, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return fallback + } + return v +} + +func parseBoolField(raw string, fallback bool) bool { + raw = strings.TrimSpace(strings.ToLower(raw)) + if raw == "" { + return fallback + } + switch raw { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return fallback + } +} + +func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) { + rawJSON, err := c.GetRawData() + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + if !json.Valid(rawJSON) { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: body must be valid JSON", + Type: "invalid_request_error", + }, + }) + return + } + + prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String()) + if prompt == "" { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: prompt is required", + Type: "invalid_request_error", + }, + }) + return + } + + imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) + if imageModel == "" { + imageModel = defaultImagesToolModel + } + responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String()) + if responseFormat == "" { + responseFormat = "b64_json" + } + stream := gjson.GetBytes(rawJSON, "stream").Bool() + + tool := []byte(`{"type":"image_generation","action":"generate"}`) + tool, _ = sjson.SetBytes(tool, "model", imageModel) + + if v := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String()); v != "" { + tool, _ = sjson.SetBytes(tool, "size", v) + } + if v := strings.TrimSpace(gjson.GetBytes(rawJSON, "quality").String()); v != "" { + tool, _ = sjson.SetBytes(tool, "quality", v) + } + if v := strings.TrimSpace(gjson.GetBytes(rawJSON, "background").String()); v != "" { + tool, _ = sjson.SetBytes(tool, "background", v) + } + if v := strings.TrimSpace(gjson.GetBytes(rawJSON, "output_format").String()); v != "" { + tool, _ = sjson.SetBytes(tool, "output_format", v) + } + if v := gjson.GetBytes(rawJSON, "output_compression"); v.Exists() { + if v.Type == gjson.Number { + tool, _ = sjson.SetBytes(tool, "output_compression", v.Int()) + } + } + if v := gjson.GetBytes(rawJSON, "partial_images"); v.Exists() { + if v.Type == gjson.Number { + tool, _ = sjson.SetBytes(tool, "partial_images", v.Int()) + } + } + if v := strings.TrimSpace(gjson.GetBytes(rawJSON, "moderation").String()); v != "" { + tool, _ = sjson.SetBytes(tool, "moderation", v) + } + + responsesReq := buildImagesResponsesRequest(prompt, nil, tool) + if stream { + h.streamImagesFromResponses(c, responsesReq, responseFormat, "image_generation") + return + } + h.collectImagesFromResponses(c, responsesReq, responseFormat) +} + +func (h *OpenAIAPIHandler) ImagesEdits(c *gin.Context) { + contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type"))) + if strings.HasPrefix(contentType, "application/json") { + h.imagesEditsFromJSON(c) + return + } + if strings.HasPrefix(contentType, "multipart/form-data") || contentType == "" { + h.imagesEditsFromMultipart(c) + return + } + + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: unsupported Content-Type %q", contentType), + Type: "invalid_request_error", + }, + }) +} + +func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + prompt := strings.TrimSpace(c.PostForm("prompt")) + if prompt == "" { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: prompt is required", + Type: "invalid_request_error", + }, + }) + return + } + + var imageFiles []*multipart.FileHeader + if files := form.File["image[]"]; len(files) > 0 { + imageFiles = files + } else if files := form.File["image"]; len(files) > 0 { + imageFiles = files + } + if len(imageFiles) == 0 { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: image is required", + Type: "invalid_request_error", + }, + }) + return + } + + images := make([]string, 0, len(imageFiles)) + for _, fh := range imageFiles { + dataURL, err := multipartFileToDataURL(fh) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + images = append(images, dataURL) + } + + var maskDataURL *string + if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil { + dataURL, err := multipartFileToDataURL(maskFiles[0]) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + maskDataURL = &dataURL + } + + imageModel := strings.TrimSpace(c.PostForm("model")) + if imageModel == "" { + imageModel = defaultImagesToolModel + } + responseFormat := strings.TrimSpace(c.PostForm("response_format")) + if responseFormat == "" { + responseFormat = "b64_json" + } + stream := parseBoolField(c.PostForm("stream"), false) + + tool := []byte(`{"type":"image_generation","action":"edit"}`) + tool, _ = sjson.SetBytes(tool, "model", imageModel) + + if v := strings.TrimSpace(c.PostForm("size")); v != "" { + tool, _ = sjson.SetBytes(tool, "size", v) + } + if v := strings.TrimSpace(c.PostForm("quality")); v != "" { + tool, _ = sjson.SetBytes(tool, "quality", v) + } + if v := strings.TrimSpace(c.PostForm("background")); v != "" { + tool, _ = sjson.SetBytes(tool, "background", v) + } + if v := strings.TrimSpace(c.PostForm("output_format")); v != "" { + tool, _ = sjson.SetBytes(tool, "output_format", v) + } + if v := strings.TrimSpace(c.PostForm("input_fidelity")); v != "" { + tool, _ = sjson.SetBytes(tool, "input_fidelity", v) + } + if v := strings.TrimSpace(c.PostForm("moderation")); v != "" { + tool, _ = sjson.SetBytes(tool, "moderation", v) + } + + if v := strings.TrimSpace(c.PostForm("output_compression")); v != "" { + tool, _ = sjson.SetBytes(tool, "output_compression", parseIntField(v, 0)) + } + if v := strings.TrimSpace(c.PostForm("partial_images")); v != "" { + tool, _ = sjson.SetBytes(tool, "partial_images", parseIntField(v, 0)) + } + + if maskDataURL != nil && strings.TrimSpace(*maskDataURL) != "" { + tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", strings.TrimSpace(*maskDataURL)) + } + + responsesReq := buildImagesResponsesRequest(prompt, images, tool) + if stream { + h.streamImagesFromResponses(c, responsesReq, responseFormat, "image_edit") + return + } + h.collectImagesFromResponses(c, responsesReq, responseFormat) +} + +func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) { + rawJSON, err := c.GetRawData() + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + if !json.Valid(rawJSON) { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: body must be valid JSON", + Type: "invalid_request_error", + }, + }) + return + } + + prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String()) + if prompt == "" { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: prompt is required", + Type: "invalid_request_error", + }, + }) + return + } + + var images []string + imagesResult := gjson.GetBytes(rawJSON, "images") + if imagesResult.IsArray() { + for _, img := range imagesResult.Array() { + url := strings.TrimSpace(img.Get("image_url").String()) + if url == "" { + continue + } + images = append(images, url) + } + } + if len(images) == 0 { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: images[].image_url is required (file_id is not supported)", + Type: "invalid_request_error", + }, + }) + return + } + + var maskDataURL *string + if mask := gjson.GetBytes(rawJSON, "mask.image_url"); mask.Exists() { + url := strings.TrimSpace(mask.String()) + if url != "" { + maskDataURL = &url + } + } else if mask := gjson.GetBytes(rawJSON, "mask.file_id"); mask.Exists() { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: mask.file_id is not supported (use mask.image_url instead)", + Type: "invalid_request_error", + }, + }) + return + } + + imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) + if imageModel == "" { + imageModel = defaultImagesToolModel + } + responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String()) + if responseFormat == "" { + responseFormat = "b64_json" + } + stream := gjson.GetBytes(rawJSON, "stream").Bool() + + tool := []byte(`{"type":"image_generation","action":"edit"}`) + tool, _ = sjson.SetBytes(tool, "model", imageModel) + + for _, field := range []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"} { + if v := strings.TrimSpace(gjson.GetBytes(rawJSON, field).String()); v != "" { + tool, _ = sjson.SetBytes(tool, field, v) + } + } + + for _, field := range []string{"output_compression", "partial_images"} { + if v := gjson.GetBytes(rawJSON, field); v.Exists() && v.Type == gjson.Number { + tool, _ = sjson.SetBytes(tool, field, v.Int()) + } + } + + if maskDataURL != nil && strings.TrimSpace(*maskDataURL) != "" { + tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", strings.TrimSpace(*maskDataURL)) + } + + responsesReq := buildImagesResponsesRequest(prompt, images, tool) + if stream { + h.streamImagesFromResponses(c, responsesReq, responseFormat, "image_edit") + return + } + h.collectImagesFromResponses(c, responsesReq, responseFormat) +} + +func buildImagesResponsesRequest(prompt string, images []string, toolJSON []byte) []byte { + req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`) + req, _ = sjson.SetBytes(req, "model", defaultImagesMainModel) + + input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`) + input, _ = sjson.SetBytes(input, "0.content.0.text", prompt) + contentIndex := 1 + for _, img := range images { + if strings.TrimSpace(img) == "" { + continue + } + part := []byte(`{"type":"input_image","image_url":""}`) + part, _ = sjson.SetBytes(part, "image_url", img) + path := fmt.Sprintf("0.content.%d", contentIndex) + input, _ = sjson.SetRawBytes(input, path, part) + contentIndex++ + } + req, _ = sjson.SetRawBytes(req, "input", input) + + req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`)) + if len(toolJSON) > 0 && json.Valid(toolJSON) { + req, _ = sjson.SetRawBytes(req, "tools.-1", toolJSON) + } + return req +} + +func (h *OpenAIAPIHandler) collectImagesFromResponses(c *gin.Context, responsesReq []byte, responseFormat string) { + c.Header("Content-Type", "application/json") + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) + + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, "openai-response", defaultImagesMainModel, responsesReq, "") + + out, errMsg := collectImagesFromResponsesStream(cliCtx, dataChan, errChan, responseFormat) + stopKeepAlive() + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + if errMsg.Error != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + _, _ = c.Writer.Write(out) + cliCancel() +} + +func collectImagesFromResponsesStream(ctx context.Context, data <-chan []byte, errs <-chan *interfaces.ErrorMessage, responseFormat string) ([]byte, *interfaces.ErrorMessage) { + acc := &sseFrameAccumulator{} + + processFrame := func(frame []byte) ([]byte, bool, *interfaces.ErrorMessage) { + for _, line := range bytes.Split(frame, []byte("\n")) { + trimmed := bytes.TrimSpace(bytes.TrimRight(line, "\r")) + if len(trimmed) == 0 { + continue + } + if !bytes.HasPrefix(trimmed, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(trimmed[len("data:"):]) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + continue + } + if !json.Valid(payload) { + return nil, false, &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: fmt.Errorf("invalid SSE data JSON")} + } + + if gjson.GetBytes(payload, "type").String() != "response.completed" { + continue + } + + results, createdAt, usageRaw, firstMeta, err := extractImagesFromResponsesCompleted(payload) + if err != nil { + return nil, false, &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err} + } + if len(results) == 0 { + return nil, false, &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: fmt.Errorf("upstream did not return image output")} + } + out, err := buildImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat) + if err != nil { + return nil, false, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} + } + return out, true, nil + } + return nil, false, nil + } + + for { + select { + case <-ctx.Done(): + return nil, &interfaces.ErrorMessage{StatusCode: http.StatusRequestTimeout, Error: ctx.Err()} + case errMsg, ok := <-errs: + if ok && errMsg != nil { + return nil, errMsg + } + errs = nil + case chunk, ok := <-data: + if !ok { + for _, frame := range acc.Flush() { + if out, done, errMsg := processFrame(frame); errMsg != nil { + return nil, errMsg + } else if done { + return out, nil + } + } + return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: fmt.Errorf("stream disconnected before completion")} + } + for _, frame := range acc.AddChunk(chunk) { + if out, done, errMsg := processFrame(frame); errMsg != nil { + return nil, errMsg + } else if done { + return out, nil + } + } + } + } +} + +func extractImagesFromResponsesCompleted(payload []byte) (results []imageCallResult, createdAt int64, usageRaw []byte, firstMeta imageCallResult, err error) { + if gjson.GetBytes(payload, "type").String() != "response.completed" { + return nil, 0, nil, imageCallResult{}, fmt.Errorf("unexpected event type") + } + + createdAt = gjson.GetBytes(payload, "response.created_at").Int() + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + output := gjson.GetBytes(payload, "response.output") + if output.IsArray() { + for _, item := range output.Array() { + if item.Get("type").String() != "image_generation_call" { + continue + } + res := strings.TrimSpace(item.Get("result").String()) + if res == "" { + continue + } + entry := imageCallResult{ + Result: res, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + if len(results) == 0 { + firstMeta = entry + } + results = append(results, entry) + } + } + + if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() { + usageRaw = []byte(usage.Raw) + } + + return results, createdAt, usageRaw, firstMeta, nil +} + +func buildImagesAPIResponse(results []imageCallResult, createdAt int64, usageRaw []byte, firstMeta imageCallResult, responseFormat string) ([]byte, error) { + out := []byte(`{"created":0,"data":[]}`) + out, _ = sjson.SetBytes(out, "created", createdAt) + + responseFormat = strings.ToLower(strings.TrimSpace(responseFormat)) + if responseFormat == "" { + responseFormat = "b64_json" + } + + for _, img := range results { + item := []byte(`{}`) + if responseFormat == "url" { + mt := mimeTypeFromOutputFormat(img.OutputFormat) + item, _ = sjson.SetBytes(item, "url", "data:"+mt+";base64,"+img.Result) + } else { + item, _ = sjson.SetBytes(item, "b64_json", img.Result) + } + if img.RevisedPrompt != "" { + item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt) + } + out, _ = sjson.SetRawBytes(out, "data.-1", item) + } + + if firstMeta.Background != "" { + out, _ = sjson.SetBytes(out, "background", firstMeta.Background) + } + if firstMeta.OutputFormat != "" { + out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat) + } + if firstMeta.Quality != "" { + out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality) + } + if firstMeta.Size != "" { + out, _ = sjson.SetBytes(out, "size", firstMeta.Size) + } + + if len(usageRaw) > 0 && json.Valid(usageRaw) { + out, _ = sjson.SetRawBytes(out, "usage", usageRaw) + } + + return out, nil +} + +func (h *OpenAIAPIHandler) streamImagesFromResponses(c *gin.Context, responsesReq []byte, responseFormat string, streamPrefix string) { + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, "openai-response", defaultImagesMainModel, responsesReq, "") + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + writeEvent := func(eventName string, dataJSON []byte) { + if strings.TrimSpace(eventName) != "" { + _, _ = fmt.Fprintf(c.Writer, "event: %s\n", eventName) + } + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(dataJSON)) + flusher.Flush() + } + + // Peek for first chunk/error so we can still return a JSON error body. + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + errChan = nil + continue + } + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + cliCancel(nil) + return + } + + setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + + h.forwardImagesStream(cliCtx, c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, chunk, responseFormat, streamPrefix, writeEvent) + return + } + } +} + +func (h *OpenAIAPIHandler) forwardImagesStream(ctx context.Context, c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, firstChunk []byte, responseFormat string, streamPrefix string, writeEvent func(string, []byte)) { + acc := &sseFrameAccumulator{} + + responseFormat = strings.ToLower(strings.TrimSpace(responseFormat)) + if responseFormat == "" { + responseFormat = "b64_json" + } + + emitError := func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return + } + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + } + errText := http.StatusText(status) + if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" { + errText = errMsg.Error.Error() + } + body := handlers.BuildErrorResponseBody(status, errText) + writeEvent("error", body) + } + + processFrame := func(frame []byte) (done bool) { + for _, line := range bytes.Split(frame, []byte("\n")) { + trimmed := bytes.TrimSpace(bytes.TrimRight(line, "\r")) + if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(trimmed[len("data:"):]) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) || !json.Valid(payload) { + continue + } + + switch gjson.GetBytes(payload, "type").String() { + case "response.image_generation_call.partial_image": + b64 := strings.TrimSpace(gjson.GetBytes(payload, "partial_image_b64").String()) + if b64 == "" { + continue + } + outputFormat := strings.TrimSpace(gjson.GetBytes(payload, "output_format").String()) + index := gjson.GetBytes(payload, "partial_image_index").Int() + eventName := streamPrefix + ".partial_image" + data := []byte(`{"type":"","partial_image_index":0}`) + data, _ = sjson.SetBytes(data, "type", eventName) + data, _ = sjson.SetBytes(data, "partial_image_index", index) + if responseFormat == "url" { + mt := mimeTypeFromOutputFormat(outputFormat) + data, _ = sjson.SetBytes(data, "url", "data:"+mt+";base64,"+b64) + } else { + data, _ = sjson.SetBytes(data, "b64_json", b64) + } + writeEvent(eventName, data) + case "response.completed": + results, _, usageRaw, _, err := extractImagesFromResponsesCompleted(payload) + if err != nil { + emitError(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}) + return true + } + if len(results) == 0 { + emitError(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: fmt.Errorf("upstream did not return image output")}) + return true + } + eventName := streamPrefix + ".completed" + for _, img := range results { + data := []byte(`{"type":""}`) + data, _ = sjson.SetBytes(data, "type", eventName) + if responseFormat == "url" { + mt := mimeTypeFromOutputFormat(img.OutputFormat) + data, _ = sjson.SetBytes(data, "url", "data:"+mt+";base64,"+img.Result) + } else { + data, _ = sjson.SetBytes(data, "b64_json", img.Result) + } + if len(usageRaw) > 0 && json.Valid(usageRaw) { + data, _ = sjson.SetRawBytes(data, "usage", usageRaw) + } + writeEvent(eventName, data) + } + return true + } + } + return false + } + + for _, frame := range acc.AddChunk(firstChunk) { + if processFrame(frame) { + cancel(nil) + return + } + } + + for { + select { + case <-c.Request.Context().Done(): + cancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errs: + if ok && errMsg != nil { + emitError(errMsg) + cancel(errMsg.Error) + return + } + errs = nil + case chunk, ok := <-data: + if !ok { + for _, frame := range acc.Flush() { + if processFrame(frame) { + cancel(nil) + return + } + } + cancel(nil) + return + } + for _, frame := range acc.AddChunk(chunk) { + if processFrame(frame) { + cancel(nil) + return + } + } + } + } +} diff --git a/sdk/cliproxy/auth/antigravity_credits.go b/sdk/cliproxy/auth/antigravity_credits.go new file mode 100644 index 0000000000..77b03bfd3e --- /dev/null +++ b/sdk/cliproxy/auth/antigravity_credits.go @@ -0,0 +1,90 @@ +package auth + +import ( + "context" + "strings" + "sync" + "time" +) + +type antigravityUseCreditsContextKey struct{} + +// WithAntigravityCredits returns a child context that signals the executor to +// inject enabledCreditTypes into the request payload. +func WithAntigravityCredits(ctx context.Context) context.Context { + return context.WithValue(ctx, antigravityUseCreditsContextKey{}, true) +} + +// AntigravityCreditsRequested reports whether the context carries the credits flag. +func AntigravityCreditsRequested(ctx context.Context) bool { + if ctx == nil { + return false + } + v, _ := ctx.Value(antigravityUseCreditsContextKey{}).(bool) + return v +} + +// AntigravityCreditsHint stores the latest known AI credits state for one auth. +type AntigravityCreditsHint struct { + Known bool + Available bool + CreditAmount float64 + MinCreditAmount float64 + PaidTierID string + UpdatedAt time.Time +} + +var antigravityCreditsHintByAuth sync.Map + +// SetAntigravityCreditsHint updates the latest known AI credits state for an auth. +func SetAntigravityCreditsHint(authID string, hint AntigravityCreditsHint) { + authID = strings.TrimSpace(authID) + if authID == "" { + return + } + if hint.UpdatedAt.IsZero() { + hint.UpdatedAt = time.Now() + } + antigravityCreditsHintByAuth.Store(authID, hint) +} + +// GetAntigravityCreditsHint returns the latest known AI credits state for an auth. +func GetAntigravityCreditsHint(authID string) (AntigravityCreditsHint, bool) { + authID = strings.TrimSpace(authID) + if authID == "" { + return AntigravityCreditsHint{}, false + } + value, ok := antigravityCreditsHintByAuth.Load(authID) + if !ok { + return AntigravityCreditsHint{}, false + } + hint, ok := value.(AntigravityCreditsHint) + if !ok { + antigravityCreditsHintByAuth.Delete(authID) + return AntigravityCreditsHint{}, false + } + return hint, true +} + +// HasKnownAntigravityCreditsHint reports whether credits state has been discovered for an auth. +func HasKnownAntigravityCreditsHint(authID string) bool { + hint, ok := GetAntigravityCreditsHint(authID) + return ok && hint.Known +} + +func antigravityCreditsAvailableForModel(auth *Auth, model string) bool { + if auth == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(auth.Provider), "antigravity") { + return false + } + if !strings.Contains(strings.ToLower(strings.TrimSpace(model)), "claude") { + return false + } + hint, ok := GetAntigravityCreditsHint(auth.ID) + if !ok || !hint.Known { + return false + } + return hint.Available +} diff --git a/sdk/cliproxy/auth/antigravity_credits_test.go b/sdk/cliproxy/auth/antigravity_credits_test.go new file mode 100644 index 0000000000..8f59b4c78f --- /dev/null +++ b/sdk/cliproxy/auth/antigravity_credits_test.go @@ -0,0 +1,62 @@ +package auth + +import ( + "testing" + "time" +) + +func TestIsAuthBlockedForModel_ClaudeWithCreditsStillBlockedDuringCooldown(t *testing.T) { + auth := &Auth{ + ID: "ag-1", + Provider: "antigravity", + ModelStates: map[string]*ModelState{ + "claude-sonnet-4-6": { + Unavailable: true, + NextRetryAfter: time.Now().Add(10 * time.Minute), + Quota: QuotaState{ + Exceeded: true, + NextRecoverAt: time.Now().Add(10 * time.Minute), + }, + }, + }, + } + + SetAntigravityCreditsHint(auth.ID, AntigravityCreditsHint{ + Known: true, + Available: true, + UpdatedAt: time.Now(), + }) + + blocked, reason, _ := isAuthBlockedForModel(auth, "claude-sonnet-4-6", time.Now()) + if !blocked || reason != blockReasonCooldown { + t.Fatalf("expected auth to be blocked during cooldown even with credits, got blocked=%v reason=%v", blocked, reason) + } +} + +func TestIsAuthBlockedForModel_KeepsGeminiBlockedWithoutCreditsBypass(t *testing.T) { + auth := &Auth{ + ID: "ag-2", + Provider: "antigravity", + ModelStates: map[string]*ModelState{ + "gemini-3-flash": { + Unavailable: true, + NextRetryAfter: time.Now().Add(10 * time.Minute), + Quota: QuotaState{ + Exceeded: true, + NextRecoverAt: time.Now().Add(10 * time.Minute), + }, + }, + }, + } + + SetAntigravityCreditsHint(auth.ID, AntigravityCreditsHint{ + Known: true, + Available: true, + UpdatedAt: time.Now(), + }) + + blocked, reason, _ := isAuthBlockedForModel(auth, "gemini-3-flash", time.Now()) + if !blocked || reason != blockReasonCooldown { + t.Fatalf("expected gemini auth to remain blocked, got blocked=%v reason=%v", blocked, reason) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index f58722039c..4d37581a61 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -64,8 +64,13 @@ const ( refreshMaxConcurrency = 16 refreshPendingBackoff = time.Minute refreshFailureBackoff = 5 * time.Minute - quotaBackoffBase = time.Second - quotaBackoffMax = 30 * time.Minute + // refreshIneffectiveBackoff throttles refresh attempts when an executor returns + // success but the auth still evaluates as needing refresh (e.g. token expiry + // wasn't updated). Without this guard, the auto-refresh loop can tight-loop and + // burn CPU at idle. + refreshIneffectiveBackoff = 30 * time.Second + quotaBackoffBase = time.Second + quotaBackoffMax = 30 * time.Minute ) var quotaCooldownDisabled atomic.Bool @@ -1197,12 +1202,16 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye } } if lastErr != nil { + if shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) { + if resp, ok := m.tryAntigravityCreditsExecute(ctx, req, opts); ok { + return resp, nil + } + } return cliproxyexecutor.Response{}, lastErr } return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"} } -// ExecuteCount performs a non-streaming execution using the configured selector and executor. // It supports multiple providers for the same model and round-robins the starting provider per model. func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { normalized := m.normalizeProviders(providers) @@ -1259,6 +1268,11 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli } } if lastErr != nil { + if shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) { + if result, ok := m.tryAntigravityCreditsExecuteStream(ctx, req, opts); ok { + return result, nil + } + } return nil, lastErr } return nil, &Error{Code: "auth_not_found", Message: "no auth available"} @@ -2314,7 +2328,8 @@ func retryAfterFromError(err error) *time.Duration { if retryAfter == nil { return nil } - return new(*retryAfter) + value := *retryAfter + return &value } func statusCodeFromResult(err *Error) int { @@ -2404,11 +2419,18 @@ func isRequestInvalidError(err error) bool { status := statusCodeFromError(err) switch status { case http.StatusBadRequest: - return strings.Contains(err.Error(), "invalid_request_error") + msg := err.Error() + return strings.Contains(msg, "invalid_request_error") || + strings.Contains(msg, "INVALID_ARGUMENT") || + strings.Contains(msg, "FAILED_PRECONDITION") case http.StatusNotFound: return isRequestScopedNotFoundMessage(err.Error()) case http.StatusUnprocessableEntity: return true + case http.StatusInternalServerError: + msg := err.Error() + return strings.Contains(msg, "\"status\":\"UNKNOWN\"") || + strings.Contains(msg, "\"status\": \"UNKNOWN\"") default: return false } @@ -2881,6 +2903,175 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s return authCopy, executor, providerKey, nil } +func (m *Manager) findAllAntigravityCreditsCandidateAuths(routeModel string, opts cliproxyexecutor.Options) []creditsCandidateEntry { + if m == nil { + return nil + } + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + m.mu.RLock() + defer m.mu.RUnlock() + var known []creditsCandidateEntry + var unknown []creditsCandidateEntry + for _, auth := range m.auths { + if auth == nil || auth.Disabled || auth.Status == StatusDisabled { + continue + } + if pinnedAuthID != "" && auth.ID != pinnedAuthID { + continue + } + if !strings.EqualFold(strings.TrimSpace(auth.Provider), "antigravity") { + continue + } + if !strings.Contains(strings.ToLower(strings.TrimSpace(routeModel)), "claude") { + continue + } + providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + executor, ok := m.executors[providerKey] + if !ok { + continue + } + + hint, okHint := GetAntigravityCreditsHint(auth.ID) + if okHint && hint.Known { + if !hint.Available { + continue + } + known = append(known, creditsCandidateEntry{ + auth: auth.Clone(), + executor: executor, + provider: providerKey, + }) + continue + } + unknown = append(unknown, creditsCandidateEntry{ + auth: auth.Clone(), + executor: executor, + provider: providerKey, + }) + } + sort.Slice(known, func(i, j int) bool { + return known[i].auth.ID < known[j].auth.ID + }) + sort.Slice(unknown, func(i, j int) bool { + return unknown[i].auth.ID < unknown[j].auth.ID + }) + return append(known, unknown...) +} + +type creditsCandidateEntry struct { + auth *Auth + executor ProviderExecutor + provider string +} + +func shouldAttemptAntigravityCreditsFallback(m *Manager, lastErr error, providers []string) bool { + if m == nil || lastErr == nil { + return false + } + if len(providers) > 0 { + hasAntigravity := false + for _, p := range providers { + if strings.EqualFold(strings.TrimSpace(p), "antigravity") { + hasAntigravity = true + break + } + } + if !hasAntigravity { + return false + } + } + cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config) + if cfg == nil || !cfg.QuotaExceeded.AntigravityCredits { + return false + } + status := statusCodeFromError(lastErr) + switch status { + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + return true + case 0: + var authErr *Error + if errors.As(lastErr, &authErr) && authErr != nil { + return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable" || authErr.Code == "model_cooldown" + } + var cooldownErr *modelCooldownError + if errors.As(lastErr, &cooldownErr) { + return true + } + return false + default: + return false + } +} + +func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, bool) { + routeModel := req.Model + candidates := m.findAllAntigravityCreditsCandidateAuths(routeModel, opts) + for _, c := range candidates { + if ctx.Err() != nil { + return cliproxyexecutor.Response{}, false + } + creditsCtx := WithAntigravityCredits(ctx) + if rt := m.roundTripperFor(c.auth); rt != nil { + creditsCtx = context.WithValue(creditsCtx, roundTripperContextKey{}, rt) + creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt) + } + creditsOpts := ensureRequestedModelMetadata(opts, routeModel) + publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID) + models := m.executionModelCandidates(c.auth, routeModel) + if len(models) == 0 { + continue + } + for _, upstreamModel := range models { + resultModel := m.stateModelForExecution(c.auth, routeModel, upstreamModel, len(models) > 1) + execReq := req + execReq.Model = upstreamModel + resp, errExec := c.executor.Execute(creditsCtx, c.auth, execReq, creditsOpts) + result := Result{AuthID: c.auth.ID, Provider: c.provider, Model: resultModel, Success: errExec == nil} + if errExec != nil { + result.Error = &Error{Message: errExec.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(creditsCtx, result) + continue + } + m.MarkResult(creditsCtx, result) + return resp, true + } + } + return cliproxyexecutor.Response{}, false +} + +func (m *Manager) tryAntigravityCreditsExecuteStream(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, bool) { + routeModel := req.Model + candidates := m.findAllAntigravityCreditsCandidateAuths(routeModel, opts) + for _, c := range candidates { + if ctx.Err() != nil { + return nil, false + } + creditsCtx := WithAntigravityCredits(ctx) + if rt := m.roundTripperFor(c.auth); rt != nil { + creditsCtx = context.WithValue(creditsCtx, roundTripperContextKey{}, rt) + creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt) + } + creditsOpts := ensureRequestedModelMetadata(opts, routeModel) + publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID) + models := m.executionModelCandidates(c.auth, routeModel) + if len(models) == 0 { + continue + } + result, errStream := m.executeStreamWithModelPool(creditsCtx, c.executor, c.auth, c.provider, req, creditsOpts, routeModel, models, len(models) > 1) + if errStream != nil { + continue + } + return result, true + } + return nil, false +} + func (m *Manager) persist(ctx context.Context, auth *Auth) error { if m.store == nil || auth == nil { return nil @@ -3195,14 +3386,15 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { m.mu.RLock() auth := m.auths[id] var exec ProviderExecutor + var cloned *Auth if auth != nil { exec = m.executors[auth.Provider] + cloned = auth.Clone() } m.mu.RUnlock() if auth == nil || exec == nil { return } - cloned := auth.Clone() updated, err := exec.Refresh(ctx, cloned) if err != nil && errors.Is(err, context.Canceled) { log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID) @@ -3240,6 +3432,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { updated.NextRefreshAfter = time.Time{} updated.LastError = nil updated.UpdatedAt = now + if m.shouldRefresh(updated, now) { + updated.NextRefreshAfter = now.Add(refreshIneffectiveBackoff) + } _, _ = m.Update(ctx, updated) } diff --git a/sdk/cliproxy/auth/conductor_credits_candidates_test.go b/sdk/cliproxy/auth/conductor_credits_candidates_test.go new file mode 100644 index 0000000000..e66798acf6 --- /dev/null +++ b/sdk/cliproxy/auth/conductor_credits_candidates_test.go @@ -0,0 +1,61 @@ +package auth + +import ( + "testing" + "time" + + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +func TestFindAllAntigravityCreditsCandidateAuths_PrefersKnownCreditsThenUnknown(t *testing.T) { + m := &Manager{ + auths: map[string]*Auth{ + "zz-credits": {ID: "zz-credits", Provider: "antigravity"}, + "aa-unknown": {ID: "aa-unknown", Provider: "antigravity"}, + "mm-no": {ID: "mm-no", Provider: "antigravity"}, + }, + executors: map[string]ProviderExecutor{ + "antigravity": schedulerTestExecutor{}, + }, + } + + SetAntigravityCreditsHint("zz-credits", AntigravityCreditsHint{ + Known: true, + Available: true, + UpdatedAt: time.Now(), + }) + SetAntigravityCreditsHint("mm-no", AntigravityCreditsHint{ + Known: true, + Available: false, + UpdatedAt: time.Now(), + }) + + opts := cliproxyexecutor.Options{} + + candidates := m.findAllAntigravityCreditsCandidateAuths("claude-sonnet-4-6", opts) + if len(candidates) != 2 { + t.Fatalf("candidates len = %d, want 2", len(candidates)) + } + if candidates[0].auth.ID != "zz-credits" { + t.Fatalf("candidates[0].auth.ID = %q, want %q", candidates[0].auth.ID, "zz-credits") + } + if candidates[1].auth.ID != "aa-unknown" { + t.Fatalf("candidates[1].auth.ID = %q, want %q", candidates[1].auth.ID, "aa-unknown") + } + + nonClaude := m.findAllAntigravityCreditsCandidateAuths("gemini-3-flash", opts) + if len(nonClaude) != 0 { + t.Fatalf("nonClaude len = %d, want 0", len(nonClaude)) + } + + pinnedOpts := cliproxyexecutor.Options{ + Metadata: map[string]any{cliproxyexecutor.PinnedAuthMetadataKey: "aa-unknown"}, + } + pinned := m.findAllAntigravityCreditsCandidateAuths("claude-sonnet-4-6", pinnedOpts) + if len(pinned) != 1 { + t.Fatalf("pinned len = %d, want 1", len(pinned)) + } + if pinned[0].auth.ID != "aa-unknown" { + t.Fatalf("pinned[0].auth.ID = %q, want %q", pinned[0].auth.ID, "aa-unknown") + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 5e873d370b..fa0d8a0aa7 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -1410,7 +1410,7 @@ func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo { if entry == nil { return nil } - return buildConfigModels(entry.Models, "openai", "openai") + return registry.WithCodexBuiltins(buildConfigModels(entry.Models, "openai", "openai")) } func rewriteModelInfoName(name, oldID, newID string) string { diff --git a/test/thinking_conversion_test.go b/test/thinking_conversion_test.go index c6ade7b2a6..51671a9c5f 100644 --- a/test/thinking_conversion_test.go +++ b/test/thinking_conversion_test.go @@ -2,7 +2,6 @@ package test import ( "fmt" - "strings" "testing" "time" @@ -1066,12 +1065,12 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { expectErr: false, }, - // Gemini Family Cross-Channel Consistency (Cases 106-114) + // Gemini Family Cross-Channel Consistency (Cases 90-95) // Tests that gemini/gemini-cli/antigravity as same API family should have consistent validation behavior - // Case 106: Gemini to Antigravity, budget 64000 (suffix) → clamped to Max + // Case 90: Gemini to Antigravity, budget 64000 (suffix) → clamped to Max { - name: "106", + name: "90", from: "gemini", to: "antigravity", model: "gemini-budget-model(64000)", @@ -1081,9 +1080,9 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 107: Gemini to Gemini-CLI, budget 64000 (suffix) → clamped to Max + // Case 91: Gemini to Gemini-CLI, budget 64000 (suffix) → clamped to Max { - name: "107", + name: "91", from: "gemini", to: "gemini-cli", model: "gemini-budget-model(64000)", @@ -1093,9 +1092,9 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 108: Gemini-CLI to Antigravity, budget 64000 (suffix) → clamped to Max + // Case 92: Gemini-CLI to Antigravity, budget 64000 (suffix) → clamped to Max { - name: "108", + name: "92", from: "gemini-cli", to: "antigravity", model: "gemini-budget-model(64000)", @@ -1105,9 +1104,9 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 109: Gemini-CLI to Gemini, budget 64000 (suffix) → clamped to Max + // Case 93: Gemini-CLI to Gemini, budget 64000 (suffix) → clamped to Max { - name: "109", + name: "93", from: "gemini-cli", to: "gemini", model: "gemini-budget-model(64000)", @@ -1117,9 +1116,9 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 110: Gemini to Antigravity, budget 8192 → passthrough (normal value) + // Case 94: Gemini to Antigravity, budget 8192 → passthrough (normal value) { - name: "110", + name: "94", from: "gemini", to: "antigravity", model: "gemini-budget-model(8192)", @@ -1129,9 +1128,9 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 111: Gemini-CLI to Antigravity, budget 8192 → passthrough (normal value) + // Case 95: Gemini-CLI to Antigravity, budget 8192 → passthrough (normal value) { - name: "111", + name: "95", from: "gemini-cli", to: "antigravity", model: "gemini-budget-model(8192)", @@ -2167,12 +2166,12 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectErr: true, }, - // Gemini Family Cross-Channel Consistency (Cases 106-114) + // Gemini Family Cross-Channel Consistency (Cases 90-95) // Tests that gemini/gemini-cli/antigravity as same API family should have consistent validation behavior - // Case 106: Gemini to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation) + // Case 90: Gemini to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation) { - name: "106", + name: "90", from: "gemini", to: "antigravity", model: "gemini-budget-model", @@ -2180,9 +2179,9 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: true, }, - // Case 107: Gemini to Gemini-CLI, thinkingBudget=64000 → exceeds Max error (same family strict validation) + // Case 91: Gemini to Gemini-CLI, thinkingBudget=64000 → exceeds Max error (same family strict validation) { - name: "107", + name: "91", from: "gemini", to: "gemini-cli", model: "gemini-budget-model", @@ -2190,9 +2189,9 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: true, }, - // Case 108: Gemini-CLI to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation) + // Case 92: Gemini-CLI to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation) { - name: "108", + name: "92", from: "gemini-cli", to: "antigravity", model: "gemini-budget-model", @@ -2200,9 +2199,9 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: true, }, - // Case 109: Gemini-CLI to Gemini, thinkingBudget=64000 → exceeds Max error (same family strict validation) + // Case 93: Gemini-CLI to Gemini, thinkingBudget=64000 → exceeds Max error (same family strict validation) { - name: "109", + name: "93", from: "gemini-cli", to: "gemini", model: "gemini-budget-model", @@ -2210,9 +2209,9 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: true, }, - // Case 110: Gemini to Antigravity, thinkingBudget=8192 → passthrough (normal value) + // Case 94: Gemini to Antigravity, thinkingBudget=8192 → passthrough (normal value) { - name: "110", + name: "94", from: "gemini", to: "antigravity", model: "gemini-budget-model", @@ -2222,9 +2221,9 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 111: Gemini-CLI to Antigravity, thinkingBudget=8192 → passthrough (normal value) + // Case 95: Gemini-CLI to Antigravity, thinkingBudget=8192 → passthrough (normal value) { - name: "111", + name: "95", from: "gemini-cli", to: "antigravity", model: "gemini-budget-model",