diff --git a/.github/workflows/rust_ci.yml b/.github/workflows/rust_ci.yml index dbca2dc..364c2c7 100644 --- a/.github/workflows/rust_ci.yml +++ b/.github/workflows/rust_ci.yml @@ -21,9 +21,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master - with: - toolchain: stable + - uses: actions-rust-lang/setup-rust-toolchain@v1 - uses: Swatinem/rust-cache@v2 - name: Build diff --git a/.github/workflows/rust_release.yml b/.github/workflows/rust_release.yml index b883615..41cd400 100644 --- a/.github/workflows/rust_release.yml +++ b/.github/workflows/rust_release.yml @@ -12,7 +12,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Publish to crates.io env: CRATES_IO_TOKEN: ${{secrets.CRATES_IO_TOKEN}} diff --git a/.gitignore b/.gitignore index 6985cf1..dc51894 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + +.idea +.vscode \ No newline at end of file diff --git a/.vscode/setting.json b/.vscode/setting.json deleted file mode 100644 index bbf3889..0000000 --- a/.vscode/setting.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "rust-analyzer.cargo.features": "all", - - "rust-analyzer.linkedProjects": ["./bindings/python/Cargo.toml"] -} diff --git a/Cargo.toml b/Cargo.toml index 7d4673c..4ebe386 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,16 +25,15 @@ categories.workspace = true exclude = ["/.github", "/.vscode", "/bindings/**"] [dependencies] -tiktoken-rs = { version = "0.5.9", optional = true } -tokenizers = { version = "0.19.1", features = ["http"], optional = true } -tree-sitter = "0.22" -openssl = { version = "0.10", features = ["vendored"] } +tiktoken-rs = { version = "0.6.0", optional = true } +tokenizers = { version = "0.21.1", features = ["http"], optional = true } +tree-sitter = "0.25.3" [dev-dependencies] -tree-sitter-go = "0.21" -tree-sitter-md = "0.2" -tree-sitter-python = "0.21" -tree-sitter-rust = "0.21" +tree-sitter-go = "0.23.4" +tree-sitter-md = "0.3.2" +tree-sitter-python = "0.23.6" +tree-sitter-rust = "0.24.0" [features] tiktoken-rs = ["dep:tiktoken-rs"] diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index c644128..7bb5906 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -17,11 +17,11 @@ crate-type = ["cdylib"] [dependencies] code-splitter = { path = "../..", features = ["tiktoken-rs", "tokenizers"] } -pyo3 = { version = "0.22.2", features = ["extension-module"] } -tiktoken-rs = "0.5.9" -tokenizers = { version = "0.19.1", features = ["http"] } -tree-sitter = "0.22" -tree-sitter-go = "0.21" -tree-sitter-md = "0.2" -tree-sitter-python = "0.21" -tree-sitter-rust = "0.21" +pyo3 = { version = "0.24.1", features = ["extension-module"] } +tiktoken-rs = "0.6.0" +tokenizers = { version = "0.21.1", features = ["http"] } +tree-sitter = "0.25.3" +tree-sitter-go = "0.23.4" +tree-sitter-md = "0.3.2" +tree-sitter-python = "0.23.6" +tree-sitter-rust = "0.24.0" diff --git a/bindings/python/src/language.rs b/bindings/python/src/language.rs index 327492e..0e6dddd 100644 --- a/bindings/python/src/language.rs +++ b/bindings/python/src/language.rs @@ -1,8 +1,11 @@ use pyo3::prelude::*; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::OnceLock; use tree_sitter::Language as TreeSitterLanguage; #[pyclass(eq)] -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, PartialEq)] #[non_exhaustive] pub enum Language { Golang, @@ -11,13 +14,34 @@ pub enum Language { Rust, } +impl Eq for Language {} + +static LANGUAGES: OnceLock> = OnceLock::new(); + impl Language { + fn init_languages() -> HashMap { + HashMap::from([ + ( + Language::Golang, + tree_sitter::Language::new(tree_sitter_go::LANGUAGE), + ), + ( + Language::Markdown, + tree_sitter::Language::new(tree_sitter_md::LANGUAGE), + ), + ( + Language::Python, + tree_sitter::Language::new(tree_sitter_python::LANGUAGE), + ), + ( + Language::Rust, + tree_sitter::Language::new(tree_sitter_rust::LANGUAGE), + ), + ]) + } + pub fn as_tree_sitter_language(&self) -> TreeSitterLanguage { - match self { - Language::Golang => tree_sitter_go::language(), - Language::Markdown => tree_sitter_md::language(), - Language::Python => tree_sitter_python::language(), - Language::Rust => tree_sitter_rust::language(), - } + let languages = LANGUAGES.get_or_init(Self::init_languages); + languages.get(self).unwrap().clone() } } diff --git a/src/sizer/chars.rs b/src/sizer/chars.rs index 41d9520..06bdce8 100644 --- a/src/sizer/chars.rs +++ b/src/sizer/chars.rs @@ -4,10 +4,10 @@ use crate::sizer::Sizer; /// A marker struct for counting characters in code chunks. /// /// ``` +/// use tree_sitter::Language; /// use code_splitter::{CharCounter, Splitter}; /// -/// let lang = tree_sitter_md::language(); -/// let splitter = Splitter::new(lang, CharCounter).unwrap(); +/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), CharCounter).unwrap(); /// /// let code = b"hello, world!"; /// let chunks = splitter.split(code).unwrap(); diff --git a/src/sizer/words.rs b/src/sizer/words.rs index af18f7a..3cceacd 100644 --- a/src/sizer/words.rs +++ b/src/sizer/words.rs @@ -4,10 +4,10 @@ use crate::sizer::Sizer; /// A marker struct for counting words in code chunks. /// /// ``` +/// use tree_sitter::Language; /// use code_splitter::{Splitter, WordCounter}; /// -/// let lang = tree_sitter_md::language(); -/// let splitter = Splitter::new(lang, WordCounter).unwrap(); +/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), WordCounter).unwrap(); /// /// let code = b"hello, world!"; /// let chunks = splitter.split(code).unwrap(); diff --git a/src/splitter.rs b/src/splitter.rs index 8b29504..bb4b7f2 100644 --- a/src/splitter.rs +++ b/src/splitter.rs @@ -26,19 +26,19 @@ where /// /// # Example: split by characters /// ``` + /// use tree_sitter::Language; /// use code_splitter::{CharCounter, Splitter}; /// - /// let lang = tree_sitter_md::language(); - /// let splitter = Splitter::new(lang, CharCounter).unwrap(); + /// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), CharCounter).unwrap(); /// let chunks = splitter.split(b"hello, world!").unwrap(); /// ``` /// /// # Example: split by words /// ``` + /// use tree_sitter::Language; /// use code_splitter::{Splitter, WordCounter}; /// - /// let lang = tree_sitter_md::language(); - /// let splitter = Splitter::new(lang, WordCounter).unwrap(); + /// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), WordCounter).unwrap(); /// let chunks = splitter.split(b"hello, world!").unwrap(); /// ``` /// @@ -46,12 +46,12 @@ where /// ``` /// # #[cfg(feature = "tokenizers")] /// # { + /// use tree_sitter::Language; /// use code_splitter::Splitter; /// use tokenizers::Tokenizer; /// - /// let lang = tree_sitter_md::language(); /// let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap(); - /// let splitter = Splitter::new(lang, tokenizer).unwrap(); + /// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), tokenizer).unwrap(); /// let chunks = splitter.split(b"hello, world!").unwrap(); /// # } /// ``` @@ -60,12 +60,12 @@ where /// ``` /// # #[cfg(feature = "tiktoken-rs")] /// # { + /// use tree_sitter::Language; /// use code_splitter::Splitter; /// use tiktoken_rs::cl100k_base; /// - /// let lang = tree_sitter_md::language(); /// let bpe = cl100k_base().unwrap(); - /// let splitter = Splitter::new(lang, bpe).unwrap(); + /// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), bpe).unwrap(); /// let chunks = splitter.split(b"hello, world!").unwrap(); /// # } /// ``` @@ -84,10 +84,10 @@ where /// /// # Example: set the maximum size to 256 /// ``` + /// use tree_sitter::Language; /// use code_splitter::{CharCounter, Splitter}; /// - /// let lang = tree_sitter_md::language(); - /// let splitter = Splitter::new(lang, CharCounter) + /// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), CharCounter) /// .unwrap() /// .with_max_size(256); /// let chunks = splitter.split(b"hello, world!").unwrap(); diff --git a/tests/golang.rs b/tests/golang.rs index 00a1bd3..69fc459 100644 --- a/tests/golang.rs +++ b/tests/golang.rs @@ -1,5 +1,6 @@ use code_splitter::{CharCounter, Sizer, Splitter, WordCounter}; use std::fs; +use tree_sitter::Language; const TEST_FILE: &str = "tests/testdata/rectangle.go"; @@ -12,9 +13,8 @@ where T: Sizer, { let code = read_test_file(); - let lang = tree_sitter_go::language(); - let splitter = Splitter::new(lang, sizer) + let splitter = Splitter::new(Language::new(tree_sitter_go::LANGUAGE), sizer) .expect("Failed to create golang splitter") .with_max_size(max_size); let chunks = splitter.split(&code).expect("Failed to split golang code"); diff --git a/tests/markdown.rs b/tests/markdown.rs index 2361577..863521e 100644 --- a/tests/markdown.rs +++ b/tests/markdown.rs @@ -1,5 +1,6 @@ use code_splitter::{CharCounter, Sizer, Splitter, WordCounter}; use std::fs; +use tree_sitter::Language; const TEST_FILE: &str = "tests/testdata/markdown.md"; @@ -12,9 +13,8 @@ where T: Sizer, { let code = read_test_file(); - let lang = tree_sitter_md::language(); - let splitter = Splitter::new(lang, sizer) + let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), sizer) .expect("Failed to create markdown splitter") .with_max_size(max_size); let chunks = splitter diff --git a/tests/python.rs b/tests/python.rs index 5058e50..02d5044 100644 --- a/tests/python.rs +++ b/tests/python.rs @@ -1,5 +1,6 @@ use code_splitter::{CharCounter, Sizer, Splitter, WordCounter}; use std::fs; +use tree_sitter::Language; const TEST_FILE: &str = "tests/testdata/rectangle.py"; @@ -12,9 +13,8 @@ where T: Sizer, { let code = read_test_file(); - let lang = tree_sitter_python::language(); - let splitter = Splitter::new(lang, sizer) + let splitter = Splitter::new(Language::new(tree_sitter_python::LANGUAGE), sizer) .expect("Failed to create python splitter") .with_max_size(max_size); let chunks = splitter.split(&code).expect("Failed to split python code"); diff --git a/tests/rust.rs b/tests/rust.rs index d893ed9..b5b0082 100644 --- a/tests/rust.rs +++ b/tests/rust.rs @@ -1,5 +1,6 @@ use code_splitter::{CharCounter, Sizer, Splitter, WordCounter}; use std::fs; +use tree_sitter::Language; const TEST_FILE: &str = "tests/testdata/rectangle.rs"; @@ -12,9 +13,8 @@ where T: Sizer, { let code = read_test_file(); - let lang = tree_sitter_rust::language(); - let splitter = Splitter::new(lang, sizer) + let splitter = Splitter::new(Language::new(tree_sitter_rust::LANGUAGE), sizer) .expect("Failed to create rust splitter") .with_max_size(max_size); let chunks = splitter.split(&code).expect("Failed to split rust code");