From 2c3164afde412f5b93fc25b39a47f1903ca50ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Groenewegen=20=E2=98=81?= Date: Mon, 17 Feb 2025 20:53:55 +0100 Subject: [PATCH 1/2] Add token counting tokens to ITokenizer and TikTokenizer Updated the ITokenizer interface to include two new methods for counting tokens in a string, with options for handling special tokens. Implemented these methods in the TikTokenizer class, adding logic for counting based on special tokens and a maximum limit. Added unit tests in TokenizerTest to verify the new counting functionality, including cases for special tokens and empty strings. In case of checking length the allocation of the list is not done. When using the max count value, the method stops counting when the maxinum number of tokens is exceeded. This prevents endless counting when extreemly large content is provided. --- Tokenizer_C#/TokenizerLib/ITokenizer.cs | 10 ++ Tokenizer_C#/TokenizerLib/TikTokenizer.cs | 95 ++++++++++++++++++- .../TokenizerTest/TikTokenizerUnitTest.cs | 22 +++++ 3 files changed, 126 insertions(+), 1 deletion(-) diff --git a/Tokenizer_C#/TokenizerLib/ITokenizer.cs b/Tokenizer_C#/TokenizerLib/ITokenizer.cs index 08ff714..d4faa02 100644 --- a/Tokenizer_C#/TokenizerLib/ITokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/ITokenizer.cs @@ -43,5 +43,15 @@ public interface ITokenizer /// Decode an array of integer token ids /// public string Decode(int[] tokens); + + /// + /// Count a string with or without special tokens set through constructor. + /// + public int Count(string text, bool applySpecialTokens = true, int max = int.MaxValue); + + /// + /// Count a string with a set of allowed special tokens that are not broken apart. + /// + public int Count(string text, IReadOnlyCollection allowedSpecial, int max = int.MaxValue); } } diff --git a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs index c0bba10..c9107db 100644 --- a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs @@ -602,6 +602,99 @@ public string Decode(int[] tokens) return Encoding.UTF8.GetString(decoded.ToArray()); } - } + public int Count(string text, bool applySpecialTokens = true, int max = int.MaxValue) + { + if (applySpecialTokens && SpecialTokens.Count > 0) + { + return CountInternal(text, SpecialTokens, max); + } + + return CountTokens(text, max); + } + + public int Count(string text, IReadOnlyCollection allowedSpecial, int max = int.MaxValue) + { + if (allowedSpecial is null || allowedSpecial.Count == 0) + { + return CountTokens(text, max); + } + + return CountInternal(text, allowedSpecial, max); + } + + private int CountInternal(string text, IReadOnlyCollection allowedSpecial, int max) + { + int tokenCount = 0; + int start = 0; + while (true) + { + Match nextSpecial; + int end; + FindNextSpecialToken(text, allowedSpecial, start, out nextSpecial, out end); + if (end > start) + { + tokenCount += CountTokens(text[start..end], max - tokenCount); + if (tokenCount >= max) + { + return max; + } + } + + if (nextSpecial.Success) + { + tokenCount++; + if (tokenCount >= max) + { + return max; + } + start = nextSpecial.Index + nextSpecial.Length; + if (start >= text.Length) + { + break; + } + } + else + { + break; + } + } + + return tokenCount; + } + + private int CountTokens(string text, int max) + { + int tokenCount = 0; + foreach (Match match in Regex.Matches(text)) + { + var piece = match.Value; + if (this.Cache.Lookup(piece, out int[] tokens)) + { + tokenCount += tokens.Length; + } + else + { + var bytes = Encoding.UTF8.GetBytes(match.Value); + if (Encoder.TryGetValue(bytes, out int token)) + { + tokenCount++; + } + else + { + var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); + this.Cache.Add(piece, encodedTokens.ToArray()); + tokenCount += encodedTokens.Count; + } + } + + if (tokenCount >= max) + { + return max; + } + } + + return tokenCount; + } + } } diff --git a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs index 43b7d2e..332a4c5 100644 --- a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs +++ b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs @@ -304,5 +304,27 @@ public void TestEncodeR50kbase() Assert.AreEqual(text, decoded); } + [TestMethod] + public void TestCountR50kbbase() + { + var text = File.ReadAllText("./testData/lib.rs.txt"); + var count = Tokenizer_r50k_base.Count(text, new HashSet()); + Assert.AreEqual(11378, count); + } + + [TestMethod] + public void TestCountR50kbbaseSetMaxTokens() + { + var text = File.ReadAllText("./testData/lib.rs.txt"); + var count = Tokenizer_r50k_base.Count(text, new HashSet(), 10000); + Assert.AreEqual(10000, count); + } + + [TestMethod] + public void TestCount0Tokens() + { + var count = Tokenizer_r50k_base.Count("", new HashSet()); + Assert.AreEqual(0, count); + } } } From 26ab79b15d97b17bd87f0c1836cbb881d35e5e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Groenewegen=20=E2=98=81?= Date: Tue, 18 Feb 2025 10:10:30 +0100 Subject: [PATCH 2/2] Do not allocate not needed arrays for the sequences in the cache. --- Tokenizer_C#/TokenizerLib/TikTokenizer.cs | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs index c9107db..d52c7bf 100644 --- a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs @@ -33,7 +33,7 @@ public class TikTokenizer : ITokenizer /// public const int DefaultCacheSize = 4096; - private readonly LruCache Cache; + private readonly LruCache> Cache; public int NumOfCacheEntries => this.Cache.Count; @@ -47,7 +47,7 @@ public class TikTokenizer : ITokenizer /// Regex pattern to break a string to be encoded public TikTokenizer(IReadOnlyDictionary encoder, IReadOnlyDictionary specialTokensEncoder, string pattern, int cacheSize = DefaultCacheSize) { - Cache = new LruCache(cacheSize); + Cache = new LruCache>(cacheSize); Init(encoder, specialTokensEncoder, pattern); } @@ -59,7 +59,7 @@ public TikTokenizer(IReadOnlyDictionary encoder, IReadOnlyDictionar /// Regex pattern to break a string to be encoded public TikTokenizer(Stream tikTokenBpeFileStream, IReadOnlyDictionary specialTokensEncoder, string pattern, int cacheSize = DefaultCacheSize) { - Cache = new LruCache(cacheSize); + Cache = new LruCache>(cacheSize); var encoder = LoadTikTokenBpe(tikTokenBpeFileStream); Init(encoder, specialTokensEncoder, pattern); } @@ -251,7 +251,7 @@ private void Encode(string text, List tokenIds, int start, int end) { foreach (Match match in Regex.Matches(text[start..end])) { - if (this.Cache.Lookup(match.Value, out int[] tokens)) + if (this.Cache.Lookup(match.Value, out List tokens)) { tokenIds.AddRange(tokens); } @@ -267,7 +267,7 @@ private void Encode(string text, List tokenIds, int start, int end) { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); tokenIds.AddRange(encodedTokens); - this.Cache.Add(match.Value, encodedTokens.ToArray()); + this.Cache.Add(match.Value, encodedTokens); } } } @@ -290,9 +290,9 @@ private void Encode(string text, List tokenIds, int start, int end) foreach (Match match in Regex.Matches(text[start..end])) { var piece = match.Value; - if (this.Cache.Lookup(piece, out int[] tokens)) + if (this.Cache.Lookup(piece, out List tokens)) { - tokenCount += tokens.Length; + tokenCount += tokens.Count; if (tokenCount <= maxTokenCount) { encodeLength += piece.Length; @@ -323,7 +323,7 @@ private void Encode(string text, List tokenIds, int start, int end) else { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); - this.Cache.Add(piece, encodedTokens.ToArray()); + this.Cache.Add(piece, encodedTokens); tokenCount += encodedTokens.Count; if (tokenCount <= maxTokenCount) { @@ -505,9 +505,9 @@ private void Encode(string text, List tokenIds, int start, ref int tokenCou { var piece = match.Value; - if (this.Cache.Lookup(match.Value, out int[] tokens)) + if (this.Cache.Lookup(match.Value, out List tokens)) { - tokenCount += tokens.Length; + tokenCount += tokens.Count; encodeLength += piece.Length; tokenIds.AddRange(tokens); tokenCountMap[tokenCount] = encodeLength; @@ -526,7 +526,7 @@ private void Encode(string text, List tokenIds, int start, ref int tokenCou else { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); - this.Cache.Add(piece, encodedTokens.ToArray()); + this.Cache.Add(piece, encodedTokens); tokenCount += encodedTokens.Count; encodeLength += piece.Length; tokenIds.AddRange(encodedTokens); @@ -669,9 +669,9 @@ private int CountTokens(string text, int max) foreach (Match match in Regex.Matches(text)) { var piece = match.Value; - if (this.Cache.Lookup(piece, out int[] tokens)) + if (this.Cache.Lookup(piece, out List tokens)) { - tokenCount += tokens.Length; + tokenCount += tokens.Count; } else { @@ -683,7 +683,7 @@ private int CountTokens(string text, int max) else { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); - this.Cache.Add(piece, encodedTokens.ToArray()); + this.Cache.Add(piece, encodedTokens); tokenCount += encodedTokens.Count; } }