diff --git a/middleware/distributor.go b/middleware/distributor.go index d626941456..5ede75b0d9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -27,6 +27,46 @@ type ModelRequest struct { Group string `json:"group,omitempty"` } +func selectPreferredAffinityChannel( + c *gin.Context, + preferred *model.Channel, + modelName string, + usingGroup string, + getAutoGroups func(string) []string, + isChannelEnabledForGroupModel func(string, string, int) bool, +) (*model.Channel, string) { + if preferred == nil { + service.ClearCurrentChannelAffinity(c, "preferred affinity channel missing") + return nil, "" + } + if preferred.Status != common.ChannelStatusEnabled { + service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d is disabled", preferred.Id)) + return nil, "" + } + + if usingGroup == "auto" { + userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) + autoGroups := getAutoGroups(userGroup) + for _, g := range autoGroups { + if isChannelEnabledForGroupModel(g, modelName, preferred.Id) { + common.SetContextKey(c, constant.ContextKeyAutoGroup, g) + service.MarkChannelAffinityUsed(c, g, preferred.Id) + return preferred, g + } + } + service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d no longer matches auto groups for model %s", preferred.Id, modelName)) + return nil, "" + } + + if isChannelEnabledForGroupModel(usingGroup, modelName, preferred.Id) { + service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id) + return preferred, usingGroup + } + + service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d no longer matches group %s for model %s", preferred.Id, usingGroup, modelName)) + return nil, "" +} + func Distribute() func(c *gin.Context) { return func(c *gin.Context) { var channel *model.Channel @@ -101,29 +141,17 @@ func Distribute() func(c *gin.Context) { if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found { preferred, err := model.CacheGetChannel(preferredChannelID) - if err == nil && preferred != nil { - if preferred.Status != common.ChannelStatusEnabled { - if service.ShouldSkipRetryAfterChannelAffinityFailure(c) { - abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled)) - return - } - } else if usingGroup == "auto" { - userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) - autoGroups := service.GetUserAutoGroup(userGroup) - for _, g := range autoGroups { - if model.IsChannelEnabledForGroupModel(g, modelRequest.Model, preferred.Id) { - selectGroup = g - common.SetContextKey(c, constant.ContextKeyAutoGroup, g) - channel = preferred - service.MarkChannelAffinityUsed(c, g, preferred.Id) - break - } - } - } else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) { - channel = preferred - selectGroup = usingGroup - service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id) - } + if err != nil { + service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d lookup failed: %v", preferredChannelID, err)) + } else { + channel, selectGroup = selectPreferredAffinityChannel( + c, + preferred, + modelRequest.Model, + usingGroup, + service.GetUserAutoGroup, + model.IsChannelEnabledForGroupModel, + ) } } diff --git a/middleware/distributor_test.go b/middleware/distributor_test.go new file mode 100644 index 0000000000..dffd35473f --- /dev/null +++ b/middleware/distributor_test.go @@ -0,0 +1,88 @@ +package middleware + +import ( + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newDistributorTestContext() *gin.Context { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Set("channel_affinity_cache_key", "new-api:channel_affinity:v1:codex cli trace:default:test-affinity-key") + ctx.Set("channel_affinity_ttl_seconds", 60) + return ctx +} + +func TestSelectPreferredAffinityChannelClearsDisabledChannel(t *testing.T) { + ctx := newDistributorTestContext() + + selected, group := selectPreferredAffinityChannel( + ctx, + &model.Channel{Id: 42, Status: common.ChannelStatusAutoDisabled}, + "gpt-5", + "default", + func(string) []string { return []string{"default"} }, + func(string, string, int) bool { return true }, + ) + + require.Nil(t, selected) + require.Empty(t, group) + + anyInfo, ok := ctx.Get("channel_affinity_log_info") + require.True(t, ok) + info, ok := anyInfo.(map[string]interface{}) + require.True(t, ok) + require.Equal(t, true, info["stale_affinity_cleared"]) + require.Equal(t, "preferred affinity channel 42 is disabled", info["stale_affinity_reason"]) +} + +func TestSelectPreferredAffinityChannelFallsBackForAutoGroupMismatch(t *testing.T) { + ctx := newDistributorTestContext() + common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default") + + selected, group := selectPreferredAffinityChannel( + ctx, + &model.Channel{Id: 42, Status: common.ChannelStatusEnabled}, + "gpt-5", + "auto", + func(string) []string { return []string{"group-a", "group-b"} }, + func(string, string, int) bool { return false }, + ) + + require.Nil(t, selected) + require.Empty(t, group) + + anyInfo, ok := ctx.Get("channel_affinity_log_info") + require.True(t, ok) + info, ok := anyInfo.(map[string]interface{}) + require.True(t, ok) + require.Equal(t, true, info["stale_affinity_cleared"]) + require.Equal(t, "preferred affinity channel 42 no longer matches auto groups for model gpt-5", info["stale_affinity_reason"]) +} + +func TestSelectPreferredAffinityChannelUsesMatchingAutoGroup(t *testing.T) { + ctx := newDistributorTestContext() + common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default") + + selected, group := selectPreferredAffinityChannel( + ctx, + &model.Channel{Id: 43, Status: common.ChannelStatusEnabled}, + "gpt-5", + "auto", + func(string) []string { return []string{"group-a", "group-b"} }, + func(group, modelName string, channelID int) bool { + return group == "group-b" && modelName == "gpt-5" && channelID == 43 + }, + ) + + require.NotNil(t, selected) + require.Equal(t, 43, selected.Id) + require.Equal(t, "group-b", group) + require.Equal(t, "group-b", common.GetContextKeyString(ctx, constant.ContextKeyAutoGroup)) +} diff --git a/service/channel_affinity.go b/service/channel_affinity.go index 9f89585fac..f541a0861a 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -511,6 +511,43 @@ func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinity }) } +func appendChannelAffinityClearedAdminInfo(c *gin.Context, reason string) { + if c == nil { + return + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "stale affinity cleared" + } + + if anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo); ok { + if info, ok := anyInfo.(map[string]interface{}); ok { + info["stale_affinity_cleared"] = true + info["stale_affinity_reason"] = reason + c.Set(ginKeyChannelAffinityLogInfo, info) + return + } + } + + info := map[string]interface{}{ + "stale_affinity_cleared": true, + "stale_affinity_reason": reason, + } + if meta, ok := getChannelAffinityMeta(c); ok { + info["reason"] = meta.RuleName + info["rule_name"] = meta.RuleName + info["using_group"] = meta.UsingGroup + info["model"] = meta.ModelName + info["request_path"] = meta.RequestPath + info["key_source"] = meta.KeySourceType + info["key_key"] = meta.KeySourceKey + info["key_path"] = meta.KeySourcePath + info["key_hint"] = meta.KeyHint + info["key_fp"] = meta.KeyFingerprint + } + c.Set(ginKeyChannelAffinityLogInfo, info) +} + // ApplyChannelAffinityOverrideTemplate merges per-rule channel override templates onto the selected channel override config. func ApplyChannelAffinityOverrideTemplate(c *gin.Context, paramOverride map[string]interface{}) (map[string]interface{}, bool) { if c == nil { @@ -623,6 +660,28 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool { return meta.SkipRetry } +func ClearCurrentChannelAffinity(c *gin.Context, reason string) bool { + if c == nil { + return false + } + cacheKey, _, ok := getChannelAffinityContext(c) + if !ok || cacheKey == "" { + return false + } + cache := getChannelAffinityCache() + if _, err := cache.DeleteMany([]string{cacheKey}); err != nil { + common.SysError(fmt.Sprintf("channel affinity cache delete failed: key=%s, err=%v", cacheKey, err)) + return false + } + c.Set(ginKeyChannelAffinitySkipRetry, false) + if meta, ok := getChannelAffinityMeta(c); ok { + meta.SkipRetry = false + c.Set(ginKeyChannelAffinityMeta, meta) + } + appendChannelAffinityClearedAdminInfo(c, reason) + return true +} + func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) { if c == nil || channelID <= 0 { return diff --git a/service/channel_affinity_cache_test.go b/service/channel_affinity_cache_test.go new file mode 100644 index 0000000000..4f295a4458 --- /dev/null +++ b/service/channel_affinity_cache_test.go @@ -0,0 +1,54 @@ +package service + +import ( + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/stretchr/testify/require" +) + +func TestClearCurrentChannelAffinity(t *testing.T) { + originalRedisEnabled := common.RedisEnabled + common.RedisEnabled = false + t.Cleanup(func() { + common.RedisEnabled = originalRedisEnabled + _ = getChannelAffinityCache().Purge() + }) + + cache := getChannelAffinityCache() + _ = cache.Purge() + + cacheKey := "new-api:channel_affinity:v1:codex cli trace:default:test-affinity-key" + require.NoError(t, cache.SetWithTTL(cacheKey, 42, time.Minute)) + + ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{ + CacheKey: cacheKey, + TTLSeconds: 60, + RuleName: "codex cli trace", + SkipRetry: true, + UsingGroup: "default", + ModelName: "gpt-5", + RequestPath: "/v1/responses", + KeySourceType: "gjson", + KeySourcePath: "prompt_cache_key", + KeyHint: "test...key", + KeyFingerprint: "abcd1234", + }) + + require.True(t, ClearCurrentChannelAffinity(ctx, "preferred affinity channel 42 is disabled")) + + _, found, err := cache.Get(cacheKey) + require.NoError(t, err) + require.False(t, found) + require.False(t, ShouldSkipRetryAfterChannelAffinityFailure(ctx)) + + anyInfo, ok := ctx.Get(ginKeyChannelAffinityLogInfo) + require.True(t, ok) + info, ok := anyInfo.(map[string]interface{}) + require.True(t, ok) + require.Equal(t, true, info["stale_affinity_cleared"]) + require.Equal(t, "preferred affinity channel 42 is disabled", info["stale_affinity_reason"]) + require.Equal(t, "codex cli trace", info["rule_name"]) + require.Equal(t, "default", info["using_group"]) +}