Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
245 commits
Select commit Hold shift + click to select a range
a024f61
jit lto interleaved scan
divyegala Oct 2, 2025
45da4aa
fix dependencies.yaml
divyegala Oct 2, 2025
a7c8621
generate files at build time, use tags to avoid compilation of types
divyegala Oct 4, 2025
eb2d74b
passing tests
divyegala Oct 5, 2025
d2318e8
update gitignore
divyegala Oct 6, 2025
5e6afcd
separate out distance function from main kernel
divyegala Oct 6, 2025
6eee4da
fix deps
divyegala Oct 6, 2025
1de8f28
add filters as jit device functions, rework caching logic
divyegala Oct 7, 2025
84c6020
lto post lambda, cleanup files, generate cmake in build dir
divyegala Oct 7, 2025
22680c8
don't read hardcoded kernels, use generator properly
divyegala Oct 8, 2025
37f1163
random cmake changes carried over from 25.10
divyegala Oct 8, 2025
0ae5383
cmake format
divyegala Oct 8, 2025
fe56aec
remove dep on kernel list
divyegala Oct 8, 2025
40c8fd6
attempt to solve overlinking problem
divyegala Oct 9, 2025
e87a8c7
reorder if-else in compiler check
divyegala Oct 9, 2025
179d733
Merge branch 'branch-25.12' into jit-lto-ivf-flat-interleaved
divyegala Oct 9, 2025
32a67bd
use cudart apis
divyegala Oct 9, 2025
c27612e
merge
divyegala Oct 9, 2025
a4b48b1
attempt to link cudart
divyegala Oct 9, 2025
d5d692e
revert cudart link, try all arch build of jit lto fatbin sources
divyegala Oct 9, 2025
1c6dd94
cmake format
divyegala Oct 9, 2025
30f5ab6
missing shared mem setting
divyegala Oct 10, 2025
9674969
separate cuda 12 and 13 compilation
divyegala Oct 22, 2025
24fc47d
merge upstream
divyegala Oct 22, 2025
db9a487
remove bench
divyegala Oct 22, 2025
aa9294f
c include directory
divyegala Oct 22, 2025
2eb77fe
style check
divyegala Oct 22, 2025
6c685fa
merge upstream
divyegala Oct 22, 2025
3e35b99
guard cuda calls and use shared_ptr
divyegala Oct 23, 2025
d0ff62c
add AlgorithmPlanner to main target
divyegala Oct 23, 2025
eb87577
merge upstream
divyegala Oct 23, 2025
445a6c4
remove nvjitlink as cuda 12 dep
divyegala Oct 23, 2025
92a27d4
address review
divyegala Oct 24, 2025
8549172
merge upstream
divyegala Oct 24, 2025
67579f4
add include guard
divyegala Oct 27, 2025
7ad8774
add and remove couple of comments
divyegala Oct 27, 2025
816a480
merge upstream
divyegala Oct 27, 2025
ab35ef3
delete readme
divyegala Oct 27, 2025
cdd4c85
increase warmup time
divyegala Oct 27, 2025
87334b2
merge upstream
divyegala Oct 27, 2025
c1eff9f
use new copyright
divyegala Oct 27, 2025
ece09b8
new copyright
divyegala Oct 27, 2025
4dacc6e
remove one more straggling comment
divyegala Oct 27, 2025
1fd95cd
use raft expects
divyegala Oct 27, 2025
64cde0d
Merge branch 'main' into jit-lto-ivf-flat-interleaved
divyegala Oct 27, 2025
5ac127b
merge upstream
divyegala Dec 12, 2025
78002c6
address review
divyegala Dec 12, 2025
9ad6a0b
pre-commit
divyegala Dec 12, 2025
bf4c4ad
address review
divyegala Dec 12, 2025
18b2af9
Generate kernel files in CMake instead of Python
KyleFromNVIDIA Dec 12, 2025
ece5cad
Merge remote-tracking branch 'refs/remotes/github/divyegala/jit-lto-i…
KyleFromNVIDIA Dec 12, 2025
8ce70c2
Style
KyleFromNVIDIA Dec 12, 2025
fdc4239
Style
KyleFromNVIDIA Dec 12, 2025
be3cf0d
Style
KyleFromNVIDIA Dec 12, 2025
7e644c3
Lint
KyleFromNVIDIA Dec 12, 2025
235938a
Style, lint
KyleFromNVIDIA Dec 12, 2025
e3b749d
Fix nvjitlink_checker
KyleFromNVIDIA Dec 15, 2025
f42ae3f
Style
KyleFromNVIDIA Dec 15, 2025
b606df9
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Dec 15, 2025
5ce7aab
Refactor JIT LTO kernel compilation
KyleFromNVIDIA Dec 15, 2025
eaad347
Style
KyleFromNVIDIA Dec 15, 2025
eb3b468
pic
KyleFromNVIDIA Dec 15, 2025
912279c
style
KyleFromNVIDIA Dec 15, 2025
19f1af3
Verbose build
KyleFromNVIDIA Dec 15, 2025
087b943
static
KyleFromNVIDIA Dec 15, 2025
c16e109
style
KyleFromNVIDIA Dec 15, 2025
323b79f
TARGET_OBJECTS
KyleFromNVIDIA Dec 15, 2025
9f13e73
Disable sccache
KyleFromNVIDIA Dec 16, 2025
eaf9d39
Recache
KyleFromNVIDIA Dec 16, 2025
ce40c51
Revert CI debugging
KyleFromNVIDIA Dec 16, 2025
0d0abb9
Install and link object library
KyleFromNVIDIA Dec 17, 2025
84bfa92
Style
KyleFromNVIDIA Dec 17, 2025
21241eb
Alias
KyleFromNVIDIA Dec 17, 2025
7c0ac13
Make cuvs_jit_lto_kernels a static library
KyleFromNVIDIA Dec 17, 2025
880dbf2
Style
KyleFromNVIDIA Dec 17, 2025
d04d7c1
rapids_cuda_init_architectures() for C tests
KyleFromNVIDIA Dec 17, 2025
19581f9
Be more specific about where we search for libclang
KyleFromNVIDIA Dec 17, 2025
a61f019
More libclang updates
KyleFromNVIDIA Dec 17, 2025
2eeb913
Revert "Fix libclang download for Rust, CUDA initialization for C tests"
KyleFromNVIDIA Dec 17, 2025
55ec26c
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Dec 18, 2025
10228c5
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Dec 18, 2025
031ce21
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Jan 14, 2026
088c21e
Copyright
KyleFromNVIDIA Jan 14, 2026
8ca1062
Apply suggestions from code review
divyegala Jan 22, 2026
d5ab5bf
merge upstream
divyegala Jan 22, 2026
b8c0d42
address some review comments
divyegala Jan 22, 2026
17d34ae
remove too many underscores
divyegala Jan 22, 2026
45a5146
FEA Add initial commit of prototype/pseudo-code for proposed UDF APIs…
dantegd Jan 26, 2026
447532e
stitch together
divyegala Jan 30, 2026
e1627d1
add udf to cmakelists
divyegala Jan 30, 2026
f7ea581
udfs working e2e
divyegala Jan 30, 2026
8b2775c
run benchmarks
divyegala Feb 3, 2026
e9c77d9
working through
divyegala Feb 3, 2026
adcfb8f
fixed overhead
divyegala Feb 4, 2026
282b376
Simplify
KyleFromNVIDIA Feb 4, 2026
609a4d6
Merge branch 'main' into jit-lto-ivf-flat-interleaved
KyleFromNVIDIA Feb 4, 2026
3115d07
address reviews
divyegala Feb 4, 2026
bb524ae
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 4, 2026
30a8a9f
Merge branch 'jit-lto-ivf-flat-interleaved' of github.com:divyegala/c…
divyegala Feb 4, 2026
72ddb36
Merge branch 'main' into jit-lto-ivf-flat-interleaved
divyegala Feb 5, 2026
4bd2102
add to docs and log about jit
divyegala Feb 10, 2026
fb722f0
Merge branch 'jit-lto-ivf-flat-interleaved' of github.com:divyegala/c…
divyegala Feb 10, 2026
3523b96
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 10, 2026
ba758a2
address review
divyegala Feb 10, 2026
42b78ae
rename inner_product to inner_prod
divyegala Feb 10, 2026
2e3a471
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 10, 2026
bfc6c09
fix merge conflict
divyegala Feb 10, 2026
f6377fa
include header and form better log
divyegala Feb 10, 2026
26abc7b
Merge branch 'jit-lto-ivf-flat-interleaved' into ivf-flat-search-udf
divyegala Feb 10, 2026
fb7f105
merge
divyegala Feb 10, 2026
432bb32
working through
divyegala Feb 11, 2026
533b770
address review and move
divyegala Feb 11, 2026
af23585
Merge remote-tracking branch 'origin/main' into jit-lto-ivf-flat-inte…
divyegala Feb 11, 2026
78c59d9
one more fix
divyegala Feb 11, 2026
9274868
Merge branch 'jit-lto-ivf-flat-interleaved' into cagra-search-jit-lto
divyegala Feb 11, 2026
7f8802b
correct path
divyegala Feb 11, 2026
f432aad
Merge branch 'jit-lto-ivf-flat-interleaved' into cagra-search-jit-lto
divyegala Feb 11, 2026
39ce9e3
in the middle of stuff
divyegala Feb 13, 2026
27acbb6
merge upstream
divyegala Feb 13, 2026
d11edfd
Merge branch 'jit-lto-ivf-flat-interleaved' into ivf-flat-search-udf
divyegala Feb 13, 2026
dd23671
multi-cta still failing
divyegala Feb 13, 2026
4f287c1
attempting to solve 2 kernel issue
divyegala Feb 15, 2026
64f6ad8
merge upstream
divyegala Feb 15, 2026
f1888a2
more cleaning
divyegala Feb 15, 2026
b596e79
merge cleanly
divyegala Feb 15, 2026
9c4980f
add nvrtc as a dependency
divyegala Feb 15, 2026
f27eeb2
fix build errors
divyegala Feb 15, 2026
bc5c90e
guard udf use
divyegala Feb 15, 2026
09dc56c
analyzing cubins
divyegala Feb 15, 2026
55c32f4
compiler definition on headers
divyegala Feb 15, 2026
1866475
guard udf test
divyegala Feb 15, 2026
c419173
remove
divyegala Feb 15, 2026
04cc166
missing include
divyegala Feb 15, 2026
1113afc
cleaning up
divyegala Feb 15, 2026
e372917
merge upstream
divyegala Feb 15, 2026
d8341ac
Merge remote-tracking branch 'divye/unneeded-cccl-includes' into cagr…
divyegala Feb 15, 2026
6feecce
most errors resolved
divyegala Feb 17, 2026
3e9f5f3
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 17, 2026
52e05c2
debug filter fragment
divyegala Feb 17, 2026
caf8d03
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 18, 2026
b65f599
occassional failure on dgx spark
divyegala Feb 18, 2026
5239a1a
fix compile
divyegala Feb 18, 2026
736dc75
Ignore cache-host run exports
bdice Feb 18, 2026
f83f595
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 18, 2026
a7a4ef7
pull out metric
divyegala Feb 19, 2026
5390c4c
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Feb 19, 2026
07a158c
use void* for desc and create more fragments
divyegala Feb 19, 2026
0e201e8
attempt to fix cuda 12 builds
divyegala Feb 19, 2026
88a4b6e
respond to reviews
divyegala Feb 19, 2026
101c5ee
Merge remote-tracking branch 'origin/main' into ivf-flat-search-udf
divyegala Feb 19, 2026
5d3a9df
Merge branch 'ivf-flat-search-udf' of github.com:divyegala/cuvs into …
divyegala Feb 19, 2026
63c7300
pin cupy to <14.0 for cuda 12 wheels
divyegala Feb 19, 2026
0c0b6b5
fix cuda 12
divyegala Feb 19, 2026
faa9339
add includes
divyegala Feb 19, 2026
73e8fa0
fix logging
divyegala Feb 19, 2026
fef68d3
fix macro
divyegala Feb 19, 2026
05cc149
major refactor to reduce # of fragments
divyegala Feb 20, 2026
b6c9031
merge upstream udf pr
divyegala Feb 20, 2026
995f998
Merge branch 'main' into ivf-flat-search-udf
divyegala Feb 20, 2026
75e2616
Account for different QueryT
divyegala Feb 20, 2026
387d9ea
Merge remote-tracking branch 'origin/main' into cagra-search-jit-lto
divyegala Feb 20, 2026
1ccb01c
cleanup some stuff
divyegala Feb 20, 2026
3256a8e
attempt to fix devcontainer error
divyegala Feb 20, 2026
32a5d9f
Merge remote-tracking branch 'origin/main' into ivf-flat-search-udf
divyegala Feb 20, 2026
592af70
Merge branch 'ivf-flat-search-udf' of github.com:divyegala/cuvs into …
divyegala Feb 20, 2026
43501b7
address review comments
divyegala Feb 20, 2026
b5342d6
Merge branch 'ivf-flat-search-udf' into cagra-search-jit-lto
divyegala Feb 20, 2026
b85f16b
Add matrix JSON files
KyleFromNVIDIA Feb 24, 2026
e79de08
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Feb 24, 2026
de0a2b5
Fix
KyleFromNVIDIA Feb 24, 2026
c7909c3
more refactors and fix stream serialization bug
divyegala Feb 24, 2026
bbbfb25
launch correctly
divyegala Feb 25, 2026
22c40fd
Use new kernel matrix system
KyleFromNVIDIA Feb 25, 2026
d404869
remove debug prints
divyegala Feb 25, 2026
53ce0aa
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Feb 25, 2026
9fc9185
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Feb 25, 2026
1eef8c5
Remove preprocessor branch
KyleFromNVIDIA Feb 25, 2026
0af09e2
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Feb 25, 2026
b2e418b
reconcile pr 1807 and add nvjitlink/nvrtc to jit target
divyegala Feb 26, 2026
53195ef
Fix ivf flat
KyleFromNVIDIA Feb 26, 2026
f589b26
Fix kernel names and matrices
KyleFromNVIDIA Feb 26, 2026
6b8d175
Fix query
KyleFromNVIDIA Feb 26, 2026
426625e
Fix another query
KyleFromNVIDIA Feb 26, 2026
97dfa18
More
KyleFromNVIDIA Feb 26, 2026
29881c8
Make naming and matrices more consistent
KyleFromNVIDIA Feb 26, 2026
bb01ec6
add func specialization for smem launcher
divyegala Feb 26, 2026
6b32331
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Feb 26, 2026
6516f78
fix ivf flat udf key
divyegala Feb 27, 2026
d737706
remove debug
divyegala Feb 27, 2026
a809041
Remove comments and debug statement, fix query, copyright
KyleFromNVIDIA Feb 27, 2026
0d48be2
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Feb 27, 2026
49f999f
missing query tag
divyegala Feb 27, 2026
d66edf0
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Feb 27, 2026
b52f8c2
Refactor and make thread-safe
KyleFromNVIDIA Feb 27, 2026
0349746
remove prints
divyegala Feb 27, 2026
6e07abb
remove unnecessary includes
divyegala Feb 27, 2026
e9e2ff0
Don't build fatbins with debug symbols
KyleFromNVIDIA Mar 4, 2026
9bd6100
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 5, 2026
34ed3e2
Merge branch 'main' into cagra-search-jit-lto
divyegala Mar 5, 2026
e6f06fc
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 5, 2026
5552b2f
Merge remote-tracking branch 'github/divyegala/cagra-search-jit-lto' …
KyleFromNVIDIA Mar 5, 2026
582d6a0
unpin raft
divyegala Mar 6, 2026
fb13ea5
Merge remote-tracking branch 'origin/main' into cagra-search-jit-lto
divyegala Mar 6, 2026
98a1dce
Update cpp/cmake/thirdparty/get_raft.cmake
divyegala Mar 6, 2026
a39c150
Update cpp/cmake/thirdparty/get_raft.cmake
divyegala Mar 6, 2026
c3a8d73
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 9, 2026
8dfb354
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 9, 2026
33e1bc5
Add L1 dist op
KyleFromNVIDIA Mar 9, 2026
f050b77
Fix L1 distance
KyleFromNVIDIA Mar 10, 2026
d6eec0a
Explicitly install cudart
KyleFromNVIDIA Mar 11, 2026
1f3b75b
use function ptr indirection
divyegala Mar 12, 2026
832eaf2
Merge remote-tracking branch 'origin/main' into cagra-search-jit-lto
divyegala Mar 12, 2026
9243390
const
KyleFromNVIDIA Mar 12, 2026
dca579a
extern
KyleFromNVIDIA Mar 12, 2026
f11daf5
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 12, 2026
1c2da37
Re-run CI
KyleFromNVIDIA Mar 12, 2026
ff3527b
fix bug and simplify json
divyegala Mar 12, 2026
671e8a7
Merge branch 'cagra-search-jit-lto' of github.com:divyegala/cuvs into…
divyegala Mar 12, 2026
e14a119
simply function ptr usage
divyegala Mar 13, 2026
39e67f3
call functions directly
divyegala Mar 13, 2026
59f8911
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 16, 2026
be21da4
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Mar 18, 2026
fb5cf1e
Merge branch 'main' into cagra-search-jit-lto
KyleFromNVIDIA Apr 15, 2026
8e6797c
merge upstream, make tests pass
divyegala Apr 17, 2026
ca67dea
delete extra files
divyegala Apr 17, 2026
3fb3df0
reconcile jit and non jit paths
divyegala Apr 18, 2026
c01ec16
merge 1780
divyegala Apr 19, 2026
a6e31c9
remove unneeded wrappers
divyegala Apr 19, 2026
bcf0470
specialize jit cache to reduce contention
divyegala Apr 19, 2026
a0636f3
rework standard/impl functions and fix recipe
divyegala Apr 21, 2026
04769c7
keep the fragments separate
divyegala Apr 21, 2026
9c981e1
more review
divyegala Apr 22, 2026
5eadc1e
Merge remote-tracking branch 'upstream/main' into cagra-search-jit-lto
divyegala Apr 22, 2026
84a93d4
fix recipe
divyegala Apr 22, 2026
f52d8ce
code simplification and ai reviews
divyegala Apr 22, 2026
d7fa7f3
use whole compilation for cagra TUs; ai reviews
divyegala Apr 22, 2026
0053171
address reviews
divyegala Apr 24, 2026
44d0ea2
ai review
divyegala Apr 24, 2026
11711c0
attempt to fix smem launch
divyegala May 1, 2026
1d58136
attempt to fix smem launch
divyegala May 4, 2026
dc29e56
dante review
divyegala May 4, 2026
f329a82
kyle review step 1
divyegala May 4, 2026
7f2fa39
fix ci error
divyegala May 5, 2026
6598f62
Merge remote-tracking branch 'upstream/main' into cagra-search-jit-lto
divyegala May 5, 2026
2f93c6e
kyle review step 2
divyegala May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda/recipes/libcuvs/recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ outputs:
- librmm =${{ minor_version }}
- nccl ${{ nccl_version }}
- cuda-cudart-dev
- cuda-nvrtc-dev
Comment thread
KyleFromNVIDIA marked this conversation as resolved.
- cuda-profiler-api
- libcublas-dev
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libnvjitlink-dev
- cuda-nvrtc-dev
run:
- ${{ pin_subpackage("libcuvs-headers", exact=True) }}
- ${{ pin_subpackage("libcuvs", exact=True) }}
Expand Down
219 changes: 176 additions & 43 deletions cpp/CMakeLists.txt

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,24 @@ struct AlgorithmLauncher {
this->call(stream, grid, block, shared_mem, kernel_args);
}

template <typename FuncT, typename... Args>
void dispatch_cooperative(
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args)
{
static_assert(
std::is_same_v<FuncT, void(std::remove_reference_t<Args>...)>,
"dispatch_cooperative() argument types do not match the kernel function signature FuncT");

void* kernel_args[] = {const_cast<void*>(static_cast<void const*>(&args))...};
this->call_cooperative(stream, grid, block, shared_mem, kernel_args);
}

cudaKernel_t get_kernel() { return this->kernel; }

private:
void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args);
void call_cooperative(
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args);
cudaKernel_t kernel;
cudaLibrary_t library;
};
6 changes: 6 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ struct AlgorithmPlanner {
add_fragment(std::make_unique<StaticFatbinFragmentEntry<FragmentTag>>());
}

protected:
/** Extra link-time option strings passed to nvJitLink. Base build()
* always passes "-lto" and "-arch=sm_XX" first; derived planners may append here in their
* constructor body. */
std::vector<std::string> linktime_extra_options;

private:
std::string get_fragments_key() const;
std::shared_ptr<AlgorithmLauncher> build();
Expand Down
93 changes: 93 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cstdint>

namespace cuvs::neighbors::cagra::detail {

struct tag_dist_f {};
struct tag_metric_l2 {};
struct tag_metric_inner_product {};
struct tag_metric_cosine {};
struct tag_metric_hamming {};
struct tag_codebook_none {};
struct tag_codebook_half {};
struct tag_metric_l1 {};
struct tag_norm_noop {};
struct tag_norm_cosine {};

/// Multi-kernel planners that do not link `sample_filter` into the JIT link (e.g.
/// `random_pickup`). Real filters use `cuvs::neighbors::detail::tag_filter_*` on
/// `CagraPlannerBase`.
struct tag_cagra_jit_sample_filter_link_absent {};

template <typename DataTag,
typename IndexTag,
typename DistanceTag,
typename QueryTag,
typename CodebookTag,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
uint32_t PqBits,
uint32_t PqLen>
struct fragment_tag_setup_workspace {};

template <typename DataTag,
typename IndexTag,
typename DistanceTag,
typename QueryTag,
typename CodebookTag,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
uint32_t PqBits,
uint32_t PqLen>
struct fragment_tag_compute_distance {};

template <typename QueryTag, typename DistanceTag, typename MetricTag>
struct fragment_tag_dist_op {};

template <typename DataTag,
typename IndexTag,
typename DistanceTag,
typename QueryTag,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
typename NormTag>
struct fragment_tag_apply_normalization_standard {};

template <typename DataTag,
typename SourceIndexTag,
typename IndexTag,
typename DistanceTag,
bool TopkByBitonicSort,
bool BitonicSortAndMergeMultiWarps>
struct fragment_tag_search_single_cta {};

template <typename DataTag,
typename SourceIndexTag,
typename IndexTag,
typename DistanceTag,
bool TopkByBitonicSort,
bool BitonicSortAndMergeMultiWarps>
struct fragment_tag_search_single_cta_p {};

template <typename DataTag, typename SourceIndexTag, typename IndexTag, typename DistanceTag>
struct fragment_tag_search_multi_cta {};

template <typename DataTag, typename IndexTag, typename DistanceTag>
struct fragment_tag_random_pickup {};

template <typename DataTag, typename IndexTag, typename DistanceTag, typename SourceIndexTag>
struct fragment_tag_compute_distance_to_child_nodes {};

template <typename IndexTag, typename DistanceTag, typename SourceIndexTag>
struct fragment_tag_apply_filter_kernel {};

template <typename BitsetTag, typename SourceIndexTag, typename FilterTag>
struct fragment_tag_sample_filter {};

} // namespace cuvs::neighbors::cagra::detail
5 changes: 5 additions & 0 deletions cpp/include/cuvs/detail/jit_lto/common_fragments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@

namespace cuvs::neighbors::detail {

struct tag_f {};
struct tag_h {};
struct tag_i8 {};
struct tag_u8 {};
struct tag_filter_none {};
struct tag_filter_bitset {};

struct tag_bitset_u32 {};

struct tag_index_u32 {};
struct tag_index_i64 {};

template <typename BitsetTag, typename IndexTag, typename FilterTag>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@

namespace cuvs::neighbors::ivf_flat::detail {

// Tag types for data types
struct tag_f {};
struct tag_h {};
struct tag_i8 {};
struct tag_u8 {};

// Tag types for accumulator types
struct tag_acc_f {};
struct tag_acc_h {};
struct tag_acc_i32 {};
struct tag_acc_u32 {};

// Tag types for distance metrics with full template info
// Tag types for distance metrics
struct tag_metric_euclidean {};
struct tag_metric_inner_product {};
struct tag_metric_custom_udf {};
Expand Down
19 changes: 18 additions & 1 deletion cpp/src/detail/jit_lto/AlgorithmLauncher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ AlgorithmLauncher::AlgorithmLauncher(AlgorithmLauncher&& other) noexcept
AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexcept
{
if (this != &other) {
// Unload current library if it exists
if (library != nullptr) { cudaLibraryUnload(library); }
kernel = other.kernel;
library = other.library;
Expand All @@ -47,3 +46,21 @@ void AlgorithmLauncher::call(

RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args));
}

void AlgorithmLauncher::call_cooperative(
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args)
{
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeCooperative;
attribute[0].val.cooperative = 1;

cudaLaunchConfig_t config{};
config.gridDim = grid;
config.blockDim = block;
config.stream = stream;
config.dynamicSmemBytes = shared_mem;
config.numAttrs = 1;
config.attrs = attribute;

RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args));
}
10 changes: 8 additions & 2 deletions cpp/src/detail/jit_lto/AlgorithmPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::build()

// Load the generated LTO IR and link them together
nvJitLinkHandle handle;
const char* lopts[] = {"-lto", archs.c_str()};
auto result = nvJitLinkCreate(&handle, 2, lopts);
std::vector<const char*> lopts;
lopts.reserve(2 + linktime_extra_options.size());
lopts.push_back("-lto");
lopts.push_back(archs.c_str());
for (auto const& opt : linktime_extra_options) {
lopts.push_back(opt.c_str());
}
auto result = nvJitLinkCreate(&handle, static_cast<unsigned int>(lopts.size()), lopts.data());
check_nvjitlink_result(handle, result);

for (const auto& frag : this->fragments) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "detail/cagra/cagra_build.cuh"
#include "detail/cagra/cagra_merge.cuh"
#include "detail/cagra/cagra_search.cuh"
#include "detail/cagra/graph_core.cuh"
#include "detail/cagra/jit_lto_kernels/graph_core.cuh"

#include "detail/ann_utils.cuh"
#include <raft/core/device_mdspan.hpp>
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "../../../core/nvtx.hpp"
#include "../../../preprocessing/quantize/vpq_build-ext.cuh"
#include "graph_core.cuh"
#include "jit_lto_kernels/graph_core.cuh"

#include <raft/core/copy.cuh>
#include <raft/core/device_mdarray.hpp>
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/linalg/norm.cuh>
#include <raft/linalg/reduce.cuh>

// All includes are done before opening namespace to avoid nested namespace issues
namespace cuvs::neighbors::cagra::detail {

template <typename DataT,
Expand Down
61 changes: 27 additions & 34 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include "device_common.hpp"
#include "hashmap.hpp"
#include "../neighbors_device_intrinsics.cuh"
#include "jit_lto_kernels/device_memory_ops.hpp"
#include "jit_lto_kernels/hashmap.hpp"
#include "utils.hpp"

#include <cuvs/distance/distance.hpp>
Expand Down Expand Up @@ -137,33 +138,20 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
};
static_assert(sizeof(smem_and_team_size_t) == sizeof(uint32_t));

using setup_workspace_type = const base_type*(const base_type*, void*, const DATA_T*, uint32_t);
using compute_distance_type = DISTANCE_T(const args_t, const INDEX_T);

args_t args;

/** Copy the descriptor and the query into shared memory and do any other work, such as
* initializing the codebook. */
setup_workspace_type* setup_workspace_impl;
/** Compute the distance from the query vector (stored in the smem_workspace) and a dataset vector
* given by the dataset_index. */
compute_distance_type* compute_distance_impl;
/** A placeholder for an implementation-specific pointer. */
void* extra_ptr3;
smem_and_team_size_t smem_and_team_size;

/** Number of records in the database. */
INDEX_T size;

RAFT_INLINE_FUNCTION dataset_descriptor_base_t(setup_workspace_type* setup_workspace_impl,
compute_distance_type* compute_distance_impl,
INDEX_T size,
RAFT_INLINE_FUNCTION dataset_descriptor_base_t(INDEX_T size,
uint32_t dim,
uint32_t team_size_bitshift,
uint32_t smem_ws_size_in_bytes)
: setup_workspace_impl(setup_workspace_impl),
compute_distance_impl(compute_distance_impl),
size(size),
: size(size),
smem_and_team_size(smem_ws_size_in_bytes, team_size_bitshift),
args{nullptr, nullptr, 0, dim, 0, 0}
{
Expand Down Expand Up @@ -191,20 +179,6 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
{
return smem_and_team_size.team_size();
}

RAFT_DEVICE_INLINE_FUNCTION auto setup_workspace(void* smem_ptr,
const DATA_T* queries_ptr,
uint32_t query_id) const -> const base_type*
{
return setup_workspace_impl(this, smem_ptr, queries_ptr, query_id);
}

RAFT_DEVICE_INLINE_FUNCTION auto compute_distance(INDEX_T dataset_index, bool valid) const
-> DISTANCE_T
{
auto per_thread_distances = valid ? compute_distance_impl(args.load(), dataset_index) : 0;
return device::team_sum(per_thread_distances, team_size_bitshift_from_smem());
}
};

/**
Expand All @@ -227,6 +201,14 @@ struct dataset_descriptor_host {
uint32_t smem_ws_size_in_bytes = 0;
uint32_t team_size = 0;

// JIT LTO metadata - stored when descriptor is created
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;
uint32_t dataset_block_dim = 0;
bool is_vpq = false;
uint32_t pq_bits = 0;
uint32_t pq_len = 0;
// Codebook type is determined by DataT for VPQ (always half for now)

struct state {
using ready_t = std::tuple<dev_descriptor_t*, rmm::cuda_stream_view>;
using init_f =
Expand Down Expand Up @@ -270,10 +252,21 @@ struct dataset_descriptor_host {
};

template <typename DescriptorImpl, typename InitF>
dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init)
dataset_descriptor_host(const DescriptorImpl& dd_host,
InitF init,
cuvs::distance::DistanceType metric_val,
uint32_t dataset_block_dim_val,
bool is_vpq_val = false,
uint32_t pq_bits_val = 0,
uint32_t pq_len_val = 0)
: value_{std::make_shared<state>(init, sizeof(DescriptorImpl))},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()},
team_size{dd_host.team_size()}
team_size{dd_host.team_size()},
metric{metric_val},
dataset_block_dim{dataset_block_dim_val},
is_vpq{is_vpq_val},
pq_bits{pq_bits_val},
pq_len{pq_len_val}
{
}

Expand Down
Loading
Loading