diff --git a/conda/recipes/libcuvs/recipe.yaml b/conda/recipes/libcuvs/recipe.yaml index b192f2af5f..aa7a37db44 100644 --- a/conda/recipes/libcuvs/recipe.yaml +++ b/conda/recipes/libcuvs/recipe.yaml @@ -399,13 +399,13 @@ outputs: - librmm =${{ minor_version }} - nccl ${{ nccl_version }} - cuda-cudart-dev + - cuda-nvrtc-dev - cuda-profiler-api - libcublas-dev - libcurand-dev - libcusolver-dev - libcusparse-dev - libnvjitlink-dev - - cuda-nvrtc-dev run: - ${{ pin_subpackage("libcuvs-headers", exact=True) }} - ${{ pin_subpackage("libcuvs", exact=True) }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b33daa635a..41b537a5c3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -315,43 +315,13 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance.cu" @ONLY ) - add_library( - cuvs-cagra-search OBJECT - ${cagra_search_inst_files} - ${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance.cu - ${cagra_compute_distance_standard_inst_files} - ${cagra_compute_distance_vpq_inst_files} - ${cagra_search_multi_cta_inst_files} - ${cagra_search_single_cta_inst_files} - ) - - set_source_files_properties( - ${cagra_compute_distance_standard_inst_files} ${cagra_compute_distance_vpq_inst_files} - PROPERTIES COMPILE_FLAGS -maxrregcount=64 - ) - - set_target_properties( - cuvs-cagra-search - PROPERTIES BUILD_RPATH "\$ORIGIN" - CXX_STANDARD 20 - CXX_STANDARD_REQUIRED ON - CUDA_STANDARD 20 - CUDA_STANDARD_REQUIRED ON - CUDA_SEPARABLE_COMPILATION ON - POSITION_INDEPENDENT_CODE ON - ) - target_link_libraries( - cuvs-cagra-search PRIVATE cuvs::cuvs_cpp_headers - $ - ) - target_compile_options( - cuvs-cagra-search - PRIVATE "$<$:${CUVS_CXX_FLAGS}>" - "$<$:${CUVS_CUDA_FLAGS}>" - "$<$,$>:${CUVS_DEBUG_CUDA_FLAGS}>" - ) - target_include_directories( - cuvs-cagra-search PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src" "${CMAKE_CURRENT_BINARY_DIR}/src" + set(cuvs_cagra_search_cuda_inst_files + ${cagra_search_inst_files} + ${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance.cu + ${cagra_compute_distance_standard_inst_files} + ${cagra_compute_distance_vpq_inst_files} + ${cagra_search_multi_cta_inst_files} + ${cagra_search_single_cta_inst_files} ) if(BUILD_MG_ALGOS) @@ -426,7 +396,7 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${ivf_flat_ns}::fragment_tag_interleaved_scan<${ivf_flat_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, @capacity@, @ascending_value@>" + "${ivf_flat_ns}::fragment_tag_interleaved_scan<${neighbors_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, @capacity@, @ascending_value@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/interleaved_scan" @@ -441,8 +411,9 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${ivf_flat_ns}::fragment_tag_load_and_compute_dist<${ivf_flat_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, @compute_norm_value@, @veclen@>" + "${ivf_flat_ns}::fragment_tag_load_and_compute_dist<${neighbors_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, @compute_norm_value@, @veclen@>" FRAGMENT_TAG_HEADER_FILES "" + "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/load_and_compute_dist" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) @@ -454,7 +425,7 @@ if(NOT BUILD_CPU_ONLY) KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/detail/jit_lto_kernels/metric_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${ivf_flat_ns}::fragment_tag_metric<${ivf_flat_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${ivf_flat_ns}::tag_metric_@metric_name@, @veclen@>" + "${ivf_flat_ns}::fragment_tag_metric<${neighbors_ns}::tag_@data_abbrev@, ${ivf_flat_ns}::tag_acc_@acc_abbrev@, ${ivf_flat_ns}::tag_metric_@metric_name@, @veclen@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/metric" @@ -633,6 +604,168 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_pq/increment_score" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + set(cagra_ns "cuvs::neighbors::cagra::detail") + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_dist_op_@metric_tag@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_dist_op<${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${cagra_ns}::@jit_metric_tag@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/dist_op" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_apply_normalization_standard_@norm_kind@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_apply_normalization_standard<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, @team_size@, @dataset_block_dim@, ${cagra_ns}::tag_norm_@norm_kind@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/apply_normalization_standard" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_search_single_cta_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_single_cta<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, @topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_search_single_cta_p_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_single_cta_p<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, @topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta_p" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_search_multi_cta_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_multi_cta<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_multi_cta" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_random_pickup_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_random_pickup<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/random_pickup" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_compute_distance_to_child_nodes_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_compute_distance_to_child_nodes<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance_to_child_nodes" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_apply_filter" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_apply_filter_kernel<${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/apply_filter" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "sample_filter_@filter_name@_index_@source_index_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_sample_filter<${neighbors_ns}::tag_bitset_@bitset_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@, ${neighbors_ns}::tag_filter_@filter_name@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/filter" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) endblock() # Note that this matrix contains an `arch_includes` placeholder, since we don't currently have a @@ -874,6 +1007,7 @@ if(NOT BUILD_CPU_ONLY) ${iface_flat_inst_files} ${iface_pq_inst_files} src/neighbors/detail/cagra/topk_for_cagra/topk.cu + ${cuvs_cagra_search_cuda_inst_files} src/neighbors/dynamic_batching.cu src/neighbors/composite/index.cu $<$:src/neighbors/cagra.cpp> @@ -985,7 +1119,6 @@ if(NOT BUILD_CPU_ONLY) target_link_libraries(cuvs_objs PUBLIC $) target_compile_definitions(cuvs_objs PUBLIC CUVS_BUILD_MG_ALGOS) - target_compile_definitions(cuvs-cagra-search PUBLIC CUVS_BUILD_MG_ALGOS) endif() set(CUVS_CUSOLVER_DEPENDENCY CUDA::cusolver${_ctk_static_suffix}) @@ -998,7 +1131,7 @@ if(NOT BUILD_CPU_ONLY) ) if(NOT cuvs_compile_mode STREQUAL "static_only") - add_library(cuvs SHARED $ $) + add_library(cuvs SHARED $) add_library(cuvs::cuvs ALIAS cuvs) set_target_properties( cuvs @@ -1059,7 +1192,7 @@ SECTIONS endif() if(NOT cuvs_compile_mode STREQUAL "shared_only") - add_library(cuvs_static STATIC $ $) + add_library(cuvs_static STATIC $) add_library(cuvs::cuvs_static ALIAS cuvs_static) set_target_properties( diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp index ae975630c5..10418c430a 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp @@ -37,10 +37,24 @@ struct AlgorithmLauncher { this->call(stream, grid, block, shared_mem, kernel_args); } + template + void dispatch_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args) + { + static_assert( + std::is_same_v...)>, + "dispatch_cooperative() argument types do not match the kernel function signature FuncT"); + + void* kernel_args[] = {const_cast(static_cast(&args))...}; + this->call_cooperative(stream, grid, block, shared_mem, kernel_args); + } + cudaKernel_t get_kernel() { return this->kernel; } private: void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); + void call_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); cudaKernel_t kernel; cudaLibrary_t library; }; diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp index ccd52f0e43..7f275b1285 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp @@ -44,6 +44,12 @@ struct AlgorithmPlanner { add_fragment(std::make_unique>()); } + protected: + /** Extra link-time option strings passed to nvJitLink. Base build() + * always passes "-lto" and "-arch=sm_XX" first; derived planners may append here in their + * constructor body. */ + std::vector linktime_extra_options; + private: std::string get_fragments_key() const; std::shared_ptr build(); diff --git a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp new file mode 100644 index 0000000000..0b42d79379 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp @@ -0,0 +1,93 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail { + +struct tag_dist_f {}; +struct tag_metric_l2 {}; +struct tag_metric_inner_product {}; +struct tag_metric_cosine {}; +struct tag_metric_hamming {}; +struct tag_codebook_none {}; +struct tag_codebook_half {}; +struct tag_metric_l1 {}; +struct tag_norm_noop {}; +struct tag_norm_cosine {}; + +/// Multi-kernel planners that do not link `sample_filter` into the JIT link (e.g. +/// `random_pickup`). Real filters use `cuvs::neighbors::detail::tag_filter_*` on +/// `CagraPlannerBase`. +struct tag_cagra_jit_sample_filter_link_absent {}; + +template +struct fragment_tag_setup_workspace {}; + +template +struct fragment_tag_compute_distance {}; + +template +struct fragment_tag_dist_op {}; + +template +struct fragment_tag_apply_normalization_standard {}; + +template +struct fragment_tag_search_single_cta {}; + +template +struct fragment_tag_search_single_cta_p {}; + +template +struct fragment_tag_search_multi_cta {}; + +template +struct fragment_tag_random_pickup {}; + +template +struct fragment_tag_compute_distance_to_child_nodes {}; + +template +struct fragment_tag_apply_filter_kernel {}; + +template +struct fragment_tag_sample_filter {}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index 6607ada812..cb33e4109b 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -7,11 +7,16 @@ namespace cuvs::neighbors::detail { +struct tag_f {}; +struct tag_h {}; +struct tag_i8 {}; +struct tag_u8 {}; struct tag_filter_none {}; struct tag_filter_bitset {}; struct tag_bitset_u32 {}; +struct tag_index_u32 {}; struct tag_index_i64 {}; template diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp index 526d094d5c..6f65837741 100644 --- a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp @@ -7,19 +7,13 @@ namespace cuvs::neighbors::ivf_flat::detail { -// Tag types for data types -struct tag_f {}; -struct tag_h {}; -struct tag_i8 {}; -struct tag_u8 {}; - // Tag types for accumulator types struct tag_acc_f {}; struct tag_acc_h {}; struct tag_acc_i32 {}; struct tag_acc_u32 {}; -// Tag types for distance metrics with full template info +// Tag types for distance metrics struct tag_metric_euclidean {}; struct tag_metric_inner_product {}; struct tag_metric_custom_udf {}; diff --git a/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp b/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp index abe519e688..b300f4d367 100644 --- a/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp +++ b/cpp/src/detail/jit_lto/AlgorithmLauncher.cpp @@ -24,7 +24,6 @@ AlgorithmLauncher::AlgorithmLauncher(AlgorithmLauncher&& other) noexcept AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexcept { if (this != &other) { - // Unload current library if it exists if (library != nullptr) { cudaLibraryUnload(library); } kernel = other.kernel; library = other.library; @@ -47,3 +46,21 @@ void AlgorithmLauncher::call( RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); } + +void AlgorithmLauncher::call_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args) +{ + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeCooperative; + attribute[0].val.cooperative = 1; + + cudaLaunchConfig_t config{}; + config.gridDim = grid; + config.blockDim = block; + config.stream = stream; + config.dynamicSmemBytes = shared_mem; + config.numAttrs = 1; + config.attrs = attribute; + + RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); +} diff --git a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp index 0556199ade..7416ea396d 100644 --- a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp +++ b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp @@ -73,8 +73,14 @@ std::shared_ptr AlgorithmPlanner::build() // Load the generated LTO IR and link them together nvJitLinkHandle handle; - const char* lopts[] = {"-lto", archs.c_str()}; - auto result = nvJitLinkCreate(&handle, 2, lopts); + std::vector lopts; + lopts.reserve(2 + linktime_extra_options.size()); + lopts.push_back("-lto"); + lopts.push_back(archs.c_str()); + for (auto const& opt : linktime_extra_options) { + lopts.push_back(opt.c_str()); + } + auto result = nvJitLinkCreate(&handle, static_cast(lopts.size()), lopts.data()); check_nvjitlink_result(handle, result); for (const auto& frag : this->fragments) { diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index f1650980e0..bca8d3314d 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -32,6 +32,7 @@ #include #include +// All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { template const base_type* - { - return setup_workspace_impl(this, smem_ptr, queries_ptr, query_id); - } - - RAFT_DEVICE_INLINE_FUNCTION auto compute_distance(INDEX_T dataset_index, bool valid) const - -> DISTANCE_T - { - auto per_thread_distances = valid ? compute_distance_impl(args.load(), dataset_index) : 0; - return device::team_sum(per_thread_distances, team_size_bitshift_from_smem()); - } }; /** @@ -227,6 +201,14 @@ struct dataset_descriptor_host { uint32_t smem_ws_size_in_bytes = 0; uint32_t team_size = 0; + // JIT LTO metadata - stored when descriptor is created + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; + uint32_t dataset_block_dim = 0; + bool is_vpq = false; + uint32_t pq_bits = 0; + uint32_t pq_len = 0; + // Codebook type is determined by DataT for VPQ (always half for now) + struct state { using ready_t = std::tuple; using init_f = @@ -270,10 +252,21 @@ struct dataset_descriptor_host { }; template - dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init) + dataset_descriptor_host(const DescriptorImpl& dd_host, + InitF init, + cuvs::distance::DistanceType metric_val, + uint32_t dataset_block_dim_val, + bool is_vpq_val = false, + uint32_t pq_bits_val = 0, + uint32_t pq_len_val = 0) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, - team_size{dd_host.team_size()} + team_size{dd_host.team_size()}, + metric{metric_val}, + dataset_block_dim{dataset_block_dim_val}, + is_vpq{is_vpq_val}, + pq_bits{pq_bits_val}, + pq_len{pq_len_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh index 05adce20e9..cde42f849b 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh @@ -7,67 +7,28 @@ #include "compute_distance_standard.hpp" #include -#include #include #include namespace cuvs::neighbors::cagra::detail { -namespace { -template - requires(Metric == cuvs::distance::DistanceType::L2Expanded) -RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) -{ - DISTANCE_T diff = a - b; - return diff * diff; -} - -template - requires(Metric == cuvs::distance::DistanceType::InnerProduct || - Metric == cuvs::distance::DistanceType::CosineExpanded) -RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) -{ - return -static_cast(a) * static_cast(b); -} -template - requires(Metric == cuvs::distance::DistanceType::BitwiseHamming && std::is_integral_v) -RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) -{ - // mask the result of xor for the integer promotion - const auto v = (a ^ b) & 0xffu; - return __popc(v); -} - -template - requires(Metric == cuvs::distance::DistanceType::L1) -RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) -{ - DISTANCE_T diff = a - b; - return raft::abs(diff); -} -} // namespace - -template + typename DistanceT, + typename QueryT> struct standard_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; - using QUERY_T = typename std:: - conditional_t; + using QUERY_T = QueryT; using base_type::args; using base_type::smem_ws_size_in_bytes; using typename base_type::args_t; - using typename base_type::compute_distance_type; using typename base_type::DATA_T; using typename base_type::DISTANCE_T; using typename base_type::INDEX_T; using typename base_type::LOAD_T; - using typename base_type::setup_workspace_type; - constexpr static inline auto kMetric = Metric; constexpr static inline auto kTeamSize = TeamSize; constexpr static inline auto kDatasetBlockDim = DatasetBlockDim; @@ -101,19 +62,12 @@ struct standard_dataset_descriptor_t : public dataset_descriptor_base_t::Log2, - get_smem_ws_size_in_bytes(dim)) + : base_type(size, dim, raft::Pow2::Log2, get_smem_ws_size_in_bytes(dim)) { standard_dataset_descriptor_t::ptr(args) = ptr; standard_dataset_descriptor_t::ld(args) = ld; @@ -122,7 +76,7 @@ struct standard_dataset_descriptor_t : public dataset_descriptor_base_t uint32_t { return sizeof(standard_dataset_descriptor_t) + @@ -130,123 +84,6 @@ struct standard_dataset_descriptor_t : public dataset_descriptor_base_t -_RAFT_DEVICE __noinline__ auto setup_workspace_standard( - const DescriptorT* that, - void* smem_ptr, - const typename DescriptorT::DATA_T* queries_ptr, - uint32_t query_id) -> const DescriptorT* -{ - using DATA_T = typename DescriptorT::DATA_T; - using LOAD_T = typename DescriptorT::LOAD_T; - using base_type = typename DescriptorT::base_type; - using QUERY_T = typename DescriptorT::QUERY_T; - using word_type = uint32_t; - constexpr auto kTeamSize = DescriptorT::kTeamSize; - constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; - auto* r = reinterpret_cast(smem_ptr); - auto* buf = reinterpret_cast(r + 1); - if (r != that) { - constexpr uint32_t kCount = sizeof(DescriptorT) / sizeof(word_type); - using blob_type = word_type[kCount]; - auto& src = reinterpret_cast(*that); - auto& dst = reinterpret_cast(*r); - for (uint32_t i = threadIdx.x; i < kCount; i += blockDim.x) { - dst[i] = src[i]; - } - const auto smem_ptr_offset = - reinterpret_cast(&(r->args.smem_ws_ptr)) - reinterpret_cast(r); - if (threadIdx.x == uint32_t(smem_ptr_offset / sizeof(word_type))) { - r->args.smem_ws_ptr = uint32_t(__cvta_generic_to_shared(buf)); - } - __syncthreads(); - } - - uint32_t dim = r->args.dim; - auto buf_len = raft::round_up_safe(dim, kDatasetBlockDim); - constexpr auto vlen = device::get_vlen(); - queries_ptr += dim * query_id; - for (unsigned i = threadIdx.x; i < buf_len; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dim) { - buf[j] = cuvs::spatial::knn::detail::utils::mapping{}(queries_ptr[i]); - } else { - buf[j] = 0; - } - } - - return const_cast(r); -} - -template -RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker( - const typename DescriptorT::DATA_T* __restrict__ dataset_ptr, - uint32_t dim, - uint32_t query_smem_ptr) -> typename DescriptorT::DISTANCE_T -{ - using DATA_T = typename DescriptorT::DATA_T; - using DISTANCE_T = typename DescriptorT::DISTANCE_T; - using LOAD_T = typename DescriptorT::LOAD_T; - using QUERY_T = typename DescriptorT::QUERY_T; - constexpr auto kTeamSize = DescriptorT::kTeamSize; - constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; - constexpr auto vlen = device::get_vlen(); - constexpr auto reg_nelem = - raft::div_rounding_up_unsafe(kDatasetBlockDim, kTeamSize * vlen); - - DISTANCE_T r = 0; - for (uint32_t elem_offset = (threadIdx.x % kTeamSize) * vlen; elem_offset < dim; - elem_offset += kDatasetBlockDim) { - DATA_T data[reg_nelem][vlen]; -#pragma unroll - for (uint32_t e = 0; e < reg_nelem; e++) { - const uint32_t k = e * (kTeamSize * vlen) + elem_offset; - if (k >= dim) break; - device::ldg_cg(reinterpret_cast(data[e]), - reinterpret_cast(dataset_ptr + k)); - } -#pragma unroll - for (uint32_t e = 0; e < reg_nelem; e++) { - const uint32_t k = e * (kTeamSize * vlen) + elem_offset; - if (k >= dim) break; -#pragma unroll - for (uint32_t v = 0; v < vlen; v++) { - // Note this loop can go above the dataset_dim for padded arrays. This is not a problem - // because: - // - Above the last element (dataset_dim-1), the query array is filled with zeros. - // - The data buffer has to be also padded with zeros. - QUERY_T d; - device::lds( - d, - query_smem_ptr + - sizeof(QUERY_T) * device::swizzling(k + v)); - r += dist_op( - d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); - } - } - } - return r; -} - -template -_RAFT_DEVICE __noinline__ auto compute_distance_standard( - const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) -> - typename DescriptorT::DISTANCE_T -{ - auto distance = compute_distance_standard_worker( - DescriptorT::ptr(args) + (static_cast(DescriptorT::ld(args)) * dataset_index), - args.dim, - args.smem_ws_ptr); - - if constexpr (DescriptorT::kMetric == cuvs::distance::DistanceType::CosineExpanded) { - const auto* dataset_norms = DescriptorT::dataset_norms_ptr(args); - auto norm = dataset_norms[dataset_index]; - if (norm > 0) { distance = distance / norm; } - } - - return distance; -} - template ; using desc_type = - standard_dataset_descriptor_t; + standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; - new (out) desc_type(reinterpret_cast( - &setup_workspace_standard), - reinterpret_cast( - &compute_distance_standard), - ptr, - size, - dim, - ld, - dataset_norms); + + new (out) desc_type(ptr, size, dim, ld, dataset_norms); } template ; using desc_type = - standard_dataset_descriptor_t; + standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; RAFT_EXPECTS(Metric != cuvs::distance::DistanceType::CosineExpanded || dataset_norms != nullptr, "Dataset norms must be provided for CosineExpanded metric"); - desc_type dd_host{nullptr, nullptr, ptr, size, dim, ld, dataset_norms}; - return host_type{dd_host, + return host_type{desc_type{ptr, size, dim, ld, dataset_norms}, [=](dataset_descriptor_base_t* dev_ptr, rmm::cuda_stream_view stream) { standard_dataset_descriptor_init_kernel <<<1, 1, 0, stream>>>(dev_ptr, ptr, size, dim, ld, dataset_norms); RAFT_CUDA_TRY(cudaPeekAtLastError()); - }}; + }, + Metric, + DatasetBlockDim, + false, // is_vpq + 0, // pq_bits + 0}; // pq_len } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index cdafb173ed..6992ae979a 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,36 +8,32 @@ #include "compute_distance_vpq.hpp" #include -#include #include #include namespace cuvs::neighbors::cagra::detail { -template + typename DistanceT, + typename QueryT> struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; - using QUERY_T = half; + using QUERY_T = QueryT; using base_type::args; using base_type::extra_ptr3; using typename base_type::args_t; - using typename base_type::compute_distance_type; using typename base_type::DATA_T; using typename base_type::DISTANCE_T; using typename base_type::INDEX_T; using typename base_type::LOAD_T; - using typename base_type::setup_workspace_type; - constexpr static inline auto kMetric = Metric; constexpr static inline auto kTeamSize = TeamSize; constexpr static inline auto kDatasetBlockDim = DatasetBlockDim; constexpr static inline auto kPqBits = PQ_BITS; @@ -87,20 +83,13 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(); - _RAFT_HOST_DEVICE cagra_q_dataset_descriptor_t(setup_workspace_type* setup_workspace_impl, - compute_distance_type* compute_distance_impl, - const std::uint8_t* encoded_dataset_ptr, + _RAFT_HOST_DEVICE cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr, std::uint32_t encoded_dataset_dim, const CODE_BOOK_T* vq_code_book_ptr, const CODE_BOOK_T* pq_code_book_ptr, IndexT size, std::uint32_t dim) - : base_type(setup_workspace_impl, - compute_distance_impl, - size, - dim, - raft::Pow2::Log2, - get_smem_ws_size_in_bytes(dim)) + : base_type(size, dim, raft::Pow2::Log2, get_smem_ws_size_in_bytes(dim)) { cagra_q_dataset_descriptor_t::encoded_dataset_ptr(args) = encoded_dataset_ptr; cagra_q_dataset_descriptor_t::vq_code_book_ptr(args) = vq_code_book_ptr; @@ -110,7 +99,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t uint32_t { /* SMEM workspace layout: @@ -121,231 +110,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(dim, DatasetBlockDim) * sizeof(QUERY_T); } -}; - -template -RAFT_DEVICE_INLINE_FUNCTION constexpr auto transpose(T x) -> T -{ - auto i = x % Block; - auto j = x / Block; - auto k = i % Stride; - auto l = i / Stride; - return j * Block + k * (Block / Stride) + l; -} - -template -_RAFT_DEVICE __noinline__ auto setup_workspace_vpq(const DescriptorT* that, - void* smem_ptr, - const typename DescriptorT::DATA_T* queries_ptr, - uint32_t query_id) -> const DescriptorT* -{ - using QUERY_T = typename DescriptorT::QUERY_T; - using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; - using word_type = uint32_t; - constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; - constexpr auto PQ_BITS = DescriptorT::kPqBits; - constexpr auto PQ_LEN = DescriptorT::kPqLen; - - auto* r = reinterpret_cast(smem_ptr); - - if (r != that) { - constexpr uint32_t kCount = sizeof(DescriptorT) / sizeof(word_type); - using blob_type = word_type[kCount]; - auto& src = reinterpret_cast(*that); - auto& dst = reinterpret_cast(*r); - for (uint32_t i = threadIdx.x; i < kCount; i += blockDim.x) { - dst[i] = src[i]; - } - - auto codebook_buf = uint32_t(__cvta_generic_to_shared(r + 1)); - const auto smem_ptr_offset = - reinterpret_cast(&(r->args.smem_ws_ptr)) - reinterpret_cast(r); - if (threadIdx.x == uint32_t(smem_ptr_offset / sizeof(word_type))) { - r->args.smem_ws_ptr = codebook_buf; - } - __syncthreads(); - - // Copy PQ table - for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { - half2 buf2; - buf2.x = r->pq_code_book_ptr()[i]; - buf2.y = r->pq_code_book_ptr()[i + 1]; - - // Change the order of PQ code book array to reduce the - // frequency of bank conflicts. - constexpr auto num_elements_per_bank = 4 / utils::size_of(); - constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; - const auto j = i / num_elements_per_bank; - const auto smem_index = - (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); - - device::sts(codebook_buf + smem_index * sizeof(half2), buf2); - } - } - uint32_t dim = r->args.dim; - queries_ptr += dim * query_id; - - constexpr cuvs::spatial::knn::detail::utils::mapping mapping{}; - auto smem_query_ptr = - reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + - DescriptorT::kSMemCodeBookSizeInBytes); - for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { - half2 buf2{0, 0}; - if (i < dim) { buf2.x = mapping(queries_ptr[i]); } - if (i + 1 < dim) { buf2.y = mapping(queries_ptr[i + 1]); } - if constexpr ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { - // Transpose the queries buffer to avoid bank conflicts in compute_distance. - constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** - constexpr auto kStride = vlen * PQ_LEN / 2; - reinterpret_cast(smem_query_ptr)[transpose(i / 2)] = - buf2; - } else { - (reinterpret_cast(smem_query_ptr + i))[0] = buf2; - } - } - - return const_cast(r); -} - -template -_RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker( - const uint8_t* __restrict__ dataset_ptr, - const typename DescriptorT::CODE_BOOK_T* __restrict__ vq_code_book_ptr, - uint32_t dim, - uint32_t pq_codebook_ptr) -> typename DescriptorT::DISTANCE_T -{ - using DISTANCE_T = typename DescriptorT::DISTANCE_T; - using LOAD_T = typename DescriptorT::LOAD_T; - using QUERY_T = typename DescriptorT::QUERY_T; - using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; - constexpr auto TeamSize = DescriptorT::kTeamSize; - constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; - constexpr auto PQ_BITS = DescriptorT::kPqBits; - constexpr auto PQ_LEN = DescriptorT::kPqLen; - - const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes; - static_assert(PQ_BITS == 8, "Only pq_bits == 8 is supported at the moment."); - constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** - constexpr uint32_t nelem = - raft::div_rounding_up_unsafe(DatasetBlockDim / PQ_LEN, TeamSize * vlen); - - constexpr auto kTeamMask = DescriptorT::kTeamSize - 1; - constexpr auto kTeamVLen = TeamSize * vlen; - - const auto n_subspace = raft::div_rounding_up_unsafe(dim, PQ_LEN); - const auto laneId = threadIdx.x & kTeamMask; - DISTANCE_T norm = 0; - for (uint32_t elem_offset = 0; elem_offset * PQ_LEN < dim; - elem_offset += DatasetBlockDim / PQ_LEN) { - // Loading PQ codes - uint32_t pq_codes[nelem]; -#pragma unroll - for (std::uint32_t e = 0; e < nelem; e++) { - const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; - if (k >= n_subspace) break; - // Loading 4 x 8-bit PQ-codes using 32-bit load ops (from device memory) - device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); - } - // - if constexpr (PQ_LEN % 2 == 0) { - // **** Use half2 for distance computation **** -#pragma unroll - for (std::uint32_t e = 0; e < nelem; e++) { - const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; - if (k >= n_subspace) break; - // Loading VQ code-book - half2 vq_vals[PQ_LEN][vlen / 2]; -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m++) { - const uint32_t d = (vlen * m) + (PQ_LEN * k); - if (d >= dim) break; - device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); - } - // Compute distance - std::uint32_t pq_code = pq_codes[e]; -#pragma unroll - for (std::uint32_t v = 0; v < vlen; v++) { - if (PQ_LEN * (v + k) >= dim) break; -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN / 2; m++) { - constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); - const std::uint32_t d1 = m + (PQ_LEN / 2) * v; - const std::uint32_t d = - d1 * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; - half2 q2, c2; - // Loading query vector from smem - device::lds(q2, query_ptr + sizeof(half2) * d); - // Loading PQ code book from smem - device::lds(c2, - pq_codebook_ptr + - sizeof(CODE_BOOK_T) * ((1 << PQ_BITS) * 2 * m + (2 * (pq_code & 0xff)))); - // L2 distance - auto dist = q2 - c2 - reinterpret_cast(vq_vals)[d1]; - dist = dist * dist; - norm += static_cast(dist.x + dist.y); - } - pq_code >>= 8; - } - } - } else { - // **** Use float for distance computation **** -#pragma unroll - for (std::uint32_t e = 0; e < nelem; e++) { - const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; - if (k >= n_subspace) break; - // Loading VQ code-book - CODE_BOOK_T vq_vals[PQ_LEN][vlen]; -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m++) { - const std::uint32_t d = (vlen * m) + (PQ_LEN * k); - if (d >= dim) break; - // Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device memory) - device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); - } - // Compute distance - std::uint32_t pq_code = pq_codes[e]; -#pragma unroll - for (std::uint32_t v = 0; v < vlen; v++) { - if (PQ_LEN * (v + k) >= dim) break; - CODE_BOOK_T pq_vals[PQ_LEN]; - device::lds(pq_vals, pq_codebook_ptr + sizeof(CODE_BOOK_T) * PQ_LEN * (pq_code & 0xff)); -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m++) { - const std::uint32_t d1 = m + (PQ_LEN * v); - const std::uint32_t d = d1 + (PQ_LEN * k); - // if (d >= dataset_dim) break; - DISTANCE_T diff; - device::lds(diff, query_ptr + sizeof(QUERY_T) * d); - diff -= static_cast(pq_vals[m]); - diff -= - static_cast(reinterpret_cast(vq_vals)[d1]); - norm += diff * diff; - } - pq_code >>= 8; - } - } - } - } - return norm; -} - -template -_RAFT_DEVICE __noinline__ auto compute_distance_vpq( - const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) -> - typename DescriptorT::DISTANCE_T -{ - const auto* dataset_ptr = - DescriptorT::encoded_dataset_ptr(args) + - (static_cast(DescriptorT::encoded_dataset_dim(args)) * dataset_index); - uint32_t vq_code; - device::ldg_cg(vq_code, reinterpret_cast(dataset_ptr)); - return compute_distance_vpq_worker( - dataset_ptr /* advance dataset pointer by the size of vq_code */, - DescriptorT::vq_code_book_ptr(args) + args.dim * vq_code, - args.dim, - args.smem_ws_ptr); -} + private: +}; template ; - using base_type = typename desc_type::base_type; + DistanceT, + half>; new (out) desc_type( - reinterpret_cast(&setup_workspace_vpq), - reinterpret_cast(&compute_distance_vpq), - encoded_dataset_ptr, - encoded_dataset_dim, - vq_code_book_ptr, - pq_code_book_ptr, - size, - dim); + encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim); } template ; - using base_type = typename desc_type::base_type; - - desc_type dd_host{nullptr, - nullptr, - encoded_dataset_ptr, - encoded_dataset_dim, - vq_code_book_ptr, - pq_code_book_ptr, - size, - dim}; - return host_type{dd_host, - [=](dataset_descriptor_base_t* dev_ptr, - rmm::cuda_stream_view stream) { - vpq_dataset_descriptor_init_kernel - <<<1, 1, 0, stream>>>(dev_ptr, - encoded_dataset_ptr, - encoded_dataset_dim, - vq_code_book_ptr, - pq_code_book_ptr, - size, - dim); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - }}; + DistanceT, + half>; + + return host_type{ + desc_type{ + encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim}, + [=](dataset_descriptor_base_t* dev_ptr, + rmm::cuda_stream_view stream) { + vpq_dataset_descriptor_init_kernel<<<1, 1, 0, stream>>>(dev_ptr, + encoded_dataset_ptr, + encoded_dataset_dim, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }, + Metric, + DatasetBlockDim, + true, // is_vpq + PqBits, // pq_bits + PqLen}; // pq_len } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp deleted file mode 100644 index 0b75de6bab..0000000000 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ /dev/null @@ -1,385 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "hashmap.hpp" -#include "utils.hpp" - -#include - -// TODO: This shouldn't be invoking anything in detail APIs outside of cuvs/neighbors -#include -#include -#include - -#include - -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace device { - -// warpSize for compile time calculation -constexpr unsigned warp_size = 32; - -// using LOAD_256BIT_T = ulonglong4; -using LOAD_128BIT_T = uint4; -using LOAD_64BIT_T = uint64_t; - -template -RAFT_DEVICE_INLINE_FUNCTION constexpr unsigned get_vlen() -{ - return utils::size_of() / utils::size_of(); -} - -/** Xorshift rondem number generator. - * - * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. - */ -_RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u) -{ - u ^= u >> 12; - u ^= u << 25; - u ^= u >> 27; - return u * 0x2545F4914F6CDD1DULL; -} - -template -RAFT_DEVICE_INLINE_FUNCTION constexpr auto swizzling(T x) -> T -{ - // Address swizzling reduces bank conflicts in shared memory, but increases - // the amount of operation instead. - // return x; - if constexpr (Stride <= 32) { - return x; - } else if constexpr (Dim <= 1024) { - return x ^ (x >> 5); - } else { - return x ^ ((x >> 5) & 0x1f); - } -} - -template -RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x) -> T -{ -#pragma unroll - for (uint32_t stride = TeamSize >> 1; stride > 0; stride >>= 1) { - x += raft::shfl_xor(x, stride, TeamSize); - } - return x; -} - -template -RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size_bitshift) -> T -{ - switch (team_size_bitshift) { - case 5: x += raft::shfl_xor(x, 16); [[fallthrough]]; - case 4: x += raft::shfl_xor(x, 8); [[fallthrough]]; - case 3: x += raft::shfl_xor(x, 4); [[fallthrough]]; - case 2: x += raft::shfl_xor(x, 2); [[fallthrough]]; - case 1: x += raft::shfl_xor(x, 1); [[fallthrough]]; - default: return x; - } -} - -template -RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes( - IndexT* __restrict__ result_indices_ptr, // [num_pickup] - DistanceT* __restrict__ result_distances_ptr, // [num_pickup] - const DATASET_DESCRIPTOR_T& dataset_desc, - const uint32_t num_pickup, - const uint32_t num_distilation, - const uint64_t rand_xor_mask, - const IndexT* __restrict__ seed_ptr, // [num_seeds] - const uint32_t num_seeds, - IndexT* __restrict__ visited_hash_ptr, - const uint32_t visited_hash_bitlen, - IndexT* __restrict__ traversed_hash_ptr, - const uint32_t traversed_hash_bitlen, - const uint32_t block_id = 0, - const uint32_t num_blocks = 1, - const IndexT graph_size = 0) -{ - const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem(); - const auto max_i = raft::round_up_safe(num_pickup, warp_size >> team_size_bits); - const auto compute_distance = dataset_desc.compute_distance_impl; - const IndexT seed_index_limit = graph_size > 0 ? graph_size : dataset_desc.size; - - for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) { - const bool valid_i = (i < num_pickup); - - IndexT best_index_team_local = raft::upper_bound(); - DistanceT best_norm2_team_local = raft::upper_bound(); - for (uint32_t j = 0; j < num_distilation; j++) { - // Select a node randomly and compute the distance to it - IndexT seed_index = 0; - if (valid_i) { - // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); - uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); - if (seed_ptr && (gid < num_seeds)) { - seed_index = seed_ptr[gid]; - } else { - seed_index = device::xorshift64(gid ^ rand_xor_mask) % seed_index_limit; - } - } - - const auto norm2 = dataset_desc.compute_distance(seed_index, valid_i); - - if (valid_i && (norm2 < best_norm2_team_local)) { - best_norm2_team_local = norm2; - best_index_team_local = seed_index; - } - } - - const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); - if (valid_i && lane_id == 0) { - if (best_index_team_local != raft::upper_bound()) { - if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { - // Deactivate this entry as insertion into visited hash table has failed. - best_norm2_team_local = raft::upper_bound(); - best_index_team_local = raft::upper_bound(); - } else if ((traversed_hash_ptr != nullptr) && - hashmap::search( - traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { - // Deactivate this entry as it has been already used by others. - best_norm2_team_local = raft::upper_bound(); - best_index_team_local = raft::upper_bound(); - } - } - result_distances_ptr[i] = best_norm2_team_local; - result_indices_ptr[i] = best_index_team_local; - } - } -} - -template -RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes( - IndexT* __restrict__ result_child_indices_ptr, - DistanceT* __restrict__ result_child_distances_ptr, - // [dataset_dim, dataset_size] - const DATASET_DESCRIPTOR_T& dataset_desc, - // [knn_k, dataset_size] - const IndexT* __restrict__ knn_graph, - const uint32_t knn_k, - // hashmap - IndexT* __restrict__ visited_hashmap_ptr, - const uint32_t visited_hash_bitlen, - IndexT* __restrict__ traversed_hashmap_ptr, - const uint32_t traversed_hash_bitlen, - const IndexT* __restrict__ parent_indices, - const IndexT* __restrict__ internal_topk_list, - const uint32_t search_width, - int* __restrict__ result_position = nullptr, - const int max_result_position = 0) -{ - constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; - constexpr IndexT invalid_index = ~static_cast(0); - - // Read child indices of parents from knn graph and check if the distance - // computaiton is necessary. - for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { - const IndexT smem_parent_id = parent_indices[i / knn_k]; - IndexT child_id = invalid_index; - if (smem_parent_id != invalid_index) { - const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; - child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; - } - if (child_id != invalid_index) { - if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { - // Deactivate this entry as insertion into visited hash table has failed. - child_id = invalid_index; - } else if ((traversed_hashmap_ptr != nullptr) && - hashmap::search( - traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { - // Deactivate this entry as this has been already used by others. - child_id = invalid_index; - } - } - if (STATIC_RESULT_POSITION) { - result_child_indices_ptr[i] = child_id; - } else if (child_id != invalid_index) { - int j = atomicSub(result_position, 1) - 1; - result_child_indices_ptr[j] = child_id; - } - } - __syncthreads(); - - // Compute the distance to child nodes - const auto team_size_bits = dataset_desc.team_size_bitshift_from_smem(); - const auto num_k = knn_k * search_width; - const auto max_i = raft::round_up_safe(num_k, warp_size >> team_size_bits); - const auto compute_distance = dataset_desc.compute_distance_impl; - const auto args = dataset_desc.args.load(); - const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; - const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; - for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { - const auto j = i + ofst; - const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); - const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; - - // We should be calling `dataset_desc.compute_distance(..)` here as follows: - // > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index); - // Instead, we manually inline this function for performance reasons. - // This allows us to move the fetching of the arguments from shared memory out of the loop. - const DistanceT child_dist = device::team_sum( - (child_id != invalid_index) ? compute_distance(args, child_id) - : (lead_lane ? raft::upper_bound() : 0), - team_size_bits); - __syncwarp(); - - // Store the distance - if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } - } -} - -RAFT_DEVICE_INLINE_FUNCTION void lds(float& x, uint32_t addr) -{ - asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x) : "r"(addr)); -} -RAFT_DEVICE_INLINE_FUNCTION void lds(half& x, uint32_t addr) -{ - asm volatile("ld.shared.u16 {%0}, [%1];" : "=h"(reinterpret_cast(x)) : "r"(addr)); -} -RAFT_DEVICE_INLINE_FUNCTION void lds(half2& x, uint32_t addr) -{ - asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(reinterpret_cast(x)) : "r"(addr)); -} -RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[1], uint32_t addr) -{ - asm volatile("ld.shared.u16 {%0}, [%1];" : "=h"(*reinterpret_cast(x)) : "r"(addr)); -} -RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[2], uint32_t addr) -{ - asm volatile("ld.shared.v2.u16 {%0, %1}, [%2];" - : "=h"(*reinterpret_cast(x)), "=h"(*reinterpret_cast(x + 1)) - : "r"(addr)); -} -RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[4], uint32_t addr) -{ - asm volatile("ld.shared.v4.u16 {%0, %1, %2, %3}, [%4];" - : "=h"(*reinterpret_cast(x)), - "=h"(*reinterpret_cast(x + 1)), - "=h"(*reinterpret_cast(x + 2)), - "=h"(*reinterpret_cast(x + 3)) - : "r"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void lds(uint8_t& x, uint32_t addr) -{ - uint32_t res; - asm volatile("ld.shared.u8 {%0}, [%1];" : "=r"(res) : "r"(addr)); - x = static_cast(res); -} - -RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, uint32_t addr) -{ - asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "r"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, const uint32_t* addr) -{ - lds(x, uint32_t(__cvta_generic_to_shared(addr))); -} - -RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, uint32_t addr) -{ - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w) - : "r"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, const uint4* addr) -{ - lds(x, uint32_t(__cvta_generic_to_shared(addr))); -} - -RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const half2& x) -{ - asm volatile("st.shared.v2.u16 [%0], {%1, %2};" - : - : "r"(addr), - "h"(reinterpret_cast(x.x)), - "h"(reinterpret_cast(x.y))); -} - -RAFT_DEVICE_INLINE_FUNCTION void ldg_cg(uint4& x, const uint4* addr) -{ - asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w) - : "l"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(uint4& x, const uint4* addr) -{ - asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w) - : "l"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(uint32_t& x, const uint32_t* addr) -{ - asm volatile("ld.global.ca.u32 %0, [%1];" : "=r"(x) : "l"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void ldg_cg(uint32_t& x, const uint32_t* addr) -{ - asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x) : "l"(addr)); -} - -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half& x, const half* addr) -{ - asm volatile("ld.global.ca.u16 {%0}, [%1];" - : "=h"(reinterpret_cast(x)) - : "l"(reinterpret_cast(addr))); -} -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[1], const half* addr) -{ - asm volatile("ld.global.ca.u16 {%0}, [%1];" - : "=h"(*reinterpret_cast(x)) - : "l"(reinterpret_cast(addr))); -} -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[2], const half* addr) -{ - asm volatile("ld.global.ca.v2.u16 {%0, %1}, [%2];" - : "=h"(*reinterpret_cast(x)), "=h"(*reinterpret_cast(x + 1)) - : "l"(reinterpret_cast(addr))); -} -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[4], const half* addr) -{ - asm volatile("ld.global.ca.v4.u16 {%0, %1, %2, %3}, [%4];" - : "=h"(*reinterpret_cast(x)), - "=h"(*reinterpret_cast(x + 1)), - "=h"(*reinterpret_cast(x + 2)), - "=h"(*reinterpret_cast(x + 3)) - : "l"(reinterpret_cast(addr))); -} - -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2& x, const half* addr) -{ - asm volatile("ld.global.ca.u32 %0, [%1];" - : "=r"(reinterpret_cast(x)) - : "l"(reinterpret_cast(addr))); -} -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2 (&x)[1], const half* addr) -{ - asm volatile("ld.global.ca.u32 %0, [%1];" - : "=r"(*reinterpret_cast(x)) - : "l"(reinterpret_cast(addr))); -} -RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2 (&x)[2], const half* addr) -{ - asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];" - : "=r"(*reinterpret_cast(x)), "=r"(*reinterpret_cast(x + 1)) - : "l"(reinterpret_cast(addr))); -} - -} // namespace device -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp new file mode 100644 index 0000000000..cc164994ea --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp @@ -0,0 +1,154 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +#include + +namespace cuvs::neighbors::cagra::detail::device { + +RAFT_DEVICE_INLINE_FUNCTION void lds(float& x, uint32_t addr) +{ + asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x) : "r"(addr)); +} +RAFT_DEVICE_INLINE_FUNCTION void lds(half& x, uint32_t addr) +{ + asm volatile("ld.shared.u16 {%0}, [%1];" : "=h"(reinterpret_cast(x)) : "r"(addr)); +} +RAFT_DEVICE_INLINE_FUNCTION void lds(half2& x, uint32_t addr) +{ + asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(reinterpret_cast(x)) : "r"(addr)); +} +RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[1], uint32_t addr) +{ + asm volatile("ld.shared.u16 {%0}, [%1];" : "=h"(*reinterpret_cast(x)) : "r"(addr)); +} +RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[2], uint32_t addr) +{ + asm volatile("ld.shared.v2.u16 {%0, %1}, [%2];" + : "=h"(*reinterpret_cast(x)), "=h"(*reinterpret_cast(x + 1)) + : "r"(addr)); +} +RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[4], uint32_t addr) +{ + asm volatile("ld.shared.v4.u16 {%0, %1, %2, %3}, [%4];" + : "=h"(*reinterpret_cast(x)), + "=h"(*reinterpret_cast(x + 1)), + "=h"(*reinterpret_cast(x + 2)), + "=h"(*reinterpret_cast(x + 3)) + : "r"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void lds(uint8_t& x, uint32_t addr) +{ + uint32_t res; + asm volatile("ld.shared.u8 {%0}, [%1];" : "=r"(res) : "r"(addr)); + x = static_cast(res); +} + +RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, uint32_t addr) +{ + asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "r"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, const uint32_t* addr) +{ + lds(x, uint32_t(__cvta_generic_to_shared(addr))); +} + +RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, uint32_t addr) +{ + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w) + : "r"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, const uint4* addr) +{ + lds(x, uint32_t(__cvta_generic_to_shared(addr))); +} + +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const half2& x) +{ + asm volatile("st.shared.v2.u16 [%0], {%1, %2};" + : + : "r"(addr), + "h"(reinterpret_cast(x.x)), + "h"(reinterpret_cast(x.y))); +} + +RAFT_DEVICE_INLINE_FUNCTION void ldg_cg(uint4& x, const uint4* addr) +{ + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w) + : "l"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(uint4& x, const uint4* addr) +{ + asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w) + : "l"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(uint32_t& x, const uint32_t* addr) +{ + asm volatile("ld.global.ca.u32 %0, [%1];" : "=r"(x) : "l"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void ldg_cg(uint32_t& x, const uint32_t* addr) +{ + asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x) : "l"(addr)); +} + +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half& x, const half* addr) +{ + asm volatile("ld.global.ca.u16 {%0}, [%1];" + : "=h"(reinterpret_cast(x)) + : "l"(reinterpret_cast(addr))); +} +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[1], const half* addr) +{ + asm volatile("ld.global.ca.u16 {%0}, [%1];" + : "=h"(*reinterpret_cast(x)) + : "l"(reinterpret_cast(addr))); +} +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[2], const half* addr) +{ + asm volatile("ld.global.ca.v2.u16 {%0, %1}, [%2];" + : "=h"(*reinterpret_cast(x)), "=h"(*reinterpret_cast(x + 1)) + : "l"(reinterpret_cast(addr))); +} +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[4], const half* addr) +{ + asm volatile("ld.global.ca.v4.u16 {%0, %1, %2, %3}, [%4];" + : "=h"(*reinterpret_cast(x)), + "=h"(*reinterpret_cast(x + 1)), + "=h"(*reinterpret_cast(x + 2)), + "=h"(*reinterpret_cast(x + 3)) + : "l"(reinterpret_cast(addr))); +} + +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2& x, const half* addr) +{ + asm volatile("ld.global.ca.u32 %0, [%1];" + : "=r"(reinterpret_cast(x)) + : "l"(reinterpret_cast(addr))); +} +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2 (&x)[1], const half* addr) +{ + asm volatile("ld.global.ca.u32 %0, [%1];" + : "=r"(*reinterpret_cast(x)) + : "l"(reinterpret_cast(addr))); +} +RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2 (&x)[2], const half* addr) +{ + asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];" + : "=r"(*reinterpret_cast(x)), "=r"(*reinterpret_cast(x + 1)) + : "l"(reinterpret_cast(addr))); +} + +} // namespace cuvs::neighbors::cagra::detail::device diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in new file mode 100644 index 0000000000..49d5d2fa07 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in @@ -0,0 +1,42 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +extern "C" __global__ void apply_filter_kernel(const source_index_t* const source_indices_ptr, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const std::uint32_t query_id_offset, + cagra_bitset_t bitset) +{ + apply_filter_kernel_jit(source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + bitset); +} + +static_assert(std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json new file mode 100644 index 0000000000..450eb5dacd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json @@ -0,0 +1,20 @@ +{ + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "u32" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_impl.cuh new file mode 100644 index 0000000000..c2a182daf7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_impl.cuh @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DistanceT apply_normalization_standard_noop_impl( + DistanceT distance, + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + (void)args; + (void)dataset_index; + return distance; +} + +template +__device__ DistanceT apply_normalization_standard_cosine_impl( + DistanceT distance, + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + const auto* dataset_norms = + standard_dataset_descriptor_t:: + dataset_norms_ptr(args); + auto norm = dataset_norms[dataset_index]; + if (norm > 0) { distance = distance / norm; } + return distance; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in new file mode 100644 index 0000000000..13eb1c745b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in @@ -0,0 +1,42 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace { + +constexpr uint32_t k_team_size = @team_size@u; +constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using query_t = @query_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +using args_t = typename dataset_descriptor_base_t::args_t; + +template <> +__device__ distance_t apply_normalization_standard(distance_t distance, + const args_t args, + index_t dataset_index) +{ + return apply_normalization_standard_@norm_kind@_impl(distance, args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json new file mode 100644 index 0000000000..d11988fe60 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json @@ -0,0 +1,66 @@ +{ + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_normalization": [ + { + "norm_kind": "noop", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "u8" + } + ] + }, + { + "norm_kind": "cosine", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh new file mode 100644 index 0000000000..ea01c2ce78 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh @@ -0,0 +1,53 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "../../../sample_filter.cuh" // bitset_filter, none_sample_filter +#include "../../sample_filter_data.cuh" + +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +using cagra_bitset = cuvs::neighbors::detail::bitset_filter_data_t; + +/// Host: bitset payload for kernels plus query offset for wrapped CAGRA filters. +template +struct cagra_sample_filter { + cagra_bitset bitset{}; + std::uint32_t query_id_offset{0}; +}; + +template +struct is_bitset_filter : std::false_type {}; + +template +struct is_bitset_filter> + : std::true_type {}; + +/// Host: fill @ref cagra_sample_filter from a CAGRA filter object (used by JIT LTO launchers). +template +cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter) +{ + cagra_sample_filter out; + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + out.query_id_offset = sample_filter.offset; + using InnerFilter = decltype(sample_filter.filter); + if constexpr (is_bitset_filter::value) { + const auto bitset_view = sample_filter.filter.view(); + out.bitset.bitset_ptr = const_cast(bitset_view.data()); + out.bitset.bitset_len = static_cast(bitset_view.size()); + out.bitset.original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + return out; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp new file mode 100644 index 0000000000..717b6c7dfd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -0,0 +1,360 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance.hpp" +#include "../shared_launcher_jit.hpp" +#include "search_multi_cta_planner.hpp" +#include "search_multi_kernel_planner.hpp" +#include "search_single_cta_planner.hpp" + +#include +#include + +#include + +namespace cuvs::neighbors::cagra::detail { + +namespace cagra_jit_launcher_factory_detail { + +template +std::shared_ptr build_single_cta_launcher( + const dataset_descriptor_host& dataset_desc, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + bool persistent) +{ + single_cta_search::CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + persistent); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_kernel_fragment( + topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + planner.add_sample_filter_device_function(); + return planner.get_launcher(); +} + +template +std::shared_ptr build_multi_cta_launcher( + const dataset_descriptor_host& dataset_desc) +{ + multi_cta_search::CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_search_multi_cta_kernel_fragment(); + planner.add_sample_filter_device_function(); + return planner.get_launcher(); +} + +template +std::shared_ptr build_multi_kernel_launcher( + const dataset_descriptor_host& dataset_desc, + const char* linked_kernel_name) +{ + multi_kernel_search::CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + linked_kernel_name, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(); + planner.add_linked_kernel(linked_kernel_name); + return planner.get_launcher(); +} + +} // namespace cagra_jit_launcher_factory_detail + +/// Build a JIT AlgorithmLauncher for single-CTA CAGRA search (runtime VPQ / metric → tag +/// dispatch). `SampleFilterJitTag` is `cuvs::neighbors::detail::tag_filter_none`, +/// `tag_filter_bitset`, or use `sample_filter_jit_tag_t`. +template +std::shared_ptr make_cagra_single_cta_jit_launcher( + const dataset_descriptor_host& dataset_desc, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + bool persistent) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + return cagra_jit_launcher_factory_detail::build_single_cta_launcher( + dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_single_cta_launcher( + dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + } + using QueryTag = query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_single_cta_launcher( + dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); +} + +/// Build a JIT AlgorithmLauncher for multi-CTA CAGRA search. +template +std::shared_ptr make_cagra_multi_cta_jit_launcher( + const dataset_descriptor_host& dataset_desc) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + return cagra_jit_launcher_factory_detail::build_multi_cta_launcher(dataset_desc); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_multi_cta_launcher(dataset_desc); + } + using QueryTag = query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_multi_cta_launcher(dataset_desc); +} + +/// Build a JIT AlgorithmLauncher for multi-kernel CAGRA helpers (random_pickup, compute_distance, +/// …). Use `SampleFilterJitTag = tag_cagra_jit_sample_filter_link_absent` (default) when the kernel +/// does not link `sample_filter`; otherwise `sample_filter_jit_tag_t` or a +/// `tag_filter_*` from `common_fragments.hpp`. +template +std::shared_ptr make_cagra_multi_kernel_jit_launcher( + const dataset_descriptor_host& dataset_desc, + const char* linked_kernel_name) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + return cagra_jit_launcher_factory_detail::build_multi_kernel_launcher( + dataset_desc, linked_kernel_name); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_multi_kernel_launcher( + dataset_desc, linked_kernel_name); + } + using QueryTag = query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_multi_kernel_launcher( + dataset_desc, linked_kernel_name); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp new file mode 100644 index 0000000000..98a4030f80 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -0,0 +1,249 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +struct CagraPlannerBase : AlgorithmPlanner { + using DataTag = DataTag_; + using IndexTag = IndexTag_; + using DistanceTag = DistanceTag_; + using QueryTag = QueryTag_; + using CodebookTag = CodebookTag_; + using SampleFilterJitTag = SampleFilterJitTag_; + + explicit CagraPlannerBase(std::string entrypoint, LauncherJitCache& jit_cache) + : AlgorithmPlanner(std::move(entrypoint), jit_cache) + { + linktime_extra_options.push_back("-maxrregcount=64"); + } + + void add_setup_workspace_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len) + { + (void)metric; + (void)is_vpq; + (void)pq_bits; + auto add = [&]() { + this->add_static_fragment>(); + }; + if constexpr (std::is_same_v) { + if (pq_bits != 0 || pq_len != 0) { + RAFT_FAIL("CAGRA JIT standard path expects pq_bits==0 and pq_len==0"); + } + dispatch_cagra_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); + } else { + if (pq_bits != 8 || (pq_len != 2 && pq_len != 4)) { + RAFT_FAIL("CAGRA JIT VPQ path expects pq_bits==8 and pq_len in {2,4}"); + } + dispatch_cagra_team_dim( + team_size, dataset_block_dim, [&add, pq_len]() { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + }); + } + } + + void add_compute_distance_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len) + { + (void)is_vpq; + // Dist/normalization apply only to standard codebook; constexpr avoids instantiating them + // with VPQ's QueryTag=tag_h (runtime !is_vpq would still instantiate those templates). + if constexpr (std::is_same_v) { + add_dist_op_device_function(metric); + add_normalization_device_function(metric, team_size, dataset_block_dim); + } + auto add = [&]() { + this->add_static_fragment>(); + }; + if constexpr (std::is_same_v) { + if (pq_bits != 0 || pq_len != 0) { + RAFT_FAIL("CAGRA JIT standard path expects pq_bits==0 and pq_len==0"); + } + dispatch_cagra_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); + } else { + if (pq_bits != 8 || (pq_len != 2 && pq_len != 4)) { + RAFT_FAIL("CAGRA JIT VPQ path expects pq_bits==8 and pq_len in {2,4}"); + } + dispatch_cagra_team_dim( + team_size, dataset_block_dim, [&add, pq_len]() { + if (pq_len == 2) { + add.template operator()(); + } else { + add.template operator()(); + } + }); + } + } + + private: + void add_dist_op_device_function(cuvs::distance::DistanceType metric) + { + // dist_op_matrix.json pairs tag_metric_hamming with uint8 query (tag_u8) only; L2/IP/L1 use + // float query (tag_f). A single switch over metric would still instantiate every case for each + // QueryTag, pulling in fragment types that have no fatbin (e.g. tag_u8 + L2). + if constexpr (std::is_same_v) { + if (metric != cuvs::distance::DistanceType::BitwiseHamming) { + RAFT_FAIL( + "CAGRA JIT uint8 query layout (tag_u8) only supports BitwiseHamming for dist_op " + "fragments"); + } + this->add_static_fragment>(); + } else { + switch (metric) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2Unexpanded: + this->add_static_fragment>(); + break; + case cuvs::distance::DistanceType::InnerProduct: + this->add_static_fragment< + fragment_tag_dist_op>(); + break; + case cuvs::distance::DistanceType::CosineExpanded: + this->add_static_fragment< + fragment_tag_dist_op>(); + break; + case cuvs::distance::DistanceType::BitwiseHamming: + // Matrix only emits hamming dist_op for tag_u8; float-query layout is not built. + RAFT_FAIL( + "CAGRA JIT BitwiseHamming dist_op is only registered for uint8_t data / tag_u8 query " + "layout"); + break; + case cuvs::distance::DistanceType::L1: + this->add_static_fragment>(); + break; + default: RAFT_FAIL("Unsupported metric for CAGRA JIT dist_op"); + } + } + } + + void add_normalization_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim) + { + auto go = [&]() { + dispatch_cagra_team_dim(team_size, dataset_block_dim, [&]() { + this->add_static_fragment>(); + }); + }; + // tag_u8 is only used for BitwiseHamming query layout; cosine norm fragments are built for + // float query tag. Use if constexpr so we do not instantiate tag_norm_cosine with tag_u8 + // (a runtime metric check would still pull in those template specializations). + if constexpr (std::is_same_v) { + go.template operator()(); + } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { + go.template operator()(); + } else { + go.template operator()(); + } + } + + public: + // Maps runtime dataset layout (same grid as the JIT matrix) to uint32_t team / block-dim + // template parameters; CAGRA reads team_size / dataset_block_dim from the host descriptor at + // planning time. + template + static void dispatch_cagra_team_dim(uint32_t team_size, uint32_t dataset_block_dim, Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<8u, 128u>(); return; + case 256: std::forward(l).template operator()<8u, 256u>(); return; + case 512: std::forward(l).template operator()<8u, 512u>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<16u, 128u>(); return; + case 256: std::forward(l).template operator()<16u, 256u>(); return; + case 512: std::forward(l).template operator()<16u, 512u>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<32u, 128u>(); return; + case 256: std::forward(l).template operator()<32u, 256u>(); return; + case 512: std::forward(l).template operator()<32u, 512u>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL("Unsupported team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + void add_sample_filter_device_function() + { + if constexpr (!std::is_same_v) { + this->add_static_fragment>(); + } + } +}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh new file mode 100644 index 0000000000..92a014bd2f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "extern_device_functions.cuh" + +#include "../../neighbors_device_intrinsics.cuh" +#include "../compute_distance_standard-impl.cuh" +#include "../compute_distance_standard.hpp" +#include "../compute_distance_vpq-impl.cuh" +#include "../compute_distance_vpq.hpp" +#include "../device_memory_ops.hpp" + +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker_impl( + const typename DescriptorT::DATA_T* __restrict__ dataset_ptr, + uint32_t dim, + uint32_t query_smem_ptr) -> typename DescriptorT::DISTANCE_T +{ + using DATA_T = typename DescriptorT::DATA_T; + using DISTANCE_T = typename DescriptorT::DISTANCE_T; + using LOAD_T = typename DescriptorT::LOAD_T; + using QUERY_T = typename DescriptorT::QUERY_T; + constexpr auto kTeamSize = DescriptorT::kTeamSize; + constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; + constexpr auto vlen = device::get_vlen(); + constexpr auto reg_nelem = + raft::div_rounding_up_unsafe(kDatasetBlockDim, kTeamSize * vlen); + + DISTANCE_T r = 0; + for (uint32_t elem_offset = (threadIdx.x % kTeamSize) * vlen; elem_offset < dim; + elem_offset += kDatasetBlockDim) { + DATA_T data[reg_nelem][vlen]; +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = e * (kTeamSize * vlen) + elem_offset; + if (k >= dim) break; + device::ldg_cg(reinterpret_cast(data[e]), + reinterpret_cast(dataset_ptr + k)); + } +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = e * (kTeamSize * vlen) + elem_offset; + if (k >= dim) break; +#pragma unroll + for (uint32_t v = 0; v < vlen; v++) { + QUERY_T d; + device::lds( + d, + query_smem_ptr + + sizeof(QUERY_T) * device::swizzling(k + v)); + r += dist_op( + d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); + } + } + } + return r; +} + +template +_RAFT_DEVICE __noinline__ auto compute_distance_standard_impl( + const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) -> + typename DescriptorT::DISTANCE_T +{ + auto distance = compute_distance_standard_worker_impl( + DescriptorT::ptr(args) + (static_cast(DescriptorT::ld(args)) * dataset_index), + args.dim, + args.smem_ws_ptr); + + distance = + apply_normalization_standard(distance, args, dataset_index); + + return distance; +} + +template +_RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( + const uint8_t* __restrict__ dataset_ptr, + const typename DescriptorT::CODE_BOOK_T* __restrict__ vq_code_book_ptr, + uint32_t dim, + uint32_t pq_codebook_ptr) -> typename DescriptorT::DISTANCE_T +{ + using DISTANCE_T = typename DescriptorT::DISTANCE_T; + using QUERY_T = typename DescriptorT::QUERY_T; + using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; + constexpr auto TeamSize = DescriptorT::kTeamSize; + constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; + constexpr auto PQ_BITS = DescriptorT::kPqBits; + constexpr auto PQ_LEN = DescriptorT::kPqLen; + + const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes; + static_assert(PQ_BITS == 8, "Only pq_bits == 8 is supported at the moment."); + constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** + constexpr uint32_t nelem = + raft::div_rounding_up_unsafe(DatasetBlockDim / PQ_LEN, TeamSize * vlen); + + constexpr auto kTeamMask = DescriptorT::kTeamSize - 1; + constexpr auto kTeamVLen = TeamSize * vlen; + + const auto n_subspace = raft::div_rounding_up_unsafe(dim, PQ_LEN); + const auto laneId = threadIdx.x & kTeamMask; + DISTANCE_T norm = 0; + for (uint32_t elem_offset = 0; elem_offset * PQ_LEN < dim; + elem_offset += DatasetBlockDim / PQ_LEN) { + uint32_t pq_codes[nelem]; +#pragma unroll + for (std::uint32_t e = 0; e < nelem; e++) { + const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; + if (k >= n_subspace) break; + device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); + } + // + if constexpr (PQ_LEN % 2 == 0) { +#pragma unroll + for (std::uint32_t e = 0; e < nelem; e++) { + const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; + if (k >= n_subspace) break; + half2 vq_vals[PQ_LEN][vlen / 2]; +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m++) { + const uint32_t d = (vlen * m) + (PQ_LEN * k); + if (d >= dim) break; + device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); + } + std::uint32_t pq_code = pq_codes[e]; +#pragma unroll + for (std::uint32_t v = 0; v < vlen; v++) { + if (PQ_LEN * (v + k) >= dim) break; +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN / 2; m++) { + constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); + const std::uint32_t d1 = m + (PQ_LEN / 2) * v; + const std::uint32_t d = + d1 * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; + half2 q2, c2; + device::lds(q2, query_ptr + sizeof(half2) * d); + device::lds(c2, + pq_codebook_ptr + + sizeof(CODE_BOOK_T) * ((1 << PQ_BITS) * 2 * m + (2 * (pq_code & 0xff)))); + auto dist = q2 - c2 - reinterpret_cast(vq_vals)[d1]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + } + pq_code >>= 8; + } + } + } else { +#pragma unroll + for (std::uint32_t e = 0; e < nelem; e++) { + const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; + if (k >= n_subspace) break; + CODE_BOOK_T vq_vals[PQ_LEN][vlen]; +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m++) { + const std::uint32_t d = (vlen * m) + (PQ_LEN * k); + if (d >= dim) break; + device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); + } + std::uint32_t pq_code = pq_codes[e]; +#pragma unroll + for (std::uint32_t v = 0; v < vlen; v++) { + if (PQ_LEN * (v + k) >= dim) break; + CODE_BOOK_T pq_vals[PQ_LEN]; + device::lds(pq_vals, pq_codebook_ptr + sizeof(CODE_BOOK_T) * PQ_LEN * (pq_code & 0xff)); +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m++) { + const std::uint32_t d1 = m + (PQ_LEN * v); + const std::uint32_t d = d1 + (PQ_LEN * k); + DISTANCE_T diff; + device::lds(diff, query_ptr + sizeof(QUERY_T) * d); + diff -= static_cast(pq_vals[m]); + diff -= + static_cast(reinterpret_cast(vq_vals)[d1]); + norm += diff * diff; + } + pq_code >>= 8; + } + } + } + } + return norm; +} + +template +_RAFT_DEVICE __noinline__ auto compute_distance_vpq_impl( + const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) -> + typename DescriptorT::DISTANCE_T +{ + const auto* dataset_ptr = + DescriptorT::encoded_dataset_ptr(args) + + (static_cast(DescriptorT::encoded_dataset_dim(args)) * dataset_index); + uint32_t vq_code; + device::ldg_cg(vq_code, reinterpret_cast(dataset_ptr)); + return compute_distance_vpq_worker_impl( + dataset_ptr, + DescriptorT::vq_code_book_ptr(args) + args.dim * vq_code, + args.dim, + args.smem_ws_ptr); +} + +template +__device__ DistanceT compute_distance_impl( + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + if constexpr (PQ_BITS == 0 && PQ_LEN == 0 && std::is_same_v) { + using desc_t = + standard_dataset_descriptor_t; + return compute_distance_standard_impl(args, dataset_index); + } else if constexpr (PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && + std::is_same_v) { + using desc_t = cagra_q_dataset_descriptor_t; + return compute_distance_vpq_impl(args, dataset_index); + } else { + static_assert(sizeof(TeamSize) == 0, + "compute_distance_impl: unsupported PQ_BITS/PQ_LEN/CodebookT/QueryT for CAGRA " + "JIT descriptor"); + return DistanceT{}; + } +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in new file mode 100644 index 0000000000..13cd022918 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -0,0 +1,61 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace { + +constexpr uint32_t k_team_size = @team_size@u; +constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; +constexpr uint32_t k_pq_bits = @pq_bits@u; +constexpr uint32_t k_pq_len = @pq_len@u; + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using query_t = @query_type@; +using codebook_t = @codebook_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +using args_t = typename dataset_descriptor_base_t::args_t; + +template <> +__device__ distance_t compute_distance(const args_t args, + index_t dataset_index, + bool valid, + uint32_t team_size_bits) +{ + auto per_thread = valid ? compute_distance_impl(args, dataset_index) + : distance_t{}; + return device::team_sum(per_thread, team_size_bits); +} + +template <> +__device__ distance_t +compute_distance_per_thread(const args_t args, index_t dataset_index) +{ + return compute_distance_impl(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json new file mode 100644 index 0000000000..82b8dbdf4e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -0,0 +1,154 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "u8" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "i8", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "0", + "pq_bits": "0", + "pq_prefix": "_standard", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_abbrev": "none" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "2", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_len": "4", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in new file mode 100644 index 0000000000..0035ae1da0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in @@ -0,0 +1,64 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; +using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +extern "C" __global__ void compute_distance_to_child_nodes(const index_t* const parent_node_list, + index_t* const parent_candidates_ptr, + distance_t* const parent_distance_ptr, + const std::size_t lds, + const std::uint32_t search_width, + const dataset_desc_base* dataset_desc, + const index_t* const neighbor_graph_ptr, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const data_t* query_ptr, + index_t* const visited_hashmap_ptr, + const std::uint32_t hash_bitlen, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const std::uint32_t ldd, + cagra_bitset_t bitset) +{ + compute_distance_to_child_nodes_kernel_jit( + parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + bitset); +} + +static_assert( + std::is_same_v< + decltype(compute_distance_to_child_nodes), + compute_distance_to_child_nodes_kernel_func_t>); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json new file mode 100644 index 0000000000..adfdf1e78b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh new file mode 100644 index 0000000000..cf62c9ba4a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh @@ -0,0 +1,181 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../neighbors_device_intrinsics.cuh" +#include "../hashmap.hpp" +#include "../utils.hpp" +#include "extern_device_functions.cuh" + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::device { + +// Helper to check if DescriptorT has kPqBits (VPQ descriptor) +template +struct has_kpq_bits { + template + static auto test(int) -> decltype(U::kPqBits, std::true_type{}); + template + static std::false_type test(...); + static constexpr bool value = decltype(test(0))::value; +}; + +template +inline constexpr bool has_kpq_bits_v = has_kpq_bits::value; + +// JIT version of compute_distance_to_random_nodes - uses const dataset_descriptor_base_t* (smem) +// Shared between single_cta and multi_cta JIT kernels +template +RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes_jit( + IndexT* __restrict__ result_indices_ptr, // [num_pickup] + DistanceT* __restrict__ result_distances_ptr, // [num_pickup] + const dataset_descriptor_base_t* smem_desc, + const uint32_t num_pickup, + const uint32_t num_distilation, + const uint64_t rand_xor_mask, + const IndexT* __restrict__ seed_ptr, // [num_seeds] + const uint32_t num_seeds, + IndexT* __restrict__ visited_hash_ptr, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hash_ptr, + const uint32_t traversed_hash_bitlen, + const uint32_t block_id = 0, + const uint32_t num_blocks = 1, + const IndexT graph_size = 0) +{ + uint32_t team_size_bits = smem_desc->team_size_bitshift_from_smem(); + IndexT dataset_size = smem_desc->size; + const auto args_load = smem_desc->args.load(); + + const auto max_i = raft::round_up_safe(num_pickup, device::warp_size >> team_size_bits); + const IndexT seed_index_limit = graph_size > 0 ? graph_size : dataset_size; + + for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) { + const bool valid_i = (i < num_pickup); + + IndexT best_index_team_local = raft::upper_bound(); + DistanceT best_norm2_team_local = raft::upper_bound(); + for (uint32_t j = 0; j < num_distilation; j++) { + IndexT seed_index = 0; + if (valid_i) { + uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); + if (seed_ptr && (gid < num_seeds)) { + seed_index = seed_ptr[gid]; + } else { + seed_index = device::xorshift64(gid ^ rand_xor_mask) % seed_index_limit; + } + } + + const auto norm2 = cuvs::neighbors::cagra::detail::compute_distance( + args_load, seed_index, valid_i, team_size_bits); + + if (valid_i && (norm2 < best_norm2_team_local)) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); + if (valid_i && lane_id == 0) { + if (best_index_team_local != raft::upper_bound()) { + if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } else if ((traversed_hash_ptr != nullptr) && + hashmap::search( + traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { + // Deactivate this entry as it has been already used by others. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } + } + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; + } + } +} + +// JIT version of compute_distance_to_child_nodes - uses const dataset_descriptor_base_t* (smem) +// Shared between single_cta and multi_cta JIT kernels +template +RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes_jit( + IndexT* __restrict__ result_child_indices_ptr, + DistanceT* __restrict__ result_child_distances_ptr, + const dataset_descriptor_base_t* smem_desc, + const IndexT* __restrict__ knn_graph, + const uint32_t knn_k, + IndexT* __restrict__ visited_hashmap_ptr, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, + const IndexT* __restrict__ parent_indices, + const IndexT* __restrict__ internal_topk_list, + const uint32_t search_width, + int* __restrict__ result_position = nullptr, + const int max_result_position = 0) +{ + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr IndexT invalid_index = ~static_cast(0); + + // Read child indices of parents from knn graph and check if the distance computation is + // necessary. + for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { + const IndexT smem_parent_id = parent_indices[i / knn_k]; + IndexT child_id = invalid_index; + if (smem_parent_id != invalid_index) { + const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; + child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; + } + if (child_id != invalid_index) { + if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { + child_id = invalid_index; + } else if ((traversed_hashmap_ptr != nullptr) && + hashmap::search( + traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { + child_id = invalid_index; + } + } + if (STATIC_RESULT_POSITION) { + result_child_indices_ptr[i] = child_id; + } else if (child_id != invalid_index) { + int j = atomicSub(result_position, 1) - 1; + result_child_indices_ptr[j] = child_id; + } + } + __syncthreads(); + + // Same inline distance pattern as search_single_cta_jit.cuh / device helpers + const auto team_size_bits = smem_desc->team_size_bitshift_from_smem(); + const auto num_k = knn_k * search_width; + const auto max_i = raft::round_up_safe(num_k, device::warp_size >> team_size_bits); + const auto args = smem_desc->args.load(); + const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; + const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; + + for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { + const auto j = i + ofst; + const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); + const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; + + const auto per_thread = + (child_id != invalid_index) + ? cuvs::neighbors::cagra::detail::compute_distance_per_thread( + args, child_id) + : (lead_lane ? raft::upper_bound() : 0); + const DistanceT child_dist = device::team_sum(per_thread, team_size_bits); + __syncwarp(); + + // Store the distance + if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } + } +} + +} // namespace cuvs::neighbors::cagra::detail::device diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_impl.cuh new file mode 100644 index 0000000000..86ba2ce3b2 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_impl.cuh @@ -0,0 +1,41 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "extern_device_functions.cuh" + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op_l2_impl(QUERY_T a, QUERY_T b) +{ + DISTANCE_T diff = a - b; + return diff * diff; +} + +template +__device__ DISTANCE_T dist_op_inner_product_impl(QUERY_T a, QUERY_T b) +{ + return -static_cast(a) * static_cast(b); +} + +template +__device__ DISTANCE_T dist_op_hamming_impl(QUERY_T a, QUERY_T b) +{ + const auto v = (a ^ b) & 0xffu; + return __popc(v); +} + +template +__device__ DISTANCE_T dist_op_l1_impl(QUERY_T a, QUERY_T b) +{ + DISTANCE_T diff = a - b; + return raft::abs(diff); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in new file mode 100644 index 0000000000..17cd39b980 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in @@ -0,0 +1,25 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +namespace { + +using query_t = @query_type@; +using distance_t = @distance_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +template <> +__device__ distance_t dist_op(query_t a, query_t b) +{ + return dist_op_@metric_tag@_impl(a, b); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json new file mode 100644 index 0000000000..63d35dd81b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json @@ -0,0 +1,34 @@ +{ + "_metric": [ + { + "metric_tag": "l2", + "jit_metric_tag": "tag_metric_l2", + "query_type": "float", + "query_abbrev": "f" + }, + { + "metric_tag": "inner_product", + "jit_metric_tag": "tag_metric_inner_product", + "query_type": "float", + "query_abbrev": "f" + }, + { + "metric_tag": "hamming", + "jit_metric_tag": "tag_metric_hamming", + "query_type": "uint8_t", + "query_abbrev": "u8" + }, + { + "metric_tag": "l1", + "jit_metric_tag": "tag_metric_l1", + "query_type": "float", + "query_abbrev": "f" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh new file mode 100644 index 0000000000..664fe9a05a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include "../compute_distance.hpp" +#include + +namespace cuvs::neighbors::cagra::detail { + +template +extern __device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b); + +template +extern __device__ DistanceT apply_normalization_standard( + DistanceT distance, + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index); + +template +extern __device__ const dataset_descriptor_base_t* setup_workspace( + const dataset_descriptor_base_t*, void*, const DataT*, uint32_t); + +template +extern __device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index, + bool valid, + uint32_t team_size_bits); + +template +extern __device__ DistanceT compute_distance_per_thread( + const typename dataset_descriptor_base_t::args_t, IndexT); +} // namespace cuvs::neighbors::cagra::detail + +namespace cuvs::neighbors::detail { + +template +extern __device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data); + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp new file mode 100644 index 0000000000..370dbd33d8 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include + +#include "../compute_distance.hpp" // dataset_descriptor_base_t +#include "cagra_bitset.cuh" +#include "search_single_cta_device_helpers.cuh" + +namespace cuvs::neighbors::cagra::detail { + +// Function types for extern "C" __global__ JIT entry points — must match cudaLibraryGetKernel / +// AlgorithmLauncher::dispatch signatures exactly (see static_assert in each *_kernel.cu). + +template +using search_single_cta_kernel_func_t = + void(uintptr_t, + DistanceT* const, + const std::uint32_t, + const DataT* const, + const IndexT* const, + const std::uint32_t, + const SourceIndexT*, + const unsigned, + const uint64_t, + const IndexT*, + const uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const dataset_descriptor_base_t*, + const IndexT, + cagra_bitset); + +namespace single_cta_search { + +template +using search_single_cta_p_kernel_func_t = + void(worker_handle_t*, + job_desc_t>*, + uint32_t*, + const IndexT* const, + const std::uint32_t, + const SourceIndexT*, + const unsigned, + const uint64_t, + const IndexT*, + const uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const dataset_descriptor_base_t*, + cagra_bitset); + +} // namespace single_cta_search + +namespace multi_cta_search { + +template +using search_multi_cta_kernel_func_t = + void(IndexT* const, + DistanceT* const, + const dataset_descriptor_base_t*, + const DataT* const, + const IndexT* const, + const std::uint32_t, + const std::uint32_t, + const SourceIndexT*, + const unsigned, + const uint64_t, + const IndexT*, + const std::uint32_t, + const std::uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const IndexT, + const std::uint32_t, + cagra_bitset); + +} // namespace multi_cta_search + +namespace multi_kernel_search { + +template +using random_pickup_kernel_func_t = void(const dataset_descriptor_base_t*, + const DataT* const, + const std::size_t, + const unsigned, + const uint64_t, + const IndexT*, + const std::uint32_t, + IndexT* const, + DistanceT* const, + const std::uint32_t, + IndexT* const, + const std::uint32_t, + const IndexT); + +template +using compute_distance_to_child_nodes_kernel_func_t = + void(const IndexT* const, + IndexT* const, + DistanceT* const, + const std::size_t, + const std::uint32_t, + const dataset_descriptor_base_t*, + const IndexT* const, + const std::uint32_t, + const SourceIndexT*, + const DataT*, + IndexT* const, + const std::uint32_t, + IndexT* const, + DistanceT* const, + const std::uint32_t, + cagra_bitset); + +template +using apply_filter_kernel_func_t = void(const SourceIndexT* const, + IndexT* const, + DistanceT* const, + const std::size_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + cagra_bitset); + +} // namespace multi_kernel_search + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in new file mode 100644 index 0000000000..94e19256d7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in @@ -0,0 +1,53 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +extern "C" __global__ void random_pickup(const dataset_desc_base* dataset_desc, + const data_t* const queries_ptr, + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const std::uint32_t num_seeds, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const std::uint32_t ldr, + index_t* const visited_hashmap_ptr, + const std::uint32_t hash_bitlen, + const index_t graph_size) +{ + random_pickup_kernel_jit(dataset_desc, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen, + graph_size); +} + +static_assert(std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json new file mode 100644 index 0000000000..adfdf1e78b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh new file mode 100644 index 0000000000..d01f58166d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh @@ -0,0 +1,41 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "extern_device_functions.cuh" + +#include "../../sample_filter_data.cuh" + +#include + +#include + +namespace cuvs::neighbors::detail { + +template +__device__ bool sample_filter_none_impl(uint32_t /*query_id*/, + SourceIndexT /*node_id*/, + void* /*filter_data*/) +{ + return true; +} + +template +__device__ bool sample_filter_bitset_impl(uint32_t /*query_id*/, + SourceIndexT node_id, + void* filter_data) +{ + if (filter_data == nullptr) { return true; } + + auto* data = static_cast*>(filter_data); + if (data->bitset_ptr == nullptr) { return true; } + + raft::core::bitset_view const view{ + data->bitset_ptr, data->bitset_len, data->original_nbits}; + return view.test(node_id); +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_kernel.cu.in new file mode 100644 index 0000000000..d7c1e54124 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_kernel.cu.in @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Explicit specialization for the unified 3-arg sample_filter used by CAGRA JIT +// (extern in extern_device_functions.cuh). Implementations in sample_filter_impl.cuh, +// aligned with filtering::{none_sample_filter, bitset_filter} in neighbors/sample_filter.cuh. + +#include + +#include + +namespace { + +using source_index_t = @source_index_type@; + +} // namespace + +namespace cuvs::neighbors::detail { + +template <> +__device__ bool sample_filter(uint32_t query_id, + source_index_t node_id, + void* filter_data) +{ + return sample_filter_@filter_name@_impl(query_id, node_id, filter_data); +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json new file mode 100644 index 0000000000..0136587b48 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json @@ -0,0 +1,15 @@ +{ + "filter_name": ["none", "bitset"], + "_bitset": [ + { + "bitset_type": "uint32_t", + "bitset_abbrev": "u32" + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "u32" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh new file mode 100644 index 0000000000..492195f359 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -0,0 +1,365 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include + +#include + +#include +#include +#include + +#ifdef _CLK_BREAKDOWN +#include +#endif + +#include "cagra_bitset.cuh" +#include "device_common_jit.cuh" +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +using cuvs::neighbors::cagra::detail::device::compute_distance_to_child_nodes_jit; +using cuvs::neighbors::cagra::detail::device::compute_distance_to_random_nodes_jit; +using cuvs::neighbors::detail::sample_filter; +template +__device__ void search_kernel_jit( + IndexT* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DistanceT* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const dataset_descriptor_base_t* dataset_desc, + const DataT* const queries_ptr, // [num_queries, dataset_dim] + const IndexT* const knn_graph, // [dataset_size, graph_degree] + const uint32_t max_elements, + const uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, // [num_queries, search_width] + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + const uint32_t visited_hash_bitlen, + IndexT* const traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] + const uint32_t traversed_hash_bitlen, + const uint32_t itopk_size, + const uint32_t min_iteration, + const uint32_t max_iteration, + uint32_t* const num_executed_iterations, /* stats */ + const IndexT graph_size, + const uint32_t query_id_offset, // Offset to add to query_id when calling filter + cagra_bitset bitset) +{ + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + auto to_source_index = [source_indices_ptr](INDEX_T x) { + return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; + }; + + const auto num_queries = gridDim.y; + const auto query_id = blockIdx.y; + const auto num_cta_per_query = gridDim.x; + const auto cta_id = blockIdx.x; // local CTA ID + +#ifdef _CLK_BREAKDOWN + uint64_t clk_init = 0; + uint64_t clk_compute_1st_distance = 0; + uint64_t clk_topk = 0; + uint64_t clk_pickup_parents = 0; + uint64_t clk_compute_distance = 0; + uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint8_t smem[]; + + // Layout of result_buffer + // +----------------+---------+---------------------------+ + // | internal_top_k | padding | neighbors of parent nodes | + // | | upto 32 | | + // +----------------+---------+---------------------------+ + // |<--- result_buffer_size_32 --->| + const auto result_buffer_size = itopk_size + graph_degree; + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + assert(result_buffer_size_32 <= max_elements); + + // Get dim and smem_ws_size_in_bytes directly from base descriptor + uint32_t dim = dataset_desc->args.dim; + uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes(); + + auto smem_desc = + setup_workspace(dataset_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ local_visited_hashmap_ptr = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); + auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); + + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); + + constexpr INDEX_T invalid_index = ~static_cast(0); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes using JIT version + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + uint32_t block_id = cta_id + (num_cta_per_query * query_id); + uint32_t num_blocks = num_cta_per_query * num_queries; + + compute_distance_to_random_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + graph_degree, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + block_id, + num_blocks, + graph_size); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + uint32_t iter = 0; + while (1) { + _CLK_START(); + if (threadIdx.x < 32) { + // [1st warp] Topk with bitonic sort + if constexpr (std::is_same_v) { + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort + // function (vs post-inlining, this impacts register pressure) + if (max_elements <= 64) { + topk_by_bitonic_sort_wrapper_64( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort_wrapper_128( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort_wrapper_256( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } else { + if (max_elements <= 64) { + topk_by_bitonic_sort<64, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort<128, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort<256, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } + } + __syncthreads(); + _CLK_REC(clk_topk); + + if (iter + 1 >= max_iteration) { break; } + + _CLK_START(); + if (threadIdx.x < 32) { + // [1st warp] Pick up a next parent + pickup_next_parent(parent_indices_buffer, + result_indices_buffer, + result_distances_buffer, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } else { + // [Other warps] Reset visited hashmap + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); + } + __syncthreads(); + _CLK_REC(clk_pickup_parents); + + if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } + + _CLK_START(); + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + if ((i >= itopk_size) && (index & index_msb_1_mask)) { + // Remove nodes kicked out of the itopk list from the traversed hash table. + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } else { + // Restore visited hashmap by putting nodes on result buffer in it. + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + } + // Initialize buffer for compute_distance_to_child_nodes. + if (threadIdx.x == blockDim.x - 1) { result_position[0] = result_buffer_size_32; } + __syncthreads(); + + // Compute the norms between child nodes and query node using JIT version + compute_distance_to_child_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + 1, + result_position, + result_buffer_size_32); + __syncthreads(); + + // Check the state of the nodes in the result buffer which were not updated + // by the compute_distance_to_child_nodes above, and if it cannot be used as + // a parent node, it is deactivated. + for (uint32_t i = threadIdx.x; i < result_position[0]; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index || index & index_msb_1_mask) { continue; } + if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + _CLK_REC(clk_compute_distance); + + // Filtering - use extern sample_filter function (linked via JIT LTO) + for (unsigned p = threadIdx.x; p < 1; p += blockDim.x) { + if (parent_indices_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id + query_id_offset, + to_source_index(parent_id), + bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_indices_buffer[p]] = invalid_index; + } + } + } + __syncthreads(); + + iter++; + } + + // Filtering - use extern sample_filter function (linked via JIT LTO) + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + if (!sample_filter(query_id + query_id_offset, + to_source_index(index), + bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + + // Output search results (1st warp only). + if (threadIdx.x < 32) { + uint32_t offset = 0; + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { + INDEX_T index = result_indices_buffer[i]; + bool is_valid = false; + if (index != invalid_index) { + if (index & index_msb_1_mask) { + is_valid = true; + index &= ~index_msb_1_mask; + } else if ((offset < itopk_size) && + hashmap::insert( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + // If a node that is not used as a parent can be inserted into + // the traversed hash table, it is considered a valid result. + is_valid = true; + } + } + const auto mask = __ballot_sync(0xffffffff, is_valid); + if (is_valid) { + const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); + if (j < itopk_size) { + uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = index & ~index_msb_1_mask; + if (result_distances_ptr != nullptr) { + DISTANCE_T dist = result_distances_buffer[i]; + result_distances_ptr[k] = dist; + } + } else { + // If it is valid and registered in the traversed hash table but is + // not output as a result, it is removed from the hash table. + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); + } + } + offset += __popc(mask); + } + // If the number of outputs is insufficient, fill in with invalid results. + for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { + uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = invalid_index; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = utils::get_max_value(); + } + } + } + + if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } + +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) && + ((query_id * 3) % gridDim.y < 3)) { + printf( + "%s:%d " + "query, %d, thread, %d" + ", init, %lu" + ", 1st_distance, %lu" + ", topk, %lu" + ", pickup_parents, %lu" + ", distance, %lu" + "\n", + __FILE__, + __LINE__, + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_pickup_parents, + clk_compute_distance); + } +#endif +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in new file mode 100644 index 0000000000..9acd73687b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in @@ -0,0 +1,75 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; +using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta( + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const dataset_desc_base* dataset_desc, + const data_t* const queries_ptr, + const index_t* const knn_graph, + const std::uint32_t max_elements, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const std::uint32_t num_seeds, + const std::uint32_t visited_hash_bitlen, + index_t* const traversed_hashmap_ptr, + const std::uint32_t traversed_hash_bitlen, + const std::uint32_t itopk_size, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const index_t graph_size, + const std::uint32_t query_id_offset, + cagra_bitset_t bitset) +{ + search_kernel_jit(result_indices_ptr, + result_distances_ptr, + dataset_desc, + queries_ptr, + knn_graph, + max_elements, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, + itopk_size, + min_iteration, + max_iteration, + num_executed_iterations, + graph_size, + query_id_offset, + bitset); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json new file mode 100644 index 0000000000..adfdf1e78b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp new file mode 100644 index 0000000000..5e6ea43130 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp @@ -0,0 +1,43 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +struct CagraMultiCtaSearchPlanner + : CagraPlannerBase { + static inline LauncherJitCache launcher_jit_cache{}; + + CagraMultiCtaSearchPlanner(cuvs::distance::DistanceType /*metric*/, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/) + : CagraPlannerBase( + "search_multi_cta", launcher_jit_cache) + { + } + + void add_search_multi_cta_kernel_fragment() + { + this->template add_static_fragment< + fragment_tag_search_multi_cta>(); + } +}; + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh new file mode 100644 index 0000000000..a7ab078343 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -0,0 +1,212 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../neighbors_device_intrinsics.cuh" +#include "../hashmap.hpp" +#include "../utils.hpp" +#include "cagra_bitset.cuh" + +#include +#include + +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template +__device__ void random_pickup_kernel_jit( + const dataset_descriptor_base_t* dataset_desc, + const DataT* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + IndexT* const result_indices_ptr, // [num_queries, ldr] + DistanceT* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen, + const IndexT graph_size) +{ + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + // Get team_size_bits directly from base descriptor + uint32_t team_size_bits = dataset_desc->team_size_bitshift(); + + const auto ldb = hashmap::get_size(hash_bitlen); + const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) >> team_size_bits; + const uint32_t query_id = blockIdx.y; + if (global_team_index >= num_pickup) { return; } + extern __shared__ uint8_t smem[]; + + auto smem_desc = + setup_workspace(dataset_desc, smem, queries_ptr, query_id); + __syncthreads(); + + IndexT dataset_size = smem_desc->size; + const INDEX_T seed_index_limit = graph_size > 0 ? graph_size : dataset_size; + const auto args_load = smem_desc->args.load(); + + INDEX_T best_index_team_local; + DISTANCE_T best_norm2_team_local = utils::get_max_value(); + for (unsigned i = 0; i < num_distilation; i++) { + INDEX_T seed_index; + if (seed_ptr && (global_team_index < num_seeds)) { + seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; + } else { + seed_index = + device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % seed_index_limit; + } + + const auto norm2 = + compute_distance(args_load, seed_index, true, team_size_bits); + + if (norm2 < best_norm2_team_local) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + const auto store_gmem_index = global_team_index + (ldr * query_id); + if ((threadIdx.x & ((1u << team_size_bits) - 1u)) == 0) { + if (hashmap::insert( + visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { + result_distances_ptr[store_gmem_index] = best_norm2_team_local; + result_indices_ptr[store_gmem_index] = best_index_team_local; + } else { + result_distances_ptr[store_gmem_index] = utils::get_max_value(); + result_indices_ptr[store_gmem_index] = utils::get_max_value(); + } + } +} + +using cuvs::neighbors::detail::sample_filter; +template +__device__ void compute_distance_to_child_nodes_kernel_jit( + const IndexT* const parent_node_list, // [num_queries, search_width] + IndexT* const parent_candidates_ptr, // [num_queries, search_width] + DistanceT* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, + const std::uint32_t search_width, + const dataset_descriptor_base_t* dataset_desc, + const IndexT* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const DataT* query_ptr, // [num_queries, data_dim] + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t hash_bitlen, + IndexT* const result_indices_ptr, // [num_queries, ldd] + DistanceT* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + cagra_bitset bitset) +{ + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + // Get team_size_bits directly from base descriptor + uint32_t team_size_bits = dataset_desc->team_size_bitshift(); + + const auto team_size = 1u << team_size_bits; + const uint32_t ldb = hashmap::get_size(hash_bitlen); + const auto tid = threadIdx.x + blockDim.x * blockIdx.x; + const auto global_team_id = tid >> team_size_bits; + const auto query_id = blockIdx.y; + + extern __shared__ uint8_t smem[]; + auto smem_desc = + setup_workspace(dataset_desc, smem, query_ptr, query_id); + + __syncthreads(); + if (global_team_id >= search_width * graph_degree) { return; } + + const std::size_t parent_list_index = + parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; + + if (parent_list_index == utils::get_max_value()) { return; } + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; + + if (raw_parent_index == utils::get_max_value()) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + return; + } + const auto parent_index = raw_parent_index & ~index_msb_1_mask; + + const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); + + const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; + + const auto compute_distance_flag = hashmap::insert( + team_size, visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); + + const auto args = smem_desc->args.load(); + DISTANCE_T norm2 = compute_distance( + args, static_cast(child_id), compute_distance_flag, team_size_bits); + + if (compute_distance_flag) { + if ((threadIdx.x & (team_size - 1)) == 0) { + result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; + result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; + } + } else { + if ((threadIdx.x & (team_size - 1)) == 0) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + } + } + + if (bitset.bitset_ptr != nullptr) { + const SourceIndexT node_id = source_indices_ptr == nullptr + ? static_cast(parent_index) + : static_cast(source_indices_ptr[parent_index]); + if (!sample_filter(query_id, node_id, &bitset)) { + parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * query_id)] = + utils::get_max_value(); + } + } +} + +template +__device__ void apply_filter_kernel_jit( + const SourceIndexT* source_indices_ptr, // [num_queries, search_width] + IndexT* const result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const std::uint32_t query_id_offset, + cagra_bitset bitset) +{ + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= result_buffer_size * num_queries) { return; } + const auto i = tid % result_buffer_size; + const auto j = tid / result_buffer_size; + const auto index = i + j * lds; + + if (result_indices_ptr[index] != ~index_msb_1_mask) { + // Use extern sample_filter function with 3 params: query_id, node_id, filter_data + // Third argument is &bitset (layout matches bitset_filter_data_t) or nullptr for none_filter + SourceIndexT node_id = source_indices_ptr == nullptr + ? static_cast(result_indices_ptr[index]) + : source_indices_ptr[result_indices_ptr[index]]; + + if (!sample_filter( + query_id_offset + j, node_id, bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + result_indices_ptr[index] = utils::get_max_value(); + result_distances_ptr[index] = utils::get_max_value(); + } + } +} + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp new file mode 100644 index 0000000000..e74134d0b2 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp @@ -0,0 +1,58 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template +struct CagraMultiKernelSearchPlanner + : CagraPlannerBase { + static inline LauncherJitCache launcher_jit_cache{}; + + CagraMultiKernelSearchPlanner(cuvs::distance::DistanceType /*metric*/, + const std::string& kernel_name, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/) + : CagraPlannerBase( + kernel_name, launcher_jit_cache) + { + } + + void add_linked_kernel(std::string const& kernel_name) + { + if (kernel_name == "random_pickup") { + this->template add_static_fragment< + fragment_tag_random_pickup>(); + } else if (kernel_name == "compute_distance_to_child_nodes") { + this->template add_static_fragment< + fragment_tag_compute_distance_to_child_nodes>(); + } else if (kernel_name == "apply_filter_kernel") { + this->template add_static_fragment< + fragment_tag_apply_filter_kernel>(); + } else { + RAFT_FAIL("Unknown CAGRA multi-kernel JIT kernel: %s", kernel_name.c_str()); + } + } +}; + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh new file mode 100644 index 0000000000..cce818d462 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh @@ -0,0 +1,680 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Device-only includes - no host-side headers +#include "../../neighbors_device_intrinsics.cuh" +#include "../bitonic.hpp" +#include "../device_memory_ops.hpp" +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include // For uint4 + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Type bundle for `job_desc_t` (DATA_T / INDEX_T / DISTANCE_T for persistent single-CTA kernels). +template +struct job_desc_traits { + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; +}; + +// Constants for persistent kernels (shared by JIT device code and host launcher) +constexpr size_t kCacheLineBytes = 64; +constexpr uint32_t kMaxJobsNum = 8192; +constexpr uint32_t kMaxWorkersNum = 4096; +constexpr uint32_t kMaxWorkersPerThread = 256; +constexpr uint32_t kSoftMaxWorkersPerThread = 16; + +// Worker handle for persistent kernels +struct alignas(kCacheLineBytes) worker_handle_t { + using handle_t = uint64_t; + struct value_t { + uint32_t desc_id; + uint32_t query_id; + }; + union data_t { + handle_t handle; + value_t value; + }; + cuda::atomic data; +}; +static_assert(sizeof(worker_handle_t::value_t) == sizeof(worker_handle_t::handle_t)); +static_assert( + cuda::atomic::is_always_lock_free); + +constexpr worker_handle_t::handle_t kWaitForWork = std::numeric_limits::max(); +constexpr worker_handle_t::handle_t kNoMoreWork = kWaitForWork - 1; + +constexpr auto is_worker_busy(worker_handle_t::handle_t h) -> bool +{ + return (h != kWaitForWork) && (h != kNoMoreWork); +} + +// Job descriptor for persistent kernels +template +struct alignas(kCacheLineBytes) job_desc_t { + using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; + using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; + // The algorithm input parameters + struct value_t { + uintptr_t result_indices_ptr; // [num_queries, top_k] + distance_type* result_distances_ptr; // [num_queries, top_k] + const data_type* queries_ptr; // [num_queries, dataset_dim] + uint32_t top_k; + uint32_t n_queries; + }; + using blob_elem_type = uint4; + constexpr static inline size_t kBlobSize = + raft::div_rounding_up_safe(sizeof(value_t), sizeof(blob_elem_type)); + // Union facilitates loading the input by a warp in a single request + union input_t { + blob_elem_type blob[kBlobSize]; // NOLINT + value_t value; + } input; + // Last thread triggers this flag. + cuda::atomic completion_flag; +}; + +// Pick up next parent nodes from the internal topk list +template +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const terminate_flag, + INDEX_T* const next_parent_indices, + INDEX_T* const internal_topk_indices, + const std::size_t internal_topk_size, + const std::uint32_t search_width) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (std::uint32_t i = threadIdx.x; i < search_width; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + std::uint32_t itopk_max = internal_topk_size; + if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } + std::uint32_t num_new_parents = 0; + for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { + std::uint32_t jj = j; + if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } + INDEX_T index; + int new_parent = 0; + if (j < internal_topk_size) { + index = internal_topk_indices[jj]; + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; + if (i < search_width) { + next_parent_indices[i] = jj; + // set most significant bit as used node + internal_topk_indices[jj] |= index_msb_1_mask; + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= search_width) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +// Helper function for bitonic sort and full +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + static_assert(MAX_CANDIDATES <= 256); + if constexpr (!MULTI_WARPS) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_CANDIDATES + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } else { + assert(blockDim.x >= 64); + // Use two warps (64 threads) + constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; + static_assert(max_candidates_per_warp <= 128); + constexpr unsigned N = (max_candidates_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (warp_id < 2) { + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = lane_id + (raft::warp_size() * i); + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && jl < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + __syncthreads(); + + unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; + if (warp_id < num_warps_used) { + /* Temp_candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned kl = max_candidates_per_warp - 1 - jl; + unsigned j = jl + (max_candidates_per_warp * warp_id); + unsigned k = MAX_CANDIDATES - 1 - j; + if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; + float temp_key = candidate_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + if (warp_id < num_warps_used) { + /* Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + } +} + +// Wrapper functions to avoid pre-inlining (impacts register pressure) +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_64_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<64, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_128_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<128, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_256_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<256, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +// TopK by bitonic sort and merge (template version with MAX_ITOPK) +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( + float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + + static_assert(MAX_ITOPK <= 512); + if constexpr (!MULTI_WARPS) { + static_assert(MAX_ITOPK <= 256); + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_ITOPK + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (first) { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + } else { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + key[i] = itopk_distances[device::swizzling(j)]; + val[i] = itopk_indices[device::swizzling(j)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + } + /* Merge candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; // [0:max_itopk-1] + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk || k >= num_candidates) continue; + float candidate_key = candidate_distances[device::swizzling(k)]; + if (key[i] > candidate_key) { + key[i] = candidate_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } else { + static_assert(MAX_ITOPK == 512); + assert(blockDim.x >= 64); + // Use two warps (64 threads) or more + constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; + constexpr unsigned N = (max_itopk_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (first) { + /* Load itop results (not sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i) + (max_itopk_per_warp * warp_id); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + __syncthreads(); + if (warp_id < 2) { + /* Load intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk) continue; + float temp_key = itopk_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = itopk_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + } + __syncthreads(); + /* Store itopk results (sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } + const uint32_t num_itopk_div2 = num_itopk / 2; + if (threadIdx.x < 3) { + // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. + work_buf[threadIdx.x] = num_itopk_div2; + } + __syncthreads(); + + // Merge candidates (using whole threads) + for (unsigned k = threadIdx.x; k < (num_candidates < num_itopk ? num_candidates : num_itopk); + k += blockDim.x) { + const unsigned j = num_itopk - 1 - k; + const float itopk_key = itopk_distances[device::swizzling(j)]; + const float candidate_key = candidate_distances[device::swizzling(k)]; + if (itopk_key > candidate_key) { + itopk_distances[device::swizzling(j)] = candidate_key; + itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; + if (j < num_itopk_div2) { + atomicMin(work_buf + 2, j); + } else { + atomicMin(work_buf + 1, j - num_itopk_div2); + } + } + } + __syncthreads(); + + // Merge 1st and 2nd half of itopk (using whole threads) + for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { + const unsigned k = j + num_itopk_div2; + float key_0 = itopk_distances[device::swizzling(j)]; + float key_1 = itopk_distances[device::swizzling(k)]; + if (key_0 > key_1) { + itopk_distances[device::swizzling(j)] = key_1; + itopk_distances[device::swizzling(k)] = key_0; + IdxT val_0 = itopk_indices[device::swizzling(j)]; + IdxT val_1 = itopk_indices[device::swizzling(k)]; + itopk_indices[device::swizzling(j)] = val_1; + itopk_indices[device::swizzling(k)] = val_0; + atomicMin(work_buf + 0, j); + } + } + if (threadIdx.x == blockDim.x - 1) { + if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } + } + __syncthreads(); + // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. + if (warp_id < 2) { + // Load intermedidate itopk results + const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 + for (unsigned i = 0; i < N; i++) { + unsigned k = num_itopk; + unsigned j = (N * lane_id) + i; + if (j < turning_point) { + k = j + (num_itopk_div2 * warp_id); + } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { + j -= (MAX_ITOPK / 2 - num_itopk_div2); + if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } + } + if (k < num_itopk) { + key[i] = itopk_distances[device::swizzling(k)]; + val[i] = itopk_indices[device::swizzling(k)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + const unsigned j = (N * lane_id) + i; + if (j < num_itopk_div2) { + unsigned k = j + (num_itopk_div2 * warp_id); + itopk_distances[device::swizzling(k)] = key[i]; + itopk_indices[device::swizzling(k)] = val[i]; + } + } + } + } +} + +// Wrapper functions to avoid pre-inlining +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_64_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<64, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_128_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<128, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_256_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<256, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +// TopK by bitonic sort and merge (runtime version) +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( + float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t max_itopk, + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t max_candidates, + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + static_assert(std::is_same_v); + assert(max_itopk <= 512); + assert(max_candidates <= 256); + assert(!MULTI_WARPS || blockDim.x >= 64); + + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_full + // function (vs post-inlining, this impacts register pressure) + if (max_candidates <= 64) { + topk_by_bitonic_sort_and_full_wrapper_64_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else if (max_candidates <= 128) { + topk_by_bitonic_sort_and_full_wrapper_128_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else { + topk_by_bitonic_sort_and_full_wrapper_256_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } + + if constexpr (!MULTI_WARPS) { + assert(max_itopk <= 256); + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_merge + // function (vs post-inlining, this impacts register pressure) + if (max_itopk <= 64) { + topk_by_bitonic_sort_and_merge_wrapper_64_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else if (max_itopk <= 128) { + topk_by_bitonic_sort_and_merge_wrapper_128_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else { + topk_by_bitonic_sort_and_merge_wrapper_256_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } + } else { + assert(max_itopk > 256); + topk_by_bitonic_sort_and_merge<512, MULTI_WARPS, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } +} + +// This function move the invalid index element to the end of the itopk list. +// Require : array_length % 32 == 0 && The invalid entry is only one. +template +RAFT_DEVICE_INLINE_FUNCTION void move_invalid_to_end_of_list(IdxT* const index_array, + float* const distance_array, + const std::uint32_t array_length) +{ + constexpr std::uint32_t warp_size = 32; + constexpr std::uint32_t invalid_index = utils::get_max_value(); + const std::uint32_t lane_id = threadIdx.x % warp_size; + + if (threadIdx.x >= warp_size) { return; } + + bool found_invalid = false; + if (array_length % warp_size == 0) { + for (std::uint32_t i = lane_id; i < array_length; i += warp_size) { + const auto index = index_array[i]; + const auto distance = distance_array[i]; + + if (found_invalid) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } else { + // Check if the index is invalid + const auto I_found_invalid = (index == invalid_index); + const auto who_has_invalid = raft::ballot(I_found_invalid); + // if a value that is loaded by a smaller lane id thread, shift the array + if (who_has_invalid << (warp_size - lane_id)) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } + + found_invalid = who_has_invalid; + } + } + } + if (lane_id == 0) { + index_array[array_length - 1] = invalid_index; + distance_array[array_length - 1] = utils::get_max_value(); + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr, + const size_t hashmap_bitlen, + const INDEX_T* itopk_indices, + const uint32_t itopk_size, + const uint32_t first_tid = 0) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + if (threadIdx.x < first_tid) return; + for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { + auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit + hashmap::insert(hashmap_ptr, hashmap_bitlen, key); + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh new file mode 100644 index 0000000000..282490559a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -0,0 +1,645 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Device-only helpers - split out to avoid pulling host launcher code into JIT translation units +#include "search_single_cta_device_helpers.cuh" + +// neighbors_device_intrinsics / memory_ops come via search_single_cta_device_helpers.cuh +#include "../hashmap.hpp" +#include "../topk_by_radix.cuh" +#include "../utils.hpp" + +#include // For raft::shfl_xor +#include // For raft::round_up_safe +#include + +#include + +#include +#include + +#include // For assert() + +#ifdef _CLK_BREAKDOWN +#include // For printf() in debug mode +#endif + +// Include extern function declarations before namespace so they're available to kernel definitions +#include "cagra_bitset.cuh" +#include "extern_device_functions.cuh" +// Include shared JIT device functions +#include "device_common_jit.cuh" + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Sample filter extern function +// sample_filter is declared in extern_device_functions.cuh +using cuvs::neighbors::detail::sample_filter; + +// JIT versions of compute_distance_to_random_nodes and compute_distance_to_child_nodes +// are now shared in device_common_jit.cuh - use fully qualified names +using cuvs::neighbors::cagra::detail::device::compute_distance_to_child_nodes_jit; +using cuvs::neighbors::cagra::detail::device::compute_distance_to_random_nodes_jit; + +// JIT search_core - setup_workspace/compute_distance via function pointers +template +RAFT_DEVICE_INLINE_FUNCTION void search_core( + uintptr_t result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::uint32_t top_k, + const DataT* const queries_ptr, + const IndexT* const knn_graph, + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + const dataset_descriptor_base_t* dataset_desc, + cagra_bitset bitset, + const IndexT graph_size = 0) // Original number of bits +{ + using LOAD_T = device::LOAD_128BIT_T; + + auto to_source_index = [source_indices_ptr](IndexT x) { + return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; + }; + +#ifdef _CLK_BREAKDOWN + std::uint64_t clk_init = 0; + std::uint64_t clk_compute_1st_distance = 0; + std::uint64_t clk_topk = 0; + std::uint64_t clk_reset_hash = 0; + std::uint64_t clk_pickup_parents = 0; + std::uint64_t clk_restore_hash = 0; + std::uint64_t clk_compute_distance = 0; + std::uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint8_t smem[]; + + // Layout of result_buffer + const auto result_buffer_size = internal_topk + (search_width * graph_degree); + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + const auto small_hash_size = hashmap::get_size(small_hash_bitlen); + + // Get dim and smem_ws_size directly from base descriptor + uint32_t dim = dataset_desc->args.dim; + uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes(); + + auto smem_desc = + setup_workspace(dataset_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ visited_hash_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_list_buffer = + reinterpret_cast(visited_hash_buffer + small_hash_size); + auto* __restrict__ topk_ws = reinterpret_cast(parent_list_buffer + search_width); + auto* terminate_flag = reinterpret_cast(topk_ws + 3); + auto* __restrict__ smem_work_ptr = reinterpret_cast(terminate_flag + 1); + + // A flag for filtering. + auto filter_flag = terminate_flag; + + if (threadIdx.x == 0) { + terminate_flag[0] = 0; + topk_ws[0] = ~0u; + } + + // Init hashmap + IndexT* local_visited_hashmap_ptr; + if (small_hash_bitlen) { + local_visited_hashmap_ptr = visited_hash_buffer; + } else { + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y); + } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes using JIT version + _CLK_START(); + const IndexT* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + // Get dataset_size directly from base descriptor + IndexT dataset_size = smem_desc->size; + compute_distance_to_random_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen, + (IndexT*)nullptr, + 0, + 0, + 1, + graph_size); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + std::uint32_t iter = 0; + while (1) { + // sort + if constexpr (TOPK_BY_BITONIC_SORT) { + assert(blockDim.x >= 64); + const bool bitonic_sort_and_full_multi_warps = (max_candidates > 128) ? true : false; + + // reset small-hash table. + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + unsigned hash_start_tid; + if (blockDim.x == 32) { + hash_start_tid = 0; + } else if (blockDim.x == 64) { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { + hash_start_tid = 0; + } else { + hash_start_tid = 32; + } + } else { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { + hash_start_tid = 64; + } else { + hash_start_tid = 32; + } + } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); + _CLK_REC(clk_reset_hash); + } + + // topk with bitonic sort + _CLK_START(); + // For JIT version, we always check filter_flag at runtime since sample_filter is extern + if (*filter_flag != 0) { + // Move the filtered out index to the end of the itopk list + for (unsigned i = 0; i < search_width; i++) { + move_invalid_to_end_of_list( + result_indices_buffer, result_distances_buffer, internal_topk); + } + if (threadIdx.x == 0) { *terminate_flag = 0; } + } + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + max_itopk, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + max_candidates, + search_width * graph_degree, + topk_ws, + (iter == 0)); + __syncthreads(); + _CLK_REC(clk_topk); + } else { + _CLK_START(); + // topk with radix block sort + topk_by_radix_sort{}(max_itopk, + internal_topk, + result_buffer_size, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + nullptr, + topk_ws, + true, + smem_work_ptr); + _CLK_REC(clk_topk); + + // reset small-hash table + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + hashmap::init(local_visited_hashmap_ptr, hash_bitlen); + _CLK_REC(clk_reset_hash); + } + } + __syncthreads(); + + if (iter + 1 == max_iteration) { break; } + + // pick up next parents + if (threadIdx.x < 32) { + _CLK_START(); + pickup_next_parents( + terminate_flag, parent_list_buffer, result_indices_buffer, internal_topk, search_width); + _CLK_REC(clk_pickup_parents); + } + + // restore small-hash table by putting internal-topk indices in it + if ((iter + 1) % small_hash_reset_interval == 0) { + const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); + _CLK_START(); + hashmap_restore( + local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); + _CLK_REC(clk_restore_hash); + } + __syncthreads(); + + if (*terminate_flag && iter >= min_iteration) { break; } + + __syncthreads(); + // compute the norms between child nodes and query node using JIT version + _CLK_START(); + compute_distance_to_child_nodes_jit( + result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + (IndexT*)nullptr, + 0u, + parent_list_buffer, + result_indices_buffer, + search_width); + // Critical: __syncthreads() must be reached by ALL threads + // If any thread is stuck in compute_distance_to_child_nodes_jit, this will hang + __syncthreads(); + _CLK_REC(clk_compute_distance); + + // Filtering - use extern sample_filter function + if (threadIdx.x == 0) { *filter_flag = 0; } + __syncthreads(); + + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const IndexT invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_list_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id + query_id_offset, + to_source_index(parent_id), + bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_list_buffer[p]] = invalid_index; + *filter_flag = 1; + } + } + } + __syncthreads(); + + iter++; + } + + // Post process for filtering - use extern sample_filter function + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const IndexT invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + if (node_id != (invalid_index & ~index_msb_1_mask) && + !sample_filter(query_id + query_id_offset, + to_source_index(node_id), + bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + // Move invalid index items to the end of the buffer without sorting the entire buffer + using scan_op_t = cub::WarpScan; + auto& temp_storage = *reinterpret_cast(smem_work_ptr); + + constexpr std::uint32_t warp_size = 32; + if (threadIdx.x < warp_size) { + std::uint32_t num_found_valid = 0; + for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; + buffer_offset += warp_size) { + const auto src_position = buffer_offset + threadIdx.x; + const std::uint32_t is_valid_index = + (result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; + std::uint32_t new_position; + scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); + if (is_valid_index) { + const auto dst_position = num_found_valid + (new_position - 1); + result_indices_buffer[dst_position] = result_indices_buffer[src_position]; + result_distances_buffer[dst_position] = result_distances_buffer[src_position]; + } + + num_found_valid += new_position; + for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { + const auto v = raft::shfl_xor(num_found_valid, offset); + if ((threadIdx.x & offset) == 0) { num_found_valid = v; } + } + + if (num_found_valid >= top_k) { break; } + } + + if (num_found_valid < top_k) { + for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + } + + // If the sufficient number of valid indexes are not in the internal topk, pick up from the + // candidate list. + if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { + __syncthreads(); + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + max_itopk, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + max_candidates, + search_width * graph_degree, + topk_ws, + (iter == 0)); + } + __syncthreads(); + + // NB: The indices pointer is tagged with its element size. + const uint32_t index_element_tag = result_indices_ptr & 0x3; + result_indices_ptr ^= index_element_tag; + auto write_indices = + index_element_tag == 3 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : index_element_tag == 2 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : index_element_tag == 1 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : [](uintptr_t ptr, uint32_t i, SourceIndexT x) { + reinterpret_cast(ptr)[i] = static_cast(x); + }; + for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { + unsigned j = i + (top_k * query_id); + unsigned ii = i; + if constexpr (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } + + auto internal_index = + result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit + auto source_index = to_source_index(internal_index); + write_indices(result_indices_ptr, j, source_index); + } + if (threadIdx.x == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) { + printf( + "%s:%d " + "query, %d, thread, %d" + ", init, %lu" + ", 1st_distance, %lu" + ", topk, %lu" + ", reset_hash, %lu" + ", pickup_parents, %lu" + ", restore_hash, %lu" + ", distance, %lu" + "\n", + __FILE__, + __LINE__, + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_reset_hash, + clk_pickup_parents, + clk_restore_hash, + clk_compute_distance); + } +#endif +} + +// JIT device implementation - called from extern "C" __global__ entry in generated .cu +template +__device__ void search_kernel_jit( + uintptr_t result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::uint32_t top_k, + const DataT* const queries_ptr, + const IndexT* const knn_graph, + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + const dataset_descriptor_base_t* dataset_desc, + const IndexT graph_size, + cagra_bitset bitset) +{ + const auto query_id = blockIdx.y; + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + dataset_desc, + bitset, + graph_size); +} + +// JIT persistent device implementation - called from extern "C" __global__ entry in generated .cu +template +__device__ void search_single_cta_p_impl( + worker_handle_t* worker_handles, + job_desc_t>* job_descriptors, + uint32_t* completion_counters, + const IndexT* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + const dataset_descriptor_base_t* dataset_desc, + cagra_bitset bitset) +{ + using job_desc_type = job_desc_t>; + __shared__ typename job_desc_type::input_t job_descriptor; + __shared__ worker_handle_t::data_t worker_data; + + auto& worker_handle = worker_handles[blockIdx.y].data; + uint32_t job_ix; + + while (true) { + // wait the writing phase + if (threadIdx.x == 0) { + worker_handle_t::data_t worker_data_local; + do { + worker_data_local = worker_handle.load(cuda::memory_order_relaxed); + } while (worker_data_local.handle == kWaitForWork); + if (worker_data_local.handle != kNoMoreWork) { + worker_handle.store({kWaitForWork}, cuda::memory_order_relaxed); + } + job_ix = worker_data_local.value.desc_id; + cuda::atomic_thread_fence(cuda::memory_order_acquire, cuda::thread_scope_system); + worker_data = worker_data_local; + } + if (threadIdx.x < raft::WarpSize) { + // Sync one warp and copy descriptor data + static_assert(job_desc_type::kBlobSize <= raft::WarpSize); + constexpr uint32_t kMaxJobsNum = 8192; + job_ix = raft::shfl(job_ix, 0); + if (threadIdx.x < job_desc_type::kBlobSize && job_ix < kMaxJobsNum) { + job_descriptor.blob[threadIdx.x] = job_descriptors[job_ix].input.blob[threadIdx.x]; + } + } + __syncthreads(); + if (worker_data.handle == kNoMoreWork) { break; } + + // reading phase + auto result_indices_ptr = job_descriptor.value.result_indices_ptr; + auto* result_distances_ptr = job_descriptor.value.result_distances_ptr; + auto* queries_ptr = job_descriptor.value.queries_ptr; + auto top_k = job_descriptor.value.top_k; + auto n_queries = job_descriptor.value.n_queries; + auto query_id = worker_data.value.query_id; + + // work phase - use JIT search_core + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + dataset_desc, + bitset); + + // make sure all writes are visible even for the host + // (e.g. when result buffers are in pinned memory) + cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); + + // arrive to mark the end of the work phase + __syncthreads(); + if (threadIdx.x == 0) { + auto completed_count = atomicInc(completion_counters + job_ix, n_queries - 1) + 1; + if (completed_count >= n_queries) { + job_descriptors[job_ix].completion_flag.store(true, cuda::memory_order_relaxed); + } + } + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in new file mode 100644 index 0000000000..26b363a90a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr bool k_topk_by_bitonic_sort = @topk_by_bitonic_sort@; +constexpr bool k_bitonic_sort_and_merge_multi_warps = @bitonic_sort_and_merge_multi_warps@; + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; +using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta( + uintptr_t topk_indices_ptr, + distance_t* const topk_distances_ptr, + const std::uint32_t topk, + const data_t* const queries_ptr, + const index_t* const knn_graph, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const uint32_t num_seeds, + index_t* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, + const dataset_desc_base* dataset_desc, + const index_t graph_size, + cagra_bitset_t bitset) +{ + single_cta_search::search_kernel_jit(topk_indices_ptr, + topk_distances_ptr, + topk, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id_offset, + dataset_desc, + graph_size, + bitset); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json new file mode 100644 index 0000000000..ba1334f11d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json @@ -0,0 +1,21 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}], + "_topk_by_bitonic": [ + {"topk_by_bitonic_sort": "true", "topk_by_bitonic_sort_str": "topk_by_bitonic_sort"}, + {"topk_by_bitonic_sort": "false", "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort"} + ], + "_bitonic_sort_and_merge_multi_warps": [ + {"bitonic_sort_and_merge_multi_warps": "true", "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps"}, + {"bitonic_sort_and_merge_multi_warps": "false", "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps"} + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in new file mode 100644 index 0000000000..3667e16f70 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace scta_jit = cuvs::neighbors::cagra::detail::single_cta_search; + +namespace { + +constexpr bool k_topk_by_bitonic_sort = @topk_by_bitonic_sort@; +constexpr bool k_bitonic_sort_and_merge_multi_warps = @bitonic_sort_and_merge_multi_warps@; + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using dataset_desc_base = + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; +using job_descriptor_batch = + scta_jit::job_desc_t>; +using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta_p( + worker_handle_t* worker_handles, + job_descriptor_batch* job_descriptors, + uint32_t* completion_counters, + const index_t* const knn_graph, + const std::uint32_t graph_degree, + const source_index_t* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const index_t* seed_ptr, + const uint32_t num_seeds, + index_t* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, + const dataset_desc_base* dataset_desc, + cagra_bitset_t bitset) +{ + search_single_cta_p_impl(worker_handles, + job_descriptors, + completion_counters, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id_offset, + dataset_desc, + bitset); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json new file mode 100644 index 0000000000..ba1334f11d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json @@ -0,0 +1,21 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}], + "_topk_by_bitonic": [ + {"topk_by_bitonic_sort": "true", "topk_by_bitonic_sort_str": "topk_by_bitonic_sort"}, + {"topk_by_bitonic_sort": "false", "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort"} + ], + "_bitonic_sort_and_merge_multi_warps": [ + {"bitonic_sort_and_merge_multi_warps": "true", "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps"}, + {"bitonic_sort_and_merge_multi_warps": "false", "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps"} + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp new file mode 100644 index 0000000000..a1cbe39dfe --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template +struct CagraSingleCtaSearchPlanner + : CagraPlannerBase { + static inline LauncherJitCache launcher_jit_cache{}; + + CagraSingleCtaSearchPlanner(cuvs::distance::DistanceType /*metric*/, + bool /*topk_by_bitonic_sort*/, + bool /*bitonic_sort_and_merge_multi_warps*/, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/, + bool persistent = false) + : CagraPlannerBase( + persistent ? "search_single_cta_p" : "search_single_cta", launcher_jit_cache) + { + } + + void add_search_kernel_fragment(bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + bool persistent) + { + if (persistent) { + if (topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else if (topk_by_bitonic_sort && !bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else if (!topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else { + this->template add_static_fragment>(); + } + } else { + if (topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else if (topk_by_bitonic_sort && !bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else if (!topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else { + this->template add_static_fragment>(); + } + } + } +}; + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh new file mode 100644 index 0000000000..8cdd7febd5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh @@ -0,0 +1,192 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../neighbors_device_intrinsics.cuh" +#include "../compute_distance_standard-impl.cuh" +#include "../compute_distance_standard.hpp" +#include "../compute_distance_vpq-impl.cuh" +#include "../compute_distance_vpq.hpp" +#include "../device_memory_ops.hpp" + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +_RAFT_DEVICE __noinline__ auto setup_workspace_standard_impl( + const DescriptorT* that, + void* smem_ptr, + const typename DescriptorT::DATA_T* queries_ptr, + uint32_t query_id) -> const DescriptorT* +{ + using DATA_T = typename DescriptorT::DATA_T; + using LOAD_T = typename DescriptorT::LOAD_T; + using QUERY_T = typename DescriptorT::QUERY_T; + using word_type = uint32_t; + constexpr auto kTeamSize = DescriptorT::kTeamSize; + constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; + auto* r = reinterpret_cast(smem_ptr); + auto* buf = reinterpret_cast(r + 1); + if (r != that) { + constexpr uint32_t kCount = sizeof(DescriptorT) / sizeof(word_type); + using blob_type = word_type[kCount]; + auto& src = reinterpret_cast(*that); + auto& dst = reinterpret_cast(*r); + for (uint32_t i = threadIdx.x; i < kCount; i += blockDim.x) { + dst[i] = src[i]; + } + const auto smem_ptr_offset = + reinterpret_cast(&(r->args.smem_ws_ptr)) - reinterpret_cast(r); + if (threadIdx.x == uint32_t(smem_ptr_offset / sizeof(word_type))) { + r->args.smem_ws_ptr = uint32_t(__cvta_generic_to_shared(buf)); + } + __syncthreads(); + } + + uint32_t dim = r->args.dim; + auto buf_len = raft::round_up_safe(dim, kDatasetBlockDim); + constexpr auto vlen = device::get_vlen(); + queries_ptr += dim * query_id; + for (unsigned i = threadIdx.x; i < buf_len; i += blockDim.x) { + unsigned j = device::swizzling(i); + if (i < dim) { + buf[j] = cuvs::spatial::knn::detail::utils::mapping{}(queries_ptr[i]); + } else { + buf[j] = 0; + } + } + return r; +} + +template +RAFT_DEVICE_INLINE_FUNCTION constexpr auto transpose(T x) -> T +{ + auto i = x % Block; + auto j = x / Block; + auto k = i % Stride; + auto l = i / Stride; + return j * Block + k * (Block / Stride) + l; +} + +template +_RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( + const DescriptorT* that, + void* smem_ptr, + const typename DescriptorT::DATA_T* queries_ptr, + uint32_t query_id) -> const DescriptorT* +{ + using QUERY_T = typename DescriptorT::QUERY_T; + using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; + using word_type = uint32_t; + constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; + constexpr auto PQ_BITS = DescriptorT::kPqBits; + constexpr auto PQ_LEN = DescriptorT::kPqLen; + + auto* r = reinterpret_cast(smem_ptr); + + if (r != that) { + constexpr uint32_t kCount = sizeof(DescriptorT) / sizeof(word_type); + using blob_type = word_type[kCount]; + auto& src = reinterpret_cast(*that); + auto& dst = reinterpret_cast(*r); + for (uint32_t i = threadIdx.x; i < kCount; i += blockDim.x) { + dst[i] = src[i]; + } + + auto codebook_buf = uint32_t(__cvta_generic_to_shared(r + 1)); + const auto smem_ptr_offset = + reinterpret_cast(&(r->args.smem_ws_ptr)) - reinterpret_cast(r); + if (threadIdx.x == uint32_t(smem_ptr_offset / sizeof(word_type))) { + r->args.smem_ws_ptr = codebook_buf; + } + __syncthreads(); + + for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { + half2 buf2; + buf2.x = r->pq_code_book_ptr()[i]; + buf2.y = r->pq_code_book_ptr()[i + 1]; + + constexpr auto num_elements_per_bank = 4 / utils::size_of(); + constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; + const auto j = i / num_elements_per_bank; + const auto smem_index = + (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); + + device::sts(codebook_buf + smem_index * sizeof(half2), buf2); + } + } + + uint32_t dim = r->args.dim; + queries_ptr += dim * query_id; + + constexpr cuvs::spatial::knn::detail::utils::mapping mapping{}; + auto smem_query_ptr = + reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + + DescriptorT::kSMemCodeBookSizeInBytes); + for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { + half2 buf2{0, 0}; + if (i < dim) { buf2.x = mapping(queries_ptr[i]); } + if (i + 1 < dim) { buf2.y = mapping(queries_ptr[i + 1]); } + if constexpr ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { + constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** + constexpr auto kStride = vlen * PQ_LEN / 2; + reinterpret_cast(smem_query_ptr)[transpose(i / 2)] = + buf2; + } else { + (reinterpret_cast(smem_query_ptr + i))[0] = buf2; + } + } + + return r; +} + +template +__device__ const dataset_descriptor_base_t* setup_workspace_impl( + const dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id) +{ + if constexpr (PQ_BITS == 0 && PQ_LEN == 0 && std::is_same_v) { + using desc_t = + standard_dataset_descriptor_t; + const desc_t* desc = static_cast(desc_ptr); + + const desc_t* result = setup_workspace_standard_impl(desc, smem, queries, query_id); + return static_cast*>(result); + } else if constexpr (PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && + std::is_same_v) { + using desc_t = cagra_q_dataset_descriptor_t; + const desc_t* desc = static_cast(desc_ptr); + + const desc_t* result = setup_workspace_vpq_impl(desc, smem, queries, query_id); + return static_cast*>(result); + } else { + static_assert( + sizeof(DataT) == 0, + "setup_workspace_impl: unsupported PQ_BITS/PQ_LEN/CodebookT/QueryT for CAGRA JIT"); + return nullptr; + } +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in new file mode 100644 index 0000000000..fa17705250 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -0,0 +1,45 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr uint32_t k_team_size = @team_size@u; +constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; +constexpr uint32_t k_pq_bits = @pq_bits@u; +constexpr uint32_t k_pq_len = @pq_len@u; + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using query_t = @query_type@; +using codebook_t = @codebook_type@; + +} // namespace + +namespace cuvs::neighbors::cagra::detail { + +template <> +__device__ const dataset_descriptor_base_t* +setup_workspace( + const dataset_descriptor_base_t* desc, + void* smem, + const data_t* queries, + uint32_t query_id) +{ + return setup_workspace_impl(desc, smem, queries, query_id); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json new file mode 100644 index 0000000000..83aa8764bc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -0,0 +1,154 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "u8" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "i8", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_prefix": "_standard", + "pq_suffix": "", + "pq_bits": "0", + "pq_len": "0" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_abbrev": "none" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "2", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_len": "4", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 4d4ddb9b80..baf9336e6d 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -4,8 +4,9 @@ */ #pragma once +#include "../neighbors_device_intrinsics.cuh" #include "bitonic.hpp" -#include "device_common.hpp" +#include "device_memory_ops.hpp" #include "hashmap.hpp" #include "search_multi_cta_kernel.cuh" #include "search_plan.cuh" @@ -30,6 +31,8 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp +#include + #include #include #include @@ -91,10 +94,10 @@ struct search constexpr static bool kNeedIndexCopy = sizeof(INDEX_T) != sizeof(OutputIndexT); uint32_t num_cta_per_query; - lightweight_uvector intermediate_indices; - lightweight_uvector intermediate_distances; + rmm::device_uvector intermediate_indices; + rmm::device_uvector intermediate_distances; size_t topk_workspace_size; - lightweight_uvector topk_workspace; + rmm::device_uvector topk_workspace; search(raft::resources const& res, search_params params, @@ -104,9 +107,9 @@ struct search int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), - intermediate_indices(res), - intermediate_distances(res), - topk_workspace(res) + intermediate_indices(0, raft::resource::get_cuda_stream(res)), + intermediate_distances(0, raft::resource::get_cuda_stream(res)), + topk_workspace(0, raft::resource::get_cuda_stream(res)) { set_params(res, params); } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_helpers.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_helpers.cuh new file mode 100644 index 0000000000..fe985f7275 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_helpers.cuh @@ -0,0 +1,138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parent( + INDEX_T* const next_parent_indices, + INDEX_T* const itopk_indices, // [itopk_size * 2] + DISTANCE_T* const itopk_distances, // [itopk_size * 2] + INDEX_T* const hash_ptr, + const uint32_t hash_bitlen) +{ + constexpr uint32_t itopk_size = 32; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr INDEX_T invalid_index = ~static_cast(0); + + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + if (threadIdx.x == 0) { next_parent_indices[0] = invalid_index; } + __syncwarp(); + + int j = -1; + for (unsigned i = threadIdx.x; i < itopk_size * 2; i += 32) { + INDEX_T index = itopk_indices[i]; + int is_invalid = 0; + int is_candidate = 0; + if (index == invalid_index) { + is_invalid = 1; + } else if (index & index_msb_1_mask) { + } else { + is_candidate = 1; + } + + const auto ballot_mask = __ballot_sync(0xffffffff, is_candidate); + const auto candidate_id = __popc(ballot_mask & ((1 << threadIdx.x) - 1)); + for (int k = 0; k < __popc(ballot_mask); k++) { + int flag_done = 0; + if (is_candidate && candidate_id == k) { + is_candidate = 0; + if (hashmap::insert(hash_ptr, hash_bitlen, index)) { + // Use this candidate as next parent + index |= index_msb_1_mask; // set most significant bit as used node + if (i < itopk_size) { + next_parent_indices[0] = i; + itopk_indices[i] = index; + } else { + next_parent_indices[0] = j; + // Move the next parent node from i-th position to j-th position + itopk_indices[j] = index; + itopk_distances[j] = itopk_distances[i]; + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + } + flag_done = 1; + } else { + // Deactivate the node since it has been used by other CTA. + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + is_invalid = 1; + } + } + if (__any_sync(0xffffffff, (flag_done > 0))) { return; } + } + if (i < itopk_size) { + j = 31 - __clz(__ballot_sync(0xffffffff, is_invalid)); + if (j < 0) { return; } + } + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements] + INDEX_T* indices, // [num_elements] + const uint32_t num_elements) +{ + const unsigned warp_id = threadIdx.x / raft::warp_size(); + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % raft::warp_size(); + constexpr unsigned N = (MAX_ELEMENTS + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + INDEX_T val[N]; + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_elements) { + key[i] = distances[j]; + val[i] = indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = ~static_cast(0); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store sorted results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_elements) { + distances[j] = key[i]; + indices[j] = val[i]; + } + } +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_64( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<64, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_128( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<128, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_256( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<256, uint32_t>(distances, indices, num_elements); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh index bd4d25d8f3..72c40c6973 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh @@ -1,37 +1,40 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "search_multi_cta_kernel-inl.cuh" +#include "../../sample_filter.cuh" +#include "sample_filter_utils.cuh" +#include "search_multi_cta_kernel.cuh" +#include "search_multi_cta_kernel_launcher_jit.cuh" #include namespace cuvs::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ - template void select_and_run( \ - const dataset_descriptor_host& dataset_desc, \ - raft::device_matrix_view graph, \ - const IndexT* source_indices_ptr, \ - uint32_t* topk_indices_ptr, \ - DistanceT* topk_distances_ptr, \ - const DataT* queries_ptr, \ - uint32_t num_queries, \ - const uint32_t* dev_seed_ptr, \ - uint32_t* num_executed_iterations, \ - const search_params& ps, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - uint32_t small_hash_bitlen, \ - int64_t hash_bitlen, \ - uint32_t* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_seeds, \ - SampleFilterT sample_filter, \ +#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ + template void select_and_run( \ + const dataset_descriptor_host& dataset_desc, \ + raft::device_matrix_view graph, \ + const IndexT* source_indices_ptr, \ + IndexT* topk_indices_ptr, \ + DistanceT* topk_distances_ptr, \ + const DataT* queries_ptr, \ + uint32_t num_queries, \ + const IndexT* dev_seed_ptr, \ + uint32_t* num_executed_iterations, \ + const search_params& ps, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + uint32_t small_hash_bitlen, \ + int64_t hash_bitlen, \ + IndexT* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_seeds, \ + SampleFilterT sample_filter, \ cudaStream_t stream); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh deleted file mode 100644 index 4ac0020c5c..0000000000 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ /dev/null @@ -1,636 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "search_multi_cta_kernel.cuh" - -#include "bitonic.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_for_cagra/topk.h" // TODO replace with raft topk if possible -#include "utils.hpp" -#include - -#include -#include -#include -#include -#include - -#include - -#include - -// TODO: This shouldn't be invoking anything from spatial/knn -#include "../ann_utils.cuh" -#include "../smem_utils.cuh" - -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace multi_cta_search { - -// #define _CLK_BREAKDOWN - -template -RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parent( - INDEX_T* const next_parent_indices, - INDEX_T* const itopk_indices, // [itopk_size * 2] - DISTANCE_T* const itopk_distances, // [itopk_size * 2] - INDEX_T* const hash_ptr, - const uint32_t hash_bitlen) -{ - constexpr uint32_t itopk_size = 32; - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - constexpr INDEX_T invalid_index = ~static_cast(0); - - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - if (threadIdx.x == 0) { next_parent_indices[0] = invalid_index; } - __syncwarp(); - - int j = -1; - for (unsigned i = threadIdx.x; i < itopk_size * 2; i += 32) { - INDEX_T index = itopk_indices[i]; - int is_invalid = 0; - int is_candidate = 0; - if (index == invalid_index) { - is_invalid = 1; - } else if (index & index_msb_1_mask) { - } else { - is_candidate = 1; - } - - const auto ballot_mask = __ballot_sync(0xffffffff, is_candidate); - const auto candidate_id = __popc(ballot_mask & ((1 << threadIdx.x) - 1)); - for (int k = 0; k < __popc(ballot_mask); k++) { - int flag_done = 0; - if (is_candidate && candidate_id == k) { - is_candidate = 0; - if (hashmap::insert(hash_ptr, hash_bitlen, index)) { - // Use this candidate as next parent - index |= index_msb_1_mask; // set most significant bit as used node - if (i < itopk_size) { - next_parent_indices[0] = i; - itopk_indices[i] = index; - } else { - next_parent_indices[0] = j; - // Move the next parent node from i-th position to j-th position - itopk_indices[j] = index; - itopk_distances[j] = itopk_distances[i]; - itopk_indices[i] = invalid_index; - itopk_distances[i] = utils::get_max_value(); - } - flag_done = 1; - } else { - // Deactivate the node since it has been used by other CTA. - itopk_indices[i] = invalid_index; - itopk_distances[i] = utils::get_max_value(); - is_invalid = 1; - } - } - if (__any_sync(0xffffffff, (flag_done > 0))) { return; } - } - if (i < itopk_size) { - j = 31 - __clz(__ballot_sync(0xffffffff, is_invalid)); - if (j < 0) { return; } - } - } -} - -template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements] - INDEX_T* indices, // [num_elements] - const uint32_t num_elements) -{ - const unsigned warp_id = threadIdx.x / raft::warp_size(); - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % raft::warp_size(); - constexpr unsigned N = (MAX_ELEMENTS + (raft::warp_size() - 1)) / raft::warp_size(); - float key[N]; - INDEX_T val[N]; - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (raft::warp_size() * i); - if (j < num_elements) { - key[i] = distances[j]; - val[i] = indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = ~static_cast(0); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store sorted results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_elements) { - distances[j] = key[i]; - indices[j] = val[i]; - } - } -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_64( - float* distances, // [num_elements] - uint32_t* indices, // [num_elements] - const uint32_t num_elements) -{ - topk_by_bitonic_sort<64, uint32_t>(distances, indices, num_elements); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_128( - float* distances, // [num_elements] - uint32_t* indices, // [num_elements] - const uint32_t num_elements) -{ - topk_by_bitonic_sort<128, uint32_t>(distances, indices, num_elements); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_256( - float* distances, // [num_elements] - uint32_t* indices, // [num_elements] - const uint32_t num_elements) -{ - topk_by_bitonic_sort<256, uint32_t>(distances, indices, num_elements); -} - -// -// multiple CTAs per single query -// -template -RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( - typename DATASET_DESCRIPTOR_T::INDEX_T* const - result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const - result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] - const DATASET_DESCRIPTOR_T* dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const uint32_t max_elements, - const uint32_t graph_degree, - const SourceIndexT* source_indices_ptr, // [num_queries, search_width] - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - const uint32_t visited_hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] - const uint32_t traversed_hash_bitlen, - const uint32_t itopk_size, - const uint32_t min_iteration, - const uint32_t max_iteration, - uint32_t* const num_executed_iterations, /* stats */ - SAMPLE_FILTER_T sample_filter, - const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0) -{ - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - auto to_source_index = [source_indices_ptr](INDEX_T x) { - return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; - }; - - const auto num_queries = gridDim.y; - const auto query_id = blockIdx.y; - const auto num_cta_per_query = gridDim.x; - const auto cta_id = blockIdx.x; // local CTA ID - -#ifdef _CLK_BREAKDOWN - uint64_t clk_init = 0; - uint64_t clk_compute_1st_distance = 0; - uint64_t clk_topk = 0; - uint64_t clk_pickup_parents = 0; - uint64_t clk_compute_distance = 0; - uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ uint8_t smem[]; - - // Layout of result_buffer - // +----------------+---------+---------------------------+ - // | internal_top_k | padding | neighbors of parent nodes | - // | | upto 32 | | - // +----------------+---------+---------------------------+ - // |<--- result_buffer_size_32 --->| - const auto result_buffer_size = itopk_size + graph_degree; - const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); - assert(result_buffer_size_32 <= max_elements); - - // Set smem working buffer for the distance calculation - dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id); - - auto* __restrict__ result_indices_buffer = - reinterpret_cast(smem + dataset_desc->smem_ws_size_in_bytes()); - auto* __restrict__ result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto* __restrict__ local_visited_hashmap_ptr = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ parent_indices_buffer = - reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); - auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); - - INDEX_T* const local_traversed_hashmap_ptr = - traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); - - constexpr INDEX_T invalid_index = ~static_cast(0); - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } - hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - uint32_t block_id = cta_id + (num_cta_per_query * query_id); - uint32_t num_blocks = num_cta_per_query * num_queries; - - device::compute_distance_to_random_nodes(result_indices_buffer, - result_distances_buffer, - *dataset_desc, - graph_degree, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - visited_hash_bitlen, - local_traversed_hashmap_ptr, - traversed_hash_bitlen, - block_id, - num_blocks, - graph_size); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - uint32_t iter = 0; - while (1) { - _CLK_START(); - if (threadIdx.x < 32) { - // [1st warp] Topk with bitonic sort - if constexpr (std::is_same_v) { - // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort - // function (vs post-inlining, this impacts register pressure) - if (max_elements <= 64) { - topk_by_bitonic_sort_wrapper_64( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - } else if (max_elements <= 128) { - topk_by_bitonic_sort_wrapper_128( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - } else { - assert(max_elements <= 256); - topk_by_bitonic_sort_wrapper_256( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - } - } else { - if (max_elements <= 64) { - topk_by_bitonic_sort<64, INDEX_T>( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - } else if (max_elements <= 128) { - topk_by_bitonic_sort<128, INDEX_T>( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - } else { - assert(max_elements <= 256); - topk_by_bitonic_sort<256, INDEX_T>( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); - } - } - } - __syncthreads(); - _CLK_REC(clk_topk); - - if (iter + 1 >= max_iteration) { break; } - - _CLK_START(); - if (threadIdx.x < 32) { - // [1st warp] Pick up a next parent - pickup_next_parent(parent_indices_buffer, - result_indices_buffer, - result_distances_buffer, - local_traversed_hashmap_ptr, - traversed_hash_bitlen); - } else { - // [Other warps] Reset visited hashmap - hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); - } - __syncthreads(); - _CLK_REC(clk_pickup_parents); - - if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } - - _CLK_START(); - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index) { continue; } - if ((i >= itopk_size) && (index & index_msb_1_mask)) { - // Remove nodes kicked out of the itopk list from the traversed hash table. - hashmap::remove( - local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } else { - // Restore visited hashmap by putting nodes on result buffer in it. - index &= ~index_msb_1_mask; - hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); - } - } - // Initialize buffer for compute_distance_to_child_nodes. - if (threadIdx.x == blockDim.x - 1) { result_position[0] = result_buffer_size_32; } - __syncthreads(); - - // Compute the norms between child nodes and query node - device::compute_distance_to_child_nodes( - result_indices_buffer, - result_distances_buffer, - *dataset_desc, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - visited_hash_bitlen, - local_traversed_hashmap_ptr, - traversed_hash_bitlen, - parent_indices_buffer, - result_indices_buffer, - 1, - result_position, - result_buffer_size_32); - // __syncthreads(); - - // Check the state of the nodes in the result buffer which were not updated - // by the compute_distance_to_child_nodes above, and if it cannot be used as - // a parent node, it is deactivated. - for (uint32_t i = threadIdx.x; i < result_position[0]; i += blockDim.x) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index || index & index_msb_1_mask) { continue; } - if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } - } - __syncthreads(); - _CLK_REC(clk_compute_distance); - - // Filtering - if constexpr (!std::is_same::value) { - for (unsigned p = threadIdx.x; p < 1; p += blockDim.x) { - if (parent_indices_buffer[p] != invalid_index) { - const auto parent_id = - result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; - if (!sample_filter(query_id, to_source_index(parent_id))) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); - result_indices_buffer[parent_indices_buffer[p]] = invalid_index; - } - } - } - __syncthreads(); - } - - iter++; - } - - // Filtering - if constexpr (!std::is_same::value) { - for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - INDEX_T index = result_indices_buffer[i]; - if (index == invalid_index) { continue; } - index &= ~index_msb_1_mask; - if (!sample_filter(query_id, to_source_index(index))) { - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } - } - __syncthreads(); - } - - // Output search results (1st warp only). - if (threadIdx.x < 32) { - uint32_t offset = 0; - for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { - INDEX_T index = result_indices_buffer[i]; - bool is_valid = false; - if (index != invalid_index) { - if (index & index_msb_1_mask) { - is_valid = true; - index &= ~index_msb_1_mask; - } else if ((offset < itopk_size) && - hashmap::insert( - local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { - // If a node that is not used as a parent can be inserted into - // the traversed hash table, it is considered a valid result. - is_valid = true; - } - } - const auto mask = __ballot_sync(0xffffffff, is_valid); - if (is_valid) { - const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); - if (j < itopk_size) { - uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - result_indices_ptr[k] = index & ~index_msb_1_mask; - if (result_distances_ptr != nullptr) { - result_distances_ptr[k] = result_distances_buffer[i]; - } - } else { - // If it is valid and registered in the traversed hash table but is - // not output as a result, it is removed from the hash table. - hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); - } - } - offset += __popc(mask); - } - // If the number of outputs is insufficient, fill in with invalid results. - for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { - uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - result_indices_ptr[k] = invalid_index; - if (result_distances_ptr != nullptr) { - result_distances_ptr[k] = utils::get_max_value(); - } - } - } - - if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } - -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) && - ((query_id * 3) % gridDim.y < 3)) { - printf( - "%s:%d " - "query, %d, thread, %d" - ", init, %lu" - ", 1st_distance, %lu" - ", topk, %lu" - ", pickup_parents, %lu" - ", distance, %lu" - "\n", - __FILE__, - __LINE__, - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_pickup_parents, - clk_compute_distance); - } -#endif -} - -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - -template -struct search_kernel_config { - // Search kernel function type. Note that the actual values for the template value - // parameters do not matter, because they are not part of the function signature. The - // second to fourth value parameters will be selected by the choose_* functions below. - using kernel_t = decltype(&search_kernel); - - static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t - { - if (result_buffer_size <= 256) { - return search_kernel; - } - THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); - } -}; - -template -void select_and_run(const dataset_descriptor_host& dataset_desc, - raft::device_matrix_view graph, - const SourceIndexT* source_indices_ptr, - IndexT* topk_indices_ptr, // [num_queries, topk] - DistanceT* topk_distances_ptr, // [num_queries, topk] - const DataT* queries_ptr, // [num_queries, dataset_dim] - uint32_t num_queries, - const IndexT* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* num_executed_iterations, // [num_queries,] - const search_params& ps, - uint32_t topk, - // multi_cta_search (params struct) - uint32_t block_size, // - uint32_t result_buffer_size, - uint32_t smem_size, - uint32_t visited_hash_bitlen, - int64_t traversed_hash_bitlen, - IndexT* traversed_hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_seeds, - SampleFilterT sample_filter, - cudaStream_t stream) -{ - auto kernel = - search_kernel_config, - SourceIndexT, - SampleFilterT>::choose_buffer_size(result_buffer_size, block_size); - - uint32_t max_elements{}; - if (result_buffer_size <= 64) { - max_elements = 64; - } else if (result_buffer_size <= 128) { - max_elements = 128; - } else if (result_buffer_size <= 256) { - max_elements = 256; - } else { - THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); - } - - // Initialize hash table - const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); - set_value_batch(traversed_hashmap_ptr, - traversed_hash_size, - ~static_cast(0), - traversed_hash_size, - num_queries, - stream); - - dim3 block_dims(block_size, 1, 1); - dim3 grid_dims(num_cta_per_query, num_queries, 1); - RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %u smem", - block_size, - num_cta_per_query, - num_queries, - smem_size); - - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - dataset_desc.dev_ptr(stream), - queries_ptr, - graph.data_handle(), - max_elements, - graph.extent(1), - source_indices_ptr, - ps.num_random_samplings, - ps.rand_xor_mask, - dev_seed_ptr, - num_seeds, - visited_hash_bitlen, - traversed_hashmap_ptr, - traversed_hash_bitlen, - ps.itopk_size, - ps.min_iterations, - ps.max_iterations, - num_executed_iterations, - sample_filter, - static_cast(graph.extent(0))); -} - -} // namespace multi_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..663f8a4559 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -0,0 +1,154 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../smem_utils.cuh" + +// Include tags header before any other includes that might open namespaces +#include + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" +#include "jit_lto_kernels/kernel_def.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "set_value_batch.cuh" // For set_value_batch +#include "shared_launcher_jit.hpp" // For shared JIT helper functions +#include +#include +#include +#include + +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +void select_and_run(const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + IndexT* topk_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DistanceT* topk_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + // multi_cta_search (params struct) + uint32_t block_size, // + uint32_t result_buffer_size, + uint32_t smem_size, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, + uint32_t num_cta_per_query, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const auto bf = extract_cagra_sample_filter(sample_filter); + const uint32_t query_id_offset = bf.query_id_offset; + + std::shared_ptr launcher = + make_cagra_multi_cta_jit_launcher>(dataset_desc); + + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher"); } + + uint32_t max_elements{}; + if (result_buffer_size <= 64) { + max_elements = 64; + } else if (result_buffer_size <= 128) { + max_elements = 128; + } else if (result_buffer_size <= 256) { + max_elements = 256; + } else { + THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); + } + + // Initialize hash table + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + num_queries, + stream); + + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(num_cta_per_query, num_queries, 1); + + // Get the device descriptor pointer + const dataset_descriptor_base_t* dev_desc_base = + dataset_desc.dev_ptr(stream); + const auto* dev_desc = dev_desc_base; + + // Note: dataset_desc is passed by const reference, so it stays alive for the duration of this + // function The descriptor's state is managed by a shared_ptr internally, so no need to explicitly + // keep it alive + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + // graph.extent(1) returns int64_t but kernel expects uint32_t + // traversed_hash_bitlen is int64_t but kernel expects uint32_t + // ps.itopk_size, ps.min_iterations, ps.max_iterations are size_t (8 bytes) but kernel expects + // uint32_t (4 bytes) ps.num_random_samplings is uint32_t but kernel expects unsigned - cast for + // consistency + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t traversed_hash_bitlen_u32 = static_cast(traversed_hash_bitlen); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + auto kernel_launcher = [&]() -> void { + launcher->dispatch< + multi_cta_search::search_multi_cta_kernel_func_t>( + stream, + grid_dims, + block_dims, + smem_size, + topk_indices_ptr, + topk_distances_ptr, + dev_desc, + queries_ptr, + graph.data_handle(), + max_elements, + graph_degree_u32, + source_indices_ptr, + num_random_samplings_u, + ps.rand_xor_mask, + dev_seed_ptr, + num_seeds, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen_u32, + itopk_size_u32, + min_iterations_u32, + max_iterations_u32, + num_executed_iterations, + static_cast(graph.extent(0)), + query_id_offset, + bf.bitset); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< + multi_cta_search::search_multi_cta_kernel_func_t>( + smem_size, kernel_launcher, launcher->get_kernel()); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index 7aab9e241b..f8ec47ab7b 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -4,7 +4,11 @@ */ #pragma once -#include "device_common.hpp" +#include "search_multi_kernel_launcher_jit.cuh" +#include + +#include "set_value_batch.cuh" + #include "hashmap.hpp" #include "search_plan.cuh" #include "topk_for_cagra/topk.h" //todo replace with raft kernel @@ -90,109 +94,6 @@ auto get_value(const T* const dev_ptr, cudaStream_t stream) -> T return value; } -// MAX_DATASET_DIM : must equal to or greater than dataset_dim -template -RAFT_KERNEL random_pickup_kernel( - const DATASET_DESCRIPTOR_T* dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldr] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::uint32_t ldr, // (*) ldr >= num_pickup - typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen, - const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0) -{ - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - const auto team_size_bits = dataset_desc->team_size_bitshift(); - const auto ldb = hashmap::get_size(hash_bitlen); - const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) >> team_size_bits; - const uint32_t query_id = blockIdx.y; - if (global_team_index >= num_pickup) { return; } - extern __shared__ uint8_t smem[]; - dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id); - __syncthreads(); - - const INDEX_T seed_index_limit = graph_size > 0 ? graph_size : dataset_desc->size; - - INDEX_T best_index_team_local; - DISTANCE_T best_norm2_team_local = utils::get_max_value(); - for (unsigned i = 0; i < num_distilation; i++) { - INDEX_T seed_index; - if (seed_ptr && (global_team_index < num_seeds)) { - seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; - } else { - // Chose a seed node randomly - seed_index = - device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % seed_index_limit; - } - - DISTANCE_T norm2 = dataset_desc->compute_distance(seed_index, true); - if (norm2 < best_norm2_team_local) { - best_norm2_team_local = norm2; - best_index_team_local = seed_index; - } - } - - const auto store_gmem_index = global_team_index + (ldr * query_id); - if ((threadIdx.x & ((1u << team_size_bits) - 1u)) == 0) { - if (hashmap::insert( - visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { - result_distances_ptr[store_gmem_index] = best_norm2_team_local; - result_indices_ptr[store_gmem_index] = best_index_team_local; - } else { - result_distances_ptr[store_gmem_index] = utils::get_max_value(); - result_indices_ptr[store_gmem_index] = utils::get_max_value(); - } - } -} - -// MAX_DATASET_DIM : must be equal to or greater than dataset_dim -template -void random_pickup(const dataset_descriptor_host& dataset_desc, - const DataT* queries_ptr, // [num_queries, dataset_dim] - std::size_t num_queries, - std::size_t num_pickup, - unsigned num_distilation, - uint64_t rand_xor_mask, - const IndexT* seed_ptr, // [num_queries, num_seeds] - uint32_t num_seeds, - IndexT* result_indices_ptr, // [num_queries, ldr] - DistanceT* result_distances_ptr, // [num_queries, ldr] - std::size_t ldr, // (*) ldr >= num_pickup - IndexT* visited_hashmap_ptr, // [num_queries, 1 << bitlen] - std::uint32_t hash_bitlen, - cudaStream_t cuda_stream, - IndexT graph_size = 0) -{ - const auto block_size = 256u; - const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; - const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, - num_queries); - - random_pickup_kernel<<>>( - dataset_desc.dev_ptr(cuda_stream), - queries_ptr, - num_pickup, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - result_indices_ptr, - result_distances_ptr, - ldr, - visited_hashmap_ptr, - hash_bitlen, - graph_size); -} - template RAFT_KERNEL pickup_next_parents_kernel( INDEX_T* const parent_candidates_ptr, // [num_queries, lds] @@ -293,146 +194,6 @@ void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, terminate_flag); } -template -RAFT_KERNEL compute_distance_to_child_nodes_kernel( - const typename DATASET_DESCRIPTOR_T::INDEX_T* const - parent_node_list, // [num_queries, search_width] - typename DATASET_DESCRIPTOR_T::INDEX_T* const - parent_candidates_ptr, // [num_queries, search_width] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const - parent_distance_ptr, // [num_queries, search_width] - const std::size_t lds, - const std::uint32_t search_width, - const DATASET_DESCRIPTOR_T* dataset_desc, - const typename DATASET_DESCRIPTOR_T::INDEX_T* const - neighbor_graph_ptr, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const SourceIndexT* source_indices_ptr, - const typename DATASET_DESCRIPTOR_T::DATA_T* query_ptr, // [num_queries, data_dim] - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldd] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter) -{ - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - const auto team_size_bits = dataset_desc->team_size_bitshift(); - const auto team_size = 1u << team_size_bits; - const uint32_t ldb = hashmap::get_size(hash_bitlen); - const auto tid = threadIdx.x + blockDim.x * blockIdx.x; - const auto global_team_id = tid >> team_size_bits; - const auto query_id = blockIdx.y; - - extern __shared__ uint8_t smem[]; - // Load a query - dataset_desc = dataset_desc->setup_workspace(smem, query_ptr, query_id); - - __syncthreads(); - if (global_team_id >= search_width * graph_degree) { return; } - - const std::size_t parent_list_index = - parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; - - if (parent_list_index == utils::get_max_value()) { return; } - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; - - if (raw_parent_index == utils::get_max_value()) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); - return; - } - const auto parent_index = raw_parent_index & ~index_msb_1_mask; - - const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); - - const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; - - const auto compute_distance_flag = hashmap::insert( - team_size, visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); - - DISTANCE_T norm2 = dataset_desc->compute_distance(child_id, compute_distance_flag); - - if (compute_distance_flag) { - if ((threadIdx.x & (team_size - 1)) == 0) { - result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; - result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; - } - } else { - if ((threadIdx.x & (team_size - 1)) == 0) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); - } - } - - if constexpr (!std::is_same::value) { - if (!sample_filter( - query_id, - source_indices_ptr == nullptr ? parent_index : source_indices_ptr[parent_index])) { - parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); - parent_distance_ptr[parent_list_index + (lds * query_id)] = - utils::get_max_value(); - } - } -} - -template -void compute_distance_to_child_nodes( - const IndexT* parent_node_list, // [num_queries, search_width] - IndexT* const parent_candidates_ptr, // [num_queries, search_width] - DistanceT* const parent_distance_ptr, // [num_queries, search_width] - std::size_t lds, - uint32_t search_width, - const dataset_descriptor_host& dataset_desc, - const IndexT* neighbor_graph_ptr, // [dataset_size, graph_degree] - std::uint32_t graph_degree, - const SourceIndexT* source_indices_ptr, - const DataT* query_ptr, // [num_queries, data_dim] - std::uint32_t num_queries, - IndexT* visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - std::uint32_t hash_bitlen, - IndexT* result_indices_ptr, // [num_queries, ldd] - DistanceT* result_distances_ptr, // [num_queries, ldd] - std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter, - cudaStream_t cuda_stream) -{ - const auto block_size = 128; - const auto teams_per_block = block_size / dataset_desc.team_size; - const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, - num_queries); - - compute_distance_to_child_nodes_kernel<<>>(parent_node_list, - parent_candidates_ptr, - parent_distance_ptr, - lds, - search_width, - dataset_desc.dev_ptr(cuda_stream), - neighbor_graph_ptr, - graph_degree, - source_indices_ptr, - query_ptr, - visited_hashmap_ptr, - hash_bitlen, - result_indices_ptr, - result_distances_ptr, - ldd, - sample_filter); -} - template RAFT_KERNEL remove_parent_bit_kernel(const std::uint32_t num_queries, const std::uint32_t num_topk, @@ -462,59 +223,6 @@ void remove_parent_bit(const std::uint32_t num_queries, num_queries, num_topk, topk_indices_ptr, ld); } -// This function called after the `remove_parent_bit` function -template -RAFT_KERNEL apply_filter_kernel( - const SourceIndexT* source_indices_ptr, // [num_queries, search_width] - INDEX_T* const result_indices_ptr, - DISTANCE_T* const result_distances_ptr, - const std::size_t lds, - const std::uint32_t result_buffer_size, - const std::uint32_t num_queries, - const INDEX_T query_id_offset, - SAMPLE_FILTER_T sample_filter) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= result_buffer_size * num_queries) { return; } - const auto i = tid % result_buffer_size; - const auto j = tid / result_buffer_size; - const auto index = i + j * lds; - - if (result_indices_ptr[index] != ~index_msb_1_mask && - !sample_filter(query_id_offset + j, - source_indices_ptr == nullptr - ? result_indices_ptr[index] - : source_indices_ptr[result_indices_ptr[index]])) { - result_indices_ptr[index] = utils::get_max_value(); - result_distances_ptr[index] = utils::get_max_value(); - } -} - -template -void apply_filter(const SourceIndexT* source_indices_ptr, - INDEX_T* const result_indices_ptr, - DISTANCE_T* const result_distances_ptr, - const std::size_t lds, - const std::uint32_t result_buffer_size, - const std::uint32_t num_queries, - const INDEX_T query_id_offset, - SAMPLE_FILTER_T sample_filter, - cudaStream_t cuda_stream) -{ - const std::uint32_t block_size = 256; - const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); - - apply_filter_kernel<<>>(source_indices_ptr, - result_indices_ptr, - result_distances_ptr, - lds, - result_buffer_size, - num_queries, - query_id_offset, - sample_filter); -} - template RAFT_KERNEL batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst] const uint64_t ld_dst, @@ -547,34 +255,6 @@ void batched_memcpy(T* const dst, // [batch_size, ld_dst] <<>>(dst, ld_dst, src, ld_src, count, batch_size); } -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - // result_buffer (work buffer) for "multi-kernel" // +--------------------+------------------------------+-------------------+ // | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | @@ -634,18 +314,18 @@ struct search using base_type::num_seeds; size_t result_buffer_allocation_size; - lightweight_uvector result_indices; // results_indices_buffer - lightweight_uvector result_distances; // result_distances_buffer - lightweight_uvector parent_node_list; - lightweight_uvector topk_hint; - lightweight_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; - lightweight_uvector topk_workspace; + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector parent_node_list; + rmm::device_uvector topk_hint; + rmm::device_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; + rmm::device_uvector topk_workspace; // temporary storage for _find_topk - lightweight_uvector input_keys_storage; - lightweight_uvector output_keys_storage; - lightweight_uvector input_values_storage; - lightweight_uvector output_values_storage; + rmm::device_uvector input_keys_storage; + rmm::device_uvector output_keys_storage; + rmm::device_uvector input_values_storage; + rmm::device_uvector output_values_storage; search(raft::resources const& res, search_params params, @@ -655,16 +335,16 @@ struct search int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), - result_indices(res), - result_distances(res), - parent_node_list(res), - topk_hint(res), - topk_workspace(res), - terminate_flag(res), - input_keys_storage(res), - output_keys_storage(res), - input_values_storage(res), - output_values_storage(res) + result_indices(0, raft::resource::get_cuda_stream(res)), + result_distances(0, raft::resource::get_cuda_stream(res)), + parent_node_list(0, raft::resource::get_cuda_stream(res)), + topk_hint(0, raft::resource::get_cuda_stream(res)), + topk_workspace(0, raft::resource::get_cuda_stream(res)), + terminate_flag(0, raft::resource::get_cuda_stream(res)), + input_keys_storage(0, raft::resource::get_cuda_stream(res)), + output_keys_storage(0, raft::resource::get_cuda_stream(res)), + input_values_storage(0, raft::resource::get_cuda_stream(res)), + output_values_storage(0, raft::resource::get_cuda_stream(res)) { set_params(res); } @@ -818,21 +498,29 @@ struct search } // Choose initial entry point candidates at random - random_pickup(dataset_desc, - queries_ptr, - num_queries, - result_buffer_size, - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - result_indices.data(), - result_distances.data(), - result_buffer_allocation_size, - hashmap.data(), - hash_bitlen, - stream, - static_cast(this->dataset_size)); + random_pickup_jit(dataset_desc, + queries_ptr, + num_queries, + result_buffer_size, + num_random_samplings, + rand_xor_mask, + dev_seed_ptr, + num_seeds, + result_indices.data(), + result_distances.data(), + result_buffer_allocation_size, + hashmap.data(), + hash_bitlen, + stream, + static_cast(this->dataset_size)); + + std::shared_ptr compute_distance_to_child_nodes_launcher = + make_cagra_multi_kernel_jit_launcher>( + dataset_desc, "compute_distance_to_child_nodes"); unsigned iter = 0; while (1) { @@ -864,6 +552,7 @@ struct search // pickup parent nodes uint32_t _small_hash_bitlen = 0; if ((iter + 1) % small_hash_reset_interval == 0) { _small_hash_bitlen = small_hash_bitlen; } + pickup_next_parents(result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, result_buffer_allocation_size, itopk_size, @@ -878,13 +567,15 @@ struct search stream); // termination (2) - if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) { - iter++; - break; + if (iter + 1 >= min_iterations) { + if (get_value(terminate_flag.data(), stream)) { + iter++; + break; + } } // Compute distance to child nodes that are adjacent to the parent node - compute_distance_to_child_nodes( + compute_distance_to_child_nodes_jit( parent_node_list.data(), result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, @@ -902,7 +593,8 @@ struct search result_distances.data() + itopk_size, result_buffer_allocation_size, sample_filter, - stream); + stream, + compute_distance_to_child_nodes_launcher); iter++; } // while ( 1 ) @@ -918,7 +610,28 @@ struct search result_buffer_allocation_size, stream); - apply_filter( + using apply_filter_data_tag = decltype(get_data_type_tag()); + using apply_filter_index_tag = decltype(get_index_type_tag()); + using apply_filter_dist_tag = decltype(get_distance_type_tag()); + using apply_filter_source_tag = decltype(get_source_index_type_tag()); + using apply_filter_query_tag = + query_type_tag_standard_t; + using apply_filter_codebook_tag = tag_codebook_none; + CagraMultiKernelSearchPlanner> + apply_filter_planner( + cuvs::distance::DistanceType::L2Expanded, "apply_filter_kernel", 8, 128, false, 0, 0); + apply_filter_planner.add_sample_filter_device_function(); + apply_filter_planner.add_linked_kernel("apply_filter_kernel"); + std::shared_ptr apply_filter_launcher = + apply_filter_planner.get_launcher(); + + apply_filter_jit( source_indices_ptr, result_indices.data() + (iter & 0x1) * itopk_size, result_distances.data() + (iter & 0x1) * itopk_size, @@ -927,7 +640,8 @@ struct search num_queries, 0, sample_filter, - stream); + stream, + apply_filter_launcher); result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; @@ -988,7 +702,6 @@ struct search num_executed_iterations[i] = iter; } } - RAFT_CUDA_TRY(cudaPeekAtLastError()); } }; diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..d0e3db2c99 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -0,0 +1,190 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Tags header should be included before this header (at file scope, not inside functions) +// to avoid namespace definition errors when this header is included inside function bodies + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" +#include "jit_lto_kernels/kernel_def.hpp" +#include "jit_lto_kernels/search_multi_kernel_planner.hpp" +#include "search_plan.cuh" // For search_params +#include "shared_launcher_jit.hpp" // cagra_bitset / cagra_sample_filter, sample_filter_jit_tag_t, tags +#include +#include +#include +#include + +#include +#include +#include +// - The launcher doesn't need the kernel function definitions +// - The kernel is dispatched via the JIT LTO launcher system +// - Including it would pull in impl files that cause namespace issues + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +// JIT version of random_pickup +template +void random_pickup_jit(const dataset_descriptor_host& dataset_desc, + const DataT* queries_ptr, // [num_queries, dataset_dim] + std::size_t num_queries, + std::size_t num_pickup, + unsigned num_distilation, + uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + uint32_t num_seeds, + IndexT* result_indices_ptr, // [num_queries, ldr] + DistanceT* result_distances_ptr, // [num_queries, ldr] + std::size_t ldr, // (*) ldr >= num_pickup + IndexT* visited_hashmap_ptr, // [num_queries, 1 << bitlen] + std::uint32_t hash_bitlen, + cudaStream_t cuda_stream, + IndexT graph_size) +{ + std::shared_ptr launcher = + make_cagra_multi_kernel_jit_launcher(dataset_desc, + "random_pickup"); + + const auto block_size = 256u; + const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; + const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, + num_queries); + + // Get the device descriptor pointer + const auto* dev_desc = dataset_desc.dev_ptr(cuda_stream); + + // Cast size_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t ldr_u32 = static_cast(ldr); + + launcher->dispatch>( + cuda_stream, + grid_size, + dim3(block_size, 1, 1), + dataset_desc.smem_ws_size_in_bytes, + dev_desc, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr_u32, + visited_hashmap_ptr, + hash_bitlen, + graph_size); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// JIT version of compute_distance_to_child_nodes +template +void compute_distance_to_child_nodes_jit( + const IndexT* parent_node_list, // [num_queries, search_width] + IndexT* const parent_candidates_ptr, // [num_queries, search_width] + DistanceT* const parent_distance_ptr, // [num_queries, search_width] + std::size_t lds, + uint32_t search_width, + const dataset_descriptor_host& dataset_desc, + const IndexT* neighbor_graph_ptr, // [dataset_size, graph_degree] + std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const DataT* query_ptr, // [num_queries, data_dim] + std::uint32_t num_queries, + IndexT* visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + std::uint32_t hash_bitlen, + IndexT* result_indices_ptr, // [num_queries, ldd] + DistanceT* result_distances_ptr, // [num_queries, ldd] + std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream, + std::shared_ptr const& launcher) +{ + const auto bf = extract_cagra_sample_filter(sample_filter); + + const auto block_size = 128; + const auto teams_per_block = block_size / dataset_desc.team_size; + const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, + num_queries); + + // Get the device descriptor pointer + const auto* dev_desc = dataset_desc.dev_ptr(cuda_stream); + + launcher->dispatch< + compute_distance_to_child_nodes_kernel_func_t>( + cuda_stream, + grid_size, + dim3(block_size, 1, 1), + dataset_desc.smem_ws_size_in_bytes, + parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dev_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + bf.bitset); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// JIT version of apply_filter +template +void apply_filter_jit(const SourceIndexT* source_indices_ptr, + INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const std::uint32_t query_id_offset, + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream, + std::shared_ptr const& launcher) +{ + // Note: query_id for the linked filter is the function's `query_id_offset` + query index, not + // the wrapper's offset; we only need bitset pointers (same as other JIT launchers). + const auto bf = extract_cagra_sample_filter(sample_filter); + + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); + + // Alias avoids nested `dispatch< alias_template<...>>` which NVCC can misparse as + // comparison/shift. + using apply_filter_kernel_func_t = apply_filter_kernel_func_t; + // `template` required: in template code, `->dispatch<...>` is otherwise parsed as `dispatch <` … + launcher->template dispatch(cuda_stream, + dim3(grid_size, 1, 1), + dim3(block_size, 1, 1), + 0, + source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + bf.bitset); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 02bf1ff697..74e34e0a14 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -4,8 +4,9 @@ */ #pragma once +#include "../neighbors_device_intrinsics.cuh" #include "bitonic.hpp" -#include "device_common.hpp" +#include "device_memory_ops.hpp" #include "hashmap.hpp" #include "search_plan.cuh" #include "search_single_cta_kernel.cuh" @@ -34,6 +35,7 @@ #include #include +// All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { namespace single_cta_search { diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh index 11b468cfca..d242e13b95 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh @@ -1,13 +1,16 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "search_single_cta_kernel-inl.cuh" #include +// Include explicit instantiations before namespace (launcher includes JIT LTO headers with +// namespace definitions) +#include "search_single_cta_kernel_explicit_inst.cuh" + namespace cuvs::neighbors::cagra::detail::single_cta_search { #define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh deleted file mode 100644 index 48553611bf..0000000000 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ /dev/null @@ -1,2355 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "search_single_cta_kernel.cuh" - -#include "bitonic.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_by_radix.cuh" -#include "topk_for_cagra/topk.h" // TODO replace with raft topk -#include "utils.hpp" -#include - -#include -#include -#include -#include -#include -#include - -#include - -// TODO: This shouldn't be invoking anything from spatial/knn -#include "../ann_utils.cuh" -#include "../smem_utils.cuh" - -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace single_cta_search { - -// #define _CLK_BREAKDOWN - -template -RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const terminate_flag, - INDEX_T* const next_parent_indices, - INDEX_T* const internal_topk_indices, - const std::size_t internal_topk_size, - const std::uint32_t search_width) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - // if (threadIdx.x >= 32) return; - - for (std::uint32_t i = threadIdx.x; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - std::uint32_t itopk_max = internal_topk_size; - if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } - std::uint32_t num_new_parents = 0; - for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { - std::uint32_t jj = j; - if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } - INDEX_T index; - int new_parent = 0; - if (j < internal_topk_size) { - index = internal_topk_indices[jj]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = jj; - // set most significant bit as used node - internal_topk_indices[jj] |= index_msb_1_mask; - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) -{ - const unsigned lane_id = threadIdx.x % raft::warp_size(); - const unsigned warp_id = threadIdx.x / raft::warp_size(); - static_assert(MAX_CANDIDATES <= 256); - if constexpr (!MULTI_WARPS) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_CANDIDATES + (raft::warp_size() - 1)) / raft::warp_size(); - float key[N]; - IdxT val[N]; - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (raft::warp_size() * i); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } else { - assert(blockDim.x >= 64); - // Use two warps (64 threads) - constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; - static_assert(max_candidates_per_warp <= 128); - constexpr unsigned N = (max_candidates_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); - float key[N]; - IdxT val[N]; - if (warp_id < 2) { - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = lane_id + (raft::warp_size() * i); - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && jl < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - __syncthreads(); - - unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; - if (warp_id < num_warps_used) { - /* Temp_candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned kl = max_candidates_per_warp - 1 - jl; - unsigned j = jl + (max_candidates_per_warp * warp_id); - unsigned k = MAX_CANDIDATES - 1 - j; - if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; - float temp_key = candidate_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - if (warp_id < num_warps_used) { - /* Merge */ - bitonic::warp_merge(key, val, raft::warp_size()); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - } -} - -template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( - float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - const unsigned lane_id = threadIdx.x % raft::warp_size(); - const unsigned warp_id = threadIdx.x / raft::warp_size(); - - static_assert(MAX_ITOPK <= 512); - if constexpr (!MULTI_WARPS) { - static_assert(MAX_ITOPK <= 256); - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_ITOPK + (raft::warp_size() - 1)) / raft::warp_size(); - float key[N]; - IdxT val[N]; - if (first) { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (raft::warp_size() * i); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - } else { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - key[i] = itopk_distances[device::swizzling(j)]; - val[i] = itopk_indices[device::swizzling(j)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - } - /* Merge candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; // [0:max_itopk-1] - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk || k >= num_candidates) continue; - float candidate_key = candidate_distances[device::swizzling(k)]; - if (key[i] > candidate_key) { - key[i] = candidate_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, raft::warp_size()); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } else { - static_assert(MAX_ITOPK == 512); - assert(blockDim.x >= 64); - // Use two warps (64 threads) or more - constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; - constexpr unsigned N = (max_itopk_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); - float key[N]; - IdxT val[N]; - if (first) { - /* Load itop results (not sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (raft::warp_size() * i) + (max_itopk_per_warp * warp_id); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - __syncthreads(); - if (warp_id < 2) { - /* Load intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk) continue; - float temp_key = itopk_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = itopk_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, raft::warp_size()); - } - __syncthreads(); - /* Store itopk results (sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } - const uint32_t num_itopk_div2 = num_itopk / 2; - if (threadIdx.x < 3) { - // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. - work_buf[threadIdx.x] = num_itopk_div2; - } - __syncthreads(); - - // Merge candidates (using whole threads) - for (unsigned k = threadIdx.x; k < min(num_candidates, num_itopk); k += blockDim.x) { - const unsigned j = num_itopk - 1 - k; - const float itopk_key = itopk_distances[device::swizzling(j)]; - const float candidate_key = candidate_distances[device::swizzling(k)]; - if (itopk_key > candidate_key) { - itopk_distances[device::swizzling(j)] = candidate_key; - itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; - if (j < num_itopk_div2) { - atomicMin(work_buf + 2, j); - } else { - atomicMin(work_buf + 1, j - num_itopk_div2); - } - } - } - __syncthreads(); - - // Merge 1st and 2nd half of itopk (using whole threads) - for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { - const unsigned k = j + num_itopk_div2; - float key_0 = itopk_distances[device::swizzling(j)]; - float key_1 = itopk_distances[device::swizzling(k)]; - if (key_0 > key_1) { - itopk_distances[device::swizzling(j)] = key_1; - itopk_distances[device::swizzling(k)] = key_0; - IdxT val_0 = itopk_indices[device::swizzling(j)]; - IdxT val_1 = itopk_indices[device::swizzling(k)]; - itopk_indices[device::swizzling(j)] = val_1; - itopk_indices[device::swizzling(k)] = val_0; - atomicMin(work_buf + 0, j); - } - } - if (threadIdx.x == blockDim.x - 1) { - if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } - } - __syncthreads(); - // if ((blockIdx.x == 0) && (threadIdx.x == 0)) { - // RAFT_LOG_DEBUG( "work_buf: %u, %u, %u\n", work_buf[0], work_buf[1], work_buf[2] ); - // } - - // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. - if (warp_id < 2) { - // Load intermedidate itopk results - const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 - for (unsigned i = 0; i < N; i++) { - unsigned k = num_itopk; - unsigned j = (N * lane_id) + i; - if (j < turning_point) { - k = j + (num_itopk_div2 * warp_id); - } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { - j -= (MAX_ITOPK / 2 - num_itopk_div2); - if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } - } - if (k < num_itopk) { - key[i] = itopk_distances[device::swizzling(k)]; - val[i] = itopk_indices[device::swizzling(k)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, raft::warp_size()); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - const unsigned j = (N * lane_id) + i; - if (j < num_itopk_div2) { - unsigned k = j + (num_itopk_div2 * warp_id); - itopk_distances[device::swizzling(k)] = key[i]; - itopk_indices[device::swizzling(k)] = val[i]; - } - } - } - } -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_64_false( - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) -{ - topk_by_bitonic_sort_and_full<64, false, uint32_t>( - candidate_distances, candidate_indices, num_candidates, num_itopk); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_128_false( - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) -{ - topk_by_bitonic_sort_and_full<128, false, uint32_t>( - candidate_distances, candidate_indices, num_candidates, num_itopk); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_256_false( - float* candidate_distances, // [num_candidates] - std::uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk) -{ - topk_by_bitonic_sort_and_full<256, false, uint32_t>( - candidate_distances, candidate_indices, num_candidates, num_itopk); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_64_false( - float* itopk_distances, // [num_itopk] - uint32_t* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - topk_by_bitonic_sort_and_merge<64, false, uint32_t>(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_128_false( - float* itopk_distances, // [num_itopk] - uint32_t* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - topk_by_bitonic_sort_and_merge<128, false, uint32_t>(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); -} - -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_256_false( - float* itopk_distances, // [num_itopk] - uint32_t* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - uint32_t* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - topk_by_bitonic_sort_and_merge<256, false, uint32_t>(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); -} - -template -RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( - float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t max_itopk, - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t max_candidates, - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first) -{ - static_assert(std::is_same_v); - assert(max_itopk <= 512); - assert(max_candidates <= 256); - assert(!MULTI_WARPS || blockDim.x >= 64); - - // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_full - // function (vs post-inlining, this impacts register pressure) - if (max_candidates <= 64) { - topk_by_bitonic_sort_and_full_wrapper_64_false( - candidate_distances, candidate_indices, num_candidates, num_itopk); - } else if (max_candidates <= 128) { - topk_by_bitonic_sort_and_full_wrapper_128_false( - candidate_distances, candidate_indices, num_candidates, num_itopk); - } else { - topk_by_bitonic_sort_and_full_wrapper_256_false( - candidate_distances, candidate_indices, num_candidates, num_itopk); - } - - if constexpr (!MULTI_WARPS) { - assert(max_itopk <= 256); - // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_merge - // function (vs post-inlining, this impacts register pressure) - if (max_itopk <= 64) { - topk_by_bitonic_sort_and_merge_wrapper_64_false(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); - } else if (max_itopk <= 128) { - topk_by_bitonic_sort_and_merge_wrapper_128_false(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); - } else { - topk_by_bitonic_sort_and_merge_wrapper_256_false(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); - } - } else { - assert(max_itopk > 256); - topk_by_bitonic_sort_and_merge<512, MULTI_WARPS, uint32_t>(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first); - } -} - -// This function move the invalid index element to the end of the itopk list. -// Require : array_length % 32 == 0 && The invalid entry is only one. -template -RAFT_DEVICE_INLINE_FUNCTION void move_invalid_to_end_of_list(IdxT* const index_array, - float* const distance_array, - const std::uint32_t array_length) -{ - constexpr std::uint32_t warp_size = 32; - constexpr std::uint32_t invalid_index = utils::get_max_value(); - const std::uint32_t lane_id = threadIdx.x % warp_size; - - if (threadIdx.x >= warp_size) { return; } - - bool found_invalid = false; - if (array_length % warp_size == 0) { - for (std::uint32_t i = lane_id; i < array_length; i += warp_size) { - const auto index = index_array[i]; - const auto distance = distance_array[i]; - - if (found_invalid) { - index_array[i - 1] = index; - distance_array[i - 1] = distance; - } else { - // Check if the index is invalid - const auto I_found_invalid = (index == invalid_index); - const auto who_has_invalid = raft::ballot(I_found_invalid); - // if a value that is loaded by a smaller lane id thread, shift the array - if (who_has_invalid << (warp_size - lane_id)) { - index_array[i - 1] = index; - distance_array[i - 1] = distance; - } - - found_invalid = who_has_invalid; - } - } - } - if (lane_id == 0) { - index_array[array_length - 1] = invalid_index; - distance_array[array_length - 1] = utils::get_max_value(); - } -} - -template -RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr, - const size_t hashmap_bitlen, - const INDEX_T* itopk_indices, - const uint32_t itopk_size, - const uint32_t first_tid = 0) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - if (threadIdx.x < first_tid) return; - for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { - auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit - hashmap::insert(hashmap_ptr, hashmap_bitlen, key); - } -} - -/** - * @brief Search operation for a single query using a single thread block. - * * - * @tparam TOPK_BY_BITONIC_SORT - * @tparam DATASET_DESCRIPTOR_T - * @tparam SAMPLE_FILTER_T - * - * @param result_indices_ptr - * Tagged pointer to the result neighbors [num_queries, top_k]; the tag is the two lower bits to - * identify the index element type (see the code below). - * @param result_distances_ptr Pointer to the result distances buffer [num_queries, top_k]. - * @param top_k Number of top-k results to retrieve. - * @param dataset_desc Pointer to the dataset descriptor. - * @param queries_ptr Pointer to the queries [num_queries, dataset_dim]. - * @param knn_graph Pointer to the k-nearest neighbors graph [dataset_size, graph_degree]. - * @param graph_degree Degree of the graph. - * @param num_distilation Number of distillation steps. - * @param rand_xor_mask Random XOR mask for randomization. - * @param seed_ptr Pointer to the seed indices [num_queries, num_seeds]. - * @param num_seeds Number of seeds. - * @param visited_hashmap_ptr - * Pointer to the hashmap of visited nodes [num_queries, 1 << hash_bitlen]. - * @param internal_topk Internal top-k size. - * @param search_width Width of the search. - * @param min_iteration Minimum number of iterations. - * @param max_iteration Maximum number of iterations. - * @param num_executed_iterations Pointer to the number of executed iterations [num_queries]. - * @param hash_bitlen Bit length of the hash. - * @param small_hash_bitlen Bit length of the small hash. - * @param small_hash_reset_interval Interval for resetting the small hash. - * @param query_id sequential id of the query in the batch - */ -template -RAFT_DEVICE_INLINE_FUNCTION void search_core( - uintptr_t result_indices_ptr, // [num_queries, top_k] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] - const std::uint32_t top_k, - const DATASET_DESCRIPTOR_T* dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const SourceIndexT* source_indices_ptr, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t max_candidates, - const std::uint32_t max_itopk, - const std::uint32_t internal_topk, - const std::uint32_t search_width, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval, - const std::uint32_t query_id, - SAMPLE_FILTER_T sample_filter, - const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0) -{ - using LOAD_T = device::LOAD_128BIT_T; - - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - auto to_source_index = [source_indices_ptr](INDEX_T x) { - return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; - }; - -#ifdef _CLK_BREAKDOWN - std::uint64_t clk_init = 0; - std::uint64_t clk_compute_1st_distance = 0; - std::uint64_t clk_topk = 0; - std::uint64_t clk_reset_hash = 0; - std::uint64_t clk_pickup_parents = 0; - std::uint64_t clk_restore_hash = 0; - std::uint64_t clk_compute_distance = 0; - std::uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ uint8_t smem[]; - - // Layout of result_buffer - // +----------------------+------------------------------+---------+ - // | internal_top_k | neighbors of internal_top_k | padding | - // | | | upto 32 | - // +----------------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - const auto result_buffer_size = internal_topk + (search_width * graph_degree); - const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); - const auto small_hash_size = hashmap::get_size(small_hash_bitlen); - - // Set smem working buffer for the distance calculation - dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id); - - auto* __restrict__ result_indices_buffer = - reinterpret_cast(smem + dataset_desc->smem_ws_size_in_bytes()); - auto* __restrict__ result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto* __restrict__ visited_hash_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto* __restrict__ parent_list_buffer = - reinterpret_cast(visited_hash_buffer + small_hash_size); - auto* __restrict__ topk_ws = reinterpret_cast(parent_list_buffer + search_width); - auto* terminate_flag = reinterpret_cast(topk_ws + 3); - auto* __restrict__ smem_work_ptr = reinterpret_cast(terminate_flag + 1); - - // A flag for filtering. - auto filter_flag = terminate_flag; - - if (threadIdx.x == 0) { - terminate_flag[0] = 0; - topk_ws[0] = ~0u; - } - - // Init hashmap - INDEX_T* local_visited_hashmap_ptr; - if (small_hash_bitlen) { - local_visited_hashmap_ptr = visited_hash_buffer; - } else { - local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y); - } - hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - device::compute_distance_to_random_nodes(result_indices_buffer, - result_distances_buffer, - *dataset_desc, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen, - (INDEX_T*)nullptr, - 0, - 0, - 1, - graph_size); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - std::uint32_t iter = 0; - while (1) { - // sort - if constexpr (TOPK_BY_BITONIC_SORT) { - // [Notice] - // It is good to use multiple warps in topk_by_bitonic_sort_and_merge() when - // batch size is small (short-latency), but it might not be always good - // when batch size is large (high-throughput). - // topk_by_bitonic_sort_and_merge() consists of two operations: - // if max_candidates is greater than 128, the first operation uses two warps; - // if max_itopk is greater than 256, the second operation used two warps. - assert(blockDim.x >= 64); - const bool bitonic_sort_and_full_multi_warps = (max_candidates > 128) ? true : false; - - // reset small-hash table. - if ((iter + 1) % small_hash_reset_interval == 0) { - // Depending on the block size and the number of warps used in - // topk_by_bitonic_sort_and_merge(), determine which warps are used to reset - // the small hash and whether they are performed in overlap with - // topk_by_bitonic_sort_and_merge(). - _CLK_START(); - unsigned hash_start_tid; - if (blockDim.x == 32) { - hash_start_tid = 0; - } else if (blockDim.x == 64) { - if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { - hash_start_tid = 0; - } else { - hash_start_tid = 32; - } - } else { - if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { - hash_start_tid = 64; - } else { - hash_start_tid = 32; - } - } - hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); - _CLK_REC(clk_reset_hash); - } - - // topk with bitonic sort - _CLK_START(); - if (!(std::is_same::value || - *filter_flag == 0)) { - // Move the filtered out index to the end of the itopk list - for (unsigned i = 0; i < search_width; i++) { - move_invalid_to_end_of_list( - result_indices_buffer, result_distances_buffer, internal_topk); - } - - if (threadIdx.x == 0) { *terminate_flag = 0; } - } - topk_by_bitonic_sort_and_merge( - result_distances_buffer, - result_indices_buffer, - max_itopk, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - max_candidates, - search_width * graph_degree, - topk_ws, - (iter == 0)); - __syncthreads(); - _CLK_REC(clk_topk); - } else { - _CLK_START(); - // topk with radix block sort - topk_by_radix_sort{}(max_itopk, - internal_topk, - result_buffer_size, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - nullptr, - topk_ws, - true, - smem_work_ptr); - _CLK_REC(clk_topk); - - // reset small-hash table - if ((iter + 1) % small_hash_reset_interval == 0) { - _CLK_START(); - hashmap::init(local_visited_hashmap_ptr, hash_bitlen); - _CLK_REC(clk_reset_hash); - } - } - __syncthreads(); - - if (iter + 1 == max_iteration) { break; } - - // pick up next parents - if (threadIdx.x < 32) { - _CLK_START(); - pickup_next_parents( - terminate_flag, parent_list_buffer, result_indices_buffer, internal_topk, search_width); - _CLK_REC(clk_pickup_parents); - } - - // restore small-hash table by putting internal-topk indices in it - if ((iter + 1) % small_hash_reset_interval == 0) { - const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); - _CLK_START(); - hashmap_restore( - local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); - _CLK_REC(clk_restore_hash); - } - __syncthreads(); - - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - device::compute_distance_to_child_nodes(result_indices_buffer + internal_topk, - result_distances_buffer + internal_topk, - *dataset_desc, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - (INDEX_T*)nullptr, - 0, - parent_list_buffer, - result_indices_buffer, - search_width); - __syncthreads(); - _CLK_REC(clk_compute_distance); - - // Filtering - if constexpr (!std::is_same::value) { - if (threadIdx.x == 0) { *filter_flag = 0; } - __syncthreads(); - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { - if (parent_list_buffer[p] != invalid_index) { - const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; - if (!sample_filter(query_id, to_source_index(parent_id))) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); - result_indices_buffer[parent_list_buffer[p]] = invalid_index; - *filter_flag = 1; - } - } - } - __syncthreads(); - } - - iter++; - } - - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; - i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && - !sample_filter(query_id, to_source_index(node_id))) { - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - // Move invalid index items to the end of the buffer without sorting the entire buffer - using scan_op_t = cub::WarpScan; - auto& temp_storage = *reinterpret_cast(smem_work_ptr); - - constexpr std::uint32_t warp_size = 32; - if (threadIdx.x < warp_size) { - std::uint32_t num_found_valid = 0; - for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; - buffer_offset += warp_size) { - // Calculate the new buffer index - const auto src_position = buffer_offset + threadIdx.x; - const std::uint32_t is_valid_index = - (result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; - std::uint32_t new_position; - scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); - if (is_valid_index) { - const auto dst_position = num_found_valid + (new_position - 1); - result_indices_buffer[dst_position] = result_indices_buffer[src_position]; - result_distances_buffer[dst_position] = result_distances_buffer[src_position]; - } - - // Calculate the largest valid position within a warp and bcast it for the next iteration - num_found_valid += new_position; - for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { - const auto v = raft::shfl_xor(num_found_valid, offset); - if ((threadIdx.x & offset) == 0) { num_found_valid = v; } - } - - // If the enough number of items are found, do early termination - if (num_found_valid >= top_k) { break; } - } - - if (num_found_valid < top_k) { - // Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort_and_merge` is - // usable in the next step - for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { - result_indices_buffer[i] = invalid_index; - result_distances_buffer[i] = utils::get_max_value(); - } - } - } - - // If the sufficient number of valid indexes are not in the internal topk, pick up from the - // candidate list. - if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { - __syncthreads(); - topk_by_bitonic_sort_and_merge( - result_distances_buffer, - result_indices_buffer, - max_itopk, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - max_candidates, - search_width * graph_degree, - topk_ws, - (iter == 0)); - } - __syncthreads(); - } - - // NB: The indices pointer is tagged with its element size. - // Here we select the correct conversion operator at runtime. - // This allows us to avoid multiplying kernel instantiations - // and any costs for extra registers in the kernel signature. - const uint32_t index_element_tag = result_indices_ptr & 0x3; - result_indices_ptr ^= index_element_tag; - auto write_indices = - index_element_tag == 3 - ? [](uintptr_t ptr, - uint32_t i, - SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } - : index_element_tag == 2 - ? [](uintptr_t ptr, - uint32_t i, - SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } - : index_element_tag == 1 - ? [](uintptr_t ptr, - uint32_t i, - SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } - : [](uintptr_t ptr, uint32_t i, SourceIndexT x) { - reinterpret_cast(ptr)[i] = static_cast(x); - }; - for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { - unsigned j = i + (top_k * query_id); - unsigned ii = i; - if constexpr (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - auto internal_index = - result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit - auto source_index = to_source_index(internal_index); - write_indices(result_indices_ptr, j, source_index); - } - if (threadIdx.x == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) { - printf( - "%s:%d " - "query, %d, thread, %d" - ", init, %lu" - ", 1st_distance, %lu" - ", topk, %lu" - ", reset_hash, %lu" - ", pickup_parents, %lu" - ", restore_hash, %lu" - ", distance, %lu" - "\n", - __FILE__, - __LINE__, - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_reset_hash, - clk_pickup_parents, - clk_restore_hash, - clk_compute_distance); - } -#endif -} - -template -RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( - uintptr_t result_indices_ptr, // [num_queries, top_k] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] - const std::uint32_t top_k, - const DATASET_DESCRIPTOR_T* dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const SourceIndexT* source_indices_ptr, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t max_candidates, - const std::uint32_t max_itopk, - const std::uint32_t internal_topk, - const std::uint32_t search_width, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval, - SAMPLE_FILTER_T sample_filter, - const typename DATASET_DESCRIPTOR_T::INDEX_T graph_size = 0) -{ - const auto query_id = blockIdx.y; - search_core(result_indices_ptr, - result_distances_ptr, - top_k, - dataset_desc, - queries_ptr, - knn_graph, - graph_degree, - source_indices_ptr, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - visited_hashmap_ptr, - max_candidates, - max_itopk, - internal_topk, - search_width, - min_iteration, - max_iteration, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - query_id, - sample_filter, - graph_size); -} - -// To make sure we avoid false sharing on both CPU and GPU, we enforce cache line size to the -// maximum of the two. -// This makes sync atomic significantly faster. -constexpr size_t kCacheLineBytes = 64; - -constexpr uint32_t kMaxJobsNum = 8192; -constexpr uint32_t kMaxWorkersNum = 4096; -constexpr uint32_t kMaxWorkersPerThread = 256; -constexpr uint32_t kSoftMaxWorkersPerThread = 16; - -template -struct alignas(kCacheLineBytes) job_desc_t { - using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; - using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; - // The algorithm input parameters - struct value_t { - uintptr_t result_indices_ptr; // [num_queries, top_k] - distance_type* result_distances_ptr; // [num_queries, top_k] - const data_type* queries_ptr; // [num_queries, dataset_dim] - uint32_t top_k; - uint32_t n_queries; - }; - using blob_elem_type = uint4; - constexpr static inline size_t kBlobSize = - raft::div_rounding_up_safe(sizeof(value_t), sizeof(blob_elem_type)); - // Union facilitates loading the input by a warp in a single request - union input_t { - blob_elem_type blob[kBlobSize]; // NOLINT - value_t value; - } input; - // Last thread triggers this flag. - cuda::atomic completion_flag; -}; - -struct alignas(kCacheLineBytes) worker_handle_t { - using handle_t = uint64_t; - struct value_t { - uint32_t desc_id; - uint32_t query_id; - }; - union data_t { - handle_t handle; - value_t value; - }; - cuda::atomic data; -}; -static_assert(sizeof(worker_handle_t::value_t) == sizeof(worker_handle_t::handle_t)); -static_assert( - cuda::atomic::is_always_lock_free); - -constexpr worker_handle_t::handle_t kWaitForWork = std::numeric_limits::max(); -constexpr worker_handle_t::handle_t kNoMoreWork = kWaitForWork - 1; - -constexpr auto is_worker_busy(worker_handle_t::handle_t h) -> bool -{ - return (h != kWaitForWork) && (h != kNoMoreWork); -} - -template -RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p( - const DATASET_DESCRIPTOR_T* dataset_desc, - worker_handle_t* worker_handles, - job_desc_t* job_descriptors, - uint32_t* completion_counters, - const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const SourceIndexT* source_indices_ptr, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t max_candidates, - const std::uint32_t max_itopk, - const std::uint32_t internal_topk, - const std::uint32_t search_width, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval, - SAMPLE_FILTER_T sample_filter) -{ - using job_desc_type = job_desc_t; - __shared__ typename job_desc_type::input_t job_descriptor; - __shared__ worker_handle_t::data_t worker_data; - - auto& worker_handle = worker_handles[blockIdx.y].data; - uint32_t job_ix; - - while (true) { - // wait the writing phase - if (threadIdx.x == 0) { - worker_handle_t::data_t worker_data_local; - do { - worker_data_local = worker_handle.load(cuda::memory_order_relaxed); - } while (worker_data_local.handle == kWaitForWork); - if (worker_data_local.handle != kNoMoreWork) { - worker_handle.store({kWaitForWork}, cuda::memory_order_relaxed); - } - job_ix = worker_data_local.value.desc_id; - cuda::atomic_thread_fence(cuda::memory_order_acquire, cuda::thread_scope_system); - worker_data = worker_data_local; - } - if (threadIdx.x < raft::WarpSize) { - // Sync one warp and copy descriptor data - static_assert(job_desc_type::kBlobSize <= raft::WarpSize); - job_ix = raft::shfl(job_ix, 0); - if (threadIdx.x < job_desc_type::kBlobSize && job_ix < kMaxJobsNum) { - job_descriptor.blob[threadIdx.x] = job_descriptors[job_ix].input.blob[threadIdx.x]; - } - } - __syncthreads(); - if (worker_data.handle == kNoMoreWork) { break; } - - // reading phase - auto result_indices_ptr = job_descriptor.value.result_indices_ptr; - auto* result_distances_ptr = job_descriptor.value.result_distances_ptr; - auto* queries_ptr = job_descriptor.value.queries_ptr; - auto top_k = job_descriptor.value.top_k; - auto n_queries = job_descriptor.value.n_queries; - auto query_id = worker_data.value.query_id; - - // work phase - search_core(result_indices_ptr, - result_distances_ptr, - top_k, - dataset_desc, - queries_ptr, - knn_graph, - graph_degree, - source_indices_ptr, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - visited_hashmap_ptr, - max_candidates, - max_itopk, - internal_topk, - search_width, - min_iteration, - max_iteration, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - query_id, - sample_filter); - - // make sure all writes are visible even for the host - // (e.g. when result buffers are in pinned memory) - cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); - - // arrive to mark the end of the work phase - __syncthreads(); - if (threadIdx.x == 0) { - auto completed_count = atomicInc(completion_counters + job_ix, n_queries - 1) + 1; - if (completed_count >= n_queries) { - job_descriptors[job_ix].completion_flag.store(true, cuda::memory_order_relaxed); - } - } - } -} - -template -auto dispatch_kernel = []() { - static_assert(TOPK_BY_BITONIC_SORT || !BITONIC_SORT_AND_MERGE_MULTI_WARPS); - if constexpr (Persistent) { - return search_kernel_p; - } else { - return search_kernel; - } -}(); - -template -struct search_kernel_config { - using kernel_t = decltype(dispatch_kernel); - - static auto choose_itopk_and_mx_candidates(unsigned itopk_size, - unsigned num_itopk_candidates, - unsigned block_size) -> kernel_t - { - assert(itopk_size <= 512); - if (num_itopk_candidates <= 256) { - if (itopk_size <= 256) { - return dispatch_kernel; - } else { - assert(block_size >= 64); - return dispatch_kernel; - } - } else { - // Radix-based topk is used - return dispatch_kernel; - } - } -}; - -/** - * @brief Resource queue - * - * @tparam T the element type - * @tparam Size the maximum capacity of the queue (power-of-two) - * @tparam Empty a special element value designating an empty queue slot. NB: storing `Empty` is UB. - * - * A shared atomic ring buffer based queue optimized for throughput when bottlenecked on `pop` - * operation. - * - * @code{.cpp} - * // allocate the queue - * resource_queue_t resource_ids; - * - * // store couple values - * resource_ids.push(42); - * resource_ids.push(7); - * - * // wait to get the value from the queue - * auto id_x = resource_ids.pop().wait(); - * - * // stand in line to get the value from the queue, but don't wait - * auto ticket_y = resource_ids.pop(); - * // do other stuff and check if the value is available - * int32_t id_y; - * while (!ticket_y.test(id_y)) { - * do_some_important_business(...); - * std::this_thread::sleep_for(std::chrono::microseconds(10); - * } - * // `id_y` is set by now and `ticket_y.wait()` won't block anymore - * assert(ticket_y.wait() == id_y); - * @endcode - */ -template ::max()> -struct alignas(kCacheLineBytes) resource_queue_t { - using value_type = T; - static constexpr uint32_t kSize = Size; - static constexpr value_type kEmpty = Empty; - static_assert(cuda::std::atomic::is_always_lock_free, - "The value type must be lock-free."); - static_assert(raft::is_a_power_of_two(kSize), "The size must be a power-of-two for efficiency."); - static constexpr uint32_t kElemsPerCacheLine = - raft::div_rounding_up_safe(kCacheLineBytes, sizeof(value_type)); - /* [Note: cache-friendly indexing] - To avoid false sharing, the queue pushes and pops values not sequentially, but with an - increment that is larger than the cache line size. - Hence we introduce the `kCounterIncrement > kCacheLineBytes`. - However, to make sure all indices are used, we choose the increment to be coprime with the - buffer size. We also require that the buffer size is a power-of-two for two reasons: - 1) Fast modulus operation - reduces to binary `and` (with `kCounterLocMask`). - 2) Easy to ensure GCD(kCounterIncrement, kSize) == 1 by construction - (see the definition below). - */ - static constexpr uint32_t kCounterIncrement = raft::bound_by_power_of_two(kElemsPerCacheLine) + 1; - static constexpr uint32_t kCounterLocMask = kSize - 1; - // These props hold by design, but we add them here as a documentation and a sanity check. - static_assert( - kCounterIncrement * sizeof(value_type) >= kCacheLineBytes, - "The counter increment should be larger than the cache line size to avoid false sharing."); - static_assert( - std::gcd(kCounterIncrement, kSize) == 1, - "The counter increment and the size must be coprime to allow using all of the queue slots."); - - static constexpr auto kMemOrder = cuda::std::memory_order_relaxed; - - explicit resource_queue_t(uint32_t capacity = Size) noexcept : capacity_{capacity} - { - head_.store(0, kMemOrder); - tail_.store(0, kMemOrder); - for (uint32_t i = 0; i < kSize; i++) { - buf_[i].store(kEmpty, kMemOrder); - } - } - - /** Nominal capacity of the queue. */ - [[nodiscard]] auto capacity() const { return capacity_; } - - /** This does not affect the queue behavior, but merely declares a nominal capacity. */ - void set_capacity(uint32_t capacity) { capacity_ = capacity; } - - /** - * A slot in the queue to take the value from. - * Once it's obtained, the corresponding value in the queue is lost for other users. - */ - struct promise_t { - explicit promise_t(cuda::std::atomic& loc) : loc_{loc}, val_{Empty} {} - ~promise_t() noexcept { wait(); } - - auto test() noexcept -> bool - { - if (val_ != Empty) { return true; } - val_ = loc_.exchange(kEmpty, kMemOrder); - return val_ != Empty; - } - - auto test(value_type& e) noexcept -> bool - { - if (test()) { - e = val_; - return true; - } - return false; - } - - auto wait() noexcept -> value_type - { - if (val_ == Empty) { - // [HOT SPOT] - // Optimize for the case of contention: expect the loc is empty. - do { - loc_.wait(kEmpty, kMemOrder); - val_ = loc_.exchange(kEmpty, kMemOrder); - } while (val_ == kEmpty); - } - return val_; - } - - private: - cuda::std::atomic& loc_; - value_type val_; - }; - - void push(value_type x) noexcept - { - auto& loc = buf_[head_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; - /* [NOT A HOT SPOT] - We expect there's always enough place in the queue to push the item, - but also we expect a few pop waiters - notify them the data is available. - */ - value_type e = kEmpty; - while (!loc.compare_exchange_weak(e, x, kMemOrder, kMemOrder)) { - e = kEmpty; - } - loc.notify_one(); - } - - auto pop() noexcept -> promise_t - { - auto& loc = buf_[tail_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; - return promise_t{loc}; - } - - private: - alignas(kCacheLineBytes) cuda::std::atomic head_{}; - alignas(kCacheLineBytes) cuda::std::atomic tail_{}; - alignas(kCacheLineBytes) std::array, kSize> buf_{}; - alignas(kCacheLineBytes) uint32_t capacity_; -}; - -/** Primitive fixed-size deque for single-threaded use. */ -template -struct local_deque_t { - explicit local_deque_t(uint32_t size) : store_(size) {} - - [[nodiscard]] auto capacity() const -> uint32_t { return store_.size(); } - [[nodiscard]] auto size() const -> uint32_t { return end_ - start_; } - - void push_back(T x) { store_[end_++ % capacity()] = x; } - - void push_front(T x) - { - if (start_ == 0) { - start_ += capacity(); - end_ += capacity(); - } - store_[--start_ % capacity()] = x; - } - - // NB: unsafe functions - do not check if the queue is full/empty. - auto pop_back() -> T { return store_[--end_ % capacity()]; } - auto pop_front() -> T { return store_[start_++ % capacity()]; } - - auto try_push_back(T x) -> bool - { - if (size() >= capacity()) { return false; } - push_back(x); - return true; - } - - auto try_push_front(T x) -> bool - { - if (size() >= capacity()) { return false; } - push_front(x); - return true; - } - - auto try_pop_back(T& x) -> bool - { - if (start_ >= end_) { return false; } - x = pop_back(); - return true; - } - - auto try_pop_front(T& x) -> bool - { - if (start_ >= end_) { return false; } - x = pop_front(); - return true; - } - - private: - std::vector store_; - uint32_t start_{0}; - uint32_t end_{0}; -}; - -struct persistent_runner_base_t { - using job_queue_type = resource_queue_t; - using worker_queue_type = resource_queue_t; - rmm::mr::pinned_host_memory_resource worker_handles_mr; - rmm::mr::pinned_host_memory_resource job_descriptor_mr; - rmm::mr::cuda_memory_resource device_mr; - cudaStream_t stream{}; - job_queue_type job_queue{}; - worker_queue_type worker_queue{}; - // This should be large enough to make the runner live through restarts of the benchmark cases. - // Otherwise, the benchmarks slowdown significantly. - std::chrono::milliseconds lifetime; - - persistent_runner_base_t(float persistent_lifetime) - : lifetime(size_t(persistent_lifetime * 1000)), job_queue(), worker_queue() - { - cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); - } - virtual ~persistent_runner_base_t() noexcept { cudaStreamDestroy(stream); }; -}; - -struct alignas(kCacheLineBytes) launcher_t { - using job_queue_type = persistent_runner_base_t::job_queue_type; - using worker_queue_type = persistent_runner_base_t::worker_queue_type; - using pending_reads_queue_type = local_deque_t; - using completion_flag_type = cuda::atomic; - - pending_reads_queue_type pending_reads; - job_queue_type& job_ids; - worker_queue_type& idle_worker_ids; - worker_handle_t* worker_handles; - uint32_t job_id; - completion_flag_type* completion_flag; - bool all_done = false; - - /* [Note: sleeping] - When the number of threads is greater than the number of cores, the threads start to fight for - the CPU time, which reduces the throughput. - To ease the competition, we track the expected GPU latency and let a thread sleep for some - time, and only start to spin when it's about a time to get the result. - */ - static inline constexpr auto kDefaultLatency = std::chrono::nanoseconds(50000); - /* This is the base for computing maximum time a thread is allowed to sleep. */ - static inline constexpr auto kMaxExpectedLatency = - kDefaultLatency * std::max(10, kMaxJobsNum / 128); - static inline thread_local auto expected_latency = kDefaultLatency; - const std::chrono::time_point start; - std::chrono::time_point now; - const int64_t pause_factor; - int pause_count = 0; - /** - * Beyond this threshold, the launcher (calling thread) does not wait for the results anymore and - * throws an exception. - */ - std::chrono::time_point deadline; - - template - launcher_t(job_queue_type& job_ids, - worker_queue_type& idle_worker_ids, - worker_handle_t* worker_handles, - uint32_t n_queries, - std::chrono::milliseconds max_wait_time, - RecordWork record_work) - : pending_reads{std::min(n_queries, kMaxWorkersPerThread)}, - job_ids{job_ids}, - idle_worker_ids{idle_worker_ids}, - worker_handles{worker_handles}, - job_id{job_ids.pop().wait()}, - completion_flag{record_work(job_id)}, - start{std::chrono::system_clock::now()}, - pause_factor{calc_pause_factor(n_queries)}, - now{start}, - deadline{start + max_wait_time + expected_latency} - { - // Wait for the first worker and submit the query immediately. - submit_query(idle_worker_ids.pop().wait(), 0); - // Submit the rest of the queries in the batch - for (uint32_t i = 1; i < n_queries; i++) { - auto promised_worker = idle_worker_ids.pop(); - uint32_t worker_id; - while (!promised_worker.test(worker_id)) { - if (pending_reads.try_pop_front(worker_id)) { - bool returned_some = false; - for (bool keep_returning = true; keep_returning;) { - if (try_return_worker(worker_id)) { - keep_returning = pending_reads.try_pop_front(worker_id); - returned_some = true; - } else { - pending_reads.push_front(worker_id); - keep_returning = false; - } - } - if (!returned_some) { pause(); } - } else { - // Calmly wait for the promised worker instead of spinning. - worker_id = promised_worker.wait(); - break; - } - } - pause_count = 0; // reset the pause behavior - submit_query(worker_id, i); - // Try to not hold too many workers in one thread - if (i >= kSoftMaxWorkersPerThread && pending_reads.try_pop_front(worker_id)) { - if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } - } - } - } - - inline ~launcher_t() noexcept // NOLINT - { - // bookkeeping: update the expected latency to wait more efficiently later - constexpr size_t kWindow = 100; // moving average memory - expected_latency = std::min( - ((kWindow - 1) * expected_latency + now - start) / kWindow, kMaxExpectedLatency); - - // Try to gracefully cleanup the queue resources if the launcher is being destructed after an - // exception. - if (job_id != job_queue_type::kEmpty) { job_ids.push(job_id); } - uint32_t worker_id; - while (pending_reads.try_pop_front(worker_id)) { - idle_worker_ids.push(worker_id); - } - } - - inline void submit_query(uint32_t worker_id, uint32_t query_id) - { - worker_handles[worker_id].data.store(worker_handle_t::data_t{.value = {job_id, query_id}}, - cuda::memory_order_relaxed); - - while (!pending_reads.try_push_back(worker_id)) { - // The only reason pending_reads cannot push is that the queue is full. - // It's local, so we must pop and wait for the returned worker to finish its work. - auto pending_worker_id = pending_reads.pop_front(); - while (!try_return_worker(pending_worker_id)) { - pause(); - } - } - pause_count = 0; // reset the pause behavior - } - - /** Check if the worker has finished the work; if so, return it to the shared pool. */ - inline auto try_return_worker(uint32_t worker_id) -> bool - { - // Use the cached `all_done` - makes sense when called from the `wait()` routine. - if (all_done || - !is_worker_busy(worker_handles[worker_id].data.load(cuda::memory_order_relaxed).handle)) { - idle_worker_ids.push(worker_id); - return true; - } else { - return false; - } - } - - /** Check if all workers finished their work. */ - inline auto is_all_done() - { - // Cache the result of the check to avoid doing unnecessary atomic loads. - if (all_done) { return true; } - all_done = completion_flag->load(cuda::memory_order_relaxed); - return all_done; - } - - /** The launcher shouldn't attempt to wait past the returned time. */ - [[nodiscard]] inline auto sleep_limit() const - { - constexpr auto kMinWakeTime = std::chrono::nanoseconds(10000); - constexpr double kSleepLimit = 0.6; - return start + expected_latency * kSleepLimit - kMinWakeTime; - } - - /** - * When the latency is much larger than expected, it's a sign that there is a thread contention. - * Then we switch to sleeping instead of waiting to give the cpu cycles to other threads. - */ - [[nodiscard]] inline auto overtime_threshold() const - { - constexpr auto kOvertimeFactor = 3; - return start + expected_latency * kOvertimeFactor; - } - - /** - * Calculate the fraction of time can be spent sleeping in a single call to `pause()`. - * Naturally it depends on the number of queries in a batch and the number of parallel workers. - */ - [[nodiscard]] inline auto calc_pause_factor(uint32_t n_queries) const -> uint32_t - { - constexpr uint32_t kMultiplier = 10; - return kMultiplier * raft::div_rounding_up_safe(n_queries, idle_worker_ids.capacity()); - } - - /** Wait a little bit (called in a loop). */ - inline void pause() - { - // Don't sleep this many times hoping for smoother run - constexpr auto kSpinLimit = 3; - // It doesn't make much sense to sleep less than this - constexpr auto kPauseTimeMin = std::chrono::nanoseconds(1000); - // Bound sleeping time - constexpr auto kPauseTimeMax = std::chrono::nanoseconds(50000); - if (pause_count++ < kSpinLimit) { - std::this_thread::yield(); - return; - } - now = std::chrono::system_clock::now(); - auto pause_time_base = std::max(now - start, expected_latency); - auto pause_time = std::clamp(pause_time_base / pause_factor, kPauseTimeMin, kPauseTimeMax); - if (now + pause_time < sleep_limit()) { - // It's too early: sleep for a bit - std::this_thread::sleep_for(pause_time); - } else if (now <= overtime_threshold()) { - // It's about time to check the results, don't sleep - std::this_thread::yield(); - } else if (now <= deadline) { - // Too late; perhaps the system is too busy - sleep again - std::this_thread::sleep_for(pause_time); - } else { - // Missed the deadline: throw an exception - throw raft::exception( - "The calling thread didn't receive the results from the persistent CAGRA kernel within the " - "expected kernel lifetime. Here are possible reasons of this failure:\n" - " (1) `persistent_lifetime` search parameter is too small - increase it;\n" - " (2) there is other work being executed on the same device and the kernel failed to " - "progress - decreasing `persistent_device_usage` may help (but not guaranteed);\n" - " (3) there is a bug in the implementation - please report it to cuVS team."); - } - } - - /** Wait for all work to finish and don't forget to return the workers to the shared pool. */ - inline void wait() - { - uint32_t worker_id; - while (pending_reads.try_pop_front(worker_id)) { - while (!try_return_worker(worker_id)) { - if (!is_all_done()) { pause(); } - } - } - pause_count = 0; // reset the pause behavior - // terminal state, should be engaged only after the `pending_reads` is empty - // and `queries_submitted == n_queries` - now = std::chrono::system_clock::now(); - while (!is_all_done()) { - auto till_time = sleep_limit(); - if (now < till_time) { - std::this_thread::sleep_until(till_time); - now = std::chrono::system_clock::now(); - } else { - pause(); - } - } - - // Return the job descriptor - job_ids.push(job_id); - job_id = job_queue_type::kEmpty; - } -}; - -template -struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_base_t { - using descriptor_base_type = dataset_descriptor_base_t; - using index_type = IndexT; - using distance_type = DistanceT; - using data_type = DataT; - using kernel_config_type = - search_kernel_config; - using kernel_type = typename kernel_config_type::kernel_t; - using job_desc_type = job_desc_t; - kernel_type kernel; - uint32_t block_size; - dataset_descriptor_host dd_host; - rmm::device_uvector worker_handles; - rmm::device_uvector job_descriptors; - rmm::device_uvector completion_counters; - rmm::device_uvector hashmap; - std::atomic> last_touch; - uint64_t param_hash; - - /** - * Calculate the hash of the parameters to detect if they've changed across the calls. - * NB: this must have the same argument types as the constructor. - */ - static inline auto calculate_parameter_hash( - std::reference_wrapper> dataset_desc, - raft::device_matrix_view graph, - const SourceIndexT* source_indices_ptr, - uint32_t max_candidates, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - uint32_t max_itopk, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SampleFilterT sample_filter, - float persistent_lifetime, - float persistent_device_usage) -> uint64_t - { - return uint64_t(graph.data_handle()) ^ uint64_t(source_indices_ptr) ^ - dataset_desc.get().team_size ^ num_itopk_candidates ^ block_size ^ smem_size ^ - hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ - num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ - uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000); - } - - persistent_runner_t( - std::reference_wrapper> dataset_desc, - raft::device_matrix_view graph, - const SourceIndexT* source_indices_ptr, - uint32_t max_candidates, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - uint32_t max_itopk, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SampleFilterT sample_filter, - float persistent_lifetime, - float persistent_device_usage) - : persistent_runner_base_t{persistent_lifetime}, - kernel{kernel_config_type::choose_itopk_and_mx_candidates( - itopk_size, num_itopk_candidates, block_size)}, - block_size{block_size}, - worker_handles(0, stream, worker_handles_mr), - job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), - completion_counters(kMaxJobsNum, stream, device_mr), - hashmap(0, stream, device_mr), - dd_host{dataset_desc.get()}, - param_hash(calculate_parameter_hash(dd_host, - graph, - source_indices_ptr, - max_candidates, - num_itopk_candidates, - block_size, - smem_size, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - num_random_samplings, - rand_xor_mask, - num_seeds, - max_itopk, - itopk_size, - search_width, - min_iterations, - max_iterations, - sample_filter, - persistent_lifetime, - persistent_device_usage)) - { - // initialize the dataset/distance descriptor - auto* dd_dev_ptr = dd_host.dev_ptr(stream); - - // set kernel attributes same as in normal kernel - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - // set kernel launch parameters - dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); - dim3 bs(block_size, 1, 1); - RAFT_LOG_DEBUG( - "Launching persistent kernel with %u threads, %u block %u smem", bs.x, gs.y, smem_size); - - // initialize the job queue - auto* completion_counters_ptr = completion_counters.data(); - auto* job_descriptors_ptr = job_descriptors.data(); - for (uint32_t i = 0; i < kMaxJobsNum; i++) { - auto& jd = job_descriptors_ptr[i].input.value; - jd.result_indices_ptr = 0; - jd.result_distances_ptr = nullptr; - jd.queries_ptr = nullptr; - jd.top_k = 0; - jd.n_queries = 0; - job_descriptors_ptr[i].completion_flag.store(false); - job_queue.push(i); - } - - // initialize the worker queue - worker_queue.set_capacity(gs.y); - worker_handles.resize(gs.y, stream); - auto* worker_handles_ptr = worker_handles.data(); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (uint32_t i = 0; i < gs.y; i++) { - worker_handles_ptr[i].data.store({kWaitForWork}); - worker_queue.push(i); - } - - index_type* hashmap_ptr = nullptr; - if (small_hash_bitlen == 0) { - hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream); - hashmap_ptr = hashmap.data(); - } - - // launch the kernel - auto* graph_ptr = graph.data_handle(); - uint32_t graph_degree = graph.extent(1); - uint32_t* num_executed_iterations = nullptr; // optional arg [num_queries] - const index_type* dev_seed_ptr = nullptr; // optional arg [num_queries, num_seeds] - - void* args[] = // NOLINT - {&dd_dev_ptr, - &worker_handles_ptr, - &job_descriptors_ptr, - &completion_counters_ptr, - &graph_ptr, // [dataset_size, graph_degree] - &graph_degree, - &source_indices_ptr, - &num_random_samplings, - &rand_xor_mask, - &dev_seed_ptr, - &num_seeds, - &hashmap_ptr, // visited_hashmap_ptr: [num_queries, 1 << hash_bitlen] - &max_candidates, - &max_itopk, - &itopk_size, - &search_width, - &min_iterations, - &max_iterations, - &num_executed_iterations, - &hash_bitlen, - &small_hash_bitlen, - &small_hash_reset_interval, - &sample_filter}; - cuda::atomic_thread_fence(cuda::memory_order_seq_cst, cuda::thread_scope_system); - RAFT_CUDA_TRY(cudaLaunchCooperativeKernel>( - kernel, gs, bs, args, smem_size, stream)); - RAFT_LOG_INFO( - "Initialized the kernel %p in stream %zd; job_queue size = %u; worker_queue size = %u", - reinterpret_cast(kernel), - int64_t((cudaStream_t)stream), - job_queue.capacity(), - worker_queue.capacity()); - last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); - } - - ~persistent_runner_t() noexcept override - { - auto whs = worker_handles.data(); - for (auto i = worker_handles.size(); i > 0; i--) { - whs[worker_queue.pop().wait()].data.store({kNoMoreWork}, cuda::memory_order_relaxed); - } - RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream)); - RAFT_LOG_INFO("Destroyed the persistent runner."); - } - - void launch(uintptr_t result_indices_ptr, // [num_queries, top_k] - distance_type* result_distances_ptr, // [num_queries, top_k] - const data_type* queries_ptr, // [num_queries, dataset_dim] - uint32_t num_queries, - uint32_t top_k) - { - // submit all queries - launcher_t launcher{job_queue, - worker_queue, - worker_handles.data(), - num_queries, - this->lifetime, - [&job_descriptors = this->job_descriptors, - result_indices_ptr, - result_distances_ptr, - queries_ptr, - top_k, - num_queries](uint32_t job_ix) { - auto& jd = job_descriptors.data()[job_ix].input.value; - auto* cflag = &job_descriptors.data()[job_ix].completion_flag; - jd.result_indices_ptr = result_indices_ptr; - jd.result_distances_ptr = result_distances_ptr; - jd.queries_ptr = queries_ptr; - jd.top_k = top_k; - jd.n_queries = num_queries; - cflag->store(false, cuda::memory_order_relaxed); - cuda::atomic_thread_fence(cuda::memory_order_release, - cuda::thread_scope_system); - return cflag; - }}; - - // Update the state of the keep-alive atomic in the meanwhile - auto prev_touch = last_touch.load(std::memory_order_relaxed); - if (prev_touch + lifetime / 10 < launcher.now) { - // to avoid congestion at this atomic, we only update it if a significant fraction of the live - // interval has passed. - last_touch.store(launcher.now, std::memory_order_relaxed); - } - // wait for the results to arrive - launcher.wait(); - } - - auto calc_coop_grid_size(uint32_t block_size, uint32_t smem_size, float persistent_device_usage) - -> dim3 - { - // determine the grid size - int ctas_per_sm = 1; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, block_size, smem_size); - int num_sm = raft::getMultiProcessorCount(); - auto n_blocks = static_cast(persistent_device_usage * (ctas_per_sm * num_sm)); - if (n_blocks > kMaxWorkersNum) { - RAFT_LOG_WARN("Limiting the grid size limit due to the size of the queue: %u -> %u", - n_blocks, - kMaxWorkersNum); - n_blocks = kMaxWorkersNum; - } - - return {1, n_blocks, 1}; - } -}; - -struct alignas(kCacheLineBytes) persistent_state { - std::shared_ptr runner{nullptr}; - std::mutex lock; -}; - -inline persistent_state persistent{}; - -template -auto create_runner(Args... args) -> std::shared_ptr // it's ok.. pass everything by values -{ - std::lock_guard guard(persistent.lock); - // Check if the runner has already been created - std::shared_ptr runner_outer = std::dynamic_pointer_cast(persistent.runner); - if (runner_outer) { - if (runner_outer->param_hash == RunnerT::calculate_parameter_hash(args...)) { - return runner_outer; - } else { - runner_outer.reset(); - } - } - // Runner has not yet been created (or it's incompatible): - // create it in another thread and only then release the lock. - // Free the resources (if any) in advance - persistent.runner.reset(); - - cuda::std::atomic_flag ready{}; - ready.clear(cuda::std::memory_order_relaxed); - std::thread( - [&runner_outer, &ready](Args... thread_args) { // pass everything by values - // create the runner (the lock is acquired in the parent thread). - runner_outer = std::make_shared(thread_args...); - auto lifetime = runner_outer->lifetime; - persistent.runner = std::static_pointer_cast(runner_outer); - std::weak_ptr runner_weak = runner_outer; - ready.test_and_set(cuda::std::memory_order_release); - ready.notify_one(); - // NB: runner_outer is passed by reference and may be dead by this time. - - while (true) { - std::this_thread::sleep_for(lifetime); - auto runner = runner_weak.lock(); // runner_weak is local - thread-safe - if (!runner) { - return; // dead already - } - if (runner->last_touch.load(std::memory_order_relaxed) + lifetime < - std::chrono::system_clock::now()) { - std::lock_guard guard(persistent.lock); - if (runner == persistent.runner) { persistent.runner.reset(); } - return; - } - } - }, - args...) - .detach(); - ready.wait(false, cuda::std::memory_order_acquire); - return runner_outer; -} - -template -auto get_runner(Args... args) -> std::shared_ptr -{ - // Using a thread-local weak pointer allows us to avoid using locks/atomics, - // since the control block of weak/shared pointers is thread-safe. - static thread_local std::weak_ptr weak; - auto runner = weak.lock(); - if (runner) { - if (runner->param_hash == RunnerT::calculate_parameter_hash(args...)) { - return runner; - } else { - weak.reset(); - runner.reset(); - } - } - // Thread-local variable expected_latency makes sense only for a current RunnerT configuration. - // If `weak` is not alive, it's a hint the configuration has changed and we should reset our - // estimate of the expected launch latency. - launcher_t::expected_latency = launcher_t::kDefaultLatency; - runner = create_runner(args...); - weak = runner; - return runner; -} - -template -void select_and_run( - const dataset_descriptor_host& dataset_desc, - raft::device_matrix_view graph, - std::optional> source_indices, - uintptr_t topk_indices_ptr, // [num_queries, topk] - DistanceT* topk_distances_ptr, // [num_queries, topk] - const DataT* queries_ptr, // [num_queries, dataset_dim] - uint32_t num_queries, - const IndexT* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* num_executed_iterations, // [num_queries,] - const search_params& ps, - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_seeds, - SampleFilterT sample_filter, - cudaStream_t stream) -{ - const SourceIndexT* source_indices_ptr = - source_indices.has_value() ? source_indices->data_handle() : nullptr; - - uint32_t max_candidates{}; - if (num_itopk_candidates <= 64) { - max_candidates = 64; - } else if (num_itopk_candidates <= 128) { - max_candidates = 128; - } else if (num_itopk_candidates <= 256) { - max_candidates = 256; - } else { - max_candidates = - 32; // irrelevant, radix based topk is used (see choose_itopk_and_max_candidates) - } - - uint32_t max_itopk{}; - assert(ps.itopk_size <= 512); - if (num_itopk_candidates <= 256) { // bitonic sort - if (ps.itopk_size <= 64) { - max_itopk = 64; - } else if (ps.itopk_size <= 128) { - max_itopk = 128; - } else if (ps.itopk_size <= 256) { - max_itopk = 256; - } else { - max_itopk = 512; - } - } else { // radix sort - if (ps.itopk_size <= 256) { - max_itopk = 256; - } else { - max_itopk = 512; - } - } - - if (ps.persistent) { - using runner_type = persistent_runner_t; - - get_runner(/* -Note, we're passing the descriptor by reference here, and this reference is going to be passed to a -new spawned thread, which is dangerous. However, the descriptor is copied in that thread before the -control is returned in this thread (in persistent_runner_t constructor), so we're safe. -*/ - std::cref(dataset_desc), - graph, - source_indices_ptr, - max_candidates, - num_itopk_candidates, - block_size, - smem_size, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - ps.num_random_samplings, - ps.rand_xor_mask, - num_seeds, - max_itopk, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - sample_filter, - ps.persistent_lifetime, - ps.persistent_device_usage) - ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); - } else { - using descriptor_base_type = dataset_descriptor_base_t; - auto kernel = search_kernel_config:: - choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset_desc.dev_ptr(stream), - queries_ptr, - graph.data_handle(), - graph.extent(1), - source_indices_ptr, - ps.num_random_samplings, - ps.rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - max_candidates, - max_itopk, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter, - static_cast(graph.extent(0))); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } -} -} // namespace single_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh new file mode 100644 index 0000000000..30d0adca6b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh @@ -0,0 +1,9 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "search_single_cta_kernel.cuh" +#include "search_single_cta_kernel_launcher_jit.cuh" diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh new file mode 100644 index 0000000000..b1e2191fec --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Common logic for computing max_candidates and max_itopk +struct LaunchConfig { + uint32_t max_candidates; + uint32_t max_itopk; + bool topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps; +}; + +inline LaunchConfig compute_launch_config(uint32_t num_itopk_candidates, + uint32_t itopk_size, + uint32_t block_size) +{ + LaunchConfig config{}; + + // Compute max_candidates + if (num_itopk_candidates <= 64) { + config.max_candidates = 64; + } else if (num_itopk_candidates <= 128) { + config.max_candidates = 128; + } else if (num_itopk_candidates <= 256) { + config.max_candidates = 256; + } else { + config.max_candidates = 32; // irrelevant, radix based topk is used + } + + // Compute max_itopk and sort flags + config.topk_by_bitonic_sort = (num_itopk_candidates <= 256); + config.bitonic_sort_and_merge_multi_warps = false; + + if (config.topk_by_bitonic_sort) { + if (itopk_size <= 64) { + config.max_itopk = 64; + } else if (itopk_size <= 128) { + config.max_itopk = 128; + } else if (itopk_size <= 256) { + config.max_itopk = 256; + } else { + config.max_itopk = 512; + config.bitonic_sort_and_merge_multi_warps = (block_size >= 64); + } + } else { + if (itopk_size <= 256) { + config.max_itopk = 256; + } else { + config.max_itopk = 512; + } + } + + return config; +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..250abcb113 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -0,0 +1,936 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../smem_utils.cuh" + +// Include tags header before any other includes that might open namespaces +#include + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "hashmap.hpp" +#include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" +#include "jit_lto_kernels/kernel_def.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "search_single_cta_kernel_launcher_common.cuh" +#include "shared_launcher_jit.hpp" // For shared JIT helper functions + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Persistent queues / runner (host). worker_handle_t, job_desc_t, kCacheLineBytes, k* job limits: +// `jit_lto_kernels/search_single_cta_device_helpers.cuh` via `kernel_def.hpp`. + +template ::max()> +struct alignas(kCacheLineBytes) resource_queue_t { + using value_type = T; + static constexpr uint32_t kSize = Size; + static constexpr value_type kEmpty = Empty; + static_assert(cuda::std::atomic::is_always_lock_free, + "The value type must be lock-free."); + static_assert(raft::is_a_power_of_two(kSize), "The size must be a power-of-two for efficiency."); + static constexpr uint32_t kElemsPerCacheLine = + raft::div_rounding_up_safe(kCacheLineBytes, sizeof(value_type)); + static constexpr uint32_t kCounterIncrement = raft::bound_by_power_of_two(kElemsPerCacheLine) + 1; + static constexpr uint32_t kCounterLocMask = kSize - 1; + static_assert( + kCounterIncrement * sizeof(value_type) >= kCacheLineBytes, + "The counter increment should be larger than the cache line size to avoid false sharing."); + static_assert( + std::gcd(kCounterIncrement, kSize) == 1, + "The counter increment and the size must be coprime to allow using all of the queue slots."); + + static constexpr auto kMemOrder = cuda::std::memory_order_relaxed; + + explicit resource_queue_t(uint32_t capacity = Size) noexcept : capacity_{capacity} + { + head_.store(0, kMemOrder); + tail_.store(0, kMemOrder); + for (uint32_t i = 0; i < kSize; i++) { + buf_[i].store(kEmpty, kMemOrder); + } + } + + [[nodiscard]] auto capacity() const { return capacity_; } + + void set_capacity(uint32_t capacity) { capacity_ = capacity; } + + struct promise_t { + explicit promise_t(cuda::std::atomic& loc) : loc_{loc}, val_{Empty} {} + ~promise_t() noexcept { wait(); } + + auto test() noexcept -> bool + { + if (val_ != Empty) { return true; } + val_ = loc_.exchange(kEmpty, kMemOrder); + return val_ != Empty; + } + + auto test(value_type& e) noexcept -> bool + { + if (test()) { + e = val_; + return true; + } + return false; + } + + auto wait() noexcept -> value_type + { + if (val_ == Empty) { + do { + loc_.wait(kEmpty, kMemOrder); + val_ = loc_.exchange(kEmpty, kMemOrder); + } while (val_ == Empty); + } + return val_; + } + + private: + cuda::std::atomic& loc_; + value_type val_; + }; + + void push(value_type x) noexcept + { + auto& loc = buf_[head_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; + value_type e = kEmpty; + while (!loc.compare_exchange_weak(e, x, kMemOrder, kMemOrder)) { + e = kEmpty; + } + loc.notify_one(); + } + + auto pop() noexcept -> promise_t + { + auto& loc = buf_[tail_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; + return promise_t{loc}; + } + + private: + alignas(kCacheLineBytes) cuda::std::atomic head_{}; + alignas(kCacheLineBytes) cuda::std::atomic tail_{}; + alignas(kCacheLineBytes) std::array, kSize> buf_{}; + alignas(kCacheLineBytes) uint32_t capacity_; +}; + +template +struct local_deque_t { + explicit local_deque_t(uint32_t size) : store_(size) {} + + [[nodiscard]] auto capacity() const -> uint32_t { return store_.size(); } + [[nodiscard]] auto size() const -> uint32_t { return end_ - start_; } + + void push_back(T x) { store_[end_++ % capacity()] = x; } + + void push_front(T x) + { + if (start_ == 0) { + start_ += capacity(); + end_ += capacity(); + } + store_[--start_ % capacity()] = x; + } + + auto pop_back() -> T { return store_[--end_ % capacity()]; } + auto pop_front() -> T { return store_[start_++ % capacity()]; } + + auto try_push_back(T x) -> bool + { + if (size() >= capacity()) { return false; } + push_back(x); + return true; + } + + auto try_push_front(T x) -> bool + { + if (size() >= capacity()) { return false; } + push_front(x); + return true; + } + + auto try_pop_back(T& x) -> bool + { + if (start_ >= end_) { return false; } + x = pop_back(); + return true; + } + + auto try_pop_front(T& x) -> bool + { + if (start_ >= end_) { return false; } + x = pop_front(); + return true; + } + + private: + std::vector store_; + uint32_t start_{0}; + uint32_t end_{0}; +}; + +struct persistent_runner_base_t { + using job_queue_type = resource_queue_t; + using worker_queue_type = resource_queue_t; + rmm::mr::pinned_host_memory_resource worker_handles_mr; + rmm::mr::pinned_host_memory_resource job_descriptor_mr; + rmm::mr::cuda_memory_resource device_mr; + cudaStream_t stream{}; + job_queue_type job_queue{}; + worker_queue_type worker_queue{}; + std::chrono::milliseconds lifetime; + + persistent_runner_base_t(float persistent_lifetime) + : lifetime(size_t(persistent_lifetime * 1000)), job_queue(), worker_queue() + { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } + virtual ~persistent_runner_base_t() noexcept { cudaStreamDestroy(stream); }; +}; + +struct alignas(kCacheLineBytes) persistent_state { + std::shared_ptr runner{nullptr}; + std::mutex lock; +}; + +inline persistent_state persistent{}; + +// Forward declarations +template +auto get_runner_jit(Args... args) -> std::shared_ptr; + +template +auto create_runner_jit(Args... args) -> std::shared_ptr; + +// Helper functions are now in shared_launcher_jit.hpp + +// JIT-compatible launcher_t that works with worker_handle_t (same as non-JIT version) +struct alignas(kCacheLineBytes) launcher_jit_t { + using job_queue_type = resource_queue_t; + using worker_queue_type = resource_queue_t; + using pending_reads_queue_type = local_deque_t; + using completion_flag_type = cuda::atomic; + + pending_reads_queue_type pending_reads; + job_queue_type& job_ids; + worker_queue_type& idle_worker_ids; + worker_handle_t* worker_handles; + uint32_t job_id; + completion_flag_type* completion_flag; + bool all_done = false; + + static inline constexpr auto kDefaultLatency = std::chrono::nanoseconds(50000); + static inline constexpr auto kMaxExpectedLatency = + kDefaultLatency * std::max(10, kMaxJobsNum / 128); + static inline thread_local auto expected_latency = kDefaultLatency; + const std::chrono::time_point start; + std::chrono::time_point now; + const int64_t pause_factor; + int pause_count = 0; + std::chrono::time_point deadline; + + template + launcher_jit_t(job_queue_type& job_ids, + worker_queue_type& idle_worker_ids, + worker_handle_t* worker_handles, + uint32_t n_queries, + std::chrono::milliseconds max_wait_time, + RecordWork record_work) + : pending_reads{std::min(n_queries, kMaxWorkersPerThread)}, + job_ids{job_ids}, + idle_worker_ids{idle_worker_ids}, + worker_handles{worker_handles}, + job_id{job_ids.pop().wait()}, + completion_flag{record_work(job_id)}, + start{std::chrono::system_clock::now()}, + pause_factor{calc_pause_factor(n_queries)}, + now{start}, + deadline{start + max_wait_time + expected_latency} + { + submit_query(idle_worker_ids.pop().wait(), 0); + for (uint32_t i = 1; i < n_queries; i++) { + auto promised_worker = idle_worker_ids.pop(); + uint32_t worker_id; + while (!promised_worker.test(worker_id)) { + if (pending_reads.try_pop_front(worker_id)) { + bool returned_some = false; + for (bool keep_returning = true; keep_returning;) { + if (try_return_worker(worker_id)) { + keep_returning = pending_reads.try_pop_front(worker_id); + returned_some = true; + } else { + pending_reads.push_front(worker_id); + keep_returning = false; + } + } + if (!returned_some) { pause(); } + } else { + worker_id = promised_worker.wait(); + break; + } + } + pause_count = 0; + submit_query(worker_id, i); + if (i >= kSoftMaxWorkersPerThread && pending_reads.try_pop_front(worker_id)) { + if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } + } + } + } + + inline ~launcher_jit_t() noexcept + { + constexpr size_t kWindow = 100; + expected_latency = std::min( + ((kWindow - 1) * expected_latency + now - start) / kWindow, kMaxExpectedLatency); + if (job_id != job_queue_type::kEmpty) { job_ids.push(job_id); } + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + idle_worker_ids.push(worker_id); + } + } + + inline void submit_query(uint32_t worker_id, uint32_t query_id) + { + worker_handles[worker_id].data.store(worker_handle_t::data_t{.value = {job_id, query_id}}, + cuda::memory_order_relaxed); + while (!pending_reads.try_push_back(worker_id)) { + auto pending_worker_id = pending_reads.pop_front(); + while (!try_return_worker(pending_worker_id)) { + pause(); + } + } + pause_count = 0; + } + + inline auto try_return_worker(uint32_t worker_id) -> bool + { + if (all_done || + !is_worker_busy(worker_handles[worker_id].data.load(cuda::memory_order_relaxed).handle)) { + idle_worker_ids.push(worker_id); + return true; + } else { + return false; + } + } + + inline auto is_all_done() + { + if (all_done) { return true; } + all_done = completion_flag->load(cuda::memory_order_relaxed); + return all_done; + } + + [[nodiscard]] inline auto sleep_limit() const + { + constexpr auto kMinWakeTime = std::chrono::nanoseconds(10000); + constexpr double kSleepLimit = 0.6; + return start + expected_latency * kSleepLimit - kMinWakeTime; + } + + [[nodiscard]] inline auto overtime_threshold() const + { + constexpr auto kOvertimeFactor = 3; + return start + expected_latency * kOvertimeFactor; + } + + [[nodiscard]] inline auto calc_pause_factor(uint32_t n_queries) const -> uint32_t + { + constexpr uint32_t kMultiplier = 10; + return kMultiplier * raft::div_rounding_up_safe(n_queries, idle_worker_ids.capacity()); + } + + inline void pause() + { + constexpr auto kSpinLimit = 3; + constexpr auto kPauseTimeMin = std::chrono::nanoseconds(1000); + constexpr auto kPauseTimeMax = std::chrono::nanoseconds(50000); + if (pause_count++ < kSpinLimit) { + std::this_thread::yield(); + return; + } + now = std::chrono::system_clock::now(); + auto pause_time_base = std::max(now - start, expected_latency); + auto pause_time = std::clamp(pause_time_base / pause_factor, kPauseTimeMin, kPauseTimeMax); + if (now + pause_time < sleep_limit()) { + std::this_thread::sleep_for(pause_time); + } else if (now <= overtime_threshold()) { + std::this_thread::yield(); + } else if (now <= deadline) { + std::this_thread::sleep_for(pause_time); + } else { + throw raft::exception( + "The calling thread didn't receive the results from the persistent CAGRA kernel within the " + "expected kernel lifetime. Here are possible reasons of this failure:\n" + " (1) `persistent_lifetime` search parameter is too small - increase it;\n" + " (2) there is other work being executed on the same device and the kernel failed to " + "progress - decreasing `persistent_device_usage` may help (but not guaranteed);\n" + " (3) there is a bug in the implementation - please report it to cuVS team."); + } + } + + inline void wait() + { + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + while (!try_return_worker(worker_id)) { + if (!is_all_done()) { pause(); } + } + } + pause_count = 0; + now = std::chrono::system_clock::now(); + while (!is_all_done()) { + auto till_time = sleep_limit(); + if (now < till_time) { + std::this_thread::sleep_until(till_time); + now = std::chrono::system_clock::now(); + } else { + pause(); + } + } + job_ids.push(job_id); + job_id = job_queue_type::kEmpty; + } +}; + +// JIT persistent runner - uses AlgorithmLauncher instead of kernel function pointer +template +struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runner_base_t { + using index_type = IndexT; + using distance_type = DistanceT; + using data_type = DataT; + // Must match job_desc_t> in kernel_def.hpp / persistent kernel. + using job_desc_type = job_desc_t>; + + std::shared_ptr launcher; + uint32_t block_size; + dataset_descriptor_host dd_host; + rmm::device_uvector worker_handles; + rmm::device_uvector job_descriptors; + rmm::device_uvector completion_counters; + rmm::device_uvector hashmap; + std::atomic> last_touch; + uint64_t param_hash; + cagra_bitset bitset; + + static inline auto calculate_parameter_hash( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + uint32_t max_itopk, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage, + std::shared_ptr /* launcher_ptr - not part of hash */, + const void* /* dataset_desc - not part of hash */) -> uint64_t + { + return uint64_t(graph.data_handle()) ^ uint64_t(source_indices_ptr) ^ + dataset_desc.get().team_size ^ num_itopk_candidates ^ block_size ^ smem_size ^ + hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ + num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ + uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000); + } + + persistent_runner_jit_t( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + uint32_t max_itopk, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage, + std::shared_ptr launcher_ptr, + const void* /* dataset_desc - descriptor contains all needed info */) + : persistent_runner_base_t{persistent_lifetime}, + launcher{launcher_ptr}, + block_size{block_size}, + worker_handles(0, stream, worker_handles_mr), + job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), + completion_counters(kMaxJobsNum, stream, device_mr), + hashmap(0, stream, device_mr), + dd_host{dataset_desc.get()}, + param_hash(calculate_parameter_hash(dd_host, + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + num_random_samplings, + rand_xor_mask, + num_seeds, + max_itopk, + itopk_size, + search_width, + min_iterations, + max_iterations, + sample_filter, + persistent_lifetime, + persistent_device_usage, + launcher_ptr, + nullptr)) // descriptor not needed in hash + { + const auto bf = extract_cagra_sample_filter(sample_filter); + this->bitset = bf.bitset; + const uint32_t query_id_offset = bf.query_id_offset; + + // set kernel launch parameters + dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); + dim3 bs(block_size, 1, 1); + RAFT_LOG_DEBUG( + "Launching JIT persistent kernel with %u threads, %u block %u smem", bs.x, gs.y, smem_size); + + // initialize the job queue + auto* completion_counters_ptr = completion_counters.data(); + auto* job_descriptors_ptr = job_descriptors.data(); + for (uint32_t i = 0; i < kMaxJobsNum; i++) { + auto& jd = job_descriptors_ptr[i].input.value; + jd.result_indices_ptr = 0; + jd.result_distances_ptr = nullptr; + jd.queries_ptr = nullptr; + jd.top_k = 0; + jd.n_queries = 0; + job_descriptors_ptr[i].completion_flag.store(false); + job_queue.push(i); + } + + // initialize the worker queue + worker_queue.set_capacity(gs.y); + worker_handles.resize(gs.y, stream); + auto* worker_handles_ptr = worker_handles.data(); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (uint32_t i = 0; i < gs.y; i++) { + worker_handles_ptr[i].data.store({kWaitForWork}); + worker_queue.push(i); + } + + index_type* hashmap_ptr = nullptr; + if (small_hash_bitlen == 0) { + hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream); + hashmap_ptr = hashmap.data(); + } + + // Prepare kernel arguments + // Get the device descriptor pointer - kernel will use the concrete type from template + const auto* dev_desc = dataset_desc.get().dev_ptr(stream); + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(itopk_size); + const uint32_t search_width_u32 = static_cast(search_width); + const uint32_t min_iterations_u32 = static_cast(min_iterations); + const uint32_t max_iterations_u32 = static_cast(max_iterations); + const unsigned num_random_samplings_u = static_cast(num_random_samplings); + + const IndexT* seed_ptr_arg = nullptr; + uint32_t* num_executed_iterations_arg = nullptr; + // Launch the persistent kernel via AlgorithmLauncher + // The persistent kernel now takes the descriptor pointer directly + launcher->dispatch_cooperative< + single_cta_search::search_single_cta_p_kernel_func_t>( + stream, + gs, + bs, + static_cast(smem_size), + worker_handles_ptr, + job_descriptors_ptr, + completion_counters_ptr, + graph.data_handle(), + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + rand_xor_mask, // uint64_t matches kernel (8 bytes) + seed_ptr_arg, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, // Cast size_t to uint32_t + search_width_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + num_executed_iterations_arg, + hash_bitlen_u32, // Cast int64_t to uint32_t + small_hash_bitlen_u32, // Cast size_t to uint32_t + small_hash_reset_interval_u32, // Cast size_t to uint32_t + query_id_offset, // Offset to add to query_id when calling filter + dev_desc, + bitset); + + last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); + } + + ~persistent_runner_jit_t() noexcept override + { + auto whs = worker_handles.data(); + for (auto i = worker_handles.size(); i > 0; i--) { + whs[worker_queue.pop().wait()].data.store({kNoMoreWork}, cuda::memory_order_relaxed); + } + RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream)); + } + + void launch(uintptr_t result_indices_ptr, + distance_type* result_distances_ptr, + const data_type* queries_ptr, + uint32_t num_queries, + uint32_t top_k) + { + launcher_jit_t launcher{job_queue, + worker_queue, + worker_handles.data(), + num_queries, + this->lifetime, + [&job_descriptors = this->job_descriptors, + result_indices_ptr, + result_distances_ptr, + queries_ptr, + top_k, + num_queries](uint32_t job_ix) { + auto& jd = job_descriptors.data()[job_ix].input.value; + auto* cflag = &job_descriptors.data()[job_ix].completion_flag; + jd.result_indices_ptr = result_indices_ptr; + jd.result_distances_ptr = result_distances_ptr; + jd.queries_ptr = queries_ptr; + jd.top_k = top_k; + jd.n_queries = num_queries; + cflag->store(false, cuda::memory_order_relaxed); + cuda::atomic_thread_fence(cuda::memory_order_release, + cuda::thread_scope_system); + return cflag; + }}; + + auto prev_touch = last_touch.load(std::memory_order_relaxed); + if (prev_touch + lifetime / 10 < launcher.now) { + last_touch.store(launcher.now, std::memory_order_relaxed); + } + launcher.wait(); + } + + auto calc_coop_grid_size(uint32_t block_size, uint32_t smem_size, float persistent_device_usage) + -> dim3 + { + int ctas_per_sm = 1; + cudaKernel_t kernel_handle = launcher->get_kernel(); + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel_handle, block_size, smem_size)); + int num_sm = raft::getMultiProcessorCount(); + auto n_blocks = static_cast(persistent_device_usage * (ctas_per_sm * num_sm)); + if (n_blocks > kMaxWorkersNum) { + RAFT_LOG_WARN("Limiting the grid size limit due to the size of the queue: %u -> %u", + n_blocks, + kMaxWorkersNum); + n_blocks = kMaxWorkersNum; + } + return {1, n_blocks, 1}; + } +}; + +template +void select_and_run( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const SourceIndexT* source_indices_ptr = + source_indices.has_value() ? source_indices->data_handle() : nullptr; + + const auto bf = extract_cagra_sample_filter(sample_filter); + const cagra_bitset bitset = bf.bitset; + const uint32_t query_id_offset = bf.query_id_offset; + + // Use common logic to compute launch config + auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); + uint32_t max_candidates = config.max_candidates; + uint32_t max_itopk = config.max_itopk; + bool topk_by_bitonic_sort = config.topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps = config.bitonic_sort_and_merge_multi_warps; + + // Handle persistent kernels + if (ps.persistent) { + // Use persistent runner for JIT kernels + using runner_type = + persistent_runner_jit_t; + + std::shared_ptr launcher = + make_cagra_single_cta_jit_launcher>( + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + true /* persistent */); + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA persistent search kernel"); } + + // Use get_runner pattern similar to non-JIT version + const auto* dev_desc_persistent = dataset_desc.dev_ptr(stream); + get_runner_jit(std::cref(dataset_desc), + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + ps.num_random_samplings, + ps.rand_xor_mask, + num_seeds, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + sample_filter, + ps.persistent_lifetime, + ps.persistent_device_usage, + launcher, + dev_desc_persistent) // Pass descriptor pointer + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + return; + } else { + std::shared_ptr launcher = + make_cagra_single_cta_jit_launcher>( + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + false /* persistent */); + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA search kernel"); } + + // Get the device descriptor pointer - dev_ptr() initializes it if needed + const auto* dev_desc = dataset_desc.dev_ptr(stream); + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t search_width_u32 = static_cast(ps.search_width); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + dim3 grid(1, num_queries, 1); + dim3 block(block_size, 1, 1); + + RAFT_LOG_DEBUG("Launching JIT kernel with %u threads, %u blocks, %u smem", + block_size, + num_queries, + smem_size); + + // Dispatch kernel via launcher + auto kernel_launcher = [&]() -> void { + launcher->dispatch>( + stream, + grid, + block, + static_cast(smem_size), + topk_indices_ptr, + topk_distances_ptr, + topk, + queries_ptr, + graph.data_handle(), + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + ps.rand_xor_mask, // uint64_t matches kernel (8 bytes) + dev_seed_ptr, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, // Cast size_t to uint32_t + search_width_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + num_executed_iterations, + hash_bitlen_u32, // Cast int64_t to uint32_t + small_hash_bitlen_u32, // Cast size_t to uint32_t + small_hash_reset_interval_u32, // Cast size_t to uint32_t + query_id_offset, // Offset to add to query_id when calling filter + dev_desc, + static_cast(graph.extent(0)), + bitset); + }; + + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< + search_single_cta_kernel_func_t>( + smem_size, kernel_launcher, launcher->get_kernel()); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +// get_runner for JIT persistent runners (similar to non-JIT version) +template +auto get_runner_jit(Args... args) -> std::shared_ptr +{ + static thread_local std::weak_ptr weak; + auto runner = weak.lock(); + if (runner) { + if (runner->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner; + } else { + weak.reset(); + runner.reset(); + } + } + launcher_jit_t::expected_latency = launcher_jit_t::kDefaultLatency; + runner = create_runner_jit(args...); + weak = runner; + return runner; +} + +template +auto create_runner_jit(Args... args) -> std::shared_ptr +{ + std::lock_guard guard(persistent.lock); + std::shared_ptr runner_outer = std::dynamic_pointer_cast(persistent.runner); + if (runner_outer) { + // calculate_parameter_hash needs all args to match constructor signature + // but only uses a subset for the actual hash + if (runner_outer->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner_outer; + } else { + runner_outer.reset(); + } + } + persistent.runner.reset(); + + cuda::std::atomic_flag ready{}; + ready.clear(cuda::std::memory_order_relaxed); + std::thread( + [&runner_outer, &ready](Args... thread_args) { + runner_outer = std::make_shared(thread_args...); + auto lifetime = runner_outer->lifetime; + persistent.runner = std::static_pointer_cast(runner_outer); + std::weak_ptr runner_weak = runner_outer; + ready.test_and_set(cuda::std::memory_order_release); + ready.notify_one(); + + while (true) { + std::this_thread::sleep_for(lifetime); + auto runner = runner_weak.lock(); + if (!runner) { return; } + if (runner->last_touch.load(std::memory_order_relaxed) + lifetime < + std::chrono::system_clock::now()) { + std::lock_guard guard(persistent.lock); + if (runner == persistent.runner) { persistent.runner.reset(); } + return; + } + } + }, + args...) + .detach(); + ready.wait(false, cuda::std::memory_order_acquire); + return runner_outer; +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/set_value_batch.cuh b/cpp/src/neighbors/detail/cagra/set_value_batch.cuh new file mode 100644 index 0000000000..a4433005a7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/set_value_batch.cuh @@ -0,0 +1,40 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +__global__ void set_value_batch_kernel(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count * batch_size) { return; } + const auto batch_id = tid / count; + const auto elem_id = tid % count; + dev_ptr[elem_id + ld * batch_id] = val; +} + +template +void set_value_batch(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size, + cudaStream_t cuda_stream) +{ + constexpr std::uint32_t block_size = 256; + const auto grid_size = (count * batch_size + block_size - 1) / block_size; + set_value_batch_kernel + <<>>(dev_ptr, ld, val, count, batch_size); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp new file mode 100644 index 0000000000..1259bc6ff9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -0,0 +1,134 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter +#include "jit_lto_kernels/cagra_bitset.cuh" // is_bitset_filter, cagra_bitset, cagra_sample_filter, extract + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +// Helper functions to get tags for JIT LTO +template +constexpr auto get_data_type_tag() +{ + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_f{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_h{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_i8{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_u8{}; } +} + +template +constexpr auto get_index_type_tag() +{ + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_index_u32{}; } +} + +template +constexpr auto get_distance_type_tag() +{ + if constexpr (std::is_same_v) { return cuvs::neighbors::cagra::detail::tag_dist_f{}; } +} + +template +constexpr auto get_source_index_type_tag() +{ + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_index_u32{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_index_i64{}; } +} + +template +struct query_type_tag_standard { + using type = std::conditional_t, + cuvs::neighbors::detail::tag_u8, + cuvs::neighbors::detail::tag_f>; +}; + +template +using query_type_tag_standard_t = typename query_type_tag_standard::type; + +template +using query_type_tag_vpq_t = cuvs::neighbors::detail::tag_h; + +template +using query_type_tag_standard_l2_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_inner_product_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_cosine_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_hamming_t = + query_type_tag_standard_t; + +using codebook_tag_vpq_t = cuvs::neighbors::cagra::detail::tag_codebook_half; +using codebook_tag_standard_t = cuvs::neighbors::cagra::detail::tag_codebook_none; + +// Dependent false for static_assert in sample_filter_jit_tag (CAGRA JIT). +template +inline constexpr bool cagra_jit_sample_filter_tag_type_always_false = false; + +/// Maps a host sample-filter type to the JIT fragment tag (`cuvs::neighbors::detail::tag_filter_*`) +/// for `CagraPlannerBase<..., JitTag>` / `add_sample_filter_device_function()`. +template +struct sample_filter_jit_tag { + private: + using DecayedFilter = std::decay_t; + + struct dispatch { + template + static constexpr auto f() + { + using namespace cuvs::neighbors::filtering; + if constexpr (std::is_same_v) { + return cuvs::neighbors::detail::tag_filter_none{}; + } else if constexpr (requires { std::declval().filter; }) { + using InnerFilter = decltype(std::declval().filter); + if constexpr (is_bitset_filter::value || + std::is_same_v> || + std::is_same_v>) { + return cuvs::neighbors::detail::tag_filter_bitset{}; + } else { + static_assert( + cagra_jit_sample_filter_tag_type_always_false, + "CAGRA JIT: sample_filter_jit_tag does not know how to link this filter. " + "CagraSampleFilterWithQueryIdOffset requires Inner of type " + "bitset_filter (see cagra_bitset.cuh is_bitset_filter and sample_filter_utils.cuh). " + "For a new filter kind, add a sample_filter_jit_tag branch. " + "(SAMPLE_FILTER_T in error; check InnerFilter in compiler output.)"); + } + } else { + static_assert( + cagra_jit_sample_filter_tag_type_always_false, + "CAGRA JIT: sample_filter_jit_tag: SAMPLE_FILTER_T must be cuvs::neighbors::filtering::" + "none_sample_filter, or " + "cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<" + "bitset_filter>. Unknown wrapper type. " + "(SAMPLE_FILTER_T in error; add a branch in sample_filter_jit_tag.)"); + } + } + }; + + public: + using type = decltype(dispatch::template f()); +}; + +template +using sample_filter_jit_tag_t = typename sample_filter_jit_tag::type; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/neighbors_device_intrinsics.cuh b/cpp/src/neighbors/detail/neighbors_device_intrinsics.cuh new file mode 100644 index 0000000000..89ac0e8271 --- /dev/null +++ b/cpp/src/neighbors/detail/neighbors_device_intrinsics.cuh @@ -0,0 +1,88 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace cuvs::neighbors::detail::device { + +// warpSize for compile time calculation +constexpr unsigned warp_size = 32; + +// using LOAD_256BIT_T = ulonglong4; +using LOAD_128BIT_T = uint4; +using LOAD_64BIT_T = uint64_t; + +template +RAFT_DEVICE_INLINE_FUNCTION constexpr unsigned get_vlen() +{ + static_assert(sizeof(DATA_T) > 0, "get_vlen: DATA_T must have positive size"); + return static_cast(sizeof(LOAD_T) / sizeof(DATA_T)); +} + +/** Xorshift rondem number generator. + * + * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. + */ +_RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u) +{ + u ^= u >> 12; + u ^= u << 25; + u ^= u >> 27; + return u * 0x2545F4914F6CDD1DULL; +} + +template +RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x) -> T +{ +#pragma unroll + for (uint32_t stride = TeamSize >> 1; stride > 0; stride >>= 1) { + x += raft::shfl_xor(x, stride, TeamSize); + } + return x; +} + +template +RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size_bitshift) -> T +{ + switch (team_size_bitshift) { + case 5: x += raft::shfl_xor(x, 16); [[fallthrough]]; + case 4: x += raft::shfl_xor(x, 8); [[fallthrough]]; + case 3: x += raft::shfl_xor(x, 4); [[fallthrough]]; + case 2: x += raft::shfl_xor(x, 2); [[fallthrough]]; + case 1: x += raft::shfl_xor(x, 1); [[fallthrough]]; + default: return x; + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION constexpr auto swizzling(T x) -> T +{ + // Address swizzling reduces bank conflicts in shared memory, but increases + // the amount of operation instead. + if constexpr (Stride <= 32) { + return x; + } else if constexpr (Dim <= 1024) { + return x ^ (x >> 5); + } else { + return x ^ ((x >> 5) & 0x1f); + } +} + +} // namespace cuvs::neighbors::detail::device + +// CAGRA JIT kernels extend `cuvs::neighbors::cagra::detail::device` in other headers; re-export +// the shared helpers there under the historical nested name. +namespace cuvs::neighbors::cagra::detail::device { +using namespace cuvs::neighbors::detail::device; +} // namespace cuvs::neighbors::cagra::detail::device diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 2e50aefc15..19a86dba56 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -6,7 +6,7 @@ #pragma once #include "ann_utils.cuh" -#include "cagra/device_common.hpp" +#include "neighbors_device_intrinsics.cuh" #include "nn_descent_gnnd.hpp" #include "../../core/omp_wrapper.hpp" @@ -1665,7 +1665,7 @@ void GNND::build(Data_t* data, graph_shrink_buffer[i * build_config_.node_degree + j] = id; } else { graph_shrink_buffer[i * build_config_.node_degree + j] = - cuvs::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; + cuvs::neighbors::detail::device::xorshift64(idx) % nrow_; } } } diff --git a/cpp/src/neighbors/detail/sample_filter_data.cuh b/cpp/src/neighbors/detail/sample_filter_data.cuh new file mode 100644 index 0000000000..4c99ca1e3a --- /dev/null +++ b/cpp/src/neighbors/detail/sample_filter_data.cuh @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::neighbors::detail { + +/// Bitset (and length metadata) for linked @c sample_filter in JIT LTO; passed by value to +/// @c __global__ entry points. `bitset_ptr == nullptr` means no bitset (none filter, or not a +/// bitset at runtime on host). Also passed as @c void* to the unified JIT @c sample_filter (see +/// @c sample_filter.cuh / @c bitset_filter). +template +struct bitset_filter_data_t { + std::uint32_t* bitset_ptr{nullptr}; + SourceIndexT bitset_len{}; + SourceIndexT original_nbits{}; +}; + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/smem_utils.cuh b/cpp/src/neighbors/detail/smem_utils.cuh index 73bc8c578d..8ad1500570 100644 --- a/cpp/src/neighbors/detail/smem_utils.cuh +++ b/cpp/src/neighbors/detail/smem_utils.cuh @@ -6,68 +6,93 @@ #include #include +#include #include #include namespace cuvs::neighbors::detail { +/** Smem high-water + last CUDA kernel handle for one @p KernelT. Handle as uint64_t bits (not + * std::atomic) for portable lock-free atomics. */ +template +struct jit_kernel_smem_state { + std::atomic last_smem_size{0}; + std::atomic last_cuda_kernel_bits{0}; + std::mutex mutex; +}; + +/** One state object per @p KernelT (Meyers singleton). Avoids `static inline` data members in a + * class template, which can pull in 16-byte libatomic helpers under NVCC/host linking. */ +template +jit_kernel_smem_state& jit_kernel_smem_state_for() noexcept +{ + static jit_kernel_smem_state state; + return state; +} + /** * @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size. * * Maintains a monotonically growing high-water mark for - * `cudaFuncAttributeMaxDynamicSharedMemorySize`. When the kernel function pointer changes, the new - * kernel is brought up to the current high-water mark; when smem_size exceeds the high-water mark, - * it is grown for the current kernel. This guarantees every kernel's attribute is always >= + * `cudaFuncAttributeMaxDynamicSharedMemorySize`. When the kernel identity changes, the new kernel + * is brought up to the current high-water mark; when @p smem_size exceeds the high-water mark, it + * is grown for the current kernel. This guarantees every kernel's attribute is always >= @p * smem_size at the time of launch. * - * NB: cudaFuncSetAttribute is per kernel function pointer value, not per type. Multiple kernel - * template instantiations may share the same KernelT type (e.g. function pointers with the same - * signature), so we track the kernel identity alongside the smem high-water mark. + * This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed + * atomically. Used this way, `cudaFuncAttributeMaxDynamicSharedMemorySize` can only grow and the + * kernel remains safe to launch. * - * @tparam KernelT The type of the kernel. - * @tparam KernelLauncherT The type of the launch function/lambda. - * @param kernel The kernel function address (for whom the smem-size is specified). - * @param smem_size The size of the dynamic shared memory to be set. - * @param launch The kernel launch function/lambda. + * NB: cudaFuncSetAttribute is per kernel handle value, not per C++ type. Multiple template + * instantiations may share the same @p KernelT (e.g. the same function signature), so we track the + * last @p cuda_kernel handle (as opaque bits) alongside the smem high-water mark. + * + * @tparam KernelT Kernel function type from kernel_def.hpp (keys static state per signature). + * @tparam KernelLauncherT Type of the launch callable (e.g. lambda calling launcher->dispatch). + * @param smem_size Dynamic shared memory required for this launch. + * @param launch Invoked after attributes are set; takes no arguments. + * @param cuda_kernel Handle passed to cudaFuncSetAttribute (e.g. launcher->get_kernel()). */ template -void safely_launch_kernel_with_smem_size(KernelT const& kernel, - uint32_t smem_size, - KernelLauncherT const& launch) +void safely_launch_kernel_with_smem_size(std::uint32_t smem_size, + KernelLauncherT const& launch, + cudaKernel_t cuda_kernel) { - // last_smem_size is a monotonically growing high-water mark across all kernel pointers. - // last_kernel tracks which kernel pointer was last used. - static std::atomic last_smem_size{0}; - static std::atomic last_kernel{KernelT{}}; - static std::mutex mutex; - // Fast path: skip the lock when the kernel matches and the smem size is within bounds. - // Load order matters: last_smem_size (acquire) before last_kernel (relaxed). Inside the lock - // we store in the opposite order: last_kernel (relaxed) then last_smem_size (release). - // This way an acquire load of last_smem_size that sees a post-cudaFuncSetAttribute value is - // guaranteed to also see the corresponding last_kernel. - if (smem_size > last_smem_size.load(std::memory_order_acquire) || - kernel != last_kernel.load(std::memory_order_relaxed)) { - std::lock_guard guard(mutex); + auto& st = jit_kernel_smem_state_for(); + + std::uint64_t const current_bits = + static_cast(reinterpret_cast(cuda_kernel)); + + // Fast path: skip the lock when the kernel handle matches and smem is within bounds. + // Load order matters: last_smem_size (acquire) before last_cuda_kernel_bits (relaxed). Inside + // the lock we store in the opposite order: last_cuda_kernel_bits (relaxed) then last_smem_size + // (release). This way an acquire load of last_smem_size that sees a post-cudaFuncSetAttribute + // value is guaranteed to also see the corresponding handle bits. + std::uint32_t observed_smem = st.last_smem_size.load(std::memory_order_acquire); + std::uint64_t observed_bits = st.last_cuda_kernel_bits.load(std::memory_order_relaxed); + if (smem_size > observed_smem || current_bits != observed_bits) { + std::lock_guard guard(st.mutex); // Re-check under the lock: the outside decision can be stale. - uint32_t cur_smem_size = last_smem_size.load(std::memory_order_relaxed); - bool need_update = (kernel != last_kernel.load(std::memory_order_relaxed)); + std::uint32_t cur_smem_size = st.last_smem_size.load(std::memory_order_relaxed); + observed_bits = st.last_cuda_kernel_bits.load(std::memory_order_relaxed); + bool need_update = (current_bits != observed_bits); if (smem_size > cur_smem_size) { cur_smem_size = smem_size; need_update = true; } if (need_update) { - auto launch_status = - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, cur_smem_size); + auto launch_status = cudaFuncSetAttribute( + cuda_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, cur_smem_size); RAFT_EXPECTS(launch_status == cudaSuccess, "Failed to set max dynamic shared memory size to %u bytes", cur_smem_size); - // Store order matters: last_kernel before last_smem_size (release) so the fast-path - // acquire load of last_smem_size also publishes last_kernel. - last_kernel.store(kernel, std::memory_order_relaxed); - last_smem_size.store(cur_smem_size, std::memory_order_release); + // Store order matters: handle bits before last_smem_size (release) so the fast-path acquire + // load of last_smem_size also publishes the handle. + st.last_cuda_kernel_bits.store(current_bits, std::memory_order_relaxed); + st.last_smem_size.store(cur_smem_size, std::memory_order_release); } } - return launch(kernel); + return launch(); } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh index c4113a83ce..a9748ef836 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh @@ -28,7 +28,6 @@ #include namespace cuvs::neighbors::ivf_flat::detail { - static constexpr int kThreadsPerBlock = 128; using namespace cuvs::spatial::knn::detail; // NOLINT @@ -37,10 +36,10 @@ using namespace cuvs::spatial::knn::detail; // NOLINT template constexpr auto get_data_type_tag() { - if constexpr (std::is_same_v) { return tag_f{}; } - if constexpr (std::is_same_v) { return tag_h{}; } - if constexpr (std::is_same_v) { return tag_i8{}; } - if constexpr (std::is_same_v) { return tag_u8{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_f{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_h{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_i8{}; } + if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_u8{}; } } template @@ -203,6 +202,9 @@ void launch_kernel(const index& index, return; } + // Pass individual filter parameters like CAGRA does + // The kernel will construct filter_data struct internally when needed + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); dim3 grid_dim(grid_dim_x, grid_dim_y, 1); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh index c4b2031e54..8aa8ce3c30 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh @@ -93,9 +93,9 @@ auto get_lut_type_tag() return tag_lut_f{}; } else if constexpr (std::is_same_v) { return tag_lut_h{}; - } else if constexpr (std::is_same_v>) { + } else if constexpr (std::is_same_v>) { return tag_lut_fp8_unsigned{}; - } else if constexpr (std::is_same_v>) { + } else if constexpr (std::is_same_v>) { return tag_lut_fp8_signed{}; } else { static_assert(sizeof(LutT) == 0, "Unsupported LutT type");