diff --git a/model/channel.go b/model/channel.go index f256b54ce35..e0f846bc4da 100644 --- a/model/channel.go +++ b/model/channel.go @@ -74,8 +74,25 @@ func (c ChannelInfo) Value() (driver.Value, error) { // Scan implements sql.Scanner interface func (c *ChannelInfo) Scan(value interface{}) error { - bytesValue, _ := value.([]byte) - return common.Unmarshal(bytesValue, c) + switch typedValue := value.(type) { + case nil: + *c = ChannelInfo{} + return nil + case []byte: + if len(typedValue) == 0 { + *c = ChannelInfo{} + return nil + } + return common.Unmarshal(typedValue, c) + case string: + if typedValue == "" { + *c = ChannelInfo{} + return nil + } + return common.Unmarshal([]byte(typedValue), c) + default: + return fmt.Errorf("unsupported channel info type: %T", value) + } } func (channel *Channel) GetKeys() []string { @@ -608,42 +625,13 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason } } +// UpdateChannelStatus updates channel state and its ability visibility atomically. func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() defer channelStatusLock.Unlock() - - channelCache, _ := CacheGetChannel(channelId) - if channelCache == nil { - return false - } - if channelCache.ChannelInfo.IsMultiKey { - // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey - pollingLock := GetChannelPollingLock(channelId) - pollingLock.Lock() - // 如果是多Key模式,更新缓存中的状态 - handlerMultiKeyUpdate(channelCache, usingKey, status, reason) - pollingLock.Unlock() - //CacheUpdateChannel(channelCache) - //return true - } else { - // 如果缓存渠道存在,且状态已是目标状态,直接返回 - if channelCache.Status == status { - return false - } - CacheUpdateChannelStatus(channelId, status) - } } - shouldUpdateAbilities := false - defer func() { - if shouldUpdateAbilities { - err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) - if err != nil { - common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err)) - } - } - }() channel, err := GetChannelById(channelId, true) if err != nil { return false @@ -670,11 +658,32 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri channel.Status = status shouldUpdateAbilities = true } - err = channel.SaveWithoutKey() - if err != nil { + } + err = DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Omit("key").Save(channel).Error; err != nil { + return err + } + if shouldUpdateAbilities { + err := tx.Model(&Ability{}). + Where("channel_id = ?", channelId). + Select("enabled"). + Update("enabled", status == common.ChannelStatusEnabled).Error + if err != nil { + return err + } + } + return nil + }) + if err != nil { + if shouldUpdateAbilities { + common.SysLog(fmt.Sprintf("failed to update channel or ability status atomically: channel_id=%d, status=%d, error=%v", channelId, status, err)) + } else { common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err)) - return false } + return false + } + if common.MemoryCacheEnabled { + InitChannelCache() } return true } diff --git a/model/channel_cache.go b/model/channel_cache.go index c9c50357603..8f8a30e9bdb 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -5,8 +5,8 @@ import ( "fmt" "math/rand" "sort" - "strings" "sync" + "sync/atomic" "time" "github.com/QuantumNous/new-api/common" @@ -17,50 +17,71 @@ import ( var group2model2channels map[string]map[string][]int // enabled channel var channelsIDM map[int]*Channel // all channels include disabled var channelSyncLock sync.RWMutex +var channelCacheRefreshInFlight atomic.Bool +var channelCacheRefreshPending atomic.Bool +// InitChannelCache rebuilds the in-memory channel cache from database state. func InitChannelCache() { if !common.MemoryCacheEnabled { return } + channelCacheRefreshPending.Store(true) + if channelCacheRefreshInFlight.CompareAndSwap(false, true) { + runChannelCacheRefreshLoop() + return + } + for channelCacheRefreshInFlight.Load() { + time.Sleep(10 * time.Millisecond) + } +} + +func buildChannelCacheSnapshot() error { newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Find(&channels) + if err := DB.Find(&channels).Error; err != nil { + return fmt.Errorf("failed to sync channels from database: %w", err) + } for _, channel := range channels { newChannelId2channel[channel.Id] = channel } var abilities []*Ability - DB.Find(&abilities) - groups := make(map[string]bool) - for _, ability := range abilities { - groups[ability.Group] = true + if err := DB.Find(&abilities).Error; err != nil { + return fmt.Errorf("failed to sync abilities from database: %w", err) } newGroup2model2channels := make(map[string]map[string][]int) - for group := range groups { - newGroup2model2channels[group] = make(map[string][]int) - } - for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { - continue // skip disabled channels + for _, ability := range abilities { + if !ability.Enabled { + continue } - groups := strings.Split(channel.Group, ",") - for _, group := range groups { - models := strings.Split(channel.Models, ",") - for _, model := range models { - if _, ok := newGroup2model2channels[group][model]; !ok { - newGroup2model2channels[group][model] = make([]int, 0) - } - newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id) - } + channel, ok := newChannelId2channel[ability.ChannelId] + if !ok || channel.Status != common.ChannelStatusEnabled { + continue + } + if _, ok := newGroup2model2channels[ability.Group]; !ok { + newGroup2model2channels[ability.Group] = make(map[string][]int) } + newGroup2model2channels[ability.Group][ability.Model] = append( + newGroup2model2channels[ability.Group][ability.Model], + ability.ChannelId, + ) } - // sort by priority + // dedupe and sort by priority for group, model2channels := range newGroup2model2channels { for model, channels := range model2channels { - sort.Slice(channels, func(i, j int) bool { - return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority() + seen := make(map[int]struct{}, len(channels)) + deduped := make([]int, 0, len(channels)) + for _, channelId := range channels { + if _, ok := seen[channelId]; ok { + continue + } + seen[channelId] = struct{}{} + deduped = append(deduped, channelId) + } + sort.Slice(deduped, func(i, j int) bool { + return newChannelId2channel[deduped[i]].GetPriority() > newChannelId2channel[deduped[j]].GetPriority() }) - newGroup2model2channels[group][model] = channels + newGroup2model2channels[group][model] = deduped } } @@ -83,8 +104,23 @@ func InitChannelCache() { channelsIDM = newChannelId2channel channelSyncLock.Unlock() common.SysLog("channels synced from database") + return nil +} + +func runChannelCacheRefreshLoop() { + defer channelCacheRefreshInFlight.Store(false) + for { + channelCacheRefreshPending.Store(false) + if err := buildChannelCacheSnapshot(); err != nil { + common.SysError(err.Error()) + } + if !channelCacheRefreshPending.Load() { + return + } + } } +// SyncChannelCache periodically refreshes the in-memory channel cache. func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) @@ -93,15 +129,25 @@ func SyncChannelCache(frequency int) { } } -func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { - // if memory cache is disabled, get channel directly from database +func requestChannelCacheRefreshAsync() { if !common.MemoryCacheEnabled { - return GetChannel(group, model, retry) + return } + channelCacheRefreshPending.Store(true) + if !channelCacheRefreshInFlight.CompareAndSwap(false, true) { + return + } + go func() { + defer func() { + if r := recover(); r != nil { + common.SysLog(fmt.Sprintf("InitChannelCache panic: %v", r)) + } + }() + runChannelCacheRefreshLoop() + }() +} - channelSyncLock.RLock() - defer channelSyncLock.RUnlock() - +func getRandomSatisfiedChannelFromCache(group string, model string, retry int) (*Channel, error, bool) { // First, try to find channels with the exact model name. channels := group2model2channels[group][model] @@ -112,14 +158,14 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, } if len(channels) == 0 { - return nil, nil + return nil, nil, false } if len(channels) == 1 { if channel, ok := channelsIDM[channels[0]]; ok { - return channel, nil + return channel, nil, true } - return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]) + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]), true } uniquePriorities := make(map[int]bool) @@ -127,7 +173,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, if channel, ok := channelsIDM[channelId]; ok { uniquePriorities[int(channel.GetPriority())] = true } else { - return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId), true } } var sortedUniquePriorities []int @@ -151,12 +197,12 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, targetChannels = append(targetChannels, channel) } } else { - return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId), true } } if len(targetChannels) == 0 { - return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority)) + return nil, fmt.Errorf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority), true } // smoothing factor and adjustment @@ -183,13 +229,45 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, for _, channel := range targetChannels { randomWeight -= channel.GetWeight()*smoothingFactor + smoothingAdjustment if randomWeight < 0 { - return channel, nil + return channel, nil, true + } + } + return nil, errors.New("channel not found"), true +} + +// GetRandomSatisfiedChannel returns a channel for the requested group/model pair. +func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { + // if memory cache is disabled, get channel directly from database + if !common.MemoryCacheEnabled { + return GetChannel(group, model, retry) + } + + channelSyncLock.RLock() + channel, cacheErr, cacheHit := getRandomSatisfiedChannelFromCache(group, model, retry) + channelSyncLock.RUnlock() + if channel != nil || (cacheHit && cacheErr == nil) { + return channel, cacheErr + } + + fallbackChannel, fallbackErr := GetChannel(group, model, retry) + if fallbackErr != nil { + if cacheErr != nil { + return nil, cacheErr } + return nil, fallbackErr + } + if fallbackChannel != nil && fallbackChannel.Status == common.ChannelStatusEnabled { + requestChannelCacheRefreshAsync() + return fallbackChannel, nil } - // return null if no channel is not found - return nil, errors.New("channel not found") + if cacheErr != nil { + requestChannelCacheRefreshAsync() + return nil, cacheErr + } + return nil, nil } +// CacheGetChannel returns a channel from the in-memory cache when available. func CacheGetChannel(id int) (*Channel, error) { if !common.MemoryCacheEnabled { return GetChannelById(id, true) @@ -204,6 +282,7 @@ func CacheGetChannel(id int) (*Channel, error) { return c, nil } +// CacheGetChannelInfo returns cached channel info when available. func CacheGetChannelInfo(id int) (*ChannelInfo, error) { if !common.MemoryCacheEnabled { channel, err := GetChannelById(id, true) @@ -222,6 +301,7 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) { return &c.ChannelInfo, nil } +// CacheUpdateChannelStatus mutates a cached channel status in place. func CacheUpdateChannelStatus(id int, status int) { if !common.MemoryCacheEnabled { return @@ -247,6 +327,7 @@ func CacheUpdateChannelStatus(id int, status int) { } } +// CacheUpdateChannel updates a cached channel entry in place. func CacheUpdateChannel(channel *Channel) { if !common.MemoryCacheEnabled { return @@ -256,10 +337,5 @@ func CacheUpdateChannel(channel *Channel) { if channel == nil { return } - - println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex) - - println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) channelsIDM[channel.Id] = channel - println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) } diff --git a/model/channel_cache_test.go b/model/channel_cache_test.go new file mode 100644 index 00000000000..6601324e407 --- /dev/null +++ b/model/channel_cache_test.go @@ -0,0 +1,191 @@ +package model + +import ( + "fmt" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/stretchr/testify/require" +) + +func prepareChannelCacheTest(t *testing.T) { + t.Helper() + initCol() + require.NoError(t, DB.AutoMigrate(&Ability{})) + DB.Exec("DELETE FROM abilities") + DB.Exec("DELETE FROM channels") + + channelSyncLock.Lock() + group2model2channels = nil + channelsIDM = nil + channelSyncLock.Unlock() + channelCacheRefreshInFlight.Store(false) + channelCacheRefreshPending.Store(false) +} + +func TestGetRandomSatisfiedChannelFallsBackToDatabaseOnCacheMiss(t *testing.T) { + prepareChannelCacheTest(t) + + prevMemoryCacheEnabled := common.MemoryCacheEnabled + common.MemoryCacheEnabled = true + t.Cleanup(func() { + common.MemoryCacheEnabled = prevMemoryCacheEnabled + }) + + channel := &Channel{ + Id: 101, + Name: "fallback-channel", + Status: common.ChannelStatusEnabled, + Group: "default", + Models: "other-model", + } + require.NoError(t, DB.Create(channel).Error) + require.NoError(t, DB.Create(&Ability{ + Group: "default", + Model: "gpt-5.4", + ChannelId: channel.Id, + Enabled: true, + }).Error) + + got, err := GetRandomSatisfiedChannel("default", "gpt-5.4", 0) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, channel.Id, got.Id) + + require.Eventually(t, func() bool { + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + return isChannelIDInList(group2model2channels["default"]["gpt-5.4"], channel.Id) + }, time.Second, 20*time.Millisecond) +} + +func TestUpdateChannelStatusRefreshesMemoryCacheAfterEnable(t *testing.T) { + prepareChannelCacheTest(t) + + prevMemoryCacheEnabled := common.MemoryCacheEnabled + common.MemoryCacheEnabled = true + t.Cleanup(func() { + common.MemoryCacheEnabled = prevMemoryCacheEnabled + }) + + channel := &Channel{ + Id: 102, + Name: "auto-disabled-channel", + Status: common.ChannelStatusAutoDisabled, + Group: "default", + Models: "gpt-5.4", + } + require.NoError(t, DB.Create(channel).Error) + require.NoError(t, DB.Create(&Ability{ + Group: "default", + Model: "gpt-5.4", + ChannelId: channel.Id, + Enabled: false, + }).Error) + + InitChannelCache() + + got, err := GetRandomSatisfiedChannel("default", "gpt-5.4", 0) + require.NoError(t, err) + require.Nil(t, got) + + require.True(t, UpdateChannelStatus(channel.Id, "", common.ChannelStatusEnabled, "")) + + got, err = GetRandomSatisfiedChannel("default", "gpt-5.4", 0) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, channel.Id, got.Id) + + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + require.True(t, isChannelIDInList(group2model2channels["default"]["gpt-5.4"], channel.Id)) +} + +func TestIsChannelEnabledForGroupModelFallsBackToDatabaseOnCacheMiss(t *testing.T) { + prepareChannelCacheTest(t) + + prevMemoryCacheEnabled := common.MemoryCacheEnabled + common.MemoryCacheEnabled = true + t.Cleanup(func() { + common.MemoryCacheEnabled = prevMemoryCacheEnabled + }) + + channel := &Channel{ + Id: 103, + Name: "satisfy-fallback-channel", + Status: common.ChannelStatusEnabled, + Group: "default", + Models: "other-model", + } + require.NoError(t, DB.Create(channel).Error) + require.NoError(t, DB.Create(&Ability{ + Group: "default", + Model: "gpt-5.4-mini", + ChannelId: channel.Id, + Enabled: true, + }).Error) + + require.True(t, IsChannelEnabledForGroupModel("default", "gpt-5.4-mini", channel.Id)) +} + +func TestInitChannelCacheKeepsPreviousSnapshotOnScanError(t *testing.T) { + prepareChannelCacheTest(t) + + prevMemoryCacheEnabled := common.MemoryCacheEnabled + common.MemoryCacheEnabled = true + t.Cleanup(func() { + common.MemoryCacheEnabled = prevMemoryCacheEnabled + }) + + channel := &Channel{ + Id: 104, + Name: "stable-cache-channel", + Status: common.ChannelStatusEnabled, + Group: "default", + Models: "gpt-5.4", + } + require.NoError(t, DB.Create(channel).Error) + require.NoError(t, DB.Create(&Ability{ + Group: "default", + Model: "gpt-5.4", + ChannelId: channel.Id, + Enabled: true, + }).Error) + + InitChannelCache() + + require.NoError(t, DB.Exec( + fmt.Sprintf( + "INSERT INTO channels (id, type, %s, status, name, models, %s, channel_info, settings) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + commonKeyCol, + commonGroupCol, + ), + 999, + 1, + "broken-key", + common.ChannelStatusEnabled, + "broken-channel", + "broken-model", + "default", + `{invalid`, + "", + ).Error) + + InitChannelCache() + + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + require.True(t, isChannelIDInList(group2model2channels["default"]["gpt-5.4"], channel.Id)) + require.Nil(t, channelsIDM[999]) +} + +func TestChannelInfoScanSupportsStringValue(t *testing.T) { + var info ChannelInfo + err := info.Scan(`{"is_multi_key":false,"multi_key_size":0,"multi_key_status_list":{},"multi_key_disabled_reason":{},"multi_key_disabled_time":{},"multi_key_polling_index":0,"multi_key_mode":"random"}`) + require.NoError(t, err) + require.False(t, info.IsMultiKey) + require.Equal(t, 0, info.MultiKeySize) + require.Equal(t, 0, info.MultiKeyPollingIndex) + require.Equal(t, "random", string(info.MultiKeyMode)) +} diff --git a/model/channel_satisfy.go b/model/channel_satisfy.go index 681f1e69bb6..35c12d91aac 100644 --- a/model/channel_satisfy.go +++ b/model/channel_satisfy.go @@ -5,6 +5,7 @@ import ( "github.com/QuantumNous/new-api/setting/ratio_setting" ) +// IsChannelEnabledForGroupModel reports whether a channel is enabled for a group/model pair. func IsChannelEnabledForGroupModel(group string, modelName string, channelID int) bool { if group == "" || modelName == "" || channelID <= 0 { return false @@ -14,22 +15,27 @@ func IsChannelEnabledForGroupModel(group string, modelName string, channelID int } channelSyncLock.RLock() - defer channelSyncLock.RUnlock() - if group2model2channels == nil { - return false + channelSyncLock.RUnlock() + return isChannelEnabledForGroupModelDB(group, modelName, channelID) } if isChannelIDInList(group2model2channels[group][modelName], channelID) { + channelSyncLock.RUnlock() return true } normalized := ratio_setting.FormatMatchingModelName(modelName) if normalized != "" && normalized != modelName { - return isChannelIDInList(group2model2channels[group][normalized], channelID) + if isChannelIDInList(group2model2channels[group][normalized], channelID) { + channelSyncLock.RUnlock() + return true + } } - return false + channelSyncLock.RUnlock() + return isChannelEnabledForGroupModelDB(group, modelName, channelID) } +// IsChannelEnabledForAnyGroupModel reports whether a channel is enabled for any group/model pair. func IsChannelEnabledForAnyGroupModel(groups []string, modelName string, channelID int) bool { if len(groups) == 0 { return false @@ -44,8 +50,10 @@ func IsChannelEnabledForAnyGroupModel(groups []string, modelName string, channel func isChannelEnabledForGroupModelDB(group string, modelName string, channelID int) bool { var count int64 + groupColumn := "abilities." + commonGroupCol err := DB.Model(&Ability{}). - Where(commonGroupCol+" = ? and model = ? and channel_id = ? and enabled = ?", group, modelName, channelID, true). + Joins("JOIN channels ON channels.id = abilities.channel_id"). + Where(groupColumn+" = ? and abilities.model = ? and abilities.channel_id = ? and abilities.enabled = ? and channels.status = ?", group, modelName, channelID, true, common.ChannelStatusEnabled). Count(&count).Error if err == nil && count > 0 { return true @@ -56,7 +64,8 @@ func isChannelEnabledForGroupModelDB(group string, modelName string, channelID i } count = 0 err = DB.Model(&Ability{}). - Where(commonGroupCol+" = ? and model = ? and channel_id = ? and enabled = ?", group, normalized, channelID, true). + Joins("JOIN channels ON channels.id = abilities.channel_id"). + Where(groupColumn+" = ? and abilities.model = ? and abilities.channel_id = ? and abilities.enabled = ? and channels.status = ?", group, normalized, channelID, true, common.ChannelStatusEnabled). Count(&count).Error return err == nil && count > 0 }