From d31fe5cabbd797f347de68e475187e2f24b7636c Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 27 Feb 2026 14:46:14 +0800 Subject: [PATCH 1/5] feat: add SentencePiece unigram tokenizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement a SentencePiece unigram tokenizer alongside the existing BPE tokenizer, enabling support for Mistral V1/V2 and other models that use SentencePiece rather than tiktoken-style BPE. Core implementation in src/core/sentencepiece.rs: - Greedy longest-match encoding with score-based tie-breaking - BOS prepension and ▁ word boundary handling - Lossless and lossy decode paths with BOS/EOS skipping - SentencePieceError type for out-of-range token IDs Public API surface: - SentencePieceTokenizer and SentencePieceError re-exported from crate root - PySentencePieceTokenizer PyO3 wrapper with encode, decode, decode_lossy - bos_token_id_by_name added to pretrained API for symmetry with eos_token_id_by_name - SentencePieceTokenizer registered in the Python _core module and __init__.py Documentation updated to cover both BPE and SentencePiece APIs in the Rust and Python reference sections of docs/api_guide.md and README.md. --- README.md | 26 +-- docs/api_guide.md | 101 ++++++++++- python/splintr/__init__.py | 19 +- src/core/mod.rs | 6 +- src/core/pretrained.rs | 5 + src/core/sentencepiece.rs | 355 +++++++++++++++++++++++++++++++++++++ src/lib.rs | 13 +- src/python/bindings.rs | 107 +++++++++++ src/python/mod.rs | 3 +- 9 files changed, 609 insertions(+), 26 deletions(-) create mode 100644 src/core/sentencepiece.rs diff --git a/README.md b/README.md index f58a479..06c12cf 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Crates.io](https://img.shields.io/crates/v/splintr.svg)](https://crates.io/crates/splintr) [![PyPI](https://img.shields.io/pypi/v/splintr-rs.svg)](https://pypi.org/project/splintr-rs/) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) -**A high-performance BPE tokenizer built with Rust with Python bindings, focused on speed, safety, and resource optimization.** +**A high-performance tokenizer (BPE + SentencePiece) built with Rust with Python bindings, focused on speed, safety, and resource optimization.** ## The Problem @@ -85,7 +85,7 @@ See the [API Guide](docs/api_guide.md) and [docs.rs](https://docs.rs/splintr) fo - **Compatible vocabularies** - Supports cl100k_base, o200k_base (OpenAI), Llama 3 family (Meta), DeepSeek V3 (DeepSeek), and Mistral V1/V2/V3 (Mistral AI) - **Streaming decoders** - Real-time LLM output display with proper UTF-8 handling ([guide](docs/api_guide.md#streaming-decoder)) - **54 agent tokens** - Built-in support for chat, CoT reasoning, ReAct agents, tool calling, RAG citations ([docs](docs/special_tokens.md)) -- **Battle-tested algorithms** - Regexr with JIT (pure Rust), Aho-Corasick for special tokens, linked-list BPE +- **Battle-tested algorithms** - Regexr with JIT (pure Rust), Aho-Corasick for special tokens, linked-list BPE, SentencePiece unigram **Cross-platform:** @@ -219,15 +219,15 @@ See the [API Guide](docs/api_guide.md#streaming-decoder) for detailed usage, exa ## Supported Vocabularies -| Vocabulary | Used By | Vocabulary Size | Special Tokens | Import Constant | -| ------------- | ----------------------------------- | --------------- | --------------- | -------------------------- | -| `cl100k_base` | GPT-4, GPT-3.5-turbo | ~100,000 | 5 + 54 agent | `CL100K_BASE_PATTERN` | -| `o200k_base` | GPT-4o | ~200,000 | 2 + 54 agent | `O200K_BASE_PATTERN` | -| `llama3` | Llama 3, 3.1, 3.2, 3.3 (Meta) | ~128,000 | 11 + 54 agent | `LLAMA3_PATTERN` | -| `deepseek_v3` | DeepSeek V3, DeepSeek R1 | ~128,000 | 17 + 54 agent | `LLAMA3_PATTERN` | -| `mistral_v1` | Mistral 7B v0.1/v0.2, Mixtral 8x7B | ~32,000 | 3 + 54 agent | `SENTENCEPIECE_PATTERN` | -| `mistral_v2` | Mistral 7B v0.3, Codestral, 8x22B | ~32,768 | 10 + 54 agent | `SENTENCEPIECE_PATTERN` | -| `mistral_v3` | Mistral NeMo, Large 2, Pixtral | ~131,000 | 10 + 54 agent | `MISTRAL_V3_PATTERN` | +| Vocabulary | Used By | Vocabulary Size | Special Tokens | Import Constant | +| ------------- | ---------------------------------- | --------------- | -------------- | ----------------------- | +| `cl100k_base` | GPT-4, GPT-3.5-turbo | ~100,000 | 5 + 54 agent | `CL100K_BASE_PATTERN` | +| `o200k_base` | GPT-4o | ~200,000 | 2 + 54 agent | `O200K_BASE_PATTERN` | +| `llama3` | Llama 3, 3.1, 3.2, 3.3 (Meta) | ~128,000 | 11 + 54 agent | `LLAMA3_PATTERN` | +| `deepseek_v3` | DeepSeek V3, DeepSeek R1 | ~128,000 | 17 + 54 agent | `LLAMA3_PATTERN` | +| `mistral_v1` | Mistral 7B v0.1/v0.2, Mixtral 8x7B | ~32,000 | 3 + 54 agent | `SENTENCEPIECE_PATTERN` | +| `mistral_v2` | Mistral 7B v0.3, Codestral, 8x22B | ~32,768 | 10 + 54 agent | `SENTENCEPIECE_PATTERN` | +| `mistral_v3` | Mistral NeMo, Large 2, Pixtral | ~131,000 | 10 + 54 agent | `MISTRAL_V3_PATTERN` | **OpenAI standard tokens:** @@ -279,6 +279,7 @@ Splintr implements several optimizations that make tokenization faster: - **Regexr with JIT compilation**: Pure Rust regex engine with SIMD acceleration - **Rayon parallelism**: Leverages multiple CPU cores for batch encoding - **Linked-list BPE algorithm**: Avoids O(N²) complexity on pathological inputs +- **SentencePiece unigram**: Greedy longest-match with score-based tie-breaking for Mistral/Llama-style models - **FxHashMap**: Faster lookups than default SipHash for non-adversarial contexts - **Aho-Corasick for special tokens**: Fast multi-pattern matching without regex alternation - **LRU cache**: Avoids redundant BPE encoding of frequently seen chunks @@ -357,6 +358,7 @@ The pre-commit hook automatically runs formatting, clippy, and tests before each Splintr builds upon concepts from: - [tiktoken](https://github.com/openai/tiktoken) - OpenAI's reference BPE tokenizer +- [SentencePiece](https://github.com/google/sentencepiece) - Google's unsupervised text tokenizer - [tokenizers](https://github.com/huggingface/tokenizers) - Hugging Face's tokenization library The performance optimizations are informed by profiling real-world usage patterns in LLM applications. @@ -368,7 +370,7 @@ If you use Splintr in your research, please cite: ```bibtex @software{splintr, author = {Farhan Syah}, - title = {Splintr: High-Performance BPE Tokenizer}, + title = {Splintr: High-Performance Tokenizer (BPE + SentencePiece)}, year = {2025}, url = {https://github.com/ml-rust/splintr} } diff --git a/docs/api_guide.md b/docs/api_guide.md index f0305e2..8754477 100644 --- a/docs/api_guide.md +++ b/docs/api_guide.md @@ -5,14 +5,17 @@ This guide provides comprehensive documentation for using Splintr's Python and R ## Table of Contents - [Python API Reference](#python-api-reference) - - [Tokenizer Class](#tokenizer-class) + - [Tokenizer Class](#tokenizer-class) (BPE) - [Encoding Methods](#encoding-methods) - [Decoding Methods](#decoding-methods) - [Cache Management](#cache-management) + - [SentencePiece Tokenizer Class](#sentencepiece-tokenizer-class) (Unigram) - [Streaming Decoder](#streaming-decoder) - [Regular Streaming Decoder](#regular-streaming-decoder) - [ByteLevel Streaming Decoder](#bytelevel-streaming-decoder) - [Rust API Reference](#rust-api-reference) + - [BPE Tokenizer](#bpe-tokenizer) + - [SentencePiece Tokenizer](#sentencepiece-tokenizer) - [Detailed Usage Examples](#detailed-usage-examples) - [Basic Encoding and Decoding](#basic-encoding-and-decoding) - [Batch Processing](#batch-processing) @@ -156,6 +159,61 @@ Clear the LRU encoding cache. Useful if memory pressure is a concern. tokenizer.clear_cache() ``` +### SentencePiece Tokenizer Class + +The `SentencePieceTokenizer` class provides unigram tokenization for models using SentencePiece (e.g., loaded from GGUF files). + +#### Creating + +```python +from splintr import SentencePieceTokenizer + +# Create from raw vocabulary data +tokenizer = SentencePieceTokenizer( + tokens=["", "", "", "▁Hello", "▁world"], + scores=[0.0, 0.0, 0.0, -1.2, -1.5], + eos_token_id=2, + bos_token_id=1, # optional +) +``` + +#### `encode(text: str) -> list[int]` + +Encode text using greedy longest-match with score-based tie-breaking. Prepends BOS if configured. + +```python +ids = tokenizer.encode("Hello world") +# [1, 3, 4] (BOS + ▁Hello + ▁world) +``` + +#### `decode(ids: list[int]) -> str` + +Decode token IDs to text. Skips BOS/EOS tokens, converts ▁ back to spaces. + +```python +text = tokenizer.decode([1, 3, 4]) +# "Hello world" +``` + +#### `decode_lossy(ids: list[int]) -> str` + +Decode token IDs, silently skipping any invalid (out-of-range) IDs. + +```python +text = tokenizer.decode_lossy([1, 3, 999, 4]) +# "Hello world" (999 is skipped) +``` + +#### Properties + +- `vocab_size: int` — Total vocabulary size +- `eos_token_id: int` — End-of-sequence token ID +- `bos_token_id: int | None` — Beginning-of-sequence token ID (if configured) + +#### Methods + +- `is_eos(token_id: int) -> bool` — Check if a token is the EOS token + ## Streaming Decoder Streaming decoders are essential for real-time LLM applications where tokens arrive one at a time. They handle the critical problem of BPE tokens not aligning with UTF-8 character boundaries. @@ -286,7 +344,7 @@ Add Splintr to your `Cargo.toml`: splintr = "*" # or pin to a specific version ``` -### Basic Usage +### BPE Tokenizer ```rust use splintr::{Tokenizer, CL100K_BASE_PATTERN}; @@ -306,19 +364,54 @@ let texts = vec!["Hello".to_string(), "World".to_string()]; let batch_tokens = tokenizer.encode_batch(&texts); ``` -### Encoding Methods +#### Encoding Methods - `encode(&self, text: &str) -> Vec`: Sequential encoding (optimal for texts <1MB) - `encode_with_special(&self, text: &str) -> Vec`: Encode with special token recognition - `encode_batch(&self, texts: &[String]) -> Vec>`: Parallel encoding across texts - `encode_rayon(&self, text: &str) -> Vec`: Parallel encoding within text (for texts >1MB) -### Decoding Methods +#### Decoding Methods - `decode(&self, tokens: &[u32]) -> Result`: Decode to UTF-8 string - `decode_bytes(&self, tokens: &[u32]) -> Vec`: Decode to raw bytes - `decode_lossy(&self, tokens: &[u32]) -> String`: Decode with replacement for invalid UTF-8 +### SentencePiece Tokenizer + +For models using SentencePiece unigram tokenization (e.g., Mistral V1/V2): + +```rust +use splintr::SentencePieceTokenizer; + +// Create from raw vocabulary data +let tokenizer = SentencePieceTokenizer::new( + tokens, // Vec — token strings indexed by ID + scores, // Vec — scores for tie-breaking (empty for uniform) + Some(1), // Optional BOS token ID + 2, // EOS token ID +)?; + +// Encode (prepends BOS if configured, uses ▁ word boundaries) +let ids = tokenizer.encode("Hello world"); + +// Decode (skips BOS/EOS, converts ▁ back to spaces) +let text = tokenizer.decode(&ids)?; + +// Lossy decode (skips invalid token IDs instead of erroring) +let text = tokenizer.decode_lossy(&ids); +``` + +#### Methods + +- `encode(&self, text: &str) -> Vec`: Greedy longest-match encoding with score-based tie-breaking +- `decode(&self, ids: &[u32]) -> Result`: Decode to UTF-8 string +- `decode_lossy(&self, ids: &[u32]) -> String`: Decode, skipping invalid token IDs +- `vocab_size(&self) -> usize`: Vocabulary size +- `is_eos(&self, token_id: u32) -> bool`: Check if token is EOS +- `eos_token_id(&self) -> u32`: Get EOS token ID +- `bos_token_id(&self) -> Option`: Get BOS token ID + ### Error Handling The Rust API uses `Result` types for operations that can fail: diff --git a/python/splintr/__init__.py b/python/splintr/__init__.py index ccc8243..dfb3b52 100644 --- a/python/splintr/__init__.py +++ b/python/splintr/__init__.py @@ -1,11 +1,12 @@ """ -Splintr - Fast Rust BPE tokenizer with Python bindings +Splintr - Fast Rust tokenizer (BPE + SentencePiece) with Python bindings A high-performance tokenizer featuring: - Regexr with JIT and SIMD (default, pure Rust) - Optional PCRE2 with JIT (requires pcre2 feature) - Rayon parallelism for multi-core encoding - Linked-list BPE algorithm (avoids O(N^2) on pathological inputs) +- SentencePiece unigram with greedy longest-match and score-based tie-breaking - FxHashMap for fast lookups - Aho-Corasick for fast special token matching - LRU cache for frequently encoded chunks @@ -61,6 +62,18 @@ print(text, end="", flush=True) print(decoder.flush()) +SentencePiece Unigram (for GGUF models): + from splintr import SentencePieceTokenizer + + tokenizer = SentencePieceTokenizer( + tokens=["", "", "", "▁Hello", "▁world"], + scores=[0.0, 0.0, 0.0, -1.2, -1.5], + eos_token_id=2, + bos_token_id=1, + ) + ids = tokenizer.encode("Hello world") + text = tokenizer.decode(ids) + Agent Tokens: from splintr import ( Tokenizer, @@ -109,6 +122,7 @@ from ._core import ( Tokenizer, + SentencePieceTokenizer, StreamingDecoder, ByteLevelStreamingDecoder, CL100K_BASE_PATTERN, @@ -125,6 +139,7 @@ __all__ = [ "Tokenizer", + "SentencePieceTokenizer", "StreamingDecoder", "ByteLevelStreamingDecoder", "CL100K_BASE_PATTERN", @@ -138,4 +153,4 @@ "MISTRAL_V2_AGENT_TOKENS", "MISTRAL_V3_AGENT_TOKENS", ] -__version__ = "0.8.0" +__version__ = "0.9.0" diff --git a/src/core/mod.rs b/src/core/mod.rs index bbc9932..59dfc1a 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -30,6 +30,7 @@ mod bpe; pub mod byte_level; pub mod pretrained; +pub mod sentencepiece; mod streaming; mod tokenizer; mod vocab; @@ -37,11 +38,12 @@ mod vocab; pub use bpe::byte_pair_encode; pub use byte_level::{byte_level_decode, byte_level_decode_bytes, byte_level_encode}; pub use pretrained::{ - bos_token_id, cl100k_base_special_tokens, deepseek_v3_special_tokens, eos_token_id, - eos_token_id_by_name, from_pretrained, from_vocab, llama3_special_tokens, + bos_token_id, bos_token_id_by_name, cl100k_base_special_tokens, deepseek_v3_special_tokens, + eos_token_id, eos_token_id_by_name, from_pretrained, from_vocab, llama3_special_tokens, o200k_base_special_tokens, pad_token_id, pattern, special_tokens, uses_byte_level, PretrainedVocab, }; +pub use sentencepiece::{SentencePieceError, SentencePieceTokenizer}; pub use streaming::{ByteLevelStreamingDecoder, StreamingDecoder}; pub use tokenizer::{ cl100k_agent_tokens, o200k_agent_tokens, Tokenizer, TokenizerError, CL100K_BASE_PATTERN, diff --git a/src/core/pretrained.rs b/src/core/pretrained.rs index 9044032..33d99b0 100644 --- a/src/core/pretrained.rs +++ b/src/core/pretrained.rs @@ -204,6 +204,11 @@ pub fn bos_token_id(vocab: PretrainedVocab) -> Option { } } +/// Get the BOS token ID by vocabulary name string. +pub fn bos_token_id_by_name(name: &str) -> Option { + PretrainedVocab::from_name(name).and_then(bos_token_id) +} + /// Get the PAD token ID for a vocabulary. pub fn pad_token_id(vocab: PretrainedVocab) -> Option { match vocab { diff --git a/src/core/sentencepiece.rs b/src/core/sentencepiece.rs new file mode 100644 index 0000000..cb2bf03 --- /dev/null +++ b/src/core/sentencepiece.rs @@ -0,0 +1,355 @@ +//! SentencePiece-compatible unigram tokenizer. +//! +//! This tokenizer implements greedy longest-match encoding with score-based +//! tie-breaking, compatible with SentencePiece unigram models used by +//! Llama, Mistral, and other models distributed in GGUF format. +//! +//! Unlike BPE (which merges byte pairs iteratively), unigram tokenization +//! greedily selects the longest matching token at each position, using +//! scores to break ties between equal-length matches. + +use std::collections::HashMap; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum SentencePieceError { + #[error("Empty vocabulary")] + EmptyVocab, + #[error("Scores length ({scores}) does not match tokens length ({tokens})")] + ScoreMismatch { scores: usize, tokens: usize }, + #[error("Decoding error: token ID {0} out of range")] + InvalidTokenId(u32), +} + +/// SentencePiece-compatible unigram tokenizer. +/// +/// Accepts a raw vocabulary (token strings, scores, special token IDs) and +/// performs greedy longest-match encoding with SentencePiece word boundary +/// markers (▁ U+2581). +/// +/// # Example +/// +/// ``` +/// use splintr::SentencePieceTokenizer; +/// +/// let tokens = vec!["▁Hello".to_string(), "▁world".to_string(), "H".to_string()]; +/// let scores = vec![0.0; 3]; +/// let tok = SentencePieceTokenizer::new(tokens, scores, None, 2).unwrap(); +/// ``` +pub struct SentencePieceTokenizer { + /// Token string -> ID mapping + token_to_id: HashMap, + /// ID -> Token string mapping + id_to_token: Vec, + /// Scores for each token (used for tie-breaking) + scores: Vec, + /// BOS token ID + bos_token_id: Option, + /// EOS token ID + eos_token_id: u32, +} + +impl SentencePieceTokenizer { + /// Create a tokenizer from raw vocabulary data. + /// + /// # Arguments + /// * `tokens` - Token strings, indexed by token ID + /// * `scores` - Score per token (for tie-breaking). If empty, defaults to all zeros. + /// * `bos_token_id` - Optional beginning-of-sequence token ID + /// * `eos_token_id` - End-of-sequence token ID + pub fn new( + tokens: Vec, + scores: Vec, + bos_token_id: Option, + eos_token_id: u32, + ) -> Result { + if tokens.is_empty() { + return Err(SentencePieceError::EmptyVocab); + } + + let scores = if scores.is_empty() { + vec![0.0; tokens.len()] + } else if scores.len() != tokens.len() { + return Err(SentencePieceError::ScoreMismatch { + scores: scores.len(), + tokens: tokens.len(), + }); + } else { + scores + }; + + let mut token_to_id = HashMap::with_capacity(tokens.len()); + for (id, token) in tokens.iter().enumerate() { + token_to_id.insert(token.clone(), id as u32); + } + + Ok(Self { + token_to_id, + id_to_token: tokens, + scores, + bos_token_id, + eos_token_id, + }) + } + + /// Encode text to token IDs using greedy longest-match. + /// + /// Prepends BOS token if configured. Replaces spaces with ▁ (U+2581) + /// following the SentencePiece convention. + pub fn encode(&self, text: &str) -> Vec { + let mut tokens = Vec::new(); + + if let Some(bos_id) = self.bos_token_id { + tokens.push(bos_id); + } + + // SentencePiece: prepend ▁ and replace spaces with ▁ + let processed = format!("▁{}", text.replace(' ', "▁")); + let chars: Vec = processed.chars().collect(); + let mut pos = 0; + let mut substr_buf = String::with_capacity(256 * 4); + + while pos < chars.len() { + let mut best_len = 0; + let mut best_id = None; + let mut best_score = f32::NEG_INFINITY; + + substr_buf.clear(); + for end in (pos + 1)..=chars.len().min(pos + 256) { + substr_buf.push(chars[end - 1]); + if let Some(&id) = self.token_to_id.get(&substr_buf) { + let score = self.scores.get(id as usize).copied().unwrap_or(0.0); + let len = end - pos; + if len > best_len || (len == best_len && score > best_score) { + best_len = len; + best_id = Some(id); + best_score = score; + } + } + } + + if let Some(id) = best_id { + tokens.push(id); + pos += best_len; + } else { + let c = chars[pos]; + let byte_tokens = self.encode_char_as_bytes(c); + if !byte_tokens.is_empty() { + tokens.extend(byte_tokens); + } + pos += 1; + } + } + + tokens + } + + /// Encode a character as individual byte tokens using `<0xNN>` format. + /// + /// Each UTF-8 byte of the character is looked up as a token (e.g., `<0xFF>`). + /// Bytes not present in the vocabulary are silently skipped. + fn encode_char_as_bytes(&self, c: char) -> Vec { + let mut result = Vec::new(); + let mut buf = [0u8; 4]; + let bytes = c.encode_utf8(&mut buf); + + for b in bytes.as_bytes() { + let byte_token = format!("<0x{:02X}>", b); + if let Some(&id) = self.token_to_id.get(&byte_token) { + result.push(id); + } + } + + result + } + + /// Decode token IDs to text. + /// + /// Skips BOS/EOS tokens and converts ▁ back to spaces. + pub fn decode(&self, ids: &[u32]) -> Result { + let mut result = String::new(); + + for &id in ids { + let token = self + .id_to_token + .get(id as usize) + .ok_or(SentencePieceError::InvalidTokenId(id))?; + + if Some(id) == self.bos_token_id || id == self.eos_token_id { + continue; + } + result.push_str(&token.replace('▁', " ")); + } + + // Remove leading space (artifact of SentencePiece encoding) + if result.starts_with(' ') { + result.remove(0); + } + + Ok(result) + } + + /// Decode token IDs to text, skipping invalid IDs. + pub fn decode_lossy(&self, ids: &[u32]) -> String { + let mut result = String::new(); + + for &id in ids { + if let Some(token) = self.id_to_token.get(id as usize) { + if Some(id) == self.bos_token_id || id == self.eos_token_id { + continue; + } + result.push_str(&token.replace('▁', " ")); + } + } + + if result.starts_with(' ') { + result.remove(0); + } + + result + } + + /// Check if a token is the EOS token. + pub fn is_eos(&self, token_id: u32) -> bool { + token_id == self.eos_token_id + } + + /// Get vocabulary size. + pub fn vocab_size(&self) -> usize { + self.id_to_token.len() + } + + /// Get EOS token ID. + pub fn eos_token_id(&self) -> u32 { + self.eos_token_id + } + + /// Get BOS token ID. + pub fn bos_token_id(&self) -> Option { + self.bos_token_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_tokenizer() -> SentencePieceTokenizer { + // Minimal vocab: ▁Hello, ▁world, ▁, <0x48> (byte fallback for 'H') + let tokens = vec![ + "".to_string(), // 0 + "".to_string(), // 1 (BOS) + "".to_string(), // 2 (EOS) + "▁Hello".to_string(), // 3 + "▁world".to_string(), // 4 + "▁".to_string(), // 5 + "H".to_string(), // 6 + "e".to_string(), // 7 + "l".to_string(), // 8 + "o".to_string(), // 9 + ]; + let scores = vec![0.0; tokens.len()]; + SentencePieceTokenizer::new(tokens, scores, Some(1), 2).unwrap() + } + + #[test] + fn test_encode_basic() { + let tok = make_tokenizer(); + let ids = tok.encode("Hello world"); + // BOS(1), ▁Hello(3), ▁world(4) + assert_eq!(ids, vec![1, 3, 4]); + } + + #[test] + fn test_decode_basic() { + let tok = make_tokenizer(); + let text = tok.decode(&[1, 3, 4]).unwrap(); + assert_eq!(text, "Hello world"); + } + + #[test] + fn test_decode_skips_bos_eos() { + let tok = make_tokenizer(); + let text = tok.decode(&[1, 3, 2]).unwrap(); + assert_eq!(text, "Hello"); + } + + #[test] + fn test_roundtrip() { + let tok = make_tokenizer(); + let ids = tok.encode("Hello world"); + let text = tok.decode(&ids).unwrap(); + assert_eq!(text, "Hello world"); + } + + #[test] + fn test_vocab_size() { + let tok = make_tokenizer(); + assert_eq!(tok.vocab_size(), 10); + } + + #[test] + fn test_is_eos() { + let tok = make_tokenizer(); + assert!(tok.is_eos(2)); + assert!(!tok.is_eos(1)); + } + + #[test] + fn test_empty_scores_defaults() { + let tokens = vec!["▁a".to_string(), "▁b".to_string()]; + let tok = SentencePieceTokenizer::new(tokens, vec![], None, 1).unwrap(); + assert_eq!(tok.vocab_size(), 2); + } + + #[test] + fn test_empty_vocab_errors() { + let result = SentencePieceTokenizer::new(vec![], vec![], None, 0); + assert!(result.is_err()); + } + + #[test] + fn test_score_mismatch_errors() { + let tokens = vec!["a".to_string()]; + let result = SentencePieceTokenizer::new(tokens, vec![1.0, 2.0], None, 0); + assert!(result.is_err()); + } + + #[test] + fn test_encode_empty_string() { + let tok = make_tokenizer(); + let ids = tok.encode(""); + // BOS + ▁ (SentencePiece always prepends ▁, which matches token 5) + assert_eq!(ids, vec![1, 5]); + } + + #[test] + fn test_encode_empty_string_no_bos() { + let tokens = vec!["▁a".to_string(), "▁b".to_string()]; + let tok = SentencePieceTokenizer::new(tokens, vec![], None, 1).unwrap(); + let ids = tok.encode(""); + assert!(ids.is_empty()); + } + + #[test] + fn test_decode_lossy_skips_invalid_tokens() { + let tok = make_tokenizer(); + // 999 is out of range, should be skipped + let text = tok.decode_lossy(&[1, 3, 999, 4]); + assert_eq!(text, "Hello world"); + } + + #[test] + fn test_decode_lossy_all_invalid() { + let tok = make_tokenizer(); + let text = tok.decode_lossy(&[999, 1000, 1001]); + assert_eq!(text, ""); + } + + #[test] + fn test_decode_invalid_token_id_errors() { + let tok = make_tokenizer(); + let result = tok.decode(&[1, 999]); + assert!(result.is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index ae06026..60026e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,26 +3,28 @@ pub mod core; mod python; pub use core::{ - ByteLevelStreamingDecoder, StreamingDecoder, Tokenizer, TokenizerError, CL100K_BASE_PATTERN, - LLAMA3_PATTERN, O200K_BASE_PATTERN, SENTENCEPIECE_PATTERN, + ByteLevelStreamingDecoder, SentencePieceError, SentencePieceTokenizer, StreamingDecoder, + Tokenizer, TokenizerError, CL100K_BASE_PATTERN, LLAMA3_PATTERN, O200K_BASE_PATTERN, + SENTENCEPIECE_PATTERN, }; // Re-export pretrained tokenizer API pub use core::pretrained; pub use core::{ - bos_token_id, cl100k_base_special_tokens, deepseek_v3_special_tokens, eos_token_id, - eos_token_id_by_name, from_pretrained, from_vocab, llama3_special_tokens, + bos_token_id, bos_token_id_by_name, cl100k_base_special_tokens, deepseek_v3_special_tokens, + eos_token_id, eos_token_id_by_name, from_pretrained, from_vocab, llama3_special_tokens, o200k_base_special_tokens, pad_token_id, pattern, special_tokens, uses_byte_level, PretrainedVocab, }; -/// Splintr - Fast Rust BPE tokenizer with Python bindings +/// Splintr - Fast Rust tokenizer (BPE + SentencePiece) with Python bindings /// /// A high-performance tokenizer featuring: /// - Regexr with JIT and SIMD (default, pure Rust) /// - Optional PCRE2 with JIT (requires `pcre2` feature) /// - Rayon parallelism for multi-core encoding /// - Linked-list BPE algorithm (avoids O(N²) on pathological inputs) +/// - SentencePiece unigram with greedy longest-match and score-based tie-breaking /// - FxHashMap for fast lookups /// - Aho-Corasick for fast special token matching /// - LRU cache for frequently encoded chunks @@ -35,6 +37,7 @@ use pyo3::prelude::*; #[pymodule] fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; // Register all agent token classes (auto-generated from scripts/generate_agent_tokens.py) diff --git a/src/python/bindings.rs b/src/python/bindings.rs index 2ad111e..3e1723b 100644 --- a/src/python/bindings.rs +++ b/src/python/bindings.rs @@ -39,6 +39,8 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use rustc_hash::FxHashMap; +use crate::core::SentencePieceTokenizer; + use crate::core::pretrained::{ cl100k_base_special_tokens, deepseek_v3_special_tokens, llama3_special_tokens, mistral_v1_special_tokens, mistral_v2_special_tokens, mistral_v3_special_tokens, @@ -445,6 +447,111 @@ impl PyTokenizer { } } +/// Python wrapper for the SentencePiece unigram tokenizer. +/// +/// For models using SentencePiece unigram tokenization (e.g., loaded from GGUF). +/// Uses greedy longest-match encoding with score-based tie-breaking. +#[pyclass(name = "SentencePieceTokenizer")] +pub struct PySentencePieceTokenizer { + inner: SentencePieceTokenizer, +} + +#[pymethods] +impl PySentencePieceTokenizer { + /// Create a new SentencePiece unigram tokenizer. + /// + /// Args: + /// tokens: List of token strings, indexed by token ID + /// scores: Scores per token for tie-breaking (empty list defaults to all zeros) + /// bos_token_id: Optional beginning-of-sequence token ID + /// eos_token_id: End-of-sequence token ID + #[new] + #[pyo3(signature = (tokens, scores, eos_token_id, bos_token_id=None))] + fn new( + tokens: Vec, + scores: Vec, + eos_token_id: u32, + bos_token_id: Option, + ) -> PyResult { + let inner = SentencePieceTokenizer::new(tokens, scores, bos_token_id, eos_token_id) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + Ok(Self { inner }) + } + + /// Encode text to token IDs using greedy longest-match. + /// + /// Prepends BOS token if configured. Replaces spaces with ▁ (U+2581) + /// following the SentencePiece convention. + /// + /// Args: + /// text: Input text to encode + /// + /// Returns: + /// List of token IDs + fn encode(&self, text: &str) -> Vec { + self.inner.encode(text) + } + + /// Decode token IDs to text. + /// + /// Skips BOS/EOS tokens and converts ▁ back to spaces. + /// + /// Args: + /// ids: List of token IDs + /// + /// Returns: + /// Decoded string + /// + /// Raises: + /// ValueError: If a token ID is out of range + fn decode(&self, ids: Vec) -> PyResult { + self.inner + .decode(&ids) + .map_err(|e| PyValueError::new_err(e.to_string())) + } + + /// Decode token IDs to text, skipping invalid IDs. + /// + /// Args: + /// ids: List of token IDs + /// + /// Returns: + /// Decoded string (invalid token IDs are silently skipped) + fn decode_lossy(&self, ids: Vec) -> String { + self.inner.decode_lossy(&ids) + } + + /// Get vocabulary size. + #[getter] + fn vocab_size(&self) -> usize { + self.inner.vocab_size() + } + + /// Check if a token is the EOS token. + fn is_eos(&self, token_id: u32) -> bool { + self.inner.is_eos(token_id) + } + + /// Get the EOS token ID. + #[getter] + fn eos_token_id(&self) -> u32 { + self.inner.eos_token_id() + } + + /// Get the BOS token ID (if configured). + #[getter] + fn bos_token_id(&self) -> Option { + self.inner.bos_token_id() + } + + fn __repr__(&self) -> String { + format!( + "SentencePieceTokenizer(vocab_size={})", + self.inner.vocab_size() + ) + } +} + /// Parse special tokens from Python dict to FxHashMap. fn parse_special_tokens( special_tokens: Option<&Bound<'_, PyDict>>, diff --git a/src/python/mod.rs b/src/python/mod.rs index acc0c55..aba4ef1 100644 --- a/src/python/mod.rs +++ b/src/python/mod.rs @@ -1,5 +1,6 @@ mod bindings; pub use bindings::{ - register_agent_tokens, PyByteLevelStreamingDecoder, PyStreamingDecoder, PyTokenizer, + register_agent_tokens, PyByteLevelStreamingDecoder, PySentencePieceTokenizer, + PyStreamingDecoder, PyTokenizer, }; From 04e85cc44ecc9190b3d08814d62891dab1687c7f Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 27 Feb 2026 14:46:36 +0800 Subject: [PATCH 2/5] chore: bump version to 0.9.0 Update version to 0.9.0 across Cargo.toml, pyproject.toml, .version, and uv.lock. Also update the crate description and keywords to reflect the addition of SentencePiece alongside BPE. --- .version | 2 +- Cargo.toml | 6 +++--- pyproject.toml | 6 +++--- uv.lock | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.version b/.version index a3df0a6..ac39a10 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.8.0 +0.9.0 diff --git a/Cargo.toml b/Cargo.toml index d32af9e..64ef451 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "splintr" -version = "0.8.0" +version = "0.9.0" edition = "2021" -description = "Fast Rust BPE tokenizer with Python bindings" +description = "Fast Rust tokenizer (BPE + SentencePiece) with Python bindings" license = "MIT" repository = "https://github.com/ml-rust/splintr" homepage = "https://github.com/ml-rust/splintr" readme = "README.md" -keywords = ["tokenizer", "bpe", "tiktoken", "gpt", "llm"] +keywords = ["tokenizer", "bpe", "sentencepiece", "tiktoken", "llm"] categories = ["text-processing", "encoding"] [lib] diff --git a/pyproject.toml b/pyproject.toml index b153fa6..fbece2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ build-backend = "maturin" [project] name = "splintr-rs" -version = "0.8.0" -description = "Fast Rust BPE tokenizer with Python bindings" +version = "0.9.0" +description = "Fast Rust tokenizer (BPE + SentencePiece) with Python bindings" readme = "README.md" license = { text = "MIT" } requires-python = ">=3.8" -keywords = ["tokenizer", "bpe", "tiktoken", "gpt", "llm"] +keywords = ["tokenizer", "bpe", "sentencepiece", "tiktoken", "llm"] authors = [{ name = "Farhan" }] classifiers = [ "Development Status :: 4 - Beta", diff --git a/uv.lock b/uv.lock index 19327e8..8774932 100644 --- a/uv.lock +++ b/uv.lock @@ -3,5 +3,5 @@ requires-python = ">=3.8" [[package]] name = "splintr-rs" -version = "0.8.0" +version = "0.9.0" source = { editable = "." } From 8a9872cee46808c2e89a7b59f51d7b8b57b351d0 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 27 Feb 2026 15:49:20 +0800 Subject: [PATCH 3/5] chore: remove pcre2 from default features Make regexr the sole default backend. The pcre2 feature remains available as an explicit opt-in for benchmarking purposes. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 64ef451..af2898b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ name = "splintr" crate-type = ["cdylib", "rlib"] [features] -default = ["pcre2"] +default = [] python = ["dep:pyo3"] pcre2 = ["dep:pcre2"] From 5bd68c52a72aa3a6ad3c4ba5c7f29c5abebe0f38 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 12 Mar 2026 05:48:15 +0800 Subject: [PATCH 4/5] chore: make rayon and regexr JIT optional features Introduce `rayon` and `regexr-jit` as named optional features, both enabled by default, so that WASM and embedded targets can opt out via `--no-default-features`. Add a `wasm` feature as a no-op marker that documents the intended build profile. Move regexr's jit/simd flags under the new `regexr-jit` feature rather than hardcoding them unconditionally. --- Cargo.toml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index af2898b..464c56c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,15 +15,18 @@ name = "splintr" crate-type = ["cdylib", "rlib"] [features] -default = [] +default = ["rayon", "regexr-jit"] python = ["dep:pyo3"] pcre2 = ["dep:pcre2"] +rayon = ["dep:rayon"] +regexr-jit = ["regexr/jit", "regexr/simd"] +wasm = [] # disables rayon, uses scalar regex — use with --no-default-features [dependencies] # PCRE2 regex with JIT support (optional, for benchmarking) pcre2 = { version = "0.2", optional = true } # Rayon for internal parallelism -rayon = "1.10" +rayon = { version = "1.10", optional = true } # Fast hashing (FxHashMap) rustc-hash = "2.0" # Error handling @@ -37,7 +40,7 @@ aho-corasick = "1.1" # LRU cache for frequent token sequences lru = "0.16" # regexr regex engine (default backend) -regexr = { version = "0.1.0-beta.5", features = ["jit", "simd"] } +regexr = { version = "0.1.0-beta.5", default-features = false } [dev-dependencies] # PCRE2 for benchmarking comparisons From 10c8cf51a1675a21ba039c4f3314b524f576e4d9 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 12 Mar 2026 05:49:44 +0800 Subject: [PATCH 5/5] fix: handle byte-fallback tokens in SentencePiece decode and gate rayon usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Decode `<0xNN>` byte-fallback tokens by accumulating raw bytes and converting via `from_utf8_lossy`, which correctly reconstructs multi-byte UTF-8 sequences (e.g. 'é' encoded as <0xC3><0xA9>). Previously these tokens were passed through as literal strings, producing garbled output for non-ASCII text. Preserve the leading-space stripping only for multi-token sequences so that single-token streaming decode does not lose meaningful word boundaries. Gate all `rayon` call sites behind `#[cfg(feature = "rayon")]` with sequential fallbacks, enabling the tokenizer to build for WASM and other no-std targets without changes to call sites. --- src/core/sentencepiece.rs | 93 ++++++++++++++++++++++++++++++++++----- src/core/tokenizer.rs | 80 ++++++++++++++++++++++++++------- 2 files changed, 146 insertions(+), 27 deletions(-) diff --git a/src/core/sentencepiece.rs b/src/core/sentencepiece.rs index cb2bf03..f731290 100644 --- a/src/core/sentencepiece.rs +++ b/src/core/sentencepiece.rs @@ -167,7 +167,7 @@ impl SentencePieceTokenizer { /// /// Skips BOS/EOS tokens and converts ▁ back to spaces. pub fn decode(&self, ids: &[u32]) -> Result { - let mut result = String::new(); + let mut bytes = Vec::new(); for &id in ids { let token = self @@ -178,34 +178,52 @@ impl SentencePieceTokenizer { if Some(id) == self.bos_token_id || id == self.eos_token_id { continue; } - result.push_str(&token.replace('▁', " ")); - } - // Remove leading space (artifact of SentencePiece encoding) - if result.starts_with(' ') { - result.remove(0); + if let Some(byte_val) = parse_byte_fallback(token) { + bytes.push(byte_val); + } else { + let decoded = token.replace('▁', " "); + bytes.extend_from_slice(decoded.as_bytes()); + } } + let result = String::from_utf8_lossy(&bytes).into_owned(); + + // Remove leading space only when decoding a full sequence (the leading ▁ + // is a SentencePiece artifact). For single-token decode (streaming), the + // leading space is meaningful word separation — don't strip it. + if ids.len() > 1 { + if let Some(stripped) = result.strip_prefix(' ') { + return Ok(stripped.to_string()); + } + } Ok(result) } /// Decode token IDs to text, skipping invalid IDs. pub fn decode_lossy(&self, ids: &[u32]) -> String { - let mut result = String::new(); + let mut bytes = Vec::new(); for &id in ids { if let Some(token) = self.id_to_token.get(id as usize) { if Some(id) == self.bos_token_id || id == self.eos_token_id { continue; } - result.push_str(&token.replace('▁', " ")); + if let Some(byte_val) = parse_byte_fallback(token) { + bytes.push(byte_val); + } else { + let decoded = token.replace('▁', " "); + bytes.extend_from_slice(decoded.as_bytes()); + } } } - if result.starts_with(' ') { - result.remove(0); + let result = String::from_utf8_lossy(&bytes).into_owned(); + if ids.len() > 1 { + if let Some(stripped) = result.strip_prefix(' ') { + return stripped.to_string(); + } } - result } @@ -230,6 +248,16 @@ impl SentencePieceTokenizer { } } +/// Parse a byte-fallback token like `<0x0A>` into its byte value. +fn parse_byte_fallback(token: &str) -> Option { + let inner = token.strip_prefix("<0x")?.strip_suffix('>')?; + if inner.len() == 2 { + u8::from_str_radix(inner, 16).ok() + } else { + None + } +} + #[cfg(test)] mod tests { use super::*; @@ -352,4 +380,47 @@ mod tests { let result = tok.decode(&[1, 999]); assert!(result.is_err()); } + + #[test] + fn test_parse_byte_fallback_valid() { + assert_eq!(parse_byte_fallback("<0x0A>"), Some(0x0A)); + assert_eq!(parse_byte_fallback("<0xFF>"), Some(0xFF)); + assert_eq!(parse_byte_fallback("<0x00>"), Some(0x00)); + assert_eq!(parse_byte_fallback("<0x7F>"), Some(0x7F)); + // Lowercase hex + assert_eq!(parse_byte_fallback("<0xab>"), Some(0xAB)); + } + + #[test] + fn test_parse_byte_fallback_invalid() { + assert_eq!(parse_byte_fallback("<0xZZ>"), None); + assert_eq!(parse_byte_fallback("<0x1>"), None); // single hex digit + assert_eq!(parse_byte_fallback("<0x123>"), None); // three hex digits + assert_eq!(parse_byte_fallback("0x0A"), None); // missing angle brackets + assert_eq!(parse_byte_fallback("<0x0A"), None); // missing closing bracket + assert_eq!(parse_byte_fallback("0x0A>"), None); // missing opening prefix + assert_eq!(parse_byte_fallback(""), None); + assert_eq!(parse_byte_fallback("hello"), None); + assert_eq!(parse_byte_fallback("<>"), None); + } + + #[test] + fn test_decode_byte_fallback_tokens() { + // Vocab with byte-fallback tokens for UTF-8 encoding of 'é' (0xC3 0xA9) + let tokens = vec![ + "".to_string(), // 0 + "".to_string(), // 1 + "".to_string(), // 2 + "<0xC3>".to_string(), // 3 + "<0xA9>".to_string(), // 4 + "▁hi".to_string(), // 5 + ]; + let scores = vec![0.0; tokens.len()]; + let tok = SentencePieceTokenizer::new(tokens, scores, Some(1), 2).unwrap(); + + // Decode: BOS + "▁hi" + byte(0xC3) + byte(0xA9) = "hié" + // Leading space from ▁ is stripped (multi-token sequence) + let text = tok.decode(&[1, 5, 3, 4]).unwrap(); + assert_eq!(text, "hié"); + } } diff --git a/src/core/tokenizer.rs b/src/core/tokenizer.rs index f05f87e..9dac75b 100644 --- a/src/core/tokenizer.rs +++ b/src/core/tokenizer.rs @@ -1,5 +1,6 @@ use aho_corasick::AhoCorasick; use lru::LruCache; +#[cfg(feature = "rayon")] use rayon::prelude::*; use regexr::{Regex as RegexrRegex, RegexBuilder}; use rustc_hash::FxHashMap; @@ -825,6 +826,7 @@ impl Tokenizer { return vec![]; } + #[cfg(feature = "rayon")] let results: Vec> = chunks .par_iter() .map(|&(start, end)| { @@ -833,6 +835,15 @@ impl Tokenizer { }) .collect(); + #[cfg(not(feature = "rayon"))] + let results: Vec> = chunks + .iter() + .map(|&(start, end)| { + let slice = &text_bytes[start..end]; + self.encode_chunk_with_position(slice, start) + }) + .collect(); + results.into_iter().flatten().collect() } @@ -928,33 +939,70 @@ impl Tokenizer { } } - /// Batch encode multiple texts in parallel. + /// Batch encode multiple texts (parallel when rayon is enabled). pub fn encode_batch(&self, texts: &[String]) -> Vec> { - texts.par_iter().map(|text| self.encode(text)).collect() + #[cfg(feature = "rayon")] + { + texts.par_iter().map(|text| self.encode(text)).collect() + } + #[cfg(not(feature = "rayon"))] + { + texts.iter().map(|text| self.encode(text)).collect() + } } /// Batch encode multiple texts with special token handling. pub fn encode_batch_with_special(&self, texts: &[String]) -> Vec> { - texts - .par_iter() - .map(|text| self.encode_with_special(text)) - .collect() + #[cfg(feature = "rayon")] + { + texts + .par_iter() + .map(|text| self.encode_with_special(text)) + .collect() + } + #[cfg(not(feature = "rayon"))] + { + texts + .iter() + .map(|text| self.encode_with_special(text)) + .collect() + } } - /// Batch decode multiple token lists in parallel. + /// Batch decode multiple token lists. pub fn decode_batch(&self, token_lists: &[Vec]) -> Result, TokenizerError> { - token_lists - .par_iter() - .map(|tokens| self.decode(tokens)) - .collect() + #[cfg(feature = "rayon")] + { + token_lists + .par_iter() + .map(|tokens| self.decode(tokens)) + .collect() + } + #[cfg(not(feature = "rayon"))] + { + token_lists + .iter() + .map(|tokens| self.decode(tokens)) + .collect() + } } - /// Batch decode multiple token lists in parallel, replacing invalid UTF-8. + /// Batch decode multiple token lists, replacing invalid UTF-8. pub fn decode_batch_lossy(&self, token_lists: &[Vec]) -> Vec { - token_lists - .par_iter() - .map(|tokens| self.decode_lossy(tokens)) - .collect() + #[cfg(feature = "rayon")] + { + token_lists + .par_iter() + .map(|tokens| self.decode_lossy(tokens)) + .collect() + } + #[cfg(not(feature = "rayon"))] + { + token_lists + .iter() + .map(|tokens| self.decode_lossy(tokens)) + .collect() + } } /// Get the vocabulary size (number of tokens).