diff --git a/service/token_estimator.go b/service/token_estimator.go index 9e27269ce3d..426f29b538c 100644 --- a/service/token_estimator.go +++ b/service/token_estimator.go @@ -3,7 +3,6 @@ package service import ( "math" "strings" - "sync" "unicode" ) @@ -32,37 +31,45 @@ type multipliers struct { BasePad int // 基础起步消耗 (Start/End tokens) } -var ( - multipliersMap = map[Provider]multipliers{ - Gemini: { - Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0, - }, - Claude: { - Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0, - }, - OpenAI: { - Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0, - }, +// 直接用常量 map,无需锁保护(只读) +var multipliersMap = map[Provider]multipliers{ + Gemini: { + Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0, + }, + Claude: { + Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0, + }, + OpenAI: { + Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0, + }, +} + +// mathSymbolSet 用 map 做 O(1) 查找,替代每次线性扫描字符串 +var mathSymbolSet = func() map[rune]struct{} { + s := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰" + m := make(map[rune]struct{}, 64) + for _, r := range s { + m[r] = struct{}{} } - multipliersLock sync.RWMutex -) + return m +}() + +// urlDelimSet 用 [128]bool 做 ASCII 快速查找 +var urlDelimSet = func() [128]bool { + var t [128]bool + for _, r := range "/:?&=;#%" { + t[r] = true + } + return t +}() // getMultipliers 根据厂商获取权重配置 func getMultipliers(p Provider) multipliers { - multipliersLock.RLock() - defer multipliersLock.RUnlock() - - switch p { - case Gemini: - return multipliersMap[Gemini] - case Claude: - return multipliersMap[Claude] - case OpenAI: - return multipliersMap[OpenAI] - default: - // 默认兜底 (按 OpenAI 的算) - return multipliersMap[OpenAI] + if m, ok := multipliersMap[p]; ok { + return m } + // 默认兜底 (按 OpenAI 的算) + return multipliersMap[OpenAI] } // EstimateToken 计算 Token 数量 @@ -71,73 +78,99 @@ func EstimateToken(provider Provider, text string) int { var count float64 // 状态机变量 - type WordType int const ( - None WordType = iota - Latin - Number + stNone = 0 + stLatin = 1 + stNumber = 2 ) - currentWordType := None + state := stNone for _, r := range text { - // 1. 处理空格和换行符 - if unicode.IsSpace(r) { - currentWordType = None - // 换行符和制表符使用Newline权重 + // 快速路径:ASCII 字符 (覆盖绝大多数英文文本) + if r < 128 { + if r == ' ' { + state = stNone + count += m.Space + continue + } if r == '\n' || r == '\t' { + state = stNone count += m.Newline - } else { - // 普通空格使用Space权重 + continue + } + // a-z, A-Z + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + if state != stLatin { + count += m.Word + state = stLatin + } + continue + } + // 0-9 + if r >= '0' && r <= '9' { + if state != stNumber { + count += m.Number + state = stNumber + } + continue + } + // 其他 ASCII + state = stNone + if r == '@' { + count += m.AtSign + } else if urlDelimSet[r] { + count += m.URLDelim + } else if r == '\r' || r == '\f' || r == '\v' { count += m.Space + } else { + count += m.Symbol } continue } - // 2. 处理 CJK (中日韩) - 按字符计费 + // 非 ASCII 路径 + + // CJK (中日韩) - 按字符计费 if isCJK(r) { - currentWordType = None + state = stNone count += m.CJK continue } - // 3. 处理Emoji - 使用专门的Emoji权重 + // Emoji if isEmoji(r) { - currentWordType = None + state = stNone count += m.Emoji continue } - // 4. 处理拉丁字母/数字 (英文单词) - if isLatinOrNumber(r) { - isNum := unicode.IsNumber(r) - newType := Latin - if isNum { - newType = Number + // 非 ASCII 字母/数字 (如带重音的拉丁字母、西里尔字母等) + if unicode.IsLetter(r) { + if state != stLatin { + count += m.Word + state = stLatin } - - // 如果之前不在单词中,或者类型发生变化(字母<->数字),则视为新token - // 注意:对于OpenAI,通常"version 3.5"会切分,"abc123xyz"有时也会切分 - // 这里简单起见,字母和数字切换时增加权重 - if currentWordType == None || currentWordType != newType { - if newType == Number { - count += m.Number - } else { - count += m.Word - } - currentWordType = newType + continue + } + if unicode.IsNumber(r) { + if state != stNumber { + count += m.Number + state = stNumber } - // 单词中间的字符不额外计费 continue } - // 5. 处理标点符号/特殊字符 - 按类型使用不同权重 - currentWordType = None + // 空白字符 (非 ASCII 空白,如全角空格) + if unicode.IsSpace(r) { + state = stNone + count += m.Space + continue + } + + // 标点符号/特殊字符 + state = stNone if isMathSymbol(r) { count += m.MathSymbol - } else if r == '@' { - count += m.AtSign - } else if isURLDelim(r) { - count += m.URLDelim } else { count += m.Symbol } @@ -147,74 +180,53 @@ func EstimateToken(provider Provider, text string) int { return int(math.Ceil(count)) + m.BasePad } -// 辅助:判断是否为 CJK 字符 +// isCJK 判断是否为CJK(中日韩)字符,基本区直接范围判断,扩展区回退到unicode.Han func isCJK(r rune) bool { - return unicode.Is(unicode.Han, r) || - (r >= 0x3040 && r <= 0x30FF) || // 日文 - (r >= 0xAC00 && r <= 0xD7A3) // 韩文 -} - -// 辅助:判断是否为单词主体 (字母或数字) -func isLatinOrNumber(r rune) bool { - return unicode.IsLetter(r) || unicode.IsNumber(r) + // CJK统一汉字基本区 (最常见,快速路径) + if r >= 0x4E00 && r <= 0x9FFF { + return true + } + // 日文平假名+片假名 + if r >= 0x3040 && r <= 0x30FF { + return true + } + // 韩文音节 + if r >= 0xAC00 && r <= 0xD7A3 { + return true + } + // CJK扩展区 (较少见,用 unicode.Han 兜底) + return unicode.Is(unicode.Han, r) } -// 辅助:判断是否为Emoji字符 +// isEmoji 判断是否为Emoji字符,覆盖常见的Emoji Unicode区块 func isEmoji(r rune) bool { - // Emoji的Unicode范围 - // 基本范围:0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs) - // 补充范围:0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats) - // 表情符号:0x1F600-0x1F64F (Emoticons) - // 其他:0x1F900-0x1F9FF (Supplemental Symbols and Pictographs) return (r >= 0x1F300 && r <= 0x1F9FF) || (r >= 0x2600 && r <= 0x26FF) || (r >= 0x2700 && r <= 0x27BF) || (r >= 0x1F600 && r <= 0x1F64F) || (r >= 0x1F900 && r <= 0x1F9FF) || - (r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A + (r >= 0x1FA00 && r <= 0x1FAFF) } -// 辅助:判断是否为数学符号 +// isMathSymbol 判断是否为数学符号,优先检查Unicode数学区块,再查散列符号集合 func isMathSymbol(r rune) bool { - // 数学运算符和符号 - // 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷ - // 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰ - // 希腊字母等也常用于数学 - mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰" - for _, m := range mathSymbols { - if r == m { - return true - } - } - // Mathematical Operators (U+2200–U+22FF) - if r >= 0x2200 && r <= 0x22FF { + // 范围检查优先(覆盖大部分数学符号区块) + if r >= 0x2200 && r <= 0x22FF { // Mathematical Operators return true } - // Supplemental Mathematical Operators (U+2A00–U+2AFF) - if r >= 0x2A00 && r <= 0x2AFF { + if r >= 0x2A00 && r <= 0x2AFF { // Supplemental Mathematical Operators return true } - // Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF) - if r >= 0x1D400 && r <= 0x1D7FF { + if r >= 0x1D400 && r <= 0x1D7FF { // Mathematical Alphanumeric Symbols return true } - return false -} - -// 辅助:判断是否为URL分隔符(tokenizer对这些优化较好) -func isURLDelim(r rune) bool { - // URL中常见的分隔符,tokenizer通常优化处理 - urlDelims := "/:?&=;#%" - for _, d := range urlDelims { - if r == d { - return true - } - } - return false + // 散落的单个数学符号 (°, ±, ×, ÷, 上下标等) + _, ok := mathSymbolSet[r] + return ok } +// EstimateTokenByModel 根据模型名称自动识别厂商并估算文本的token数量 func EstimateTokenByModel(model, text string) int { - // strings.Contains(model, "gpt-4o") if text == "" { return 0 } diff --git a/service/token_estimator_test.go b/service/token_estimator_test.go new file mode 100644 index 00000000000..4b638590092 --- /dev/null +++ b/service/token_estimator_test.go @@ -0,0 +1,131 @@ +package service + +import ( + "strings" + "testing" +) + +// Golden tests: 精确锁定每个 provider 的 token 估算结果,防止优化引入行为变更 +func TestEstimateToken_Golden(t *testing.T) { + tests := []struct { + name string + provider Provider + text string + want int + }{ + // 英文 + {"english/openai", OpenAI, "Hello world, this is a test sentence with some numbers 12345.", 17}, + {"english/claude", Claude, "Hello world, this is a test sentence with some numbers 12345.", 18}, + {"english/gemini", Gemini, "Hello world, this is a test sentence with some numbers 12345.", 18}, + // 中文 + {"chinese/openai", OpenAI, "你好世界,这是一段测试文本。", 11}, + {"chinese/claude", Claude, "你好世界,这是一段测试文本。", 16}, + {"chinese/gemini", Gemini, "你好世界,这是一段测试文本。", 9}, + // 混合文本(含 @, URL) + {"mixed/openai", OpenAI, "Hello 你好 world 世界 123 test@email.com https://example.com/path?q=1&a=2", 33}, + {"mixed/claude", Claude, "Hello 你好 world 世界 123 test@email.com https://example.com/path?q=1&a=2", 39}, + {"mixed/gemini", Gemini, "Hello 你好 world 世界 123 test@email.com https://example.com/path?q=1&a=2", 38}, + // 数学符号 + {"math/openai", OpenAI, "∑∫∂√∞ x² + y³ = z⁴", 25}, + {"math/claude", Claude, "∑∫∂√∞ x² + y³ = z⁴", 35}, + {"math/gemini", Gemini, "∑∫∂√∞ x² + y³ = z⁴", 20}, + // Emoji + {"emoji/openai", OpenAI, "Hello 😀🎉🚀 World", 10}, + {"emoji/claude", Claude, "Hello 😀🎉🚀 World", 11}, + {"emoji/gemini", Gemini, "Hello 😀🎉🚀 World", 6}, + // 空格和换行 + {"spaces_newlines/openai", OpenAI, "line1\nline2\tindented double", 10}, + {"spaces_newlines/claude", Claude, "line1\nline2\tindented double", 11}, + {"spaces_newlines/gemini", Gemini, "line1\nline2\tindented double", 13}, + // \r \f \v — 走 Space 权重(和 unicode.IsSpace 旧行为一致) + {"cr_ff_vt/openai", OpenAI, "a\rb\fc\vd", 6}, + {"cr_ff_vt/claude", Claude, "a\rb\fc\vd", 6}, + {"cr_ff_vt/gemini", Gemini, "a\rb\fc\vd", 6}, + // URL 密集 + {"url_heavy/openai", OpenAI, "https://example.com/path/to/resource?key=value&foo=bar#section", 23}, + {"url_heavy/claude", Claude, "https://example.com/path/to/resource?key=value&foo=bar#section", 27}, + {"url_heavy/gemini", Gemini, "https://example.com/path/to/resource?key=value&foo=bar#section", 27}, + // @ 符号 + {"at_sign/openai", OpenAI, "user@example.com @mention", 9}, + {"at_sign/claude", Claude, "user@example.com @mention", 11}, + {"at_sign/gemini", Gemini, "user@example.com @mention", 11}, + // 空字符串 + {"empty/openai", OpenAI, "", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := EstimateToken(tt.provider, tt.text) + if got != tt.want { + t.Errorf("EstimateToken(%s, %q) = %d, want %d", tt.provider, tt.text, got, tt.want) + } + }) + } +} + +func TestEstimateTokenByModel(t *testing.T) { + tests := []struct { + model string + text string + want int + }{ + {"gpt-4o", "Hello world 你好", 5}, + {"gemini-pro", "Hello world 你好", 5}, + {"claude-3-sonnet", "Hello world 你好", 6}, + {"gpt-4o", "", 0}, + } + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + got := EstimateTokenByModel(tt.model, tt.text) + if got != tt.want { + t.Errorf("EstimateTokenByModel(%q, %q) = %d, want %d", tt.model, tt.text, got, tt.want) + } + }) + } +} + +// --- Benchmarks --- + +var benchText = strings.Repeat("Hello world, this is a benchmark test. ", 100) + + strings.Repeat("你好世界,这是性能测试。", 50) + + strings.Repeat("https://example.com/path?q=1&a=2#frag ", 20) + + strings.Repeat("∑∫∂√∞ x²+y³=z⁴ ", 10) + + strings.Repeat("😀🎉🚀 ", 10) + +func BenchmarkEstimateToken_OpenAI(b *testing.B) { + for b.Loop() { + EstimateToken(OpenAI, benchText) + } +} + +func BenchmarkEstimateToken_Claude(b *testing.B) { + for b.Loop() { + EstimateToken(Claude, benchText) + } +} + +func BenchmarkEstimateToken_Gemini(b *testing.B) { + for b.Loop() { + EstimateToken(Gemini, benchText) + } +} + +func BenchmarkEstimateToken_PureEnglish(b *testing.B) { + text := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 200) + for b.Loop() { + EstimateToken(OpenAI, text) + } +} + +func BenchmarkEstimateToken_PureChinese(b *testing.B) { + text := strings.Repeat("人工智能技术正在快速发展和广泛应用。", 200) + for b.Loop() { + EstimateToken(OpenAI, text) + } +} + +func BenchmarkEstimateTokenByModel(b *testing.B) { + for b.Loop() { + EstimateTokenByModel("gpt-4o-mini", benchText) + } +}