diff --git a/Cargo.lock b/Cargo.lock index 5e86c23..d3cb511 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,6 +46,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "aligned-vec" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b" +dependencies = [ + "equator", + "serde", +] + [[package]] name = "alloc-no-stdlib" version = "2.0.4" @@ -848,6 +858,15 @@ dependencies = [ "serde", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bindgen" version = "0.69.5" @@ -2250,6 +2269,26 @@ dependencies = [ "log", ] +[[package]] +name = "equator" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc" +dependencies = [ + "equator-macro", +] + +[[package]] +name = "equator-macro" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -4422,6 +4461,7 @@ dependencies = [ "rust-lapper", "strum", "strum_macros", + "superintervals", "tokio", ] @@ -4666,6 +4706,19 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "superintervals" +version = "0.3.6" +dependencies = [ + "aligned-vec", + "bincode", + "clap", + "fnv", + "libc", + "rand 0.8.5", + "serde", +] + [[package]] name = "syn" version = "2.0.106" diff --git a/flamegraph.svg b/flamegraph.svg new file mode 100644 index 0000000..b295dd9 --- /dev/null +++ b/flamegraph.svg @@ -0,0 +1,491 @@ +Flame Graph Reset ZoomSearch databio_benchmark-aa444f60d7f9d923`<criterion::Criterion as core::default::Default>::default (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`<criterion::Criterion as core::default::Default..databio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initialize (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initializedatabio_benchmark-aa444f60d7f9d923`once_cell::imp::initialize_or_wait (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`once_cell::imp::initialize_or_waitdatabio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initialize::_{{closure}} (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initialize::_{{clo..databio_benchmark-aa444f60d7f9d923`core::ops::function::FnOnce::call_once (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`core::ops::function::FnOnce::call_oncedatabio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initialize (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initializedatabio_benchmark-aa444f60d7f9d923`once_cell::imp::initialize_or_wait (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`once_cell::imp::initialize_or_waitdatabio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initialize::_{{closure}} (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`once_cell::imp::OnceCell<T>::initialize::_{{clo..databio_benchmark-aa444f60d7f9d923`criterion_plot::version (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`criterion_plot::versiondatabio_benchmark-aa444f60d7f9d923`std::process::Command::output (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`std::process::Command::outputdatabio_benchmark-aa444f60d7f9d923`std::sys::pal::unix::process::process_inner::_<impl std::sys::pal::unix::process::process_common::Command>::spawn (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`std::sys::pal::unix::process::process_inner::_<..libsystem_c.dylib`posix_spawnp (1 samples, 50.00%)libsystem_c.dylib`posix_spawnplibsystem_kernel.dylib`__posix_spawn (1 samples, 50.00%)libsystem_kernel.dylib`__posix_spawnall (2 samples, 100%)dyld`start (2 samples, 100.00%)dyld`startdatabio_benchmark-aa444f60d7f9d923`main (2 samples, 100.00%)databio_benchmark-aa444f60d7f9d923`maindatabio_benchmark-aa444f60d7f9d923`std::rt::lang_start_internal (2 samples, 100.00%)databio_benchmark-aa444f60d7f9d923`std::rt::lang_start_internaldatabio_benchmark-aa444f60d7f9d923`std::rt::lang_start::_{{closure}} (2 samples, 100.00%)databio_benchmark-aa444f60d7f9d923`std::rt::lang_start::_{{closure}}databio_benchmark-aa444f60d7f9d923`std::sys_common::backtrace::__rust_begin_short_backtrace (2 samples, 100.00%)databio_benchmark-aa444f60d7f9d923`std::sys_common::backtrace::__rust_begin_short_backtracedatabio_benchmark-aa444f60d7f9d923`databio_benchmark::main (2 samples, 100.00%)databio_benchmark-aa444f60d7f9d923`databio_benchmark::maindatabio_benchmark-aa444f60d7f9d923`criterion::benchmark_group::BenchmarkGroup<M>::bench_function (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`criterion::benchmark_group::BenchmarkGroup<M>::..databio_benchmark-aa444f60d7f9d923`criterion::analysis::common (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`criterion::analysis::commondatabio_benchmark-aa444f60d7f9d923`criterion::routine::Routine::sample (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`criterion::routine::Routine::sampledatabio_benchmark-aa444f60d7f9d923`<criterion::routine::Function<M,F,T> as criterion::routine::Routine<M,T>>::warm_up (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`<criterion::routine::Function<M,F,T> as criteri..databio_benchmark-aa444f60d7f9d923`criterion::bencher::AsyncBencher<A,M>::iter (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`criterion::bencher::AsyncBencher<A,M>::iterdatabio_benchmark-aa444f60d7f9d923`<&tokio::runtime::runtime::Runtime as criterion::async_executor::AsyncExecutor>::block_on (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`<&tokio::runtime::runtime::Runtime as criterion..databio_benchmark-aa444f60d7f9d923`tokio::runtime::context::runtime::enter_runtime (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`tokio::runtime::context::runtime::enter_runtimedatabio_benchmark-aa444f60d7f9d923`tokio::runtime::scheduler::current_thread::CoreGuard::block_on (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`tokio::runtime::scheduler::current_thread::Core..databio_benchmark-aa444f60d7f9d923`tokio::runtime::context::set_scheduler (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`tokio::runtime::context::set_schedulerdatabio_benchmark-aa444f60d7f9d923`tokio::runtime::scheduler::current_thread::Context::enter (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`tokio::runtime::scheduler::current_thread::Cont..databio_benchmark-aa444f60d7f9d923`criterion::bencher::AsyncBencher<A,M>::iter::_{{closure}} (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`criterion::bencher::AsyncBencher<A,M>::iter::_{..databio_benchmark-aa444f60d7f9d923`databio_benchmark::databio_benchmark::_{{closure}}::_{{closure}}::_{{closure}} (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`databio_benchmark::databio_benchmark::_{{closur..databio_benchmark-aa444f60d7f9d923`datafusion::execution::context::parquet::_<impl datafusion::execution::context::SessionContext>::register_parquet::_{{closure}} (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`datafusion::execution::context::parquet::_<impl..databio_benchmark-aa444f60d7f9d923`alloc::collections::btree::map::IntoIter<K,V,A>::dying_next (1 samples, 50.00%)databio_benchmark-aa444f60d7f9d923`alloc::collections::btree::map::IntoIter<K,V,A>.. \ No newline at end of file diff --git a/output.txt b/output.txt new file mode 100644 index 0000000..7796154 --- /dev/null +++ b/output.txt @@ -0,0 +1 @@ +test Coitrees-Small-Medium-2-7-7-8 ... bench: 8488532322 ns/iter (+/- 417491090) diff --git a/sequila/sequila-core/Cargo.toml b/sequila/sequila-core/Cargo.toml index e1ceacc..e7e0a17 100644 --- a/sequila/sequila-core/Cargo.toml +++ b/sequila/sequila-core/Cargo.toml @@ -24,6 +24,7 @@ fnv = "1.0.7" bio = "2.0.1" rand = "0.8.5" rust-lapper = "1.1.0" +superintervals = { path = "superintervals" } [dev-dependencies] rstest = "0.22.0" diff --git a/sequila/sequila-core/src/physical_planner/joins/interval_join.rs b/sequila/sequila-core/src/physical_planner/joins/interval_join.rs index 46aa4c2..6438949 100644 --- a/sequila/sequila-core/src/physical_planner/joins/interval_join.rs +++ b/sequila/sequila-core/src/physical_planner/joins/interval_join.rs @@ -686,6 +686,7 @@ enum IntervalJoinAlgorithm { IntervalTree(FnvHashMap>), ArrayIntervalTree(FnvHashMap>), Lapper(FnvHashMap>), + SuperIntervals(FnvHashMap>), CoitreesNearest( FnvHashMap< u64, @@ -719,6 +720,9 @@ impl Debug for IntervalJoinAlgorithm { f.debug_struct("ArrayIntervalTree").field("0", m).finish() } IntervalJoinAlgorithm::Lapper(m) => f.debug_struct("Lapper").field("0", m).finish(), + IntervalJoinAlgorithm::SuperIntervals(m) => { + f.debug_struct("SuperIntervals").field("0", m).finish() + } } } } @@ -815,6 +819,20 @@ impl IntervalJoinAlgorithm { IntervalJoinAlgorithm::Lapper(hashmap) } + Algorithm::SuperIntervals => { + let hashmap = hash_map + .into_iter() + .map(|(k, v)| { + let mut map = superintervals::IntervalMap::new(); + for s in v { + map.add(s.start, s.end, s.position); + } + map.build(); + (k, map) + }) + .collect::>>(); + IntervalJoinAlgorithm::SuperIntervals(hashmap) + } } } @@ -954,6 +972,13 @@ impl IntervalJoinAlgorithm { } } } + IntervalJoinAlgorithm::SuperIntervals(hashmap) => { + if let Some(intervals) = hashmap.get(&k) { + for val in intervals.search_values_iter(start, end) { + f(val); + } + } + } } } } @@ -1344,6 +1369,7 @@ mod tests { Some(Algorithm::Coitrees), Some(Algorithm::IntervalTree), Some(Algorithm::ArrayIntervalTree), + Some(Algorithm::SuperIntervals), ]; for alg in algs { @@ -1439,6 +1465,7 @@ mod tests { Some(Algorithm::IntervalTree), Some(Algorithm::ArrayIntervalTree), Some(Algorithm::Lapper), + Some(Algorithm::SuperIntervals), ]; let schema = &schema(); diff --git a/sequila/sequila-core/src/session_context.rs b/sequila/sequila-core/src/session_context.rs index abcb24d..9dab7d1 100644 --- a/sequila/sequila-core/src/session_context.rs +++ b/sequila/sequila-core/src/session_context.rs @@ -65,6 +65,7 @@ pub enum Algorithm { IntervalTree, ArrayIntervalTree, Lapper, + SuperIntervals, CoitreesNearest, CoitreesCountOverlaps, } @@ -90,6 +91,7 @@ impl FromStr for Algorithm { "intervaltree" => Ok(Algorithm::IntervalTree), "arrayintervaltree" => Ok(Algorithm::ArrayIntervalTree), "lapper" => Ok(Algorithm::Lapper), + "superintervals" => Ok(Algorithm::SuperIntervals), "coitreesnearest" => Ok(Algorithm::CoitreesNearest), "coitreescountoverlaps" => Ok(Algorithm::CoitreesCountOverlaps), _ => Err(ParseAlgorithmError(format!( @@ -107,6 +109,7 @@ impl std::fmt::Display for Algorithm { Algorithm::IntervalTree => "IntervalTree", Algorithm::ArrayIntervalTree => "ArrayIntervalTree", Algorithm::Lapper => "Lapper", + Algorithm::SuperIntervals => "SuperIntervals", Algorithm::CoitreesNearest => "CoitreesNearest", Algorithm::CoitreesCountOverlaps => "CoitreesCountOverlaps", }; diff --git a/sequila/sequila-core/superintervals/.gitattributes b/sequila/sequila-core/superintervals/.gitattributes new file mode 100644 index 0000000..f9198e0 --- /dev/null +++ b/sequila/sequila-core/superintervals/.gitattributes @@ -0,0 +1,10 @@ +*.py linguist-detectable=true +*.rs linguist-detectable=true +*.cpp linguist-detectable=true +*.cc linguist-detectable=true +*.h linguist-detectable=true +*.hpp linguist-detectable=true + +test/* linguist-detectable=false +examples/* linguist-detectable=false +*.md linguist-detectable=false \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/.github/workflows/python-package.yml b/sequila/sequila-core/superintervals/.github/workflows/python-package.yml new file mode 100644 index 0000000..364e5be --- /dev/null +++ b/sequila/sequila-core/superintervals/.github/workflows/python-package.yml @@ -0,0 +1,28 @@ +name: Python package + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install + run: | + python -m pip install --upgrade pip + python -m pip install . + diff --git a/sequila/sequila-core/superintervals/.gitignore b/sequila/sequila-core/superintervals/.gitignore new file mode 100644 index 0000000..375c8ea --- /dev/null +++ b/sequila/sequila-core/superintervals/.gitignore @@ -0,0 +1,45 @@ +/test/run-cpp-libs +/test/run-cpp-libs.dSYM/ +/test/run_tests.dSYM/ +/test/run-tests +/test/*.bed +/build/ +/test/bench_data/ +/test/3rd-party/coitrees/target/ +/test/intervaldb.o +/superintervals.egg-info/ +/superintervals/intervalset.cpp +/superintervals/intervalset.*.so +/test/bed-intersect +/superintervals/superintervals.cpp + + +# Added by cargo + +/target +/test/bed-intersect-si +/src/intervalset.cpython-311-x86_64-linux-gnu.so +/src/superintervals/intervalset.cpp +/src/superintervals.cpp +/src/superintervals/intervalset.cpython-310-x86_64-linux-gnu.so +/src/superintervals/intervalset.cpp +/src/superintervals/intervalset.cpython-310-x86_64-linux-gnu.so +/test/benchmark.png +/test/cgranges.o +/test/chr1.genome +/old/ +/src/R/src/superintervals.so +/src/R/src/RcppExports.o +/src/R/src/intervalmap_r.o +/test/bed-intersect-rl +/src/superintervals/intervalmap.cpp +/src/R/..Rcheck/ +/src/R/superintervals_0.99.0.tar.gz +/src/R/superintervals.Rcheck/ +/src/superintervals.egg-info/ +/R/..Rcheck/ +/R/superintervals.Rcheck/ +/R/superintervals_0.99.0.tar.gz +/test/3rd-party/rust-lapper_kc/target/ +/test/3rd-party/rust-lapper_kc/benches/ +/test/3rd-party/rust-lapper_kc/images/ diff --git a/sequila/sequila-core/superintervals/Cargo.toml b/sequila/sequila-core/superintervals/Cargo.toml new file mode 100644 index 0000000..27ec561 --- /dev/null +++ b/sequila/sequila-core/superintervals/Cargo.toml @@ -0,0 +1,59 @@ +[package] +name = "superintervals" +description = "Interval overlap library" +version = "0.3.6" +authors = ["Kez Cleal. "] +edition = "2021" +repository = "https://github.com/kcleal/superintervals" +homepage = "https://github.com/kcleal/superintervals" +documentation = "https://github.com/kcleal/superintervals" +readme = "README.md" +license-file = "LICENSE" +exclude = [ + "test/*", + "src/superintervals/*", + "src/R/*", + "src/superintervals.egg*", + ".idea", + "dist", + "py*", "setup.py", "*.h", "*.hpp", "MANIFEST.in", ".gitignore" +] + +[lib] +path = "src/superintervals.rs" + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +aligned-vec = { version = "0.6.4", features = ["serde"] } + +[profile.release] +debug = 0 +strip = "symbols" +lto = true +opt-level = 3 +codegen-units = 1 + +[features] +nosimd = [] + +[profile.dev.package."*"] +opt-level = 3 # Compile dependencies with optimizations on even in debug mode. + +[profile.no-opt] +inherits = "dev" +opt-level = 0 + +[profile.profiling] +inherits = "release" +debug = true +strip = false + +[dev-dependencies] +rand = "0.8" +fnv = "1.0.7" +libc = "0.2" +clap = { version = "4.3.7", features = ["derive"] } +bincode = "1.3.3" + +[[example]] +name = "bed-intersect-si" \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/LICENSE b/sequila/sequila-core/superintervals/LICENSE new file mode 100644 index 0000000..4dfb928 --- /dev/null +++ b/sequila/sequila-core/superintervals/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Kez Cleal + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/README.md b/sequila/sequila-core/superintervals/README.md new file mode 100644 index 0000000..b038bc3 --- /dev/null +++ b/sequila/sequila-core/superintervals/README.md @@ -0,0 +1,197 @@ +SuperIntervals +============== + +A fast, memory-efficient data structure for interval intersection queries. +SuperIntervals uses a novel superset-index approach that maintains +intervals in position-sorted order, enabling cache-friendly searches and SIMD-optimized counting. + +Available for [C++](#cpp), [Rust](#rust), [Python](#python). +The R package is hosted at https://github.com/kcleal/superintervalsr. + +### Features: + +- Linear-time index construction from sorted intervals +- Cache-friendly querying +- SIMD acceleration (AVX2/Neon) for counting operations +- Small memory overhead (one size_t per interval) +- Optional Eytzinger memory layout for slightly faster queries (C++ only) +- No dependencies, header only + +### Notes: + +- Intervals are considered end-inclusive +- The build() function must be called before any queries +- Found intervals are returned in **reverse** position-sorted order + +## Python + +Install using `pip install superintervals` + +```python +from superintervals import IntervalMap + +imap = IntervalMap() +imap.add(10, 20, 'A') +imap.build() +results = imap.search_values(8, 20) # ['A'] +``` + +Python API documentation can be found here: +https://github.com/kcleal/superintervals/blob/main/src/superintervals/README.md + + +## Cpp + +Header only implementation, copy to your include directory. + +```cpp +#include "SuperIntervals.hpp" + +si::IntervalMap imap; +imap.add(1, 5, "A"); +imap.build(); +std::vector results; +imap.search_values(4, 9, results); +``` + +C++ API documentation can be found here: +https://github.com/kcleal/superintervals/blob/main/src/README_cpp.md + + +## Rust + +Add to your project using `cargo add superintervals` + +```rust +use superintervals::IntervalMap; + +let mut imap = IntervalMap::new(); +imap.add(1, 5, "A"); +imap.build(); +let mut results = Vec::new(); +imap.search_values(4, 11, &mut results); +``` + +Rust API documentation can be found here: +https://github.com/kcleal/superintervals/blob/main/src/README_rust.md + + + +## Test programs +Test programs expect plain text BED files and only assess chr1 records - other chromosomes are ignored. + +C++ program compares SuperIntervals, ImplicitIntervalTree, IntervalTree and NCLS: +``` +cd test; make +./run-tests +./run-cpp-libs a.bed b.bed +``` + +Rust program: +``` +RUSTFLAGS="-Ctarget-cpu=native" cargo run --release --example bed-intersect-si +cargo run --release --example bed-intersect-si a.bed b.bed +``` + +Python program: +``` +python test/run-py-libs.py a.bed b.bed +``` + +R program: +``` +Rscript src/R/benchmark.R +``` + +## Benchmark + +SuperIntervals (SI) was compared with: + +- Coitrees (Rust: https://github.com/dcjones/coitrees) +- Implicit Interval Tree (C++: https://github.com/lh3/cgranges) +- Interval Tree (C++: https://github.com/ekg/intervaltree) +- Nested Containment List (C: https://github.com/pyranges/ncls/tree/master/ncls/src) + +Main results: + +- Finding interval intersections is on average ~1.5-3x faster than other libraries (Coitrees for Rust, Implicit Interval Tree for C++), with some +exceptions. Coitrees-s was faster for one test (ONT reads, sorted DB53 reads). +- The SIMD counting performance of coitrees and superintervals was similar. + +Datasets https://github.com/kcleal/superintervals/releases/download/v0.2.0/data.tar.gz: + +- `rna / anno` RNA-seq reads and annotations from cgranges repository +- `ONT reads` nanopore alignments from sample PAO33946 chr1, converted to bed format +- `DB53 reads` paired-end reads from sample DB53, NCBI BioProject PRJNA417592, chr1, converted to bed format +- `mito-b, mito-a` paired-end reads from sample DB53 chrM, converted to bed format (mito-b and mito-a are the same) +- `genes` UCSC genes from hg19 + +Test programs use internal timers and print data to stdout, measuring the index time, and time to find all intersections. Other steps such as file IO are ignored. Test programs also only assess chr1 bed records - other chromosomes are ignored. For 'chrM' records, the M was replaced with 1 using sed. Data were assessed in position sorted and random order. Datasets can be found on the Releases page, and the test/run_tools.sh script has instructions for how to repeat the benchmark. + +Timings were in microseconds using an i9-11900K, 64 GB, 2TB NVMe machine. + +## Finding interval intersections + +- Coitrees-s uses the SortedQuerent version of coitrees +- SI = superintervals. Eytz refers to the eytzinger layout. -rs is the Rust implementation. + +### Intervals in sorted order + +| | Coitrees | Coitrees-s | SuperIntervals-rs | SuperIntervalsEytz-rs | ImplicitITree-C++ | IntervalTree-C++ | NCLS-C | SuperIntervals-C++ | SuperIntervalsEytz-C++ | +| --------------------- | -------- | ---------- |-------------------| --------------------- | ----------------- | ---------------- | ------ | ------------------ | ---------------------- | +| DB53 reads, ONT reads | 1668 | 3179 | **757** | **757** | 3831 | 44404 | 10642 | **1315** | 1358 | +| DB53 reads, genes | 55 | 84 | **21** | **21** | 122 | 109 | 291 | 42 | **40** | +| ONT reads, DB53 reads | 6504 | **3354** | 3859 | 3854 | 17949 | 12280 | 30772 | 5290 | **4462** | +| anno, rna | 50 | 35 | **18** | **18** | 127 | 90 | 208 | 29 | **22** | +| genes, DB53 reads | 1171 | 1018 | 301 | **296** | 3129 | 1315 | 1780 | 442 | **323** | +| mito-b, mito-a | 34769 | 34594 | 16971 | **16952** | 93900 | 107660 | 251707 | 33177 | **32985** | +| rna, anno | 31 | 23 | 21 | **20** | 70 | 55 | 233 | 28 | **27** | + +### Intervals in random order + +| | Coitrees | Coitrees-s | SuperIntervals-rs | SuperIntervalsEytz-rs | ImplicitITree-C++ | IntervalTree-C++ | NCLS-C | SuperIntervals-C++ | SuperIntervalsEytz-C++ | +| --------------------- | -------- | ---------- | ----------------- | --------------------- | ----------------- | ---------------- | ------ | ------------------ | ---------------------- | +| DB53 reads, ONT reads | 2943 | 4663 | 1356 | **1355** | 6505 | 46743 | 11947 | 2491 | **2169** | +| DB53 reads, genes | 78 | 130 | 27 | **26** | 170 | 125 | 305 | 58 | **51** | +| ONT reads, DB53 reads | 16650 | 18931 | 16116 | **16037** | 38677 | 27832 | 53452 | **23003** | 23232 | +| anno, rna | 89 | 105 | **54** | **54** | 188 | 143 | 294 | **58** | 60 | +| genes, DB53 reads | 2222 | 2424 | 1693 | **1684** | 4490 | 2701 | 3605 | **1251** | 1749 | +| mito-b, mito-a | 38030 | 86309 | **18326** | 18368 | 125336 | 118321 | 256293 | 42195 | **41695** | +| rna, anno | 53 | 73 | **45** | **45** | 137 | 83 | 311 | **52** | **52** | + +## Counting interval intersections + +### Intervals in sorted order + +| | Coitrees | SuperIntervals-rs | SuperIntervalsEytz-rs | SuperIntervals-C++ | SuperIntervalsEytz-C++ | +| --------------------- | -------- | ----------------- | --------------------- | ------------------ | ---------------------- | +| DB53 reads, ONT reads | 551 | 370 | 371 | **241** | 263 | +| DB53 reads, genes | 28 | 12 | 12 | 8 | **7** | +| ONT reads, DB53 reads | 2478 | 1909 | 1890 | 2209 | **1312** | +| anno, rna | 26 | 14 | 14 | 22 | **11** | +| genes, DB53 reads | 747 | 321 | 336 | 446 | **290** | +| mito-b, mito-a | 6894 | 6727 | 6746 | 3088 | **2966** | +| rna, anno | **9** | 13 | 13 | 12 | 10 | + +### Intervals in random order + +| | Coitrees | SuperIntervals-rs | SuperIntervalsEytz-rs | SuperIntervals-C++ | SuperIntervalsEytz-C++ | +| --------------------- | -------- | ----------------- | --------------------- | ------------------ | ---------------------- | +| DB53 reads, ONT reads | 1988 | 972 | 969 | 1016 | **778** | +| DB53 reads, genes | 53 | 20 | 20 | 16 | **13** | +| ONT reads, DB53 reads | 6692 | 8864 | 8733 | **8182** | 9523 | +| anno, rna | 52 | 49 | 48 | **47** | 50 | +| genes, DB53 reads | 1503 | 1628 | 1592 | **1120** | 1623 | +| mito-b, mito-a | 14354 | 7579 | 7600 | 4442 | **4383** | +| rna, anno | 22 | 30 | 29 | **25** | **25** | + + +## Acknowledgements + +- The rust test program borrows heavily from the coitrees package +- The superset-index implemented here exploits a similar interval ordering as described in +Schmidt 2009 "Interval Stabbing Problems in Small Integer Ranges". However, the superset-index has several advantages including + 1. An implicit memory layout + 1. General purpose implementation (not just small integer ranges) + 1. SIMD counting algorithm +- The Eytzinger layout was adapted from Sergey Slotin, Algorithmica diff --git a/sequila/sequila-core/superintervals/examples/bed-intersect-si.rs b/sequila/sequila-core/superintervals/examples/bed-intersect-si.rs new file mode 100644 index 0000000..c3bffac --- /dev/null +++ b/sequila/sequila-core/superintervals/examples/bed-intersect-si.rs @@ -0,0 +1,295 @@ +use std::error::Error; +use std::fs::File; +use std::io::{BufRead, BufReader, Write}; +use std::str; +use std::time::Instant; +extern crate fnv; +use clap::Parser; +use fnv::FnvHashMap; +extern crate libc; + +use superintervals::IntervalMap; +// use IntervalMap::IntervalMapEytz; + +use bincode; + +// Define a trait that all IntervalMap subclasses implement +// pub trait IntervalCollection { +// fn new() -> Self; +// fn add(&mut self, start: i32, end: i32, value: T); +// fn index(&mut self); +// } +// +// // Implement the trait for IntervalMap +// impl IntervalCollection for IntervalMap { +// fn new() -> Self { +// IntervalMap::new() +// } +// fn add(&mut self, start: i32, end: i32, value: T) { +// self.add(start, end, value); +// } +// fn index(&mut self) { +// self.index(); +// } +// } + +// Implement the trait for IntervalMapEytz +// impl IntervalCollection for IntervalMapEytz { +// fn new() -> Self { +// IntervalMapEytz::new() +// } +// fn add(&mut self, start: i32, end: i32, value: T) { +// self.add(start, end, value); +// } +// fn index(&mut self) { +// self.index(); +// } +// } + +type GenericError = Box; + +// Parse a i32 with no checking whatsoever. (e.g. non-number characters will just) +fn i32_from_bytes_uncheckd(s: &[u8]) -> i32 { + if s.is_empty() { + 0 + } else if s[0] == b'-' { + -s[1..].iter().fold(0, |a, b| a * 10 + (b & 0x0f) as i32) + } else { + s.iter().fold(0, |a, b| a * 10 + (b & 0x0f) as i32) + } +} + +fn parse_bed_line(line: &[u8]) -> (&str, i32, i32) { + let n = line.len() - 1; + let mut p = 0; + for c in &line[p..n] { + if *c == b'\t' { + break; + } + p += 1; + } + let seqname = unsafe { str::from_utf8_unchecked(&line[..p]) }; + p += 1; + let p0 = p; + + for c in &line[p..n] { + if *c == b'\t' { + break; + } + p += 1; + } + let first = i32_from_bytes_uncheckd(&line[p0..p]); + p += 1; + let p0 = p; + + for c in &line[p..n] { + if *c == b'\t' { + break; + } + p += 1; + } + let last = i32_from_bytes_uncheckd(&line[p0..p]) - 1; + + (seqname, first, last) +} + +// Read a bed file into a IntervalMap +// fn read_bed_file>(path: &str) -> Result, GenericError> { +fn read_bed_file(path: &str) -> Result>, GenericError> { + let mut nodes = FnvHashMap::default(); + let file = File::open(path)?; + let mut rdr = BufReader::new(file); + let mut line = Vec::new(); + while rdr.read_until(b'\n', &mut line).unwrap() > 0 { + let (seqname, first, last) = parse_bed_line(&line); + if seqname != "chr1" || last < 0 || last < first { + line.clear(); + continue; + } + // let intervals = nodes.entry(seqname.to_string()).or_insert_with(I::new); + let intervals = nodes + .entry(seqname.to_string()) + .or_insert_with(IntervalMap::new); + intervals.add(first, last, ()); + line.clear(); + } + let now = Instant::now(); + for intervals in nodes.values_mut() { + intervals.build(); + } + eprint!("{},", now.elapsed().as_micros()); + std::io::stderr().flush().unwrap(); + Ok(nodes) +} + +fn query_bed_files(filename_a: &str, filename_b: &str) -> Result<(), GenericError> { + let file = File::open(filename_b)?; + let mut rdr = BufReader::new(file); + let mut ranges: Vec<(i32, i32)> = Vec::new(); + let mut line = Vec::new(); + while rdr.read_until(b'\n', &mut line).unwrap() > 0 { + let (chrom, first, last) = parse_bed_line(&line); + if chrom != "chr1" || last < 0 || last < first { + line.clear(); + continue; + } + ranges.push((first, last)); + line.clear(); + } + // + eprint!("SuperIntervals-rs,"); + // let mut trees: FnvHashMap> = read_bed_file::>(filename_a)?; + let mut trees: FnvHashMap> = read_bed_file(filename_a)?; + let intervals: &mut IntervalMap<()> = trees + .get_mut("chr1") + .ok_or("Chromosome intervals not found")?; + + // Verify the deserialized data + let serialized_size = bincode::serialized_size(&intervals).unwrap(); + assert_ne!(serialized_size, 0); + + // Find overlaps (collecting results) + let mut total_found = 0; + let mut results = Vec::new(); + results.reserve(10000); + let mut now = Instant::now(); + for &(first, last) in &ranges { + intervals.search_values(first, last, &mut results); + total_found += results.len(); + results.clear(); + // for _value in intervals.search_keys_iter(first, last) { + // total_found += 1; + // } + } + eprint!("{},{},", now.elapsed().as_micros(), total_found); + std::io::stderr().flush().unwrap(); + + // Find overlaps2 + total_found = 0; + now = Instant::now(); + for &(first, last) in &ranges { + intervals.search_values_large(first, last, &mut results); + total_found += results.len(); + results.clear(); + } + eprint!("{},{},", now.elapsed().as_micros(), total_found); + std::io::stderr().flush().unwrap(); + + // Count overlaps + let mut n_overlaps = 0; + now = Instant::now(); + for &(first, last) in &ranges { + n_overlaps += intervals.count(first, last); + } + eprint!("{},{},", now.elapsed().as_micros(), n_overlaps); + std::io::stderr().flush().unwrap(); + + // Count overlaps exponential search + let mut n_overlaps = 0; + now = Instant::now(); + for &(first, last) in &ranges { + n_overlaps += intervals.count_large(first, last); + } + eprint!("{},{}\n", now.elapsed().as_micros(), n_overlaps); + std::io::stderr().flush().unwrap(); + + // + // eprint!("IntervalMapEytz-rs,"); + // let mut trees2: FnvHashMap> = read_bed_file::>(filename_a)?; + // let intervals2: &mut IntervalMapEytz<()> = trees2.get_mut("chr1").ok_or("Chromosome intervals not found")?; + // + // // Find overlaps (collecting results) + // total_found = 0; + // results.clear(); + // results.reserve(10000); + // now = Instant::now(); + // for &(first, last) in &ranges { + // intervals2.find_overlaps(first, last, &mut results); + // total_found += results.len(); + // results.clear(); + // } + // eprint!("{},{},", now.elapsed().as_micros(), total_found); + // std::io::stderr().flush().unwrap(); + // + // // Count overlaps + // n_overlaps = 0; + // now = Instant::now(); + // for &(first, last) in &ranges { + // n_overlaps += intervals2.count(first, last); + // } + // eprint!("{},{}\n", now.elapsed().as_micros(), n_overlaps); + // std::io::stderr().flush().unwrap(); + // + Ok(()) +} + +#[derive(Parser, Debug)] +#[command(about = " Find overlaps between two groups of intervals ")] +struct Args { + /// intervals to index + #[arg(value_name = "intervals.bed")] + input1: String, + /// query intervals + #[arg(value_name = "queries.bed")] + input2: String, + + /// compute proportion of queries covered + #[arg(short = 'c', long)] + coverage: bool, +} + +fn query_bed_files_coverage(filename_a: &str, filename_b: &str) -> Result<(), GenericError> { + // let mut trees: FnvHashMap> = read_bed_file::>(filename_a)?; + let mut trees: FnvHashMap> = read_bed_file(filename_a)?; + + let file = File::open(filename_b)?; + let mut rdr = BufReader::new(file); + let mut line = Vec::new(); + + let mut total_count: usize = 0; + let now = Instant::now(); + + while rdr.read_until(b'\n', &mut line).unwrap() > 0 { + let (seqname, first, last) = parse_bed_line(&line); + let mut count: usize = 0; + let mut cov: i32 = 0; + if let Some(seqname_tree) = trees.get_mut(seqname) { + let countcov = seqname_tree.coverage(first, last); + count = countcov.0; + cov = countcov.1; + } + + unsafe { + let linelen = line.len(); + line[linelen - 1] = b'\0'; + libc::printf( + b"%s\t%u\t%u\n\0".as_ptr() as *const libc::c_char, + line.as_ptr() as *const libc::c_char, + count as u32, + cov, + ); + } + total_count += count; + line.clear(); + } + + eprintln!("overlap: {}s", now.elapsed().as_millis() as f64 / 1000.0); + eprintln!("With coverage func total overlaps: {}", total_count); + + Ok(()) +} + +fn main() { + let matches = Args::parse(); + let input1 = matches.input1.as_str(); + let input2 = matches.input2.as_str(); + let result; + if matches.coverage { + result = query_bed_files_coverage(input1, input2); + } else { + result = query_bed_files(input1, input2); + } + if let Err(err) = result { + println!("error: {}", err) + } +} diff --git a/sequila/sequila-core/superintervals/pyproject.toml b/sequila/sequila-core/superintervals/pyproject.toml new file mode 100644 index 0000000..1b82656 --- /dev/null +++ b/sequila/sequila-core/superintervals/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel", + "Cython" +] +build-backend = "setuptools.build_meta" + +[project] +name = "superintervals" +version = "0.3.6" +description = "Rapid interval intersections" +dependencies = ['Cython'] +authors = [{name = "Kez Cleal", email = "clealk@cardiff.ac.uk"}] \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/setup.py b/sequila/sequila-core/superintervals/setup.py new file mode 100644 index 0000000..5d9fef8 --- /dev/null +++ b/sequila/sequila-core/superintervals/setup.py @@ -0,0 +1,22 @@ +from setuptools import setup, find_packages, Extension +from Cython.Build import cythonize + +ext_modules = [ + Extension("superintervals.intervalmap", + ["src/superintervals/intervalmap.pyx"], + include_dirs=["src"], + language="c++", + extra_compile_args=["-std=c++17"]) +] + +print('PAKCAGES', find_packages(where='src')) # Add this line for debugging + +setup( + name='superintervals', + description="Rapid interval intersections", + author="Kez Cleal", + author_email="clealk@cardiff.ac.uk", + packages=find_packages(where='src'), + package_dir={"": "src"}, + ext_modules=cythonize(ext_modules), +) \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/src/README_cpp.md b/sequila/sequila-core/superintervals/src/README_cpp.md new file mode 100644 index 0000000..b10bcfc --- /dev/null +++ b/sequila/sequila-core/superintervals/src/README_cpp.md @@ -0,0 +1,104 @@ +# Superintervals C++ API + + +Header only implementation, copy to your include directory. + +```cpp +#include "SuperIntervals.hpp" + +si::IntervalMap imap; +imap.add(1, 5, "A"); +imap.build(); + +// Collect results into a vector +std::vector results; +imap.search_values(4, 9, results); + +// Or use lazy iterator interfaces +for (const auto &value : imap.search_values_iter(query_start, query_end)) { + std::cout << "Found: " << value << std::endl; +} +``` + +### API Reference + +**IntervalMap** class (also **IntervalMapEytz** for Eytzinger layout): + +- `void add(S start, S end, const T& value)` + Add interval with associated value + + +- `void build()` + Build index (required before queries) + + +- `void clear()` + Remove all intervals + + +- `void reserve(size_t n)` + Reserve space for n intervals + + +- `size_t size()` + Get number of intervals + + +- `const Interval& at(size_t index)` + Get Interval at index + + +- `void at(size_t index, Interval& interval)` + Fill provided Interval object + + +- `bool has_overlaps(S start, S end)` + Check if any intervals overlap range + + +- `size_t count(S start, S end)` + Count overlapping intervals (SIMD optimized) + + +- `size_t count_linear(S start, S end)` + Count overlapping intervals (linear) + + +- `size_t count_large(S start, S end)` + Count optimized for large ranges + + +- `void search_values(S start, S end, std::vector& found)` + Fill vector with values of overlapping intervals + + +- `void search_values_large(S start, S end, std::vector& found)` + Search optimized for large ranges + + +- `void search_idxs(S start, S end, std::vector& found)` + Fill vector with indices of overlapping intervals + + +- `void search_keys(S start, S end, std::vector>& found)` + Fill vector with (start,end) pairs + + +- `void search_items(S start, S end, std::vector>& found)` + Fill vector with Interval objects + + +- `void search_point(S point, std::vector& found)` + Find intervals containing single point + + +- `void coverage(S start, S end, std::pair& result)` + Get pair(count, total_coverage) for range + + +- `IndexRange search_idxs(S start, S end)` + Returns IndexRange for range-based loops over indices + + +- `ItemRange search_items(S start, S end)` + Returns ItemRange for range-based loops over intervals \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/src/README_rust.md b/sequila/sequila-core/superintervals/src/README_rust.md new file mode 100644 index 0000000..ac32a56 --- /dev/null +++ b/sequila/sequila-core/superintervals/src/README_rust.md @@ -0,0 +1,109 @@ +# Superintervals rust API + +Add to your project using `cargo add superintervals` + +```rust +use superintervals::IntervalMap; + +let mut imap = IntervalMap::new(); +imap.add(1, 5, "A"); +imap.build(); + +// Collect results into a Vec +let mut results = Vec::new(); +imap.search_values(4, 11, &mut results); + +// Or use lazy iterator interfaces +for value in imap.search_values_iter(query_start, query_end) { + println!("Found: {}", value); +} +``` + +### API Reference + +**IntervalMap** struct (also **IntervalMapEytz** for Eytzinger layout): + +- `fn new() -> Self` + Create new IntervalMap + + +- `fn add(&mut self, start: i32, end: i32, value: T)` + Add interval with associated value + + +- `fn build(&mut self)` + Build index (required before queries) + + +- `fn clear(&mut self)` + Remove all intervals + + +- `fn reserve(&mut self, n: usize)` + Reserve space for n intervals + + +- `fn size(&self) -> usize` + Get number of intervals + + +- `fn at(&self, index: usize) -> Interval` + Get Interval at index + + +- `fn has_overlaps(&mut self, start: i32, end: i32) -> bool` + Check if any intervals overlap range + + +- `fn count(&mut self, start: i32, end: i32) -> usize` + Count overlapping intervals (SIMD optimized) + + +- `fn count_linear(&mut self, start: i32, end: i32) -> usize` + Count overlapping intervals (linear) + + +- `fn count_large(&mut self, start: i32, end: i32) -> usize` + Count optimized for large ranges + + +- `fn search_values(&mut self, start: i32, end: i32, found: &mut Vec)` + Fill vector with values of overlapping intervals + + +- `fn search_values_large(&mut self, start: i32, end: i32, found: &mut Vec)` + Search optimized for large ranges + + +- `fn search_idxs(&mut self, start: i32, end: i32, found: &mut Vec)` + Fill vector with indices of overlapping intervals + + +- `fn search_keys(&mut self, start: i32, end: i32, found: &mut Vec<(i32, i32)>)` + Fill vector with (start,end) pairs + + +- `fn search_items(&mut self, start: i32, end: i32, found: &mut Vec>)` + Fill vector with Interval objects + + +- `fn search_stabbed(&mut self, point: i32, found: &mut Vec)` + Find intervals containing single point + + +- `fn coverage(&mut self, start: i32, end: i32) -> (usize, i32)` + Get (count, total_coverage) for range + + +- `fn search_idxs_iter(&mut self, start: i32, end: i32) -> IndexIterator` + Iterator over indices + + +- `fn search_items_iter(&mut self, start: i32, end: i32) -> ItemIterator` + Iterator over intervals + +- `fn search_keys_iter(&mut self, start: i32, end: i32) -> KeyIterator` + Iterator over keys + +- `fn search_values_iter(&mut self, start: i32, end: i32) -> ValueIterator` + Iterator over values \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/src/c_superintervals.h b/sequila/sequila-core/superintervals/src/c_superintervals.h new file mode 100644 index 0000000..2ce03bc --- /dev/null +++ b/sequila/sequila-core/superintervals/src/c_superintervals.h @@ -0,0 +1,315 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef SUPERINTERVALS_HEADER_INCLUDED +#define SUPERINTERVALS_HEADER_INCLUDED 1 +#endif +#include +#include +#include +#include + +#ifdef __AVX2__ + #include +#elif defined __ARM_NEON + #include +#endif + +// Define the Interval struct +typedef struct { + int32_t start; + int32_t end; + int32_t data; +} Interval; + +// Define the SuperIntervals struct +typedef struct { + int32_t* starts; + int32_t* ends; + int32_t* data; + size_t* branch; + int32_t* extent; + size_t size; + size_t capacity; + size_t idx; + bool startSorted; + bool endSorted; +} cSuperIntervals; + +// Function prototypes +cSuperIntervals* createSuperIntervals(); +void destroySuperIntervals(cSuperIntervals* si); +void clearSuperIntervals(cSuperIntervals* si); +void reserveSuperIntervals(cSuperIntervals* si, size_t n); +void addInterval(cSuperIntervals* si, int32_t start, int32_t end, int32_t value); +void sortIntervals(cSuperIntervals* si); +void indexSuperIntervals(cSuperIntervals* si, bool use_linear); +void upperBound(cSuperIntervals* si, int32_t value); +bool anyOverlaps(cSuperIntervals* si, int32_t start, int32_t end); +void findOverlaps(cSuperIntervals* si, int32_t start, int32_t end, int32_t* found, size_t* found_size); +size_t countOverlaps(cSuperIntervals* si, int32_t start, int32_t end); + +// Helper function prototypes +static void sortBlock(cSuperIntervals* si, size_t start_i, size_t end_i, int (*compare)(const void*, const void*)); +static size_t eytzinger_helper(int32_t* arr, size_t n, size_t i, size_t k, int32_t* eytz, size_t* eytz_index); +static int eytzinger(int32_t* arr, size_t n, int32_t* eytz, size_t* eytz_index); + +cSuperIntervals* createSuperIntervals() { + cSuperIntervals* si = (cSuperIntervals*)malloc(sizeof(cSuperIntervals)); + si->starts = NULL; + si->ends = NULL; + si->data = NULL; + si->branch = NULL; + si->extent = NULL; + si->size = 0; + si->capacity = 0; + si->idx = 0; + si->startSorted = true; + si->endSorted = true; + return si; +} + +void destroySuperIntervals(cSuperIntervals* si) { + free(si->starts); + free(si->ends); + free(si->data); + free(si->branch); + free(si->extent); + free(si); +} + +void clearSuperIntervals(cSuperIntervals* si) { + si->size = 0; + si->idx = 0; +} + +void reserveSuperIntervals(cSuperIntervals* si, size_t n) { + if (n > si->capacity) { + si->capacity = n; + si->starts = (int32_t*)realloc(si->starts, n * sizeof(int32_t)); + si->ends = (int32_t*)realloc(si->ends, n * sizeof(int32_t)); + si->data = (int32_t*)realloc(si->data, n * sizeof(int32_t)); + } +} + +void addInterval(cSuperIntervals* si, int32_t start, int32_t end, int32_t value) { + if (si->size >= si->capacity) { + size_t new_capacity = si->capacity == 0 ? 1 : si->capacity * 2; + reserveSuperIntervals(si, new_capacity); + } + + if (si->startSorted && si->size > 0) { + si->startSorted = (start < si->starts[si->size - 1]) ? false : true; + if (si->startSorted && start == si->starts[si->size - 1] && end > si->ends[si->size - 1]) { + si->endSorted = false; + } + } + + si->starts[si->size] = start; + si->ends[si->size] = end; + si->data[si->size] = value; + si->size++; +} + +static int compareIntervalsStart(const void* a, const void* b) { + const Interval* ia = (const Interval*)a; + const Interval* ib = (const Interval*)b; + if (ia->start != ib->start) { + return ia->start - ib->start; + } + return ib->end - ia->end; // Sort start, and end in descending order +} + +static int compareIntervalsEnd(const void* a, const void* b) { + const Interval* ia = (const Interval*)a; + const Interval* ib = (const Interval*)b; + return ib->end - ia->end; // Sort by end in descending order +} + +static void sortBlock(cSuperIntervals* si, size_t start_i, size_t end_i, + int (*compare)(const void*, const void*)) { + size_t range_size = end_i - start_i; + Interval* tmp = (Interval*)malloc(range_size * sizeof(Interval)); + + for (size_t i = 0; i < range_size; ++i) { + tmp[i].start = si->starts[start_i + i]; + tmp[i].end = si->ends[start_i + i]; + tmp[i].data = si->data[start_i + i]; + } + + qsort(tmp, range_size, sizeof(Interval), compare); + + for (size_t i = 0; i < range_size; ++i) { + si->starts[start_i + i] = tmp[i].start; + si->ends[start_i + i] = tmp[i].end; + si->data[start_i + i] = tmp[i].data; + } + + free(tmp); +} + +void sortIntervals(cSuperIntervals* si) { + if (!si->startSorted) { + sortBlock(si, 0, si->size, compareIntervalsStart); + si->startSorted = true; + si->endSorted = true; + } else if (!si->endSorted) { + size_t it_start = 0; + while (it_start < si->size) { + size_t block_end = it_start + 1; + bool needs_sort = false; + while (block_end < si->size && si->starts[block_end] == si->starts[it_start]) { + if (block_end > it_start && si->ends[block_end] > si->ends[block_end - 1]) { + needs_sort = true; + } + ++block_end; + } + if (needs_sort) { + sortBlock(si, it_start, block_end, compareIntervalsEnd); + } + it_start = block_end; + } + si->endSorted = true; + } +} + +static size_t eytzinger_helper(int32_t* arr, size_t n, size_t i, size_t k, int32_t* eytz, size_t* eytz_index) { + if (k < n) { + i = eytzinger_helper(arr, n, i, 2*k+1, eytz, eytz_index); + eytz[k] = arr[i]; + eytz_index[k] = i; + ++i; + i = eytzinger_helper(arr, n, i, 2*k + 2, eytz, eytz_index); + } + return i; +} + +static int eytzinger(int32_t* arr, size_t n, int32_t* eytz, size_t* eytz_index) { + return eytzinger_helper(arr, n, 0, 0, eytz, eytz_index); +} + +void indexSuperIntervals(cSuperIntervals* si, bool use_linear) { + if (si->size == 0) { + return; + } + + sortIntervals(si); + + int32_t* eytz = (int32_t*)malloc((si->size + 1) * sizeof(int32_t)); + size_t* eytz_index = (size_t*)malloc((si->size + 1) * sizeof(size_t)); + eytzinger(si->starts, si->size, eytz, eytz_index); + + si->branch = (size_t*)realloc(si->branch, si->size * sizeof(size_t)); + memset(si->branch, -1, si->size * sizeof(size_t)); + + if (!use_linear) { + si->extent = (int32_t*)malloc(si->size * sizeof(int32_t)); + memcpy(si->extent, si->ends, si->size * sizeof(int32_t)); + + for (size_t i = 0; i < si->size - 1; ++i) { + int32_t e = si->ends[i]; + for (size_t j = i + 1; j < si->size; ++j) { + if (si->ends[j] >= si->ends[i]) { + break; + } + si->branch[j] = i; + if (e > si->extent[j]) { + si->extent[j] = e; + } + } + } + } else { + // Linear branch implementation + size_t* br = (size_t*)malloc(si->size * sizeof(size_t)); + int32_t* br_ends = (int32_t*)malloc(si->size * sizeof(int32_t)); + size_t br_size = 0; + + br[br_size] = 0; + br_ends[br_size] = si->ends[0]; + br_size++; + + for (size_t i = 1; i < si->size; ++i) { + while (br_size > 0 && br_ends[br_size - 1] < si->ends[i]) { + br_size--; + } + if (br_size > 0) { + si->branch[i] = br[br_size - 1]; + } + br[br_size] = i; + br_ends[br_size] = si->ends[i]; + br_size++; + } + + free(br); + free(br_ends); + } + + free(eytz); + free(eytz_index); + si->idx = 0; +} + + +void upperBound(cSuperIntervals* si, int32_t value) { + size_t length = si->size - 1; + si->idx = 0; + const int entries_per_256KB = 256 * 1024 / sizeof(int32_t); + const int num_per_cache_line = 64 / sizeof(int32_t) > 1 ? 64 / sizeof(int32_t) : 1; + + if (length >= entries_per_256KB) { + while (length >= 3 * num_per_cache_line) { + size_t half = length / 2; +// __builtin_prefetch(&si->starts[si->idx + half / 2]); + size_t first_half1 = si->idx + (length - half); +// __builtin_prefetch(&si->starts[first_half1 + half / 2]); + si->idx += (si->starts[si->idx + half] <= value) * (length - half); + length = half; + } + } + + while (length > 0) { + size_t half = length / 2; + si->idx += (si->starts[si->idx + half] <= value) * (length - half); + length = half; + } + + if (si->idx > 0 && (si->idx == si->size - 1 || si->starts[si->idx] > value)) { + --si->idx; + } +} + +bool anyOverlaps(cSuperIntervals* si, int32_t start, int32_t end) { + upperBound(si, end); + return start <= si->ends[si->idx]; +} + +void findOverlaps(cSuperIntervals* si, int32_t start, int32_t end, int32_t* found, size_t* found_size) { + if (si->size == 0) { + *found_size = 0; + return; + } + upperBound(si, end); + size_t i = si->idx; + *found_size = 0; + while (i > 0) { + if (start <= si->ends[i]) { + found[(*found_size)++] = si->data[i]; + --i; + + } else { + if (si->branch[i] >= i) { // segfaults here + break; + } + i = si->branch[i]; + } + } + if (i == 0 && start <= si->ends[0] && si->starts[0] <= end) { + found[(*found_size)++] = si->data[0]; + } +} + + +#ifdef __cplusplus +} // extern C +#endif \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/src/lib.rs b/sequila/sequila-core/superintervals/src/lib.rs new file mode 100644 index 0000000..5983007 --- /dev/null +++ b/sequila/sequila-core/superintervals/src/lib.rs @@ -0,0 +1,4 @@ +mod superintervals; + +// Public API - everything available at root level +pub use superintervals::*; diff --git a/sequila/sequila-core/superintervals/src/superintervals.hpp b/sequila/sequila-core/superintervals/src/superintervals.hpp new file mode 100644 index 0000000..5203921 --- /dev/null +++ b/sequila/sequila-core/superintervals/src/superintervals.hpp @@ -0,0 +1,1009 @@ +// Version 0.3.6 +#pragma once + +#include +#include +#include +#include +#include +#include +#ifndef SI_NOSIMD + #if defined(__AVX2__) + #include + #elif defined(__ARM_NEON__) || defined(__aarch64__) + #include + #else + #define SI_NOSIMD + #endif +#endif + +namespace si { + +// Core data structure +template +struct Interval { + S start, end; + T data; + Interval() = default; + Interval(S s, S e, T d) : start(s), end(e), data(d) {} +}; + +/** + * @file SuperIntervals.hpp + * @brief IntervalMap is an associative data structure for finding interval intersections + * + * IntervalMap is a template class that provides efficient interval intersection operations. + * It supports adding intervals, indexing them for fast queries, and performing various + * intersection operations using an implicit superinterval tree algorithm. + * + * @note Intervals are considered end-inclusive + * @note The build() function must be called before any queries. If more intervals are added, call build() again. + * + * @tparam S The scalar type for interval start and end points (e.g., int, float) + * @tparam T The data type associated with each interval + */ +template + +class IntervalMap { + public: + std::vector starts; +// std::vector ends; + + #ifdef __AVX2__ + alignas(32) std::vector ends; + #elif defined(__ARM_NEON) + alignas(16) std::vector ends; + #else + std::vector ends; + #endif + std::vector branch; + std::vector data; + bool start_sorted, end_sorted; + + IntervalMap() : start_sorted(true), end_sorted(true) {} + virtual ~IntervalMap() = default; + + /** + * @brief Clears all intervals and resets the data structure + */ + void clear() noexcept { + data.clear(); starts.clear(); ends.clear(); branch.clear(); // idx = 0; + } + + /** + * @brief Reserves memory for a specified number of intervals + * @param n Number of intervals to reserve space for + */ + void reserve(size_t n) { + data.reserve(n); starts.reserve(n); ends.reserve(n); + } + + /** + * @brief Returns the number of intervals in the data structure + * @return Number of intervals + */ + size_t size() { + return starts.size(); + } + + /** + * @brief Adds a new interval to the data structure + * @param start Start point of the interval + * @param end End point of the interval + * @param value Data associated with the interval + */ + void add(S start, S end, const T& value) { + if (start_sorted && !starts.empty()) { + start_sorted = (start < starts.back()) ? false : true; + if (start_sorted && start == starts.back() && end > ends.back()) { + end_sorted = false; + } + } + starts.push_back(start); + ends.push_back(end); + data.emplace_back(value); + } + + /** + * @brief Build the index. + * This function must be called after adding intervals and before performing any queries. + * If more intervals are added after indexing, this function should be called again. + */ + virtual void build() { + if (starts.size() == 0) { + return; + } + sort_intervals(); + branch.resize(starts.size(), SIZE_MAX); + std::vector> br; + br.reserve((starts.size() / 10) + 1); + br.emplace_back() = {ends[0], 0}; + for (size_t i=1; i < ends.size(); ++i) { + while (!br.empty() && br.back().first < ends[i]) { + br.pop_back(); + } + if (!br.empty()) { + branch[i] = br.back().second; + } + br.emplace_back() = {ends[i], i}; + } + } + + /** + * @brief Retrieves an interval at a specific index + * @param index The index of the interval to retrieve + * @return The Interval at the specified index + */ + const Interval& at(size_t index) const { + return Interval{starts[index], ends[index], data[index]}; + } + /** + * @brief Retrieves an interval at a specific index and writes it into the provided Interval object. + * @param index The index of the interval to retrieve. + * @param itv Reference to an Interval object to populate with the retrieved interval. + */ + void at(size_t index, Interval& itv) { + itv.start = starts[index]; + itv.end = ends[index]; + itv.data = data[index]; + } + + // Iterator interfaces + class IndexIterator { + public: + IndexIterator(const IntervalMap* parent, size_t pos, S query_start) + : parent_(parent), current_pos_(pos), query_start_(query_start), + has_value_(false), value_(0) { + next(); + } + + size_t operator*() const { return value_; } + + IndexIterator& operator++() { + next(); + return *this; + } + + bool operator!=(const IndexIterator& other) const { + return has_value_ != other.has_value_; + } + + bool operator==(const IndexIterator& other) const { + return has_value_ == other.has_value_; + } + + private: + void next() const { + if (current_pos_ == SIZE_MAX) { + has_value_ = false; + return; + } + if (query_start_ <= parent_->ends[current_pos_]) { + value_ = parent_->data[current_pos_]; + current_pos_ -= 1; + has_value_ = true; + return; + } + while (true) { + current_pos_ =parent_->branch[current_pos_]; + if (current_pos_ == SIZE_MAX) { + break; + } + if (query_start_ <= parent_->ends[current_pos_]) { + value_ = parent_->data[current_pos_]; + current_pos_ -= 1; + has_value_ = true; + return; + } + } + has_value_ = false; + return; + } + + const IntervalMap* parent_; + mutable size_t current_pos_; + S query_start_; + mutable bool has_value_; + mutable size_t value_; + }; + + class IndexRange { + public: + IndexRange(const IntervalMap* parent, S start, S end) + : parent_(parent), start_(start), end_(end) {} + IndexIterator begin() const { + const size_t pos = parent_->starts.empty() ? SIZE_MAX : parent_->upper_bound(end_); + return IndexIterator(parent_, pos, start_); + } + IndexIterator end() const { + return IndexIterator(parent_, SIZE_MAX, start_); + } + private: + const IntervalMap* parent_; + S start_, end_; + }; + + /** + * @brief Creates a range object for iterating over interval indices that intersect [start, end] + * @param start Start point of the search range + * @param end End point of the search range + * @return IndexRange object that can be used with range-based for loops + */ + IndexRange search_idxs(S start, S end) const noexcept { + return IndexRange(this, start, end); + } + + class ItemIterator { + public: + ItemIterator(const IntervalMap* parent, size_t pos, S query_start) + : parent_(parent), current_pos_(pos), query_start_(query_start), + has_value_(false), value_{} { + next(); + } + + size_t operator*() const { return value_; } + + ItemIterator& operator++() { + next(); + return *this; + } + + bool operator!=(const ItemIterator& other) const { + return has_value_ != other.has_value_; + } + + bool operator==(const ItemIterator& other) const { + return has_value_ == other.has_value_; + } + + private: + void next() const { + if (current_pos_ == SIZE_MAX) { + has_value_ = false; + return; + } + if (query_start_ <= parent_->ends[current_pos_]) { + value_.start = parent_->starts[current_pos_]; + value_.end = parent_->ends[current_pos_]; + value_.data = parent_->data[current_pos_]; + current_pos_ -= 1; + has_value_ = true; + return; + } + while (true) { + current_pos_ =parent_->branch[current_pos_]; + if (current_pos_ == SIZE_MAX) { + break; + } + if (query_start_ <= parent_->ends[current_pos_]) { + value_.start = parent_->starts[current_pos_]; + value_.end = parent_->ends[current_pos_]; + value_.data = parent_->data[current_pos_]; + current_pos_ -= 1; + has_value_ = true; + return; + } + } + has_value_ = false; + return; + } + + const IntervalMap* parent_; + mutable size_t current_pos_; + S query_start_; + mutable bool has_value_; + mutable Interval value_; + }; + + class ItemRange { + public: + ItemRange(const IntervalMap* parent, S start, S end) + : parent_(parent), start_(start), end_(end) {} + ItemIterator begin() const { + const size_t pos = parent_->starts.empty() ? SIZE_MAX : parent_->upper_bound(end_); + return ItemIterator(parent_, pos, start_); + } + ItemIterator end() const { + return ItemIterator(parent_, SIZE_MAX, start_); + } + private: + const IntervalMap* parent_; + S start_, end_; + }; + + /** + * @brief Creates a range object for iterating over interval items that intersect [start, end] + * @param start Start point of the search range + * @param end End point of the search range + * @return ItemRange object that can be used with range-based for loops + */ + ItemRange search_items(S start, S end) const noexcept { + return ItemRange(this, start, end); + } + + /** + * @brief Finds the largest index such that starts[index] ≤ value. + * @param value The upper bound value to search for. + * @note After calling this, idx will be set to that index (or SIZE_MAX if none). + */ + virtual inline size_t upper_bound(const S value) const noexcept { + size_t length = starts.size(); + size_t idx = 0; + while (length > 1) { + const size_t half = length / 2; + idx += (starts[idx + half] <= value) * (length - half); + length = half; + } + if (starts[idx] > value) { + --idx; // Might underflow to SIZE_MAX + } + return idx; + +// idx = std::distance(starts.begin(), +// std::upper_bound(starts.begin(), starts.end(), value)) - 1; + } + + /** + * @brief Narrows a [left…right) range so that starts[left] is the first element ≥ value. + * @param value The lower‐bound value to search for. + * @param left On entry, the lower end of the search range; on exit, the found index or SIZE_MAX. + * @param right The upper end of the search range (exclusive). + */ + inline void upper_bound_range(const S value, size_t& left, const size_t right) const noexcept { + // First do exponential search if we have room + size_t search_right = right; + size_t bound = 1; + while (left > 0 && value < starts[left]) { + search_right = left; + left = (bound <= left) ? left - bound : 0; + bound *= 2; + } + // Now do binary search in the range + size_t length = search_right - left; + while (length > 1) { + const size_t half = length / 2; + left += (starts[left + half] < value) * (length - half); + length = half; + } + if (left == 0 && starts[left] >= value) { + left = SIZE_MAX; + } + } + + /** + * @brief Collects all data values whose intervals intersect [start,end]. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @param found Output vector that will be filled with matching data values. + */ + void search_values(const S start, const S end, std::vector& found) const { + if (starts.empty()) { + return; + } + const size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return; + } + size_t i = idx; + while (i != SIZE_MAX && start <= ends[i]) { + --i; + } + if (i == SIZE_MAX) { + found.insert(found.end(), data.rend() - idx - 1, data.rend()); + return; + } + found.insert(found.end(), data.rend() - idx - 1, data.rend() - i - 1); + // This is slightly faster, but would lose the sort order: + // found.insert(found.end(), data.begin() + i, data.begin() + idx); + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.push_back(data[i]); + --i; + } else { + i = branch[i]; + } + } + } + + /** + * @brief Like search_intervals, but optimized for large query ranges. + * @note Uses exponential search - best when query range is large relative to stored intervals + * @param start The start of the query interval. + * @param end The end of the query interval. + * @param found Output vector that will be filled with matching data values. + */ + void search_values_large(const S start, const S end, std::vector& found) const { + if (starts.empty()) { + return; + } + const size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return; + } + size_t i = idx; + upper_bound_range(start, i, idx); + while (i != SIZE_MAX && start <= ends[i]) { + --i; + } + if (i == SIZE_MAX) { + found.insert(found.end(), data.rend() - idx - 1, data.rend()); + return; + } + found.insert(found.end(), data.rend() - idx - 1, data.rend() - i - 1); + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.push_back(data[i]); + --i; + } else { + i = branch[i]; + } + } + } + + /** + * @brief Counts how many intervals intersect [start,end]. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @return Number of intervals overlapping the query. + */ + size_t count_linear(const S start, const S end) const noexcept { + if (starts.empty()) { + return 0; + } + size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return 0; + } + size_t starting_i = idx; + while (idx != SIZE_MAX && start <= ends[idx]) { + --idx; + } + size_t count = starting_i - idx; + if (idx == SIZE_MAX) { + return count; + } + idx = branch[idx]; + while (idx != SIZE_MAX) { + if (start <= ends[idx]) { + count += 1; + --idx; + } else { + idx = branch[idx]; + } + } + return count; + } + + size_t count(const S start, const S end) const noexcept { + if (starts.empty()) { + return 0; + } + size_t i = upper_bound(end); + if (i == SIZE_MAX) { + return 0; + } + size_t found = 0; + + #if defined(SI_NOSIMD) + + constexpr size_t block = 16; + constexpr simd_kind active_simd = simd_kind::none; + + #elif defined(__AVX2__) + + __m256i start_vec = _mm256_set1_epi32(start); + constexpr size_t simd_width = 256 / (sizeof(S) * 8); + constexpr size_t block = simd_width * 4; // 2 cache lines + constexpr simd_kind active_simd = simd_kind::avx2; + + #elif defined(__ARM_NEON__) || defined(__aarch64__) + + int32x4_t start_vec = vdupq_n_s32(start); + constexpr size_t simd_width = 128 / (sizeof(S) * 8); + uint32x4_t ones = vdupq_n_u32(1); + constexpr size_t block = simd_width * 8; // 2 cache lines + constexpr simd_kind active_simd = simd_kind::neon; + + #endif + + while (i > 0) { + if (start <= ends[i]) { + ++found; + --i; + // Types with width !=4 will use the no-simd path here + if constexpr (active_simd == simd_kind::none || sizeof(S) != 4) { + while (i > block) { + size_t count = 0; + for (size_t j = i; j > i - block; --j) { + count += (start <= ends[j]) ? 1 : 0; + } + found += count; + i -= block; + if (count < block && start > ends[i + 1]) { // check for a branch + break; + } + } + } + #if defined(__AVX2__) + else if constexpr (active_simd == simd_kind::avx2) { +// while (i > block) { +// size_t count = 0; +// for (size_t j = i - block + 1; j < i; j += simd_width) { +// __m256i ends_vec = _mm256_loadu_si256((__m256i*)(&ends[j - simd_width + 1])); +// __m256i cmp_mask = _mm256_cmpgt_epi32(start_vec, ends_vec); +// int mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask)); +// count += 8 - _mm_popcnt_u32(mask); +// } +// found += count; +// i -= block; +// if (count < block) { +// break; +// } +// } + + while (i > block) { + size_t j = i - block + 1; + + // Load all 4 vectors + __m256i ends_vec0 = _mm256_load_si256((__m256i*)(&ends[j])); + __m256i ends_vec1 = _mm256_load_si256((__m256i*)(&ends[j + simd_width])); + __m256i ends_vec2 = _mm256_load_si256((__m256i*)(&ends[j + 2 * simd_width])); + __m256i ends_vec3 = _mm256_load_si256((__m256i*)(&ends[j + 3 * simd_width])); + + // Compare all vectors + __m256i cmp_mask0 = _mm256_cmpgt_epi32(start_vec, ends_vec0); + __m256i cmp_mask1 = _mm256_cmpgt_epi32(start_vec, ends_vec1); + __m256i cmp_mask2 = _mm256_cmpgt_epi32(start_vec, ends_vec2); + __m256i cmp_mask3 = _mm256_cmpgt_epi32(start_vec, ends_vec3); + + // Extract masks + int mask0 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask0)); + int mask1 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask1)); + int mask2 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask2)); + int mask3 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask3)); + + // Count and accumulate + size_t count = (8 - _mm_popcnt_u32(mask0)) + (8 - _mm_popcnt_u32(mask1)) + + (8 - _mm_popcnt_u32(mask2)) + (8 - _mm_popcnt_u32(mask3)); + + found += count; + i -= block; + if (count < block) { + break; + } + } + } + #elif defined(__ARM_NEON__) || defined(__aarch64__) + else { // NEON +// while (i > block) { +// size_t count = 0; +// uint32x4_t mask, bool_mask; +// for (size_t j = i - block + 1; j < i; j += simd_width) { // Neon 4 int32 at a time +// int32x4_t ends_vec = vld1q_s32(&ends[j]); +// mask = vcgtq_s32(start_vec, ends_vec); // start > ends[j] +// bool_mask = vaddq_u32(mask, ones); +// count += vaddvq_u32(bool_mask); // Sum all lanes +// } +// found += count; +// i -= block; +// if (count < block) { // check for overlap again, before checking for branch? +// break; +// } +// } + while (i > block) { + size_t j = i - block + 1; + + // Load all 8 vectors + int32x4_t ends_vec0 = vld1q_s32(&ends[j]); + int32x4_t ends_vec1 = vld1q_s32(&ends[j + simd_width]); + int32x4_t ends_vec2 = vld1q_s32(&ends[j + 2 * simd_width]); + int32x4_t ends_vec3 = vld1q_s32(&ends[j + 3 * simd_width]); + int32x4_t ends_vec4 = vld1q_s32(&ends[j + 4 * simd_width]); + int32x4_t ends_vec5 = vld1q_s32(&ends[j + 5 * simd_width]); + int32x4_t ends_vec6 = vld1q_s32(&ends[j + 6 * simd_width]); + int32x4_t ends_vec7 = vld1q_s32(&ends[j + 7 * simd_width]); + + // Compare all vectors + uint32x4_t mask0 = vcgtq_s32(start_vec, ends_vec0); + uint32x4_t mask1 = vcgtq_s32(start_vec, ends_vec1); + uint32x4_t mask2 = vcgtq_s32(start_vec, ends_vec2); + uint32x4_t mask3 = vcgtq_s32(start_vec, ends_vec3); + uint32x4_t mask4 = vcgtq_s32(start_vec, ends_vec4); + uint32x4_t mask5 = vcgtq_s32(start_vec, ends_vec5); + uint32x4_t mask6 = vcgtq_s32(start_vec, ends_vec6); + uint32x4_t mask7 = vcgtq_s32(start_vec, ends_vec7); + + // Convert to boolean masks + uint32x4_t bool_mask0 = vaddq_u32(mask0, ones); + uint32x4_t bool_mask1 = vaddq_u32(mask1, ones); + uint32x4_t bool_mask2 = vaddq_u32(mask2, ones); + uint32x4_t bool_mask3 = vaddq_u32(mask3, ones); + uint32x4_t bool_mask4 = vaddq_u32(mask4, ones); + uint32x4_t bool_mask5 = vaddq_u32(mask5, ones); + uint32x4_t bool_mask6 = vaddq_u32(mask6, ones); + uint32x4_t bool_mask7 = vaddq_u32(mask7, ones); + + // Sum all lanes and accumulate + size_t count = vaddvq_u32(bool_mask0) + vaddvq_u32(bool_mask1) + + vaddvq_u32(bool_mask2) + vaddvq_u32(bool_mask3) + + vaddvq_u32(bool_mask4) + vaddvq_u32(bool_mask5) + + vaddvq_u32(bool_mask6) + vaddvq_u32(bool_mask7); + + found += count; + i -= block; + if (count < block) { + break; + } + } + } + #endif + } else { + if (branch[i] == SIZE_MAX) { + return found; + } + i = branch[i]; + } + } + if (i==0 && start <= ends[0] && starts[0] <= end) { + ++found; + } + return found; + } + + /** + * @brief Like count, but optimized for large query ranges. + * @note Uses exponential search - best when query range is large relative to stored intervals + * @param start The start of the query interval. + * @param end The end of the query interval. + * @return Number of intervals overlapping the query. + */ + size_t count_large(const S start, const S end) const noexcept { + if (starts.empty()) { + return 0; + } + const size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return 0; + } + size_t i = idx; + upper_bound_range(start, i, idx); + while (i != SIZE_MAX && start <= ends[i]) { + --i; + } + size_t count = idx - i; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + count += 1; + --i; + } else { + i = branch[i]; + } + } + return count; + } + + /** + * @brief Tests whether any interval overlaps the point range [start,end]. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @return true if at least one overlap exists, false otherwise. + */ + bool has_overlaps(const S start, const S end) const noexcept { + if (starts.empty()) { + return false; + } + const size_t idx = upper_bound(end); + return idx != SIZE_MAX && start <= ends[idx]; + } + + /** + * @brief Collects the indices of all intervals intersecting [start,end]. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @param found Output vector that will be filled with matching interval indices. + */ + void search_idxs(const S start, const S end, std::vector& found) const { + if (starts.empty()) { + return; + } + const size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return; + } + size_t i = idx; + while (i != SIZE_MAX && start <= ends[i]) { + --i; + } + if (i == SIZE_MAX) { + found.insert(found.end(), CountingIterator(0), CountingIterator(idx + 1)); + return; + } + found.insert(found.end(), CountingIterator(i + 1), CountingIterator(idx + 1)); + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.push_back(i); + --i; + } else { + i = branch[i]; + } + } + } + + /** + * @brief Collects the (start,end) pairs of all intervals intersecting [start,end]. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @param found Output vector that will be filled with matching interval key pairs. + */ + void search_keys(const S start, const S end, std::vector>& found) const { + if (starts.empty()) { + return; + } + const size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return; + } + size_t i = idx; + while (i != SIZE_MAX && start <= ends[i]) { + found.emplace_back() = {starts[i], ends[i]}; + --i; + } + if (i == SIZE_MAX) { + return; + } + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.emplace_back() = {starts[i], ends[i]}; + --i; + } else { + i = branch[i]; + } + } + } + + /** + * @brief Collects the full Interval objects intersecting [start,end]. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @param found Output vector that will be filled with matching Interval instances. + */ + void search_items(const S start, const S end, std::vector>& found) const { + if (starts.empty()) { + return; + } + const size_t idx = upper_bound(end); + if (idx == SIZE_MAX) { + return; + } + size_t i = idx; + while (i != SIZE_MAX && start <= ends[i]) { + found.emplace_back() = {starts[i], ends[i], data[i]}; + --i; + } + if (i == SIZE_MAX) { + return; + } + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.emplace_back() = {starts[i], ends[i], data[i]}; + --i; + } else { + i = branch[i]; + } + } + } + + /** + * @brief Computes how many intervals overlap [start,end] and the total covered length. + * @param start The start of the query interval. + * @param end The end of the query interval. + * @param cov_result Pair where first = count of overlaps, second = sum of overlapping lengths. + */ + void coverage(const S start, const S end, std::pair &cov_result) const { + if (starts.empty()) { + return; + } + size_t i = upper_bound(end); + if (i == SIZE_MAX) { + return; + } +// size_t i = idx; + while (i != SIZE_MAX && start <= ends[i]) { + ++cov_result.first; + cov_result.second += std::min(ends[i], end) - std::max(starts[i], start); + --i; + } + if (i == SIZE_MAX) { + return; + } + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + ++cov_result.first; + cov_result.second += std::min(ends[i], end) - std::max(starts[i], start); + --i; + } else { + i = branch[i]; + } + } + } + + /** + * @brief Finds all data values of intervals “stabbed” by a single point. + * @param point The point to test. + * @param found Output vector that will be filled with data values of intervals containing point. + */ + void search_point(const S point, std::vector& found) const { + if (starts.empty()) { + return; + } + size_t i = upper_bound(point); +// size_t i = idx; + while (i != SIZE_MAX && point <= ends[i]) { + found.push_back(data[i]); + --i; + } + i = branch[i]; + while (i != SIZE_MAX) { + if (point <= ends[i]) { + found.push_back(data[i]); + --i; + } else { + i = branch[i]; + } + } + } + + protected: + + enum class simd_kind { none, avx2, neon }; + + std::vector> tmp; + + template + void sort_block(size_t start_i, size_t end_i, CompareFunc compare) { + size_t range_size = end_i - start_i; + tmp.resize(range_size); + for (size_t i = 0; i < range_size; ++i) { + tmp[i] = Interval(starts[start_i + i], ends[start_i + i], data[start_i + i]); + } + std::sort(tmp.begin(), tmp.end(), compare); + for (size_t i = 0; i < range_size; ++i) { + starts[start_i + i] = tmp[i].start; + ends[start_i + i] = tmp[i].end; + data[start_i + i] = tmp[i].data; + } + } + + /** + * @brief Ensures the global interval list is properly sorted by start (and end within ties). + * If only ends need resorting among equal starts, does a more targeted pass. + */ + void sort_intervals() { + if (!start_sorted) { + sort_block(0, starts.size(), + [](const Interval& a, const Interval& b) { return (a.start < b.start || (a.start == b.start && a.end > b.end)); }); + start_sorted = true; + end_sorted = true; + } else if (!end_sorted) { // only sort parts that need sorting - ends in descending order + size_t it_start = 0; + while (it_start < starts.size()) { + size_t block_end = it_start + 1; + bool needs_sort = false; + while (block_end < starts.size() && starts[block_end] == starts[it_start]) { + if (block_end > it_start && ends[block_end] > ends[block_end - 1]) { + needs_sort = true; + } + ++block_end; + } + if (needs_sort) { + sort_block(it_start, block_end, [](const Interval& a, const Interval& b) { return a.end > b.end; }); + } + it_start = block_end; + } + end_sorted = true; + } + } + + struct CountingIterator { // Just for using insert indexes + size_t value; + using iterator_category = std::forward_iterator_tag; + using value_type = size_t; + using difference_type = std::ptrdiff_t; + using pointer = const size_t*; + using reference = size_t; + + explicit CountingIterator(size_t v) : value(v) {} + size_t operator*() const { return value; } + CountingIterator& operator++() { ++value; return *this; } + bool operator!=(const CountingIterator& other) const { return value != other.value; } + }; +}; + + +template +class IntervalMapEytz : public IntervalMap { +public: + /** + * @brief Builds the Eytzinger‐layout index and branch pointers for fast searching. + * @note Overrides the base build() to use the cache‐friendly layout. + */ + void build() override { + if (this->starts.size() == 0) { + return; + } + this->starts.shrink_to_fit(); + this->ends.shrink_to_fit(); + this->data.shrink_to_fit(); + this->sort_intervals(); + + eytz.resize(this->starts.size() + 1); + eytz_index.resize(this->starts.size() + 1); + eytzinger(&this->starts[0], this->starts.size()); + + this->branch.resize(this->starts.size(), SIZE_MAX); + std::vector> br; + br.reserve(1000); + br.emplace_back() = {this->ends[0], 0}; + for (size_t i=1; i < this->ends.size(); ++i) { + while (!br.empty() && br.back().first < this->ends[i]) { + br.pop_back(); + } + if (!br.empty()) { + this->branch[i] = br.back().second; + } + br.emplace_back() = {this->ends[i], i}; + } +// this->idx = 0; + } + + /** + * @brief Uses the Eytzinger‐layout tree to locate the largest start ≤ x. + * @param x The search value. + * @note Overrides the base upper_bound to navigate the implicit binary tree. + */ + inline size_t upper_bound(const S x) const noexcept override { + size_t i = 0; + const size_t n_intervals = this->starts.size(); + while (i < n_intervals) { + if (eytz[i] > x) { + i = 2 * i + 1; + } else { + i = 2 * i + 2; + } + } + int shift = __builtin_ffs(~(i + 1)); + size_t best_idx = (i >> shift) - ((shift > 1) ? 1 : 0); + i = (best_idx < n_intervals) ? eytz_index[best_idx] : n_intervals - 1; + if (i > 0 && this->starts[i] > x) { + --i; + } + return i; + } + +private: + std::vector eytz; + std::vector eytz_index; + + size_t eytzinger_helper(S* arr, size_t n, size_t i, size_t k) { + if (k < n) { + i = eytzinger_helper(arr, n, i, 2*k+1); + eytz[k] = this->starts[i]; + eytz_index[k] = i; + ++i; + i = eytzinger_helper(arr, n, i, 2*k + 2); + } + return i; + } + + int eytzinger(S* arr, size_t n) { + return eytzinger_helper(arr, n, 0, 0); + } +}; + +} // si diff --git a/sequila/sequila-core/superintervals/src/superintervals.rs b/sequila/sequila-core/superintervals/src/superintervals.rs new file mode 100644 index 0000000..bcbc9cf --- /dev/null +++ b/sequila/sequila-core/superintervals/src/superintervals.rs @@ -0,0 +1,1063 @@ +//! This module provides an associative data structure for performing interval intersection queries. + +use aligned_vec::{AVec, ConstAlign}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::cmp::{max, min}; + +/// Represents an interval with associated data. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Interval { + pub start: i32, + pub end: i32, + pub data: T, +} + +#[cfg(target_feature = "avx2")] +type AlignedEnds = AVec>; + +#[cfg(all(target_feature = "neon", not(target_feature = "avx2")))] +type AlignedEnds = AVec>; + +#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))] +type AlignedEnds = Vec; + +/// A static data structure for finding interval intersections +/// +/// IntervalMap is a template class that provides efficient interval intersection operations. +/// It supports adding intervals, indexing them for fast queries, and performing various +/// intersection operations. +/// +/// Intervals are considered end-inclusive +/// The build() function must be called before any queries. If more intervals are added, call build() again. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IntervalMap { + pub starts: Vec, + // pub ends: Vec, + pub ends: AlignedEnds, + pub branch: Vec, + pub data: Vec, + pub start_sorted: bool, + pub end_sorted: bool, +} + +impl IntervalMap { + pub fn new() -> Self { + IntervalMap { + starts: Vec::new(), + // ends: Vec::new(), + ends: { + #[cfg(target_feature = "avx2")] + { + AVec::new(32) + } + + #[cfg(all(target_feature = "neon", not(target_feature = "avx2")))] + { + AVec::new(16) + } + + #[cfg(not(any(target_feature = "avx2", target_feature = "neon")))] + { + Vec::new() + } + }, + + branch: Vec::new(), + data: Vec::new(), + start_sorted: true, + end_sorted: true, + } + } + /// Clears all intervals from the structure. + pub fn clear(&mut self) { + self.starts.clear(); + self.ends.clear(); + self.branch.clear(); + self.data.clear(); + } + + pub fn reserve(&mut self, n: usize) { + self.starts.reserve(n); + self.ends.reserve(n); + self.data.reserve(n); + } + + pub fn size(&self) -> usize { + self.starts.len() + } + + /// Adds a new interval to the structure. + /// + /// # Arguments + /// + /// * `start` - The start point of the interval. + /// * `end` - The end point of the interval. + /// * `value` - The data associated with the interval. + pub fn add(&mut self, start: i32, end: i32, value: T) { + if self.start_sorted && !self.starts.is_empty() { + self.start_sorted = start >= *self.starts.last().unwrap(); + if self.start_sorted + && start == *self.starts.last().unwrap() + && end > *self.ends.last().unwrap() + { + self.end_sorted = false; + } + } + self.starts.push(start); + self.ends.push(end); + self.data.push(value); + } + + pub fn at(&self, index: usize) -> Interval { + Interval { + start: self.starts[index], + end: self.ends[index], + data: self.data[index].clone(), + } + } + + /// Sorts the intervals by start and -end. + pub fn sort_intervals(&mut self) { + if !self.start_sorted { + self.sort_block(0, self.starts.len(), |a, b| { + if a.start < b.start { + Ordering::Less + } else if a.start == b.start { + if a.end > b.end { + Ordering::Less + } else { + Ordering::Greater + } + } else { + Ordering::Greater + } + }); + self.start_sorted = true; + self.end_sorted = true; + } else if !self.end_sorted { + let mut it_start = 0; + while it_start < self.starts.len() { + let mut block_end = it_start + 1; + let mut needs_sort = false; + while block_end < self.starts.len() + && self.starts[block_end] == self.starts[it_start] + { + if block_end > it_start && self.ends[block_end] > self.ends[block_end - 1] { + needs_sort = true; + } + block_end += 1; + } + if needs_sort { + self.sort_block(it_start, block_end, |a, b| a.end.cmp(&b.end).reverse()); + } + it_start = block_end; + } + self.end_sorted = true; + } + } + + /// Build the index. Must be called before queries are performed. + pub fn build(&mut self) { + if self.starts.is_empty() { + return; + } + self.sort_intervals(); + self.branch.resize(self.starts.len(), usize::MAX); + let mut br: Vec<(i32, usize)> = Vec::with_capacity((self.starts.len() / 10) + 1); + unsafe { + br.push((*self.ends.get_unchecked(0), 0)); + for i in 1..self.ends.len() { + while !br.is_empty() && br.last().unwrap().0 < *self.ends.get_unchecked(i) { + br.pop(); + } + if !br.is_empty() { + *self.branch.get_unchecked_mut(i) = br.last().unwrap().1; + } + br.push((*self.ends.get_unchecked(i), i)); + } + } + } + + #[inline(always)] + pub fn upper_bound(&self, value: i32) -> usize { + let mut idx = 0; + let mut length = self.starts.len(); + unsafe { + while length > 1 { + let half = length / 2; + idx += (*self.starts.get_unchecked(idx + half) <= value) as usize * (length - half); + length = half; + } + // idx = idx.wrapping_sub((*self.starts.get_unchecked(idx) > value) as usize); + if *self.starts.get_unchecked(idx) > value { + idx = idx.wrapping_sub(1); + } + } + return idx; + } + + #[inline(always)] + pub fn upper_bound_range(&self, value: i32, right: usize) -> usize { + let mut left = right; + let mut search_right = right; + let mut bound = 1; + // Exponential search to find a smaller range + unsafe { + while left > 0 && value < *self.starts.get_unchecked(left) { + search_right = left; + left = if bound <= left { left - bound } else { 0 }; + bound *= 2; + } + // Binary search in the found range + let mut length = search_right - left; + while length > 1 { + let half = length / 2; + let condition = *self.starts.get_unchecked(left + half) < value; + left += if condition { length - half } else { 0 }; + length = half; + } + if left == 0 && *self.starts.get_unchecked(left) >= value { + left = usize::MAX; + } + } + left + } + + pub fn has_overlaps(&self, start: i32, end: i32) -> bool { + if self.starts.is_empty() { + return false; + } + let idx = self.upper_bound(end); + unsafe { idx != usize::MAX && start <= *self.ends.get_unchecked(idx) } + } + + // This is the simple algorithm, but suffers performance wise due to the branching + // pub fn find_overlaps(&mut self, start: i32, end: i32, found: &mut Vec) { + // if self.starts.is_empty() { + // return; + // } + // let mut i = self.upper_bound(end); + // + // unsafe { + // while i != usize::MAX { + // if start <= *self.ends.get_unchecked(i) { + // found.push(self.data.get_unchecked(i).clone()); + // i = i.wrapping_sub(1); + // } else { + // i = *self.branch.get_unchecked(i); + // } + // } + // } + // } + + /// Finds all intervals that overlap with the given range. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// * `found` - A mutable vector to store the overlapping intervals' data. + pub fn search_values(&self, start: i32, end: i32, found: &mut Vec) { + if self.starts.is_empty() { + return; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return; + } + let mut i = idx; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + found.reserve(idx + 1); + found.extend((0..=idx).rev().map(|j| self.data.get_unchecked(j).clone())); + return; + } + let count = idx - i; + found.reserve(count); + found.extend( + ((i + 1)..=idx) + .rev() + .map(|j| self.data.get_unchecked(j).clone()), + ); + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push(self.data.get_unchecked(i).clone()); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + + /// Finds all intervals that overlap with the given range. Works best when + /// query intervals are large compared to stored database intervals. Uses + /// an exponential search to find overlaps + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// * `found` - A mutable vector to store the overlapping intervals' data. + pub fn search_values_large(&self, start: i32, end: i32, found: &mut Vec) { + if self.starts.is_empty() { + return; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return; + } + let mut i = self.upper_bound_range(start, idx); + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + found.reserve(idx + 1); + found.extend((0..=idx).rev().map(|j| self.data.get_unchecked(j).clone())); + return; + } + found.reserve(idx - i); + found.extend( + ((i + 1)..=idx) + .rev() + .map(|j| self.data.get_unchecked(j).clone()), + ); + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push(self.data.get_unchecked(i).clone()); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + + /// Counts all intervals that overlap with the given range. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals. + pub fn count_linear(&self, start: i32, end: i32) -> usize { + if self.starts.is_empty() { + return 0; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return 0; + } + let mut i = idx; + let mut count: usize = 0; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + return idx + 1; + } + count += idx - i; + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + count += 1; + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + count + } + + /// Counts the number of intervals that overlap with the given range. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals. + /// Counts the number of intervals that overlap with the given range. + pub fn count(&self, start: i32, end: i32) -> usize { + if self.starts.is_empty() { + return 0; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return 0; + } + let mut i = idx; + let mut found = 0; + + unsafe { + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::*; + let start_vec = _mm256_set1_epi32(start); + const SIMD_WIDTH: usize = 8; // 256 bits / 32 bits = 8 elements + const BLOCK: usize = SIMD_WIDTH * 4; // 32 elements per block, 2 cache lines + + while i > 0 { + if start <= *self.ends.get_unchecked(i) { + found += 1; + i -= 1; + + // Process blocks of data with SIMD + // while i > BLOCK { + // let mut count = 0; + // let mut j = i; + // // Process SIMD_WIDTH elements at a time + // while j > i - BLOCK { + // let end_idx = if j >= SIMD_WIDTH { j - SIMD_WIDTH + 1 } else { 0 }; + // let ends_vec = _mm256_loadu_si256(self.ends.as_ptr().add(end_idx) as *const __m256i); + // let cmp_mask = _mm256_cmpgt_epi32(start_vec, ends_vec); + // let mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask)); + // count += 8 - (mask).count_ones() as usize; + // j = j.saturating_sub(SIMD_WIDTH); + // } + // + // found += count; + // i -= BLOCK; + // + // // Early exit if we didn't find full block worth of matches + // if count < BLOCK && start > *self.ends.get_unchecked(i + 1) { + // break; + // } + // } + while i > BLOCK { + let j = i - BLOCK + 1; + + // Load all 4 vectors + let ends_vec0 = + _mm256_loadu_si256(self.ends.as_ptr().add(j) as *const __m256i); + let ends_vec1 = _mm256_loadu_si256( + self.ends.as_ptr().add(j + SIMD_WIDTH) as *const __m256i, + ); + let ends_vec2 = _mm256_loadu_si256( + self.ends.as_ptr().add(j + 2 * SIMD_WIDTH) as *const __m256i, + ); + let ends_vec3 = _mm256_loadu_si256( + self.ends.as_ptr().add(j + 3 * SIMD_WIDTH) as *const __m256i, + ); + + // Compare all vectors + let cmp_mask0 = _mm256_cmpgt_epi32(start_vec, ends_vec0); + let cmp_mask1 = _mm256_cmpgt_epi32(start_vec, ends_vec1); + let cmp_mask2 = _mm256_cmpgt_epi32(start_vec, ends_vec2); + let cmp_mask3 = _mm256_cmpgt_epi32(start_vec, ends_vec3); + + // Extract masks + let mask0 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask0)); + let mask1 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask1)); + let mask2 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask2)); + let mask3 = _mm256_movemask_ps(_mm256_castsi256_ps(cmp_mask3)); + + // Count and accumulate + let count = (8 - (mask0).count_ones() as usize) + + (8 - (mask1).count_ones() as usize) + + (8 - (mask2).count_ones() as usize) + + (8 - (mask3).count_ones() as usize); + + found += count; + i -= BLOCK; + + // Early exit if we didn't find full block worth of matches + if count < BLOCK && start > *self.ends.get_unchecked(i + 1) { + break; + } + } + } else { + if *self.branch.get_unchecked(i) == usize::MAX { + return found; + } + i = *self.branch.get_unchecked(i); + } + } + } + + #[cfg(target_arch = "aarch64")] + { + use std::arch::aarch64::*; + let start_vec = vdupq_n_s32(start); + const SIMD_WIDTH: usize = 4; // 128 bits / 32 bits = 4 elements + const BLOCK: usize = SIMD_WIDTH * 8; // 2 cache lines + let ones = vdupq_n_u32(1); + while i > 0 { + if start <= *self.ends.get_unchecked(i) { + found += 1; + i -= 1; + // while i > BLOCK { + // let mut count = 0; + // let mut j = i; + // while j > i - BLOCK { + // let end_idx = if j >= SIMD_WIDTH { j - SIMD_WIDTH + 1 } else { 0 }; + // let ends_vec = vld1q_s32(self.ends.as_ptr().add(end_idx) as *const i32); + // let mask = vcgtq_s32(start_vec, ends_vec); + // let bool_mask = vaddq_u32(mask, ones); + // count += vaddvq_u32(bool_mask) as usize; + // j = j.saturating_sub(SIMD_WIDTH); + // } + // found += count; + // i -= BLOCK; + // if count < BLOCK { + // break; + // } + // } + + while i > BLOCK { + let j = i - BLOCK + 1; + + // Load all 8 vectors + let ends_vec0 = vld1q_s32(self.ends.as_ptr().add(j) as *const i32); + let ends_vec1 = + vld1q_s32(self.ends.as_ptr().add(j + SIMD_WIDTH) as *const i32); + let ends_vec2 = + vld1q_s32(self.ends.as_ptr().add(j + 2 * SIMD_WIDTH) as *const i32); + let ends_vec3 = + vld1q_s32(self.ends.as_ptr().add(j + 3 * SIMD_WIDTH) as *const i32); + let ends_vec4 = + vld1q_s32(self.ends.as_ptr().add(j + 4 * SIMD_WIDTH) as *const i32); + let ends_vec5 = + vld1q_s32(self.ends.as_ptr().add(j + 5 * SIMD_WIDTH) as *const i32); + let ends_vec6 = + vld1q_s32(self.ends.as_ptr().add(j + 6 * SIMD_WIDTH) as *const i32); + let ends_vec7 = + vld1q_s32(self.ends.as_ptr().add(j + 7 * SIMD_WIDTH) as *const i32); + + // Compare all vectors + let mask0 = vcgtq_s32(start_vec, ends_vec0); + let mask1 = vcgtq_s32(start_vec, ends_vec1); + let mask2 = vcgtq_s32(start_vec, ends_vec2); + let mask3 = vcgtq_s32(start_vec, ends_vec3); + let mask4 = vcgtq_s32(start_vec, ends_vec4); + let mask5 = vcgtq_s32(start_vec, ends_vec5); + let mask6 = vcgtq_s32(start_vec, ends_vec6); + let mask7 = vcgtq_s32(start_vec, ends_vec7); + + // Convert to boolean masks + let bool_mask0 = vaddq_u32(mask0, ones); + let bool_mask1 = vaddq_u32(mask1, ones); + let bool_mask2 = vaddq_u32(mask2, ones); + let bool_mask3 = vaddq_u32(mask3, ones); + let bool_mask4 = vaddq_u32(mask4, ones); + let bool_mask5 = vaddq_u32(mask5, ones); + let bool_mask6 = vaddq_u32(mask6, ones); + let bool_mask7 = vaddq_u32(mask7, ones); + + // Sum all lanes and accumulate + let count = vaddvq_u32(bool_mask0) as usize + + vaddvq_u32(bool_mask1) as usize + + vaddvq_u32(bool_mask2) as usize + + vaddvq_u32(bool_mask3) as usize + + vaddvq_u32(bool_mask4) as usize + + vaddvq_u32(bool_mask5) as usize + + vaddvq_u32(bool_mask6) as usize + + vaddvq_u32(bool_mask7) as usize; + + found += count; + i -= BLOCK; + if count < BLOCK { + break; + } + } + } else { + if *self.branch.get_unchecked(i) == usize::MAX { + return found; + } + i = *self.branch.get_unchecked(i); + } + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + const BLOCK: usize = 16; + + while i > 0 { + if start <= *self.ends.get_unchecked(i) { + found += 1; + i -= 1; + while i > BLOCK { + let mut count = 0; + for j in (i - BLOCK + 1..=i).rev() { + if start <= *self.ends.get_unchecked(j) { + count += 1; + } + } + found += count; + i -= BLOCK; + if count < BLOCK && start > *self.ends.get_unchecked(i + 1) { + break; + } + } + } else { + if *self.branch.get_unchecked(i) == usize::MAX { + return found; + } + i = *self.branch.get_unchecked(i); + } + } + } + // Final check for element at index 0 + if i == 0 + && start <= *self.ends.get_unchecked(0) + && *self.starts.get_unchecked(0) <= end + { + found += 1; + } + } + + found + } + + /// Counts all intervals that overlap with the given range. Works best when + /// query intervals are large compared to stored database intervals. Uses + /// an exponential search to find overlaps + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals. + pub fn count_large(&self, start: i32, end: i32) -> usize { + if self.starts.is_empty() { + return 0; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return 0; + } + let mut i = self.upper_bound_range(start, idx); + let mut count: usize = 0; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + return idx + 1; + } + count += idx - i; + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + count += 1; + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + count + } + + pub fn search_idxs(&self, start: i32, end: i32, found: &mut Vec) { + if self.starts.is_empty() { + return; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return; + } + let mut i = idx; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + found.extend((0..=idx).rev()); + return; + } + found.extend(((i + 1)..=idx).rev()); + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push(i); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + + pub fn search_keys(&self, start: i32, end: i32, found: &mut Vec<(i32, i32)>) { + if self.starts.is_empty() { + return; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return; + } + let mut i = idx; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + found.push((*self.starts.get_unchecked(i), *self.ends.get_unchecked(i))); + i = i.wrapping_sub(1); + } + if i != usize::MAX { + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push((*self.starts.get_unchecked(i), *self.ends.get_unchecked(i))); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + } + + pub fn search_items(&self, start: i32, end: i32, found: &mut Vec>) { + if self.starts.is_empty() { + return; + } + let idx = self.upper_bound(end); + if idx == usize::MAX { + return; + } + let mut i = idx; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + found.push(Interval { + start: *self.starts.get_unchecked(i), + end: *self.ends.get_unchecked(i), + data: self.data.get_unchecked(i).clone(), + }); + i = i.wrapping_sub(1); + } + if i != usize::MAX { + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push(Interval { + start: *self.starts.get_unchecked(i), + end: *self.ends.get_unchecked(i), + data: self.data.get_unchecked(i).clone(), + }); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + } + + pub fn search_stabbed(&self, point: i32, found: &mut Vec) { + if self.starts.is_empty() { + return; + } + let idx = self.upper_bound(point); + if idx == usize::MAX { + return; + } + let mut i = idx; + unsafe { + while i != usize::MAX && point <= *self.ends.get_unchecked(i) { + found.push(self.data.get_unchecked(i).clone()); + i = i.wrapping_sub(1); + } + if i != usize::MAX { + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if point <= *self.ends.get_unchecked(i) { + found.push(self.data.get_unchecked(i).clone()); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + } + + /// Counts the total coverage over the query interval. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals, plus the coverage. + pub fn coverage(&self, start: i32, end: i32) -> (usize, i32) { + if self.starts.is_empty() { + return (0, 0); + } + let mut i = self.upper_bound(end); + let mut count: usize = 0; + let mut coverage: i32 = 0; + unsafe { + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + count += 1; + coverage += min(*self.ends.get_unchecked(i), end) + - max(*self.starts.get_unchecked(i), start); + i -= 1; + } else { + i = *self.branch.get_unchecked(i); + } + } + } + (count, coverage) + } + + fn sort_block(&mut self, start_i: usize, end_i: usize, compare: F) + where + F: Fn(&Interval, &Interval) -> Ordering, + { + unsafe { + let range_size = end_i - start_i; + let mut tmp: Vec> = Vec::with_capacity(range_size); + for i in 0..range_size { + tmp.push(Interval { + start: *self.starts.get_unchecked(start_i + i), + end: *self.ends.get_unchecked(start_i + i), + data: (*self.data.get_unchecked(start_i + i)).clone(), + }); + } + tmp.sort_by(compare); + for i in 0..range_size { + self.starts[start_i + i] = tmp.get_unchecked(i).start; + self.ends[start_i + i] = tmp.get_unchecked(i).end; + self.data[start_i + i] = tmp.get_unchecked(i).data.clone(); + } + } + } +} + +// Iterator interfaces + +pub struct IndexIterator<'a, T> { + tree: &'a IntervalMap, + current_idx: usize, + query_start: i32, +} + +impl<'a, T: Clone> Iterator for IndexIterator<'a, T> { + type Item = usize; + + fn next(&mut self) -> Option { + if self.current_idx == usize::MAX { + return None; + } + unsafe { + // Linear scan backwards + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = self.current_idx; + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + // Traverse branch array + loop { + self.current_idx = *self.tree.branch.get_unchecked(self.current_idx); + if self.current_idx == usize::MAX { + break; + } + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = self.current_idx; + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + } + } + None + } +} + +pub struct ItemIterator<'a, T> { + tree: &'a IntervalMap, + current_idx: usize, + query_start: i32, +} + +impl<'a, T: Clone> Iterator for ItemIterator<'a, T> { + type Item = Interval; + + fn next(&mut self) -> Option { + if self.current_idx == usize::MAX { + return None; + } + unsafe { + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = Interval { + start: *self.tree.starts.get_unchecked(self.current_idx), + end: *self.tree.ends.get_unchecked(self.current_idx), + data: self.tree.data.get_unchecked(self.current_idx).clone(), + }; + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + loop { + self.current_idx = *self.tree.branch.get_unchecked(self.current_idx); + if self.current_idx == usize::MAX { + break; + } + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = Interval { + start: *self.tree.starts.get_unchecked(self.current_idx), + end: *self.tree.ends.get_unchecked(self.current_idx), + data: self.tree.data.get_unchecked(self.current_idx).clone(), + }; + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + } + } + None + } +} + +pub struct KeyIterator<'a, T> { + tree: &'a IntervalMap, + current_idx: usize, + query_start: i32, +} + +impl<'a, T: Clone> Iterator for KeyIterator<'a, T> { + type Item = (i32, i32); + + fn next(&mut self) -> Option { + if self.current_idx == usize::MAX { + return None; + } + unsafe { + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = ( + *self.tree.starts.get_unchecked(self.current_idx), + *self.tree.ends.get_unchecked(self.current_idx), + ); + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + loop { + self.current_idx = *self.tree.branch.get_unchecked(self.current_idx); + if self.current_idx == usize::MAX { + break; + } + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = ( + *self.tree.starts.get_unchecked(self.current_idx), + *self.tree.ends.get_unchecked(self.current_idx), + ); + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + } + } + None + } +} + +pub struct ValueIterator<'a, T> { + tree: &'a IntervalMap, + current_idx: usize, + query_start: i32, +} + +impl<'a, T: Clone> Iterator for ValueIterator<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + if self.current_idx == usize::MAX { + return None; + } + unsafe { + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = self.tree.data.get_unchecked(self.current_idx).clone(); + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + loop { + self.current_idx = *self.tree.branch.get_unchecked(self.current_idx); + if self.current_idx == usize::MAX { + break; + } + if self.query_start <= *self.tree.ends.get_unchecked(self.current_idx) { + let value = self.tree.data.get_unchecked(self.current_idx).clone(); + self.current_idx = self.current_idx.wrapping_sub(1); + return Some(value); + } + } + } + None + } +} + +// Updated iterator creation methods for IntervalMap +impl IntervalMap { + /// Returns an iterator over indices of intervals that intersect [start, end] + pub fn search_idxs_iter(&self, start: i32, end: i32) -> IndexIterator { + let current_idx = if self.starts.is_empty() { + usize::MAX + } else { + self.upper_bound(end) + }; + IndexIterator { + tree: self, + current_idx, + query_start: start, + } + } + + /// Returns an iterator over items (intervals with data) that intersect [start, end] + pub fn search_items_iter(&self, start: i32, end: i32) -> ItemIterator { + let current_idx = if self.starts.is_empty() { + usize::MAX + } else { + self.upper_bound(end) + }; + ItemIterator { + tree: self, + current_idx, + query_start: start, + } + } + + /// Returns an iterator over keys (interval start/end pairs) that intersect [start, end] + pub fn search_keys_iter(&self, start: i32, end: i32) -> KeyIterator { + let current_idx = if self.starts.is_empty() { + usize::MAX + } else { + self.upper_bound(end) + }; + KeyIterator { + tree: self, + current_idx, + query_start: start, + } + } + + /// Returns an iterator over values (data) of intervals that intersect [start, end] + pub fn search_values_iter(&self, start: i32, end: i32) -> ValueIterator { + let current_idx = if self.starts.is_empty() { + usize::MAX + } else { + self.upper_bound(end) + }; + ValueIterator { + tree: self, + current_idx, + query_start: start, + } + } +} diff --git a/sequila/sequila-core/superintervals/src/superintervals/README.md b/sequila/sequila-core/superintervals/src/superintervals/README.md new file mode 100644 index 0000000..f9c1c70 --- /dev/null +++ b/sequila/sequila-core/superintervals/src/superintervals/README.md @@ -0,0 +1,119 @@ +# Superintervals Python API + +Install using `pip install superintervals` + +```python +from superintervals import IntervalMap +from array import array + +# Method 1: Manual construction +imap = IntervalMap() +imap.add(10, 20, 'A') +imap.add(15, 25, 'B') +imap.build() +results = imap.search_values(8, 20) # ['A', 'B'] + +# Method 2: Efficient construction from arrays (no build() needed) +starts = array('i', [10, 15, 30]) +ends = array('i', [20, 25, 40]) +values = ['A', 'B', 'C'] +imap = IntervalMap.from_arrays(starts, ends, values) +results = imap.search_values(8, 20) # ['A', 'B'] + +# Batch operations for high performance +query_starts = array('i', [5, 18, 35]) +query_ends = array('i', [12, 22, 45]) +counts = imap.count_batch(query_starts, query_ends) # [1, 2, 1] +indices = imap.search_idxs_batch(query_starts, query_ends) # [[0], [0, 1], [2]] +``` + +### API Reference + +**IntervalMap** class: + +#### Construction Methods +- `IntervalMap()` + Create empty interval map + +- `IntervalMap.from_arrays(starts, ends, values=None)` + Create from arrays (array.array, numpy arrays, or lists) + Returns ready-to-use IntervalMap (no build() needed) + +#### Adding Intervals +- `add(start, end, value=None)` + Add interval with associated value + +- `build()` + Build index (required before queries when using add()) + +#### Memory Management +- `clear()` + Remove all intervals + +- `reserve(n)` + Reserve space for n intervals + +- `size()` + Get number of intervals + +#### Access Methods +- `at(index)` + Get interval at index as (start, end, value) + +- `starts_at(index)` + Get start position at index + +- `ends_at(index)` + Get end position at index + +- `data_at(index)` + Get value at index + +#### Single Query Methods +- `has_overlaps(start, end)` + Check if any intervals overlap range + +- `count(start, end)` + Count overlapping intervals + +- `search_values(start, end)` + Get values of overlapping intervals + +- `search_idxs(start, end)` + Get indices of overlapping intervals + +- `search_keys(start, end)` + Get (start, end) pairs of overlapping intervals + +- `search_items(start, end)` + Get (start, end, value) tuples of overlapping intervals + +- `coverage(start, end)` + Get (count, total_coverage) for range + +#### Batch Query Methods (High Performance) +- `count_batch(starts, ends)` + Count overlaps for multiple ranges + Args: Memory views (array.array, numpy arrays) + Returns: List of counts + +- `search_idxs_batch(starts, ends)` + Get indices for multiple ranges + Args: Memory views (array.array, numpy arrays) + Returns: List of lists containing indices + +- `search_values_batch(starts, ends)` + Get values for multiple ranges + Args: Memory views (array.array, numpy arrays) + Returns: List of lists containing values + +### Performance Tips + +- Use `IntervalMap.from_arrays()` for best construction performance +- Use batch methods for multiple queries (often 5-10x faster) +- Convert lists to arrays for batch operations: + ```python + from array import array + starts = array('i', [1, 5, 10]) # For batch methods + ``` +- Reserve space with `reserve(n)` when adding many intervals manually \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/src/superintervals/__init__.py b/sequila/sequila-core/superintervals/src/superintervals/__init__.py new file mode 100644 index 0000000..d29a7f8 --- /dev/null +++ b/sequila/sequila-core/superintervals/src/superintervals/__init__.py @@ -0,0 +1,2 @@ +from .intervalmap import ( + IntervalMap) \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/src/superintervals/intervalmap.pxd b/sequila/sequila-core/superintervals/src/superintervals/intervalmap.pxd new file mode 100644 index 0000000..31ea0d1 --- /dev/null +++ b/sequila/sequila-core/superintervals/src/superintervals/intervalmap.pxd @@ -0,0 +1,106 @@ +# distutils: language = c++ +from libcpp.vector cimport vector +from libcpp.pair cimport pair + +cdef extern from "Python.h": + void Py_INCREF(object) + void Py_DECREF(object) + + +cdef extern from "superintervals.hpp" namespace "si": + + cdef cppclass Interval[S, T]: + S start, end + T data + Interval() + Interval(S s, S e, T d) + + + cdef cppclass CppIntervalMap "si::IntervalMap"[S, T]: + IntervalMap() except + + + vector[S] starts, ends + vector[size_t] branch + vector[T] data + size_t idx + + void clear() + void reserve(size_t n) + size_t size() + void add(S start, S end, const T& value) + void build() + void at(size_t index, Interval[S, T]& itv) + const Interval[S, T]& at(size_t index) const + + # Search methods + void upper_bound(const S value) const + bint has_overlaps(const S start, const S end) + size_t count_linear(const S start, const S end) + size_t count(const S start, const S end) + size_t count_large(const S start, const S end) + + void search_values(const S start, const S end, vector[T]& found) + void search_values_large(const S start, const S end, vector[T]& found) + void search_idxs(const S start, const S end, vector[size_t]& found) + void search_keys(const S start, const S end, vector[pair[S, S]]& found) + void search_items(const S start, const S end, vector[Interval[S, T]]& found) + void search_point(const S point, vector[T]& found) + void coverage(const S start, const S end, pair[size_t, S]& cov_result) + + # Iterator classes not yet implemented! + # cppclass IndexIterator: + # IndexIterator(const IntervalMap * parent, size_t pos) + # size_t operator *() const + # IndexIterator& operator++() + # bint operator !=(const IndexIterator& other) const + # bint operator ==(const IndexIterator& other) const + # + # cppclass ItemIterator: + # ItemIterator(const IntervalMap * parent, size_t pos) + # Interval[S, T] operator *() const + # ItemIterator& operator++() + # bint operator !=(const ItemIterator& other) const + # bint operator ==(const ItemIterator& other) const + # + # cppclass IndexRange: + # IndexRange(const IntervalMap * parent, S start, S end) + # IndexIterator begin() const + # IndexIterator end() const + # + # cppclass ItemRange: + # ItemRange(const IntervalMap * parent, S start, S end) + # ItemIterator begin() const + # ItemIterator end() const + # + # IndexRange search_idxs(S start, S end) const + # ItemRange search_items(S start, S end) const + + + +# Type alias for Python object pointer +ctypedef void * PyObjectPtr + +cdef class IntervalMap: + cdef CppIntervalMap[int, PyObjectPtr] * thisptr + cdef vector[PyObjectPtr] found_values + cdef vector[size_t] found_indexes + + cpdef add(self, int start, int end, object value= *) + cpdef build(self) + cpdef at(self, int index) + cpdef starts_at(self, int index) + cpdef ends_at(self, int index) + cpdef data_at(self, int index) + cpdef clear(self) + cpdef reserve(self, size_t n) + cpdef size(self) + cpdef has_overlaps(self, int start, int end) + cpdef count(self, int start, int end) + cpdef search_values(self, int start, int end) + cpdef search_idxs(self, int start, int end) + cpdef search_keys(self, int start, int end) + cpdef search_items(self, int start, int end) + cpdef coverage(self, int start, int end) + cpdef count_batch(self, int[:] starts, int[:] ends) + cpdef search_idxs_batch(self, int[:] starts, int[:] ends) + cpdef search_values_batch(self, int[:] starts, int[:] ends) diff --git a/sequila/sequila-core/superintervals/src/superintervals/intervalmap.pyx b/sequila/sequila-core/superintervals/src/superintervals/intervalmap.pyx new file mode 100644 index 0000000..f80a2ca --- /dev/null +++ b/sequila/sequila-core/superintervals/src/superintervals/intervalmap.pyx @@ -0,0 +1,479 @@ + +from libcpp.pair cimport pair + +__all__ = ["IntervalMap"] + + +cdef class IntervalMap: + """ + SuperIntervals interval map to manage a collection of intervals with associated Python objects, + supporting operations such as adding intervals, checking overlaps, and querying stored data. + """ + + def __cinit__(self): + """ + Initialize the IntervalMap. + """ + self.thisptr = new CppIntervalMap[int, PyObjectPtr]() + + def __dealloc__(self): + cdef PyObjectPtr obj_ptr + if self.thisptr: + for obj_ptr in self.thisptr.data: + if obj_ptr != NULL: + Py_DECREF( obj_ptr) + del self.thisptr + + def __len__(self): + return self.size() + + def __getitem__(self, int index): + return self.at(index) + + cpdef add(self, int start, int end, object value=None): + """ + Add an interval with an associated Python object. + + Args: + start (int): The start of the interval (inclusive). + end (int): The end of the interval (inclusive). + value (object): The Python object to associate with this interval. + + Updates: + - Adds the interval to the underlying data structure. + - Stores a reference to the Python object directly in C++. + """ + cdef PyObjectPtr obj_ptr = NULL + if value is not None: + obj_ptr = value + Py_INCREF(value) # Increment reference count + self.thisptr.add(start, end, obj_ptr) + + @classmethod + def from_arrays(cls, starts, ends, values=None): + """ + Create an IntervalMap from arrays of starts, ends, and optional values. + + This is the most efficient way to create an IntervalMap when you have + existing data in array format. The resulting IntervalMap is ready to use + (no need to call build()). + + Args: + starts: Array-like of start positions (array.array, numpy array, list, etc.) + ends: Array-like of end positions (array.array, numpy array, list, etc.) + values: Optional iterable of values to associate with each interval. + If None, no values are stored. + + Returns: + IntervalMap: A new IntervalMap ready for queries + + Examples: + >>> from array import array + >>> import numpy as np + >>> + >>> # Using array.array + >>> starts = array('i', [1, 10, 20]) + >>> ends = array('i', [5, 15, 25]) + >>> values = ["gene1", "gene2", "gene3"] + >>> im = IntervalMap.from_arrays(starts, ends, values) + >>> + >>> # Using numpy arrays + >>> starts = np.array([1, 10, 20], dtype=np.int32) + >>> ends = np.array([5, 15, 25], dtype=np.int32) + >>> im = IntervalMap.from_arrays(starts, ends) + >>> + >>> # Using lists (will be converted internally) + >>> im = IntervalMap.from_arrays([1, 10, 20], [5, 15, 25], ["A", "B", "C"]) + """ + cdef IntervalMap instance = cls() + cdef int[:] start_view + cdef int[:] end_view + if hasattr(starts, 'shape'): # numpy array or already array-like + start_view = starts + end_view = ends + else: # lists or other iterables + from array import array + start_array = array('i', starts) + end_array = array('i', ends) + start_view = start_array + end_view = end_array + + if start_view.shape[0] != end_view.shape[0]: + raise ValueError("starts and ends must have the same length") + + cdef size_t n = start_view.shape[0] + cdef size_t i + cdef object value + + instance.reserve(n) + + if values is None: + for i in range(n): + instance.add(start_view[i], end_view[i]) + else: + if len(values) != n: + raise ValueError("values length must match starts/ends length") + for i, value in enumerate(values): + instance.add(start_view[i], end_view[i], value) + + instance.build() + return instance + + cpdef build(self): + """ + Builds the superintervals index, must be called before queries are made. + """ + self.thisptr.build() + + cpdef at(self, int index): + """ + Fetches the interval and data at the given index. Negative indexing is not supported. + + Args: + index (int): The index of a stored interval. + + Raises: + IndexError: If the index is out of range. + + Returns: + tuple: (start, end, data) + """ + if self.size() == 0 or index < 0 or index >= self.size(): + raise IndexError('Index out of range') + if self.thisptr.data[index] != NULL: + return self.thisptr.starts[index], self.thisptr.ends[index], self.thisptr.data[index] + else: + return self.thisptr.starts[index], self.thisptr.ends[index], None + + cpdef starts_at(self, int index): + """ + Fetches the start position at the given index. Negative indexing is not supported. + + Args: + index (int): The index of a stored interval. + + Raises: + IndexError: If the index is out of range. + + Returns: + tuple: start + """ + if self.size() == 0 or index < 0 or index >= self.size(): + raise IndexError('Index out of range') + return self.thisptr.starts[index] + + cpdef ends_at(self, int index): + """ + Fetches the end position at the given index. Negative indexing is not supported. + + Args: + index (int): The index of a stored interval. + + Raises: + IndexError: If the index is out of range. + + Returns: + tuple: start + """ + if self.size() == 0 or index < 0 or index >= self.size(): + raise IndexError('Index out of range') + return self.thisptr.ends[index] + + cpdef data_at(self, int index): + """ + Fetches the stored data at the given index. Negative indexing is not supported. + + Args: + index (int): The index of a stored interval. + + Raises: + IndexError: If the index is out of range. + + Returns: + tuple: start + """ + if self.size() == 0 or index < 0 or index >= self.size(): + raise IndexError('Index out of range') + if self.thisptr.data[index] == NULL: + return None + else: + return self.thisptr.starts[index] + + cpdef clear(self): + """ + Clear all intervals and associated data. + """ + cdef PyObjectPtr obj_ptr + for obj_ptr in self.thisptr.data: + if obj_ptr != NULL: + Py_DECREF( obj_ptr) + self.thisptr.clear() + + cpdef reserve(self, size_t n): + """ + Reserve space for a specified number of intervals. + + Args: + n (size_t): The number of intervals to reserve space for. + """ + self.thisptr.reserve(n) + + cpdef size(self): + """ + Get the number of intervals in the map. + + Returns: + int: The number of intervals. + """ + return self.thisptr.size() + + cpdef has_overlaps(self, int start, int end): + """ + Check if any intervals overlap with a given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + bool: True if any intervals overlap with the given range, False otherwise. + """ + return self.thisptr.has_overlaps(start, end) + + cpdef count(self, int start, int end): + """ + Count the number of intervals that overlap with a given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + int: The count of overlapping intervals. + """ + return self.thisptr.count(start, end) + + cpdef search_values(self, int start, int end): + """ + Find all Python objects associated with intervals that overlap the given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + list: A list of Python objects associated with overlapping intervals. + """ + self.found_values.clear() + self.thisptr.search_values(start, end, self.found_values) + cdef list result = [None] * self.found_values.size() + cdef size_t i + for i in range(self.found_values.size()): + if self.found_values[i] != NULL: + result[i] = self.found_values[i] + return result + + cpdef search_idxs(self, int start, int end): + """ + Find indices of all intervals that overlap with a given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + list: A list of indices of overlapping intervals. + """ + self.found_indexes.clear() + self.thisptr.search_idxs(start, end, self.found_indexes) + return list(self.found_indexes) + + cpdef search_keys(self, int start, int end): + """ + Find interval keys (start, end pairs) that overlap with a given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + list: A list of (start, end) tuples for overlapping intervals. + """ + self.found_indexes.clear() + self.thisptr.search_idxs(start, end, self.found_indexes) + cdef list result = [None] * self.found_indexes.size() + cdef size_t i + for i in range(self.found_indexes.size()): + result[i] = (self.thisptr.starts[i], self.thisptr.ends[i]) + return result + + cpdef search_items(self, int start, int end): + """ + Find complete interval items (start, end, data) that overlap with a given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + list: A list of (start, end, data) tuples for overlapping intervals. + """ + self.found_indexes.clear() + self.thisptr.search_idxs(start, end, self.found_indexes) + cdef list result = [None] * self.found_indexes.size() + cdef size_t i + for i in range(self.found_indexes.size()): + if self.thisptr.data[i] != NULL: + result[i] = (self.thisptr.starts[i], self.thisptr.ends[i], self.thisptr.data[i]) + else: + result[i] = (self.thisptr.starts[i], self.thisptr.ends[i], None) + return result + + cpdef coverage(self, int start, int end): + """ + Compute coverage statistics for the given range. + + Args: + start (int): The start of the range (inclusive). + end (int): The end of the range (inclusive). + + Returns: + tuple: (count, total_coverage) where count is number of overlapping intervals + and total_coverage is the sum of overlapping lengths. + """ + cdef pair[size_t, int] cov_result = pair[size_t, int](0, 0) + self.thisptr.coverage(start, end, cov_result) + return cov_result.first, cov_result.second + + cpdef count_batch(self, int[:] starts, int[:] ends): + """ + Count overlapping intervals for multiple query ranges. + + Args: + starts: Memory view of start positions (array.array, numpy array, etc.) + ends: Memory view of end positions (array.array, numpy array, etc.) + + Returns: + List: Count of overlapping intervals for each query range + + Example: + >>> from array import array + >>> import numpy as np + >>> im = IntervalMap() + >>> im.add(1, 10, "A") + >>> im.add(5, 15, "B") + >>> im.build() + >>> + >>> # Works with array.array + >>> starts = array('i', [1, 8, 15]) + >>> ends = array('i', [5, 12, 20]) + >>> counts = im.count_batch(starts, ends) + >>> + >>> # Also works with numpy arrays + >>> starts_np = np.array([1, 8, 15], dtype=np.int32) + >>> ends_np = np.array([5, 12, 20], dtype=np.int32) + >>> counts = im.count_batch(starts_np, ends_np) + """ + if starts.shape[0] != ends.shape[0]: + raise ValueError("starts and ends must have the same length") + cdef size_t n = starts.shape[0] + cdef list result_array = [0] * n + cdef size_t i + for i in range(n): + result_array[i] = self.thisptr.count(starts[i], ends[i]) + + return result_array + + cpdef search_idxs_batch(self, int[:] starts, int[:] ends): + """ + Find indices of overlapping intervals for multiple query ranges. + + Args: + starts: Memory view of start positions + ends: Memory view of end positions + + Returns: + list: List of lists, where each sublist contains indices of + overlapping intervals for the corresponding query range + + Example: + >>> from array import array + >>> import numpy as np + >>> im = IntervalMap() + >>> im.add(1, 10, "A") + >>> im.add(5, 15, "B") + >>> im.build() + >>> + >>> # Works with array.array + >>> starts = array('i', [1, 8]) + >>> ends = array('i', [5, 12]) + >>> results = im.search_idxs_batch(starts, ends) + >>> + >>> # Works with numpy arrays + >>> starts_np = np.array([1, 8], dtype=np.int32) + >>> ends_np = np.array([5, 12], dtype=np.int32) + >>> results = im.search_idxs_batch(starts_np, ends_np) + """ + if starts.shape[0] != ends.shape[0]: + raise ValueError("starts and ends must have the same length") + + cdef size_t n = starts.shape[0] + cdef list results = [[] for _ in range(n)] + cdef size_t i, j + for i in range(n): + self.found_indexes.clear() + self.thisptr.search_idxs(starts[i], ends[i], self.found_indexes) + query_result = [0] * self.found_indexes.size() + for j in range(self.found_indexes.size()): + query_result[j] = self.found_indexes[j] + results[i] = query_result + + return results + + cpdef search_values_batch(self, int[:] starts, int[:] ends): + """ + Find values of overlapping intervals for multiple query ranges. + + Args: + starts: Memory view of start positions + ends: Memory view of end positions + + Returns: + list: List of lists, where each inner list contains values of + overlapping intervals for the corresponding query range + + Example: + >>> from array import array + >>> import numpy as np + >>> im = IntervalMap() + >>> im.add(1, 10, "A") + >>> im.add(5, 15, "B") + >>> im.build() + >>> + >>> # Works with array.array + >>> starts = array('i', [1, 8]) + >>> ends = array('i', [5, 12]) + >>> results = im.search_values_batch(starts, ends) + >>> + >>> # Works with numpy arrays + >>> starts_np = np.array([1, 8], dtype=np.int32) + >>> ends_np = np.array([5, 12], dtype=np.int32) + >>> results = im.search_values_batch(starts_np, ends_np) + """ + if starts.shape[0] != ends.shape[0]: + raise ValueError("starts and ends must have the same length") + + cdef size_t n = starts.shape[0] + cdef list results = [[] for _ in range(n)] + cdef size_t i, j + cdef list query_result + for i in range(n): + self.found_values.clear() + self.thisptr.search_values(starts[i], ends[i], self.found_values) + query_result = [None] * self.found_values.size() + for j in range(self.found_values.size()): + if self.found_values[j] != NULL: + query_result[j] = self.found_values[j] + results[i] = query_result + + return results diff --git a/sequila/sequila-core/superintervals/src/variants/superintervals_var.hpp b/sequila/sequila-core/superintervals/src/variants/superintervals_var.hpp new file mode 100644 index 0000000..8ad97ba --- /dev/null +++ b/sequila/sequila-core/superintervals/src/variants/superintervals_var.hpp @@ -0,0 +1,826 @@ + +#pragma once + +#include +#include +#include +#include +#include +#include +#ifndef SI_NOSIMD + #if defined(__AVX2__) + #include + #elif defined(__ARM_NEON__) || defined(__aarch64__) + #include + #else + #define SI_NOSIMD + #endif +#endif + +/** + * @file SuperIntervals.hpp + * @brief A static data structure for finding interval intersections + * + * SuperIntervals is a template class that provides efficient interval intersection operations. + * It supports adding intervals, indexing them for fast queries, and performing various + * intersection operations. + * + * @note Intervals are considered end-inclusive + * @note The index() function must be called before any queries. If more intervals are added, call index() again. + * + * @tparam S The scalar type for interval start and end points (e.g., int, float) + * @tparam T The data type associated with each interval + */ +template +class SuperIntervals { + public: + + struct Interval { + S start, end; + T data; + Interval() = default; + Interval(S s, S e, T d) : start(s), end(e), data(d) {} + }; + + alignas(alignof(std::vector)) std::vector starts; + alignas(alignof(std::vector)) std::vector ends; + alignas(alignof(size_t)) std::vector branch; + std::vector data; + size_t idx; + bool startSorted, endSorted; + + SuperIntervals() + : idx(0) + , startSorted(true) + , endSorted(true) + , it_low(0) + , it_high(0) + {} + + virtual ~SuperIntervals() = default; + + /** + * @brief Clears all intervals and resets the data structure + */ + void clear() noexcept { + data.clear(); starts.clear(); ends.clear(); branch.clear(); idx = 0; + } + + /** + * @brief Reserves memory for a specified number of intervals + * @param n Number of intervals to reserve space for + */ + void reserve(size_t n) { + data.reserve(n); starts.reserve(n); ends.reserve(n); + } + + /** + * @brief Returns the number of intervals in the data structure + * @return Number of intervals + */ + size_t size() { + return starts.size(); + } + + /** + * @brief Adds a new interval to the data structure + * @param start Start point of the interval + * @param end End point of the interval + * @param value Data associated with the interval + */ + void add(S start, S end, const T& value) { + if (startSorted && !starts.empty()) { + startSorted = (start < starts.back()) ? false : true; + if (startSorted && start == starts.back() && end > ends.back()) { + endSorted = false; + } + } + starts.push_back(start); + ends.push_back(end); + data.emplace_back(value); + } + + /** + * @brief Indexes the intervals. + * + * This function must be called after adding intervals and before performing any queries. + * If more intervals are added after indexing, this function should be called again. + */ + virtual void index() { + if (starts.size() == 0) { + return; + } + sortIntervals(); + branch.resize(starts.size(), SIZE_MAX); + std::vector> br; + br.reserve((starts.size() / 10) + 1); + br.emplace_back() = {ends[0], 0}; + for (size_t i=1; i < ends.size(); ++i) { + while (!br.empty() && br.back().first < ends[i]) { + br.pop_back(); + } + if (!br.empty()) { + branch[i] = br.back().second; + } + br.emplace_back() = {ends[i], i}; + } + idx = 0; + } + + /** + * @brief Retrieves an interval at a specific index + * @param index The index of the interval to retrieve + * @return The Interval at the specified index + */ + const Interval& at(size_t index) const { + return Interval{starts[index], ends[index], data[index]}; + } + + void at(size_t index, Interval& itv) { + itv.start = starts[index]; + itv.end = ends[index]; + itv.data = data[index]; + } + + class Iterator { + public: + size_t it_index; + Iterator(const SuperIntervals* list, size_t index) : super(list) { + _start = list->it_low; + _end = list->it_high; + it_index = index; + } + typename SuperIntervals::Interval operator*() const { + return typename SuperIntervals::Interval{super->starts[it_index], super->ends[it_index], super->data[it_index]}; + } + Iterator& operator++() { + if (it_index == 0) { + it_index = SIZE_MAX; + return *this; + } + if (it_index > 0) { + if (_start <= super->ends[it_index]) { + --it_index; + } else { + if (super->branch[it_index] >= it_index) { + it_index = SIZE_MAX; + return *this; + } + it_index = super->branch[it_index]; + if (_start <= super->ends[it_index]) { + --it_index; + } else { + it_index = SIZE_MAX; + return *this; + } + } + } + return *this; + } + bool operator!=(const Iterator& other) const { + return it_index != other.it_index; + } + bool operator==(const Iterator& other) const { + return it_index == other.it_index; + } + Iterator begin() const { return Iterator(super, super->idx); } + Iterator end() const { return Iterator(super, SIZE_MAX); } + private: + S _start, _end; + const SuperIntervals* super; + + }; + + Iterator begin() const { return Iterator(this, idx); } + Iterator end() const { return Iterator(this, SIZE_MAX); } + + /** + * @brief Sets the search interval. Must be called before using the iterator. + * @param start Start point of the search range + * @param end End point of the search range + */ + void searchInterval(const S start, const S end) noexcept { + if (starts.empty()) { + return; + } + it_low = start; it_high = end; + upperBound(end); + if (start > ends[idx] || starts[0] > end) { + idx = SIZE_MAX; + } + } + + virtual inline void upperBound(const S value) noexcept { // https://github.com/mh-dm/sb_lower_bound/blob/master/sbpm_lower_bound.h + // less branchy + size_t length = starts.size(); + idx = 0; + while (length > 1) { + size_t half = length / 2; + idx += (starts[idx + half] <= value) * (length - half); + length = half; + } + if (starts[idx] > value) { + --idx; // Might set idx to SIZE_MAX + } + +// idx = std::distance(starts.begin(), +// std::upper_bound(starts.begin(), starts.end(), value)) - 1; + } + + void findOverlaps(const S start, const S end, std::vector& found) { + if (starts.empty()) { + return; + } + upperBound(end); + + size_t i = idx; + + if (idx == SIZE_MAX) { + return; + } + while (i != SIZE_MAX && start <= ends[i]) { + --i; + } + if (i == SIZE_MAX) { + found.insert(found.end(), data.rend() - idx - 1, data.rend()); + return; + } + found.insert(found.end(), data.rend() - idx - 1, data.rend() - i - 1); + + i = branch[i]; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.push_back(data[i]); + --i; + } else { + i = branch[i]; + } + } + + + // v0 imporved +// while (i != SIZE_MAX) { +// if (start <= ends[i]) { +// found.push_back(data[i]); +// --i; +// } else { +// if (branch[i] == SIZE_MAX) { +// return; +// } +// i = branch[i]; +// } +// } + + + // v0 +// while (i > 0) { +// if (start <= ends[i]) { +// found.push_back(data[i]); +// --i; +// } else { +// if (branch[i] == SIZE_MAX) { +// return; +// } +// i = branch[i]; +// } +// } +// if (i == 0 && start <= ends[0] && starts[0] <= end) { +// found.push_back(data[0]); +// } + + } + + +// inline void upperBoundRange(const S value, size_t& left, const size_t right) noexcept { +// size_t length = right - left; +// while (length > 1) { +// size_t half = length / 2; +// left += (starts[left + half] < value) * (length - half); +// length = half; +// } +// if (starts[left] >= value) { +// --left; // underflows to SIZE_MAX +// } +// } +// +// void findOverlaps2(const S start, const S end, std::vector& found) { +// if (starts.empty()) { +// return; +// } +// upperBound(end); +// if (idx == SIZE_MAX) { +// return; +// } +// +// size_t starting_idx = idx; +// size_t right = idx; +// size_t bound = 1; +// +// // Exponential search to find range where start < starts[idx] stops being true +// while (idx > 0 && start < starts[idx]) { +// right = idx; +// idx = (bound <= idx) ? idx - bound : 0; +// bound *= 2; +// } +// +// // Binary search in range [idx, right] to find exact boundary +// upperBoundRange(start, idx, right); +// if (idx == SIZE_MAX) { +// found.insert(found.end(), data.rend() - starting_idx - 1, data.rend()); +// return; +// } +// // Bulk insert from starting_idx down to idx +// found.insert(found.end(), data.rend() - starting_idx - 1, data.rend() - idx - 1); +// +// // Continue with branching logic +// while (idx != SIZE_MAX) { +// if (start <= ends[idx]) { +// found.push_back(data[idx]); +// --idx; +// } else { +// idx = branch[idx]; +// } +// } +// } + + inline void upperBoundRange(const S value, size_t& left, const size_t right) noexcept { + // First do exponential search if we have room + size_t search_right = right; + size_t bound = 1; + // Exponential search to find a smaller range + while (left > 0 && value < starts[left]) { + search_right = left; + left = (bound <= left) ? left - bound : 0; + bound *= 2; + } + // Now do binary search in the reduced range [left, search_right] + size_t length = search_right - left; + while (length > 1) { + size_t half = length / 2; + left += (starts[left + half] < value) * (length - half); + length = half; + } + if (left == 0 && starts[left] >= value) { + left = SIZE_MAX; + } + } + + void findOverlaps2(const S start, const S end, std::vector& found) { + if (starts.empty()) { + return; + } + upperBound(end); + if (idx == SIZE_MAX) { + return; + } + size_t starting_idx = idx; + upperBoundRange(start, idx, idx); + if (idx == SIZE_MAX) { + found.insert(found.end(), data.rend() - starting_idx - 1, data.rend()); + return; + } + found.insert(found.end(), data.rend() - starting_idx - 1, data.rend() - idx - 1); + while (idx != SIZE_MAX) { + if (start <= ends[idx]) { + found.push_back(data[idx]); + --idx; + } else { + idx = branch[idx]; + } + } + } + + size_t countOverlaps2(const S start, const S end) noexcept { + if (starts.empty()) { + return 0; + } + upperBound(end); + if (idx == SIZE_MAX) { + return 0; + } + size_t starting_i = idx; + upperBoundRange(start, idx, idx); + size_t count = starting_i - idx; + + // Linear scan or branch jump + while (idx != SIZE_MAX) { + if (start <= ends[idx]) { + count += 1; + --idx; + } else { + idx = branch[idx]; + } + } + return count; + } + + static constexpr auto findValues = &SuperIntervals::findOverlaps; + + bool anyOverlaps(const S start, const S end) noexcept { + if (starts.empty()) { + return false; + } + upperBound(end); + size_t i = idx; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + return true; + } else { + if (branch[i] == SIZE_MAX) { + return false; + } + i = branch[i]; + } + } + return false; + } + + void findIndexes(const S start, const S end, std::vector& found) { + if (starts.empty()) { + return; + } + upperBound(end); + size_t i = idx; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.push_back(i); + --i; + } else { + if (branch[i] == SIZE_MAX) { + return; + } + i = branch[i]; + } + } + } + + void findKeys(const S start, const S end, std::vector>& found) { + if (starts.empty()) { + return; + } + upperBound(end); + size_t i = idx; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.push_back({starts[i], ends[i]}); + --i; + } else { + if (branch[i] == SIZE_MAX) { + return; + } + i = branch[i]; + } + } + } + + void findItems(const S start, const S end, std::vector& found) { + if (starts.empty()) { + return; + } + upperBound(end); + size_t i = idx; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + found.emplace_back() = {starts[i], ends[i], data[i]}; + --i; + } else { + if (branch[i] == SIZE_MAX) { + return; + } + i = branch[i]; + } + } + } + + size_t countOverlaps(const S start, const S end) noexcept { + if (starts.empty()) { + return 0; + } + upperBound(end); + size_t i = idx; + if (i == SIZE_MAX) { + return 0; + } + size_t found = 0; + +#ifdef SI_NOSIMD + constexpr size_t block = 16; +#elif defined(__AVX2__) + __m256i start_vec = _mm256_set1_epi32(start); + constexpr size_t simd_width = 256 / (sizeof(S) * 8); + constexpr size_t block = simd_width * 4; +#elif defined(__ARM_NEON__) || defined(__aarch64__) + int32x4_t start_vec = vdupq_n_s32(start); + constexpr size_t simd_width = 128 / (sizeof(S) * 8); + uint32x4_t ones = vdupq_n_u32(1); + constexpr size_t block = simd_width * 4; +#endif + + while (i > 0) { + if (start <= ends[i]) { + ++found; + --i; +#ifdef SI_NOSIMD + while (i > block) { // Rely on compiler auto vectorize + size_t count = 0; + for (size_t j = i; j > i - block; --j) { + count += (start <= ends[j]) ? 1 : 0; + } + found += count; + i -= block; + if (count < block && start > ends[i + 1]) { // check for a branch + break; + } + } + +#elif defined(__AVX2__) + while (i > block) { + size_t count = 0; + for (size_t j = i; j > i - block; j -= simd_width) { + __m256i ends_vec = _mm256_load_si256((__m256i*)(&ends[j - simd_width + 1])); + __m256i cmp_mask = _mm256_cmpgt_epi32(start_vec, ends_vec); + int mask = _mm256_movemask_epi8(~cmp_mask); + count += _mm_popcnt_u32(mask); + } + found += count / 4; // Each comparison result is 4 bits + i -= block; + if (count < block) { + break; + } + } +#elif defined(__ARM_NEON__) || defined(__aarch64__) + while (i > block) { + size_t count = 0; + uint32x4_t mask, bool_mask; + for (size_t j = i; j > i - block; j -= simd_width) { // Neon processes 4 int32 at a time + int32x4_t ends_vec = vld1q_s32(&ends[j - simd_width + 1]); + mask = vcleq_s32(start_vec, ends_vec); // True (0xFFFFFFFF) for elements where start_vec <= ends_vec + bool_mask = vandq_u32(mask, ones); + count += vaddvq_u32(bool_mask); + } + found += count; + i -= block; +// if (count < block && vgetq_lane_u32(mask, 0) == 0) { // check for overlap again, before checking for branch? + if (count < block) { // check for overlap again, before checking for branch? + break; + } + } +#endif + } else { + if (branch[i] == SIZE_MAX) { + return found; + } + i = branch[i]; + } + } + if (i==0 && start <= ends[0] && starts[0] <= end) { + ++found; + } + return found; + } + + void coverage(const S start, const S end, std::pair &cov_result) { + if (starts.empty()) { + cov_result.first = 0; + cov_result.second = 0; + return; + } + upperBound(end); + size_t i = idx; + size_t cnt = 0; + S cov = 0; + while (i != SIZE_MAX) { + if (start <= ends[i]) { + ++cnt; + cov += std::min(ends[i], end) - std::max(starts[i], start); + --i; + } else { + if (branch[i] == SIZE_MAX) { + cov_result.first = cnt; + cov_result.second = cov; + return; + } + i = branch[i]; + } + } + cov_result.first = cnt; + cov_result.second = cov; + } + + void findStabbed(const S point, std::vector& found) { + if (starts.empty()) { + return; + } + upperBound(point); + size_t i = idx; + while (i != SIZE_MAX) { + if (point <= ends[i]) { + found.push_back(data[i]); + --i; + } else { + i = branch[i]; + } + } + } + + size_t countStabbed(const S point) noexcept { + if (starts.empty()) { + return 0; + } + upperBound(point); + size_t found = 0; + size_t i = idx; + +#ifdef SI_NOSIMD + constexpr size_t block = 16; +#elif defined(__AVX2__) + __m256i start_vec = _mm256_set1_epi32(point); + constexpr size_t simd_width = 256 / (sizeof(S) * 8); + constexpr size_t block = simd_width * 4; +#elif defined(__ARM_NEON__) || defined(__aarch64__) + int32x4_t start_vec = vdupq_n_s32(point); + constexpr size_t simd_width = 128 / (sizeof(S) * 8); + uint32x4_t ones = vdupq_n_u32(1); + constexpr size_t block = simd_width * 4; +#endif + + while (i > 0) { + if (point <= ends[i]) { + ++found; + --i; + +#ifdef SI_NOSIMD + while (i > block) { + size_t count = 0; + for (size_t j = i; j > i - block; --j) { + count += (point <= ends[j]) ? 1 : 0; + } + found += count; + i -= block; + if (count < block) { // go check for a branch + break; + } + } +#elif defined(__AVX2__) + while (i > block) { + size_t count = 0; + for (size_t j = i; j > i - block; j -= simd_width) { + __m256i ends_vec = _mm256_load_si256((__m256i*)(&ends[j - simd_width + 1])); + __m256i cmp_mask = _mm256_cmpgt_epi32(start_vec, ends_vec); + int mask = _mm256_movemask_epi8(~cmp_mask); + count += _mm_popcnt_u32(mask); + } + found += count / 4; // Each comparison result is 4 bits + i -= block; + if (count < block) { + break; + } + } +#elif defined(__ARM_NEON__) || defined(__aarch64__) + while (i > block) { + size_t count = 0; + uint32x4_t bool_mask; + for (size_t j = i; j > i - block; j -= simd_width) { // Neon processes 4 int32 at a time + int32x4_t ends_vec = vld1q_s32(&ends[j - simd_width + 1]); + uint32x4_t mask = vcleq_s32(start_vec, ends_vec); + bool_mask = vandq_u32(mask, ones); // Convert -1 to 1 for true elements + count += vaddvq_u32(bool_mask); + } + found += count; + i -= block; + if (count < block) { // go check for a branch + break; + } + } +#endif + } else { + if (branch[i] >= i) { + break; + } + i = branch[i]; + } + } + if (i==0 && point <= ends[0] && starts[0] <= point) { + ++found; + } + return found; + } + + protected: + + S it_low, it_high; + std::vector tmp; + + template + void sortBlock(size_t start_i, size_t end_i, CompareFunc compare) { + size_t range_size = end_i - start_i; + tmp.resize(range_size); + for (size_t i = 0; i < range_size; ++i) { + tmp[i] = Interval(starts[start_i + i], ends[start_i + i], data[start_i + i]); + } + std::sort(tmp.begin(), tmp.end(), compare); + for (size_t i = 0; i < range_size; ++i) { + starts[start_i + i] = tmp[i].start; + ends[start_i + i] = tmp[i].end; + data[start_i + i] = tmp[i].data; + } + } + + void sortIntervals() { + if (!startSorted) { + sortBlock(0, starts.size(), + [](const Interval& a, const Interval& b) { return (a.start < b.start || (a.start == b.start && a.end > b.end)); }); + startSorted = true; + endSorted = true; + } else if (!endSorted) { // only sort parts that need sorting - ends in descending order + size_t it_start = 0; + while (it_start < starts.size()) { + size_t block_end = it_start + 1; + bool needs_sort = false; + while (block_end < starts.size() && starts[block_end] == starts[it_start]) { + if (block_end > it_start && ends[block_end] > ends[block_end - 1]) { + needs_sort = true; + } + ++block_end; + } + if (needs_sort) { + sortBlock(it_start, block_end, [](const Interval& a, const Interval& b) { return a.end > b.end; }); + } + it_start = block_end; + } + endSorted = true; + } + } +}; + + +template +class SuperIntervalsEytz : public SuperIntervals { +public: + + void index() override { + if (this->starts.size() == 0) { + return; + } + this->starts.shrink_to_fit(); + this->ends.shrink_to_fit(); + this->data.shrink_to_fit(); + this->sortIntervals(); + + eytz.resize(this->starts.size() + 1); + eytz_index.resize(this->starts.size() + 1); + eytzinger(&this->starts[0], this->starts.size()); + + this->branch.resize(this->starts.size(), SIZE_MAX); + std::vector> br; + br.reserve(1000); + br.emplace_back() = {this->ends[0], 0}; + for (size_t i=1; i < this->ends.size(); ++i) { + while (!br.empty() && br.back().first < this->ends[i]) { + br.pop_back(); + } + if (!br.empty()) { + this->branch[i] = br.back().second; + } + br.emplace_back() = {this->ends[i], i}; + } + this->idx = 0; + } + + inline void upperBound(const S x) noexcept override { + size_t i = 0; + const size_t n_intervals = this->starts.size(); + while (i < n_intervals) { + if (eytz[i] > x) { + i = 2 * i + 1; + } else { + i = 2 * i + 2; + } + } + int shift = __builtin_ffs(~(i + 1)); + size_t best_idx = (i >> shift) - ((shift > 1) ? 1 : 0); + this->idx = (best_idx < n_intervals) ? eytz_index[best_idx] : n_intervals - 1; + if (this->idx > 0 && this->starts[this->idx] > x) { + --this->idx; + } + } + +private: + std::vector eytz; + std::vector eytz_index; + + size_t eytzinger_helper(S* arr, size_t n, size_t i, size_t k) { + if (k < n) { + i = eytzinger_helper(arr, n, i, 2*k+1); + eytz[k] = this->starts[i]; + eytz_index[k] = i; + ++i; + i = eytzinger_helper(arr, n, i, 2*k + 2); + } + return i; + } + + int eytzinger(S* arr, size_t n) { + return eytzinger_helper(arr, n, 0, 0); + } +}; diff --git a/sequila/sequila-core/superintervals/src/variants/superintervals_var.rs b/sequila/sequila-core/superintervals/src/variants/superintervals_var.rs new file mode 100644 index 0000000..26b224a --- /dev/null +++ b/sequila/sequila-core/superintervals/src/variants/superintervals_var.rs @@ -0,0 +1,688 @@ +//! This module provides efficient data structures for managing and querying intervals. +//! It includes implementations for standard and Eytzinger layout-based interval storage. + +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::cmp::{max, min}; + +/// Represents an interval with associated data. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Interval { + start: i32, + end: i32, + data: T, +} + +/// A static data structure for finding interval intersections +/// +/// SuperIntervals is a template class that provides efficient interval intersection operations. +/// It supports adding intervals, indexing them for fast queries, and performing various +/// intersection operations. +/// +/// Intervals are considered end-inclusive +/// The index() function must be called before any queries. If more intervals are added, call index() again. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SuperIntervals { + starts: Vec, + ends: Vec, + branch: Vec, + data: Vec, + idx: usize, + start_sorted: bool, + end_sorted: bool, +} + +impl SuperIntervals { + pub fn new() -> Self { + SuperIntervals { + starts: Vec::new(), + ends: Vec::new(), + branch: Vec::new(), + data: Vec::new(), + idx: 0, + start_sorted: true, + end_sorted: true, + } + } + /// Clears all intervals from the structure. + pub fn clear(&mut self) { + self.starts.clear(); + self.ends.clear(); + self.branch.clear(); + self.data.clear(); + self.idx = 0; + } + /// Adds a new interval to the structure. + /// + /// # Arguments + /// + /// * `start` - The start point of the interval. + /// * `end` - The end point of the interval. + /// * `value` - The data associated with the interval. + pub fn add(&mut self, start: i32, end: i32, value: T) { + if self.start_sorted && !self.starts.is_empty() { + self.start_sorted = start >= *self.starts.last().unwrap(); + if self.start_sorted + && start == *self.starts.last().unwrap() + && end > *self.ends.last().unwrap() + { + self.end_sorted = false; + } + } + self.starts.push(start); + self.ends.push(end); + self.data.push(value); + } + /// Sorts the intervals by start and -end. + pub fn sort_intervals(&mut self) { + if !self.start_sorted { + self.sort_block(0, self.starts.len(), |a, b| { + if a.start < b.start { + Ordering::Less + } else if a.start == b.start { + if a.end > b.end { + Ordering::Less + } else { + Ordering::Greater + } + } else { + Ordering::Greater + } + }); + self.start_sorted = true; + self.end_sorted = true; + } else if !self.end_sorted { + let mut it_start = 0; + while it_start < self.starts.len() { + let mut block_end = it_start + 1; + let mut needs_sort = false; + while block_end < self.starts.len() + && self.starts[block_end] == self.starts[it_start] + { + if block_end > it_start && self.ends[block_end] > self.ends[block_end - 1] { + needs_sort = true; + } + block_end += 1; + } + if needs_sort { + self.sort_block(it_start, block_end, |a, b| a.end.cmp(&b.end).reverse()); + } + it_start = block_end; + } + self.end_sorted = true; + } + } + /// Indexes the intervals. Must be called before queries are performed. + pub fn index(&mut self) { + if self.starts.is_empty() { + return; + } + self.sort_intervals(); + self.branch.resize(self.starts.len(), usize::MAX); + let mut br: Vec<(i32, usize)> = Vec::with_capacity((self.starts.len() / 10) + 1); + unsafe { + br.push((*self.ends.get_unchecked(0), 0)); + for i in 1..self.ends.len() { + while !br.is_empty() && br.last().unwrap().0 < *self.ends.get_unchecked(i) { + br.pop(); + } + if !br.is_empty() { + *self.branch.get_unchecked_mut(i) = br.last().unwrap().1; + } + br.push((*self.ends.get_unchecked(i), i)); + } + } + self.idx = 0; + } + + // #[inline(always)] + // pub fn upper_bound(&mut self, value: i32) { + // let mut len = self.starts.len(); + // unsafe { + // self.idx = 0; + // while len > 0 { + // let half = len / 2; + // let mid = self.idx + half; + // let cond = (*self.starts.get_unchecked(mid) <= value) as usize; + // // self.idx += cond * (len - half); + // self.idx += cond * (half + 1); + // len = half; + // } + // self.idx = self.idx.wrapping_sub(1); // Might underflow to usize::MAX + // + // // if *self.starts.get_unchecked(self.idx) > value { + // // self.idx = self.idx.wrapping_sub(1); + // // } + // } + // } + + #[inline(always)] + pub fn upper_bound(&mut self, value: i32) { + let mut length = self.starts.len(); + unsafe { + self.idx = 0; //usize::MAX; + while length > 1 { + let half = length / 2; + self.idx += (*self.starts.get_unchecked(self.idx + half) <= value) as usize + * (length - half); + length = half; + } + // self.idx = self.idx.wrapping_sub((*self.starts.get_unchecked(self.idx) > value) as usize); + if *self.starts.get_unchecked(self.idx) > value { + self.idx = self.idx.wrapping_sub(1); + } + } + } + + // pub fn find_overlaps(&mut self, start: i32, end: i32, found: &mut Vec) { + // if self.starts.is_empty() { + // return; + // } + // self.upper_bound(end); + // let mut i = self.idx; + // + // unsafe { + // while i != usize::MAX { + // if start <= *self.ends.get_unchecked(i) { + // found.push(self.data.get_unchecked(i).clone()); + // i = i.wrapping_sub(1); + // } else { + // i = *self.branch.get_unchecked(i); + // } + // } + // } + // } + + /// Finds all intervals that overlap with the given range. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// * `found` - A mutable vector to store the overlapping intervals' data. + pub fn find_overlaps(&mut self, start: i32, end: i32, found: &mut Vec) { + if self.starts.is_empty() { + return; + } + self.upper_bound(end); + if self.idx == usize::MAX { + return; + } + let mut i = self.idx; + + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + + if i == usize::MAX { + found.reserve(self.idx + 1); + found.extend( + (0..=self.idx) + .rev() + .map(|j| self.data.get_unchecked(j).clone()), + ); + return; + } + + let count = self.idx - i; + found.reserve(count); + found.extend( + ((i + 1)..=self.idx) + .rev() + .map(|j| self.data.get_unchecked(j).clone()), + ); + + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push(self.data.get_unchecked(i).clone()); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + + #[inline(always)] + pub fn upper_bound_range(&mut self, value: i32, right: usize) -> usize { + let mut left = right; + let mut search_right = right; + let mut bound = 1; + // Exponential search to find a smaller range + unsafe { + while left > 0 && value < *self.starts.get_unchecked(left) { + search_right = left; + left = if bound <= left { left - bound } else { 0 }; + bound *= 2; + } + // Binary search in the reduced range [search_left, search_right] + let mut length = search_right - left; + while length > 1 { + let half = length / 2; + let condition = *self.starts.get_unchecked(left + half) < value; + left += if condition { length - half } else { 0 }; + length = half; + } + if left == 0 && *self.starts.get_unchecked(left) >= value { + left = usize::MAX; + } + } + left + } + + /// Finds all intervals that overlap with the given range (version 2). + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// * `found` - A mutable vector to store the overlapping intervals' data. + pub fn find_overlaps2(&mut self, start: i32, end: i32, found: &mut Vec) { + if self.starts.is_empty() { + return; + } + self.upper_bound(end); + if self.idx == usize::MAX { + return; + } + let mut i = self.upper_bound_range(start, self.idx); + + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + found.reserve(self.idx + 1); + found.extend( + (0..=self.idx) + .rev() + .map(|j| self.data.get_unchecked(j).clone()), + ); + return; + } + found.reserve(self.idx - i); + found.extend( + ((i + 1)..=self.idx) + .rev() + .map(|j| self.data.get_unchecked(j).clone()), + ); + + i = *self.branch.get_unchecked(i); + + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + found.push(self.data.get_unchecked(i).clone()); + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + } + + pub fn count_overlaps(&mut self, start: i32, end: i32) -> usize { + if self.starts.is_empty() { + return 0; + } + self.upper_bound(end); + if self.idx == usize::MAX { + return 0; + } + let mut i = self.idx; + let mut count: usize = 0; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + if i == usize::MAX { + return self.idx + 1; + } + count += self.idx - i; + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + count += 1; + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + } + count + } + + /// Counts all intervals that overlap with the given range (version 2). + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals. + pub fn count_overlaps2(&mut self, start: i32, end: i32) -> usize { + if self.starts.is_empty() { + return 0; + } + self.upper_bound(end); + if self.idx == usize::MAX { + return 0; + } + let mut i = self.upper_bound_range(start, self.idx); + let mut count: usize = 0; + unsafe { + while i != usize::MAX && start <= *self.ends.get_unchecked(i) { + i = i.wrapping_sub(1); + } + // i = self.rewind_past_overlaps(i, start); + if i == usize::MAX { + return self.idx + 1; + } + count += self.idx - i; + i = *self.branch.get_unchecked(i); + while i != usize::MAX { + if start <= *self.ends.get_unchecked(i) { + count += 1; + i = i.wrapping_sub(1); + } else { + i = *self.branch.get_unchecked(i); + } + } + count + } + } + + /// Counts the number of intervals that overlap with the given range. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals. + // pub fn count_overlaps(&mut self, start: i32, end: i32) -> usize { + // if self.starts.is_empty() { + // return 0; + // } + // self.upper_bound(end); + // if self.idx == usize::MAX { + // return 0; + // } + // let mut found: usize = 0; + // let mut i = self.idx; + // + // unsafe { + // #[cfg(target_arch = "x86_64")] + // { + // use std::arch::x86_64::*; + // let start_vec = _mm256_set1_epi32(start); + // let ones: __m256i = _mm256_set1_epi32(1); + // const SIMD_WIDTH: usize = 8; + // const BLOCK: usize = 32; + // while i > 0 { + // if start <= *self.ends.get_unchecked(i) { + // found += 1; + // i -= 1; + // while i > BLOCK { + // let mut count = 0; + // for j in (i - BLOCK + 1..=i).rev().step_by(SIMD_WIDTH) { + // let ends_vec = _mm256_loadu_si256(self.ends.as_ptr().add(j - SIMD_WIDTH + 1) as *const __m256i); + // // Add one to convert the greater than to greater or equal than + // let adj_ends_vec = _mm256_add_epi32(ends_vec, ones); + // let cmp_mask = _mm256_cmpgt_epi32(adj_ends_vec, start_vec); + // let mask = _mm256_movemask_epi8(cmp_mask); + // count += mask.count_ones() as usize; + // } + // found += count / 4; // Each comparison result is 4 bits + // i -= BLOCK; + // if count < BLOCK { + // break; + // } + // } + // } else { + // if *self.branch.get_unchecked(i) == usize::MAX { + // return found; + // } + // i = *self.branch.get_unchecked(i); + // } + // } + // } + // + // #[cfg(target_arch = "aarch64")] + // { + // use std::arch::aarch64::*; + // let start_vec = vdupq_n_s32(start); + // const SIMD_WIDTH: usize = 4; //128 / (core::mem::size_of::() * 8); + // const BLOCK: usize = 32; // SIMD_WIDTH * 4; + // let ones = vdupq_n_u32(1); + // while i > 0 { + // if start <= *self.ends.get_unchecked(i) { + // found += 1; + // i -= 1; + // while i > BLOCK { + // let mut count = 0; + // for j in (i - BLOCK + 1..=i).rev().step_by(SIMD_WIDTH) { + // let ends_vec = vld1q_s32(self.ends.as_ptr().add(j - SIMD_WIDTH + 1) as *const i32); + // let mask = vcleq_s32(start_vec, ends_vec); + // let bool_mask = vandq_u32(mask, ones); + // count += vaddvq_u32(bool_mask) as usize; + // } + // found += count; + // i -= BLOCK; + // if count < BLOCK { + // break; + // } + // } + // } else { + // if *self.branch.get_unchecked(i) == usize::MAX { + // return found; + // } + // i = *self.branch.get_unchecked(i); + // } + // } + // } + // + // #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + // { + // const BLOCK: usize = 16; + // while i > 0 { + // if start <= *self.ends.get_unchecked(i) { + // found += 1; + // i -= 1; + // while i > BLOCK { + // let mut count = 0; + // for j in (i - BLOCK + 1..=i).rev() { + // if start <= *self.ends.get_unchecked(j) { + // count += 1; + // } + // } + // found += count; + // i -= BLOCK; + // if count < BLOCK { + // break; + // } + // } + // } else { + // if *self.branch.get_unchecked(i) == usize::MAX { + // return found; + // } + // i = *self.branch.get_unchecked(i); + // } + // } + // } + // if i == 0 && start <= *self.ends.get_unchecked(0) && *self.starts.get_unchecked(0) <= end { + // found += 1; + // } + // } + // found + // } + /// Counts the total coverage over the query interval. + /// + /// # Arguments + /// + /// * `start` - The start of the range to check for overlaps. + /// * `end` - The end of the range to check for overlaps. + /// + /// # Returns + /// + /// The number of overlapping intervals, plus the coverage. + pub fn coverage(&mut self, start: i32, end: i32) -> (usize, i32) { + if self.starts.is_empty() { + return (0, 0); + } + self.upper_bound(end); + let mut i = self.idx; + let mut cnt = 0; + let mut cov = 0; + unsafe { + while i > 0 { + if start <= *self.ends.get_unchecked(i) { + cnt += 1; + cov += min(*self.ends.get_unchecked(i), end) + - max(*self.starts.get_unchecked(i), start); + i -= 1; + } else { + if *self.branch.get_unchecked(i) == usize::MAX { + return (cnt, cov); + } + i = *self.branch.get_unchecked(i); + } + } + if i == 0 + && start <= *self.ends.get_unchecked(0) + && *self.starts.get_unchecked(0) <= end + { + cov += min(*self.ends.get_unchecked(i), end) + - max(*self.starts.get_unchecked(i), start); + cnt += 1; + } + } + (cnt, cov) + } + + fn sort_block(&mut self, start_i: usize, end_i: usize, compare: F) + where + F: Fn(&Interval, &Interval) -> Ordering, + { + unsafe { + let range_size = end_i - start_i; + let mut tmp: Vec> = Vec::with_capacity(range_size); + for i in 0..range_size { + tmp.push(Interval { + start: *self.starts.get_unchecked(start_i + i), + end: *self.ends.get_unchecked(start_i + i), + data: (*self.data.get_unchecked(start_i + i)).clone(), + }); + } + tmp.sort_by(compare); + for i in 0..range_size { + self.starts[start_i + i] = tmp.get_unchecked(i).start; + self.ends[start_i + i] = tmp.get_unchecked(i).end; + self.data[start_i + i] = tmp.get_unchecked(i).data.clone(); + } + } + } +} + +/// A variant of `SuperIntervals` that uses the Eytzinger layout for faster searching. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SuperIntervalsEytz { + inner: SuperIntervals, + eytz: Vec, + eytz_index: Vec, +} + +#[inline(always)] +pub fn ffs(x: u32) -> u32 { + if x == 0 { + 0 + } else { + x.trailing_zeros() + 1 + } +} + +impl SuperIntervalsEytz { + pub fn new() -> Self { + SuperIntervalsEytz { + inner: SuperIntervals::new(), + eytz: Vec::new(), + eytz_index: Vec::new(), + } + } + + pub fn clear(&mut self) { + self.inner.clear(); + self.eytz.clear(); + self.eytz_index.clear(); + } + + pub fn add(&mut self, start: i32, end: i32, value: T) { + self.inner.add(start, end, value); + } + + pub fn sort_intervals(&mut self) { + self.inner.sort_intervals(); + } + + fn eytzinger_helper(&mut self, mut i: usize, k: usize, n: usize) -> usize { + unsafe { + if k < n { + i = self.eytzinger_helper(i, 2 * k + 1, n); + self.eytz[k] = *self.inner.starts.get_unchecked(i); + self.eytz_index[k] = i; + i += 1; + i = self.eytzinger_helper(i, 2 * k + 2, n); + } + } + i + } + + pub fn index(&mut self) { + if self.inner.starts.is_empty() { + return; + } + self.sort_intervals(); + self.eytz.resize(self.inner.starts.len() + 1, 0); + self.eytz_index.resize(self.inner.starts.len() + 1, 0); + self.eytzinger_helper(0, 0, self.inner.starts.len()); + self.inner.index(); + } + + // todo this is not actually used! + pub fn upper_bound(&mut self, x: i32) { + unsafe { + let mut i: usize = 0; + let n_intervals: usize = self.inner.starts.len(); + while i < n_intervals { + if *self.eytz.get_unchecked(i) > x { + i = 2 * i + 1; + } else { + i = 2 * i + 2; + } + } + let shift: u32 = ffs(!(i as u32 + 1)); + let best_idx: usize = (i >> shift) - (if shift > 1 { 1 } else { 0 }); + self.inner.idx = if best_idx < n_intervals { + *self.eytz_index.get_unchecked(best_idx) + } else { + n_intervals - 1 + }; + if self.inner.idx > 0 && *self.inner.starts.get_unchecked(self.inner.idx) > x { + self.inner.idx -= 1; + } + } + } + + pub fn find_overlaps(&mut self, start: i32, end: i32, found: &mut Vec) { + self.inner.find_overlaps(start, end, found) + } + + pub fn count_overlaps(&mut self, start: i32, end: i32) -> usize { + self.inner.count_overlaps(start, end) + } +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/IITree.hpp b/sequila/sequila-core/superintervals/test/3rd-party/IITree.hpp new file mode 100644 index 0000000..3544022 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/IITree.hpp @@ -0,0 +1,110 @@ +#pragma once + +#include +#include + +/* Suppose there are N=2^(K+1)-1 sorted numbers in an array a[]. They + * implicitly form a complete binary tree of height K+1. We consider leaves to + * be at level 0. The binary tree has the following properties: + * + * 1. The lowest k-1 bits of nodes at level k are all 1. The k-th bit is 0. + * The first node at level k is indexed by 2^k-1. The root of the tree is + * indexed by 2^K-1. + * + * 2. For a node x at level k, its left child is x-2^(k-1) and the right child + * is x+2^(k-1). + * + * 3. For a node x at level k, it is a left child if its (k+1)-th bit is 0. Its + * parent node is x+2^k. Similarly, if the (k+1)-th bit is 1, x is a right + * child and its parent is x-2^k. + * + * 4. For a node x at level k, there are 2^(k+1)-1 nodes in the subtree + * descending from x, including x. The left-most leaf is x&~(2^k-1) (masking + * the lowest k bits to 0). + * + * When numbers can't fill a complete binary tree, the parent of a node may not + * be present in the array. The implementation here still mimics a complete + * tree, though getting the special casing right is a little complex. There may + * be alternative solutions. + * + * As a sorted array can be considered as a binary search tree, we can + * implement an interval tree on top of the idea. We only need to record, for + * each node, the maximum value in the subtree descending from the node. + */ +template // "S" is a scalar type; "T" is the type of data associated with each interval +class IITree { + struct StackCell { + size_t x; // node + int k, w; // k: level; w: 0 if left child hasn't been processed + StackCell() {}; + StackCell(int k_, size_t x_, int w_) : x(x_), k(k_), w(w_) {}; + }; + struct Interval { + S st, en, max; + T data; + Interval(const S &s, const S &e, const T &d) : st(s), en(e), max(e), data(d) {}; + }; + struct IntervalLess { + bool operator()(const Interval &a, const Interval &b) const { return a.st < b.st; } + }; + std::vector a; + int max_level; + int index_core(std::vector &a) { + size_t i, last_i; // last_i points to the rightmost node in the tree + S last; // last is the max value at node last_i + int k; + if (a.size() == 0) return -1; + for (i = 0; i < a.size(); i += 2) last_i = i, last = a[i].max = a[i].en; // leaves (i.e. at level 0) + for (k = 1; 1LL< el? e : el; + e = e > er? e : er; + a[i].max = e; // set the max value for node i + } + last_i = last_i>>k&1? last_i - x : last_i + x; // last_i now points to the parent of the original last_i + if (last_i < a.size() && a[last_i].max > last) // update last accordingly + last = a[last_i].max; + } + return k - 1; + } +public: + void add(const S &s, const S &e, const T &d) { a.push_back(Interval(s, e, d)); } + void index(void) { + std::sort(a.begin(), a.end(), IntervalLess()); + max_level = index_core(a); + } + bool overlap(const S &st, const S &en, std::vector &out) const { + int t = 0; + StackCell stack[64]; + out.clear(); + if (max_level < 0) return false; + stack[t++] = StackCell(max_level, (1LL<> z.k << z.k, i1 = i0 + (1LL<<(z.k+1)) - 1; + if (i1 >= a.size()) i1 = a.size(); + for (i = i0; i < i1 && a[i].st < en; ++i) + if (st < a[i].en) // if overlap, append to out[] + out.push_back(i); + } else if (z.w == 0) { // if left child not processed + size_t y = z.x - (1LL<<(z.k-1)); // the left child of z.x; NB: y may be out of range (i.e. y>=a.size()) + stack[t++] = StackCell(z.k, z.x, 1); // re-add node z.x, but mark the left child having been processed + if (y >= a.size() || a[y].max > st) // push the left child if y is out of range or may overlap with the query + stack[t++] = StackCell(z.k - 1, y, 0); + } else if (z.x < a.size() && a[z.x].st < en) { // need to push the right child + if (st < a[z.x].en) out.push_back(z.x); // test if z.x overlaps the query; if yes, append to out[] + stack[t++] = StackCell(z.k - 1, z.x + (1LL<<(z.k-1)), 0); // push the right child + } + } + return out.size() > 0? true : false; + } + size_t size(void) const { return a.size(); } + const S &start(size_t i) const { return a[i].st; } + const S &end(size_t i) const { return a[i].en; } + const T &data(size_t i) const { return a[i].data; } +}; \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/test/3rd-party/IntervalTree.h b/sequila/sequila-core/superintervals/test/3rd-party/IntervalTree.h new file mode 100644 index 0000000..39ca5cc --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/IntervalTree.h @@ -0,0 +1,345 @@ +// https://raw.githubusercontent.com/ekg/intervaltree/master/IntervalTree.h +#ifndef __INTERVAL_TREE_H +#define __INTERVAL_TREE_H + +#include +#include +#include +#include +#include +#include + +//#ifdef USE_INTERVAL_TREE_NAMESPACE +namespace interval_tree { +//#endif +template +class Interval { +public: + Scalar start; + Scalar stop; + Value value; + Interval(const Scalar& s, const Scalar& e, const Value& v) + : start(std::min(s, e)) + , stop(std::max(s, e)) + , value(v) + {} +}; + +template +Value intervalStart(const Interval& i) { + return i.start; +} + +template +Value intervalStop(const Interval& i) { + return i.stop; +} + +template +std::ostream& operator<<(std::ostream& out, const Interval& i) { + out << "Interval(" << i.start << ", " << i.stop << "): " << i.value; + return out; +} + +template +class IntervalTree { +public: + typedef Interval interval; + typedef std::vector interval_vector; + + + struct IntervalStartCmp { + bool operator()(const interval& a, const interval& b) { + return a.start < b.start; + } + }; + + struct IntervalStopCmp { + bool operator()(const interval& a, const interval& b) { + return a.stop < b.stop; + } + }; + + IntervalTree() + : left(nullptr) + , right(nullptr) + , center(0) + {} + + ~IntervalTree() = default; + + std::unique_ptr clone() const { + return std::unique_ptr(new IntervalTree(*this)); + } + + IntervalTree(const IntervalTree& other) + : intervals(other.intervals), + left(other.left ? other.left->clone() : nullptr), + right(other.right ? other.right->clone() : nullptr), + center(other.center) + {} + + IntervalTree& operator=(IntervalTree&&) = default; + IntervalTree(IntervalTree&&) = default; + + IntervalTree& operator=(const IntervalTree& other) { + center = other.center; + intervals = other.intervals; + left = other.left ? other.left->clone() : nullptr; + right = other.right ? other.right->clone() : nullptr; + return *this; + } + + IntervalTree( + interval_vector&& ivals, + std::size_t depth = 16, + std::size_t minbucket = 64, + std::size_t maxbucket = 512, + Scalar leftextent = 0, + Scalar rightextent = 0) + : left(nullptr) + , right(nullptr) + { + --depth; + const auto minmaxStop = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStartCmp()); + if (!ivals.empty()) { + center = (minmaxStart.first->start + minmaxStop.second->stop) / 2; + } + if (leftextent == 0 && rightextent == 0) { + // sort intervals by start + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + } else { + assert(std::is_sorted(ivals.begin(), ivals.end(), IntervalStartCmp())); + } + if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + intervals = std::move(ivals); + assert(is_valid().first); + return; + } else { + Scalar leftp = 0; + Scalar rightp = 0; + + if (leftextent || rightextent) { + leftp = leftextent; + rightp = rightextent; + } else { + leftp = ivals.front().start; + rightp = std::max_element(ivals.begin(), ivals.end(), + IntervalStopCmp())->stop; + } + + interval_vector lefts; + interval_vector rights; + + for (typename interval_vector::const_iterator i = ivals.begin(); + i != ivals.end(); ++i) { + const interval& interval = *i; + if (interval.stop < center) { + lefts.push_back(interval); + } else if (interval.start > center) { + rights.push_back(interval); + } else { + assert(interval.start <= center); + assert(center <= interval.stop); + intervals.push_back(interval); + } + } + + if (!lefts.empty()) { + left.reset(new IntervalTree(std::move(lefts), + depth, minbucket, maxbucket, + leftp, center)); + } + if (!rights.empty()) { + right.reset(new IntervalTree(std::move(rights), + depth, minbucket, maxbucket, + center, rightp)); + } + } + assert(is_valid().first); + } + + // Call f on all intervals near the range [start, stop]: + template + void visit_near(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + if (!intervals.empty() && ! (stop < intervals.front().start)) { + for (auto & i : intervals) { + f(i); + } + } + if (left && start <= center) { + left->visit_near(start, stop, f); + } + if (right && stop >= center) { + right->visit_near(start, stop, f); + } + } + + // Call f on all intervals crossing pos + template + void visit_overlapping(const Scalar& pos, UnaryFunction f) const { + visit_overlapping(pos, pos, f); + } + + // Call f on all intervals overlapping [start, stop] + template + void visit_overlapping(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (interval.stop >= start && interval.start <= stop) { + // Only apply f if overlapping + f(interval); + } + }; + visit_near(start, stop, filterF); + } + + // Call f on all intervals contained within [start, stop] + template + void visit_contained(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (start <= interval.start && interval.stop <= stop) { + f(interval); + } + }; + visit_near(start, stop, filterF); + } + + interval_vector findOverlapping(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_overlapping(start, stop, + [&](const interval& interval) { + result.emplace_back(interval); + }); + return result; + } + + interval_vector findContained(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_contained(start, stop, + [&](const interval& interval) { + result.push_back(interval); + }); + return result; + } + bool empty() const { + if (left && !left->empty()) { + return false; + } + if (!intervals.empty()) { + return false; + } + if (right && !right->empty()) { + return false; + } + return true; + } + + template + void visit_all(UnaryFunction f) const { + if (left) { + left->visit_all(f); + } + std::for_each(intervals.begin(), intervals.end(), f); + if (right) { + right->visit_all(f); + } + } + + std::pair extentBruitForce() const { + struct Extent { + std::pair x = {std::numeric_limits::max(), + std::numeric_limits::min() }; + void operator()(const interval & interval) { + x.first = std::min(x.first, interval.start); + x.second = std::max(x.second, interval.stop); + } + }; + Extent extent; + + visit_all([&](const interval & interval) { extent(interval); }); + return extent.x; + } + + // Check all constraints. + // If first is false, second is invalid. + std::pair> is_valid() const { + const auto minmaxStop = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStartCmp()); + + std::pair> result = {true, { std::numeric_limits::max(), + std::numeric_limits::min() }}; + if (!intervals.empty()) { + result.second.first = std::min(result.second.first, minmaxStart.first->start); + result.second.second = std::min(result.second.second, minmaxStop.second->stop); + } + if (left) { + auto valid = left->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.second >= center) { + result.first = false; + return result; + } + } + if (right) { + auto valid = right->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.first <= center) { + result.first = false; + return result; + } + } + if (!std::is_sorted(intervals.begin(), intervals.end(), IntervalStartCmp())) { + result.first = false; + } + return result; + } + + friend std::ostream& operator<<(std::ostream& os, const IntervalTree& itree) { + return writeOut(os, itree); + } + + friend std::ostream& writeOut(std::ostream& os, const IntervalTree& itree, + std::size_t depth = 0) { + auto pad = [&]() { for (std::size_t i = 0; i != depth; ++i) { os << ' '; } }; + pad(); os << "center: " << itree.center << '\n'; + for (const interval & inter : itree.intervals) { + pad(); os << inter << '\n'; + } + if (itree.left) { + pad(); os << "left:\n"; + writeOut(os, *itree.left, depth + 1); + } else { + pad(); os << "left: nullptr\n"; + } + if (itree.right) { + pad(); os << "right:\n"; + writeOut(os, *itree.right, depth + 1); + } else { + pad(); os << "right: nullptr\n"; + } + return os; + } + +private: + interval_vector intervals; + std::unique_ptr left; + std::unique_ptr right; + Scalar center; +}; +//#ifdef USE_INTERVAL_TREE_NAMESPACE +} +//#endif + +#endif diff --git a/sequila/sequila-core/superintervals/test/3rd-party/__init__.py b/sequila/sequila-core/superintervals/test/3rd-party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sequila/sequila-core/superintervals/test/3rd-party/cgranges.c b/sequila/sequila-core/superintervals/test/3rd-party/cgranges.c new file mode 100644 index 0000000..3b4fe04 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/cgranges.c @@ -0,0 +1,330 @@ +#include +#include +#include "cgranges.h" +#include "khash.h" + +/************** + * Radix sort * + **************/ + +#define RS_MIN_SIZE 64 +#define RS_MAX_BITS 8 + +#define KRADIX_SORT_INIT(name, rstype_t, rskey, sizeof_key) \ + typedef struct { \ + rstype_t *b, *e; \ + } rsbucket_##name##_t; \ + void rs_insertsort_##name(rstype_t *beg, rstype_t *end) \ + { \ + rstype_t *i; \ + for (i = beg + 1; i < end; ++i) \ + if (rskey(*i) < rskey(*(i - 1))) { \ + rstype_t *j, tmp = *i; \ + for (j = i; j > beg && rskey(tmp) < rskey(*(j-1)); --j) \ + *j = *(j - 1); \ + *j = tmp; \ + } \ + } \ + void rs_sort_##name(rstype_t *beg, rstype_t *end, int n_bits, int s) \ + { \ + rstype_t *i; \ + int size = 1<b = k->e = beg; \ + for (i = beg; i != end; ++i) ++b[rskey(*i)>>s&m].e; \ + for (k = b + 1; k != be; ++k) \ + k->e += (k-1)->e - beg, k->b = (k-1)->e; \ + for (k = b; k != be;) { \ + if (k->b != k->e) { \ + rsbucket_##name##_t *l; \ + if ((l = b + (rskey(*k->b)>>s&m)) != k) { \ + rstype_t tmp = *k->b, swap; \ + do { \ + swap = tmp; tmp = *l->b; *l->b++ = swap; \ + l = b + (rskey(tmp)>>s&m); \ + } while (l != k); \ + *k->b++ = tmp; \ + } else ++k->b; \ + } else ++k; \ + } \ + for (b->b = beg, k = b + 1; k != be; ++k) k->b = (k-1)->e; \ + if (s) { \ + s = s > n_bits? s - n_bits : 0; \ + for (k = b; k != be; ++k) \ + if (k->e - k->b > RS_MIN_SIZE) rs_sort_##name(k->b, k->e, n_bits, s); \ + else if (k->e - k->b > 1) rs_insertsort_##name(k->b, k->e); \ + } \ + } \ + void radix_sort_##name(rstype_t *beg, rstype_t *end) \ + { \ + if (end - beg <= RS_MIN_SIZE) rs_insertsort_##name(beg, end); \ + else rs_sort_##name(beg, end, RS_MAX_BITS, (sizeof_key - 1) * RS_MAX_BITS); \ + } + +/********************* + * Convenient macros * + *********************/ + +#ifndef kroundup32 +#define kroundup32(x) (--(x), (x)|=(x)>>1, (x)|=(x)>>2, (x)|=(x)>>4, (x)|=(x)>>8, (x)|=(x)>>16, ++(x)) +#endif + +#define CALLOC(type, len) ((type*)calloc((len), sizeof(type))) +#define REALLOC(ptr, len) ((ptr) = (__typeof__(ptr))realloc((ptr), (len) * sizeof(*(ptr)))) + +#define EXPAND(a, m) do { \ + (m) = (m)? (m) + ((m)>>1) : 16; \ + REALLOC((a), (m)); \ + } while (0) + +/******************** + * Basic operations * + ********************/ + +#define cr_intv_key(r) ((r).x) +KRADIX_SORT_INIT(cr_intv, cr_intv_t, cr_intv_key, 8) + +KHASH_MAP_INIT_STR(str, int32_t) +typedef khash_t(str) strhash_t; + +cgranges_t *cr_init(void) +{ + cgranges_t *cr; + cr = CALLOC(cgranges_t, 1); + cr->hc = kh_init(str); + return cr; +} + +void cr_destroy(cgranges_t *cr) +{ + int32_t i; + if (cr == 0) return; + for (i = 0; i < cr->n_ctg; ++i) + free(cr->ctg[i].name); + free(cr->ctg); + kh_destroy(str, (strhash_t*)cr->hc); + free(cr); +} + +int32_t cr_add_ctg(cgranges_t *cr, const char *ctg, int32_t len) +{ + int absent; + khint_t k; + strhash_t *h = (strhash_t*)cr->hc; + k = kh_put(str, h, ctg, &absent); + if (absent) { + cr_ctg_t *p; + if (cr->n_ctg == cr->m_ctg) + EXPAND(cr->ctg, cr->m_ctg); + kh_val(h, k) = cr->n_ctg; + p = &cr->ctg[cr->n_ctg++]; + p->name = strdup(ctg); + kh_key(h, k) = p->name; + p->len = len; + p->n = 0, p->off = -1; + } + if (len > cr->ctg[kh_val(h, k)].len) + cr->ctg[kh_val(h, k)].len = len; + return kh_val(h, k); +} + +int32_t cr_get_ctg(const cgranges_t *cr, const char *ctg) +{ + khint_t k; + strhash_t *h = (strhash_t*)cr->hc; + k = kh_get(str, h, ctg); + return k == kh_end(h)? -1 : kh_val(h, k); +} + +cr_intv_t *cr_add(cgranges_t *cr, const char *ctg, int32_t st, int32_t en, int32_t label_int) +{ + cr_intv_t *p; + int32_t k; + if (st > en) return 0; + k = cr_add_ctg(cr, ctg, 0); + if (cr->n_r == cr->m_r) + EXPAND(cr->r, cr->m_r); + p = &cr->r[cr->n_r++]; + p->x = (uint64_t)k << 32 | st; + p->y = en; + p->label = label_int; + if (cr->ctg[k].len < en) + cr->ctg[k].len = en; + return p; +} + +void cr_sort(cgranges_t *cr) +{ + if (cr->n_ctg == 0 || cr->n_r == 0) return; + radix_sort_cr_intv(cr->r, cr->r + cr->n_r); +} + +int32_t cr_is_sorted(const cgranges_t *cr) +{ + uint64_t i; + for (i = 1; i < cr->n_r; ++i) + if (cr->r[i-1].x > cr->r[i].x) + break; + return (i == cr->n_r); +} + +/************ + * Indexing * + ************/ + +void cr_index_prepare(cgranges_t *cr) +{ + int64_t i, st; + if (!cr_is_sorted(cr)) cr_sort(cr); + for (st = 0, i = 1; i <= cr->n_r; ++i) { + if (i == cr->n_r || cr->r[i].x>>32 != cr->r[st].x>>32) { + int32_t ctg = cr->r[st].x>>32; + cr->ctg[ctg].off = st; + cr->ctg[ctg].n = i - st; + st = i; + } + } + for (i = 0; i < cr->n_r; ++i) { + cr_intv_t *r = &cr->r[i]; + r->x = r->x<<32 | r->y; + r->y = 0; + } +} + +int32_t cr_index1(cr_intv_t *a, int64_t n) +{ + int64_t i, last_i; + int32_t last, k; + if (n <= 0) return -1; + for (i = 0; i < n; i += 2) last_i = i, last = a[i].y = (int32_t)a[i].x; + for (k = 1; 1LL< el? e : el; + e = e > er? e : er; + a[i].y = e; + } + last_i = last_i>>k&1? last_i - x : last_i + x; + if (last_i < n && a[last_i].y > last) + last = a[last_i].y; + } + return k - 1; +} + +void cr_index(cgranges_t *cr) +{ + int32_t i; + cr_index_prepare(cr); + for (i = 0; i < cr->n_ctg; ++i) + cr->ctg[i].root_k = cr_index1(&cr->r[cr->ctg[i].off], cr->ctg[i].n); +} + +/********* + * Query * + *********/ + +int64_t cr_min_start_int(const cgranges_t *cr, int32_t ctg_id, int32_t st) // find the smallest i such that cr_st(&r[i]) >= st +{ + int64_t left, right; + const cr_ctg_t *c; + const cr_intv_t *r; + + if (ctg_id < 0 || ctg_id >= cr->n_ctg) return -1; + c = &cr->ctg[ctg_id]; + r = &cr->r[c->off]; + if (c->n == 0) return -1; + left = 0, right = c->n; + while (right > left) { + int64_t mid = left + ((right - left) >> 1); + if (cr_st(&r[mid]) >= st) right = mid; + else left = mid + 1; + } + assert(left == right); + return left == c->n? -1 : c->off + left; +} + +typedef struct { + int64_t x; + int32_t k, w; +} istack_t; + +int64_t cr_overlap_int(const cgranges_t *cr, int32_t ctg_id, int32_t st, int32_t en, int64_t **b_, int64_t *m_b_) +{ + int32_t t = 0; + const cr_ctg_t *c; + const cr_intv_t *r; + int64_t *b = *b_, m_b = *m_b_, n = 0; + istack_t stack[64], *p; + + if (ctg_id < 0 || ctg_id >= cr->n_ctg) return 0; + c = &cr->ctg[ctg_id]; + r = &cr->r[c->off]; + p = &stack[t++]; + p->k = c->root_k, p->x = (1LL<k) - 1, p->w = 0; // push the root into the stack + while (t) { // stack is not empyt + istack_t z = stack[--t]; + if (z.k <= 3) { // the subtree is no larger than (1<<(z.k+1))-1; do a linear scan + int64_t i, i0 = z.x >> z.k << z.k, i1 = i0 + (1LL<<(z.k+1)) - 1; + if (i1 >= c->n) i1 = c->n; + for (i = i0; i < i1 && cr_st(&r[i]) < en; ++i) + if (st < cr_en(&r[i])) { + if (n == m_b) EXPAND(b, m_b); + b[n++] = c->off + i; + } + } else if (z.w == 0) { // if left child not processed + int64_t y = z.x - (1LL<<(z.k-1)); + p = &stack[t++]; + p->k = z.k, p->x = z.x, p->w = 1; + if (y >= c->n || r[y].y > st) { + p = &stack[t++]; + p->k = z.k - 1, p->x = y, p->w = 0; // push the left child to the stack + } + } else if (z.x < c->n && cr_st(&r[z.x]) < en) { + if (st < cr_en(&r[z.x])) { // then z.x overlaps the query; write to the output array + if (n == m_b) EXPAND(b, m_b); + b[n++] = c->off + z.x; + } + p = &stack[t++]; + p->k = z.k - 1, p->x = z.x + (1LL<<(z.k-1)), p->w = 0; // push the right child + } + } + *b_ = b, *m_b_ = m_b; + return n; +} + +int64_t cr_contain_int(const cgranges_t *cr, int32_t ctg_id, int32_t st, int32_t en, int64_t **b_, int64_t *m_b_) +{ + int64_t n = 0, i, s, e, *b = *b_, m_b = *m_b_; + s = cr_min_start_int(cr, ctg_id, st); + if (s < 0) return 0; + e = cr->ctg[ctg_id].off + cr->ctg[ctg_id].n; + for (i = s; i < e; ++i) { + const cr_intv_t *r = &cr->r[i]; + if (cr_st(r) >= en) break; + if (cr_st(r) >= st && cr_en(r) <= en) { + if (n == m_b) EXPAND(b, m_b); + b[n++] = i; + } + } + *b_ = b, *m_b_ = m_b; + return n; +} + +int64_t cr_min_start(const cgranges_t *cr, const char *ctg, int32_t st) +{ + return cr_min_start_int(cr, cr_get_ctg(cr, ctg), st); +} + +int64_t cr_overlap(const cgranges_t *cr, const char *ctg, int32_t st, int32_t en, int64_t **b_, int64_t *m_b_) +{ + return cr_overlap_int(cr, cr_get_ctg(cr, ctg), st, en, b_, m_b_); +} + +int64_t cr_contain(const cgranges_t *cr, const char *ctg, int32_t st, int32_t en, int64_t **b_, int64_t *m_b_) +{ + return cr_contain_int(cr, cr_get_ctg(cr, ctg), st, en, b_, m_b_); +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/cgranges.h b/sequila/sequila-core/superintervals/test/3rd-party/cgranges.h new file mode 100644 index 0000000..16bbc63 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/cgranges.h @@ -0,0 +1,87 @@ +/* The MIT License + + Copyright (c) 2019 Dana-Farber Cancer Institute + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ +#ifndef CRANGES_H +#define CRANGES_H + +#include + +typedef struct { // a contig + char *name; // name of the contig + int32_t len; // max length seen in data + int32_t root_k; + int64_t n, off; // sum of lengths of previous contigs +} cr_ctg_t; + +typedef struct { // an interval + uint64_t x; // prior to cr_index(), x = ctg_id<<32|start_pos; after: x = start_pos<<32|end_pos + uint32_t y:31, rev:1; + int32_t label; // NOT used +} cr_intv_t; + +typedef struct { + int64_t n_r, m_r; // number and max number of intervals + cr_intv_t *r; // list of intervals (of size _n_r_) + int32_t n_ctg, m_ctg; // number and max number of contigs + cr_ctg_t *ctg; // list of contigs (of size _n_ctg_) + void *hc; // dictionary for converting contig names to integers +} cgranges_t; + +#ifdef __cplusplus +extern "C" { +#endif + +// retrieve start and end positions from a cr_intv_t object +static inline int32_t cr_st(const cr_intv_t *r) { return (int32_t)(r->x>>32); } +static inline int32_t cr_en(const cr_intv_t *r) { return (int32_t)r->x; } +static inline int32_t cr_start(const cgranges_t *cr, int64_t i) { return cr_st(&cr->r[i]); } +static inline int32_t cr_end(const cgranges_t *cr, int64_t i) { return cr_en(&cr->r[i]); } +static inline int32_t cr_label(const cgranges_t *cr, int64_t i) { return cr->r[i].label; } + +// Initialize +cgranges_t *cr_init(void); + +// Deallocate +void cr_destroy(cgranges_t *cr); + +// Add an interval +cr_intv_t *cr_add(cgranges_t *cr, const char *ctg, int32_t st, int32_t en, int32_t label_int); + +// Sort and index intervals +void cr_index(cgranges_t *cr); + +int64_t cr_overlap(const cgranges_t *cr, const char *ctg, int32_t st, int32_t en, int64_t **b_, int64_t *m_b_); +int64_t cr_contain(const cgranges_t *cr, const char *ctg, int32_t st, int32_t en, int64_t **b_, int64_t *m_b_); + +// Add a contig and length. Call this for desired contig ordering. _len_ can be 0. +int32_t cr_add_ctg(cgranges_t *cr, const char *ctg, int32_t len); + +// Get the contig ID given its name +int32_t cr_get_ctg(const cgranges_t *cr, const char *ctg); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/Cargo.lock b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/Cargo.lock new file mode 100644 index 0000000..4902da3 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/Cargo.lock @@ -0,0 +1,311 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anstream" +version = "0.6.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" + +[[package]] +name = "anstyle-parse" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +dependencies = [ + "anstyle", + "windows-sys", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "clap" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" + +[[package]] +name = "coitrees" +version = "0.4.0" +dependencies = [ + "clap", + "fnv", + "libc", + "rand", +] + +[[package]] +name = "colorchoice" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/Cargo.toml b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/Cargo.toml new file mode 100644 index 0000000..604fa5a --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "coitrees" +description = "A very fast data structure for overlap queries on sets of intervals." +version = "0.4.0" +authors = ["Daniel C. Jones "] +edition = "2021" +repository = "https://github.com/dcjones/coitrees" +documentation = "https://docs.rs/coitrees" +readme = "README.md" +license-file = "LICENSE" + +[profile.release] +debug = 0 +strip = "symbols" +lto = true +opt-level = 3 +codegen-units = 1 + +[features] +nosimd = [] + +[profile.dev.package."*"] +opt-level = 3 # Compile dependencies with optimizations on even in debug mode. + +[profile.no-opt] +inherits = "dev" +opt-level = 0 + +[profile.profiling] +inherits = "release" +debug = true +strip = false + +[dev-dependencies] +rand = "0.8" +fnv = "1.0.7" +libc = "0.2" +clap = { version = "4.3.7", features = ["derive"] } + +[[example]] +name = "bed-intersect" diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/LICENSE b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/LICENSE new file mode 100644 index 0000000..ea25c2f --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/LICENSE @@ -0,0 +1,19 @@ +Copyright © 2019 Daniel C. Jones (dcjones@cs.washington.edu) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/README.md b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/README.md new file mode 100644 index 0000000..11205a0 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/README.md @@ -0,0 +1,110 @@ + +# COITrees: Cache Oblivious Interval Trees + +COITrees implements a data structure for very fast overlap queries of a +static set of integer intervals, with genomic intervals in mind. + +Borrowing from [cgranges](https://github.com/lh3/cgranges), this data +structure stores intervals in contiguous memory, but improves query +performance by storing the nodes in in-order [van Emde Boas +layout](http://erikdemaine.org/papers/FOCS2000b/paper.pdf). Computing +the layout requires some extra time and memory, but improves average cache +locality for queries of the tree. If the interval set is relatively large, +and a sufficiently large number of queries are performed, it tends to out-perform +other data structures. + +The `SortedQuerent` type implements an alternative query strategy that keeps track +of the results of the previous query. When a query overlaps the previous one, +the results from that previous query can be reused to dramatically accelerate +the current one. (In the benchmarks, this is the `--sorted` option.) + +Some operations can further be sped up using SIMD instructions. Two COITree +variants are implemented to exploit AVX2 instructions on x86-64 cpus +(`AVXCOITree`), and Neon instructions on ARM cpus (`NeonCOITree`). The `COITree` +type is oppurtunistically defined to one of these types if the right instruction +set is detected. Typically it's necessary to compile with the environment +variable `RUSTFLAGS="-Ctarget-cpu=native"` set for this to work. The fallback +implemntation (`BasicCOITree`) supports any platform rust compiles to and +remains highly efficient. + +# Trying Out + +This is primary a library for use in other programs, but for benchmarking +purposes it includes a program for intersecting BED files. + +To try out, just clone this repo and run: +```shell +cargo run --release --example bed-intersect -- test1.bed test2.bed > intersections.bed +``` + +# Benchmarks + +`A` is 2,755,864 intervals from Ensembl's human genome annotations, `B` is +62,159,484 intervals from some RNA-Seq alignments, and `B'` is the first 2 +million lines of `B`. + +## Intervals in sorted order + +| | A vs B | B vs A | A vs A | B' vs B' | +| ----------------------------------- | ---------: | ---------: | -------: | ---------: | +| coitrees AVX | 11.8s | **3.7s** | 0.7 | 5.3s | +| coitrees AVX (`--sorted`) | 6.4s | 4.2s | **0.6s** | **0.5s** | +| coitrees | 11.4s | 5.2s | 0.8s | 8.3s | +| coitrees (`--sorted`) | **5.8s** | 5.4s | **0.6s** | **0.5s** | +| cgranges (`bedcov-cr -c`) | 35.4s | 6.6s | 2.0s | 17.6s | +| AIList | 13.8s | 10.1s | 1.1s | 18.4s | +| CITree | 20.1s | 13.5s | 1.6s | 45.7s | +| NCList | 22.5s | 16.8s | 1.9s | 39.8s | +| AITree | 23.8s | 26.3s | 2.1s | 63.4s | +| `bedtools coverage -counts -sorted` | 257.5s | 295.6s | 71.6s | 2130.9s | +| `bedtools coverage -counts` | 322.4s | 378.5s | 75.0s | 3595.9s | + +### With coverage + +| | A vs B | B vs A | A vs A | B' vs B' | +| ----------------------------------- | ---------: | ---------: | -------: | ---------: | +| coitrees AVX | 18.2s | **4.8s** | 1.1s | 16.0s | +| coitrees | **14.6s** | 5.7s | **1.0s** | **12.0s** | +| cgranges | 38.4s | 8.1s | 2.2s | 31.0s | +| CITree | 23.2s | 25.6s | 2.0s | 160.4s | + +## Intervals in randomized order + +| | A vs B | B vs A | A vs A | B' vs B' | +| ----------------------------------- | ---------: | ---------: | -------: | --------: | +| coitrees AVX | **23.9s** | **7.2s** | **1.6s** | **6.1s** | +| coitrees | 24.2s | 8.9s | 1.9s | 9.4s | +| cgranges (`bedcov-cr -c`) | 55.7s | 11.1s | 3.3s | 19.6s | +| AIList | 31.2s | 18.2s | 2.3s | 19.3s | +| CITree | 39.4s | 19.0s | 2.9s | 47.1s | +| NCList | 42.7s | 23.8s | 3.4s | 44.0s | +| AITree | 225.3s | 134.8s | 14.7s | 921.6s | +| `bedtools coverage -counts` | 1160.4s | 849.6s | 104.5s | 9254.6s | + +### With coverage + +| | A vs B | B vs A | A vs A | B' vs B' | +| ----------------------------------- | ---------: | ---------: | -------: | ---------: | +| coitrees AVX | 34.3s | **8.8s** | **2.2s** | 16.3s | +| coitrees | **29.6s** | 9.7s | 2.3s | **13.1s** | +| cgranges | 57.6s | 12.5s | 3.6s | 32.6s | +| CITree | 50.0s | 32.5s | 3.8s | 170.4s | + + +All benchmarks run on a ryzen 5950x. + +# Discussion + +These benchmarks are somewhat realistic in that they use real data, but are +not entirely apples-to-apples because they all involve parsing and writing +BED files. Most of the programs (including the one implemented in coitrees) +have incomplete BED parsers, and some use other shortcuts like assuming a +fixed set of chromosomes with specific naming schemes. + +`bedtools` carries the disadvantage of being an actually useful tool, rather +than implemented being implemented entirely for the purpose of winning benchmark +games. It seems clear it could be a lot faster, but there no doubt some cost can +be chalked up to featurefulness, completeness, and safety. + +If you have a BED intersection program you suspect may be faster (or just +interesting), please let me know and I'll try to benchmark it. diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/examples/bed-intersect.rs b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/examples/bed-intersect.rs new file mode 100644 index 0000000..93d1e37 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/examples/bed-intersect.rs @@ -0,0 +1,218 @@ +use coitrees::*; + +use std::error::Error; +use std::fs::File; +use std::io::{BufRead, BufReader, Write}; +use std::str; +use std::time::Instant; + +extern crate fnv; +use fnv::FnvHashMap; + +use clap::Parser; + +extern crate libc; + +type GenericError = Box; + +// Parse a i32 with no checking whatsoever. (e.g. non-number characters will just) +fn i32_from_bytes_uncheckd(s: &[u8]) -> i32 { + if s.is_empty() { + 0 + } else if s[0] == b'-' { + -s[1..].iter().fold(0, |a, b| a * 10 + (b & 0x0f) as i32) + } else { + s.iter().fold(0, |a, b| a * 10 + (b & 0x0f) as i32) + } +} + +fn parse_bed_line(line: &[u8]) -> (&str, i32, i32) { + let n = line.len() - 1; + let mut p = 0; + for c in &line[p..n] { + if *c == b'\t' { + break; + } + p += 1; + } + let seqname = unsafe { str::from_utf8_unchecked(&line[..p]) }; + p += 1; + let p0 = p; + + for c in &line[p..n] { + if *c == b'\t' { + break; + } + p += 1; + } + let first = i32_from_bytes_uncheckd(&line[p0..p]); + p += 1; + let p0 = p; + + for c in &line[p..n] { + if *c == b'\t' { + break; + } + p += 1; + } + let last = i32_from_bytes_uncheckd(&line[p0..p]) - 1; + + (seqname, first, last) +} + +type IntervalHashMap = FnvHashMap>>; + +// Read a bed file into a COITree +fn read_bed_file( + path: &str, + name: &str, +) -> Result>, GenericError> { + let mut nodes = IntervalHashMap::default(); + let file = File::open(path)?; + let mut rdr = BufReader::new(file); + let mut line = Vec::new(); + while rdr.read_until(b'\n', &mut line).unwrap() > 0 { + let (seqname, first, last) = parse_bed_line(&line); + if seqname != "chr1" || last < 0 || last < first { + line.clear(); + continue; + } + let node_arr = if let Some(node_arr) = nodes.get_mut(seqname) { + node_arr + } else { + nodes.entry(seqname.to_string()).or_insert(Vec::new()) + }; + node_arr.push(Interval::new(first, last, ())); + line.clear(); + } + let now = Instant::now(); + let mut trees = FnvHashMap::>::default(); + for (seqname, seqname_nodes) in nodes { + trees.insert(seqname, COITree::new(&seqname_nodes)); + } + eprint!("{},{},", name, now.elapsed().as_micros()); + std::io::stderr().flush().unwrap(); + Ok(trees) +} + +fn query_bed_files(filename_a: &str, filename_b: &str) -> Result<(), GenericError> { + let tree = read_bed_file(filename_a, "Coitrees")?; + let file = File::open(filename_b)?; + let mut rdr = BufReader::new(file); + let mut ranges: Vec<(i32, i32)> = Vec::new(); + let mut line = Vec::new(); + while rdr.read_until(b'\n', &mut line).unwrap() > 0 { + let (chrom, first, last) = parse_bed_line(&line); + if chrom != "chr1" || last < 0 || last < first { + line.clear(); + continue; + } + ranges.push((first, last)); + line.clear(); + } + let seqname_tree = tree.get("chr1").ok_or("Chromosome tree not found")?; + + // Find overlaps (collecting results) + let mut total_found = 0; + let mut results = Vec::new(); + results.reserve(10000); + let mut now = Instant::now(); + for &(first, last) in &ranges { + seqname_tree.query(first, last, |node| { + results.push(node.metadata.clone()); + }); + total_found += results.len(); + results.clear(); + } + eprint!("{},{},", now.elapsed().as_micros(), total_found); + std::io::stderr().flush().unwrap(); + + // Count overlaps + now = Instant::now(); + let total_count: usize = ranges + .iter() + .map(|&(first, last)| seqname_tree.query_count(first, last)) + .sum(); + eprint!("{},{}\n", now.elapsed().as_micros(), total_count); + std::io::stderr().flush().unwrap(); + + Ok(()) +} + +fn query_bed_files_with_sorted_querent( + filename_a: &str, + filename_b: &str, +) -> Result<(), GenericError> { + let trees = read_bed_file(filename_a, "Coitrees-s")?; + + let mut querents = FnvHashMap::>::default(); + for (seqname, tree) in &trees { + querents.insert(seqname.clone(), COITreeSortedQuerent::new(tree)); + } + + let file = File::open(filename_b)?; + let mut rdr = BufReader::new(file); + let mut ranges: Vec<(i32, i32)> = Vec::new(); + let mut line = Vec::new(); + while rdr.read_until(b'\n', &mut line).unwrap() > 0 { + let (_, first, last) = parse_bed_line(&line); + ranges.push((first, last)); + line.clear(); + } + + // Use `get_mut` to get a mutable reference + let seqname_tree = querents + .get_mut("chr1") + .ok_or("Chromosome tree not found")?; + + // Find overlaps (collecting results) + let mut total_found = 0; + let mut results = Vec::new(); + results.reserve(10000); + let now = Instant::now(); + for &(first, last) in &ranges { + seqname_tree.query(first, last, |node| { + results.push(node.metadata); + }); + total_found += results.len(); + results.clear(); + } + eprint!("{},{}\n", now.elapsed().as_micros(), total_found); + std::io::stderr().flush().unwrap(); + + Ok(()) +} + +#[derive(Parser, Debug)] +#[command(about = " Find overlaps between two groups of intervals ")] +struct Args { + /// intervals to index + #[arg(value_name = "intervals.bed")] + input1: String, + + /// query intervals + #[arg(value_name = "queries.bed")] + input2: String, + + /// use alternative search strategy that's faster if queries are sorted and tend to overlap + #[arg(short = 's', long = "sorted")] + use_sorted_querent: bool, +} + +fn main() { + let matches = Args::parse(); + + let input1 = matches.input1.as_str(); + let input2 = matches.input2.as_str(); + + let result; + + if matches.use_sorted_querent { + result = query_bed_files_with_sorted_querent(input1, input2); + } else { + result = query_bed_files(input1, input2); + } + if let Err(err) = result { + println!("error: {}", err) + } +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/avx.rs b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/avx.rs new file mode 100644 index 0000000..07eb7ae --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/avx.rs @@ -0,0 +1,1506 @@ +use std::arch::x86_64::{ + __m256i, _mm256_and_si256, _mm256_cmpgt_epi32, _mm256_extract_epi32, _mm256_movemask_epi8, + _mm256_set1_epi32, _mm256_set_epi32, +}; + +use super::interval::{GenericInterval, IntWithMax, Interval, IntervalTree, SortedQuerent}; +use std::cmp::{max, Ordering}; +use std::fmt::Debug; +use std::marker::Copy; +use std::mem::transmute; + +#[allow(non_camel_case_types)] +type i32x8 = __m256i; + +// Small subtrees at the bottom of the tree are stored in sorted order +// This gives the upper bound on the size of such subtrees. Performance isn't +// super sensitive, but is worse with a very small or very large number. +const SIMPLE_SUBTREE_CUTOFF: usize = 8; + +// Very dense subtrees in which we probably intersect most of the intervals +// are more efficient to query linearly. When the expected proportion of hits +// is a above this number it becomes a simple subtree. +const SIMPLE_SUBTREE_DENSITY_CUTOFF: f32 = 0.2; + +/// Node in the interval tree. Each node holds a chunk of 8 intervals. +#[derive(Clone)] +struct IntervalNode +where + T: Clone, + I: IntWithMax, +{ + // subtree interval + subtree_last: i32, + + firsts: i32x8, + lasts: i32x8, + metadata: [T; 8], + + // when this is the root of a simple subtree, left == right is the size + // of the subtree, otherwise they are left, right child pointers. + left: I, + right: I, +} + +impl IntervalNode +where + T: Clone, + I: IntWithMax, +{ + fn new(intervals: &[Interval; 8]) -> Self + where + T: Copy, + { + assert!(!intervals.is_empty() && intervals.len() <= 8); + + unsafe { + let mut max_last = intervals[0].last; + for interval in &intervals[1..] { + max_last = max(max_last, interval.last); + } + + // the firsts and lasts are adjust by 1 here because AVX has an + // instruction for > but not >=. + let firsts = _mm256_set_epi32( + intervals[7].first - 1, + intervals[6].first - 1, + intervals[5].first - 1, + intervals[4].first - 1, + intervals[3].first - 1, + intervals[2].first - 1, + intervals[1].first - 1, + intervals[0].first - 1, + ); + + let lasts = _mm256_set_epi32( + intervals[7].last + 1, + intervals[6].last + 1, + intervals[5].last + 1, + intervals[4].last + 1, + intervals[3].last + 1, + intervals[2].last + 1, + intervals[1].last + 1, + intervals[0].last + 1, + ); + + let metadata: [T; 8] = [ + intervals[0].metadata, + intervals[1].metadata, + intervals[2].metadata, + intervals[3].metadata, + intervals[4].metadata, + intervals[5].metadata, + intervals[6].metadata, + intervals[7].metadata, + ]; + + Self { + subtree_last: max_last, + firsts, + lasts, + metadata, + left: I::MAX, + right: I::MAX, + } + } + } + + fn first(&self, i: usize) -> i32 { + assert!(i < 8); + (unsafe { transmute::<&i32x8, &[i32; 8]>(&self.firsts) }[i]) + } + + fn last(&self, i: usize) -> i32 { + assert!(i < 8); + (unsafe { transmute::<&i32x8, &[i32; 8]>(&self.lasts) }[i]) + } + + fn min_first(&self) -> i32 { + unsafe { _mm256_extract_epi32(self.firsts, 0) + 1 } + } + + fn max_last(&self) -> i32 { + // faster ways to do this, but doesn't really matter + unsafe { + max( + max( + max( + _mm256_extract_epi32(self.lasts, 0), + _mm256_extract_epi32(self.lasts, 1), + ), + max( + _mm256_extract_epi32(self.lasts, 2), + _mm256_extract_epi32(self.lasts, 3), + ), + ), + max( + max( + _mm256_extract_epi32(self.lasts, 4), + _mm256_extract_epi32(self.lasts, 5), + ), + max( + _mm256_extract_epi32(self.lasts, 6), + _mm256_extract_epi32(self.lasts, 7), + ), + ), + ) - 1 + } + } + + #[inline(always)] + fn query_count_chunk(&self, query_first: i32x8, query_last: i32x8) -> usize { + unsafe { + let cmp1 = _mm256_cmpgt_epi32(query_last, self.firsts); + let cmp2 = _mm256_cmpgt_epi32(self.lasts, query_first); + let cmp = _mm256_movemask_epi8(_mm256_and_si256(cmp1, cmp2)); + let count = cmp.count_ones() / 4; + + count as usize + } + } + + fn query_chunk_firsts<'a, F>(&'a self, query_first: i32x8, query_last: i32x8, mut visit: F) + where + F: FnMut(i32, i32, &'a T), + { + let cmp: u32 = unsafe { + let cmp1 = _mm256_cmpgt_epi32(query_last, self.firsts); + let cmp2 = _mm256_cmpgt_epi32(self.firsts, query_first); + let cmp = _mm256_movemask_epi8(_mm256_and_si256(cmp1, cmp2)); + transmute(cmp) + }; + + let masks = [ + 0xf, 0xf0, 0xf00, 0xf000, 0xf0000, 0xf00000, 0xf000000, 0xf0000000, + ]; + if cmp & masks[0] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 0) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 0) - 1 }, + &self.metadata[0], + ); + } + if cmp & masks[1] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 1) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 1) - 1 }, + &self.metadata[1], + ); + } + if cmp & masks[2] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 2) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 2) - 1 }, + &self.metadata[2], + ); + } + if cmp & masks[3] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 3) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 3) - 1 }, + &self.metadata[3], + ); + } + if cmp & masks[4] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 4) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 4) - 1 }, + &self.metadata[4], + ); + } + if cmp & masks[5] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 5) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 5) - 1 }, + &self.metadata[5], + ); + } + if cmp & masks[6] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 6) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 6) - 1 }, + &self.metadata[6], + ); + } + if cmp & masks[7] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 7) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 7) - 1 }, + &self.metadata[7], + ); + } + } + + fn query_chunk_metadata<'a, F>(&'a self, query_first: i32x8, query_last: i32x8, mut visit: F) + where + F: FnMut(i32, i32, &'a T), + { + let cmp: u32 = unsafe { + let cmp1 = _mm256_cmpgt_epi32(query_last, self.firsts); + let cmp2 = _mm256_cmpgt_epi32(self.lasts, query_first); + let cmp = _mm256_movemask_epi8(_mm256_and_si256(cmp1, cmp2)); + transmute(cmp) + }; + + // could be made nicer with a macro perhaps? + let masks = [ + 0xf, 0xf0, 0xf00, 0xf000, 0xf0000, 0xf00000, 0xf000000, 0xf0000000, + ]; + if cmp & masks[0] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 0) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 0) - 1 }, + &self.metadata[0], + ); + } + if cmp & masks[1] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 1) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 1) - 1 }, + &self.metadata[1], + ); + } + if cmp & masks[2] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 2) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 2) - 1 }, + &self.metadata[2], + ); + } + if cmp & masks[3] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 3) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 3) - 1 }, + &self.metadata[3], + ); + } + if cmp & masks[4] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 4) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 4) - 1 }, + &self.metadata[4], + ); + } + if cmp & masks[5] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 5) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 5) - 1 }, + &self.metadata[5], + ); + } + if cmp & masks[6] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 6) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 6) - 1 }, + &self.metadata[6], + ); + } + if cmp & masks[7] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 7) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 7) - 1 }, + &self.metadata[7], + ); + } + } + + fn query_chunk(&self, query_first: i32x8, query_last: i32x8, mut visit: F) + where + F: FnMut(i32, i32), + { + let cmp: u32 = unsafe { + let cmp1 = _mm256_cmpgt_epi32(query_last, self.firsts); + let cmp2 = _mm256_cmpgt_epi32(self.lasts, query_first); + let cmp = _mm256_movemask_epi8(_mm256_and_si256(cmp1, cmp2)); + transmute(cmp) + }; + + let masks = [ + 0xf, 0xf0, 0xf00, 0xf000, 0xf0000, 0xf00000, 0xf000000, 0xf0000000, + ]; + + // this is a bit slower + // for (i, mask) in masks.iter().enumerate() { + // if cmp & mask != 0 { + // visit(self.first(i) + 1, self.last(i) - 1) + // } + // } + + if cmp & masks[0] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 0) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 0) - 1 }, + ); + } + if cmp & masks[1] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 1) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 1) - 1 }, + ); + } + if cmp & masks[2] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 2) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 2) - 1 }, + ); + } + if cmp & masks[3] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 3) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 3) - 1 }, + ); + } + if cmp & masks[4] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 4) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 4) - 1 }, + ); + } + if cmp & masks[5] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 5) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 5) - 1 }, + ); + } + if cmp & masks[6] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 6) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 6) - 1 }, + ); + } + if cmp & masks[7] != 0 { + visit( + unsafe { _mm256_extract_epi32(self.firsts, 7) + 1 }, + unsafe { _mm256_extract_epi32(self.lasts, 7) - 1 }, + ); + } + } +} + +/// COITree data structure. A representation of a static set of intervals with +/// associated metadata, enabling fast overlap and coverage queries. +/// +/// The index type `I` is a typically `usize`, but can be `u32` or `u16`. +/// It's slightly more efficient to use a smalled index type, assuming there are +/// fewer than I::MAX-1 intervals to store. +#[derive(Clone)] +pub struct AVXCOITree +where + T: Clone, + I: IntWithMax, +{ + nodes: Vec>, + len: usize, + root_idx: usize, + height: usize, +} + +impl AVXCOITree +where + T: Default + Copy + Clone, + I: IntWithMax, +{ + fn chunk_intervals(intervals: Vec>) -> Vec> { + let n = intervals.len(); + let num_chunks = (n / 8) + (n % 8 != 0) as usize; + let mut nodes: Vec> = Vec::with_capacity(num_chunks); + + // chunk initializer + let mut chunk_init: [Interval; 8] = [Interval { + first: i32::MAX, + last: i32::MIN, + metadata: T::default(), + }; 8]; + + for chunk in intervals.chunks(8) { + for (j, interval) in chunk.iter().enumerate() { + chunk_init[j] = *interval; + } + + chunk_init.iter_mut().skip(chunk.len()).for_each(|item| { + *item = Interval { + first: i32::MAX, + last: i32::MIN, + metadata: T::default(), + }; + }); + + nodes.push(IntervalNode::new(&chunk_init)); + } + + nodes + } +} + +impl<'a, T, I> IntervalTree<'a> for AVXCOITree +where + T: Default + Copy + Clone + 'a, + I: IntWithMax + 'a, +{ + type Metadata = T; + type Index = I; + type Item = Interval<&'a T>; + type Iter = AVXCOITreeIterator<'a, T, I>; + + fn new<'b, U, V>(intervals: U) -> AVXCOITree + where + U: IntoIterator, + V: GenericInterval + 'b, + { + let mut intervals: Vec<_> = intervals + .into_iter() + .map(|interval| { + Interval::new( + interval.first(), + interval.last(), + interval.metadata().clone(), + ) + }) + .collect(); + + if intervals.len() >= (I::MAX).to_usize() { + panic!("COITree construction failed: more intervals than index type can enumerate") + } + + let n = intervals.len(); + intervals.sort_unstable_by_key(|interval| (interval.first, interval.last)); + let nodes = Self::chunk_intervals(intervals); + + let (nodes, root_idx, height) = veb_order(nodes); + AVXCOITree { + nodes, + len: n, + root_idx, + height, + } + } + + /// Number of intervals in the set. + fn len(&self) -> usize { + self.len + } + + /// True iff the set is empty. + fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + // /// Find intervals in the set overlaping the query `[first, last]` and call `visit` on every overlapping node + fn query(&'a self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&Interval<&'a T>), + { + let (firstv, lastv) = unsafe { (_mm256_set1_epi32(first), _mm256_set1_epi32(last)) }; + + if !self.is_empty() { + query_recursion( + &self.nodes, + self.root_idx, + first, + last, + firstv, + lastv, + &mut visit, + ); + } + } + + /// Count the number of intervals in the set overlapping the query `[first, last]`. + fn query_count(&self, first: i32, last: i32) -> usize { + let (firstv, lastv) = unsafe { (_mm256_set1_epi32(first), _mm256_set1_epi32(last)) }; + + if !self.is_empty() { + query_recursion_count(&self.nodes, self.root_idx, first, last, firstv, lastv) + } else { + 0 + } + } + + /// Return a pair `(count, cov)`, where `count` gives the number of intervals + /// in the set overlapping the query, and `cov` the number of positions in the query + /// interval covered by at least one interval in the set. + fn coverage(&self, first: i32, last: i32) -> (usize, usize) { + assert!(last >= first); + + if self.is_empty() { + return (0, 0); + } + + let (firstv, lastv) = unsafe { (_mm256_set1_epi32(first), _mm256_set1_epi32(last)) }; + + let (mut uncov_len, last_cov, count) = coverage_recursion( + &self.nodes, + self.root_idx, + first, + last, + firstv, + lastv, + first - 1, + ); + + if last_cov < last { + uncov_len += last - last_cov; + } + + let cov = ((last - first + 1) as usize) - (uncov_len as usize); + + (count, cov) + } + + /// Iterate through the interval set in sorted order by interval start position. + fn iter(&self) -> AVXCOITreeIterator { + let mut i = self.root_idx; + let mut stack: Vec = Vec::with_capacity(self.height); + while i < self.nodes.len() + && self.nodes[i].left != I::MAX + && self.nodes[i].left != self.nodes[i].right + { + stack.push(i); + i = self.nodes[i].left.to_usize(); + } + + AVXCOITreeIterator { + nodes: &self.nodes, + len: self.len, + i, + j: 0, + count: 0, + stack, + } + } +} + +impl<'a, T, I> IntoIterator for &'a AVXCOITree +where + T: Default + Copy + Clone, + I: IntWithMax, +{ + type Item = Interval<&'a T>; + type IntoIter = AVXCOITreeIterator<'a, T, I>; + + fn into_iter(self) -> AVXCOITreeIterator<'a, T, I> { + return self.iter(); + } +} + +/// Iterate through nodes in a tree in sorted order by interval start position. +pub struct AVXCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + nodes: &'a Vec>, + len: usize, // total number of intervals + i: usize, // current node + j: usize, // offset into the current chunk + count: usize, // number generated so far + stack: Vec, +} + +impl<'a, T, I> Iterator for AVXCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + type Item = Interval<&'a T>; + + fn next(&mut self) -> Option { + if self.i * 8 + self.j >= self.len { + return None; + } + + let node = &self.nodes[self.i]; + if self.j < 8 { + let ret = Some(Interval { + first: node.first(self.j), + last: node.last(self.j), + metadata: &node.metadata[self.j], + }); + self.count += 1; + self.j += 1; + return ret; + } + + // find the next node + self.j = 0; + if node.left == node.right { + // simple node + if node.left.to_usize() > 1 { + self.i += 1; + } else if let Some(i) = self.stack.pop() { + self.i = i; + } else { + self.i = usize::MAX; + } + } else if node.right == I::MAX { + if let Some(i) = self.stack.pop() { + self.i = i; + } else { + self.i = usize::MAX; + } + } else { + let mut i = node.right.to_usize(); + + while self.nodes[i].left != I::MAX && self.nodes[i].left != self.nodes[i].right { + self.stack.push(i); + i = self.nodes[i].left.to_usize(); + } + + self.i = i; + } + + let node = &self.nodes[self.i]; + self.count += 1; + Some(Interval { + first: node.first(self.j), + last: node.last(self.j), + metadata: &node.metadata[self.j], + }) + } + + fn size_hint(&self) -> (usize, Option) { + let len_remaining = self.len - self.count; + (len_remaining, Some(len_remaining)) + } +} + +impl<'a, T, I> ExactSizeIterator for AVXCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + fn len(&self) -> usize { + self.len - self.count + } +} + +// // Recursively count overlaps between the tree specified by `nodes` and a +// // query interval specified by `first`, `last`. +fn query_recursion<'a, T, I, F>( + nodes: &'a [IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x8, + lastv: i32x8, + visit: &mut F, +) where + T: Clone, + I: IntWithMax, + F: FnMut(&Interval<&'a T>), +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } + + node.query_chunk_metadata(firstv, lastv, |first_hit, last_hit, metadata| { + visit(&Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }); + } + } else { + node.query_chunk_metadata(firstv, lastv, |first_hit, last_hit, metadata| { + visit(&Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }); + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + query_recursion(nodes, left, first, last, firstv, lastv, visit); + } + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.min_first(), nodes[right].subtree_last, first, last) { + query_recursion(nodes, right, first, last, firstv, lastv, visit); + } + } + } +} + +// query_recursion but just count number of overlaps +fn query_recursion_count( + nodes: &[IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x8, + lastv: i32x8, +) -> usize +where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + let mut count = 0; + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } else { + count += node.query_count_chunk(firstv, lastv); + } + } + count + } else { + let mut count = 0; + count += node.query_count_chunk(firstv, lastv); + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + count += query_recursion_count(nodes, left, first, last, firstv, lastv); + } + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.min_first(), nodes[right].subtree_last, first, last) { + count += query_recursion_count(nodes, right, first, last, firstv, lastv); + } + } + + count + } +} + +fn coverage_recursion( + nodes: &[IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x8, + lastv: i32x8, + mut last_cov: i32, +) -> (i32, i32, usize) +where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + let mut count = 0; + let mut uncov_len = 0; + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } + + node.query_chunk(firstv, lastv, |first_hit, last_hit| { + if first_hit > last_cov { + uncov_len += first_hit - (last_cov + 1); + } + last_cov = last_cov.max(last_hit); + count += 1; + }); + } + (uncov_len, last_cov, count) + } else { + let mut uncov_len = 0; + let mut count = 0; + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + let (left_uncov_len, left_last_cov, left_count) = + coverage_recursion(nodes, left, first, last, firstv, lastv, last_cov); + last_cov = left_last_cov; + uncov_len += left_uncov_len; + count += left_count; + } + } + + node.query_chunk(firstv, lastv, |first_hit, last_hit| { + if first_hit > last_cov { + uncov_len += first_hit - (last_cov + 1); + } + last_cov = last_cov.max(last_hit); + count += 1; + }); + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.min_first(), nodes[right].subtree_last, first, last) { + let (right_uncov_len, right_last_cov, right_count) = + coverage_recursion(nodes, right, first, last, firstv, lastv, last_cov); + last_cov = right_last_cov; + uncov_len += right_uncov_len; + count += right_count; + } + } + + (uncov_len, last_cov, count) + } +} + +/// An alternative query strategy that can be much faster when queries are performed +/// in sorted order and overlap. +/// +/// Unilke `COITree::query`, some state is retained between queries. +/// `SortedQuerent` tracks that state. If queries are not sorted or don't +/// overlap, this strategy still works, but is slightly slower than +/// `COITree::query`. +pub struct AVXSortedQuerent<'a, T, I> +where + T: Default + Clone, + I: IntWithMax, +{ + tree: &'a AVXCOITree, + prev_first: i32, + prev_last: i32, + overlapping_intervals: Vec>, +} + +impl<'a, T, I> SortedQuerent<'a> for AVXSortedQuerent<'a, T, I> +where + T: Default + Clone + Copy, + I: IntWithMax, +{ + type Metadata = T; + type Index = I; + type Item = Interval<&'a T>; + type Iter = AVXCOITreeIterator<'a, T, I>; + type Tree = AVXCOITree; + + /// Construct a new `SortedQuerent` to perform a sequence. queries. + fn new(tree: &'a AVXCOITree) -> AVXSortedQuerent<'a, T, I> { + let overlapping_intervals: Vec> = Vec::new(); + AVXSortedQuerent { + tree, + prev_first: -1, + prev_last: -1, + overlapping_intervals, + } + } + + /// Find intervals in the underlying `COITree` that overlap the query + /// `[first, last]` and call `visit` on each. Works equivalently to + /// `COITrees::query` but queries that overlap prior queries will potentially + /// be faster. + fn query(&mut self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&Interval<&'a T>), + { + if self.tree.is_empty() { + return; + } + + // not overlaping or preceding + if first < self.prev_first || first > self.prev_last { + // no overlap with previous query. have to resort to regular query strategy + self.overlapping_intervals.clear(); + self.tree + .query(first, last, |node| self.overlapping_intervals.push(*node)); + } else { + // successor query, exploit the overlap + + // delete previously overlapping intervals with end in [prev_first, first-1] + if self.prev_first < first { + let mut i = 0; + while i < self.overlapping_intervals.len() { + if self.overlapping_intervals[i].last < first { + self.overlapping_intervals.swap_remove(i); + } else { + i += 1; + } + } + } + + // delete previously overlapping intervals with start in [last+1, prev_end] + if self.prev_last > last { + let mut i = 0; + while i < self.overlapping_intervals.len() { + if self.overlapping_intervals[i].first > last { + self.overlapping_intervals.swap_remove(i); + } else { + i += 1; + } + } + } + + // add any interval that start in [prev_last+1, last] + if self.prev_last < last { + let qa = self.prev_last + 1 - 2; // -2 accounts for the adjustment made in the chunk + let qb = last; + + let (qav, qbv) = unsafe { (_mm256_set1_epi32(qa), _mm256_set1_epi32(qb)) }; + + sorted_querent_query_firsts( + &self.tree.nodes, + self.tree.root_idx, + qa, + qb, + qav, + qbv, + &mut self.overlapping_intervals, + ); + } + } + + // call visit on everything + for overlapping_interval in &self.overlapping_intervals { + visit(overlapping_interval); + } + + self.prev_first = first; + self.prev_last = last; + } +} + +// find any intervals in the tree with their first value in [first, last] +fn sorted_querent_query_firsts<'a, T, I>( + nodes: &'a [IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x8, + lastv: i32x8, + overlapping_intervals: &mut Vec>, +) where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } + + node.query_chunk_firsts(firstv, lastv, |first_hit, last_hit, metadata| { + overlapping_intervals.push(Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }) + } + } else { + node.query_chunk_firsts(firstv, lastv, |first_hit, last_hit, metadata| { + overlapping_intervals.push(Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }); + + if node.left < I::MAX && first <= node.min_first() { + let left = node.left.to_usize(); + sorted_querent_query_firsts( + nodes, + left, + first, + last, + firstv, + lastv, + overlapping_intervals, + ); + } + + if node.right < I::MAX && last >= node.min_first() { + let right = node.right.to_usize(); + sorted_querent_query_firsts( + nodes, + right, + first, + last, + firstv, + lastv, + overlapping_intervals, + ); + } + } +} + +// True iff the two intervals overlap. +#[inline(always)] +fn overlaps(first_a: i32, last_a: i32, first_b: i32, last_b: i32) -> bool { + first_a <= last_b && last_a >= first_b +} + +// Used by `traverse` to keep record tree metadata. +#[derive(Clone, Debug)] +struct TraversalInfo +where + I: IntWithMax, +{ + depth: u32, + inorder: I, // in-order visit number + preorder: I, // pre-order visit number + subtree_size: I, + parent: I, + expected_hit_proportion: f32, +} + +impl Default for TraversalInfo +where + I: IntWithMax, +{ + fn default() -> Self { + TraversalInfo { + depth: 0, + inorder: I::default(), + preorder: I::default(), + subtree_size: I::one(), + parent: I::MAX, + expected_hit_proportion: 0.0, + } + } +} + +// dfs traversal of an implicit bst computing dfs number, node depth, subtree +// size, and left and right pointers. +fn traverse(nodes: &mut [IntervalNode]) -> Vec> +where + T: Clone, + I: IntWithMax, +{ + let n = nodes.len(); + let mut info = vec![TraversalInfo::default(); n]; + let mut inorder = 0; + let mut preorder = 0; + traverse_recursion( + nodes, + &mut info, + 0, + n, + 0, + I::MAX, + &mut inorder, + &mut preorder, + ); + + info +} + +// The recursive part of the `traverse` function. +fn traverse_recursion( + nodes: &mut [IntervalNode], + info: &mut [TraversalInfo], + start: usize, + end: usize, + depth: u32, + parent: I, + inorder: &mut usize, + preorder: &mut usize, +) -> (I, i32, f32) +where + T: Clone, + I: IntWithMax, +{ + if start >= end { + return (I::MAX, i32::MAX, f32::NAN); + } + + let root_idx = start + (end - start) / 2; + let subtree_size = end - start; + + info[root_idx].depth = depth; + info[root_idx].preorder = I::from_usize(*preorder); + info[root_idx].parent = parent; + *preorder += 1; + + let mut subtree_first = nodes[root_idx].min_first(); + + let mut left_expected_hits = 0.0; + let mut left_subtree_span = 0; + + if root_idx > start { + let (left, left_subtree_first, left_expected_hits_) = traverse_recursion( + nodes, + info, + start, + root_idx, + depth + 1, + I::from_usize(root_idx), + inorder, + preorder, + ); + + left_expected_hits = left_expected_hits_; + left_subtree_span = nodes[left.to_usize()].subtree_last - left_subtree_first + 1; + + subtree_first = left_subtree_first; + if nodes[left.to_usize()].subtree_last > nodes[root_idx].subtree_last { + nodes[root_idx].subtree_last = nodes[left.to_usize()].subtree_last; + } + nodes[root_idx].left = left; + } + + info[root_idx].inorder = I::from_usize(*inorder); + *inorder += 1; + + let mut right_expected_hits = 0.0; + let mut right_subtree_span = 0; + + if root_idx + 1 < end { + let (right, right_subtree_first, right_expected_hits_) = traverse_recursion( + nodes, + info, + root_idx + 1, + end, + depth + 1, + I::from_usize(root_idx), + inorder, + preorder, + ); + + right_expected_hits = right_expected_hits_; + right_subtree_span = nodes[right.to_usize()].subtree_last - right_subtree_first + 1; + + if nodes[right.to_usize()].subtree_last > nodes[root_idx].subtree_last { + nodes[root_idx].subtree_last = nodes[right.to_usize()].subtree_last; + } + nodes[root_idx].right = right; + } + + info[root_idx].subtree_size = I::from_usize(subtree_size); + let subtree_span = nodes[root_idx].subtree_last - subtree_first + 1; + + debug_assert!(left_subtree_span <= subtree_span); + debug_assert!(right_subtree_span <= subtree_span); + + let expected_hits = ((nodes[root_idx].max_last() - nodes[root_idx].min_first() + 1) as f32 + + (left_subtree_span as f32) * left_expected_hits + + (right_subtree_span as f32) * right_expected_hits) + / subtree_span as f32; + + info[root_idx].expected_hit_proportion = expected_hits / subtree_size as f32; + + (I::from_usize(root_idx), subtree_first, expected_hits) +} + +// norder partition by depth on pivot into three parts, like so +// [ bottom left ][ top ][ bottom right ] +// where bottom left and right are the bottom subtrees with positioned to +// the left and right of the root node +fn stable_ternary_tree_partition( + input: &[I], + output: &mut [I], + partition: &mut [i8], + info: &[TraversalInfo], + pivot_depth: u32, + pivot_dfs: I, + start: usize, + end: usize, +) -> (usize, usize) +where + I: IntWithMax, +{ + let n = end - start; + + // determine which partition each index goes in + let mut bottom_left_size = 0; + let mut top_size = 0; + let mut bottom_right_size = 0; + + for (i, p) in input[start..end].iter().zip(&mut partition[start..end]) { + let info_j = &info[i.to_usize()]; + if info_j.depth <= pivot_depth { + *p = 0; + top_size += 1; + } else if info_j.inorder < pivot_dfs { + *p = -1; + bottom_left_size += 1; + } else { + *p = 1; + bottom_right_size += 1; + } + } + + debug_assert!(bottom_left_size + top_size + bottom_right_size == n); + + // do the partition + let mut bl = start; + let mut t = bl + bottom_left_size; + let mut br = t + top_size; + for (i, p) in input[start..end].iter().zip(&partition[start..end]) { + match p.cmp(&0) { + Ordering::Less => { + output[bl] = *i; + bl += 1; + } + Ordering::Equal => { + output[t] = *i; + t += 1; + } + Ordering::Greater => { + output[br] = *i; + br += 1; + } + } + } + debug_assert!(br == end); + + (bl, t) +} + +// put nodes in van Emde Boas order +fn veb_order(mut nodes: Vec>) -> (Vec>, usize, usize) +where + T: Clone, + I: IntWithMax, +{ + // let now = Instant::now(); + let mut veb_nodes = nodes.clone(); + let n = veb_nodes.len(); + + if veb_nodes.is_empty() { + return (veb_nodes, 0, 0); + } + + // let now = Instant::now(); + let info = traverse(&mut nodes); + // eprintln!("traversing: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + let mut max_depth = 0; + for info_i in &info { + if info_i.depth > max_depth { + max_depth = info_i.depth; + } + } + + let idxs: &mut [I] = &mut vec![I::default(); n]; + (0..n).for_each(|i| idxs[i] = I::from_usize(i)); + + let tmp: &mut [I] = &mut vec![I::default(); n]; + + // put in dfs order + for i in &*idxs { + tmp[info[i.to_usize()].preorder.to_usize()] = *i; + } + let (idxs, tmp) = (tmp, idxs); + + // space used to by stable_ternary_tree_partition + let partition: &mut [i8] = &mut vec![0; n]; + + // let now = Instant::now(); + let root_idx = veb_order_recursion( + idxs, tmp, partition, &info, &mut nodes, 0, n, false, 0, max_depth, + ); + // eprintln!("computing order: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + // let now = Instant::now(); + // idxs is now a vEB -> sorted order map. Build the reverse here. + let revidx = tmp; + for (i, j) in idxs.iter().enumerate() { + revidx[j.to_usize()] = I::from_usize(i); + } + + // put nodes in vEB order + for (i_, mut node) in revidx.iter().zip(nodes) { + let i = i_.to_usize(); + if node.left != node.right { + if node.left < I::MAX { + node.left = revidx[node.left.to_usize()]; + } + + if node.right < I::MAX { + node.right = revidx[node.right.to_usize()]; + } + } + veb_nodes[i.to_usize()] = node; + } + + let root_idx = revidx[root_idx.to_usize()].to_usize(); + + // eprintln!("ordering: {}s", now.elapsed().as_millis() as f64 / 1000.0); + debug_assert!(compute_tree_size(&veb_nodes, root_idx) == n); + + (veb_nodes, root_idx, max_depth as usize) +} + +// Traverse the tree and return the size, used as a basic sanity check. +fn compute_tree_size(nodes: &[IntervalNode], root_idx: usize) -> usize +where + T: Clone, + I: IntWithMax, +{ + let mut subtree_size = 1; + + let node = &nodes[root_idx]; + if node.left == node.right { + subtree_size = nodes[root_idx].right.to_usize(); + } else { + if node.left < I::MAX { + let left = node.left.to_usize(); + subtree_size += compute_tree_size(nodes, left); + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + subtree_size += compute_tree_size(nodes, right); + } + } + + subtree_size +} + +// recursively reorder indexes to put it in vEB order. Called by `veb_order` +// idxs: current permutation +// tmp: temporary space of equal length to idxs +// partition: space used to assist `stable_ternary_tree_partition`. +// nodes: the interval nodes (in sorted order) +// start, end: slice within idxs to be reordered +// childless: true if this slice is a proper subtree and has no children below it +// parity: true if idxs and tmp are swapped and need to be copied back, +// min_depth, max_depth: depth extreme of the start..end slice +// +fn veb_order_recursion( + idxs: &mut [I], + tmp: &mut [I], + partition: &mut [i8], + info: &[TraversalInfo], + nodes: &mut [IntervalNode], + start: usize, + end: usize, + parity: bool, + min_depth: u32, + max_depth: u32, +) -> I +where + T: Clone, + I: IntWithMax, +{ + let n = (start..end).len(); + + assert!(n > 0); + + let childless = info[idxs[start].to_usize()].subtree_size.to_usize() == n; + + // small subtrees are put into sorted order and just searched through + // linearly. There is a little trickiness to this because we have to + // update the parent's child pointers and some other fields. + if childless + && ((info[idxs[start].to_usize()].subtree_size.to_usize() <= SIMPLE_SUBTREE_CUTOFF) + || (info[idxs[start].to_usize()].expected_hit_proportion + >= SIMPLE_SUBTREE_DENSITY_CUTOFF)) + { + debug_assert!(n == info[idxs[start].to_usize()].subtree_size.to_usize()); + + let old_root = idxs[start]; + + idxs[start..end].sort_unstable_by_key(|i| info[i.to_usize()].inorder); + let subtree_size = info[old_root.to_usize()].subtree_size; + nodes[idxs[start].to_usize()].subtree_last = nodes[old_root.to_usize()].subtree_last; + + // all children nodes record the size of the remaining list + // let mut subtree_i_size = subtree_size - i; + let mut subtree_i_size = subtree_size; + for idx in &idxs[start..end] { + nodes[idx.to_usize()].left = subtree_i_size; + nodes[idx.to_usize()].right = subtree_i_size; + subtree_i_size -= I::one(); + } + + let parent = info[old_root.to_usize()].parent; + if parent < I::MAX { + if nodes[parent.to_usize()].left == old_root { + nodes[parent.to_usize()].left = idxs[start]; + } else { + debug_assert!(nodes[parent.to_usize()].right == old_root); + nodes[parent.to_usize()].right = idxs[start]; + } + } + + if parity { + tmp[start..end].copy_from_slice(&idxs[start..end]); + } + return idxs[start]; + } + + // very small trees are already in order + if n == 1 { + if parity { + tmp[start] = idxs[start]; + } + return idxs[start]; + } + + let pivot_depth = min_depth + (max_depth - min_depth) / 2; + let pivot_dfs = info[idxs[start].to_usize()].inorder; + + let (top_start, bottom_right_start) = stable_ternary_tree_partition( + idxs, + tmp, + partition, + info, + pivot_depth, + pivot_dfs, + start, + end, + ); + + // tmp is not partitioned, so swap pointers + let (tmp, idxs) = (idxs, tmp); + + // recurse on top subtree + let top_root_idx = veb_order_recursion( + idxs, + tmp, + partition, + info, + nodes, + top_start, + bottom_right_start, + !parity, + min_depth, + pivot_depth, + ); + + // find on recurse on subtrees in the bottom left partition and bottom right partition + for (part_start, part_end) in &[(start, top_start), (bottom_right_start, end)] { + let bottom_subtree_depth = pivot_depth + 1; + let mut i = *part_start; + while i < *part_end { + debug_assert!(info[idxs[i].to_usize()].depth == bottom_subtree_depth); + + let mut subtree_max_depth = info[idxs[i].to_usize()].depth; + let mut j = *part_end; + for (u, v) in (i + 1..*part_end).zip(&idxs[i + 1..*part_end]) { + let depth = info[v.to_usize()].depth; + if depth == bottom_subtree_depth { + j = u; + break; + } else if depth > subtree_max_depth { + subtree_max_depth = depth; + } + } + + veb_order_recursion( + idxs, + tmp, + partition, + info, + nodes, + i, + j, + !parity, + bottom_subtree_depth, + subtree_max_depth, + ); + i = j; + } + } + + top_root_idx +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/interval.rs b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/interval.rs new file mode 100644 index 0000000..e731b16 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/interval.rs @@ -0,0 +1,208 @@ +use std::iter::IntoIterator; +use std::ops::{AddAssign, SubAssign}; + +pub trait GenericInterval +where + T: Clone, +{ + fn first(&self) -> i32; + fn last(&self) -> i32; + fn metadata(&self) -> &T; + + fn len(&self) -> i32 { + 0.max(self.last() - self.first() + 1) + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// An interval with associated metadata. +/// +/// Intervals in `COITree` are treated as end-inclusive. +/// +/// Metadata can be an arbitrary type `T`, but because nodes are stored in contiguous +/// memory, it may be better to store large metadata outside the node and +/// use a pointer or reference for the metadata. +/// +/// # Examples +/// ``` +/// use coitrees::Interval; +/// use coitrees::GenericInterval; +/// +/// #[derive(Clone)] +/// struct MyMetadata { +/// chrom: String, +/// posstrand: bool +/// } +/// +/// let some_interval = Interval{ +/// first: 10, last: 24000, +/// metadata: MyMetadata{chrom: String::from("chr1"), posstrand: false}}; +/// +/// assert_eq!(some_interval.len(), 23991); +/// ``` +#[derive(Clone, Copy, Debug)] +pub struct Interval +where + T: Clone, +{ + pub first: i32, + pub last: i32, + pub metadata: T, +} + +impl Interval +where + T: Clone, +{ + pub fn new(first: i32, last: i32, metadata: T) -> Interval { + Self { + first, + last, + metadata, + } + } +} + +impl GenericInterval for Interval +where + T: Clone, +{ + fn first(&self) -> i32 { + self.first + } + + fn last(&self) -> i32 { + self.last + } + + fn metadata(&self) -> &T { + &self.metadata + } +} + +impl<'a, T> GenericInterval for Interval<&'a T> +where + T: Clone, +{ + fn first(&self) -> i32 { + self.first + } + + fn last(&self) -> i32 { + self.last + } + + fn metadata(&self) -> &T { + self.metadata + } +} + +#[test] +fn test_interval_len() { + fn make_interval(first: i32, last: i32) -> Interval<()> { + Interval { + first, + last, + metadata: (), + } + } + + assert_eq!(make_interval(1, -1).len(), 0); + assert_eq!(make_interval(1, 0).len(), 0); + assert_eq!(make_interval(1, 1).len(), 1); + assert_eq!(make_interval(1, 2).len(), 2); +} + +/// A trait facilitating COITree index types. +pub trait IntWithMax: + TryInto + TryFrom + Copy + Default + PartialEq + Ord + AddAssign + SubAssign +{ + const MAX: Self; + + // typically the branch here should be optimized out, because we are + // converting, e.g. a u32 to a usize on 64-bit system. + #[inline(always)] + fn to_usize(self) -> usize { + match self.try_into() { + Ok(x) => x, + Err(_) => panic!("index conversion to usize failed"), + } + } + + #[inline(always)] + fn from_usize(x: usize) -> Self { + match x.try_into() { + Ok(y) => y, + Err(_) => panic!("index conversion from usize failed"), + } + } + + fn one() -> Self { + Self::from_usize(1) + } +} + +impl IntWithMax for usize { + const MAX: usize = usize::MAX; +} + +impl IntWithMax for u64 { + const MAX: u64 = u64::MAX; +} + +impl IntWithMax for u32 { + const MAX: u32 = u32::MAX; +} + +impl IntWithMax for u16 { + const MAX: u16 = u16::MAX; +} + +/// Basic interval tree interface supported by each COITree implementation. +pub trait IntervalTree<'a> { + type Metadata: Clone + 'a; + type Index: IntWithMax; + type Item: GenericInterval + 'a; + type Iter: Iterator>; + + fn new<'b, U, V>(intervals: U) -> Self + where + U: IntoIterator, + V: GenericInterval + 'b; + + fn len(&self) -> usize; + + fn is_empty(&self) -> bool; + + fn query(&'a self, first: i32, last: i32, visit: F) + where + F: FnMut(&Self::Item); + + fn query_count(&self, first: i32, last: i32) -> usize; + fn coverage(&self, first: i32, last: i32) -> (usize, usize); + + fn iter(&'a self) -> Self::Iter; +} + +pub trait SortedQuerent<'a> { + type Metadata: Clone + 'a; + type Index: IntWithMax; + type Item: GenericInterval + 'a; + type Iter: Iterator>; + type Tree: IntervalTree< + 'a, + Metadata = Self::Metadata, + Index = Self::Index, + Item = Self::Item, + Iter = Self::Iter, + > + 'a; + + fn new(tree: &'a Self::Tree) -> Self; + + fn query(&mut self, first: i32, last: i32, visit: F) + where + F: FnMut(&Self::Item); +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/lib.rs b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/lib.rs new file mode 100644 index 0000000..14d3c59 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/lib.rs @@ -0,0 +1,53 @@ +//! # COITrees +//! `coitrees` implements a fast static interval tree data structure with genomic +//! data in mind. +//! +//! The data structure used a fairly standard interval tree, but with nodes stored +//! in van Emde Boas layout, which improves average cache locality, and thus +//! query performance. The downside is that building the tree is more expensive +//! so a relatively large number of queries needs to made for it to break even. +//! +//! The data structure `COITree` is constructed with an array of `IntervalNode` +//! structs which store integer, end-inclusive intervals along with associated +//! metadata. The tree can be queried directly for coverage or overlaps, or +//! through the intermediary `SortedQuerenty` which keeps track of some state +//! to accelerate overlaping queries. + +mod interval; +pub use interval::*; + +mod nosimd; +pub use nosimd::*; + +#[cfg(all(target_feature = "avx2", not(feature = "nosimd")))] +mod avx; +#[cfg(all(target_feature = "avx2", not(feature = "nosimd")))] +pub use avx::*; + +#[cfg(all(target_feature = "neon", not(feature = "nosimd")))] +mod neon; +#[cfg(all(target_feature = "neon", not(feature = "nosimd")))] +pub use neon::*; + +// These are necessary mutually exclusive +#[cfg(all( + not(feature = "nosimd"), + not(target_feature = "avx2"), + not(target_feature = "neon") +))] +pub type COITree = BasicCOITree; +#[cfg(all(target_feature = "avx2", not(feature = "nosimd")))] +pub type COITree = AVXCOITree; +#[cfg(all(target_feature = "neon", not(feature = "nosimd")))] +pub type COITree = NeonCOITree; + +#[cfg(all( + not(feature = "nosimd"), + not(target_feature = "avx2"), + not(target_feature = "neon") +))] +pub type COITreeSortedQuerent<'a, T, I> = BasicSortedQuerent<'a, T, I>; +#[cfg(all(target_feature = "avx2", not(feature = "nosimd")))] +pub type COITreeSortedQuerent<'a, T, I> = AVXSortedQuerent<'a, T, I>; +#[cfg(all(target_feature = "neon", not(feature = "nosimd")))] +pub type COITreeSortedQuerent<'a, T, I> = NeonSortedQuerent<'a, T, I>; diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/neon.rs b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/neon.rs new file mode 100644 index 0000000..304ee58 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/neon.rs @@ -0,0 +1,1422 @@ +use std::arch::aarch64::*; + +use super::interval::{GenericInterval, IntWithMax, Interval, IntervalTree, SortedQuerent}; +use std::cmp::{max, Ordering}; +use std::fmt::Debug; +use std::marker::Copy; +use std::mem::transmute; + +#[allow(non_camel_case_types)] +type i32x4 = int32x4_t; + +const LANE_SIZE: usize = 4; + +// Small subtrees at the bottom of the tree are stored in sorted order +// This gives the upper bound on the size of such subtrees. Performance isn't +// super sensitive, but is worse with a very small or very large number. +const SIMPLE_SUBTREE_CUTOFF: usize = 8; + +// Very dense subtrees in which we probably intersect most of the intervals +// are more efficient to query linearly. When the expected proportion of hits +// is a above this number it becomes a simple subtree. +const SIMPLE_SUBTREE_DENSITY_CUTOFF: f32 = 0.2; + +/// Node in the interval tree. Each node holds a chunk of 8 intervals. +#[derive(Clone)] +struct IntervalNode +where + T: Clone, + I: IntWithMax, +{ + // subtree interval + subtree_last: i32, + + firsts: i32x4, + lasts: i32x4, + metadata: [T; LANE_SIZE], + + // when this is the root of a simple subtree, left == right is the size + // of the subtree, otherwise they are left, right child pointers. + left: I, + right: I, +} + +impl IntervalNode +where + T: Clone, + I: IntWithMax, +{ + fn new(intervals: &[Interval; LANE_SIZE]) -> Self + where + T: Copy, + { + assert!(!intervals.is_empty() && intervals.len() <= LANE_SIZE); + + unsafe { + let max_last = intervals + .iter() + .max_by_key(|x| x.last) + .map(|x| x.last) + .unwrap(); + + // the firsts and lasts are adjust by 1 here because AVX has an + // instruction for > but not >=. + // WARN: may not need to minus 1 <03-01-23, Yangyang Li yangyang.li@northwestern.edu> + let firsts = vld1q_s32( + [ + intervals[0].first - 1, + intervals[1].first - 1, + intervals[2].first - 1, + intervals[3].first - 1, + ] + .as_ptr(), + ); + + let lasts = vld1q_s32( + [ + intervals[0].last + 1, + intervals[1].last + 1, + intervals[2].last + 1, + intervals[3].last + 1, + ] + .as_ptr(), + ); + + // NOTE: the order is reversed from first and lasts data + let metadata: [T; LANE_SIZE] = [ + intervals[0].metadata, + intervals[1].metadata, + intervals[2].metadata, + intervals[3].metadata, + ]; + + Self { + subtree_last: max_last, + firsts, + lasts, + metadata, + left: I::MAX, + right: I::MAX, + } + } + } + + fn first(&self, i: usize) -> i32 { + assert!(i < LANE_SIZE); + (unsafe { transmute::<&i32x4, &[i32; LANE_SIZE]>(&self.firsts) }[i]) + } + + fn last(&self, i: usize) -> i32 { + assert!(i < LANE_SIZE); + (unsafe { transmute::<&i32x4, &[i32; LANE_SIZE]>(&self.lasts) }[i]) + } + + fn min_first(&self) -> i32 { + unsafe { vgetq_lane_s32(self.firsts, 0) + 1 } + } + + fn max_last(&self) -> i32 { + // faster ways to do this, but doesn't really matter + unsafe { + max( + max(vgetq_lane_s32(self.lasts, 0), vgetq_lane_s32(self.lasts, 1)), + max(vgetq_lane_s32(self.lasts, 2), vgetq_lane_s32(self.lasts, 3)), + ) - 1 + } + } + + #[inline(always)] + fn query_count_chunk(&self, query_first: i32x4, query_last: i32x4) -> usize { + unsafe { + let cmp1 = vcgtq_s32(query_last, self.firsts); + let cmp2 = vcgtq_s32(self.lasts, query_first); + let cmp = vandq_u32(cmp1, cmp2); + count_bits(cmp) / 32 + } + } + + fn query_chunk_firsts<'a, F>(&'a self, query_first: i32x4, query_last: i32x4, mut visit: F) + where + F: FnMut(i32, i32, &'a T), + { + let cmp: u64 = unsafe { + let cmp1 = vcgtq_s32(query_last, self.firsts); + let cmp2 = vcgtq_s32(self.firsts, query_first); + let cmp = vmovn_u32(vandq_u32(cmp1, cmp2)); + transmute(cmp) + }; + + let masks = [0xffff, 0xffff << 16, 0xffff << 32, 0xffff << 48]; + + if cmp & masks[0] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 0) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 0) - 1 }, + &self.metadata[0], + ); + } + + if cmp & masks[1] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 1) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 1) - 1 }, + &self.metadata[1], + ); + } + + if cmp & masks[2] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 2) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 2) - 1 }, + &self.metadata[2], + ); + } + + if cmp & masks[3] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 3) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 3) - 1 }, + &self.metadata[3], + ); + } + } + + fn query_chunk_metadata<'a, F>(&'a self, query_first: i32x4, query_last: i32x4, mut visit: F) + where + F: FnMut(i32, i32, &'a T), + { + let cmp: u64 = unsafe { + let cmp1 = vcgtq_s32(query_last, self.firsts); + let cmp2 = vcgtq_s32(self.lasts, query_first); + let cmp = vmovn_u32(vandq_u32(cmp1, cmp2)); + transmute(cmp) + }; + + let masks = [0xffff, 0xffff << 16, 0xffff << 32, 0xffff << 48]; + + if cmp & masks[0] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 0) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 0) - 1 }, + &self.metadata[0], + ); + } + if cmp & masks[1] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 1) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 1) - 1 }, + &self.metadata[1], + ); + } + if cmp & masks[2] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 2) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 2) - 1 }, + &self.metadata[2], + ); + } + if cmp & masks[3] != 0 { + visit( + unsafe { vgetq_lane_s32(self.firsts, 3) + 1 }, + unsafe { vgetq_lane_s32(self.lasts, 3) - 1 }, + &self.metadata[3], + ); + } + } + + fn query_chunk(&self, query_first: i32x4, query_last: i32x4, mut visit: F) + where + F: FnMut(i32, i32), + { + let cmp: u64 = unsafe { + let cmp1 = vcgtq_s32(query_last, self.firsts); + let cmp2 = vcgtq_s32(self.lasts, query_first); + let cmp = vmovn_u32(vandq_u32(cmp1, cmp2)); + transmute(cmp) + }; + + let masks = [0xffff, 0xffff << 16, 0xffff << 32, 0xffff << 48]; + + if cmp & masks[0] != 0 { + visit(unsafe { vgetq_lane_s32(self.firsts, 0) + 1 }, unsafe { + vgetq_lane_s32(self.lasts, 0) - 1 + }); + } + if cmp & masks[1] != 0 { + visit(unsafe { vgetq_lane_s32(self.firsts, 1) + 1 }, unsafe { + vgetq_lane_s32(self.lasts, 1) - 1 + }); + } + if cmp & masks[2] != 0 { + visit(unsafe { vgetq_lane_s32(self.firsts, 2) + 1 }, unsafe { + vgetq_lane_s32(self.lasts, 2) - 1 + }); + } + if cmp & masks[3] != 0 { + visit(unsafe { vgetq_lane_s32(self.firsts, 3) + 1 }, unsafe { + vgetq_lane_s32(self.lasts, 3) - 1 + }); + } + } +} + +fn count_bits(bits: uint32x4_t) -> usize { + unsafe { + let t2 = vreinterpretq_u8_u32(bits); + let t3 = vcntq_u8(t2); + let sum = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(t3))); + vgetq_lane_u64(sum, 0) as usize + vgetq_lane_u64(sum, 1) as usize + } +} + +/// NeonCOITree data structure. A representation of a static set of intervals with +/// associated metadata, enabling fast overlap and coverage queries. +/// +/// The index type `I` is a typically `usize`, but can be `u32` or `u16`. +/// It's slightly more efficient to use a smalled index type, assuming there are +/// fewer than I::MAX-1 intervals to store. +#[derive(Clone)] +pub struct NeonCOITree +where + T: Clone, + I: IntWithMax, +{ + nodes: Vec>, + len: usize, + root_idx: usize, + height: usize, +} + +impl NeonCOITree +where + T: Default + Copy + Clone, + I: IntWithMax, +{ + fn chunk_intervals(intervals: Vec>) -> Vec> + where + T: Copy, + { + let n = intervals.len(); + + let num_chunks = (n / LANE_SIZE) + (n % LANE_SIZE != 0) as usize; + + let mut nodes: Vec> = Vec::with_capacity(num_chunks); + + // chunk initializer + let mut chunk_init: [Interval; LANE_SIZE] = [Interval { + first: i32::MAX, + last: i32::MIN, + metadata: T::default(), + }; LANE_SIZE]; + + for chunk in intervals.chunks(LANE_SIZE) { + for (j, interval) in chunk.iter().enumerate() { + chunk_init[j] = *interval; + } + + chunk_init.iter_mut().skip(chunk.len()).for_each(|item| { + *item = Interval { + first: i32::MAX, + last: i32::MIN, + metadata: T::default(), + }; + }); + + nodes.push(IntervalNode::new(&chunk_init)); + } + + nodes + } +} + +impl<'a, T, I> IntervalTree<'a> for NeonCOITree +where + T: Default + Copy + Clone + 'a, + I: IntWithMax + 'a, +{ + type Metadata = T; + type Index = I; + type Item = Interval<&'a T>; + type Iter = NeonCOITreeIterator<'a, T, I>; + + fn new<'b, U, V>(intervals: U) -> NeonCOITree + where + U: IntoIterator, + V: GenericInterval + 'b, + { + let mut intervals: Vec<_> = intervals + .into_iter() + .map(|interval| { + Interval::new( + interval.first(), + interval.last(), + interval.metadata().clone(), + ) + }) + .collect(); + + if intervals.len() >= (I::MAX).to_usize() { + panic!("NeonCOITree construction failed: more intervals than index type can enumerate") + } + + let n = intervals.len(); + intervals.sort_unstable_by_key(|interval| (interval.first, interval.last)); + let nodes = Self::chunk_intervals(intervals); + + let (nodes, root_idx, height) = veb_order(nodes); + NeonCOITree { + nodes, + len: n, + root_idx, + height, + } + } + + /// Number of intervals in the set. + fn len(&self) -> usize { + self.len + } + + /// True iff the set is empty. + fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + // /// Find intervals in the set overlaping the query `[first, last]` and call `visit` on every overlapping node + fn query(&'a self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&Interval<&'a T>), + { + let (firstv, lastv) = unsafe { (vdupq_n_s32(first), vdupq_n_s32(last)) }; + + if !self.is_empty() { + query_recursion( + &self.nodes, + self.root_idx, + first, + last, + firstv, + lastv, + &mut visit, + ); + } + } + + /// Count the number of intervals in the set overlapping the query `[first, last]`. + fn query_count(&self, first: i32, last: i32) -> usize { + let (firstv, lastv) = unsafe { (vdupq_n_s32(first), vdupq_n_s32(last)) }; + + if !self.is_empty() { + query_recursion_count(&self.nodes, self.root_idx, first, last, firstv, lastv) + } else { + 0 + } + } + + /// Return a pair `(count, cov)`, where `count` gives the number of intervals + /// in the set overlapping the query, and `cov` the number of positions in the query + /// interval covered by at least one interval in the set. + fn coverage(&self, first: i32, last: i32) -> (usize, usize) { + assert!(last >= first); + + if self.is_empty() { + return (0, 0); + } + + let (firstv, lastv) = unsafe { (vdupq_n_s32(first), vdupq_n_s32(last)) }; + + let (mut uncov_len, last_cov, count) = coverage_recursion( + &self.nodes, + self.root_idx, + first, + last, + firstv, + lastv, + first - 1, + ); + + if last_cov < last { + uncov_len += last - last_cov; + } + + let cov = ((last - first + 1) as usize) - (uncov_len as usize); + + (count, cov) + } + + /// Iterate through the interval set in sorted order by interval start position. + fn iter(&self) -> NeonCOITreeIterator { + let mut i = self.root_idx; + let mut stack: Vec = Vec::with_capacity(self.height); + while i < self.nodes.len() + && self.nodes[i].left != I::MAX + && self.nodes[i].left != self.nodes[i].right + { + stack.push(i); + i = self.nodes[i].left.to_usize(); + } + + NeonCOITreeIterator { + nodes: &self.nodes, + len: self.len, + i, + j: 0, + count: 0, + stack, + } + } +} + +impl<'a, T, I> IntoIterator for &'a NeonCOITree +where + T: Default + Copy + Clone, + I: IntWithMax, +{ + type Item = Interval<&'a T>; + type IntoIter = NeonCOITreeIterator<'a, T, I>; + + fn into_iter(self) -> NeonCOITreeIterator<'a, T, I> { + return self.iter(); + } +} + +/// Iterate through nodes in a tree in sorted order by interval start position. +pub struct NeonCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + nodes: &'a Vec>, + len: usize, // total number of intervals + i: usize, // current node + j: usize, // offset into the current chunk + count: usize, // number generated so far + stack: Vec, +} + +impl<'a, T, I> Iterator for NeonCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + type Item = Interval<&'a T>; + + fn next(&mut self) -> Option { + if self.i * LANE_SIZE + self.j >= self.len { + return None; + } + + let node = &self.nodes[self.i]; + if self.j < LANE_SIZE { + let ret = Some(Interval { + first: node.first(self.j), + last: node.last(self.j), + metadata: &node.metadata[self.j], + }); + self.count += 1; + self.j += 1; + return ret; + } + + // find the next node + self.j = 0; + if node.left == node.right { + // simple node + if node.left.to_usize() > 1 { + self.i += 1; + } else if let Some(i) = self.stack.pop() { + self.i = i; + } else { + self.i = usize::MAX; + } + } else if node.right == I::MAX { + if let Some(i) = self.stack.pop() { + self.i = i; + } else { + self.i = usize::MAX; + } + } else { + let mut i = node.right.to_usize(); + + while self.nodes[i].left != I::MAX && self.nodes[i].left != self.nodes[i].right { + self.stack.push(i); + i = self.nodes[i].left.to_usize(); + } + + self.i = i; + } + + let node = &self.nodes[self.i]; + self.count += 1; + Some(Interval { + first: node.first(self.j), + last: node.last(self.j), + metadata: &node.metadata[self.j], + }) + } + + fn size_hint(&self) -> (usize, Option) { + let len_remaining = self.len - self.count; + (len_remaining, Some(len_remaining)) + } +} + +impl<'a, T, I> ExactSizeIterator for NeonCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + fn len(&self) -> usize { + self.len - self.count + } +} + +// // Recursively count overlaps between the tree specified by `nodes` and a +// // query interval specified by `first`, `last`. +fn query_recursion<'a, T, I, F>( + nodes: &'a [IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x4, + lastv: i32x4, + visit: &mut F, +) where + T: Clone, + I: IntWithMax, + F: FnMut(&Interval<&'a T>), +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } + + node.query_chunk_metadata(firstv, lastv, |first_hit, last_hit, metadata| { + visit(&Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }); + } + } else { + node.query_chunk_metadata(firstv, lastv, |first_hit, last_hit, metadata| { + visit(&Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }); + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + query_recursion(nodes, left, first, last, firstv, lastv, visit); + } + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.min_first(), nodes[right].subtree_last, first, last) { + query_recursion(nodes, right, first, last, firstv, lastv, visit); + } + } + } +} + +// query_recursion but just count number of overlaps +fn query_recursion_count( + nodes: &[IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x4, + lastv: i32x4, +) -> usize +where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + let mut count = 0; + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } else { + count += node.query_count_chunk(firstv, lastv); + } + } + count + } else { + let mut count = 0; + count += node.query_count_chunk(firstv, lastv); + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + count += query_recursion_count(nodes, left, first, last, firstv, lastv); + } + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.min_first(), nodes[right].subtree_last, first, last) { + count += query_recursion_count(nodes, right, first, last, firstv, lastv); + } + } + + count + } +} + +fn coverage_recursion( + nodes: &[IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x4, + lastv: i32x4, + mut last_cov: i32, +) -> (i32, i32, usize) +where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + let mut count = 0; + let mut uncov_len = 0; + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } + + node.query_chunk(firstv, lastv, |first_hit, last_hit| { + if first_hit > last_cov { + uncov_len += first_hit - (last_cov + 1); + } + last_cov = last_cov.max(last_hit); + count += 1; + }); + } + (uncov_len, last_cov, count) + } else { + let mut uncov_len = 0; + let mut count = 0; + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + let (left_uncov_len, left_last_cov, left_count) = + coverage_recursion(nodes, left, first, last, firstv, lastv, last_cov); + last_cov = left_last_cov; + uncov_len += left_uncov_len; + count += left_count; + } + } + + node.query_chunk(firstv, lastv, |first_hit, last_hit| { + if first_hit > last_cov { + uncov_len += first_hit - (last_cov + 1); + } + last_cov = last_cov.max(last_hit); + count += 1; + }); + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.min_first(), nodes[right].subtree_last, first, last) { + let (right_uncov_len, right_last_cov, right_count) = + coverage_recursion(nodes, right, first, last, firstv, lastv, last_cov); + last_cov = right_last_cov; + uncov_len += right_uncov_len; + count += right_count; + } + } + + (uncov_len, last_cov, count) + } +} + +/// An alternative query strategy that can be much faster when queries are performed +/// in sorted order and overlap. +/// +/// Unilke `COITree::query`, some state is retained between queries. +/// `SortedQuerent` tracks that state. If queries are not sorted or don't +/// overlap, this strategy still works, but is slightly slower than +/// `COITree::query`. +pub struct NeonSortedQuerent<'a, T, I> +where + T: Default + Clone, + I: IntWithMax, +{ + tree: &'a NeonCOITree, + prev_first: i32, + prev_last: i32, + overlapping_intervals: Vec>, +} + +impl<'a, T, I> SortedQuerent<'a> for NeonSortedQuerent<'a, T, I> +where + T: Default + Copy + Clone, + I: IntWithMax, +{ + type Metadata = T; + type Index = I; + type Item = Interval<&'a T>; + type Iter = NeonCOITreeIterator<'a, T, I>; + type Tree = NeonCOITree; + + /// Construct a new `SortedQuerent` to perform a sequence. queries. + fn new(tree: &'a NeonCOITree) -> NeonSortedQuerent<'a, T, I> { + let overlapping_intervals: Vec> = Vec::new(); + NeonSortedQuerent { + tree, + prev_first: -1, + prev_last: -1, + overlapping_intervals, + } + } + + /// Find intervals in the underlying `COITree` that overlap the query + /// `[first, last]` and call `visit` on each. Works equivalently to + /// `COITrees::query` but queries that overlap prior queries will potentially + /// be faster. + fn query(&mut self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&Interval<&'a T>), + { + if self.tree.is_empty() { + return; + } + + // not overlaping or preceding + if first < self.prev_first || first > self.prev_last { + // no overlap with previous query. have to resort to regular query strategy + self.overlapping_intervals.clear(); + self.tree + .query(first, last, |node| self.overlapping_intervals.push(*node)); + } else { + // successor query, exploit the overlap + + // delete previously overlapping intervals with end in [prev_first, first-1] + if self.prev_first < first { + let mut i = 0; + while i < self.overlapping_intervals.len() { + if self.overlapping_intervals[i].last < first { + self.overlapping_intervals.swap_remove(i); + } else { + i += 1; + } + } + } + + // delete previously overlapping intervals with start in [last+1, prev_end] + if self.prev_last > last { + let mut i = 0; + while i < self.overlapping_intervals.len() { + if self.overlapping_intervals[i].first > last { + self.overlapping_intervals.swap_remove(i); + } else { + i += 1; + } + } + } + + // add any interval that start in [prev_last+1, last] + if self.prev_last < last { + let qa = self.prev_last + 1 - 2; // -2 accounts for the adjustment made in the chunk + let qb = last; + + let (qav, qbv) = unsafe { (vdupq_n_s32(qa), vdupq_n_s32(qb)) }; + + sorted_querent_query_firsts( + &self.tree.nodes, + self.tree.root_idx, + qa, + qb, + qav, + qbv, + &mut self.overlapping_intervals, + ); + } + } + + // call visit on everything + for overlapping_interval in &self.overlapping_intervals { + visit(overlapping_interval); + } + + self.prev_first = first; + self.prev_last = last; + } +} + +// find any intervals in the tree with their first value in [first, last] +fn sorted_querent_query_firsts<'a, T, I>( + nodes: &'a [IntervalNode], + root_idx: usize, + first: i32, + last: i32, + firstv: i32x4, + lastv: i32x4, + overlapping_intervals: &mut Vec>, +) where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.min_first() { + break; + } + + node.query_chunk_firsts(firstv, lastv, |first_hit, last_hit, metadata| { + overlapping_intervals.push(Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }) + } + } else { + node.query_chunk_firsts(firstv, lastv, |first_hit, last_hit, metadata| { + overlapping_intervals.push(Interval { + first: first_hit, + last: last_hit, + metadata, + }); + }); + + if node.left < I::MAX && first <= node.min_first() { + let left = node.left.to_usize(); + sorted_querent_query_firsts( + nodes, + left, + first, + last, + firstv, + lastv, + overlapping_intervals, + ); + } + + if node.right < I::MAX && last >= node.min_first() { + let right = node.right.to_usize(); + sorted_querent_query_firsts( + nodes, + right, + first, + last, + firstv, + lastv, + overlapping_intervals, + ); + } + } +} + +// True iff the two intervals overlap. +#[inline(always)] +fn overlaps(first_a: i32, last_a: i32, first_b: i32, last_b: i32) -> bool { + first_a <= last_b && last_a >= first_b +} + +// Used by `traverse` to keep record tree metadata. +#[derive(Clone, Debug)] +struct TraversalInfo +where + I: IntWithMax, +{ + depth: u32, + inorder: I, // in-order visit number + preorder: I, // pre-order visit number + subtree_size: I, + parent: I, + expected_hit_proportion: f32, +} + +impl Default for TraversalInfo +where + I: IntWithMax, +{ + fn default() -> Self { + TraversalInfo { + depth: 0, + inorder: I::default(), + preorder: I::default(), + subtree_size: I::one(), + parent: I::MAX, + expected_hit_proportion: 0.0, + } + } +} + +// dfs traversal of an implicit bst computing dfs number, node depth, subtree +// size, and left and right pointers. +fn traverse(nodes: &mut [IntervalNode]) -> Vec> +where + T: Clone, + I: IntWithMax, +{ + let n = nodes.len(); + let mut info = vec![TraversalInfo::default(); n]; + let mut inorder = 0; + let mut preorder = 0; + traverse_recursion( + nodes, + &mut info, + 0, + n, + 0, + I::MAX, + &mut inorder, + &mut preorder, + ); + + info +} + +// The recursive part of the `traverse` function. +fn traverse_recursion( + nodes: &mut [IntervalNode], + info: &mut [TraversalInfo], + start: usize, + end: usize, + depth: u32, + parent: I, + inorder: &mut usize, + preorder: &mut usize, +) -> (I, i32, f32) +where + T: Clone, + I: IntWithMax, +{ + if start >= end { + return (I::MAX, i32::MAX, f32::NAN); + } + + let root_idx = start + (end - start) / 2; + let subtree_size = end - start; + + info[root_idx].depth = depth; + info[root_idx].preorder = I::from_usize(*preorder); + info[root_idx].parent = parent; + *preorder += 1; + + let mut subtree_first = nodes[root_idx].min_first(); + + let mut left_expected_hits = 0.0; + let mut left_subtree_span = 0; + + if root_idx > start { + let (left, left_subtree_first, left_expected_hits_) = traverse_recursion( + nodes, + info, + start, + root_idx, + depth + 1, + I::from_usize(root_idx), + inorder, + preorder, + ); + + left_expected_hits = left_expected_hits_; + left_subtree_span = nodes[left.to_usize()].subtree_last - left_subtree_first + 1; + + subtree_first = left_subtree_first; + if nodes[left.to_usize()].subtree_last > nodes[root_idx].subtree_last { + nodes[root_idx].subtree_last = nodes[left.to_usize()].subtree_last; + } + nodes[root_idx].left = left; + } + + info[root_idx].inorder = I::from_usize(*inorder); + *inorder += 1; + + let mut right_expected_hits = 0.0; + let mut right_subtree_span = 0; + + if root_idx + 1 < end { + let (right, right_subtree_first, right_expected_hits_) = traverse_recursion( + nodes, + info, + root_idx + 1, + end, + depth + 1, + I::from_usize(root_idx), + inorder, + preorder, + ); + + right_expected_hits = right_expected_hits_; + right_subtree_span = nodes[right.to_usize()].subtree_last - right_subtree_first + 1; + + if nodes[right.to_usize()].subtree_last > nodes[root_idx].subtree_last { + nodes[root_idx].subtree_last = nodes[right.to_usize()].subtree_last; + } + nodes[root_idx].right = right; + } + + info[root_idx].subtree_size = I::from_usize(subtree_size); + let subtree_span = nodes[root_idx].subtree_last - subtree_first + 1; + + debug_assert!(left_subtree_span <= subtree_span); + debug_assert!(right_subtree_span <= subtree_span); + + let expected_hits = ((nodes[root_idx].max_last() - nodes[root_idx].min_first() + 1) as f32 + + (left_subtree_span as f32) * left_expected_hits + + (right_subtree_span as f32) * right_expected_hits) + / subtree_span as f32; + + info[root_idx].expected_hit_proportion = expected_hits / subtree_size as f32; + + (I::from_usize(root_idx), subtree_first, expected_hits) +} + +// norder partition by depth on pivot into three parts, like so +// [ bottom left ][ top ][ bottom right ] +// where bottom left and right are the bottom subtrees with positioned to +// the left and right of the root node +fn stable_ternary_tree_partition( + input: &[I], + output: &mut [I], + partition: &mut [i8], + info: &[TraversalInfo], + pivot_depth: u32, + pivot_dfs: I, + start: usize, + end: usize, +) -> (usize, usize) +where + I: IntWithMax, +{ + let n = end - start; + + // determine which partition each index goes in + let mut bottom_left_size = 0; + let mut top_size = 0; + let mut bottom_right_size = 0; + + for (i, p) in input[start..end].iter().zip(&mut partition[start..end]) { + let info_j = &info[i.to_usize()]; + if info_j.depth <= pivot_depth { + *p = 0; + top_size += 1; + } else if info_j.inorder < pivot_dfs { + *p = -1; + bottom_left_size += 1; + } else { + *p = 1; + bottom_right_size += 1; + } + } + + debug_assert!(bottom_left_size + top_size + bottom_right_size == n); + + // do the partition + let mut bl = start; + let mut t = bl + bottom_left_size; + let mut br = t + top_size; + for (i, p) in input[start..end].iter().zip(&partition[start..end]) { + match p.cmp(&0) { + Ordering::Less => { + output[bl] = *i; + bl += 1; + } + Ordering::Equal => { + output[t] = *i; + t += 1; + } + Ordering::Greater => { + output[br] = *i; + br += 1; + } + } + } + debug_assert!(br == end); + + (bl, t) +} + +// put nodes in van Emde Boas order +fn veb_order(mut nodes: Vec>) -> (Vec>, usize, usize) +where + T: Clone, + I: IntWithMax, +{ + // let now = Instant::now(); + let mut veb_nodes = nodes.clone(); + let n = veb_nodes.len(); + + if veb_nodes.is_empty() { + return (veb_nodes, 0, 0); + } + + // let now = Instant::now(); + let info = traverse(&mut nodes); + // eprintln!("traversing: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + let mut max_depth = 0; + for info_i in &info { + if info_i.depth > max_depth { + max_depth = info_i.depth; + } + } + + let idxs: &mut [I] = &mut vec![I::default(); n]; + (0..n).for_each(|i| idxs[i] = I::from_usize(i)); + + let tmp: &mut [I] = &mut vec![I::default(); n]; + + // put in dfs order + for i in &*idxs { + tmp[info[i.to_usize()].preorder.to_usize()] = *i; + } + let (idxs, tmp) = (tmp, idxs); + + // space used to by stable_ternary_tree_partition + let partition: &mut [i8] = &mut vec![0; n]; + + // let now = Instant::now(); + let root_idx = veb_order_recursion( + idxs, tmp, partition, &info, &mut nodes, 0, n, false, 0, max_depth, + ); + // eprintln!("computing order: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + // let now = Instant::now(); + // idxs is now a vEB -> sorted order map. Build the reverse here. + let revidx = tmp; + for (i, j) in idxs.iter().enumerate() { + revidx[j.to_usize()] = I::from_usize(i); + } + + // put nodes in vEB order + for (i_, mut node) in revidx.iter().zip(nodes) { + let i = i_.to_usize(); + if node.left != node.right { + if node.left < I::MAX { + node.left = revidx[node.left.to_usize()]; + } + + if node.right < I::MAX { + node.right = revidx[node.right.to_usize()]; + } + } + veb_nodes[i.to_usize()] = node; + } + + let root_idx = revidx[root_idx.to_usize()].to_usize(); + + // eprintln!("ordering: {}s", now.elapsed().as_millis() as f64 / 1000.0); + debug_assert!(compute_tree_size(&veb_nodes, root_idx) == n); + + (veb_nodes, root_idx, max_depth as usize) +} + +// Traverse the tree and return the size, used as a basic sanity check. +fn compute_tree_size(nodes: &[IntervalNode], root_idx: usize) -> usize +where + T: Clone, + I: IntWithMax, +{ + let mut subtree_size = 1; + + let node = &nodes[root_idx]; + if node.left == node.right { + subtree_size = nodes[root_idx].right.to_usize(); + } else { + if node.left < I::MAX { + let left = node.left.to_usize(); + subtree_size += compute_tree_size(nodes, left); + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + subtree_size += compute_tree_size(nodes, right); + } + } + + subtree_size +} + +// recursively reorder indexes to put it in vEB order. Called by `veb_order` +// idxs: current permutation +// tmp: temporary space of equal length to idxs +// partition: space used to assist `stable_ternary_tree_partition`. +// nodes: the interval nodes (in sorted order) +// start, end: slice within idxs to be reordered +// childless: true if this slice is a proper subtree and has no children below it +// parity: true if idxs and tmp are swapped and need to be copied back, +// min_depth, max_depth: depth extreme of the start..end slice +// +fn veb_order_recursion( + idxs: &mut [I], + tmp: &mut [I], + partition: &mut [i8], + info: &[TraversalInfo], + nodes: &mut [IntervalNode], + start: usize, + end: usize, + parity: bool, + min_depth: u32, + max_depth: u32, +) -> I +where + T: Clone, + I: IntWithMax, +{ + let n = (start..end).len(); + + assert!(n > 0); + + let childless = info[idxs[start].to_usize()].subtree_size.to_usize() == n; + + // small subtrees are put into sorted order and just searched through + // linearly. There is a little trickiness to this because we have to + // update the parent's child pointers and some other fields. + if childless + && ((info[idxs[start].to_usize()].subtree_size.to_usize() <= SIMPLE_SUBTREE_CUTOFF) + || (info[idxs[start].to_usize()].expected_hit_proportion + >= SIMPLE_SUBTREE_DENSITY_CUTOFF)) + { + debug_assert!(n == info[idxs[start].to_usize()].subtree_size.to_usize()); + + let old_root = idxs[start]; + + idxs[start..end].sort_unstable_by_key(|i| info[i.to_usize()].inorder); + let subtree_size = info[old_root.to_usize()].subtree_size; + nodes[idxs[start].to_usize()].subtree_last = nodes[old_root.to_usize()].subtree_last; + + // all children nodes record the size of the remaining list + // let mut subtree_i_size = subtree_size - i; + let mut subtree_i_size = subtree_size; + for idx in &idxs[start..end] { + nodes[idx.to_usize()].left = subtree_i_size; + nodes[idx.to_usize()].right = subtree_i_size; + subtree_i_size -= I::one(); + } + + let parent = info[old_root.to_usize()].parent; + if parent < I::MAX { + if nodes[parent.to_usize()].left == old_root { + nodes[parent.to_usize()].left = idxs[start]; + } else { + debug_assert!(nodes[parent.to_usize()].right == old_root); + nodes[parent.to_usize()].right = idxs[start]; + } + } + + if parity { + tmp[start..end].copy_from_slice(&idxs[start..end]); + } + return idxs[start]; + } + + // very small trees are already in order + if n == 1 { + if parity { + tmp[start] = idxs[start]; + } + return idxs[start]; + } + + let pivot_depth = min_depth + (max_depth - min_depth) / 2; + let pivot_dfs = info[idxs[start].to_usize()].inorder; + + let (top_start, bottom_right_start) = stable_ternary_tree_partition( + idxs, + tmp, + partition, + info, + pivot_depth, + pivot_dfs, + start, + end, + ); + + // tmp is not partitioned, so swap pointers + let (tmp, idxs) = (idxs, tmp); + + // recurse on top subtree + let top_root_idx = veb_order_recursion( + idxs, + tmp, + partition, + info, + nodes, + top_start, + bottom_right_start, + !parity, + min_depth, + pivot_depth, + ); + + // find on recurse on subtrees in the bottom left partition and bottom right partition + for (part_start, part_end) in &[(start, top_start), (bottom_right_start, end)] { + let bottom_subtree_depth = pivot_depth + 1; + let mut i = *part_start; + while i < *part_end { + debug_assert!(info[idxs[i].to_usize()].depth == bottom_subtree_depth); + + let mut subtree_max_depth = info[idxs[i].to_usize()].depth; + let mut j = *part_end; + for (u, v) in (i + 1..*part_end).zip(&idxs[i + 1..*part_end]) { + let depth = info[v.to_usize()].depth; + if depth == bottom_subtree_depth { + j = u; + break; + } else if depth > subtree_max_depth { + subtree_max_depth = depth; + } + } + + veb_order_recursion( + idxs, + tmp, + partition, + info, + nodes, + i, + j, + !parity, + bottom_subtree_depth, + subtree_max_depth, + ); + i = j; + } + } + + top_root_idx +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metadata() { + let intervals = (0..10) + .map(|i| Interval::new(i, i + 1, i as usize)) + .collect::>>(); + + let tree = NeonCOITree::::new(&intervals); + + let mut metadata = Vec::new(); + tree.query(0, 3, |node| metadata.push(node.clone())); + + println!("{:?}", metadata); + } +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/nosimd.rs b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/nosimd.rs new file mode 100644 index 0000000..48099d2 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/coitrees/src/nosimd.rs @@ -0,0 +1,1123 @@ +//! # COITrees +//! `coitrees` implements a fast static interval tree data structure with genomic +//! data in mind. +//! +//! The data structure used a fairly standard interval tree, but with nodes stored +//! in van Emde Boas layout, which improves average cache locality, and thus +//! query performance. The downside is that building the tree is more expensive +//! so a relatively large number of queries needs to made for it to break even. +//! +//! The data structure `COITree` is constructed with an array of `IntervalNode` +//! structs which store integer, end-inclusive intervals along with associated +//! metadata. The tree can be queried directly for coverage or overlaps, or +//! through the intermediary `SortedQuerent` which keeps track of some state +//! to accelerate overlaping queries. + +use super::interval::{GenericInterval, IntWithMax, Interval, IntervalTree, SortedQuerent}; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::iter::IntoIterator; + +// Small subtrees at the bottom of the tree are stored in sorted order +// This gives the upper bound on the size of such subtrees. Performance isn't +// super sensitive, but is worse with a very small or very large number. +const SIMPLE_SUBTREE_CUTOFF: usize = 64; + +// Very dense subtrees in which we probably intersect most of the intervals +// are more efficient to query linearly. When the expected proportion of hits +// is a above this number it becomes a simple subtree. +const SIMPLE_SUBTREE_DENSITY_CUTOFF: f32 = 0.2; + +/// Internal interval node representation used by BasicCOITree +#[derive(Clone)] +pub struct IntervalNode +where + T: Clone, + I: IntWithMax, +{ + // subtree interval + subtree_last: i32, + + // interval + pub first: i32, + pub last: i32, + + // when this is the root of a simple subtree, left == right is the size + // of the subtree, otherwise they are left, right child pointers. + left: I, + right: I, + + pub metadata: T, +} + +impl IntervalNode +where + T: Clone, + I: IntWithMax, +{ + pub fn new(first: i32, last: i32, metadata: T) -> IntervalNode { + Self { + subtree_last: last, + first, + last, + left: I::MAX, + right: I::MAX, + metadata, + } + } + + fn from_interval(interval: &V) -> IntervalNode + where + V: GenericInterval, + { + return Self::new( + interval.first(), + interval.last(), + interval.metadata().clone(), + ); + } + + /// Length spanned by the interval. (Interval are end-inclusive.) + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> i32 { + (self.last - self.first + 1).max(0) + } +} + +// IntervalNodes are themselves a type of annotated interval +impl GenericInterval for IntervalNode +where + T: Clone, + I: IntWithMax, +{ + fn first(&self) -> i32 { + self.first + } + + fn last(&self) -> i32 { + self.last + } + + fn metadata(&self) -> &T { + &self.metadata + } +} + +#[test] +fn test_interval_len() { + fn make_interval(first: i32, last: i32) -> IntervalNode<(), u32> { + IntervalNode::new(first, last, ()) + } + + assert_eq!(make_interval(1, -1).len(), 0); + assert_eq!(make_interval(1, 0).len(), 0); + assert_eq!(make_interval(1, 1).len(), 1); + assert_eq!(make_interval(1, 2).len(), 2); +} + +/// COITree data structure. A representation of a static set of intervals with +/// associated metadata, enabling fast overlap and coverage queries. +/// +/// The index type `I` is a typically `usize`, but can be `u32` or `u16`. +/// It's slightly more efficient to use a smalled index type, assuming there are +/// fewer than I::MAX-1 intervals to store. +#[derive(Clone)] +pub struct BasicCOITree +where + T: Clone, + I: IntWithMax, +{ + nodes: Vec>, + root_idx: usize, + height: usize, +} + +impl<'a, T, I> BasicCOITree +where + T: Clone, + I: IntWithMax, +{ + // Exactly the same as IntervalTree::query, but provides a reference with the + // lifetime of the tree (not something we can guarantee with other implementations). + fn query_static(&'a self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&'a IntervalNode), + { + if !self.is_empty() { + query_recursion(&self.nodes, self.root_idx, first, last, &mut visit); + } + } +} + +impl<'a, T, I> IntervalTree<'a> for BasicCOITree +where + T: Clone + 'a, + I: IntWithMax + 'a, +{ + type Metadata = T; + type Index = I; + type Item = IntervalNode; + type Iter = BasicCOITreeIterator<'a, T, I>; + + fn new<'c, U, V>(intervals: U) -> BasicCOITree + where + U: IntoIterator, + V: GenericInterval + 'c, + { + let nodes: Vec> = intervals + .into_iter() + .map(|node| IntervalNode::from_interval(node)) + .collect(); + if nodes.len() >= (I::MAX).to_usize() { + panic!("BasicCOITree construction failed: more intervals than index type can enumerate") + } + + let (nodes, root_idx, height) = veb_order(nodes); + + BasicCOITree { + nodes, + root_idx, + height, + } + } + + /// Number of intervals in the set. + fn len(&self) -> usize { + self.nodes.len() + } + + /// True iff the set is empty. + fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Find intervals in the set overlaping the query `[first, last]` and call `visit` on every overlapping node + fn query(&'a self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&IntervalNode), + { + if !self.is_empty() { + query_recursion(&self.nodes, self.root_idx, first, last, &mut visit); + } + } + + /// Count the number of intervals in the set overlapping the query `[first, last]`. + fn query_count(&self, first: i32, last: i32) -> usize { + if !self.is_empty() { + query_recursion_count(&self.nodes, self.root_idx, first, last) + } else { + 0 + } + } + + /// Return a pair `(count, cov)`, where `count` gives the number of intervals + /// in the set overlapping the query, and `cov` the number of positions in the query + /// interval covered by at least one interval in the set. + fn coverage(&self, first: i32, last: i32) -> (usize, usize) { + assert!(last >= first); + + if self.is_empty() { + return (0, 0); + } + + let (mut uncov_len, last_cov, count) = + coverage_recursion(&self.nodes, self.root_idx, first, last, first - 1); + + if last_cov < last { + uncov_len += last - last_cov; + } + + let cov = ((last - first + 1) as usize) - (uncov_len as usize); + + (count, cov) + } + + /// Iterate through the interval set in sorted order by interval start position. + fn iter(&'a self) -> BasicCOITreeIterator<'a, T, I> { + let mut i = self.root_idx; + let mut stack: Vec = Vec::with_capacity(self.height); + while i < self.nodes.len() + && self.nodes[i].left != I::MAX + && self.nodes[i].left != self.nodes[i].right + { + stack.push(i); + i = self.nodes[i].left.to_usize(); + } + + BasicCOITreeIterator { + nodes: &self.nodes, + i, + count: 0, + stack, + } + } +} + +impl<'a, T, I> IntoIterator for &'a BasicCOITree +where + T: Clone, + I: IntWithMax, +{ + type Item = Interval<&'a T>; + type IntoIter = BasicCOITreeIterator<'a, T, I>; + + fn into_iter(self) -> BasicCOITreeIterator<'a, T, I> { + self.iter() + } +} + +/// Iterate through nodes in a tree in sorted order by interval start position. +pub struct BasicCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + nodes: &'a Vec>, + i: usize, // current node + count: usize, // number generated so far + stack: Vec, +} + +impl<'a, T, I> Iterator for BasicCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + type Item = Interval<&'a T>; + + fn next(&mut self) -> Option { + if self.i >= self.nodes.len() { + return None; + } + + let node = &self.nodes[self.i]; + + if node.left == node.right { + // simple node + if node.left.to_usize() > 1 { + self.i += 1; + } else if let Some(i) = self.stack.pop() { + self.i = i; + } else { + self.i = usize::MAX; + } + } else if node.right == I::MAX { + if let Some(i) = self.stack.pop() { + self.i = i; + } else { + self.i = usize::MAX; + } + } else { + let mut i = node.right.to_usize(); + + while self.nodes[i].left != I::MAX && self.nodes[i].left != self.nodes[i].right { + self.stack.push(i); + i = self.nodes[i].left.to_usize(); + } + + self.i = i; + } + + self.count += 1; + Some(Interval::new(node.first, node.last, &node.metadata)) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.nodes.len() - self.count; + (len, Some(len)) + } +} + +impl<'a, T, I> ExactSizeIterator for BasicCOITreeIterator<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + fn len(&self) -> usize { + self.nodes.len() - self.count + } +} + +// Recursively count overlaps between the tree specified by `nodes` and a +// query interval specified by `first`, `last`. +fn query_recursion<'a, T, I, F>( + nodes: &'a [IntervalNode], + root_idx: usize, + first: i32, + last: i32, + visit: &mut F, +) where + T: Clone, + I: IntWithMax, + F: FnMut(&'a IntervalNode), +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.first { + break; + } else if first <= node.last { + visit(node); + } + } + } else { + if overlaps(node.first, node.last, first, last) { + visit(node); + } + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + query_recursion(nodes, left, first, last, visit); + } + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.first, nodes[right].subtree_last, first, last) { + query_recursion(nodes, right, first, last, visit); + } + } + } +} + +// query_recursion but just count number of overlaps +fn query_recursion_count( + nodes: &[IntervalNode], + root_idx: usize, + first: i32, + last: i32, +) -> usize +where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + let mut count = 0; + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.first { + break; + } else if first <= node.last { + count += 1; + } + } + count + } else { + let mut count = 0; + if overlaps(node.first, node.last, first, last) { + count += 1; + } + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + count += query_recursion_count(nodes, left, first, last); + } + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.first, nodes[right].subtree_last, first, last) { + count += query_recursion_count(nodes, right, first, last); + } + } + + count + } +} + +fn coverage_recursion( + nodes: &[IntervalNode], + root_idx: usize, + first: i32, + last: i32, + mut last_cov: i32, +) -> (i32, i32, usize) +where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + let mut count = 0; + let mut uncov_len = 0; + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if overlaps(node.first, node.last, first, last) { + if node.first > last_cov { + uncov_len += node.first - (last_cov + 1); + } + last_cov = last_cov.max(node.last); + count += 1; + } + } + (uncov_len, last_cov, count) + } else { + let mut uncov_len = 0; + let mut count = 0; + + if node.left < I::MAX { + let left = node.left.to_usize(); + if nodes[left].subtree_last >= first { + let (left_uncov_len, left_last_cov, left_count) = + coverage_recursion(nodes, left, first, last, last_cov); + last_cov = left_last_cov; + uncov_len += left_uncov_len; + count += left_count; + } + } + + if overlaps(node.first, node.last, first, last) { + if node.first > last_cov { + uncov_len += node.first - (last_cov + 1); + } + last_cov = last_cov.max(node.last); + count += 1; + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + if overlaps(node.first, nodes[right].subtree_last, first, last) { + let (right_uncov_len, right_last_cov, right_count) = + coverage_recursion(nodes, right, first, last, last_cov); + last_cov = right_last_cov; + uncov_len += right_uncov_len; + count += right_count; + } + } + + (uncov_len, last_cov, count) + } +} + +/// An alternative query strategy that can be much faster when queries are performed +/// in sorted order and overlap. +/// +/// Unilke `COITree::query`, some state is retained between queries. +/// `SortedQuerent` tracks that state. If queries are not sorted or don't +/// overlap, this strategy still works, but is slightly slower than +/// `COITree::query`. +pub struct BasicSortedQuerent<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + tree: &'a BasicCOITree, + prev_first: i32, + prev_last: i32, + overlapping_intervals: Vec<&'a IntervalNode>, +} + +impl<'a, T, I> SortedQuerent<'a> for BasicSortedQuerent<'a, T, I> +where + T: Clone, + I: IntWithMax, +{ + type Metadata = T; + type Index = I; + type Item = IntervalNode; + type Iter = BasicCOITreeIterator<'a, T, I>; + type Tree = BasicCOITree; + + /// Construct a new `SortedQuerent` to perform a sequence. queries. + fn new(tree: &'a BasicCOITree) -> BasicSortedQuerent<'a, T, I> { + let overlapping_intervals: Vec<&IntervalNode> = Vec::new(); + BasicSortedQuerent { + tree, + prev_first: -1, + prev_last: -1, + overlapping_intervals, + } + } + + /// Find intervals in the underlying `COITree` that overlap the query + /// `[first, last]` and call `visit` on each. Works equivalently to + /// `COITrees::query` but queries that overlap prior queries will potentially + /// be faster. + fn query(&mut self, first: i32, last: i32, mut visit: F) + where + F: FnMut(&'a IntervalNode), + { + if self.tree.is_empty() { + return; + } + + // not overlaping or preceding + if first < self.prev_first || first > self.prev_last { + // no overlap with previous query. have to resort to regular query strategy + self.overlapping_intervals.clear(); + self.tree + .query_static(first, last, |node| self.overlapping_intervals.push(node)); + } else { + // successor query, exploit the overlap + + // delete previously overlapping intervals with end in [prev_first, first-1] + if self.prev_first < first { + let mut i = 0; + while i < self.overlapping_intervals.len() { + if self.overlapping_intervals[i].last < first { + self.overlapping_intervals.swap_remove(i); + } else { + i += 1; + } + } + } + + // delete previously overlapping intervals with start in [last+1, prev_end] + if self.prev_last > last { + let mut i = 0; + while i < self.overlapping_intervals.len() { + if self.overlapping_intervals[i].first > last { + self.overlapping_intervals.swap_remove(i); + } else { + i += 1; + } + } + } + + // add any interval that start in [prev_last+1, last] + if self.prev_last < last { + sorted_querent_query_firsts( + &self.tree.nodes, + self.tree.root_idx, + self.prev_last + 1, + last, + &mut self.overlapping_intervals, + ); + } + } + + // call visit on everything + for overlapping_interval in &self.overlapping_intervals { + visit(overlapping_interval); + } + + self.prev_first = first; + self.prev_last = last; + } +} + +// find any intervals in the tree with their first value in [first, last] +fn sorted_querent_query_firsts<'a, T, I>( + nodes: &'a [IntervalNode], + root_idx: usize, + first: i32, + last: i32, + overlapping_intervals: &mut Vec<&'a IntervalNode>, +) where + T: Clone, + I: IntWithMax, +{ + let node = &nodes[root_idx]; + + if node.left == node.right { + // simple subtree + for node in &nodes[root_idx..root_idx + node.right.to_usize()] { + if last < node.first { + break; + } else if first <= node.first { + overlapping_intervals.push(node); + } + } + } else { + if first <= node.first && node.first <= last { + overlapping_intervals.push(node); + } + + if node.left < I::MAX && first <= node.first { + let left = node.left.to_usize(); + sorted_querent_query_firsts(nodes, left, first, last, overlapping_intervals); + } + + if node.right < I::MAX && last >= node.first { + let right = node.right.to_usize(); + sorted_querent_query_firsts(nodes, right, first, last, overlapping_intervals); + } + } +} + +// True iff the two intervals overlap. +#[inline(always)] +fn overlaps(first_a: i32, last_a: i32, first_b: i32, last_b: i32) -> bool { + first_a <= last_b && last_a >= first_b +} + +// Used by `traverse` to keep record tree metadata. +#[derive(Clone, Debug)] +struct TraversalInfo +where + I: IntWithMax, +{ + depth: u32, + inorder: I, // in-order visit number + preorder: I, // pre-order visit number + subtree_size: I, + parent: I, + expected_hit_proportion: f32, +} + +impl Default for TraversalInfo +where + I: IntWithMax, +{ + fn default() -> Self { + TraversalInfo { + depth: 0, + inorder: I::default(), + preorder: I::default(), + subtree_size: I::one(), + parent: I::MAX, + expected_hit_proportion: 0.0, + } + } +} + +// dfs traversal of an implicit bst computing dfs number, node depth, subtree +// size, and left and right pointers. +fn traverse(nodes: &mut [IntervalNode]) -> Vec> +where + T: Clone, + I: IntWithMax, +{ + let n = nodes.len(); + let mut info = vec![TraversalInfo::default(); n]; + let mut inorder = 0; + let mut preorder = 0; + traverse_recursion( + nodes, + &mut info, + 0, + n, + 0, + I::MAX, + &mut inorder, + &mut preorder, + ); + + info +} + +// The recursive part of the `traverse` function. +fn traverse_recursion( + nodes: &mut [IntervalNode], + info: &mut [TraversalInfo], + start: usize, + end: usize, + depth: u32, + parent: I, + inorder: &mut usize, + preorder: &mut usize, +) -> (I, i32, f32) +where + T: Clone, + I: IntWithMax, +{ + if start >= end { + return (I::MAX, i32::MAX, f32::NAN); + } + + let root_idx = start + (end - start) / 2; + let subtree_size = end - start; + + info[root_idx].depth = depth; + info[root_idx].preorder = I::from_usize(*preorder); + info[root_idx].parent = parent; + *preorder += 1; + + let mut subtree_first = nodes[root_idx].first; + + let mut left_expected_hits = 0.0; + let mut left_subtree_span = 0; + + if root_idx > start { + let (left, left_subtree_first, left_expected_hits_) = traverse_recursion( + nodes, + info, + start, + root_idx, + depth + 1, + I::from_usize(root_idx), + inorder, + preorder, + ); + + left_expected_hits = left_expected_hits_; + left_subtree_span = nodes[left.to_usize()].subtree_last - left_subtree_first + 1; + + subtree_first = left_subtree_first; + if nodes[left.to_usize()].subtree_last > nodes[root_idx].subtree_last { + nodes[root_idx].subtree_last = nodes[left.to_usize()].subtree_last; + } + nodes[root_idx].left = left; + } + + info[root_idx].inorder = I::from_usize(*inorder); + *inorder += 1; + + let mut right_expected_hits = 0.0; + let mut right_subtree_span = 0; + + if root_idx + 1 < end { + let (right, right_subtree_first, right_expected_hits_) = traverse_recursion( + nodes, + info, + root_idx + 1, + end, + depth + 1, + I::from_usize(root_idx), + inorder, + preorder, + ); + + right_expected_hits = right_expected_hits_; + right_subtree_span = nodes[right.to_usize()].subtree_last - right_subtree_first + 1; + + if nodes[right.to_usize()].subtree_last > nodes[root_idx].subtree_last { + nodes[root_idx].subtree_last = nodes[right.to_usize()].subtree_last; + } + nodes[root_idx].right = right; + } + + info[root_idx].subtree_size = I::from_usize(subtree_size); + let subtree_span = nodes[root_idx].subtree_last - subtree_first + 1; + + debug_assert!(left_subtree_span <= subtree_span); + debug_assert!(right_subtree_span <= subtree_span); + + let expected_hits = ((nodes[root_idx].last - nodes[root_idx].first + 1) as f32 + + (left_subtree_span as f32) * left_expected_hits + + (right_subtree_span as f32) * right_expected_hits) + / subtree_span as f32; + + info[root_idx].expected_hit_proportion = expected_hits / subtree_size as f32; + + (I::from_usize(root_idx), subtree_first, expected_hits) +} + +// norder partition by depth on pivot into three parts, like so +// [ bottom left ][ top ][ bottom right ] +// where bottom left and right are the bottom subtrees with positioned to +// the left and right of the root node +fn stable_ternary_tree_partition( + input: &[I], + output: &mut [I], + partition: &mut [i8], + info: &[TraversalInfo], + pivot_depth: u32, + pivot_dfs: I, + start: usize, + end: usize, +) -> (usize, usize) +where + I: IntWithMax, +{ + let n = end - start; + + // determine which partition each index goes in + let mut bottom_left_size = 0; + let mut top_size = 0; + let mut bottom_right_size = 0; + + for (i, p) in input[start..end].iter().zip(&mut partition[start..end]) { + let info_j = &info[i.to_usize()]; + if info_j.depth <= pivot_depth { + *p = 0; + top_size += 1; + } else if info_j.inorder < pivot_dfs { + *p = -1; + bottom_left_size += 1; + } else { + *p = 1; + bottom_right_size += 1; + } + } + + debug_assert!(bottom_left_size + top_size + bottom_right_size == n); + + // do the partition + let mut bl = start; + let mut t = bl + bottom_left_size; + let mut br = t + top_size; + for (i, p) in input[start..end].iter().zip(&partition[start..end]) { + match p.cmp(&0) { + Ordering::Less => { + output[bl] = *i; + bl += 1; + } + Ordering::Equal => { + output[t] = *i; + t += 1; + } + Ordering::Greater => { + output[br] = *i; + br += 1; + } + } + } + debug_assert!(br == end); + + (bl, t) +} + +// put nodes in van Emde Boas order +fn veb_order(mut nodes: Vec>) -> (Vec>, usize, usize) +where + T: Clone, + I: IntWithMax, +{ + // let now = Instant::now(); + let mut veb_nodes = nodes.clone(); + let n = veb_nodes.len(); + + if veb_nodes.is_empty() { + return (veb_nodes, 0, 0); + } + + nodes.sort_unstable_by_key(|node| (node.first, node.last)); + // eprintln!("sorting nodes: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + // let now = Instant::now(); + let info = traverse(&mut nodes); + // eprintln!("traversing: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + let mut max_depth = 0; + for info_i in &info { + if info_i.depth > max_depth { + max_depth = info_i.depth; + } + } + + let idxs: &mut [I] = &mut vec![I::default(); n]; + + idxs.iter_mut() + .enumerate() + .for_each(|(i, idx)| *idx = I::from_usize(i)); + + let tmp: &mut [I] = &mut vec![I::default(); n]; + + // put in dfs order + for i in &*idxs { + tmp[info[i.to_usize()].preorder.to_usize()] = *i; + } + let (idxs, tmp) = (tmp, idxs); + + // space used to by stable_ternary_tree_partition + let partition: &mut [i8] = &mut vec![0; n]; + + // let now = Instant::now(); + let root_idx = veb_order_recursion( + idxs, tmp, partition, &info, &mut nodes, 0, n, false, 0, max_depth, + ); + // eprintln!("computing order: {}s", now.elapsed().as_millis() as f64 / 1000.0); + + // let now = Instant::now(); + // idxs is now a vEB -> sorted order map. Build the reverse here. + let revidx = tmp; + for (i, j) in idxs.iter().enumerate() { + revidx[j.to_usize()] = I::from_usize(i); + } + + // put nodes in vEB order + for (i_, mut node) in revidx.iter().zip(nodes) { + let i = i_.to_usize(); + if node.left != node.right { + if node.left < I::MAX { + node.left = revidx[node.left.to_usize()]; + } + + if node.right < I::MAX { + node.right = revidx[node.right.to_usize()]; + } + } + veb_nodes[i.to_usize()] = node; + } + + let root_idx = revidx[root_idx.to_usize()].to_usize(); + + // eprintln!("ordering: {}s", now.elapsed().as_millis() as f64 / 1000.0); + debug_assert!(compute_tree_size(&veb_nodes, root_idx) == n); + + (veb_nodes, root_idx, max_depth as usize) +} + +// Traverse the tree and return the size, used as a basic sanity check. +fn compute_tree_size(nodes: &[IntervalNode], root_idx: usize) -> usize +where + T: Clone, + I: IntWithMax, +{ + let mut subtree_size = 1; + + let node = &nodes[root_idx]; + if node.left == node.right { + subtree_size = nodes[root_idx].right.to_usize(); + } else { + if node.left < I::MAX { + let left = node.left.to_usize(); + subtree_size += compute_tree_size(nodes, left); + } + + if node.right < I::MAX { + let right = node.right.to_usize(); + subtree_size += compute_tree_size(nodes, right); + } + } + + subtree_size +} + +// recursively reorder indexes to put it in vEB order. Called by `veb_order` +// idxs: current permutation +// tmp: temporary space of equal length to idxs +// partition: space used to assist `stable_ternary_tree_partition`. +// nodes: the interval nodes (in sorted order) +// start, end: slice within idxs to be reordered +// childless: true if this slice is a proper subtree and has no children below it +// parity: true if idxs and tmp are swapped and need to be copied back, +// min_depth, max_depth: depth extreme of the start..end slice +// +fn veb_order_recursion( + idxs: &mut [I], + tmp: &mut [I], + partition: &mut [i8], + info: &[TraversalInfo], + nodes: &mut [IntervalNode], + start: usize, + end: usize, + parity: bool, + min_depth: u32, + max_depth: u32, +) -> I +where + T: Clone, + I: IntWithMax, +{ + let n = (start..end).len(); + + assert!(n > 0); + + let childless = info[idxs[start].to_usize()].subtree_size.to_usize() == n; + + // small subtrees are put into sorted order and just searched through + // linearly. There is a little trickiness to this because we have to + // update the parent's child pointers and some other fields. + if childless + && ((info[idxs[start].to_usize()].subtree_size.to_usize() <= SIMPLE_SUBTREE_CUTOFF) + || (info[idxs[start].to_usize()].expected_hit_proportion + >= SIMPLE_SUBTREE_DENSITY_CUTOFF)) + { + debug_assert!(n == info[idxs[start].to_usize()].subtree_size.to_usize()); + + let old_root = idxs[start]; + + idxs[start..end].sort_unstable_by_key(|i| info[i.to_usize()].inorder); + let subtree_size = info[old_root.to_usize()].subtree_size; + nodes[idxs[start].to_usize()].subtree_last = nodes[old_root.to_usize()].subtree_last; + + // all children nodes record the size of the remaining list + // let mut subtree_i_size = subtree_size - i; + let mut subtree_i_size = subtree_size; + for idx in &idxs[start..end] { + nodes[idx.to_usize()].left = subtree_i_size; + nodes[idx.to_usize()].right = subtree_i_size; + subtree_i_size -= I::one(); + } + + let parent = info[old_root.to_usize()].parent; + if parent < I::MAX { + if nodes[parent.to_usize()].left == old_root { + nodes[parent.to_usize()].left = idxs[start]; + } else { + debug_assert!(nodes[parent.to_usize()].right == old_root); + nodes[parent.to_usize()].right = idxs[start]; + } + } + + if parity { + tmp[start..end].copy_from_slice(&idxs[start..end]); + } + return idxs[start]; + } + + // very small trees are already in order + if n == 1 { + if parity { + tmp[start] = idxs[start]; + } + return idxs[start]; + } + + let pivot_depth = min_depth + (max_depth - min_depth) / 2; + let pivot_dfs = info[idxs[start].to_usize()].inorder; + + let (top_start, bottom_right_start) = stable_ternary_tree_partition( + idxs, + tmp, + partition, + info, + pivot_depth, + pivot_dfs, + start, + end, + ); + + // tmp is not partitioned, so swap pointers + let (tmp, idxs) = (idxs, tmp); + + // recurse on top subtree + let top_root_idx = veb_order_recursion( + idxs, + tmp, + partition, + info, + nodes, + top_start, + bottom_right_start, + !parity, + min_depth, + pivot_depth, + ); + + // find on recurse on subtrees in the bottom left partition and bottom right partition + for (part_start, part_end) in &[(start, top_start), (bottom_right_start, end)] { + let bottom_subtree_depth = pivot_depth + 1; + let mut i = *part_start; + while i < *part_end { + debug_assert!(info[idxs[i].to_usize()].depth == bottom_subtree_depth); + + let mut subtree_max_depth = info[idxs[i].to_usize()].depth; + let mut j = *part_end; + for (u, v) in (i + 1..*part_end).zip(&idxs[i + 1..*part_end]) { + let depth = info[v.to_usize()].depth; + if depth == bottom_subtree_depth { + j = u; + break; + } else if depth > subtree_max_depth { + subtree_max_depth = depth; + } + } + + veb_order_recursion( + idxs, + tmp, + partition, + info, + nodes, + i, + j, + !parity, + bottom_subtree_depth, + subtree_max_depth, + ); + i = j; + } + } + + top_root_idx +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/intervaldb.c b/sequila/sequila-core/superintervals/test/3rd-party/intervaldb.c new file mode 100644 index 0000000..1dafbdb --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/intervaldb.c @@ -0,0 +1,658 @@ +#include + +#include "intervaldb.h" + +int C_int_max=INT_MAX; /* KLUDGE TO LET PYREX CODE ACCESS VALUE OF INT_MAX MACRO */ + +unsigned long gstart[25] = {0, 15940, 31520, 44280, 56520, 68190, 79180, 89440, + 98810, 107670, 116280, 124990, 133570, 140870, 147710, 154220, 160050, + 165420, 170580, 174400, 178550, 181530, 184770, 194650, 198160}; + +int imstart_qsort_cmp(const void *void_a,const void *void_b) +{ /* STRAIGHTFORWARD COMPARISON OF SIGNED start VALUES, LONGER INTERVALS 1ST */ + IntervalMap *a=(IntervalMap *)void_a,*b=(IntervalMap *)void_b; + if (a->startstart) + return -1; + else if (a->start>b->start) + return 1; + else if (a->end>b->end) /* SAME START: PUT LONGER INTERVAL 1ST */ + return -1; + else if (a->endend) /* CONTAINED INTERVAL SHOULD FOLLOW LARGER INTERVAL*/ + return 1; + else + return 0; +} + +#ifdef MERGE_INTERVAL_ORIENTATIONS +int im_qsort_cmp(const void *void_a,const void *void_b) +{ /* MERGE FORWARD AND REVERSE INTERVALS AS IF THEY WERE ALL IN FORWARD ORI */ + int a_start,a_end,b_start,b_end; + IntervalMap *a=(IntervalMap *)void_a,*b=(IntervalMap *)void_b; + SET_INTERVAL_POSITIVE(*a,a_start,a_end); + SET_INTERVAL_POSITIVE(*b,b_start,b_end); + if (a_startb_start) + return 1; + else if (a_end>b_end) /* SAME START: PUT LONGER INTERVAL 1ST */ + return -1; + else if (a_endsublistsublist) + return -1; + else if (a->sublist>b->sublist) + return 1; + else if (START_POSITIVE(*a) < START_POSITIVE(*b)) + return -1; + else if (START_POSITIVE(*a) > START_POSITIVE(*b)) + return 1; + else + return 0; +} + +SublistHeader *build_nested_list_inplace(IntervalMap im[],int n, + int *p_n,int *p_nlists) +{ + int i=0,parent,nlists=1,isublist=0,total=0,temp=0; + SublistHeader *subheader=NULL; + +#ifdef ALL_POSITIVE_ORIENTATION + reorient_intervals(n,im,1); /* FORCE ALL INTERVALS INTO POSITIVE ORI */ +#endif +#ifdef MERGE_INTERVAL_ORIENTATIONS + qsort(im,n,sizeof(IntervalMap),im_qsort_cmp); /* SORT BY start, CONTAINMENT */ +#else + qsort(im,n,sizeof(IntervalMap),imstart_qsort_cmp); /* SORT BY start, CONTAINMENT */ +#endif + nlists=1; + for(i=1;iEND_POSITIVE(im[i-1]) /* i NOT CONTAINED */ + || (END_POSITIVE(im[i])==END_POSITIVE(im[i-1]) /* SAME INTERVAL! */ + && START_POSITIVE(im[i])==START_POSITIVE(im[i-1])))){ + nlists++; +/* printf("%d (%d,%d) -> (%d,%d) %d\n", nlists, im[i-1].start, */ +/* im[i-1].end, im[i].start,im[i].end,i); */ + } + } + +/* printf("%d lists?!\n", nlists); */ + *p_nlists=nlists-1; + + if(nlists==1){ + *p_n=n; + //subheader = calloc(1,sizeof(SublistHeader)); + CALLOC(subheader,1,SublistHeader); /* RETURN A DUMMY ARRAY, SINCE NULL RETURN IS ERROR CODE */ + return subheader; + } + //subheader = calloc(nlists+1,sizeof(SublistHeader)); + CALLOC(subheader,nlists+1,SublistHeader); /* SUBLIST HEADER INDEX */ + + im[0].sublist=0; + subheader[0].start= -1; + subheader[0].len=1; + parent=0; + nlists=1; + isublist=1; + for(i=1;iEND_POSITIVE(im[parent]) /* i NOT CONTAINED */ + || (END_POSITIVE(im[i])==END_POSITIVE(im[parent]) /* SAME INTERVAL! */ + && START_POSITIVE(im[i])==START_POSITIVE(im[parent])))){ + subheader[isublist].start=subheader[im[parent].sublist].len-1; /* RECORD PARENT RELATIVE POSITION */ + isublist=im[parent].sublist; + parent=subheader[im[parent].sublist].start; + } + else{ + if(subheader[isublist].len==0){ + nlists++; + } + subheader[isublist].len++; + im[i].sublist=isublist; + parent=i; + isublist=nlists; + subheader[isublist].start=parent; + i++; + } + } + + while(isublist>0){ /* pop remaining stack */ + subheader[isublist].start=subheader[im[parent].sublist].len-1; /* RECORD PARENT RELATIVE POSITION */ + isublist=im[parent].sublist; + parent=subheader[im[parent].sublist].start; + } + + *p_n=subheader[0].len; + + total=0; + for(i=0;iim[i-1].sublist){ + subheader[im[i].sublist].start+=subheader[im[i-1].sublist].len; + } + } + + /* SUBHEADER.START IS NOW ABS POSITION OF PARENT */ + + qsort(im,n,sizeof(IntervalMap),sublist_qsort_cmp); + /* AT THIS POINT SUBLISTS ARE GROUPED TOGETHER, READY TO PACK */ + + isublist=0; + subheader[0].start=0; + subheader[0].len=0; + for(i=0;iisublist){ +/* printf("Entering sublist %d (%d,%d)\n", im[i].sublist, im[i].start,im[i].end); */ + isublist=im[i].sublist; + parent=subheader[isublist].start; +/* printf("Parent (%d,%d) is at %d, list start is at %d\n", */ +/* im[parent].start, im[parent].end, subheader[isublist].start,i); */ + im[parent].sublist=isublist-1; + subheader[isublist].len=0; + subheader[isublist].start=i; + } + subheader[isublist].len++; + im[i].sublist= -1; + } + + nlists--; + memmove(subheader,subheader+1,nlists*sizeof(SublistHeader)); + + return subheader; + handle_malloc_failure: + /* FREE ANY MALLOCS WE PERFORMED*/ + FREE(subheader); + return NULL; +} + + + +SublistHeader *build_nested_list(IntervalMap im[],int n, + int *p_n,int *p_nlists) +{ + int i=0,j,k,parent,nsub=0,nlists=0; + IntervalMap *imsub=NULL; + SublistHeader *subheader=NULL; + +#ifdef ALL_POSITIVE_ORIENTATION + reorient_intervals(n,im,1); /* FORCE ALL INTERVALS INTO POSITIVE ORI */ +#endif +#ifdef MERGE_INTERVAL_ORIENTATIONS + qsort(im,n,sizeof(IntervalMap),im_qsort_cmp); /* SORT BY start, CONTAINMENT */ +#else + qsort(im,n,sizeof(IntervalMap),imstart_qsort_cmp); /* SORT BY start, CONTAINMENT */ +#endif + while (i=0) { /* RECURSIVE ALGORITHM OF ALEX ALEKSEYENKO */ + if (END_POSITIVE(im[i])>END_POSITIVE(im[parent]) /* i NOT CONTAINED */ + || (END_POSITIVE(im[i])==END_POSITIVE(im[parent]) /* SAME INTERVAL! */ + && START_POSITIVE(im[i])==START_POSITIVE(im[parent]))) + parent=im[parent].sublist; /* POP RECURSIVE STACK*/ + else { /* i CONTAINED IN parent*/ + im[i].sublist=parent; /* MARK AS CONTAINED IN parent */ + nsub++; /* COUNT TOTAL #SUBLIST ENTRIES */ + parent=i; /* AND PUSH ONTO RECURSIVE STACK */ + i++; /* ADVANCE TO NEXT INTERVAL */ + } + } + } /* AT THIS POINT sublist IS EITHER -1 IF NOT IN SUBLIST, OR INDICATES parent*/ + + if (nsub>0) { /* WE HAVE SUBLISTS TO PROCESS */ + CALLOC(imsub,nsub,IntervalMap); /* TEMPORARY ARRAY FOR REPACKING SUBLISTS */ + for (i=j=0;i=0) {/* IN A SUBLIST */ + imsub[j].start=i; + imsub[j].sublist=parent; + j++; + if (im[parent].sublist<0) /* A NEW PARENT! SET HIS SUBLIST HEADER INDEX */ + im[parent].sublist=nlists++; + } + im[i].sublist= -1; /* RESET TO DEFAULT VALUE: NO SUBLIST */ + } + qsort(imsub,nsub,sizeof(IntervalMap),sublist_qsort_cmp); + /* AT THIS POINT SUBLISTS ARE GROUPED TOGETHER, READY TO PACK */ + + CALLOC(subheader,nlists,SublistHeader); /* SUBLIST HEADER INDEX */ + for (i=0;i=0) { + i=find_overlap_start(start,end,im+subheader[isub].start,subheader[isub].len); + if (i>=0) + return i+subheader[isub].start; + } + return -1; +} + + +IntervalIterator *interval_iterator_alloc(void) +{ + IntervalIterator *it=NULL; + CALLOC(it,1,IntervalIterator); + return it; + handle_malloc_failure: + return NULL; +} + +int free_interval_iterator(IntervalIterator *it) +{ + IntervalIterator *it2,*it_next; + if (!it) + return 0; + FREE_ITERATOR_STACK(it,it2,it_next); + return 0; +} + + +IntervalIterator *reset_interval_iterator(IntervalIterator *it) +{ + ITERATOR_STACK_TOP(it); + it->n=0; + return it; +} + + +void reorient_intervals(int n,IntervalMap im[],int ori_sign) +{ + int i,tmp; + for (i=0;i=0 ? 1:-1)!=ori_sign) { /* ORIENTATION MISMATCH */ + tmp=im[i].start; /* SO REVERSE THIS INTERVAL MAPPING */ + im[i].start= -im[i].end; + im[i].end = -tmp; + /* tmp=im[i].target_start; */ + /* im[i].target_start= -im[i].target_end; */ + /* im[i].target_end = -tmp; */ + } + } +} + +int find_intervals(IntervalIterator *it0,int start,int end, + IntervalMap im[],int n, + SublistHeader subheader[],int nlists, + IntervalMap buf[],int nbuf, + int *p_nreturn,IntervalIterator **it_return) +{ + IntervalIterator *it=NULL,*it2=NULL; + int ibuf=0,j,k,ori_sign=1; + if (!it0) { /* ALLOCATE AN ITERATOR IF NOT SUPPLIED*/ + CALLOC(it,1,IntervalIterator); + } + else + it=it0; + +#if defined(ALL_POSITIVE_ORIENTATION) || defined(MERGE_INTERVAL_ORIENTATIONS) + if (start<0) { /* NEED TO CONVERT TO POSITIVE ORIENTATION */ + j=start; + start= -end; + end= -j; + ori_sign = -1; + } +#endif + if (it->n == 0) { /* DEFAULT: SEARCH THE TOP NESTED LIST */ + j = subheader[0].start; + it->n=n; + it->i=find_overlap_start(start,end,im,n); + } + do { + while (it->i>=0 && it->in && HAS_OVERLAP_POSITIVE(im[it->i],start,end)) { + memcpy(buf+ibuf,im + it->i,sizeof(IntervalMap)); /*SAVE THIS HIT TO BUFFER */ + ibuf++; + k=im[it->i].sublist; /* GET SUBLIST OF i IF ANY */ + it->i++; /* ADVANCE TO NEXT INTERVAL */ + if (k>=0 && (j=find_suboverlap_start(start,end,k,im,subheader,nlists))>=0) { + PUSH_ITERATOR_STACK(it,it2,IntervalIterator); /* RECURSE TO SUBLIST */ + it2->i = j; /* START OF OVERLAPPING HITS IN THIS SUBLIST */ + it2->n = subheader[k].start+subheader[k].len; /* END OF SUBLIST */ + it=it2; /* PUSH THE ITERATOR STACK */ + } + if (ibuf>=nbuf){ /* FILLED THE BUFFER, RETURN THE RESULTS SO FAR */ + goto finally_return_result; + } + } + } while (POP_ITERATOR_STACK(it)); /* IF STACK EXHAUSTED, EXIT */ + if (!it0) /* FREE THE ITERATOR WE CREATED. NO NEED TO RETURN IT TO USER */ + free_interval_iterator(it); + it=NULL; /* ITERATOR IS EXHAUSTED */ + + finally_return_result: +#if defined(ALL_POSITIVE_ORIENTATION) || defined(MERGE_INTERVAL_ORIENTATIONS) + reorient_intervals(ibuf,buf,ori_sign); /* REORIENT INTERVALS TO MATCH QUERY ORI */ +#endif + + *p_nreturn=ibuf; /* #INTERVALS FOUND IN THIS PASS */ + *it_return=it; /* HAND BACK ITERATOR FOR CONTINUING THE SEARCH, IF ANY */ + return 0; /* SIGNAL THAT NO ERROR OCCURRED */ + handle_malloc_failure: + return -1; +} + + +int repack_subheaders(IntervalMap im[],int n,int div, + SublistHeader *subheader,int nlists) +{ + int i,j,*sub_map=NULL; + SublistHeader *sub_pack=NULL; + + CALLOC(sub_map,nlists,int); + CALLOC(sub_pack,nlists,SublistHeader); + for (i=j=0;idiv AT FRONT */ + if (subheader[i].len>div) { + memcpy(sub_pack+j,subheader+i,sizeof(SublistHeader)); + sub_map[i]=j; + j++; + } + } + for (i=0;i=0) + im[i].sublist=sub_map[im[i].sublist]; + memcpy(subheader,sub_pack,nlists*sizeof(SublistHeader)); /* SAVE REORDERED LIST*/ + + FREE(sub_map); + FREE(sub_pack); + return 0; + handle_malloc_failure: + return -1; +} + +int free_interval_dbfile(IntervalDBFile *db_file) +{ + if (db_file->ifile_idb) + fclose(db_file->ifile_idb); +#ifdef ON_DEMAND_SUBLIST_HEADER + if (db_file->subheader_file.ifile) + fclose(db_file->subheader_file.ifile); +#endif + FREE(db_file->ii); + FREE(db_file->subheader); + free(db_file); + return 0; +} + + +IntervalMap *read_intervals(int n,FILE *ifile) +{ + int i=0; + IntervalMap *im=NULL; + CALLOC(im,n,IntervalMap); // ALLOCATE THE WHOLE ARRAY + while (i 5 || lens < 4) + ichr = -1; + else if(strcmp(s1, "chrX")==0) + ichr = 22; + else if(strcmp(s1, "chrY")==0) + ichr = 23; + else if (strcmp(s1, "chrM")==0) + ichr = -1; + else{ + ichr = (int)(atoi(&s1[3])-1); + } + if(ichr>=0) + nD[ichr]++; + } + fseek(fd, 0, SEEK_SET); + //------------------------------------------------------------------------- + IntervalMap** im = malloc(24*sizeof(IntervalMap*)); + for(i=0;i<24;i++){ + im[i] = NULL; + if(nD[i]>0) + CALLOC(im[i], nD[i], IntervalMap); + nD[i]=0; + } + while (fgets(buf, 1024, fd)) { + s1 = strtok(buf, "\t"); + s2 = strtok(NULL, "\t"); + s3 = strtok(NULL, "\t"); + lens = strlen(s1); + if(lens > 5 || lens < 4) + ichr = -1; + else if(strcmp(s1, "chrX")==0) + ichr = 22; + else if(strcmp(s1, "chrY")==0) + ichr = 23; + else if (strcmp(s1, "chrM")==0) + ichr = -1; + else{ + ichr = (int)(atoi(&s1[3])-1); + } + if(ichr>=0){ + k = nD[ichr]; + im[ichr][k].start = atol(s2); + im[ichr][k].end = atol(s3); + im[ichr][k].target_id = atol(s2); + im[ichr][k].sublist = -1; + nD[ichr]++; + } + } + fclose(fd); + return im; +} + +#define LINE_LEN 1024 +//int main(int argc, char **argv) { +// if(argc!=3){ +// printf("input: data file, query file \n"); +// return 0; +// } +// clock_t start1, end1, end2; +// start1 = clock(); +// int i, j, ichr; +// char *qfile = argv[1]; +// char *dfile = argv[2]; +// char *s1, *s2, *s3; +// FILE *fp; +// char line[LINE_LEN]; +// int start, end; +// //------------------------------------------------------------------------- +// int interval_map_size = 1024; +// int* nD = calloc(24, sizeof(int)); +// IntervalMap **im = openBed24(dfile, nD); +// SublistHeader **sh = malloc(24*sizeof(SublistHeader*)); +// int **p_n = malloc(24*sizeof(int*)); +// int **p_nlists = malloc(24*sizeof(int*)); +// int **nhits = malloc(24*sizeof(int*)); +// uint64_t Total=0; +// for(i=0;i<24;i++){ +// p_n[i] = malloc(1*sizeof(int)); +// p_nlists[i] = malloc(1*sizeof(int)); +// nhits[i] = malloc(1*sizeof(int)); +// if(nD[i]>0){ +// sh[i] = build_nested_list(im[i], nD[i], p_n[i], p_nlists[i]); +// //printf("%i:\t%i\t%i\t%i\n", i, nD[i], *p_n[i], *p_nlists[i]); +// //for(j=0;j5 || strlen(s1)<4 || strcmp(s1, "chrM")==0) +// ichr = -1; +// else if(strcmp(s1, "chrX")==0) +// ichr = 22; +// else if(strcmp(s1, "chrY")==0) +// ichr = 23; +// else{ +// ichr = (int)(atoi(&s1[3])-1); +// } +// if(ichr>=0){ +// qhits = 0; +// start = atol(s2); +// end = atol(s3);//the definition is different! +// //----------------------------------------------------------------- +// it_alloc = interval_iterator_alloc(); +// it = it_alloc; +// while(it){ +// find_intervals(it, start, end, im[ichr], *p_n[ichr], sh[ichr], *p_nlists[ichr], im_buf, 1024, nhits[ichr], &it); +// //printf("nhits %d\n", *nhits[ichr]); +// qhits += *nhits[ichr]; +// //for (i = 0; i < *nhits[ichr]; i++){ +// // printf("\t%d\t%d\n", im_buf[i].start, im_buf[i].end); +// //} +// } +// free_interval_iterator(it_alloc); +// Total += qhits; +// printf("%s\t%ld\t%ld\t%i\n", s1, atol(s2), atol(s3), qhits); +// } +// } +// fclose(fp); +// end2 = clock(); +// +// //printf("Total: %lld\n", (long long)Total); +// for(i=0;i<24;i++){ +// free(p_n[i]); +// free(p_nlists[i]); +// free(nhits[i]); +// } +// free(p_n); +// free(p_nlists); +// free(nhits); +// free(nD); +//} \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/test/3rd-party/intervaldb.h b/sequila/sequila-core/superintervals/test/3rd-party/intervaldb.h new file mode 100644 index 0000000..f5370d8 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/intervaldb.h @@ -0,0 +1,196 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef INTERVALDB_HEADER_INCLUDED +#define INTERVALDB_HEADER_INCLUDED 1 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define _FILE_OFFSET_BITS 64 +#define PYGR_OFF_T off_t +#define MALLOC_FAILURE_ACTION goto handle_malloc_failure + +#define CALLOC(memptr,N,ATYPE) \ + (memptr)=(ATYPE *)calloc((size_t)(N),sizeof(ATYPE)) +#define FREE(P) if (P) {free(P);(P)=NULL;} +//#include "default.h" +#include + +extern int C_int_max; + +typedef struct { + int start; + int end; + int target_id; + /* int target_start; */ + /* int target_end; */ + int sublist; +} IntervalMap; + + +typedef struct { + int start; + int end; +} IntervalIndex; + +typedef struct { + int start; + int len; +} SublistHeader; + +typedef struct { + int n; + int ntop; + int nlists; + IntervalMap *im; + SublistHeader *subheader; +} IntervalDB; + +typedef struct { /* FOR REAL-TIME DISK ACCESS TO SUBLIST HEADER FILE*/ + SublistHeader *subheader; + int nblock; + int start; + FILE *ifile; +} SubheaderFile; + +typedef struct { + int n; + int ntop; + int nlists; + int div; + int nii; + IntervalIndex *ii; + SublistHeader *subheader; + SubheaderFile subheader_file; + FILE *ifile_idb; +} IntervalDBFile; + +typedef struct IntervalIterator_S { + int i; + int n; + int nii; + int ntop; + int i_div; + IntervalMap *im; + struct IntervalIterator_S *up; + struct IntervalIterator_S *down; +} IntervalIterator; + + +typedef struct { + FILE *ifile; + int left; + int right; + int ihead; + char *filename; +} FilePtrRecord; + +extern int imstart_qsort_cmp(const void *void_a,const void *void_b); +extern int target_qsort_cmp(const void *void_a,const void *void_b); +extern IntervalMap *read_intervals(int n,FILE *ifile); +extern SublistHeader *build_nested_list(IntervalMap im[],int n, + int *p_n,int *p_nlists); +extern SublistHeader *build_nested_list_inplace(IntervalMap im[],int n, + int *p_n,int *p_nlists); +extern IntervalMap *interval_map_alloc(int n); +extern IntervalDB *build_interval_db(IntervalMap im[],int n); +extern IntervalIterator *interval_iterator_alloc(void); +extern int free_interval_iterator(IntervalIterator *it); +extern IntervalIterator *reset_interval_iterator(IntervalIterator *it); +extern int find_intervals(IntervalIterator *it0,int start,int end,IntervalMap im[],int n,SublistHeader subheader[],int nlists,IntervalMap buf[],int nbuf,int *p_nreturn,IntervalIterator **it_return); +extern int read_imdiv(FILE *ifile,IntervalMap imdiv[],int div,int i_div,int ntop); +extern IntervalMap *read_sublist(FILE *ifile,SublistHeader *subheader,IntervalMap *im); +extern int find_file_intervals(IntervalIterator *it0,int start,int end, + IntervalIndex ii[],int nii, + SublistHeader subheader[],int nlists, + SubheaderFile *subheader_file, + int ntop,int div,FILE *ifile, + IntervalMap buf[],int nbuf, + int *p_nreturn,IntervalIterator **it_return); +extern int write_padded_binary(IntervalMap im[],int n,int div,FILE *ifile); +extern char *write_binary_files(IntervalMap im[],int n,int ntop,int div, + SublistHeader *subheader,int nlists,char filestem[]); +extern IntervalDBFile *read_binary_files(char filestem[],char err_msg[], + int subheader_nblock); +extern int free_interval_dbfile(IntervalDBFile *db_file); + +extern int save_text_file(char filestem[],char err_msg[], + char basestem[],FILE *ofile); +extern int text_file_to_binaries(FILE *infile,char buildpath[],char err_msg[]); +extern void reorient_intervals(int n,IntervalMap im[],int ori_sign); + +#define FIND_FILE_MALLOC_ERR -2 + +#define ITERATOR_STACK_TOP(it) while (it->up) it=it->up; +#define FREE_ITERATOR_STACK(it,it2,it_next) \ + for (it2=it->down;it2;it2=it_next) { \ + it_next=it2->down; \ + if (it2->im) \ + free(it2->im); \ + free(it2); \ + } \ + for (it2=it;it2;it2=it_next) { \ + it_next=it2->up; \ + if (it2->im) \ + free(it2->im); \ + free(it2); \ + } + +#define PUSH_ITERATOR_STACK(it,it2,TYPE) \ + if (it->down) \ + it2=it->down; \ + else { \ + CALLOC(it2,1,TYPE); \ + it2->up = it; \ + it->down= it2; \ + } +#define POP_ITERATOR_STACK_DONE(it) (it->up==NULL || (it=it->up)==NULL) + +#define POP_ITERATOR_STACK(it) (it->up && (it=it->up)) + + +#ifdef MERGE_INTERVAL_ORIENTATIONS +/* MACROS FOR MERGING POSITIVE AND NEGATIVE ORIENTATIONS */ +#define START_POSITIVE(IM) (((IM).start>=0) ? ((IM).start) : -((IM).end)) +#define END_POSITIVE(IM) (((IM).start>=0) ? ((IM).end) : -((IM).start)) +#define SET_INTERVAL_POSITIVE(IM,START,END) if ((IM).start>=0) {\ + START= (IM).start; \ + END= (IM).end; \ +} else { \ + START= -((IM).end); \ + END= -((IM).start); \ +} + +#define HAS_OVERLAP_POSITIVE(IM,START,END) (((IM).start>=0) ? \ + ((IM).start<(END) && (START)<(IM).end) \ + : (-((IM).end)<(END) && (START) < -((IM).start))) + /* ????? MERGE_INTERVAL_ORIENTATIONS ??????? */ + +#else +/* STANDARD MACROS */ +#define START_POSITIVE(IM) ((IM).start) +#define END_POSITIVE(IM) ((IM).end) +#define HAS_OVERLAP_POSITIVE(IM,START,END) ((IM).start<(END) && (START)<(IM).end) + +#endif + +/* STORE ALL INTERVALS IN POSITIVE SOURCE ORIENTATION */ +#define ALL_POSITIVE_ORIENTATION 1 +/* ONLY LOAD SUBLISTS INDIVIDUALLY WHEN NEEDED */ +#define ON_DEMAND_SUBLIST_HEADER 1 + +#endif + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/test/3rd-party/intervalstab.hpp b/sequila/sequila-core/superintervals/test/3rd-party/intervalstab.hpp new file mode 100644 index 0000000..4ad603c --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/intervalstab.hpp @@ -0,0 +1,246 @@ +/************************************************************ +Copyright (C) 2009 Jens M. Schmidt + +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation; either version 2 +or 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, write to the Free Software +Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace intervalstab { + +// an interval +struct interval { + uint64_t l = 0; + uint64_t r = 0; + interval* leftsibling = nullptr; + interval* rightchild = nullptr; + interval* parent = nullptr; + interval* smaller = nullptr; + std::list::iterator pIt; // = (std::list::iterator)nullptr; + bool stabbed = false; + interval(void) { } + interval(const uint64_t& a, + const uint64_t& b) + : l(a), + r(b) { } + ~interval(void) { } + }; + +// lexicographic order +inline bool operator<(const interval& x,const interval& y) { + return (x.l < y.l || x.l == y.l && x.r > y.r); +} + +// equality +inline bool operator==(const interval& x,const interval& y) { + return (x.l == y.l && x.r == y.r); +} + +// lexicographic order +inline bool operator>(const interval& x,const interval& y) { + return y < x; +} + +// output stream for intervals +inline std::ostream& operator<<(std::ostream& os, const interval& a) { + os << &a << "\t" << a.l << "\t" << a.r << "\tP " << a.parent << " L " << a.leftsibling + << " C " << a.rightchild << " Sm " << a.smaller; + return os; +} +inline std::ostream& operator<<(std::ostream& os, const std::vector& a) { + for (unsigned int i = 0; i < a.size(); ++i) { + os << *a[i]; + } + return os; +} + +// fast stabbing +//template // TODO +class faststabbing +{ +private: + std::vector& a; // array of intervals [0,n-1] + uint64_t n,bigN; + std::vector > eventlist; // sweepline + std::vector stop; + interval dummy; + + void preprocessing(void) { + // sort the array + std::sort(a.begin(), a.end()); //, IntervalComp()); + // if we want to work on unique data + //a.erase(std::unique(a.begin(), a.end()), a.end()); + //((Interval*)buffer.data)+data_len, + //IntervalLess()); + // create smaller lists and event lists + uint64_t i,l,starting=-1; + for (i=0; i= a[i].r); + a[i-1].smaller = &a[i]; + } + starting = l; + } + + // sweep line + std::list L; // status list + interval* temp; + interval* last; + for (i=1; i<=bigN; ++i) { + // interval with starting point i + if (!eventlist[i].empty()) { + temp = eventlist[i].back(); + if (temp->l == i) { + L.push_back(temp); + temp->pIt = std::prev(L.end()); + eventlist[i].pop_back(); + } + } + /* + std::cerr << "sweeep " << i << ": " << eventlist[i]; + for (auto& l : L) std::cerr << " " << l; + std::cerr << std::endl; + */ + //assert(!L.empty() || eventlist[i].empty()); + if (!L.empty()) { + // compute stop[i] + stop[i] = L.back(); + // intervals with end points i + for (auto it = eventlist[i].rbegin(); it != eventlist[i].rend(); ++it) { + temp = *it; + //std::cerr << "Temp " << temp->l << " " << temp->r << std::endl; + if (temp->pIt != L.begin()) { + //std::cerr << "setting last " << *temp << std::endl; + last = *std::prev(temp->pIt); + } else last = &dummy; + //std::cerr << "\n\t\t" << last << "\t\t" << temp << std::endl; + temp->parent = last; + temp->leftsibling = last->rightchild; + last->rightchild = temp; + //temp->pIt = + //std::cerr << "L size " << L.size() << std::endl; + //if (temp->pIt != L.end()) + L.erase(temp->pIt); + //temp->pIt = std::prev(L.end()); + last = temp; + } + } + } +//#ifdef INTERVALSTAB_DEBUG + //std::cerr << "\nDummy\t\t" << &dummy << "\n" << a.size() << std::endl; +//#endi + } + +//#ifdef INTERVALSTAB_DEBUG + bool verify(std::vector output, const uint64_t& q) { +// cout << "\nQuery q=" << q << ":\n" << output; + interval* temp; + interval* last = nullptr; + while (!output.empty()) { + temp = output.back(); + output.pop_back(); + if (last && (*temp < *last)) { + std::cerr << "\nerror: interval " << temp << " not in order (not after " << last << ")\n"; + return 1; + } + temp->stabbed = true; + last = temp; + } + bool stabs; + for (uint64_t i=0; i& intervals, + const uint64_t& numberIntervals, + const uint64_t& numberDomain) + : a(intervals), eventlist(numberDomain+1), stop(numberDomain+1) { + n = numberIntervals; + bigN = numberDomain; + dummy.parent = nullptr; + dummy.leftsibling = nullptr; + dummy.rightchild = nullptr; + preprocessing(); + }; + + std::vector query(const uint64_t& q) { //, uint64_t& numComparisons) { + //assert(q >= 1 && q <= bigN+1); + std::vector output; + if (stop[q] == nullptr) return output; // no stabbed intervals + interval* i; + interval* temp; + std::deque process; + for (temp = stop[q]; temp->parent != nullptr; temp = temp->parent) { + process.push_front(temp); + } + + // traverse + while (!process.empty()) { + i = process.back(); + process.pop_back(); + //process.pop_back(); + output.push_back(i); + + temp = i->smaller; + while (temp != nullptr) { + //++numComparisons; + if (q > temp->r) break; + output.push_back(temp); +//#ifdef INTERVALSTAB_DEBUG +// cout << "\tSmaller " << (*temp); +//#endif + temp = temp->smaller; + } + + // go along rightmost path of pa + temp = i->leftsibling; + while (temp) { + //++numComparisons; + if (temp->r < q) break; + process.push_back(temp); + temp = temp->rightchild; + } + } + //assert(verify(output,q) == 0); + return output; + } +}; + + +} diff --git a/sequila/sequila-core/superintervals/test/3rd-party/khash.h b/sequila/sequila-core/superintervals/test/3rd-party/khash.h new file mode 100644 index 0000000..f75f347 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/3rd-party/khash.h @@ -0,0 +1,627 @@ +/* The MIT License + + Copyright (c) 2008, 2009, 2011 by Attractive Chaos + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +/* + An example: + +#include "khash.h" +KHASH_MAP_INIT_INT(32, char) +int main() { + int ret, is_missing; + khiter_t k; + khash_t(32) *h = kh_init(32); + k = kh_put(32, h, 5, &ret); + kh_value(h, k) = 10; + k = kh_get(32, h, 10); + is_missing = (k == kh_end(h)); + k = kh_get(32, h, 5); + kh_del(32, h, k); + for (k = kh_begin(h); k != kh_end(h); ++k) + if (kh_exist(h, k)) kh_value(h, k) = 1; + kh_destroy(32, h); + return 0; +} +*/ + +/* + 2013-05-02 (0.2.8): + + * Use quadratic probing. When the capacity is power of 2, stepping function + i*(i+1)/2 guarantees to traverse each bucket. It is better than double + hashing on cache performance and is more robust than linear probing. + + In theory, double hashing should be more robust than quadratic probing. + However, my implementation is probably not for large hash tables, because + the second hash function is closely tied to the first hash function, + which reduce the effectiveness of double hashing. + + Reference: http://research.cs.vt.edu/AVresearch/hashing/quadratic.php + + 2011-12-29 (0.2.7): + + * Minor code clean up; no actual effect. + + 2011-09-16 (0.2.6): + + * The capacity is a power of 2. This seems to dramatically improve the + speed for simple keys. Thank Zilong Tan for the suggestion. Reference: + + - http://code.google.com/p/ulib/ + - http://nothings.org/computer/judy/ + + * Allow to optionally use linear probing which usually has better + performance for random input. Double hashing is still the default as it + is more robust to certain non-random input. + + * Added Wang's integer hash function (not used by default). This hash + function is more robust to certain non-random input. + + 2011-02-14 (0.2.5): + + * Allow to declare global functions. + + 2009-09-26 (0.2.4): + + * Improve portability + + 2008-09-19 (0.2.3): + + * Corrected the example + * Improved interfaces + + 2008-09-11 (0.2.2): + + * Improved speed a little in kh_put() + + 2008-09-10 (0.2.1): + + * Added kh_clear() + * Fixed a compiling error + + 2008-09-02 (0.2.0): + + * Changed to token concatenation which increases flexibility. + + 2008-08-31 (0.1.2): + + * Fixed a bug in kh_get(), which has not been tested previously. + + 2008-08-31 (0.1.1): + + * Added destructor +*/ + + +#ifndef __AC_KHASH_H +#define __AC_KHASH_H + +/*! + @header + + Generic hash table library. + */ + +#define AC_VERSION_KHASH_H "0.2.8" + +#include +#include +#include + +/* compiler specific configuration */ + +#if UINT_MAX == 0xffffffffu +typedef unsigned int khint32_t; +#elif ULONG_MAX == 0xffffffffu +typedef unsigned long khint32_t; +#endif + +#if ULONG_MAX == ULLONG_MAX +typedef unsigned long khint64_t; +#else +typedef unsigned long long khint64_t; +#endif + +#ifndef kh_inline +#ifdef _MSC_VER +#define kh_inline __inline +#else +#define kh_inline inline +#endif +#endif /* kh_inline */ + +#ifndef klib_unused +#if (defined __clang__ && __clang_major__ >= 3) || (defined __GNUC__ && __GNUC__ >= 3) +#define klib_unused __attribute__ ((__unused__)) +#else +#define klib_unused +#endif +#endif /* klib_unused */ + +typedef khint32_t khint_t; +typedef khint_t khiter_t; + +#define __ac_isempty(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&2) +#define __ac_isdel(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&1) +#define __ac_iseither(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&3) +#define __ac_set_isdel_false(flag, i) (flag[i>>4]&=~(1ul<<((i&0xfU)<<1))) +#define __ac_set_isempty_false(flag, i) (flag[i>>4]&=~(2ul<<((i&0xfU)<<1))) +#define __ac_set_isboth_false(flag, i) (flag[i>>4]&=~(3ul<<((i&0xfU)<<1))) +#define __ac_set_isdel_true(flag, i) (flag[i>>4]|=1ul<<((i&0xfU)<<1)) + +#define __ac_fsize(m) ((m) < 16? 1 : (m)>>4) + +#ifndef kroundup32 +#define kroundup32(x) (--(x), (x)|=(x)>>1, (x)|=(x)>>2, (x)|=(x)>>4, (x)|=(x)>>8, (x)|=(x)>>16, ++(x)) +#endif + +#ifndef kcalloc +#define kcalloc(N,Z) calloc(N,Z) +#endif +#ifndef kmalloc +#define kmalloc(Z) malloc(Z) +#endif +#ifndef krealloc +#define krealloc(P,Z) realloc(P,Z) +#endif +#ifndef kfree +#define kfree(P) free(P) +#endif + +static const double __ac_HASH_UPPER = 0.77; + +#define __KHASH_TYPE(name, khkey_t, khval_t) \ + typedef struct kh_##name##_s { \ + khint_t n_buckets, size, n_occupied, upper_bound; \ + khint32_t *flags; \ + khkey_t *keys; \ + khval_t *vals; \ + } kh_##name##_t; + +#define __KHASH_PROTOTYPES(name, khkey_t, khval_t) \ + extern kh_##name##_t *kh_init_##name(void); \ + extern void kh_destroy_##name(kh_##name##_t *h); \ + extern void kh_clear_##name(kh_##name##_t *h); \ + extern khint_t kh_get_##name(const kh_##name##_t *h, khkey_t key); \ + extern int kh_resize_##name(kh_##name##_t *h, khint_t new_n_buckets); \ + extern khint_t kh_put_##name(kh_##name##_t *h, khkey_t key, int *ret); \ + extern void kh_del_##name(kh_##name##_t *h, khint_t x); + +#define __KHASH_IMPL(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ + SCOPE kh_##name##_t *kh_init_##name(void) { \ + return (kh_##name##_t*)kcalloc(1, sizeof(kh_##name##_t)); \ + } \ + SCOPE void kh_destroy_##name(kh_##name##_t *h) \ + { \ + if (h) { \ + kfree((void *)h->keys); kfree(h->flags); \ + kfree((void *)h->vals); \ + kfree(h); \ + } \ + } \ + SCOPE void kh_clear_##name(kh_##name##_t *h) \ + { \ + if (h && h->flags) { \ + memset(h->flags, 0xaa, __ac_fsize(h->n_buckets) * sizeof(khint32_t)); \ + h->size = h->n_occupied = 0; \ + } \ + } \ + SCOPE khint_t kh_get_##name(const kh_##name##_t *h, khkey_t key) \ + { \ + if (h->n_buckets) { \ + khint_t k, i, last, mask, step = 0; \ + mask = h->n_buckets - 1; \ + k = __hash_func(key); i = k & mask; \ + last = i; \ + while (!__ac_isempty(h->flags, i) && (__ac_isdel(h->flags, i) || !__hash_equal(h->keys[i], key))) { \ + i = (i + (++step)) & mask; \ + if (i == last) return h->n_buckets; \ + } \ + return __ac_iseither(h->flags, i)? h->n_buckets : i; \ + } else return 0; \ + } \ + SCOPE int kh_resize_##name(kh_##name##_t *h, khint_t new_n_buckets) \ + { /* This function uses 0.25*n_buckets bytes of working space instead of [sizeof(key_t+val_t)+.25]*n_buckets. */ \ + khint32_t *new_flags = 0; \ + khint_t j = 1; \ + { \ + kroundup32(new_n_buckets); \ + if (new_n_buckets < 4) new_n_buckets = 4; \ + if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; /* requested size is too small */ \ + else { /* hash table size to be changed (shrink or expand); rehash */ \ + new_flags = (khint32_t*)kmalloc(__ac_fsize(new_n_buckets) * sizeof(khint32_t)); \ + if (!new_flags) return -1; \ + memset(new_flags, 0xaa, __ac_fsize(new_n_buckets) * sizeof(khint32_t)); \ + if (h->n_buckets < new_n_buckets) { /* expand */ \ + khkey_t *new_keys = (khkey_t*)krealloc((void *)h->keys, new_n_buckets * sizeof(khkey_t)); \ + if (!new_keys) { kfree(new_flags); return -1; } \ + h->keys = new_keys; \ + if (kh_is_map) { \ + khval_t *new_vals = (khval_t*)krealloc((void *)h->vals, new_n_buckets * sizeof(khval_t)); \ + if (!new_vals) { kfree(new_flags); return -1; } \ + h->vals = new_vals; \ + } \ + } /* otherwise shrink */ \ + } \ + } \ + if (j) { /* rehashing is needed */ \ + for (j = 0; j != h->n_buckets; ++j) { \ + if (__ac_iseither(h->flags, j) == 0) { \ + khkey_t key = h->keys[j]; \ + khval_t val; \ + khint_t new_mask; \ + new_mask = new_n_buckets - 1; \ + if (kh_is_map) val = h->vals[j]; \ + __ac_set_isdel_true(h->flags, j); \ + while (1) { /* kick-out process; sort of like in Cuckoo hashing */ \ + khint_t k, i, step = 0; \ + k = __hash_func(key); \ + i = k & new_mask; \ + while (!__ac_isempty(new_flags, i)) i = (i + (++step)) & new_mask; \ + __ac_set_isempty_false(new_flags, i); \ + if (i < h->n_buckets && __ac_iseither(h->flags, i) == 0) { /* kick out the existing element */ \ + { khkey_t tmp = h->keys[i]; h->keys[i] = key; key = tmp; } \ + if (kh_is_map) { khval_t tmp = h->vals[i]; h->vals[i] = val; val = tmp; } \ + __ac_set_isdel_true(h->flags, i); /* mark it as deleted in the old hash table */ \ + } else { /* write the element and jump out of the loop */ \ + h->keys[i] = key; \ + if (kh_is_map) h->vals[i] = val; \ + break; \ + } \ + } \ + } \ + } \ + if (h->n_buckets > new_n_buckets) { /* shrink the hash table */ \ + h->keys = (khkey_t*)krealloc((void *)h->keys, new_n_buckets * sizeof(khkey_t)); \ + if (kh_is_map) h->vals = (khval_t*)krealloc((void *)h->vals, new_n_buckets * sizeof(khval_t)); \ + } \ + kfree(h->flags); /* free the working space */ \ + h->flags = new_flags; \ + h->n_buckets = new_n_buckets; \ + h->n_occupied = h->size; \ + h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); \ + } \ + return 0; \ + } \ + SCOPE khint_t kh_put_##name(kh_##name##_t *h, khkey_t key, int *ret) \ + { \ + khint_t x; \ + if (h->n_occupied >= h->upper_bound) { /* update the hash table */ \ + if (h->n_buckets > (h->size<<1)) { \ + if (kh_resize_##name(h, h->n_buckets - 1) < 0) { /* clear "deleted" elements */ \ + *ret = -1; return h->n_buckets; \ + } \ + } else if (kh_resize_##name(h, h->n_buckets + 1) < 0) { /* expand the hash table */ \ + *ret = -1; return h->n_buckets; \ + } \ + } /* TODO: to implement automatically shrinking; resize() already support shrinking */ \ + { \ + khint_t k, i, site, last, mask = h->n_buckets - 1, step = 0; \ + x = site = h->n_buckets; k = __hash_func(key); i = k & mask; \ + if (__ac_isempty(h->flags, i)) x = i; /* for speed up */ \ + else { \ + last = i; \ + while (!__ac_isempty(h->flags, i) && (__ac_isdel(h->flags, i) || !__hash_equal(h->keys[i], key))) { \ + if (__ac_isdel(h->flags, i)) site = i; \ + i = (i + (++step)) & mask; \ + if (i == last) { x = site; break; } \ + } \ + if (x == h->n_buckets) { \ + if (__ac_isempty(h->flags, i) && site != h->n_buckets) x = site; \ + else x = i; \ + } \ + } \ + } \ + if (__ac_isempty(h->flags, x)) { /* not present at all */ \ + h->keys[x] = key; \ + __ac_set_isboth_false(h->flags, x); \ + ++h->size; ++h->n_occupied; \ + *ret = 1; \ + } else if (__ac_isdel(h->flags, x)) { /* deleted */ \ + h->keys[x] = key; \ + __ac_set_isboth_false(h->flags, x); \ + ++h->size; \ + *ret = 2; \ + } else *ret = 0; /* Don't touch h->keys[x] if present and not deleted */ \ + return x; \ + } \ + SCOPE void kh_del_##name(kh_##name##_t *h, khint_t x) \ + { \ + if (x != h->n_buckets && !__ac_iseither(h->flags, x)) { \ + __ac_set_isdel_true(h->flags, x); \ + --h->size; \ + } \ + } + +#define KHASH_DECLARE(name, khkey_t, khval_t) \ + __KHASH_TYPE(name, khkey_t, khval_t) \ + __KHASH_PROTOTYPES(name, khkey_t, khval_t) + +#define KHASH_INIT2(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ + __KHASH_TYPE(name, khkey_t, khval_t) \ + __KHASH_IMPL(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) + +#define KHASH_INIT(name, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ + KHASH_INIT2(name, static kh_inline klib_unused, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) + +/* --- BEGIN OF HASH FUNCTIONS --- */ + +/*! @function + @abstract Integer hash function + @param key The integer [khint32_t] + @return The hash value [khint_t] + */ +#define kh_int_hash_func(key) (khint32_t)(key) +/*! @function + @abstract Integer comparison function + */ +#define kh_int_hash_equal(a, b) ((a) == (b)) +/*! @function + @abstract 64-bit integer hash function + @param key The integer [khint64_t] + @return The hash value [khint_t] + */ +#define kh_int64_hash_func(key) (khint32_t)((key)>>33^(key)^(key)<<11) +/*! @function + @abstract 64-bit integer comparison function + */ +#define kh_int64_hash_equal(a, b) ((a) == (b)) +/*! @function + @abstract const char* hash function + @param s Pointer to a null terminated string + @return The hash value + */ +static kh_inline khint_t __ac_X31_hash_string(const char *s) +{ + khint_t h = (khint_t)*s; + if (h) for (++s ; *s; ++s) h = (h << 5) - h + (khint_t)*s; + return h; +} +/*! @function + @abstract Another interface to const char* hash function + @param key Pointer to a null terminated string [const char*] + @return The hash value [khint_t] + */ +#define kh_str_hash_func(key) __ac_X31_hash_string(key) +/*! @function + @abstract Const char* comparison function + */ +#define kh_str_hash_equal(a, b) (strcmp(a, b) == 0) + +static kh_inline khint_t __ac_Wang_hash(khint_t key) +{ + key += ~(key << 15); + key ^= (key >> 10); + key += (key << 3); + key ^= (key >> 6); + key += ~(key << 11); + key ^= (key >> 16); + return key; +} +#define kh_int_hash_func2(key) __ac_Wang_hash((khint_t)key) + +/* --- END OF HASH FUNCTIONS --- */ + +/* Other convenient macros... */ + +/*! + @abstract Type of the hash table. + @param name Name of the hash table [symbol] + */ +#define khash_t(name) kh_##name##_t + +/*! @function + @abstract Initiate a hash table. + @param name Name of the hash table [symbol] + @return Pointer to the hash table [khash_t(name)*] + */ +#define kh_init(name) kh_init_##name() + +/*! @function + @abstract Destroy a hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + */ +#define kh_destroy(name, h) kh_destroy_##name(h) + +/*! @function + @abstract Reset a hash table without deallocating memory. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + */ +#define kh_clear(name, h) kh_clear_##name(h) + +/*! @function + @abstract Resize a hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param s New size [khint_t] + */ +#define kh_resize(name, h, s) kh_resize_##name(h, s) + +/*! @function + @abstract Insert a key to the hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param k Key [type of keys] + @param r Extra return code: -1 if the operation failed; + 0 if the key is present in the hash table; + 1 if the bucket is empty (never used); 2 if the element in + the bucket has been deleted [int*] + @return Iterator to the inserted element [khint_t] + */ +#define kh_put(name, h, k, r) kh_put_##name(h, k, r) + +/*! @function + @abstract Retrieve a key from the hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param k Key [type of keys] + @return Iterator to the found element, or kh_end(h) if the element is absent [khint_t] + */ +#define kh_get(name, h, k) kh_get_##name(h, k) + +/*! @function + @abstract Remove a key from the hash table. + @param name Name of the hash table [symbol] + @param h Pointer to the hash table [khash_t(name)*] + @param k Iterator to the element to be deleted [khint_t] + */ +#define kh_del(name, h, k) kh_del_##name(h, k) + +/*! @function + @abstract Test whether a bucket contains data. + @param h Pointer to the hash table [khash_t(name)*] + @param x Iterator to the bucket [khint_t] + @return 1 if containing data; 0 otherwise [int] + */ +#define kh_exist(h, x) (!__ac_iseither((h)->flags, (x))) + +/*! @function + @abstract Get key given an iterator + @param h Pointer to the hash table [khash_t(name)*] + @param x Iterator to the bucket [khint_t] + @return Key [type of keys] + */ +#define kh_key(h, x) ((h)->keys[x]) + +/*! @function + @abstract Get value given an iterator + @param h Pointer to the hash table [khash_t(name)*] + @param x Iterator to the bucket [khint_t] + @return Value [type of values] + @discussion For hash sets, calling this results in segfault. + */ +#define kh_val(h, x) ((h)->vals[x]) + +/*! @function + @abstract Alias of kh_val() + */ +#define kh_value(h, x) ((h)->vals[x]) + +/*! @function + @abstract Get the start iterator + @param h Pointer to the hash table [khash_t(name)*] + @return The start iterator [khint_t] + */ +#define kh_begin(h) (khint_t)(0) + +/*! @function + @abstract Get the end iterator + @param h Pointer to the hash table [khash_t(name)*] + @return The end iterator [khint_t] + */ +#define kh_end(h) ((h)->n_buckets) + +/*! @function + @abstract Get the number of elements in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @return Number of elements in the hash table [khint_t] + */ +#define kh_size(h) ((h)->size) + +/*! @function + @abstract Get the number of buckets in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @return Number of buckets in the hash table [khint_t] + */ +#define kh_n_buckets(h) ((h)->n_buckets) + +/*! @function + @abstract Iterate over the entries in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @param kvar Variable to which key will be assigned + @param vvar Variable to which value will be assigned + @param code Block of code to execute + */ +#define kh_foreach(h, kvar, vvar, code) { khint_t __i; \ + for (__i = kh_begin(h); __i != kh_end(h); ++__i) { \ + if (!kh_exist(h,__i)) continue; \ + (kvar) = kh_key(h,__i); \ + (vvar) = kh_val(h,__i); \ + code; \ + } } + +/*! @function + @abstract Iterate over the values in the hash table + @param h Pointer to the hash table [khash_t(name)*] + @param vvar Variable to which value will be assigned + @param code Block of code to execute + */ +#define kh_foreach_value(h, vvar, code) { khint_t __i; \ + for (__i = kh_begin(h); __i != kh_end(h); ++__i) { \ + if (!kh_exist(h,__i)) continue; \ + (vvar) = kh_val(h,__i); \ + code; \ + } } + +/* More convenient interfaces */ + +/*! @function + @abstract Instantiate a hash set containing integer keys + @param name Name of the hash table [symbol] + */ +#define KHASH_SET_INIT_INT(name) \ + KHASH_INIT(name, khint32_t, char, 0, kh_int_hash_func, kh_int_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing integer keys + @param name Name of the hash table [symbol] + @param khval_t Type of values [type] + */ +#define KHASH_MAP_INIT_INT(name, khval_t) \ + KHASH_INIT(name, khint32_t, khval_t, 1, kh_int_hash_func, kh_int_hash_equal) + +/*! @function + @abstract Instantiate a hash set containing 64-bit integer keys + @param name Name of the hash table [symbol] + */ +#define KHASH_SET_INIT_INT64(name) \ + KHASH_INIT(name, khint64_t, char, 0, kh_int64_hash_func, kh_int64_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing 64-bit integer keys + @param name Name of the hash table [symbol] + @param khval_t Type of values [type] + */ +#define KHASH_MAP_INIT_INT64(name, khval_t) \ + KHASH_INIT(name, khint64_t, khval_t, 1, kh_int64_hash_func, kh_int64_hash_equal) + +typedef const char *kh_cstr_t; +/*! @function + @abstract Instantiate a hash map containing const char* keys + @param name Name of the hash table [symbol] + */ +#define KHASH_SET_INIT_STR(name) \ + KHASH_INIT(name, kh_cstr_t, char, 0, kh_str_hash_func, kh_str_hash_equal) + +/*! @function + @abstract Instantiate a hash map containing const char* keys + @param name Name of the hash table [symbol] + @param khval_t Type of values [type] + */ +#define KHASH_MAP_INIT_STR(name, khval_t) \ + KHASH_INIT(name, kh_cstr_t, khval_t, 1, kh_str_hash_func, kh_str_hash_equal) + +#endif /* __AC_KHASH_H */ diff --git a/sequila/sequila-core/superintervals/test/Makefile b/sequila/sequila-core/superintervals/test/Makefile new file mode 100644 index 0000000..1e4cc7e --- /dev/null +++ b/sequila/sequila-core/superintervals/test/Makefile @@ -0,0 +1,30 @@ +CXXFLAGS+= -O3 -std=c++17 -march=native -mtune=native -fomit-frame-pointer -fno-exceptions -fno-rtti -flto -fno-stack-protector +LIBS= +EXE= run-tests run-cpp-libs +INCLUDE= -I../src -I./3rd-party +# GCC use -ftree-vectorize -fopt-info-vec-optimized + +# to build coitrees use: +# cd 3rd-party/coitrees; RUSTFLAGS="-Ctarget-cpu=native" cargo run --release --example bed-intersect + + +ifneq ($(asan),) + CFLAGS+=-fsanitize=address + CXXFLAGS+=-fsanitize=address + LIBS+=-fsanitize=address +endif + +all:$(EXE) + +build_clibs: + $(CC) $(CFLAGS) -c ./3rd-party/intervaldb.c -o intervaldb.o + $(CC) $(CFLAGS) -c ./3rd-party/cgranges.c -o cgranges.o + +run-tests: tests.cpp + $(CXX) $(CXXFLAGS) $(INCLUDE) $< $(LIBS) -o run-tests + +run-cpp-libs: bench.cpp build_clibs + $(CXX) $(CXXFLAGS) $(INCLUDE) $< cgranges.o intervaldb.o $(LIBS) -o $@ + +clean: + rm -fr *.o a.out *.dSYM $(EXE) diff --git a/sequila/sequila-core/superintervals/test/__init__.py b/sequila/sequila-core/superintervals/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sequila/sequila-core/superintervals/test/bench.cpp b/sequila/sequila-core/superintervals/test/bench.cpp new file mode 100644 index 0000000..76ac7b8 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/bench.cpp @@ -0,0 +1,287 @@ +#include "intervalstab.hpp" +#include "IITree.hpp" +#include "IntervalTree.h" +extern "C" { + #include "cgranges.h" + #include "intervaldb.h" +} +#include "superintervals.hpp" + + +#include +#include +#include +#include +#include +#include +#include +#include + +using std::chrono::high_resolution_clock; +using std::chrono::duration_cast; +using std::chrono::duration; +using std::chrono::microseconds; +using std::chrono::milliseconds; + +namespace Bench { + +size_t found, index; +high_resolution_clock::time_point t0, t1; + +// Defines what is stored alongside a test interval +typedef int DType; + + +struct BedInterval { + int start; + int end; +}; + +size_t uSec(high_resolution_clock::time_point& t0) { + return duration_cast(high_resolution_clock::now() - t0).count(); +} + +void branch_factor(std::vector& ends) { + double avg = 0; + double max_count = 0; + std::vector counts(ends.size(), 0); + for (size_t i=0; i < ends.size() - 1; ++i) { + for (size_t j=i + 1; j < ends.size(); ++j) { + if (ends[j] >= ends[i]) { + break; + } + counts[j] += 1; + } + } + double sum_count = 0; + for (const auto &v : counts) { + if (v > max_count) { + max_count = (double)v; + } + sum_count += (double)v; + } + avg = sum_count / ((double)counts.size() + 0.0000001); + std::cout << "Avg branching: " << avg << " Max branching: " << max_count << std::endl; +} + +void load_intervals(const std::string& intervals_file, + const std::string& queries_file, + std::vector& intervals, + std::vector& queries) { + intervals.clear(); + queries.clear(); + std::ifstream intervals_stream(intervals_file); + std::ifstream queries_stream(queries_file); + if (!intervals_stream || !queries_stream) { + std::cerr << "Failed to open input files\n"; + std::exit(-1); + } + std::string line; + while (std::getline(intervals_stream, line)) { + std::istringstream iss(line); + std::string token; + std::getline(iss, token, '\t'); + if (token != "chr1") continue; + std::getline(iss, token, '\t'); + int start = std::stoi(token); + std::getline(iss, token, '\t'); + int end = std::stoi(token); + intervals.emplace_back(BedInterval{std::min(start, end), std::max(start, end)}); + } + while (std::getline(queries_stream, line)) { + std::istringstream iss(line); + std::string token; + std::getline(iss, token, '\t'); + if (token != "chr1") continue; + std::getline(iss, token, '\t'); + int start = std::stoi(token); + std::getline(iss, token, '\t'); + int end = std::stoi(token); + queries.emplace_back(BedInterval{std::min(start, end), std::max(start, end)}); + } +} + + +void run_IITree(std::vector& intervals, std::vector& queries) { + alignas(16) std::vector a, b; + a.reserve(10000); b.reserve(10000); + std::cout << "ImplicitITree-C++,"; + t0 = high_resolution_clock::now(); + IITree tree; + index = 0; + for (const auto& item : intervals) { + tree.add(item.start, item.end, index); + index += 1; + } + tree.index(); + std::cerr << uSec(t0) << ","; + + found = 0; + std::vector c; + t1 = high_resolution_clock::now(); + for (const auto& item : queries) { + tree.overlap(item.start, item.end, c); +// found += c.size(); + // IITree only returns indexes, need to reference actual data for a fair test + for (const auto &v: c) { + a.push_back(tree.data(v)); + } + found += a.size(); + a.clear(); + } + std::cout << uSec(t1) << "," << found << std::endl; +} + +void run_ITree(std::vector& intervals, std::vector& queries) { + alignas(16) std::vector a, b; + a.reserve(10000); b.reserve(10000); + std::cout << "IntervalTree-C++,"; + std::vector> intervals2; + index = 0; + for (const auto& item : intervals) { + intervals2.push_back(interval_tree::Interval(item.start, item.end - 1, index)); + index += 1; + } + t0 = high_resolution_clock::now(); + interval_tree::IntervalTree tree2(std::move(intervals2)); + std::cerr << uSec(t0) << ","; + + found = 0; + t1 = high_resolution_clock::now(); + for (const auto& item : queries) { + std::vector> result = tree2.findOverlapping(item.start, item.end - 1); + found += result.size(); + } + std::cerr << uSec(t1) << "," << found << std::endl; +} + +void run_NCLS(std::vector& intervals, std::vector& queries) { + alignas(16) std::vector a, b; + a.reserve(10000); b.reserve(10000); + std::cout << "NCLS-C,"; + t0 = high_resolution_clock::now(); + + int nD = (int)intervals.size(); + int* p_n = (int*)malloc(sizeof(int)); + int* p_nlists = (int*)malloc(sizeof(int)); + int* nhits = (int*)malloc(sizeof(int)); + IntervalMap* im = (IntervalMap*)calloc(1, sizeof(IntervalMap)); + CALLOC(im, nD, IntervalMap); + index = 0; + for (const auto& item : intervals) { + im[index].start = (int)item.start; + im[index].end = (int)item.end; + im[index].target_id = (int)index; + im[index].sublist = -1; + index += 1; + } + SublistHeader* sh = build_nested_list(im, nD, p_n, p_nlists); + std::cerr << uSec(t0) << ","; + + t1 = high_resolution_clock::now(); + IntervalIterator *it; +// IntervalIterator *it_alloc; + IntervalIterator *it_alloc = interval_iterator_alloc(); + IntervalMap im_buf[50000]; + found = 0; + for (const auto& item : queries) { +// it_alloc = interval_iterator_alloc(); + reset_interval_iterator(it_alloc); + it = it_alloc; + while(it){ + find_intervals(it, item.start, item.end, im, *p_n, sh, *p_nlists, im_buf, 50000, nhits, &it); + found += *nhits; + } +// free_interval_iterator(it_alloc); + } + std::cerr << uSec(t1) << "," << found << std::endl; +} + +void run_SuperIntervals(std::vector& intervals, std::vector& queries, + si::IntervalMap &itv, std::string name + ) { + alignas(16) std::vector a, b; + a.reserve(10000); b.reserve(10000); + + std::cout << name << ","; + index = 0; + t0 = high_resolution_clock::now(); + for (const auto& item : intervals) { + itv.add(item.start, item.end - 1, index); + index += 1; + } + itv.build(); +// branch_factor(itv.ends); + std::cout << uSec(t0) << ","; // construct + + found = 0; + t1 = high_resolution_clock::now(); + for (const auto& item : queries) { + itv.search_values(item.start, item.end - 1, a); + found += a.size(); + a.clear(); +// for (size_t idx: itv.search_idxs(item.start, item.end - 1)) { +// found += 1; +// } + } + std::cerr << uSec(t1) << "," << found << ","; // find all overlapping + + found = 0; + t1 = high_resolution_clock::now(); + for (const auto& item : queries) { + itv.search_values_large(item.start, item.end - 1, a); + found += a.size(); + a.clear(); + } + std::cerr << uSec(t1) << "," << found << ","; // find all overlapping large + + found = 0; + t1 = high_resolution_clock::now(); + for (const auto& item : queries) { + found += itv.count(item.start, item.end - 1); + } + std::cerr << uSec(t1) << "," << found << ","; // count all overlapping + + found = 0; + t1 = high_resolution_clock::now(); + for (const auto& item : queries) { + found += itv.count_large(item.start, item.end - 1); + } + std::cerr << uSec(t1) << "," << found << "\n"; // count all overlapping large + +} + +void run_tools(std::vector& intervals, std::vector& queries) { + + + + run_ITree(intervals, queries); + + run_IITree(intervals, queries); + + run_NCLS(intervals, queries); + + //auto itv = SuperIntervals(); + auto itv = si::IntervalMap(); + run_SuperIntervals(intervals, queries, itv, "SuperIntervals-C++"); + +// auto itv2 = si::IntervalMapEytz(); +// run_SuperIntervals(intervals, queries, itv2, "SuperIntervalsEytz-C++"); + +} + +} // namespace Bench + +int main(int argc, char *argv[]) { + if (argc < 3) { + printf("Usage: run-cpp-libs \n"); + return -1; + } + std::vector intervals; + std::vector queries; + + Bench::load_intervals(std::string(argv[1]), argv[2], intervals, queries); + Bench::run_tools(intervals, queries); + + return 0; +} diff --git a/sequila/sequila-core/superintervals/test/generate_test_intervals.py b/sequila/sequila-core/superintervals/test/generate_test_intervals.py new file mode 100644 index 0000000..83cea10 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/generate_test_intervals.py @@ -0,0 +1,81 @@ +""" +This script required bedtools to be available on your PATH. +Run this script from within this test directory +""" + +from subprocess import run +import numpy as np +from scipy import stats +import random +import os + + +g = "genome.txt" +out = "bench_data" +seed = 0 + +if not os.path.exists(out): + os.mkdir(out) + + +def generate_lognormal(mean, variance, size): + # Calculate mu and sigma + mu = np.log(mean ** 2 / np.sqrt(variance + mean ** 2)) + sigma = np.sqrt(np.log(1 + variance / mean ** 2)) + + # Generate log-normal distributed random variables + return stats.lognorm(sigma, loc=mean, scale=np.exp(mu)).rvs(size=size) + + +def existing_overlaps(target_queries): + """All queries overlap at least one reference interval""" + rand_sizes_q = generate_lognormal(1000, 5000, target_queries).astype(int) + rand_sizes_r = generate_lognormal(1000, 5000, target_queries).astype(int) + rand_pos = [random.randint(1, 250_000_000) for _ in range(target_queries)] + with open(f"{out}/REF_query_overlaps_all_ref.{target_queries}.bed", "w") as out_ref, \ + open(f"{out}/QUERY_query_overlaps_all_ref.{target_queries}.bed", "w") as out_query: + for sq, sr, pos in zip(rand_sizes_q, rand_sizes_r, rand_pos): + end = pos + sq + out_query.write(f"chr1\t{pos}\t{end}\n") + pos_ref = random.randint(max(1, pos - sr - 1), end - 1) + out_ref.write(f"chr1\t{pos_ref}\t{pos_ref + sr}\n") + + run(f"bedtools sort -i {out}/REF_query_overlaps_all_ref.{target_queries}.bed > {out}/REF_query_overlaps_all_ref.{target_queries}.srt.bed", shell=True) + run(f"bedtools sort -i {out}/QUERY_query_overlaps_all_ref.{target_queries}.bed > {out}/QUERY_query_overlaps_all_ref.{target_queries}.srt.bed", shell=True) + + print(f"Target queries {target_queries}, unsorted, then sorted:") + run(f"./run-cpp-libs {out}/REF_query_overlaps_all_ref.{target_queries}.bed {out}/QUERY_query_overlaps_all_ref.{target_queries}.bed", shell=True) + run(f"./run-cpp-libs {out}/REF_query_overlaps_all_ref.{target_queries}.srt.bed {out}/QUERY_query_overlaps_all_ref.{target_queries}.srt.bed", shell=True) + + +existing_overlaps(1_000_000) + + +def non_existing_overlaps(target_queries): + """All queries overlap zero reference intervals""" + rand_sizes_q = [1] * target_queries*2 + rand_sizes_r = [1] * target_queries*2 + rand_pos = [random.randint(1, 250_000_000) for _ in range(target_queries)] + with open(f"{out}/REF_query_overlaps_0_ref.{target_queries}.tmp.bed", "w") as out_ref, \ + open(f"{out}/QUERY_query_overlaps_0_ref.{target_queries}.bed", "w") as out_query: + for sq, sr, pos in zip(rand_sizes_q, rand_sizes_r, rand_pos): + end = pos + sq + out_query.write(f"chr1\t{pos}\t{end}\n") + pos_ref = end + 1 + out_ref.write(f"chr1\t{pos_ref}\t{pos_ref + sr}\n") + + run(f"bedtools subtract -a {out}/REF_query_overlaps_0_ref.{target_queries}.tmp.bed -b {out}/QUERY_query_overlaps_0_ref.{target_queries}.bed > {out}/REF_query_overlaps_0_ref.{target_queries}.bed", shell=True) + run(f"rm {out}/REF_query_overlaps_0_ref.{target_queries}.tmp.bed", shell=True) + + run(f"bedtools sort -i {out}/REF_query_overlaps_0_ref.{target_queries}.bed > {out}/REF_query_overlaps_0_ref.{target_queries}.srt.bed", + shell=True) + run(f"bedtools sort -i {out}/QUERY_query_overlaps_0_ref.{target_queries}.bed > {out}/QUERY_query_overlaps_0_ref.{target_queries}.srt.bed", + shell=True) + + print(f"Target queries {target_queries}, unsorted, then sorted:") + run(f"./run-cpp-libs {out}/REF_query_overlaps_0_ref.{target_queries}.bed {out}/QUERY_query_overlaps_0_ref.{target_queries}.bed", + shell=True) + run(f"./run-cpp-libs {out}/REF_query_overlaps_0_ref.{target_queries}.srt.bed {out}/QUERY_query_overlaps_0_ref.{target_queries}.srt.bed", + shell=True) + +# non_existing_overlaps(1_000_000) \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/test/genome.txt b/sequila/sequila-core/superintervals/test/genome.txt new file mode 100644 index 0000000..34fa9ff --- /dev/null +++ b/sequila/sequila-core/superintervals/test/genome.txt @@ -0,0 +1 @@ +chr1 250000000 diff --git a/sequila/sequila-core/superintervals/test/run-py-libs.py b/sequila/sequila-core/superintervals/test/run-py-libs.py new file mode 100644 index 0000000..7662ecc --- /dev/null +++ b/sequila/sequila-core/superintervals/test/run-py-libs.py @@ -0,0 +1,168 @@ +import time +import quicksect +from quicksect import Interval +from ncls import NCLS +import cgranges as cr + +from superintervals import IntervalMap + +import numpy as np +import pandas as pd +import sys + + +def load_intervals(intervals_path, queries_path): + queries = [] + intervals = [] + with open(intervals_path, "r") as f: + for line in f: + l = line.split("\t") + intervals.append((int(l[1]), int(l[2]))) + with open(queries_path, "r") as f: + for line in f: + l = line.split("\t") + queries.append((int(l[1]), int(l[2]))) + return np.array(intervals), np.array(queries) + + +def to_micro(t0): + return int((time.time() - t0) * 1000000) + + +def run_tools(intervals, queries): + # superintervals - Updated for new API + t0 = time.time() + sitv = IntervalMap() + for i, (s, e) in enumerate(intervals): + sitv.add(s, e, i) # Using index as data + sitv.build() + build_time = to_micro(t0) + print(f"SuperIntervals-py,{build_time},", end='') + + # Test search_values (returns data objects) + t0 = time.time() + v = 0 + for start, end in queries: + a = sitv.search_values(start, end) + v += len(a) + search_time = to_micro(t0) + print(f'{search_time},{v},', end='') + + # Test count method + t0 = time.time() + v = 0 + for start, end in queries: + v += sitv.count(start, end) + count_time = to_micro(t0) + print(f'{count_time},{v}') + + # quicksect + t0 = time.time() + tree = quicksect.IntervalTree() + for s, e in intervals: + tree.add(s, e) + quicksect_build = to_micro(t0) + print(f"Quicksect,{quicksect_build},", end='') + + t0 = time.time() + v = 0 + for start, end in queries: + a = tree.find(Interval(start, end)) + v += len(a) + quicksect_search = to_micro(t0) + print(f'{quicksect_search},{v}') + + # cgranges + t0 = time.time() + cg = cr.cgranges() + for s, e in intervals: + cg.add("1", s, e + 1, 0) + cg.index() + cgranges_build = to_micro(t0) + print(f"Cgranges-py,{cgranges_build},", end='') + + t0 = time.time() + v = 0 + for start, end in queries: + a = list(cg.overlap("1", start, end + 1)) + v += len(a) + cgranges_search = to_micro(t0) + print(f'{cgranges_search},{v}') + + # ncls + t0 = time.time() + starts = pd.Series(intervals[:, 0]) + ends = pd.Series(intervals[:, 1]) + treencls = NCLS(starts, ends, starts) + ncls_build = to_micro(t0) + print(f"NCLS-py,{ncls_build},", end='') + + t0 = time.time() + v = 0 + for start, end in queries: + a = list(treencls.find_overlap(start - 1, end + 1)) + v += len(a) + ncls_search = to_micro(t0) + print(f'{ncls_search},{v}') + + +def run_superintervals_detailed_benchmark(intervals, queries): + """ + Run a more detailed benchmark of superintervals methods + """ + print("\n=== SuperIntervals Detailed Benchmark ===") + + # Setup + sitv = IntervalMap() + for i, (s, e) in enumerate(intervals): + sitv.add(s, e, f"data_{i}") + + t0 = time.time() + sitv.build() + print(f"Index time: {to_micro(t0)} microseconds") + + # Test different search methods + methods = [ + ("search_values", lambda start, end: len(sitv.search_values(start, end))), + ("search_items", lambda start, end: len(sitv.search_items(start, end))), + ("search_keys", lambda start, end: len(sitv.search_keys(start, end))), + ("search_idxs", lambda start, end: len(sitv.search_idxs(start, end))), + ("count", lambda start, end: sitv.count(start, end)), + ("has_overlaps", lambda start, end: int(sitv.has_overlaps(start, end))), + ] + + for method_name, method_func in methods: + t0 = time.time() + total_results = 0 + for start, end in queries: + total_results += method_func(start, end) + elapsed = to_micro(t0) + print(f"{method_name}: {elapsed} microseconds, {total_results} total results") + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python benchmark.py ") + print("Both files should be in BED format (tab-separated: chr start end ...)") + sys.exit(1) + + intervals_file = sys.argv[1] + queries_file = sys.argv[2] + + print("Tool,BuildTime(μs),SearchTime(μs),ResultCount") + + try: + intervals, queries = load_intervals(intervals_file, queries_file) + print(f"# Testing with {len(intervals)} intervals and {len(queries)} queries", file=sys.stderr) + + run_tools(intervals, queries) + + # Run detailed superintervals benchmark + run_superintervals_detailed_benchmark(intervals, queries) + + except FileNotFoundError as e: + print(f"Error: Could not find file - {e}", file=sys.stderr) + sys.exit(1) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) \ No newline at end of file diff --git a/sequila/sequila-core/superintervals/test/run_benchmark_all.sh b/sequila/sequila-core/superintervals/test/run_benchmark_all.sh new file mode 100755 index 0000000..da90227 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/run_benchmark_all.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# +# Overview +# --------- +# SuperIntervals (SI) was compared with: +# 1. Coitrees (Rust: https://github.com/dcjones/coitrees) +# 2. Implicit Interval Tree (C++: https://github.com/lh3/cgranges) +# 3. Interval Tree (C++: https://github.com/ekg/intervaltree) +# 4. Nested Containment List (C: https://github.com/pyranges/ncls/tree/master/ncls/src) +# +# Datasets: +# 1. Random regions generated using bedtools +# 2. RNA-seq reads and annotations from cgranges repository +# 3. ONT reads from sample PAO33946 (chr1 + chrM) +# 4. Paired-end reads from sample DB53, NCBI BioProject PRJNA417592, (chr1 + chrM) +# 5. ucsc genes for hg19 +# +# Note, programs only assess chr1 bed records - other chromosomes are ignored. For 'chrM' records, +# the M was replaced with 1 using sed. +# +# Data were assessed in position sorted and random order. Datasets can be found on the Releases page. +# +# For Coitrees, a benchmark program is available at: +# https://github.com/kcleal/superintervals/tree/main/test/3rd-party/coitrees/examples +# !!! This will need to be build before hand + +# Running: +# -------- +# Run this script from the top-level superintervals repo +# git clone https://github.com/kcleal/superintervals +# cd superintervals +# +# Fetch data using: +# wget https://github.com/kcleal/superintervals/releases/download/v0.2.0/data.tar.gz +# tar -xvf data.tar.gz +# +# Run this script once: +# bash test/run_tools.sh +# +# To run a few times and store results: +# mkdir -p results +# for r in {1..5}; do bash run_tools.sh 2>&1 | tee results/run${r}.txt done + + +run_libs=test/run-cpp-libs +rust_coitrees=test/3rd-party/coitrees/target/release/examples/bed-intersect +rust_superintervals=target/release/examples/bed-intersect-si + + +run_benchmarks() { + local ref="$1" + local query="$2" + local label1="$3" + local label2="$4" + + echo "$label1, $label2" + $run_libs "$ref" "$query" + $rust_superintervals "$ref" "$query" + $rust_coitrees "$ref" "$query" + $rust_coitrees -s "$ref" "$query" + echo +} + +# Prepare and run benchmark sets +prepare_and_run() { + local ref="$1" + local query="$2" + local name1="$3" + local name2="$4" + + # SORTED + grep -P '^chr1\t' "$ref" | sort -k1,1 -k2,2n > ref + grep -P '^chr1\t' "$query" | sort -k1,1 -k2,2n > query + echo "SORTED" + run_benchmarks ref query "$name1" "$name2" + + # SHUFFLED + grep -P '^chr1\t' "$ref" | shuf > ref + grep -P '^chr1\t' "$query" | shuf > query + echo "SHUFFLED" + run_benchmarks ref query "$name1" "$name2" + + echo "----------------------------------" +} + +# random regions from bedtools random, baseline +# prepare_and_run data/l1000_n1000.b.bed data/l1000_n1000.b.bed "rand-b" "rand-a" + +# chrM reads from DB53 (renamed as chr1) +prepare_and_run data/DB53.chrM_reads_as_chr1.bed data/DB53.chrM_reads_as_chr1.bed "mito-b" "mito-a" +#prepare_and_run data/PAO33946.chrM_reads_as_chr1.bed data/PAO33946.chrM_reads_as_chr1.bed "mito-lr-b" "mito-lr-a" + +# RNA anno from cgranges releases page +prepare_and_run data/ex-rna.bed data/ex-anno.bed "rna" "anno" +prepare_and_run data/ex-anno.bed data/ex-rna.bed "anno" "rna" + +# genes vs reads +prepare_and_run data/ucsc.hg19.genes.bed data/chr1.DB53.bed "genes" "DB53 reads" +prepare_and_run data/chr1.DB53.bed data/ucsc.hg19.genes.bed "DB53 reads" "genes" + +# long reads vs short reads +prepare_and_run data/PAO33946.chr1.bed data/chr1.DB53.bed "ONT reads" "DB53 reads" +prepare_and_run data/chr1.DB53.bed data/PAO33946.chr1.bed "DB53 reads" "ONT reads" + diff --git a/sequila/sequila-core/superintervals/test/tests.cpp b/sequila/sequila-core/superintervals/test/tests.cpp new file mode 100644 index 0000000..6be53c9 --- /dev/null +++ b/sequila/sequila-core/superintervals/test/tests.cpp @@ -0,0 +1,267 @@ + +#include "superintervals.hpp" +#include +#include +#include +#include + + +void superTests(si::IntervalMap &itv, std::string name) { + std::vector a; + std::vector a_idxs; + std::vector> a_keys; + std::pair cov_res; + + assert(!itv.has_overlaps(0, 1)); + + std::cout << "\n" << name << " tests \n"; + + std::cout << "0, "; + itv.add(10, 20, 0); + itv.add(11, 12, -1); + itv.add(13, 14, -1); + itv.add(15, 16, -1); + itv.add(25, 29, 4); + itv.build(); + assert(itv.has_overlaps(17, 30)); + assert(!itv.has_overlaps(1, 3)); + itv.search_values(17, 30, a); + assert (a[0] == 4); assert (a[1] == 0); + itv.coverage(10, 29, cov_res); + assert (cov_res.second == 17); + itv.search_idxs(17, 30, a_idxs); + assert (a_idxs[0] == 4); assert (a_idxs[1] == 0); + itv.search_keys(17, 30, a_keys); + assert (a_keys[0].first == 25); assert (a_keys[1].second == 20); + itv.clear(); a.clear(); cov_res = {0, 0}; a_idxs.clear(); a_keys.clear(); + + std::cout << "1, "; + itv.add(1, 2, 0); + itv.add(3, 8, -1); + itv.add(5, 7, -1); + itv.add(7, 20, 3); // i=3 + itv.add(9, 10, -1); + itv.add(13, 15, -1); + itv.add(15, 16, -1); + itv.add(19, 30, 7); // i=7 + itv.add(22, 24, -1); + itv.add(24, 25, -1); + itv.add(26, 28, -1); /// i=10 + itv.add(32, 39, -1); + itv.add(34, 36, -1); + itv.add(38, 40, -1); + itv.build(); + itv.search_values(17, 21, a); + assert (a[0] == 7); assert (a[1] == 3); + itv.coverage(17, 21, cov_res); + assert (cov_res.second == 5); + itv.clear(); a.clear(); cov_res = {0, 0}; + + std::cout << "2, "; + itv.add(0, 250000, 0); // + itv.add(55, 1055, -1); + itv.add(115, 1115, -1); + itv.add(130, 1130, -1); + itv.add(281, 1281, -1); + itv.add(639, 1639, -1); // - 5 + itv.add(842, 1842, -1); + itv.add(999, 1999, -1); + itv.add(1094, 2094, -1); + itv.add(1157, 2157, -1); + itv.add(1161, 2161, -1); + itv.add(1265, 2265, -1); // 11 + itv.add(1532, 2532, -1); + itv.add(1590, 2590, -1); + itv.add(1665, 2665, -1); + itv.add(1945, 2945, -1); // ^ 15 + itv.add(2384, 3384, -1); + itv.add(2515, 3515, -1); + itv.build(); + itv.search_values(1377, 2377, a); + + assert (a.back() == 0 && a.size() == 12); + itv.clear(); a.clear(); + + std::cout << "3, "; + itv.add(0, 400, 0); + itv.add(2, 10, 1); + itv.add(4, 6, 2); // i=2 + itv.add(6, 7, 3); + itv.add(9, 20, 11); + itv.add(15, 70, 22); + itv.add(19, 30, 33); + itv.add(29, 40, 44); + itv.add(39, 50, 55); + itv.add(49, 60, 66); + itv.add(58, 59, 77); + itv.build(); + itv.search_values(1, 5, a); + assert (a.back() == 0 && a.size() == 3); + itv.clear(); a.clear(); + + std::cout << "4, "; + itv.add(1, 6100000, 0); + itv.add(4, 5, 6); + itv.add(6, 7, 7); + itv.add(9, 10, 7); + itv.add(11, 12, 7); + itv.build(); + itv.search_values(2, 25, a); + assert (a.back() == 0 && a.size() == 5); + itv.clear(); a.clear(); + + std::cout << "5, "; + itv.add(1, 100, 0); + itv.add(30, 200, 7); + itv.add(40, 50, 6); + itv.add(60, 70, 7); + itv.build(); + itv.search_values(55, 65, a); + assert (a.back() == 0 && a.size() == 3); + itv.coverage(55, 65, cov_res); + assert (cov_res.second == 25); + itv.clear(); a.clear(); cov_res = {0, 0}; + + std::cout << "6, "; + itv.add(10, 1001, 0); + itv.add(30, 400, 1); + itv.add(60, 700, 2); + itv.add(65, 80, 3); + itv.build(); + itv.search_values(100, 200, a); + assert (a.back() == 0 && a.size() == 3); + itv.clear(); a.clear(); + + std::cout << "7, "; + itv.add(3, 40, 0); + itv.add(10, 30, 5); + itv.add(20, 25, 5); + itv.add(22, 24, 6); + itv.build(); + itv.search_values(31, 32, a); + assert (a.back() == 0 && a.size() == 1); + itv.clear(); a.clear(); + + std::cout << "8, "; + itv.add(3, 40, 0); + itv.add(4, 5, 4); + itv.add(6, 7, 4); + itv.add(10, 31, 5); + itv.add(31, 32, 5); + itv.build(); + itv.search_values(31, 32, a); assert (a.back() == 0 && a.size() == 3); a.clear(); + itv.search_values(10, 11, a); assert (a.back() == 0 && a.size() == 2); a.clear(); + itv.search_values(8, 40, a); assert (a.back() == 0 && a.size() == 3); a.clear(); + itv.search_values(4, 7, a); assert (a.back() == 0 && a.size() == 3); a.clear(); + itv.clear(); + + std::cout << "9, "; + itv.add(3, 40, 0); + itv.add(3, 40, 4); + itv.add(3, 40, 4); + itv.add(3, 4, 4); + itv.add(35, 50, 4); + itv.add(40, 400, 5); + itv.add(40, 400, 4); + itv.build(); + itv.search_values(38, 41, a); assert (a.back() == 0 && a.size() == 6); a.clear(); + itv.search_values(41, 42, a); assert (a.back() == 4 && a.size() == 3); a.clear(); + itv.search_values(339, 410, a); assert (a.back() == 5 && a.size() == 2); a.clear(); + itv.clear(); + + std::cout << "10, "; + itv.add(0, 100, 0); + itv.add(10, 110, 2); + itv.add(10, 20, 1); + itv.add(30, 40, 3); + itv.add(35, 135, 4); + itv.add(45, 55, 5); + itv.add(60, 160, 6); + itv.add(70, 80, 7); + itv.add(90, 190, 8); + itv.add(110, 120, 9); + itv.add(130, 140, 10); + itv.add(150, 250, 11); + itv.build(); + itv.search_values(95, 105, a); assert (a.front() == 8 && a.size() == 5); + itv.clear(); a.clear(); + + std::cout << "11, "; + itv.add(1, 100, 7); + itv.add(10, 110, 8); + itv.add(15, 16, 9); + itv.add(30, 130, 11); + itv.add(50, 60, 4); + itv.add(100, 200, 11); + itv.build(); + itv.search_values(20, 90, a); + + assert (a.front() == 4 && a.size() == 4); + itv.clear(); a.clear(); + + std::cout << "12, "; + itv.add(1, 10, 1); + itv.build(); + size_t count = itv.count(1, 5); + assert (count == 1); + itv.clear(); a.clear(); + + std::cout << "13, "; + itv.build(); + count = itv.count(1, 5); + assert (count == 0); + itv.clear(); a.clear(); + + std::cout << "14, "; + itv.add(1, 10, 0); + itv.build(); + auto iter = itv.search_idxs(5, 11); + count = 0; + int last_d{}; + for (const auto& i : iter) { + last_d = itv.data[i]; + count += 1; + } + assert (count == 1); assert (last_d == 0); + itv.clear(); a.clear(); + +// std::cout << "15, "; +// itv.add(10, 11, -1); +// itv.add(1, 100, 1); +// itv.add(1, 1000, 2); +// itv.build(); +// count = 0; +// for (const auto& i : itv.search_idxs(5, 11)) { +// last_d = itv.data[i]; +// count += 1; +// } +// assert (count == 3); assert (last_d == 2); +// a.clear(); +// bool any = false; +// for (auto i : itv.search_items(0, 0)) { any = true; } +// assert (!any); +// for (auto i : itv.search_items(2000, 2001)) { any = true; } +// assert (!any); +// itv.clear(); a.clear(); +// +// std::cout << "16\n"; +// itv.add(1, 10, 0); +// itv.build(); +// count = 0; +// for (const auto& i : itv.search_idxs(11, 12)) { +// last_d = itv.data[i]; +// count += 1; +// } +// assert (count == 0); +// itv.clear(); a.clear(); + + std::cout << "All tests passed for " << name << "\n"; + +} + + +int main(int argc, char *argv[]) { + auto itv = si::IntervalMap(); + superTests(itv, "SuperIntervals"); + return 0; +} diff --git a/sequila/sequila-core/tests/integration_test.rs b/sequila/sequila-core/tests/integration_test.rs index e0244e5..81e07c1 100644 --- a/sequila/sequila-core/tests/integration_test.rs +++ b/sequila/sequila-core/tests/integration_test.rs @@ -71,6 +71,7 @@ fn expected_equi() -> [&'static str; 20] { #[case::interval_join_interval_tree(Some(Algorithm::IntervalTree))] #[case::interval_join_array_interval_tree(Some(Algorithm::ArrayIntervalTree))] #[case::interval_join_lapper(Some(Algorithm::Lapper))] +#[case::interval_join_superintervals(Some(Algorithm::SuperIntervals))] async fn test_equi_and_range_condition( #[case] algorithm: Option, ctx: SessionContext, @@ -166,6 +167,7 @@ fn expected_range() -> [&'static str; 36] { #[case::interval_join_interval_tree(Some(Algorithm::IntervalTree))] #[case::interval_join_array_interval_tree(Some(Algorithm::ArrayIntervalTree))] #[case::interval_join_lapper(Some(Algorithm::Lapper))] +#[case::interval_join_superintervals(Some(Algorithm::SuperIntervals))] async fn test_range_condition( #[case] algorithm: Option, ctx: SessionContext,