Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/rust_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
- uses: actions-rust-lang/setup-rust-toolchain@v1
- uses: Swatinem/rust-cache@v2

- name: Build
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/rust_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: actions-rust-lang/setup-rust-toolchain@v1
- name: Publish to crates.io
env:
CRATES_IO_TOKEN: ${{secrets.CRATES_IO_TOKEN}}
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ Cargo.lock

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

.idea
.vscode
5 changes: 0 additions & 5 deletions .vscode/setting.json

This file was deleted.

15 changes: 7 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@ categories.workspace = true
exclude = ["/.github", "/.vscode", "/bindings/**"]

[dependencies]
tiktoken-rs = { version = "0.5.9", optional = true }
tokenizers = { version = "0.19.1", features = ["http"], optional = true }
tree-sitter = "0.22"
openssl = { version = "0.10", features = ["vendored"] }
tiktoken-rs = { version = "0.6.0", optional = true }
tokenizers = { version = "0.21.1", features = ["http"], optional = true }
tree-sitter = "0.25.3"

[dev-dependencies]
tree-sitter-go = "0.21"
tree-sitter-md = "0.2"
tree-sitter-python = "0.21"
tree-sitter-rust = "0.21"
tree-sitter-go = "0.23.4"
tree-sitter-md = "0.3.2"
tree-sitter-python = "0.23.6"
tree-sitter-rust = "0.24.0"

[features]
tiktoken-rs = ["dep:tiktoken-rs"]
Expand Down
16 changes: 8 additions & 8 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ crate-type = ["cdylib"]

[dependencies]
code-splitter = { path = "../..", features = ["tiktoken-rs", "tokenizers"] }
pyo3 = { version = "0.22.2", features = ["extension-module"] }
tiktoken-rs = "0.5.9"
tokenizers = { version = "0.19.1", features = ["http"] }
tree-sitter = "0.22"
tree-sitter-go = "0.21"
tree-sitter-md = "0.2"
tree-sitter-python = "0.21"
tree-sitter-rust = "0.21"
pyo3 = { version = "0.24.1", features = ["extension-module"] }
tiktoken-rs = "0.6.0"
tokenizers = { version = "0.21.1", features = ["http"] }
tree-sitter = "0.25.3"
tree-sitter-go = "0.23.4"
tree-sitter-md = "0.3.2"
tree-sitter-python = "0.23.6"
tree-sitter-rust = "0.24.0"
38 changes: 31 additions & 7 deletions bindings/python/src/language.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use pyo3::prelude::*;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::OnceLock;
use tree_sitter::Language as TreeSitterLanguage;

#[pyclass(eq)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
#[non_exhaustive]
pub enum Language {
Golang,
Expand All @@ -11,13 +14,34 @@ pub enum Language {
Rust,
}

impl Eq for Language {}

static LANGUAGES: OnceLock<HashMap<Language, tree_sitter::Language>> = OnceLock::new();

impl Language {
fn init_languages() -> HashMap<Language, tree_sitter::Language> {
HashMap::from([
(
Language::Golang,
tree_sitter::Language::new(tree_sitter_go::LANGUAGE),
),
(
Language::Markdown,
tree_sitter::Language::new(tree_sitter_md::LANGUAGE),
),
(
Language::Python,
tree_sitter::Language::new(tree_sitter_python::LANGUAGE),
),
(
Language::Rust,
tree_sitter::Language::new(tree_sitter_rust::LANGUAGE),
),
])
}

pub fn as_tree_sitter_language(&self) -> TreeSitterLanguage {
match self {
Language::Golang => tree_sitter_go::language(),
Language::Markdown => tree_sitter_md::language(),
Language::Python => tree_sitter_python::language(),
Language::Rust => tree_sitter_rust::language(),
}
let languages = LANGUAGES.get_or_init(Self::init_languages);
languages.get(self).unwrap().clone()
}
}
4 changes: 2 additions & 2 deletions src/sizer/chars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use crate::sizer::Sizer;
/// A marker struct for counting characters in code chunks.
///
/// ```
/// use tree_sitter::Language;
/// use code_splitter::{CharCounter, Splitter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, CharCounter).unwrap();
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), CharCounter).unwrap();
///
/// let code = b"hello, world!";
/// let chunks = splitter.split(code).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/sizer/words.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use crate::sizer::Sizer;
/// A marker struct for counting words in code chunks.
///
/// ```
/// use tree_sitter::Language;
/// use code_splitter::{Splitter, WordCounter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, WordCounter).unwrap();
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), WordCounter).unwrap();
///
/// let code = b"hello, world!";
/// let chunks = splitter.split(code).unwrap();
Expand Down
20 changes: 10 additions & 10 deletions src/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,32 @@ where
///
/// # Example: split by characters
/// ```
/// use tree_sitter::Language;
/// use code_splitter::{CharCounter, Splitter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, CharCounter).unwrap();
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), CharCounter).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// ```
///
/// # Example: split by words
/// ```
/// use tree_sitter::Language;
/// use code_splitter::{Splitter, WordCounter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, WordCounter).unwrap();
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), WordCounter).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// ```
///
/// # Example: split by tokens with huggingface tokenizer
/// ```
/// # #[cfg(feature = "tokenizers")]
/// # {
/// use tree_sitter::Language;
/// use code_splitter::Splitter;
/// use tokenizers::Tokenizer;
///
/// let lang = tree_sitter_md::language();
/// let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
/// let splitter = Splitter::new(lang, tokenizer).unwrap();
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), tokenizer).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// # }
/// ```
Expand All @@ -60,12 +60,12 @@ where
/// ```
/// # #[cfg(feature = "tiktoken-rs")]
/// # {
/// use tree_sitter::Language;
/// use code_splitter::Splitter;
/// use tiktoken_rs::cl100k_base;
///
/// let lang = tree_sitter_md::language();
/// let bpe = cl100k_base().unwrap();
/// let splitter = Splitter::new(lang, bpe).unwrap();
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), bpe).unwrap();
/// let chunks = splitter.split(b"hello, world!").unwrap();
/// # }
/// ```
Expand All @@ -84,10 +84,10 @@ where
///
/// # Example: set the maximum size to 256
/// ```
/// use tree_sitter::Language;
/// use code_splitter::{CharCounter, Splitter};
///
/// let lang = tree_sitter_md::language();
/// let splitter = Splitter::new(lang, CharCounter)
/// let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), CharCounter)
/// .unwrap()
/// .with_max_size(256);
/// let chunks = splitter.split(b"hello, world!").unwrap();
Expand Down
4 changes: 2 additions & 2 deletions tests/golang.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use code_splitter::{CharCounter, Sizer, Splitter, WordCounter};
use std::fs;
use tree_sitter::Language;

const TEST_FILE: &str = "tests/testdata/rectangle.go";

Expand All @@ -12,9 +13,8 @@ where
T: Sizer,
{
let code = read_test_file();
let lang = tree_sitter_go::language();

let splitter = Splitter::new(lang, sizer)
let splitter = Splitter::new(Language::new(tree_sitter_go::LANGUAGE), sizer)
.expect("Failed to create golang splitter")
.with_max_size(max_size);
let chunks = splitter.split(&code).expect("Failed to split golang code");
Expand Down
4 changes: 2 additions & 2 deletions tests/markdown.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use code_splitter::{CharCounter, Sizer, Splitter, WordCounter};
use std::fs;
use tree_sitter::Language;

const TEST_FILE: &str = "tests/testdata/markdown.md";

Expand All @@ -12,9 +13,8 @@ where
T: Sizer,
{
let code = read_test_file();
let lang = tree_sitter_md::language();

let splitter = Splitter::new(lang, sizer)
let splitter = Splitter::new(Language::new(tree_sitter_md::LANGUAGE), sizer)
.expect("Failed to create markdown splitter")
.with_max_size(max_size);
let chunks = splitter
Expand Down
4 changes: 2 additions & 2 deletions tests/python.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use code_splitter::{CharCounter, Sizer, Splitter, WordCounter};
use std::fs;
use tree_sitter::Language;

const TEST_FILE: &str = "tests/testdata/rectangle.py";

Expand All @@ -12,9 +13,8 @@ where
T: Sizer,
{
let code = read_test_file();
let lang = tree_sitter_python::language();

let splitter = Splitter::new(lang, sizer)
let splitter = Splitter::new(Language::new(tree_sitter_python::LANGUAGE), sizer)
.expect("Failed to create python splitter")
.with_max_size(max_size);
let chunks = splitter.split(&code).expect("Failed to split python code");
Expand Down
4 changes: 2 additions & 2 deletions tests/rust.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use code_splitter::{CharCounter, Sizer, Splitter, WordCounter};
use std::fs;
use tree_sitter::Language;

const TEST_FILE: &str = "tests/testdata/rectangle.rs";

Expand All @@ -12,9 +13,8 @@ where
T: Sizer,
{
let code = read_test_file();
let lang = tree_sitter_rust::language();

let splitter = Splitter::new(lang, sizer)
let splitter = Splitter::new(Language::new(tree_sitter_rust::LANGUAGE), sizer)
.expect("Failed to create rust splitter")
.with_max_size(max_size);
let chunks = splitter.split(&code).expect("Failed to split rust code");
Expand Down