Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,25 @@ func (c ChannelInfo) Value() (driver.Value, error) {

// Scan implements sql.Scanner interface
func (c *ChannelInfo) Scan(value interface{}) error {
bytesValue, _ := value.([]byte)
return common.Unmarshal(bytesValue, c)
switch typedValue := value.(type) {
case nil:
*c = ChannelInfo{}
return nil
case []byte:
if len(typedValue) == 0 {
*c = ChannelInfo{}
return nil
}
return common.Unmarshal(typedValue, c)
case string:
if typedValue == "" {
*c = ChannelInfo{}
return nil
}
return common.Unmarshal([]byte(typedValue), c)
default:
return fmt.Errorf("unsupported channel info type: %T", value)
}
}

func (channel *Channel) GetKeys() []string {
Expand Down Expand Up @@ -608,42 +625,13 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason
}
}

// UpdateChannelStatus updates channel state and its ability visibility atomically.
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
if common.MemoryCacheEnabled {
channelStatusLock.Lock()
defer channelStatusLock.Unlock()

channelCache, _ := CacheGetChannel(channelId)
if channelCache == nil {
return false
}
if channelCache.ChannelInfo.IsMultiKey {
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
pollingLock := GetChannelPollingLock(channelId)
pollingLock.Lock()
// 如果是多Key模式,更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
pollingLock.Unlock()
//CacheUpdateChannel(channelCache)
//return true
} else {
// 如果缓存渠道存在,且状态已是目标状态,直接返回
if channelCache.Status == status {
return false
}
CacheUpdateChannelStatus(channelId, status)
}
}

shouldUpdateAbilities := false
defer func() {
if shouldUpdateAbilities {
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
if err != nil {
common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
}
}
}()
channel, err := GetChannelById(channelId, true)
if err != nil {
return false
Expand All @@ -670,11 +658,32 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
channel.Status = status
shouldUpdateAbilities = true
}
err = channel.SaveWithoutKey()
if err != nil {
}
err = DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Omit("key").Save(channel).Error; err != nil {
return err
}
if shouldUpdateAbilities {
err := tx.Model(&Ability{}).
Where("channel_id = ?", channelId).
Select("enabled").
Update("enabled", status == common.ChannelStatusEnabled).Error
if err != nil {
return err
}
}
return nil
})
if err != nil {
if shouldUpdateAbilities {
common.SysLog(fmt.Sprintf("failed to update channel or ability status atomically: channel_id=%d, status=%d, error=%v", channelId, status, err))
} else {
common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
return false
}
return false
}
if common.MemoryCacheEnabled {
InitChannelCache()
}
return true
}
Expand Down
166 changes: 121 additions & 45 deletions model/channel_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"fmt"
"math/rand"
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/QuantumNous/new-api/common"
Expand All @@ -17,50 +17,71 @@ import (
var group2model2channels map[string]map[string][]int // enabled channel
var channelsIDM map[int]*Channel // all channels include disabled
var channelSyncLock sync.RWMutex
var channelCacheRefreshInFlight atomic.Bool
var channelCacheRefreshPending atomic.Bool

// InitChannelCache rebuilds the in-memory channel cache from database state.
func InitChannelCache() {
if !common.MemoryCacheEnabled {
return
}
channelCacheRefreshPending.Store(true)
if channelCacheRefreshInFlight.CompareAndSwap(false, true) {
runChannelCacheRefreshLoop()
return
}
for channelCacheRefreshInFlight.Load() {
time.Sleep(10 * time.Millisecond)
}
}

func buildChannelCacheSnapshot() error {
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Find(&channels)
if err := DB.Find(&channels).Error; err != nil {
return fmt.Errorf("failed to sync channels from database: %w", err)
}
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
var abilities []*Ability
DB.Find(&abilities)
groups := make(map[string]bool)
for _, ability := range abilities {
groups[ability.Group] = true
if err := DB.Find(&abilities).Error; err != nil {
return fmt.Errorf("failed to sync abilities from database: %w", err)
}
newGroup2model2channels := make(map[string]map[string][]int)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]int)
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue // skip disabled channels
for _, ability := range abilities {
if !ability.Enabled {
continue
}
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]int, 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
}
channel, ok := newChannelId2channel[ability.ChannelId]
if !ok || channel.Status != common.ChannelStatusEnabled {
continue
}
if _, ok := newGroup2model2channels[ability.Group]; !ok {
newGroup2model2channels[ability.Group] = make(map[string][]int)
}
newGroup2model2channels[ability.Group][ability.Model] = append(
newGroup2model2channels[ability.Group][ability.Model],
ability.ChannelId,
)
}

// sort by priority
// dedupe and sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
seen := make(map[int]struct{}, len(channels))
deduped := make([]int, 0, len(channels))
for _, channelId := range channels {
if _, ok := seen[channelId]; ok {
continue
}
seen[channelId] = struct{}{}
deduped = append(deduped, channelId)
}
sort.Slice(deduped, func(i, j int) bool {
return newChannelId2channel[deduped[i]].GetPriority() > newChannelId2channel[deduped[j]].GetPriority()
})
newGroup2model2channels[group][model] = channels
newGroup2model2channels[group][model] = deduped
}
}

Expand All @@ -83,8 +104,23 @@ func InitChannelCache() {
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
return nil
}

func runChannelCacheRefreshLoop() {
defer channelCacheRefreshInFlight.Store(false)
for {
channelCacheRefreshPending.Store(false)
if err := buildChannelCacheSnapshot(); err != nil {
common.SysError(err.Error())
}
if !channelCacheRefreshPending.Load() {
return
}
}
}

// SyncChannelCache periodically refreshes the in-memory channel cache.
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
Expand All @@ -93,15 +129,25 @@ func SyncChannelCache(frequency int) {
}
}

func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
// if memory cache is disabled, get channel directly from database
func requestChannelCacheRefreshAsync() {
if !common.MemoryCacheEnabled {
return GetChannel(group, model, retry)
return
}
channelCacheRefreshPending.Store(true)
if !channelCacheRefreshInFlight.CompareAndSwap(false, true) {
return
}
go func() {
defer func() {
if r := recover(); r != nil {
common.SysLog(fmt.Sprintf("InitChannelCache panic: %v", r))
}
}()
runChannelCacheRefreshLoop()
}()
}

channelSyncLock.RLock()
defer channelSyncLock.RUnlock()

func getRandomSatisfiedChannelFromCache(group string, model string, retry int) (*Channel, error, bool) {
// First, try to find channels with the exact model name.
channels := group2model2channels[group][model]

Expand All @@ -112,22 +158,22 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
}

if len(channels) == 0 {
return nil, nil
return nil, nil, false
}

if len(channels) == 1 {
if channel, ok := channelsIDM[channels[0]]; ok {
return channel, nil
return channel, nil, true
}
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]), true
}

uniquePriorities := make(map[int]bool)
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
uniquePriorities[int(channel.GetPriority())] = true
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId), true
}
}
var sortedUniquePriorities []int
Expand All @@ -151,12 +197,12 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
targetChannels = append(targetChannels, channel)
}
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId), true
}
}

if len(targetChannels) == 0 {
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority))
return nil, fmt.Errorf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority), true
}

// smoothing factor and adjustment
Expand All @@ -183,13 +229,45 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
for _, channel := range targetChannels {
randomWeight -= channel.GetWeight()*smoothingFactor + smoothingAdjustment
if randomWeight < 0 {
return channel, nil
return channel, nil, true
}
}
return nil, errors.New("channel not found"), true
}

// GetRandomSatisfiedChannel returns a channel for the requested group/model pair.
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetChannel(group, model, retry)
}

channelSyncLock.RLock()
channel, cacheErr, cacheHit := getRandomSatisfiedChannelFromCache(group, model, retry)
channelSyncLock.RUnlock()
if channel != nil || (cacheHit && cacheErr == nil) {
return channel, cacheErr
}

fallbackChannel, fallbackErr := GetChannel(group, model, retry)
if fallbackErr != nil {
if cacheErr != nil {
return nil, cacheErr
}
return nil, fallbackErr
}
if fallbackChannel != nil && fallbackChannel.Status == common.ChannelStatusEnabled {
requestChannelCacheRefreshAsync()
return fallbackChannel, nil
}
// return null if no channel is not found
return nil, errors.New("channel not found")
if cacheErr != nil {
requestChannelCacheRefreshAsync()
return nil, cacheErr
}
return nil, nil
}

// CacheGetChannel returns a channel from the in-memory cache when available.
func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
Expand All @@ -204,6 +282,7 @@ func CacheGetChannel(id int) (*Channel, error) {
return c, nil
}

// CacheGetChannelInfo returns cached channel info when available.
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !common.MemoryCacheEnabled {
channel, err := GetChannelById(id, true)
Expand All @@ -222,6 +301,7 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
return &c.ChannelInfo, nil
}

// CacheUpdateChannelStatus mutates a cached channel status in place.
func CacheUpdateChannelStatus(id int, status int) {
if !common.MemoryCacheEnabled {
return
Expand All @@ -247,6 +327,7 @@ func CacheUpdateChannelStatus(id int, status int) {
}
}

// CacheUpdateChannel updates a cached channel entry in place.
func CacheUpdateChannel(channel *Channel) {
if !common.MemoryCacheEnabled {
return
Expand All @@ -256,10 +337,5 @@ func CacheUpdateChannel(channel *Channel) {
if channel == nil {
return
}

println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)

println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
channelsIDM[channel.Id] = channel
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
}
Loading