Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,46 @@ type ModelRequest struct {
Group string `json:"group,omitempty"`
}

func selectPreferredAffinityChannel(
c *gin.Context,
preferred *model.Channel,
modelName string,
usingGroup string,
getAutoGroups func(string) []string,
isChannelEnabledForGroupModel func(string, string, int) bool,
) (*model.Channel, string) {
if preferred == nil {
service.ClearCurrentChannelAffinity(c, "preferred affinity channel missing")
return nil, ""
}
if preferred.Status != common.ChannelStatusEnabled {
service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d is disabled", preferred.Id))
return nil, ""
}

if usingGroup == "auto" {
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
autoGroups := getAutoGroups(userGroup)
for _, g := range autoGroups {
if isChannelEnabledForGroupModel(g, modelName, preferred.Id) {
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
service.MarkChannelAffinityUsed(c, g, preferred.Id)
return preferred, g
}
}
service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d no longer matches auto groups for model %s", preferred.Id, modelName))
return nil, ""
}

if isChannelEnabledForGroupModel(usingGroup, modelName, preferred.Id) {
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
return preferred, usingGroup
}

service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d no longer matches group %s for model %s", preferred.Id, usingGroup, modelName))
return nil, ""
}

func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
var channel *model.Channel
Expand Down Expand Up @@ -101,29 +141,17 @@ func Distribute() func(c *gin.Context) {

if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
preferred, err := model.CacheGetChannel(preferredChannelID)
if err == nil && preferred != nil {
if preferred.Status != common.ChannelStatusEnabled {
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
return
}
} else if usingGroup == "auto" {
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
autoGroups := service.GetUserAutoGroup(userGroup)
for _, g := range autoGroups {
if model.IsChannelEnabledForGroupModel(g, modelRequest.Model, preferred.Id) {
selectGroup = g
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
channel = preferred
service.MarkChannelAffinityUsed(c, g, preferred.Id)
break
}
}
} else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) {
channel = preferred
selectGroup = usingGroup
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
}
if err != nil {
service.ClearCurrentChannelAffinity(c, fmt.Sprintf("preferred affinity channel %d lookup failed: %v", preferredChannelID, err))
} else {
channel, selectGroup = selectPreferredAffinityChannel(
c,
preferred,
modelRequest.Model,
usingGroup,
service.GetUserAutoGroup,
model.IsChannelEnabledForGroupModel,
)
}
}

Expand Down
88 changes: 88 additions & 0 deletions middleware/distributor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package middleware

import (
"net/http/httptest"
"testing"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)

func newDistributorTestContext() *gin.Context {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Set("channel_affinity_cache_key", "new-api:channel_affinity:v1:codex cli trace:default:test-affinity-key")
ctx.Set("channel_affinity_ttl_seconds", 60)
return ctx
}

func TestSelectPreferredAffinityChannelClearsDisabledChannel(t *testing.T) {
ctx := newDistributorTestContext()

selected, group := selectPreferredAffinityChannel(
ctx,
&model.Channel{Id: 42, Status: common.ChannelStatusAutoDisabled},
"gpt-5",
"default",
func(string) []string { return []string{"default"} },
func(string, string, int) bool { return true },
)

require.Nil(t, selected)
require.Empty(t, group)

anyInfo, ok := ctx.Get("channel_affinity_log_info")
require.True(t, ok)
info, ok := anyInfo.(map[string]interface{})
require.True(t, ok)
require.Equal(t, true, info["stale_affinity_cleared"])
require.Equal(t, "preferred affinity channel 42 is disabled", info["stale_affinity_reason"])
}

func TestSelectPreferredAffinityChannelFallsBackForAutoGroupMismatch(t *testing.T) {
ctx := newDistributorTestContext()
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")

selected, group := selectPreferredAffinityChannel(
ctx,
&model.Channel{Id: 42, Status: common.ChannelStatusEnabled},
"gpt-5",
"auto",
func(string) []string { return []string{"group-a", "group-b"} },
func(string, string, int) bool { return false },
)

require.Nil(t, selected)
require.Empty(t, group)

anyInfo, ok := ctx.Get("channel_affinity_log_info")
require.True(t, ok)
info, ok := anyInfo.(map[string]interface{})
require.True(t, ok)
require.Equal(t, true, info["stale_affinity_cleared"])
require.Equal(t, "preferred affinity channel 42 no longer matches auto groups for model gpt-5", info["stale_affinity_reason"])
}

func TestSelectPreferredAffinityChannelUsesMatchingAutoGroup(t *testing.T) {
ctx := newDistributorTestContext()
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")

selected, group := selectPreferredAffinityChannel(
ctx,
&model.Channel{Id: 43, Status: common.ChannelStatusEnabled},
"gpt-5",
"auto",
func(string) []string { return []string{"group-a", "group-b"} },
func(group, modelName string, channelID int) bool {
return group == "group-b" && modelName == "gpt-5" && channelID == 43
},
)

require.NotNil(t, selected)
require.Equal(t, 43, selected.Id)
require.Equal(t, "group-b", group)
require.Equal(t, "group-b", common.GetContextKeyString(ctx, constant.ContextKeyAutoGroup))
}
59 changes: 59 additions & 0 deletions service/channel_affinity.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,43 @@ func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinity
})
}

func appendChannelAffinityClearedAdminInfo(c *gin.Context, reason string) {
if c == nil {
return
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "stale affinity cleared"
}

if anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo); ok {
if info, ok := anyInfo.(map[string]interface{}); ok {
info["stale_affinity_cleared"] = true
info["stale_affinity_reason"] = reason
c.Set(ginKeyChannelAffinityLogInfo, info)
return
}
}

info := map[string]interface{}{
"stale_affinity_cleared": true,
"stale_affinity_reason": reason,
}
if meta, ok := getChannelAffinityMeta(c); ok {
info["reason"] = meta.RuleName
info["rule_name"] = meta.RuleName
info["using_group"] = meta.UsingGroup
info["model"] = meta.ModelName
info["request_path"] = meta.RequestPath
info["key_source"] = meta.KeySourceType
info["key_key"] = meta.KeySourceKey
info["key_path"] = meta.KeySourcePath
info["key_hint"] = meta.KeyHint
info["key_fp"] = meta.KeyFingerprint
}
c.Set(ginKeyChannelAffinityLogInfo, info)
}

// ApplyChannelAffinityOverrideTemplate merges per-rule channel override templates onto the selected channel override config.
func ApplyChannelAffinityOverrideTemplate(c *gin.Context, paramOverride map[string]interface{}) (map[string]interface{}, bool) {
if c == nil {
Expand Down Expand Up @@ -623,6 +660,28 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
return meta.SkipRetry
}

func ClearCurrentChannelAffinity(c *gin.Context, reason string) bool {
if c == nil {
return false
}
cacheKey, _, ok := getChannelAffinityContext(c)
if !ok || cacheKey == "" {
return false
}
cache := getChannelAffinityCache()
if _, err := cache.DeleteMany([]string{cacheKey}); err != nil {
common.SysError(fmt.Sprintf("channel affinity cache delete failed: key=%s, err=%v", cacheKey, err))
return false
}
c.Set(ginKeyChannelAffinitySkipRetry, false)
if meta, ok := getChannelAffinityMeta(c); ok {
meta.SkipRetry = false
c.Set(ginKeyChannelAffinityMeta, meta)
}
appendChannelAffinityClearedAdminInfo(c, reason)
return true
}

func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
if c == nil || channelID <= 0 {
return
Expand Down
54 changes: 54 additions & 0 deletions service/channel_affinity_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package service

import (
"testing"
"time"

"github.com/QuantumNous/new-api/common"
"github.com/stretchr/testify/require"
)

func TestClearCurrentChannelAffinity(t *testing.T) {
originalRedisEnabled := common.RedisEnabled
common.RedisEnabled = false
t.Cleanup(func() {
common.RedisEnabled = originalRedisEnabled
_ = getChannelAffinityCache().Purge()
})

cache := getChannelAffinityCache()
_ = cache.Purge()

cacheKey := "new-api:channel_affinity:v1:codex cli trace:default:test-affinity-key"
require.NoError(t, cache.SetWithTTL(cacheKey, 42, time.Minute))

ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
CacheKey: cacheKey,
TTLSeconds: 60,
RuleName: "codex cli trace",
SkipRetry: true,
UsingGroup: "default",
ModelName: "gpt-5",
RequestPath: "/v1/responses",
KeySourceType: "gjson",
KeySourcePath: "prompt_cache_key",
KeyHint: "test...key",
KeyFingerprint: "abcd1234",
})

require.True(t, ClearCurrentChannelAffinity(ctx, "preferred affinity channel 42 is disabled"))

_, found, err := cache.Get(cacheKey)
require.NoError(t, err)
require.False(t, found)
require.False(t, ShouldSkipRetryAfterChannelAffinityFailure(ctx))

anyInfo, ok := ctx.Get(ginKeyChannelAffinityLogInfo)
require.True(t, ok)
info, ok := anyInfo.(map[string]interface{})
require.True(t, ok)
require.Equal(t, true, info["stale_affinity_cleared"])
require.Equal(t, "preferred affinity channel 42 is disabled", info["stale_affinity_reason"])
require.Equal(t, "codex cli trace", info["rule_name"])
require.Equal(t, "default", info["using_group"])
}