From 75ceb6f66f411acb0edbf5b5fcc41ae616ad3c47 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 27 Feb 2026 18:10:46 +0100 Subject: [PATCH 01/31] IVF-SQ --- cpp/CMakeLists.txt | 6 + cpp/bench/ann/CMakeLists.txt | 11 +- .../src/cuvs/cuvs_ann_bench_param_parser.h | 25 + cpp/bench/ann/src/cuvs/cuvs_benchmark.cu | 20 +- cpp/bench/ann/src/cuvs/cuvs_ivf_sq.cu | 10 + cpp/bench/ann/src/cuvs/cuvs_ivf_sq_wrapper.h | 141 ++++ cpp/include/cuvs/neighbors/ivf_sq.hpp | 336 +++++++++ cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 664 ++++++++++++++++++ ...f_sq_build_extend_float_uint8_t_int64_t.cu | 89 +++ ...vf_sq_build_extend_half_uint8_t_int64_t.cu | 89 +++ cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 549 +++++++++++++++ .../ivf_sq_search_float_uint8_t_int64_t.cu | 29 + .../ivf_sq_search_half_uint8_t_int64_t.cu | 29 + cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh | 161 +++++ .../ivf_sq/ivf_sq_serialize_uint8_t.cu | 16 + cpp/src/neighbors/ivf_sq_index.cpp | 236 +++++++ cpp/tests/CMakeLists.txt | 7 + cpp/tests/neighbors/ann_ivf_sq.cuh | 457 ++++++++++++ .../ann_ivf_sq/test_float_uint8_t.cu | 21 + .../cuvs_bench/config/algorithms.yaml | 3 + .../cuvs_bench/config/algos/cuvs_ivf_sq.yaml | 16 + 21 files changed, 2913 insertions(+), 2 deletions(-) create mode 100644 cpp/bench/ann/src/cuvs/cuvs_ivf_sq.cu create mode 100644 cpp/bench/ann/src/cuvs/cuvs_ivf_sq_wrapper.h create mode 100644 cpp/include/cuvs/neighbors/ivf_sq.hpp create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh create mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_serialize_uint8_t.cu create mode 100644 cpp/src/neighbors/ivf_sq_index.cpp create mode 100644 cpp/tests/neighbors/ann_ivf_sq.cuh create mode 100644 cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu create mode 100644 python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d90579812a..610b9eff3f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -621,6 +621,12 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/ivf_pq/detail/ivf_pq_transform_half_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_transform_int8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_transform_uint8_t_int64_t.cu + src/neighbors/ivf_sq_index.cpp + src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu + src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu + src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu + src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu + src/neighbors/ivf_sq/ivf_sq_serialize_uint8_t.cu src/neighbors/knn_merge_parts.cu src/neighbors/nn_descent.cu src/neighbors/nn_descent_float.cu diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 8d254c0933..ae42abeb35 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -1,6 +1,6 @@ # ============================================================================= # cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # cmake-format: on # ============================================================================= @@ -24,6 +24,7 @@ option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algori option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_FAISS_CPU_HNSW_FLAT "Include faiss' hnsw algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT "Include cuVS ivf flat algorithm in benchmark" ON) +option(CUVS_ANN_BENCH_USE_CUVS_IVF_SQ "Include cuVS ivf sq algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ "Include cuVS ivf pq algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_CUVS_CAGRA "Include cuVS CAGRA in benchmark" ON) option(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benchmark" ON) @@ -80,6 +81,7 @@ set(CUVS_USE_FAISS_STATIC ON) if(BUILD_CPU_ONLY) set(CUVS_FAISS_ENABLE_GPU OFF) set(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT OFF) + set(CUVS_ANN_BENCH_USE_CUVS_IVF_SQ OFF) set(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ OFF) set(CUVS_ANN_BENCH_USE_CUVS_CAGRA OFF) set(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE OFF) @@ -97,6 +99,7 @@ set(CUVS_ANN_BENCH_USE_CUVS OFF) if(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ OR CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE OR CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT + OR CUVS_ANN_BENCH_USE_CUVS_IVF_SQ OR CUVS_ANN_BENCH_USE_CUVS_CAGRA OR CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB OR CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE @@ -242,6 +245,12 @@ if(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT) ) endif() +if(CUVS_ANN_BENCH_USE_CUVS_IVF_SQ) + ConfigureAnnBench( + NAME CUVS_IVF_SQ PATH src/cuvs/cuvs_benchmark.cu src/cuvs/cuvs_ivf_sq.cu LINKS cuvs + ) +endif() + if(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE) ConfigureAnnBench(NAME CUVS_BRUTE_FORCE PATH src/cuvs/cuvs_benchmark.cu LINKS cuvs) endif() diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h index faa3345d1f..4bc7505dc4 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h @@ -35,6 +35,11 @@ extern template class cuvs::bench::cuvs_cagra; extern template class cuvs::bench::cuvs_cagra; #endif +#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_SQ +#include "cuvs_ivf_sq_wrapper.h" +extern template class cuvs::bench::cuvs_ivf_sq; +extern template class cuvs::bench::cuvs_ivf_sq; +#endif #ifdef CUVS_ANN_BENCH_USE_CUVS_MG #include "cuvs_ivf_flat_wrapper.h" #include "cuvs_mg_ivf_flat_wrapper.h" @@ -86,6 +91,26 @@ void parse_search_param(const nlohmann::json& conf, } #endif +#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_SQ +template +void parse_build_param(const nlohmann::json& conf, + typename cuvs::bench::cuvs_ivf_sq::build_param& param) +{ + param.n_lists = conf.at("nlist"); + if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } + if (conf.contains("ratio")) { + param.kmeans_trainset_fraction = 1.0 / static_cast(conf.at("ratio")); + } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename cuvs::bench::cuvs_ivf_sq::search_param& param) +{ + param.ivf_sq_params.n_probes = conf.at("nprobe"); +} +#endif + #if defined(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA) || \ defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB) || defined(CUVS_ANN_BENCH_USE_CUVS_MG) || \ defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_DISKANN) diff --git a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu index aebac654c2..22aeb31c38 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -84,6 +84,15 @@ auto create_algo(const std::string& algo_name, } } #endif +#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_SQ + if constexpr (std::is_same_v || std::is_same_v) { + if (algo_name == "cuvs_ivf_sq") { + typename cuvs::bench::cuvs_ivf_sq::build_param param; + parse_build_param(conf, param); + a = std::make_unique>(metric, dim, param); + } + } +#endif #ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_PQ if (algo_name == "raft_ivf_pq" || algo_name == "cuvs_ivf_pq") { typename cuvs::bench::cuvs_ivf_pq::build_param param; @@ -151,6 +160,15 @@ auto create_search_param(const std::string& algo_name, const nlohmann::json& con } } #endif +#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_SQ + if constexpr (std::is_same_v || std::is_same_v) { + if (algo_name == "cuvs_ivf_sq") { + auto param = std::make_unique::search_param>(); + parse_search_param(conf, *param); + return param; + } + } +#endif #ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_PQ if (algo_name == "raft_ivf_pq" || algo_name == "cuvs_ivf_pq") { auto param = std::make_unique::search_param>(); diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_sq.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_sq.cu new file mode 100644 index 0000000000..ec41324c8d --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_sq.cu @@ -0,0 +1,10 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cuvs_ivf_sq_wrapper.h" + +namespace cuvs::bench { +template class cuvs_ivf_sq; +template class cuvs_ivf_sq; +} // namespace cuvs::bench diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_sq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_ivf_sq_wrapper.h new file mode 100644 index 0000000000..1503e6bb84 --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_sq_wrapper.h @@ -0,0 +1,141 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "../common/ann_types.hpp" +#include "cuvs_ann_bench_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cuvs::bench { + +template +class cuvs_ivf_sq : public algo, public algo_gpu { + public: + using search_param_base = typename algo::search_param; + + struct search_param : public search_param_base { + cuvs::neighbors::ivf_sq::search_params ivf_sq_params; + }; + + using build_param = cuvs::neighbors::ivf_sq::index_params; + + cuvs_ivf_sq(Metric metric, int dim, const build_param& param) + : algo(metric, dim), index_params_(param), dimension_(dim) + { + index_params_.metric = parse_metric_type(metric); + index_params_.conservative_memory_allocation = true; + RAFT_CUDA_TRY(cudaGetDevice(&device_)); + } + + void build(const T* dataset, size_t nrow) final; + + void set_search_param(const search_param_base& param, const void* filter_bitset) override; + + void search(const T* queries, + int batch_size, + int k, + algo_base::index_type* neighbors, + float* distances) const override; + + [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override + { + return handle_.get_sync_stream(); + } + + [[nodiscard]] auto get_preference() const -> algo_property override + { + algo_property property; + property.dataset_memory_type = MemoryType::kHostMmap; + property.query_memory_type = MemoryType::kDevice; + return property; + } + + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + configured_raft_resources handle_{}; + build_param index_params_; + cuvs::neighbors::ivf_sq::search_params search_params_; + std::shared_ptr> index_; + int device_; + int dimension_; + + std::shared_ptr filter_; +}; + +template +void cuvs_ivf_sq::build(const T* dataset, size_t nrow) +{ + size_t n_streams = 1; + raft::resource::set_cuda_stream_pool(handle_, std::make_shared(n_streams)); + index_ = std::make_shared>( + std::move(cuvs::neighbors::ivf_sq::build( + handle_, + index_params_, + raft::make_host_matrix_view(dataset, nrow, dimension_)))); +} + +template +void cuvs_ivf_sq::set_search_param(const search_param_base& param, const void* filter_bitset) +{ + filter_ = make_cuvs_filter(filter_bitset, index_->size()); + auto sp = dynamic_cast(param); + search_params_ = sp.ivf_sq_params; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void cuvs_ivf_sq::save(const std::string& file) const +{ + cuvs::neighbors::ivf_sq::serialize(handle_, file, *index_); +} + +template +void cuvs_ivf_sq::load(const std::string& file) +{ + index_ = + std::make_shared>(handle_, index_params_, this->dim_); + cuvs::neighbors::ivf_sq::deserialize(handle_, file, index_.get()); +} + +template +std::unique_ptr> cuvs_ivf_sq::copy() +{ + return std::make_unique>(*this); +} + +template +void cuvs_ivf_sq::search( + const T* queries, int batch_size, int k, algo_base::index_type* neighbors, float* distances) const +{ + static_assert(sizeof(algo_base::index_type) == sizeof(int64_t)); + + cuvs::neighbors::ivf_sq::search( + handle_, + search_params_, + *index_, + raft::make_device_matrix_view(queries, batch_size, index_->dim()), + raft::make_device_matrix_view( + reinterpret_cast(neighbors), batch_size, k), + raft::make_device_matrix_view(distances, batch_size, k), + *filter_); +} + +} // namespace cuvs::bench diff --git a/cpp/include/cuvs/neighbors/ivf_sq.hpp b/cpp/include/cuvs/neighbors/ivf_sq.hpp new file mode 100644 index 0000000000..2f09751e95 --- /dev/null +++ b/cpp/include/cuvs/neighbors/ivf_sq.hpp @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "common.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::ivf_sq { + +/** + * @defgroup ivf_sq_cpp_index_params IVF-SQ index build parameters + * @{ + */ + +constexpr static uint32_t kIndexGroupSize = 32; + +struct index_params : cuvs::neighbors::index_params { + uint32_t n_lists = 1024; + uint32_t kmeans_n_iters = 20; + double kmeans_trainset_fraction = 0.5; + bool adaptive_centers = false; + bool conservative_memory_allocation = false; + bool add_data_on_build = true; +}; + +struct search_params : cuvs::neighbors::search_params { + uint32_t n_probes = 20; +}; + +static_assert(std::is_aggregate_v); +static_assert(std::is_aggregate_v); + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_list_spec IVF-SQ list storage spec + * @{ + */ + +template +struct list_spec { + static_assert(std::is_same_v, "IVF-SQ code type IdxT must be uint8_t"); + + using value_type = IdxT; + using list_extents = raft::matrix_extent; + using index_type = ExtT; + + SizeT align_max; + SizeT align_min; + uint32_t dim; + + constexpr list_spec(uint32_t dim, bool conservative_memory_allocation) + : dim(dim), + align_min(kIndexGroupSize), + align_max(conservative_memory_allocation ? kIndexGroupSize : 1024) + { + } + + template + constexpr explicit list_spec(const list_spec& other_spec) + : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} + { + } + + static constexpr uint32_t kVecLen = 16; + + constexpr auto make_list_extents(SizeT n_rows) const -> list_extents + { + uint32_t padded = ((dim + kVecLen - 1) / kVecLen) * kVecLen; + return raft::make_extents(n_rows, padded); + } +}; + +template +using list_data = ivf::list; + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_index IVF-SQ index + * @{ + */ + +/** + * @brief IVF-SQ index. + * + * @tparam IdxT SQ code type. Only uint8_t (8-bit, codes in [0,255]) for now. + * + * No member depends on the raw data type T (float, half). T appears only + * in the free-function signatures (build, search, extend) where input data + * is consumed, following the IVF-PQ pattern. + */ +template +struct index : cuvs::neighbors::index { + static_assert(std::is_same_v, "IVF-SQ code type IdxT must be uint8_t for now."); + + using index_params_type = ivf_sq::index_params; + using search_params_type = ivf_sq::search_params; + using code_type = IdxT; + + static constexpr uint32_t sq_bits = sizeof(IdxT) * 8; + + public: + index(const index&) = delete; + index(index&&) = default; + index& operator=(const index&) = delete; + index& operator=(index&&) = default; + ~index() = default; + + index(raft::resources const& res); + index(raft::resources const& res, const index_params& params, uint32_t dim); + index(raft::resources const& res, + cuvs::distance::DistanceType metric, + uint32_t n_lists, + uint32_t dim, + bool adaptive_centers, + bool conservative_memory_allocation); + + cuvs::distance::DistanceType metric() const noexcept; + bool adaptive_centers() const noexcept; + int64_t size() const noexcept; + uint32_t dim() const noexcept; + uint32_t n_lists() const noexcept; + bool conservative_memory_allocation() const noexcept; + + raft::device_vector_view list_sizes() noexcept; + raft::device_vector_view list_sizes() const noexcept; + + raft::device_matrix_view centers() noexcept; + raft::device_matrix_view centers() const noexcept; + + std::optional> center_norms() noexcept; + std::optional> center_norms() const noexcept; + void allocate_center_norms(raft::resources const& res); + + raft::device_vector_view sq_vmin() noexcept; + raft::device_vector_view sq_vmin() const noexcept; + + raft::device_vector_view sq_delta() noexcept; + raft::device_vector_view sq_delta() const noexcept; + + raft::host_vector_view accum_sorted_sizes() noexcept; + [[nodiscard]] raft::host_vector_view accum_sorted_sizes() const noexcept; + + raft::device_vector_view data_ptrs() noexcept; + raft::device_vector_view data_ptrs() const noexcept; + + raft::device_vector_view inds_ptrs() noexcept; + raft::device_vector_view inds_ptrs() const noexcept; + + std::vector>>& lists() noexcept; + const std::vector>>& lists() const noexcept; + + void check_consistency(); + + private: + cuvs::distance::DistanceType metric_; + bool adaptive_centers_; + bool conservative_memory_allocation_; + + std::vector>> lists_; + raft::device_vector list_sizes_; + raft::device_matrix centers_; + std::optional> center_norms_; + raft::device_vector sq_vmin_; + raft::device_vector sq_delta_; + + raft::device_vector data_ptrs_; + raft::device_vector inds_ptrs_; + raft::host_vector accum_sorted_sizes_; +}; + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_index_build IVF-SQ index build + * @{ + */ + +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_sq::index; + +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_sq::index& idx); + +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_sq::index; + +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_sq::index& idx); + +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::ivf_sq::index; + +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::host_matrix_view dataset, + cuvs::neighbors::ivf_sq::index& idx); + +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::ivf_sq::index; + +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::index_params& index_params, + raft::host_matrix_view dataset, + cuvs::neighbors::ivf_sq::index& idx); + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_index_extend IVF-SQ index extend + * @{ + */ + +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_sq::index& orig_index) + -> cuvs::neighbors::ivf_sq::index; + +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_sq::index* idx); + +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_sq::index& orig_index) + -> cuvs::neighbors::ivf_sq::index; + +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_sq::index* idx); + +auto extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_sq::index& orig_index) + -> cuvs::neighbors::ivf_sq::index; + +void extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_sq::index* idx); + +auto extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_sq::index& orig_index) + -> cuvs::neighbors::ivf_sq::index; + +void extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_sq::index* idx); + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_index_search IVF-SQ index search + * @{ + */ + +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::search_params& params, + const cuvs::neighbors::ivf_sq::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_sq::search_params& params, + const cuvs::neighbors::ivf_sq::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_index_serialize IVF-SQ index serialize + * @{ + */ + +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_sq::index& index); + +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_sq::index* index); + +/** + * @} + */ + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh new file mode 100644 index 0000000000..6c46a20e65 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -0,0 +1,664 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../core/nvtx.hpp" +#include "../ivf_common.cuh" +#include "../ivf_list.cuh" + +#include +#include +#include + +#include "../../cluster/kmeans_balanced.cuh" +#include "../detail/ann_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace cuvs::neighbors::ivf_sq { +using namespace cuvs::spatial::knn::detail; // NOLINT + +namespace detail { + +struct ColMinMaxPair { + float min_val; + float max_val; +}; + +struct ColMinMaxOp { + __device__ __forceinline__ ColMinMaxPair operator()(const ColMinMaxPair& a, + const ColMinMaxPair& b) const + { + return {fminf(a.min_val, b.min_val), fmaxf(a.max_val, b.max_val)}; + } +}; + +/** + * Fused per-column min+max in a single pass (2x less DRAM traffic than two + * separate reductions). One thread block per column; threads stride over + * rows and feed CUB BlockReduce with a combined min/max pair. + * + * Row-loop is manually 4x-unrolled so the compiler can overlap four + * independent __ldg requests in the memory pipeline. + */ +template +__launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const float* __restrict__ data, + float* __restrict__ col_min, + float* __restrict__ col_max, + int64_t n_rows, + uint32_t dim) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const uint32_t col = blockIdx.x; + if (col >= dim) return; + + ColMinMaxPair agg = {std::numeric_limits::max(), std::numeric_limits::lowest()}; + + const int64_t stride = static_cast(BlockSize); + int64_t row = static_cast(threadIdx.x); + + for (; row + 3 * stride < n_rows; row += 4 * stride) { + float v0 = __ldg(&data[row * dim + col]); + float v1 = __ldg(&data[(row + stride) * dim + col]); + float v2 = __ldg(&data[(row + 2 * stride) * dim + col]); + float v3 = __ldg(&data[(row + 3 * stride) * dim + col]); + agg.min_val = fminf(agg.min_val, fminf(fminf(v0, v1), fminf(v2, v3))); + agg.max_val = fmaxf(agg.max_val, fmaxf(fmaxf(v0, v1), fmaxf(v2, v3))); + } + for (; row < n_rows; row += stride) { + float val = __ldg(&data[row * dim + col]); + agg.min_val = fminf(agg.min_val, val); + agg.max_val = fmaxf(agg.max_val, val); + } + + agg = BlockReduce(temp_storage).Reduce(agg, ColMinMaxOp()); + + if (threadIdx.x == 0) { + col_min[col] = agg.min_val; + col_max[col] = agg.max_val; + } +} + +template +auto clone(const raft::resources& res, const index& source) -> index +{ + auto stream = raft::resource::get_cuda_stream(res); + + index target(res, + source.metric(), + source.n_lists(), + source.dim(), + source.adaptive_centers(), + source.conservative_memory_allocation()); + + raft::copy(target.list_sizes().data_handle(), + source.list_sizes().data_handle(), + source.list_sizes().size(), + stream); + raft::copy(target.centers().data_handle(), + source.centers().data_handle(), + source.centers().size(), + stream); + if (source.center_norms().has_value()) { + target.allocate_center_norms(res); + raft::copy(target.center_norms()->data_handle(), + source.center_norms()->data_handle(), + source.center_norms()->size(), + stream); + } + raft::copy(target.sq_vmin().data_handle(), + source.sq_vmin().data_handle(), + source.sq_vmin().size(), + stream); + raft::copy(target.sq_delta().data_handle(), + source.sq_delta().data_handle(), + source.sq_delta().size(), + stream); + target.lists() = source.lists(); + ivf::detail::recompute_internal_state(res, target); + return target; +} + +/** + * Kernel to encode float residuals to uint8_t SQ codes and write them interleaved. + * + * Uses warp-per-vector parallelism: each warp cooperatively encodes one vector + * so that reads from residuals/vmin/delta are coalesced across the 32 lanes. + * Lane 0 handles the atomic position assignment and the index write. + */ +template +__launch_bounds__(BlockSize) RAFT_KERNEL encode_and_fill_kernel(const uint32_t* labels, + const float* residuals, + const int64_t* source_ixs, + uint8_t** list_data_ptrs, + int64_t** list_index_ptrs, + uint32_t* list_sizes_ptr, + const float* vmin, + const float* delta, + int64_t n_rows, + uint32_t dim, + int64_t batch_offset) +{ + constexpr uint32_t kWarpSize = kIndexGroupSize; + constexpr uint32_t kWarpsPerBlock = BlockSize / kWarpSize; + + const uint32_t lane_id = threadIdx.x % kWarpSize; + const int64_t row_id = + int64_t(threadIdx.x / kWarpSize) + int64_t(blockIdx.x) * int64_t(kWarpsPerBlock); + if (row_id >= n_rows) return; + + uint32_t list_id = 0; + uint32_t inlist_id = 0; + if (lane_id == 0) { + auto source_ix = source_ixs == nullptr ? row_id + batch_offset : source_ixs[row_id]; + list_id = labels[row_id]; + inlist_id = atomicAdd(list_sizes_ptr + list_id, 1); + list_index_ptrs[list_id][inlist_id] = source_ix; + } + list_id = __shfl_sync(0xFFFFFFFF, list_id, 0); + inlist_id = __shfl_sync(0xFFFFFFFF, inlist_id, 0); + + using interleaved_group = raft::Pow2; + auto group_offset = interleaved_group::roundDown(inlist_id); + auto ingroup_id = interleaved_group::mod(inlist_id); + + constexpr uint32_t veclen = list_spec::kVecLen; + uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; + auto* list_dat = list_data_ptrs[list_id] + static_cast(group_offset) * padded_dim; + const float* src = residuals + row_id * dim; + + for (uint32_t d = lane_id; d < padded_dim; d += kWarpSize) { + uint8_t out; + if (d < dim) { + float val = src[d]; + float dv = delta[d]; + float v = vmin[d]; + float code = (dv > 0.0f) ? roundf((val - v) / dv) : 0.0f; + out = static_cast(fminf(fmaxf(code, 0.0f), 255.0f)); + } else { + out = 0; + } + uint32_t l = (d / veclen) * veclen; + uint32_t j = d % veclen; + list_dat[l * kIndexGroupSize + ingroup_id * veclen + j] = out; + } +} + +/** + * Compute residuals: residual[i] = cast(x_i) - centers[labels[i]] + */ +template +RAFT_KERNEL compute_residuals_kernel(const T* dataset, + const float* centers, + const uint32_t* labels, + float* residuals, + int64_t n_rows, + uint32_t dim) +{ + int64_t i = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; + uint32_t j = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= n_rows || j >= dim) return; + + float val = utils::mapping{}(dataset[i * dim + j]); + uint32_t c = labels[i]; + residuals[i * dim + j] = val - centers[c * dim + j]; +} + +template +void extend(raft::resources const& handle, + index* index, + const T* new_vectors, + const int64_t* new_indices, + int64_t n_rows) +{ + using LabelT = uint32_t; + RAFT_EXPECTS(index != nullptr, "index cannot be empty."); + if (n_rows == 0) return; + + auto stream = raft::resource::get_cuda_stream(handle); + auto n_lists = index->n_lists(); + auto dim = index->dim(); + list_spec list_device_spec{index->dim(), + index->conservative_memory_allocation()}; + cuvs::common::nvtx::range fun_scope( + "ivf_sq::extend(%zu, %u)", size_t(n_rows), dim); + + RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, + "You must pass data indices when the index is non-empty."); + + auto new_labels = + raft::make_device_mdarray(handle, + raft::resource::get_large_workspace_resource(handle), + raft::make_extents(n_rows)); + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.metric = index->metric(); + auto orig_centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), n_lists, dim); + + constexpr size_t kReasonableMaxBatchSize = 65536; + size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + + auto copy_stream = raft::resource::get_cuda_stream(handle); + bool enable_prefetch = false; + if (handle.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { + if (raft::resource::get_stream_pool_size(handle) >= 1) { + enable_prefetch = true; + copy_stream = raft::resource::get_stream_from_stream_pool(handle); + } + } + + utils::batch_load_iterator vec_batches(new_vectors, + n_rows, + index->dim(), + max_batch_size, + copy_stream, + raft::resource::get_workspace_resource(handle), + enable_prefetch); + vec_batches.prefetch_next_batch(); + + for (const auto& batch : vec_batches) { + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_labels.data_handle() + batch.offset(), batch.size()); + cuvs::cluster::kmeans::predict( + handle, kmeans_params, batch_data_view, orig_centroids_view, batch_labels_view); + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(handle); + } + + auto* list_sizes_ptr = index->list_sizes().data_handle(); + auto old_list_sizes_dev = raft::make_device_mdarray( + handle, raft::resource::get_workspace_resource(handle), raft::make_extents(n_lists)); + raft::copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); + + if (index->adaptive_centers()) { + auto centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); + auto list_sizes_view = + raft::make_device_vector_view, int64_t>( + list_sizes_ptr, n_lists); + for (const auto& batch : vec_batches) { + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_labels.data_handle() + batch.offset(), batch.size()); + cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, + batch_data_view, + batch_labels_view, + centroids_view, + list_sizes_view, + false, + utils::mapping{}); + } + } else { + raft::stats::histogram(raft::stats::HistTypeAuto, + reinterpret_cast(list_sizes_ptr), + int64_t(n_lists), + new_labels.data_handle(), + n_rows, + 1, + stream); + raft::linalg::add( + list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + } + + std::vector new_list_sizes(n_lists); + std::vector old_list_sizes(n_lists); + { + raft::copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); + raft::copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); + raft::resource::sync_stream(handle); + auto& lists = index->lists(); + for (uint32_t label = 0; label < n_lists; label++) { + ivf::resize_list(handle, + lists[label], + list_device_spec, + new_list_sizes[label], + raft::Pow2::roundUp(old_list_sizes[label])); + } + } + ivf::detail::recompute_internal_state(handle, *index); + raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + + utils::batch_load_iterator vec_indices( + new_indices, n_rows, 1, max_batch_size, stream, raft::resource::get_workspace_resource(handle)); + vec_batches.reset(); + vec_batches.prefetch_next_batch(); + utils::batch_load_iterator idx_batch = vec_indices.begin(); + + auto residuals_buf = raft::make_device_vector(handle, max_batch_size * dim); + + size_t next_report_offset = 0; + size_t d_report_offset = n_rows * 5 / 100; + + for (const auto& batch : vec_batches) { + int64_t bs = batch.size(); + + { + dim3 threads(32, 8); + dim3 blocks(raft::ceildiv(bs, threads.x), raft::ceildiv(dim, threads.y)); + compute_residuals_kernel + <<>>(batch.data(), + index->centers().data_handle(), + new_labels.data_handle() + batch.offset(), + residuals_buf.data_handle(), + bs, + dim); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + { + constexpr int kEncodeBlockSize = 256; + constexpr int kEncodeWarpsPerBlk = kEncodeBlockSize / kIndexGroupSize; + const dim3 block_dim(kEncodeBlockSize); + const dim3 grid_dim(raft::ceildiv(bs, int64_t(kEncodeWarpsPerBlk))); + encode_and_fill_kernel + <<>>(new_labels.data_handle() + batch.offset(), + residuals_buf.data_handle(), + idx_batch->data(), + index->data_ptrs().data_handle(), + index->inds_ptrs().data_handle(), + list_sizes_ptr, + index->sq_vmin().data_handle(), + index->sq_delta().data_handle(), + bs, + dim, + batch.offset()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(handle); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + if (batch.offset() > next_report_offset) { + float progress = batch.offset() * 100.0f / n_rows; + RAFT_LOG_DEBUG("ivf_sq::extend added vectors %zu, %6.1f%% complete", + static_cast(batch.offset()), + progress); + next_report_offset += d_report_offset; + } + ++idx_batch; + } + + auto compute_center_norms = [&]() { + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + stream, + raft::sqrt_op{}); + } else { + raft::linalg::rowNorm( + index->center_norms()->data_handle(), index->centers().data_handle(), dim, n_lists, stream); + } + }; + + if (!index->center_norms().has_value()) { + index->allocate_center_norms(handle); + if (index->center_norms().has_value()) { compute_center_norms(); } + } else if (index->adaptive_centers()) { + compute_center_norms(); + } +} + +template +auto extend(raft::resources const& handle, + const index& orig_index, + const T* new_vectors, + const int64_t* new_indices, + int64_t n_rows) -> index +{ + auto ext_index = clone(handle, orig_index); + detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); + return ext_index; +} + +template +inline auto build(raft::resources const& handle, + const index_params& params, + const T* dataset, + int64_t n_rows, + uint32_t dim) -> index +{ + auto stream = raft::resource::get_cuda_stream(handle); + cuvs::common::nvtx::range fun_scope( + "ivf_sq::build(%zu, %u)", size_t(n_rows), dim); + static_assert(std::is_same_v || std::is_same_v, "unsupported data type"); + RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); + RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); + RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::CosineExpanded || dim > 1, + "Cosine metric requires more than one dim"); + + index idx(handle, params, dim); + utils::memzero(idx.accum_sorted_sizes().data_handle(), idx.accum_sorted_sizes().size(), stream); + utils::memzero(idx.list_sizes().data_handle(), idx.list_sizes().size(), stream); + utils::memzero(idx.data_ptrs().data_handle(), idx.data_ptrs().size(), stream); + utils::memzero(idx.inds_ptrs().data_handle(), idx.inds_ptrs().size(), stream); + + // Train k-means centroids and SQ parameters on the same training subset. + // This mirrors IVF-PQ, which also trains its codebook on a subset of the data. + { + auto trainset_ratio = std::max( + 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, idx.n_lists())); + auto n_rows_train = n_rows / trainset_ratio; + rmm::device_uvector trainset( + n_rows_train * idx.dim(), stream, raft::resource::get_large_workspace_resource(handle)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), + sizeof(T) * idx.dim(), + dataset, + sizeof(T) * idx.dim() * trainset_ratio, + sizeof(T) * idx.dim(), + n_rows_train, + cudaMemcpyDefault, + stream)); + auto trainset_const_view = + raft::make_device_matrix_view(trainset.data(), n_rows_train, idx.dim()); + auto centers_view = raft::make_device_matrix_view( + idx.centers().data_handle(), idx.n_lists(), idx.dim()); + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = idx.metric(); + cuvs::cluster::kmeans::fit(handle, kmeans_params, trainset_const_view, centers_view); + raft::resource::sync_stream(handle); + + // Train SQ: predict labels for the training subset, compute its residuals, + // and derive per-dimension vmin/delta from them. + auto train_labels = raft::make_device_vector(handle, n_rows_train); + { + cuvs::cluster::kmeans::balanced_params pred_params; + pred_params.metric = idx.metric(); + auto centers_const_view = raft::make_device_matrix_view( + idx.centers().data_handle(), idx.n_lists(), dim); + cuvs::cluster::kmeans::predict( + handle, pred_params, trainset_const_view, centers_const_view, train_labels.view()); + raft::resource::sync_stream(handle); + } + + rmm::device_uvector residuals( + n_rows_train * dim, stream, raft::resource::get_large_workspace_resource(handle)); + { + dim3 threads(32, 8); + dim3 blocks(raft::ceildiv(n_rows_train, threads.x), + raft::ceildiv(dim, threads.y)); + compute_residuals_kernel<<>>(trainset.data(), + idx.centers().data_handle(), + train_labels.data_handle(), + residuals.data(), + n_rows_train, + dim); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + { + auto vmax_buf = raft::make_device_vector(handle, dim); + auto* vmin_ptr = idx.sq_vmin().data_handle(); + auto* vmax_ptr = vmax_buf.data_handle(); + + constexpr int kMinMaxBlockSize = 256; + fused_column_minmax_kernel<<>>( + residuals.data(), vmin_ptr, vmax_ptr, n_rows_train, dim); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // Expand the observed range by a small margin to reduce clipping on unseen data, + // since the SQ parameters are trained on a subset rather than the full dataset. + constexpr float kMargin = 0.05f; + auto* delta_ptr = idx.sq_delta().data_handle(); + raft::linalg::map_offset( + handle, idx.sq_vmin(), [vmin_ptr, vmax_ptr, delta_ptr, kMargin] __device__(uint32_t j) { + float range = vmax_ptr[j] - vmin_ptr[j]; + float margin = range * kMargin; + delta_ptr[j] = (range > 0.0f) ? (range + 2.0f * margin) / 255.0f : 1.0f; + return vmin_ptr[j] - margin; + }); + } + } + + if (params.add_data_on_build) { detail::extend(handle, &idx, dataset, nullptr, n_rows); } + + return idx; +} + +template +auto build(raft::resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) -> index +{ + int64_t n_rows = dataset.extent(0); + uint32_t dim = dataset.extent(1); + return build(handle, params, dataset.data_handle(), n_rows, dim); +} + +template +auto build(raft::resources const& handle, + const index_params& params, + raft::host_matrix_view dataset) -> index +{ + int64_t n_rows = dataset.extent(0); + uint32_t dim = dataset.extent(1); + return build(handle, params, dataset.data_handle(), n_rows, dim); +} + +template +void build(raft::resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + index& idx) +{ + idx = build(handle, params, dataset); +} + +template +void build(raft::resources const& handle, + const index_params& params, + raft::host_matrix_view dataset, + index& idx) +{ + idx = build(handle, params, dataset); +} + +template +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index +{ + RAFT_EXPECTS(new_vectors.extent(1) == orig_index.dim(), + "new_vectors should have the same dimension as the index"); + if (new_indices.has_value()) { + RAFT_EXPECTS(new_indices.value().extent(0) == new_vectors.extent(0), + "new_vectors and new_indices have different number of rows"); + } + int64_t n_rows = new_vectors.extent(0); + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} + +template +auto extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index +{ + RAFT_EXPECTS(new_vectors.extent(1) == orig_index.dim(), + "new_vectors should have the same dimension as the index"); + if (new_indices.has_value()) { + RAFT_EXPECTS(new_indices.value().extent(0) == new_vectors.extent(0), + "new_vectors and new_indices have different number of rows"); + } + int64_t n_rows = new_vectors.extent(0); + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} + +template +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) +{ + RAFT_EXPECTS(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + if (new_indices.has_value()) { + RAFT_EXPECTS(new_indices.value().extent(0) == new_vectors.extent(0), + "new_vectors and new_indices have different number of rows"); + } + detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} + +template +void extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + index* idx) +{ + RAFT_EXPECTS(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + if (new_indices.has_value()) { + RAFT_EXPECTS(new_indices.value().extent(0) == new_vectors.extent(0), + "new_vectors and new_indices have different number of rows"); + } + detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} + +} // namespace detail +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu new file mode 100644 index 0000000000..a97aebb11c --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "ivf_sq_build.cuh" + +namespace cuvs::neighbors::ivf_sq { + +#define CUVS_INST_IVF_SQ_BUILD_EXTEND(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ + } + +CUVS_INST_IVF_SQ_BUILD_EXTEND(float, uint8_t); + +#undef CUVS_INST_IVF_SQ_BUILD_EXTEND + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu new file mode 100644 index 0000000000..9148e5c328 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "ivf_sq_build.cuh" + +namespace cuvs::neighbors::ivf_sq { + +#define CUVS_INST_IVF_SQ_BUILD_EXTEND(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ + } + +CUVS_INST_IVF_SQ_BUILD_EXTEND(half, uint8_t); + +#undef CUVS_INST_IVF_SQ_BUILD_EXTEND + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh new file mode 100644 index 0000000000..39c653b048 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -0,0 +1,549 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../core/nvtx.hpp" +#include "../detail/ann_utils.cuh" +#include "../ivf_common.cuh" +#include "../sample_filter.cuh" +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +using namespace cuvs::spatial::knn::detail; // NOLINT + +enum class SqScanMetric { kL2, kIP, kCosine }; + +/** + * Per-probe scan kernel for IVF-SQ search. + * + * Grid: (n_queries, n_probes). Each block handles one (query, probe) pair. + * Within a block, each warp processes one interleaved group of kIndexGroupSize + * (=32) vectors at a time, with each lane responsible for one vector. + * Dimension blocks of veclen=16 bytes are loaded as coalesced uint4 reads + * across the warp (32 lanes x 16 bytes = 512 bytes = 4 cache lines), giving + * full memory-bandwidth utilisation. + * + * Per-dimension constants that are invariant across rows are precomputed into + * shared memory so the hot loop only reads from smem + one uint4 per dim-block: + * + * L2 / L2Sqrt: + * s_query_term[d] = query[d] - centroid[d] - sq_vmin[d] + * dist += (s_query_term[d] - code * s_sq_scale[d])^2 + * + * InnerProduct / Cosine: + * s_query_term[d] = query[d] + * s_recon_base[d] = centroid[d] + sq_vmin[d] + * v_d = s_recon_base[d] + code * s_sq_scale[d] + * dist += s_query_term[d] * v_d + * + * Shared-memory layout adapts to the metric to avoid waste: + * L2 / L2Sqrt : [s_query_term | s_sq_scale] (2 * dim floats) + * InnerProduct/Cosine: [s_query_term | s_recon_base | s_sq_scale] (3 * dim floats) + */ +template +__launch_bounds__(BlockDim) RAFT_KERNEL ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, + const uint32_t* list_sizes, + const uint32_t* coarse_indices, + const float* queries_float, + const float* centers, + const float* sq_vmin, + const float* sq_delta, + const float* query_norms, + uint32_t n_probes, + uint32_t dim, + uint32_t max_samples, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + IvfSampleFilterT sample_filter) +{ + static_assert(kIndexGroupSize == raft::WarpSize, + "Warp-coalesced scan requires kIndexGroupSize == WarpSize"); + + extern __shared__ float smem[]; + + constexpr bool kIsL2 = (Metric == SqScanMetric::kL2); + constexpr bool kIsCosine = (Metric == SqScanMetric::kCosine); + + float* s_query_term = smem; + float* s_recon_base = smem + dim; + float* s_sq_scale = kIsL2 ? (smem + dim) : (smem + 2 * dim); + + const uint32_t query_ix = blockIdx.x; + const uint32_t probe_ix = blockIdx.y; + + const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; + const uint32_t cluster_id = my_coarse[probe_ix]; + const uint32_t cluster_sz = list_sizes[cluster_id]; + if (cluster_sz == 0) return; + + const uint8_t* codes = data_ptrs[cluster_id]; + const float* query = queries_float + query_ix * dim; + const float* centroid = centers + cluster_id * dim; + + for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { + float vmin_d = sq_vmin[d]; + s_sq_scale[d] = sq_delta[d]; + if constexpr (kIsL2) { + s_query_term[d] = query[d] - centroid[d] - vmin_d; + } else { + s_query_term[d] = query[d]; + s_recon_base[d] = centroid[d] + vmin_d; + } + } + __syncthreads(); + + const uint32_t* my_chunk = chunk_indices + query_ix * n_probes; + uint32_t out_base = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; + + constexpr uint32_t veclen = 16; + constexpr uint32_t kWarpsPerBlock = BlockDim / raft::WarpSize; + const uint32_t warp_id = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; + uint32_t n_dim_blocks = padded_dim / veclen; + + for (uint32_t group = warp_id * kIndexGroupSize; group < cluster_sz; + group += kWarpsPerBlock * kIndexGroupSize) { + const uint32_t row = group + lane_id; + const bool valid = (row < cluster_sz) && sample_filter(query_ix, cluster_id, row); + + float dist = 0.0f; + float v_norm_sq = 0.0f; + + const uint8_t* group_data = codes + size_t(group) * padded_dim; + + for (uint32_t bl = 0; bl < n_dim_blocks; bl++) { + uint8_t codes_local[veclen]; + *reinterpret_cast(codes_local) = *reinterpret_cast( + group_data + bl * (veclen * kIndexGroupSize) + lane_id * veclen); + + const uint32_t l = bl * veclen; +#pragma unroll + for (uint32_t j = 0; j < veclen; j++) { + if (l + j < dim) { + float recon = float(codes_local[j]) * s_sq_scale[l + j]; + + if constexpr (kIsL2) { + float diff = s_query_term[l + j] - recon; + dist += diff * diff; + } else { + float v_d = s_recon_base[l + j] + recon; + dist += s_query_term[l + j] * v_d; + if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } + } + } + } + } + + if constexpr (kIsCosine) { + float denom = query_norms[query_ix] * sqrtf(v_norm_sq); + dist = (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; + } + + if (valid) { + uint32_t out_idx = query_ix * max_samples + out_base + row; + out_distances[out_idx] = dist; + out_indices[out_idx] = out_base + row; + } + } +} + +template +void ivf_sq_scan(raft::resources const& handle, + const index& idx, + const float* queries_float, + const float* query_norms, + uint32_t n_queries, + uint32_t n_probes, + uint32_t max_samples, + const uint32_t* coarse_indices, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + IvfSampleFilterT sample_filter, + rmm::cuda_stream_view stream) +{ + constexpr int kThreads = 256; + dim3 grid(n_queries, n_probes); + dim3 block(kThreads); + uint32_t dim = idx.dim(); + + auto do_launch = [&](auto kernel_ptr, size_t smem) { + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem)); + kernel_ptr<<>>(idx.data_ptrs().data_handle(), + idx.list_sizes().data_handle(), + coarse_indices, + queries_float, + idx.centers().data_handle(), + idx.sq_vmin().data_handle(), + idx.sq_delta().data_handle(), + query_norms, + n_probes, + dim, + max_samples, + chunk_indices, + out_distances, + out_indices, + sample_filter); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + + switch (idx.metric()) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: + do_launch(ivf_sq_scan_kernel, + 2 * dim * sizeof(float)); + break; + case cuvs::distance::DistanceType::InnerProduct: + do_launch(ivf_sq_scan_kernel, + 3 * dim * sizeof(float)); + break; + case cuvs::distance::DistanceType::CosineExpanded: + do_launch(ivf_sq_scan_kernel, + 3 * dim * sizeof(float)); + break; + default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); + } +} + +template +void search_impl(raft::resources const& handle, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + uint32_t n_probes, + bool select_min, + int64_t* neighbors, + float* distances, + rmm::device_async_resource_ref search_mr, + IvfSampleFilterT sample_filter) +{ + auto stream = raft::resource::get_cuda_stream(handle); + auto dim = index.dim(); + + std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); + + rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); + rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); + rmm::device_uvector coarse_distances_dev(n_queries_probes, stream, search_mr); + rmm::device_uvector coarse_indices_dev(n_queries_probes, stream, search_mr); + + size_t float_query_size; + if constexpr (std::is_same_v) { + float_query_size = 0; + } else { + float_query_size = n_queries * dim; + } + rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); + float* converted_queries_ptr = converted_queries_dev.data(); + + if constexpr (std::is_same_v) { + converted_queries_ptr = const_cast(queries); + } else { + raft::linalg::unaryOp( + converted_queries_ptr, queries, n_queries * dim, utils::mapping{}, stream); + } + + auto distance_buffer_dev_view = raft::make_device_matrix_view( + distance_buffer_dev.data(), n_queries, index.n_lists()); + + RAFT_EXPECTS(index.metric() == cuvs::distance::DistanceType::InnerProduct || + index.center_norms().has_value(), + "Center norms are required for search with L2 or Cosine metric. " + "Rebuild the index with add_data_on_build=true or call extend() first."); + + float alpha = 1.0f; + float beta = 0.0f; + switch (index.metric()) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: { + alpha = -2.0f; + beta = 1.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(dim), + static_cast(n_queries), + stream); + utils::outer_add(query_norm_dev.data(), + (int64_t)n_queries, + index.center_norms()->data_handle(), + (int64_t)index.n_lists(), + distance_buffer_dev.data(), + stream); + break; + } + case cuvs::distance::DistanceType::CosineExpanded: { + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(dim), + static_cast(n_queries), + stream, + raft::sqrt_op{}); + alpha = -1.0f; + beta = 0.0f; + break; + } + case cuvs::distance::DistanceType::InnerProduct: { + alpha = 1.0f; + beta = 0.0f; + break; + } + default: RAFT_FAIL("Unsupported metric type for IVF-SQ search."); + } + + raft::linalg::gemm(handle, + true, + false, + index.n_lists(), + n_queries, + dim, + &alpha, + index.centers().data_handle(), + dim, + converted_queries_ptr, + dim, + &beta, + distance_buffer_dev.data(), + index.n_lists(), + stream); + + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + auto n_lists_local = index.n_lists(); + const auto* q_norm_ptr = query_norm_dev.data(); + const auto* center_norm_ptr = index.center_norms()->data_handle(); + raft::linalg::map_offset( + handle, + distance_buffer_dev_view, + [=] __device__(const uint32_t idx, const float dist) { + const auto query = idx / n_lists_local; + const auto cluster = idx % n_lists_local; + float denom = q_norm_ptr[query] * center_norm_ptr[cluster]; + return (denom > 0.0f) ? dist / denom : 0.0f; + }, + raft::make_const_mdspan(distance_buffer_dev_view)); + } + + cuvs::selection::select_k( + handle, + raft::make_const_mdspan(distance_buffer_dev_view), + std::nullopt, + raft::make_device_matrix_view(coarse_distances_dev.data(), n_queries, n_probes), + raft::make_device_matrix_view( + coarse_indices_dev.data(), n_queries, n_probes), + select_min); + + rmm::device_uvector num_samples(n_queries, stream, search_mr); + rmm::device_uvector chunk_index(n_queries_probes, stream, search_mr); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + + uint32_t max_samples = + std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + + rmm::device_uvector all_distances(std::size_t(n_queries) * max_samples, stream, search_mr); + rmm::device_uvector all_indices( + std::size_t(n_queries) * max_samples, stream, search_mr); + + float init_val = + select_min ? std::numeric_limits::max() : std::numeric_limits::lowest(); + thrust::fill_n(raft::resource::get_thrust_policy(handle), + all_distances.data(), + std::size_t(n_queries) * max_samples, + init_val); + thrust::fill_n(raft::resource::get_thrust_policy(handle), + all_indices.data(), + std::size_t(n_queries) * max_samples, + uint32_t(0xFFFFFFFF)); + + auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( + index.inds_ptrs().data_handle(), sample_filter); + + ivf_sq_scan(handle, + index, + converted_queries_ptr, + query_norm_dev.data(), + n_queries, + n_probes, + max_samples, + coarse_indices_dev.data(), + chunk_index.data(), + all_distances.data(), + all_indices.data(), + filter_adapter, + stream); + + rmm::device_uvector neighbors_uint32(0, stream, search_mr); + uint32_t* neighbors_uint32_ptr = nullptr; + if constexpr (sizeof(int64_t) == sizeof(uint32_t)) { + neighbors_uint32_ptr = reinterpret_cast(neighbors); + } else { + neighbors_uint32.resize(std::size_t(n_queries) * k, stream); + neighbors_uint32_ptr = neighbors_uint32.data(); + } + + auto num_samples_view = + raft::make_device_vector_view(num_samples.data(), n_queries); + + cuvs::selection::select_k( + handle, + raft::make_device_matrix_view( + all_distances.data(), n_queries, max_samples), + raft::make_device_matrix_view( + all_indices.data(), n_queries, max_samples), + raft::make_device_matrix_view(distances, n_queries, k), + raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), + select_min, + false, + cuvs::selection::SelectAlgo::kAuto, + num_samples_view); + + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); + + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32_ptr, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); +} + +template +inline void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + int64_t* neighbors, + float* distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) +{ + cuvs::common::nvtx::range fun_scope( + "ivf_sq::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); + + RAFT_EXPECTS(params.n_probes > 0, + "n_probes (number of clusters to probe in the search) must be positive."); + auto n_probes = std::min(params.n_probes, index.n_lists()); + + uint32_t max_samples = + std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + + constexpr uint64_t kExpectedWsSize = 1024ull * 1024 * 1024; + uint64_t max_ws_size = + std::min(raft::resource::get_workspace_free_bytes(handle), kExpectedWsSize); + + uint64_t converted_query_floats = std::is_same_v ? 0 : index.dim(); + uint64_t ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + max_samples + + converted_query_floats) + + sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + max_samples + k); + + const uint32_t max_queries = + std::min(n_queries, std::max(1, max_ws_size / ws_per_query)); + + for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { + uint32_t queries_batch = std::min(max_queries, n_queries - offset_q); + + search_impl(handle, + index, + queries + std::size_t(offset_q) * index.dim(), + queries_batch, + k, + n_probes, + cuvs::distance::is_min_close(index.metric()), + neighbors + std::size_t(offset_q) * k, + distances + std::size_t(offset_q) * k, + raft::resource::get_workspace_resource(handle), + sample_filter); + } +} + +template +void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + search_with_filtering(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(neighbors.extent(1)), + neighbors.data_handle(), + distances.data_handle(), + sample_filter); +} + +template +void search(raft::resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) +{ + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + return search_with_filtering(handle, params, idx, queries, neighbors, distances, sample_filter); + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast&>( + sample_filter_ref); + return search_with_filtering(handle, params, idx, queries, neighbors, distances, sample_filter); + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type"); + } +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu new file mode 100644 index 0000000000..60d95a153f --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "ivf_sq_search.cuh" + +namespace cuvs::neighbors::ivf_sq { + +#define CUVS_INST_IVF_SQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::search_params& params, \ + const cuvs::neighbors::ivf_sq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::ivf_sq::detail::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ + } + +CUVS_INST_IVF_SQ_SEARCH(float, uint8_t); + +#undef CUVS_INST_IVF_SQ_SEARCH + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu new file mode 100644 index 0000000000..fbed3fd432 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "ivf_sq_search.cuh" + +namespace cuvs::neighbors::ivf_sq { + +#define CUVS_INST_IVF_SQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::search_params& params, \ + const cuvs::neighbors::ivf_sq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::ivf_sq::detail::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ + } + +CUVS_INST_IVF_SQ_SEARCH(half, uint8_t); + +#undef CUVS_INST_IVF_SQ_SEARCH + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh new file mode 100644 index 0000000000..b95e63ee33 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh @@ -0,0 +1,161 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../ivf_common.cuh" +#include "../ivf_list.cuh" +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +constexpr int serialization_version = 1; + +template +void serialize(raft::resources const& handle, std::ostream& os, const index& index_) +{ + RAFT_LOG_DEBUG( + "Saving IVF-SQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + + std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + + serialize_scalar(handle, os, serialization_version); + serialize_scalar(handle, os, index_.size()); + serialize_scalar(handle, os, index_.dim()); + serialize_scalar(handle, os, index_.n_lists()); + serialize_scalar(handle, os, index_.metric()); + serialize_scalar(handle, os, index_.adaptive_centers()); + serialize_scalar(handle, os, index_.conservative_memory_allocation()); + serialize_mdspan(handle, os, index_.centers()); + + if (index_.center_norms()) { + bool has_norms = true; + serialize_scalar(handle, os, has_norms); + serialize_mdspan(handle, os, *index_.center_norms()); + } else { + bool has_norms = false; + serialize_scalar(handle, os, has_norms); + } + + serialize_mdspan(handle, os, index_.sq_vmin()); + serialize_mdspan(handle, os, index_.sq_delta()); + + auto sizes_host = raft::make_host_vector(index_.list_sizes().extent(0)); + raft::copy(sizes_host.data_handle(), + index_.list_sizes().data_handle(), + sizes_host.size(), + raft::resource::get_cuda_stream(handle)); + raft::resource::sync_stream(handle); + serialize_mdspan(handle, os, sizes_host.view()); + + list_spec list_store_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::serialize_list(handle, + os, + index_.lists()[label], + list_store_spec, + raft::Pow2::roundUp(sizes_host(label))); + } + raft::resource::sync_stream(handle); +} + +template +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + detail::serialize(handle, of, index_); + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +template +auto deserialize(raft::resources const& handle, std::istream& is) -> index +{ + char dtype_string[4]; + is.read(dtype_string, 4); + + auto ver = raft::deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto n_rows = raft::deserialize_scalar(handle, is); + auto dim = raft::deserialize_scalar(handle, is); + auto n_lists = raft::deserialize_scalar(handle, is); + auto metric = raft::deserialize_scalar(handle, is); + bool adaptive_centers = raft::deserialize_scalar(handle, is); + bool cma = raft::deserialize_scalar(handle, is); + + index index_ = index(handle, metric, n_lists, dim, adaptive_centers, cma); + + deserialize_mdspan(handle, is, index_.centers()); + + bool has_norms = raft::deserialize_scalar(handle, is); + if (has_norms) { + index_.allocate_center_norms(handle); + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = index_.center_norms().value(); + deserialize_mdspan(handle, is, center_norms); + } + } + + deserialize_mdspan(handle, is, index_.sq_vmin()); + deserialize_mdspan(handle, is, index_.sq_delta()); + + deserialize_mdspan(handle, is, index_.list_sizes()); + + list_spec list_device_spec{index_.dim(), cma}; + list_spec list_store_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); + } + raft::resource::sync_stream(handle); + + ivf::detail::recompute_internal_state(handle, index_); + + return index_; +} + +template +auto deserialize(raft::resources const& handle, const std::string& filename) -> index +{ + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + auto index = detail::deserialize(handle, is); + is.close(); + return index; +} + +} // namespace cuvs::neighbors::ivf_sq::detail + +#define CUVS_INST_IVF_SQ_SERIALIZE(IdxT) \ + void serialize(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_sq::index& index) \ + { \ + cuvs::neighbors::ivf_sq::detail::serialize(handle, filename, index); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_sq::index* index) \ + { \ + *index = cuvs::neighbors::ivf_sq::detail::deserialize(handle, filename); \ + } diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_serialize_uint8_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize_uint8_t.cu new file mode 100644 index 0000000000..c2351ed8c3 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize_uint8_t.cu @@ -0,0 +1,16 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "ivf_sq_serialize.cuh" + +namespace cuvs::neighbors::ivf_sq { + +CUVS_INST_IVF_SQ_SERIALIZE(uint8_t); + +#undef CUVS_INST_IVF_SQ_SERIALIZE + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq_index.cpp b/cpp/src/neighbors/ivf_sq_index.cpp new file mode 100644 index 0000000000..d97ace7dcb --- /dev/null +++ b/cpp/src/neighbors/ivf_sq_index.cpp @@ -0,0 +1,236 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::ivf_sq { + +template +index::index(raft::resources const& res) + : index(res, cuvs::distance::DistanceType::L2Expanded, 0, 0, false, false) +{ +} + +template +index::index(raft::resources const& res, const index_params& params, uint32_t dim) + : index(res, + params.metric, + params.n_lists, + dim, + params.adaptive_centers, + params.conservative_memory_allocation) +{ +} + +template +index::index(raft::resources const& res, + cuvs::distance::DistanceType metric, + uint32_t n_lists, + uint32_t dim, + bool adaptive_centers, + bool conservative_memory_allocation) + : cuvs::neighbors::index(), + metric_(metric), + adaptive_centers_(adaptive_centers), + conservative_memory_allocation_(conservative_memory_allocation), + lists_{n_lists}, + list_sizes_{raft::make_device_vector(res, n_lists)}, + centers_(raft::make_device_matrix(res, n_lists, dim)), + center_norms_(std::nullopt), + sq_vmin_{raft::make_device_vector(res, dim)}, + sq_delta_{raft::make_device_vector(res, dim)}, + data_ptrs_{raft::make_device_vector(res, n_lists)}, + inds_ptrs_{raft::make_device_vector(res, n_lists)}, + accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} +{ + check_consistency(); + accum_sorted_sizes_(n_lists) = 0; +} + +template +cuvs::distance::DistanceType index::metric() const noexcept +{ + return metric_; +} + +template +bool index::adaptive_centers() const noexcept +{ + return adaptive_centers_; +} + +template +int64_t index::size() const noexcept +{ + return accum_sorted_sizes()(n_lists()); +} + +template +uint32_t index::dim() const noexcept +{ + return centers_.extent(1); +} + +template +uint32_t index::n_lists() const noexcept +{ + return lists_.size(); +} + +template +bool index::conservative_memory_allocation() const noexcept +{ + return conservative_memory_allocation_; +} + +template +raft::device_vector_view index::list_sizes() noexcept +{ + return list_sizes_.view(); +} + +template +raft::device_vector_view index::list_sizes() const noexcept +{ + return list_sizes_.view(); +} + +template +raft::device_matrix_view index::centers() noexcept +{ + return centers_.view(); +} + +template +raft::device_matrix_view index::centers() + const noexcept +{ + return centers_.view(); +} + +template +std::optional> index::center_norms() noexcept +{ + if (center_norms_.has_value()) { + return std::make_optional>(center_norms_->view()); + } else { + return std::nullopt; + } +} + +template +std::optional> index::center_norms() + const noexcept +{ + if (center_norms_.has_value()) { + return std::make_optional>( + center_norms_->view()); + } else { + return std::nullopt; + } +} + +template +void index::allocate_center_norms(raft::resources const& res) +{ + switch (metric_) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Unexpanded: + case cuvs::distance::DistanceType::L2SqrtUnexpanded: + case cuvs::distance::DistanceType::CosineExpanded: + center_norms_ = raft::make_device_vector(res, n_lists()); + break; + default: center_norms_ = std::nullopt; + } +} + +template +raft::device_vector_view index::sq_vmin() noexcept +{ + return sq_vmin_.view(); +} + +template +raft::device_vector_view index::sq_vmin() const noexcept +{ + return sq_vmin_.view(); +} + +template +raft::device_vector_view index::sq_delta() noexcept +{ + return sq_delta_.view(); +} + +template +raft::device_vector_view index::sq_delta() const noexcept +{ + return sq_delta_.view(); +} + +template +raft::host_vector_view index::accum_sorted_sizes() noexcept +{ + return accum_sorted_sizes_.view(); +} + +template +raft::host_vector_view index::accum_sorted_sizes() const noexcept +{ + return accum_sorted_sizes_.view(); +} + +template +raft::device_vector_view index::data_ptrs() noexcept +{ + return data_ptrs_.view(); +} + +template +raft::device_vector_view index::data_ptrs() const noexcept +{ + return data_ptrs_.view(); +} + +template +raft::device_vector_view index::inds_ptrs() noexcept +{ + return inds_ptrs_.view(); +} + +template +raft::device_vector_view index::inds_ptrs() const noexcept +{ + return inds_ptrs_.view(); +} + +template +std::vector>>& index::lists() noexcept +{ + return lists_; +} + +template +const std::vector>>& index::lists() const noexcept +{ + return lists_; +} + +template +void index::check_consistency() +{ + auto n_lists = lists_.size(); + RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS((centers_.extent(0) == list_sizes_.extent(0)) && + (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), + "inconsistent number of lists (clusters)"); +} + +template struct index; + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 35794adf9b..208c330de7 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -131,6 +131,13 @@ ConfigureTest( PERCENT 100 ) +ConfigureTest( + NAME NEIGHBORS_ANN_IVF_SQ_TEST + PATH neighbors/ann_ivf_sq/test_float_uint8_t.cu + GPUS 1 + PERCENT 100 +) + ConfigureTest( NAME NEIGHBORS_ANN_IVF_PQ_TEST PATH neighbors/ann_ivf_pq/test_float_int64_t.cu neighbors/ann_ivf_pq/test_int8_t_int64_t.cu diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh new file mode 100644 index 0000000000..a7e02315e4 --- /dev/null +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -0,0 +1,457 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" +#include "naive_knn.cuh" + +#include +#include +#include +#include + +#include +#include + +namespace cuvs::neighbors::ivf_sq { + +struct test_ivf_sample_filter { + static constexpr unsigned offset = 300; +}; + +template +struct AnnIvfSqInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + IdxT nprobe; + IdxT nlist; + cuvs::distance::DistanceType metric; + bool adaptive_centers; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const AnnIvfSqInputs& p) +{ + os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " + << p.nprobe << ", " << p.nlist << ", " + << cuvs::neighbors::print_metric{static_cast((int)p.metric)} + << ", " << p.adaptive_centers << '}' << std::endl; + return os; +} + +template +class AnnIVFSQTest : public ::testing::TestWithParam> { + public: + AnnIVFSQTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void testIVFSQ() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfsq(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfsq(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + { + double min_recall = + std::min(1.0, static_cast(ps.nprobe) / static_cast(ps.nlist)); + + rmm::device_uvector distances_ivfsq_dev(queries_size, stream_); + rmm::device_uvector indices_ivfsq_dev(queries_size, stream_); + + { + cuvs::neighbors::ivf_sq::index_params index_params; + cuvs::neighbors::ivf_sq::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + + auto idx = cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); + + // Test extend: build without data, then extend + cuvs::neighbors::ivf_sq::index_params index_params_no_add; + index_params_no_add.n_lists = ps.nlist; + index_params_no_add.metric = ps.metric; + index_params_no_add.adaptive_centers = ps.adaptive_centers; + index_params_no_add.add_data_on_build = false; + index_params_no_add.kmeans_trainset_fraction = 0.5; + + auto idx_empty = + cuvs::neighbors::ivf_sq::build(handle_, index_params_no_add, database_view); + + auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); + raft::linalg::map_offset(handle_, vector_indices.view(), raft::identity_op{}); + raft::resource::sync_stream(handle_); + + auto indices_view = raft::make_device_vector_view( + vector_indices.data_handle(), ps.num_db_vecs); + cuvs::neighbors::ivf_sq::extend( + handle_, + database_view, + std::make_optional>(indices_view), + &idx_empty); + + // Serialize / deserialize round-trip + tmp_index_file index_file; + cuvs::neighbors::ivf_sq::serialize(handle_, index_file.filename, idx); + cuvs::neighbors::ivf_sq::index index_loaded(handle_); + cuvs::neighbors::ivf_sq::deserialize(handle_, index_file.filename, &index_loaded); + ASSERT_EQ(idx.size(), index_loaded.size()); + ASSERT_EQ(idx.dim(), index_loaded.dim()); + ASSERT_EQ(idx.n_lists(), index_loaded.n_lists()); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_ivfsq_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = + raft::make_device_matrix_view(distances_ivfsq_dev.data(), ps.num_queries, ps.k); + + cuvs::neighbors::ivf_sq::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + raft::update_host( + distances_ivfsq.data(), distances_ivfsq_dev.data(), queries_size, stream_); + raft::update_host(indices_ivfsq.data(), indices_ivfsq_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + // SQ introduces quantization error, so we relax the distance epsilon + float eps = 0.1; + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfsq, + distances_naive, + distances_ivfsq, + ps.num_queries, + ps.k, + eps, + min_recall)); + } + } + + void testFilter() + { + if (ps.num_db_vecs <= static_cast(test_ivf_sample_filter::offset)) { + GTEST_SKIP() << "Skipping filter test: num_db_vecs <= filter offset"; + } + + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfsq(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfsq(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_ivf_sample_filter::offset * ps.dim; + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + { + double min_recall = + std::min(1.0, static_cast(ps.nprobe) / static_cast(ps.nlist)); + + rmm::device_uvector distances_ivfsq_dev(queries_size, stream_); + rmm::device_uvector indices_ivfsq_dev(queries_size, stream_); + + { + cuvs::neighbors::ivf_sq::index_params index_params; + cuvs::neighbors::ivf_sq::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + auto index = cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); + + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + raft::linalg::map_offset(handle_, removed_indices.view(), raft::identity_op{}); + raft::resource::sync_stream(handle_); + + cuvs::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + auto bitset_filter_obj = + cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view()); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_ivfsq_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = + raft::make_device_matrix_view(distances_ivfsq_dev.data(), ps.num_queries, ps.k); + + cuvs::neighbors::ivf_sq::search(handle_, + search_params, + index, + search_queries_view, + indices_out_view, + dists_out_view, + bitset_filter_obj); + + raft::update_host( + distances_ivfsq.data(), distances_ivfsq_dev.data(), queries_size, stream_); + raft::update_host(indices_ivfsq.data(), indices_ivfsq_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + float eps = 0.1; + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfsq, + distances_naive, + distances_ivfsq, + ps.num_queries, + ps.k, + eps, + min_recall)); + } + } + + void SetUp() override + { + database.resize(ps.num_db_vecs * ps.dim, stream_); + search_queries.resize(ps.num_queries * ps.dim, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same_v || std::is_same_v) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } + raft::resource::sync_stream(handle_); + } + + void TearDown() override + { + raft::resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnIvfSqInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +const std::vector> inputs = { + // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers + + // ===== Dimension edge cases (all four metrics) ===== + // dim=1 (CosineExpanded excluded: requires dim > 1) + {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + // dim=2,3,4,5 (unaligned) + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + // dim=7,8 (around veclen=16 boundary, not a multiple of veclen) + {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + // dim=15,16,17 (around veclen=16 boundary) + {1000, 10000, 15, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 15, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 17, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 17, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + // dim=31,32,33 (around 2*veclen boundary) + {1000, 10000, 31, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 31, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 33, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 33, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + // medium dims + {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + // large dims (may exceed shared memory limits) + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + + // ===== k edge cases ===== + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 2, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 5, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 20, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 20, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 50, 100, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + + // ===== nprobe / nlist edge cases ===== + // nprobe == nlist (exhaustive probe) + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, + // nprobe == 1 (minimal probe) + {1000, 10000, 16, 10, 1, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 1, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + // nprobe > nlist (clamped to nlist) + {1000, 10000, 16, 10, 2048, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 2048, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + // various nprobe + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + // very small nlist + {100, 10000, 16, 10, 8, 8, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, 8, 8, cuvs::distance::DistanceType::CosineExpanded, false}, + // smaller nlist + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + + // ===== Dataset size edge cases ===== + // single query + {1, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + // very few queries + {2, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {5, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + // very few db vectors (nlist reduced to fit) + {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::CosineExpanded, false}, + // small db with many empty clusters + {100, 100, 16, 5, 20, 64, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 100, 16, 5, 20, 64, cuvs::distance::DistanceType::CosineExpanded, false}, + // larger datasets + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + + // ===== Large query batches (gridDim.x > 65535) ===== + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, + {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::CosineExpanded, false}, + // just above the old 65535 limit + {65536, 1024, 16, 10, 32, 64, cuvs::distance::DistanceType::L2Expanded, false}, + {65536, 1024, 16, 10, 32, 64, cuvs::distance::DistanceType::CosineExpanded, false}, + + // ===== Adaptive centers (all four metrics, multiple dims) ===== + {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 32, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 32, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 32, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + + // ===== Recall-stability: same data, different query counts ===== + {20000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, + {50000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, +}; + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu b/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu new file mode 100644 index 0000000000..02ec8a7dfc --- /dev/null +++ b/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu @@ -0,0 +1,21 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "../ann_ivf_sq.cuh" + +namespace cuvs::neighbors::ivf_sq { + +typedef AnnIVFSQTest AnnIVFSQTestF_float; +TEST_P(AnnIVFSQTestF_float, AnnIVFSQ) +{ + this->testIVFSQ(); + this->testFilter(); +} + +INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_float, ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::ivf_sq diff --git a/python/cuvs_bench/cuvs_bench/config/algorithms.yaml b/python/cuvs_bench/cuvs_bench/config/algorithms.yaml index fa2195fc61..3a787f65ab 100644 --- a/python/cuvs_bench/cuvs_bench/config/algorithms.yaml +++ b/python/cuvs_bench/cuvs_bench/config/algorithms.yaml @@ -34,6 +34,9 @@ cuvs_ivf_flat: cuvs_ivf_pq: executable: CUVS_IVF_PQ_ANN_BENCH requires_gpu: true +cuvs_ivf_sq: + executable: CUVS_IVF_SQ_ANN_BENCH + requires_gpu: true cuvs_cagra: executable: CUVS_CAGRA_ANN_BENCH requires_gpu: true diff --git a/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml b/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml new file mode 100644 index 0000000000..711f3e8ce8 --- /dev/null +++ b/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml @@ -0,0 +1,16 @@ +name: cuvs_ivf_sq +groups: + base: + build: + nlist: [1024, 2048, 4096, 8192, 16384, 32000, 64000] + ratio: [1, 2, 4] + niter: [20, 25] + search: + nprobe: [1, 5, 10, 50, 100, 200, 500, 1000, 2000] + test: + build: + nlist: [1024] + ratio: [1] + niter: [20] + search: + nprobe: [1, 5] From cf19a8629c3377426a2a6bfa3b3e7d900044b42a Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 2 Mar 2026 11:25:03 +0100 Subject: [PATCH 02/31] add IVF-SQ bench constraints --- .../cuvs_bench/config/algos/constraints/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py b/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py index 9111bdc3b9..ea2afe351e 100644 --- a/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py +++ b/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 @@ -50,6 +50,12 @@ def cuvs_cagra_search(params, build_params, k, batch_size): return True +def cuvs_ivf_sq_search(params, build_params, k, batch_size): + if "nlist" in build_params and "nprobe" in params: + return build_params["nlist"] >= params["nprobe"] + return True + + ############################################################################### # FAISS constraints # ############################################################################### From 6a95e8a8215016a01127d9c62e97cbb2fffa1cac Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 2 Mar 2026 12:21:49 +0100 Subject: [PATCH 03/31] Update default IVF-SQ benchmark config --- .../cuvs_bench/config/algos/cuvs_ivf_sq.yaml | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml b/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml index 711f3e8ce8..adaad54e04 100644 --- a/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml +++ b/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml @@ -1,12 +1,21 @@ name: cuvs_ivf_sq +constraints: + search: cuvs_bench.config.algos.constraints.cuvs_ivf_sq_search groups: base: build: - nlist: [1024, 2048, 4096, 8192, 16384, 32000, 64000] - ratio: [1, 2, 4] - niter: [20, 25] + nlist: [1024, 2048, 4096, 8192] + ratio: [1, 2] + niter: [25] search: - nprobe: [1, 5, 10, 50, 100, 200, 500, 1000, 2000] + nprobe: [1, 5, 10, 20, 50, 100, 200, 500] + large: + build: + nlist: [8192, 16384, 32000, 64000] + ratio: [2, 4] + niter: [20] + search: + nprobe: [10, 20, 50, 100, 200, 500, 1000, 2000] test: build: nlist: [1024] From 83b8c6351239cd03aaff97a761eca6a49e6f08f5 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 12 Mar 2026 11:40:17 +0100 Subject: [PATCH 04/31] Update postprocess_neighbors signature --- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index 39c653b048..a17992ff19 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -427,7 +427,7 @@ void search_impl(raft::resources const& handle, num_samples_view); ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); + handle, distances, distances, index.metric(), n_queries, k, 1.0, false); ivf::detail::postprocess_neighbors(neighbors, neighbors_uint32_ptr, From 1050debec9dc64c52f766a5e6e69ea9569789430 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 12 Mar 2026 12:01:52 +0100 Subject: [PATCH 05/31] update testing --- cpp/tests/neighbors/ann_ivf_sq.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh index a7e02315e4..25abc82740 100644 --- a/cpp/tests/neighbors/ann_ivf_sq.cuh +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -410,9 +410,6 @@ const std::vector> inputs = { // very few db vectors (nlist reduced to fit) {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::L2Expanded, false}, {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::CosineExpanded, false}, - // small db with many empty clusters - {100, 100, 16, 5, 20, 64, cuvs::distance::DistanceType::L2Expanded, false}, - {100, 100, 16, 5, 20, 64, cuvs::distance::DistanceType::CosineExpanded, false}, // larger datasets {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, From 3a911d887a19471137085bcc19f6ab881af5f76f Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 11:14:58 +0100 Subject: [PATCH 06/31] documentation --- cpp/include/cuvs/neighbors/ivf_sq.hpp | 554 ++++++++++++++++++++++- docs/source/cpp_api/neighbors.rst | 1 + docs/source/cpp_api/neighbors_ivf_sq.rst | 68 +++ 3 files changed, 614 insertions(+), 9 deletions(-) create mode 100644 docs/source/cpp_api/neighbors_ivf_sq.rst diff --git a/cpp/include/cuvs/neighbors/ivf_sq.hpp b/cpp/include/cuvs/neighbors/ivf_sq.hpp index 2f09751e95..042ecda10a 100644 --- a/cpp/include/cuvs/neighbors/ivf_sq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_sq.hpp @@ -24,19 +24,62 @@ namespace cuvs::neighbors::ivf_sq { constexpr static uint32_t kIndexGroupSize = 32; struct index_params : cuvs::neighbors::index_params { - uint32_t n_lists = 1024; - uint32_t kmeans_n_iters = 20; - double kmeans_trainset_fraction = 0.5; - bool adaptive_centers = false; + /** The number of inverted lists (clusters) */ + uint32_t n_lists = 1024; + /** The number of iterations searching for kmeans centers (index building). */ + uint32_t kmeans_n_iters = 20; + /** The fraction of data to use during iterative kmeans building. */ + double kmeans_trainset_fraction = 0.5; + /** + * By default (adaptive_centers = false), the cluster centers are trained in `ivf_sq::build`, + * and never modified in `ivf_sq::extend`. As a result, you may need to retrain the index + * from scratch after invoking (`ivf_sq::extend`) a few times with new data, the distribution of + * which is no longer representative of the original training set. + * + * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new + * data when it is added. In this case, `index.centers()` are always exactly the centroids of the + * data in the corresponding clusters. The drawback of this behavior is that the centroids depend + * on the order of adding new data (through the classification of the added data); that is, + * `index.centers()` "drift" together with the changing distribution of the newly added data. + */ + bool adaptive_centers = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ bool conservative_memory_allocation = false; - bool add_data_on_build = true; + /** + * Whether to add the dataset content to the index, i.e.: + * + * - `true` means the index is filled with the dataset vectors and ready to search after calling + * `build`. + * - `false` means `build` only trains the underlying model (e.g. quantizer or clustering), but + * the index is left empty; you'd need to call `extend` on the index afterwards to populate it. + */ + bool add_data_on_build = true; }; +static_assert(std::is_aggregate_v); + +/** + * @} + */ + +/** + * @defgroup ivf_sq_cpp_search_params IVF-SQ index search parameters + * @{ + */ + struct search_params : cuvs::neighbors::search_params { + /** The number of clusters to search. */ uint32_t n_probes = 20; }; -static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); /** @@ -97,11 +140,23 @@ using list_data = ivf::list; /** * @brief IVF-SQ index. * + * In the IVF-SQ index, a database vector is first assigned to the nearest cluster center + * using an inverted file (IVF) structure, and then compressed using scalar quantization (SQ). + * + * Scalar quantization independently maps each dimension of the vector to a fixed-width integer + * code. For 8-bit quantization (uint8_t), each floating-point component is linearly mapped to + * an integer in [0, 255] using learned per-dimension minimum (`sq_vmin`) and range (`sq_delta`) + * values: + * + * code_i = round((x_i - vmin_i) / delta_i * 255) + * + * This provides a compact representation (1 byte per dimension) while preserving the relative + * distances between vectors with high fidelity, offering a good trade-off between index size, + * search speed, and recall compared to flat (uncompressed) and product-quantized (PQ) + * representations. + * * @tparam IdxT SQ code type. Only uint8_t (8-bit, codes in [0,255]) for now. * - * No member depends on the raw data type T (float, half). T appears only - * in the free-function signatures (build, search, extend) where input data - * is consumed, following the IVF-PQ pattern. */ template struct index : cuvs::neighbors::index { @@ -192,41 +247,219 @@ struct index : cuvs::neighbors::index { * @{ */ +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_sq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-sq index + */ auto build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::device_matrix_view dataset) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_sq::index index; + * ivf_sq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_sq::index + * + */ void build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::device_matrix_view dataset, cuvs::neighbors::ivf_sq::index& idx); +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_sq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-sq index + */ auto build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::device_matrix_view dataset) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_sq::index index; + * ivf_sq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_sq::index + * + */ void build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::device_matrix_view dataset, cuvs::neighbors::ivf_sq::index& idx); +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_sq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a host pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-sq index + */ auto build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::host_matrix_view dataset) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_sq::index index; + * ivf_sq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::host_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_sq::index + * + */ void build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::host_matrix_view dataset, cuvs::neighbors::ivf_sq::index& idx); +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_sq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a host pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-sq index + */ auto build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::host_matrix_view dataset) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_sq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_sq::index index; + * ivf_sq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::host_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_sq::index + * + */ void build(raft::resources const& handle, const cuvs::neighbors::ivf_sq::index_params& index_params, raft::host_matrix_view dataset, @@ -241,45 +474,237 @@ void build(raft::resources const& handle, * @{ */ +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_sq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index the original index + * + * @return the constructed extended ivf-sq index + */ auto extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, const cuvs::neighbors::ivf_sq::index& orig_index) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_sq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to ivf_sq::index + */ void extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, cuvs::neighbors::ivf_sq::index* idx); +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_sq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index the original index + * + * @return the constructed extended ivf-sq index + */ auto extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, const cuvs::neighbors::ivf_sq::index& orig_index) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_sq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to ivf_sq::index + */ void extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, cuvs::neighbors::ivf_sq::index* idx); +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_sq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a host matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a host vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index the original index + * + * @return the constructed extended ivf-sq index + */ auto extend(raft::resources const& handle, raft::host_matrix_view new_vectors, std::optional> new_indices, const cuvs::neighbors::ivf_sq::index& orig_index) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_sq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a host matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a host vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to ivf_sq::index + */ void extend(raft::resources const& handle, raft::host_matrix_view new_vectors, std::optional> new_indices, cuvs::neighbors::ivf_sq::index* idx); +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_sq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a host matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a host vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index the original index + * + * @return the constructed extended ivf-sq index + */ auto extend(raft::resources const& handle, raft::host_matrix_view new_vectors, std::optional> new_indices, const cuvs::neighbors::ivf_sq::index& orig_index) -> cuvs::neighbors::ivf_sq::index; +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_sq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_sq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_sq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a host matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a host vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to ivf_sq::index + */ void extend(raft::resources const& handle, raft::host_matrix_view new_vectors, std::optional> new_indices, @@ -294,6 +719,41 @@ void extend(raft::resources const& handle, * @{ */ +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_sq::build](#ivf_sq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default search parameters + * ivf_sq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_sq::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_sq::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_sq::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * @endcode + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-sq constructed index + * @param[in] queries raft::device_matrix_view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors raft::device_matrix_view to the indices of the neighbors in the source + * dataset [n_queries, k] + * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors + * [n_queries, k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) + */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_sq::search_params& params, const cuvs::neighbors::ivf_sq::index& index, @@ -303,6 +763,38 @@ void search(raft::resources const& handle, const cuvs::neighbors::filtering::base_filter& sample_filter = cuvs::neighbors::filtering::none_sample_filter{}); +/** + * @brief Search ANN using the constructed index with half-precision queries. + * + * See the [ivf_sq::build](#ivf_sq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default search parameters + * ivf_sq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_sq::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_sq::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_sq::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * @endcode + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-sq constructed index + * @param[in] queries raft::device_matrix_view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors raft::device_matrix_view to the indices of the neighbors in the source + * dataset [n_queries, k] + * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors + * [n_queries, k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) + */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_sq::search_params& params, const cuvs::neighbors::ivf_sq::index& index, @@ -321,10 +813,54 @@ void search(raft::resources const& handle, * @{ */ +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_sq::build(...);` + * cuvs::neighbors::ivf_sq::serialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-SQ index + * + */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::ivf_sq::index& index); +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an empty index with `ivf_sq::index index(handle);` + * cuvs::neighbors::ivf_sq::deserialize(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index IVF-SQ index + * + */ void deserialize(raft::resources const& handle, const std::string& filename, cuvs::neighbors::ivf_sq::index* index); diff --git a/docs/source/cpp_api/neighbors.rst b/docs/source/cpp_api/neighbors.rst index 0c6e9cfd86..1c266bb902 100644 --- a/docs/source/cpp_api/neighbors.rst +++ b/docs/source/cpp_api/neighbors.rst @@ -18,6 +18,7 @@ Nearest Neighbors neighbors_hnsw.rst neighbors_ivf_flat.rst neighbors_ivf_pq.rst + neighbors_ivf_sq.rst neighbors_nn_descent.rst neighbors_refine.rst neighbors_mg.rst diff --git a/docs/source/cpp_api/neighbors_ivf_sq.rst b/docs/source/cpp_api/neighbors_ivf_sq.rst new file mode 100644 index 0000000000..d0554f926a --- /dev/null +++ b/docs/source/cpp_api/neighbors_ivf_sq.rst @@ -0,0 +1,68 @@ +IVF-SQ +====== + +The IVF-SQ method is an ANN algorithm. Like IVF-Flat, IVF-SQ splits the points into a number of clusters (also specified by a parameter called n_lists) and searches the closest clusters to compute the nearest neighbors (also specified by a parameter called n_probes), but it shrinks the sizes of the vectors using scalar quantization, independently mapping each dimension to a fixed-width integer code. + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *cuvs::neighbors::ivf_sq* + +Index build parameters +---------------------- + +.. doxygengroup:: ivf_sq_cpp_index_params + :project: cuvs + :members: + :content-only: + +Index search parameters +----------------------- + +.. doxygengroup:: ivf_sq_cpp_search_params + :project: cuvs + :members: + :content-only: + +Index +----- + +.. doxygengroup:: ivf_sq_cpp_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: ivf_sq_cpp_index_build + :project: cuvs + :members: + :content-only: + +Index extend +------------ + +.. doxygengroup:: ivf_sq_cpp_index_extend + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: ivf_sq_cpp_index_search + :project: cuvs + :members: + :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_sq_cpp_index_serialize + :project: cuvs + :members: + :content-only: From b1246284aef228e2e5a255e862149523716e1238 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 14:03:36 +0100 Subject: [PATCH 07/31] memset in index constructor --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 4 ---- cpp/src/neighbors/ivf_sq_index.cpp | 14 +++++++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 6c46a20e65..13838556ae 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -456,10 +456,6 @@ inline auto build(raft::resources const& handle, "Cosine metric requires more than one dim"); index idx(handle, params, dim); - utils::memzero(idx.accum_sorted_sizes().data_handle(), idx.accum_sorted_sizes().size(), stream); - utils::memzero(idx.list_sizes().data_handle(), idx.list_sizes().size(), stream); - utils::memzero(idx.data_ptrs().data_handle(), idx.data_ptrs().size(), stream); - utils::memzero(idx.inds_ptrs().data_handle(), idx.inds_ptrs().size(), stream); // Train k-means centroids and SQ parameters on the same training subset. // This mirrors IVF-PQ, which also trains its codebook on a subset of the data. diff --git a/cpp/src/neighbors/ivf_sq_index.cpp b/cpp/src/neighbors/ivf_sq_index.cpp index d97ace7dcb..ffa54bd2e9 100644 --- a/cpp/src/neighbors/ivf_sq_index.cpp +++ b/cpp/src/neighbors/ivf_sq_index.cpp @@ -5,6 +5,11 @@ #include +#include +#include + +#include + namespace cuvs::neighbors::ivf_sq { template @@ -46,7 +51,14 @@ index::index(raft::resources const& res, accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} { check_consistency(); - accum_sorted_sizes_(n_lists) = 0; + auto stream = raft::resource::get_cuda_stream(res); + std::memset(accum_sorted_sizes_.data_handle(), 0, accum_sorted_sizes_.size() * sizeof(int64_t)); + RAFT_CUDA_TRY( + cudaMemsetAsync(list_sizes_.data_handle(), 0, list_sizes_.size() * sizeof(uint32_t), stream)); + RAFT_CUDA_TRY( + cudaMemsetAsync(data_ptrs_.data_handle(), 0, data_ptrs_.size() * sizeof(IdxT*), stream)); + RAFT_CUDA_TRY( + cudaMemsetAsync(inds_ptrs_.data_handle(), 0, inds_ptrs_.size() * sizeof(int64_t*), stream)); } template From 641c6ca34babd5d2914d6c1b6c7bbe47aec92980 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 14:35:56 +0100 Subject: [PATCH 08/31] random sampling --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 26 +++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 13838556ae..e7c0f734b4 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -460,22 +462,18 @@ inline auto build(raft::resources const& handle, // Train k-means centroids and SQ parameters on the same training subset. // This mirrors IVF-PQ, which also trains its codebook on a subset of the data. { + raft::random::RngState random_state{137}; auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, idx.n_lists())); auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset( - n_rows_train * idx.dim(), stream, raft::resource::get_large_workspace_resource(handle)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * idx.dim(), - dataset, - sizeof(T) * idx.dim() * trainset_ratio, - sizeof(T) * idx.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, idx.dim()); - auto centers_view = raft::make_device_matrix_view( + auto trainset = + raft::make_device_mdarray(handle, + raft::resource::get_large_workspace_resource(handle), + raft::make_extents(n_rows_train, idx.dim())); + auto dataset_view = raft::make_device_matrix_view(dataset, n_rows, idx.dim()); + raft::matrix::sample_rows(handle, random_state, dataset_view, trainset.view()); + auto trainset_const_view = raft::make_const_mdspan(trainset.view()); + auto centers_view = raft::make_device_matrix_view( idx.centers().data_handle(), idx.n_lists(), idx.dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; @@ -502,7 +500,7 @@ inline auto build(raft::resources const& handle, dim3 threads(32, 8); dim3 blocks(raft::ceildiv(n_rows_train, threads.x), raft::ceildiv(dim, threads.y)); - compute_residuals_kernel<<>>(trainset.data(), + compute_residuals_kernel<<>>(trainset.data_handle(), idx.centers().data_handle(), train_labels.data_handle(), residuals.data(), From 70ca00a331cd96138b0eab3a57bb135858882341 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 15:16:41 +0100 Subject: [PATCH 09/31] inplace residuals --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 53 ++++++++++++++--------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index e7c0f734b4..e11f5dd6b5 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -64,8 +64,8 @@ struct ColMinMaxOp { * Row-loop is manually 4x-unrolled so the compiler can overlap four * independent __ldg requests in the memory pipeline. */ -template -__launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const float* __restrict__ data, +template +__launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __restrict__ data, float* __restrict__ col_min, float* __restrict__ col_max, int64_t n_rows, @@ -83,15 +83,15 @@ __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const float* int64_t row = static_cast(threadIdx.x); for (; row + 3 * stride < n_rows; row += 4 * stride) { - float v0 = __ldg(&data[row * dim + col]); - float v1 = __ldg(&data[(row + stride) * dim + col]); - float v2 = __ldg(&data[(row + 2 * stride) * dim + col]); - float v3 = __ldg(&data[(row + 3 * stride) * dim + col]); + float v0 = float(data[row * dim + col]); + float v1 = float(data[(row + stride) * dim + col]); + float v2 = float(data[(row + 2 * stride) * dim + col]); + float v3 = float(data[(row + 3 * stride) * dim + col]); agg.min_val = fminf(agg.min_val, fminf(fminf(v0, v1), fminf(v2, v3))); agg.max_val = fmaxf(agg.max_val, fmaxf(fmaxf(v0, v1), fmaxf(v2, v3))); } for (; row < n_rows; row += stride) { - float val = __ldg(&data[row * dim + col]); + float val = float(data[row * dim + col]); agg.min_val = fminf(agg.min_val, val); agg.max_val = fmaxf(agg.max_val, val); } @@ -229,6 +229,20 @@ RAFT_KERNEL compute_residuals_kernel(const T* dataset, residuals[i * dim + j] = val - centers[c * dim + j]; } +/** In-place variant: dataset[i] = cast(cast(dataset[i]) - centers[labels[i]]) */ +template +RAFT_KERNEL compute_residuals_inplace_kernel( + T* dataset, const float* centers, const uint32_t* labels, int64_t n_rows, uint32_t dim) +{ + int64_t i = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; + uint32_t j = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= n_rows || j >= dim) return; + + float val = utils::mapping{}(dataset[i * dim + j]); + uint32_t c = labels[i]; + dataset[i * dim + j] = utils::mapping{}(val - centers[c * dim + j]); +} + template void extend(raft::resources const& handle, index* index, @@ -481,10 +495,10 @@ inline auto build(raft::resources const& handle, cuvs::cluster::kmeans::fit(handle, kmeans_params, trainset_const_view, centers_view); raft::resource::sync_stream(handle); - // Train SQ: predict labels for the training subset, compute its residuals, + // Train SQ: predict labels for the training subset, compute residuals in-place, // and derive per-dimension vmin/delta from them. - auto train_labels = raft::make_device_vector(handle, n_rows_train); { + auto train_labels = raft::make_device_vector(handle, n_rows_train); cuvs::cluster::kmeans::balanced_params pred_params; pred_params.metric = idx.metric(); auto centers_const_view = raft::make_device_matrix_view( @@ -492,23 +506,22 @@ inline auto build(raft::resources const& handle, cuvs::cluster::kmeans::predict( handle, pred_params, trainset_const_view, centers_const_view, train_labels.view()); raft::resource::sync_stream(handle); - } - rmm::device_uvector residuals( - n_rows_train * dim, stream, raft::resource::get_large_workspace_resource(handle)); - { dim3 threads(32, 8); dim3 blocks(raft::ceildiv(n_rows_train, threads.x), raft::ceildiv(dim, threads.y)); - compute_residuals_kernel<<>>(trainset.data_handle(), - idx.centers().data_handle(), - train_labels.data_handle(), - residuals.data(), - n_rows_train, - dim); + compute_residuals_inplace_kernel + <<>>(trainset.data_handle(), + idx.centers().data_handle(), + train_labels.data_handle(), + n_rows_train, + dim); RAFT_CUDA_TRY(cudaPeekAtLastError()); } + // After the in-place kernel, trainset now contains residuals. + auto& residuals = trainset; + { auto vmax_buf = raft::make_device_vector(handle, dim); auto* vmin_ptr = idx.sq_vmin().data_handle(); @@ -516,7 +529,7 @@ inline auto build(raft::resources const& handle, constexpr int kMinMaxBlockSize = 256; fused_column_minmax_kernel<<>>( - residuals.data(), vmin_ptr, vmax_ptr, n_rows_train, dim); + residuals.data_handle(), vmin_ptr, vmax_ptr, n_rows_train, dim); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Expand the observed range by a small margin to reduce clipping on unseen data, From e7d660cddcc9e197c12dacae0db54a978d75fb1e Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 15:27:29 +0100 Subject: [PATCH 10/31] improved kernel layout for residuals computation --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 28 +++++++++++------------ 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index e11f5dd6b5..617d01179a 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -234,13 +234,13 @@ template RAFT_KERNEL compute_residuals_inplace_kernel( T* dataset, const float* centers, const uint32_t* labels, int64_t n_rows, uint32_t dim) { - int64_t i = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; - uint32_t j = blockIdx.y * blockDim.y + threadIdx.y; - if (i >= n_rows || j >= dim) return; - - float val = utils::mapping{}(dataset[i * dim + j]); - uint32_t c = labels[i]; - dataset[i * dim + j] = utils::mapping{}(val - centers[c * dim + j]); + int64_t i = blockIdx.x; + if (i >= n_rows) return; + uint32_t c = labels[i]; + for (uint32_t j = threadIdx.x; j < dim; j += blockDim.x) { + float val = utils::mapping{}(dataset[i * dim + j]); + dataset[i * dim + j] = utils::mapping{}(val - centers[c * dim + j]); + } } template @@ -507,15 +507,13 @@ inline auto build(raft::resources const& handle, handle, pred_params, trainset_const_view, centers_const_view, train_labels.view()); raft::resource::sync_stream(handle); - dim3 threads(32, 8); - dim3 blocks(raft::ceildiv(n_rows_train, threads.x), - raft::ceildiv(dim, threads.y)); + constexpr int kResidualBlockSize = 256; compute_residuals_inplace_kernel - <<>>(trainset.data_handle(), - idx.centers().data_handle(), - train_labels.data_handle(), - n_rows_train, - dim); + <<>>(trainset.data_handle(), + idx.centers().data_handle(), + train_labels.data_handle(), + n_rows_train, + dim); RAFT_CUDA_TRY(cudaPeekAtLastError()); } From 96b28db42cbb0aae730fefffc13ba0da2c763f59 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 15:40:27 +0100 Subject: [PATCH 11/31] raft::device_vector --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 617d01179a..3691b2306d 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -307,8 +307,7 @@ void extend(raft::resources const& handle, } auto* list_sizes_ptr = index->list_sizes().data_handle(); - auto old_list_sizes_dev = raft::make_device_mdarray( - handle, raft::resource::get_workspace_resource(handle), raft::make_extents(n_lists)); + auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); raft::copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); if (index->adaptive_centers()) { From 206cb2e612538afadbb3017d90c668f742b5f7e1 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 13 Mar 2026 17:32:06 +0100 Subject: [PATCH 12/31] drop adaptative_centers feature --- cpp/include/cuvs/neighbors/ivf_sq.hpp | 16 -- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 50 +--- cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh | 16 +- cpp/src/neighbors/ivf_sq_index.cpp | 17 +- cpp/tests/neighbors/ann_ivf_sq.cuh | 243 ++++++++---------- 5 files changed, 132 insertions(+), 210 deletions(-) diff --git a/cpp/include/cuvs/neighbors/ivf_sq.hpp b/cpp/include/cuvs/neighbors/ivf_sq.hpp index 042ecda10a..10d9a4c856 100644 --- a/cpp/include/cuvs/neighbors/ivf_sq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_sq.hpp @@ -30,19 +30,6 @@ struct index_params : cuvs::neighbors::index_params { uint32_t kmeans_n_iters = 20; /** The fraction of data to use during iterative kmeans building. */ double kmeans_trainset_fraction = 0.5; - /** - * By default (adaptive_centers = false), the cluster centers are trained in `ivf_sq::build`, - * and never modified in `ivf_sq::extend`. As a result, you may need to retrain the index - * from scratch after invoking (`ivf_sq::extend`) a few times with new data, the distribution of - * which is no longer representative of the original training set. - * - * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new - * data when it is added. In this case, `index.centers()` are always exactly the centroids of the - * data in the corresponding clusters. The drawback of this behavior is that the centroids depend - * on the order of adding new data (through the classification of the added data); that is, - * `index.centers()` "drift" together with the changing distribution of the newly added data. - */ - bool adaptive_centers = false; /** * By default, the algorithm allocates more space than necessary for individual clusters * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of @@ -181,11 +168,9 @@ struct index : cuvs::neighbors::index { cuvs::distance::DistanceType metric, uint32_t n_lists, uint32_t dim, - bool adaptive_centers, bool conservative_memory_allocation); cuvs::distance::DistanceType metric() const noexcept; - bool adaptive_centers() const noexcept; int64_t size() const noexcept; uint32_t dim() const noexcept; uint32_t n_lists() const noexcept; @@ -223,7 +208,6 @@ struct index : cuvs::neighbors::index { private: cuvs::distance::DistanceType metric_; - bool adaptive_centers_; bool conservative_memory_allocation_; std::vector>> lists_; diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 3691b2306d..e4a373a80b 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -13,7 +13,6 @@ #include #include -#include "../../cluster/kmeans_balanced.cuh" #include "../detail/ann_utils.cuh" #include #include @@ -109,12 +108,8 @@ auto clone(const raft::resources& res, const index& source) -> index { auto stream = raft::resource::get_cuda_stream(res); - index target(res, - source.metric(), - source.n_lists(), - source.dim(), - source.adaptive_centers(), - source.conservative_memory_allocation()); + index target( + res, source.metric(), source.n_lists(), source.dim(), source.conservative_memory_allocation()); raft::copy(target.list_sizes().data_handle(), source.list_sizes().data_handle(), @@ -310,36 +305,15 @@ void extend(raft::resources const& handle, auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); raft::copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); - if (index->adaptive_centers()) { - auto centroids_view = raft::make_device_matrix_view( - index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); - auto list_sizes_view = - raft::make_device_vector_view, int64_t>( - list_sizes_ptr, n_lists); - for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_labels.data_handle() + batch.offset(), batch.size()); - cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, - batch_data_view, - batch_labels_view, - centroids_view, - list_sizes_view, - false, - utils::mapping{}); - } - } else { - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(list_sizes_ptr), - int64_t(n_lists), - new_labels.data_handle(), - n_rows, - 1, - stream); - raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); - } + raft::stats::histogram(raft::stats::HistTypeAuto, + reinterpret_cast(list_sizes_ptr), + int64_t(n_lists), + new_labels.data_handle(), + n_rows, + 1, + stream); + raft::linalg::add( + list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); std::vector new_list_sizes(n_lists); std::vector old_list_sizes(n_lists); @@ -437,8 +411,6 @@ void extend(raft::resources const& handle, if (!index->center_norms().has_value()) { index->allocate_center_norms(handle); if (index->center_norms().has_value()) { compute_center_norms(); } - } else if (index->adaptive_centers()) { - compute_center_norms(); } } diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh index b95e63ee33..8aa1f12e04 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh @@ -37,7 +37,6 @@ void serialize(raft::resources const& handle, std::ostream& os, const index index if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); } - auto n_rows = raft::deserialize_scalar(handle, is); - auto dim = raft::deserialize_scalar(handle, is); - auto n_lists = raft::deserialize_scalar(handle, is); - auto metric = raft::deserialize_scalar(handle, is); - bool adaptive_centers = raft::deserialize_scalar(handle, is); - bool cma = raft::deserialize_scalar(handle, is); - - index index_ = index(handle, metric, n_lists, dim, adaptive_centers, cma); + auto n_rows = raft::deserialize_scalar(handle, is); + auto dim = raft::deserialize_scalar(handle, is); + auto n_lists = raft::deserialize_scalar(handle, is); + auto metric = raft::deserialize_scalar(handle, is); + bool cma = raft::deserialize_scalar(handle, is); + + index index_ = index(handle, metric, n_lists, dim, cma); deserialize_mdspan(handle, is, index_.centers()); diff --git a/cpp/src/neighbors/ivf_sq_index.cpp b/cpp/src/neighbors/ivf_sq_index.cpp index ffa54bd2e9..8b4de55f54 100644 --- a/cpp/src/neighbors/ivf_sq_index.cpp +++ b/cpp/src/neighbors/ivf_sq_index.cpp @@ -14,18 +14,13 @@ namespace cuvs::neighbors::ivf_sq { template index::index(raft::resources const& res) - : index(res, cuvs::distance::DistanceType::L2Expanded, 0, 0, false, false) + : index(res, cuvs::distance::DistanceType::L2Expanded, 0, 0, false) { } template index::index(raft::resources const& res, const index_params& params, uint32_t dim) - : index(res, - params.metric, - params.n_lists, - dim, - params.adaptive_centers, - params.conservative_memory_allocation) + : index(res, params.metric, params.n_lists, dim, params.conservative_memory_allocation) { } @@ -34,11 +29,9 @@ index::index(raft::resources const& res, cuvs::distance::DistanceType metric, uint32_t n_lists, uint32_t dim, - bool adaptive_centers, bool conservative_memory_allocation) : cuvs::neighbors::index(), metric_(metric), - adaptive_centers_(adaptive_centers), conservative_memory_allocation_(conservative_memory_allocation), lists_{n_lists}, list_sizes_{raft::make_device_vector(res, n_lists)}, @@ -67,12 +60,6 @@ cuvs::distance::DistanceType index::metric() const noexcept return metric_; } -template -bool index::adaptive_centers() const noexcept -{ - return adaptive_centers_; -} - template int64_t index::size() const noexcept { diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh index 25abc82740..d90ec66959 100644 --- a/cpp/tests/neighbors/ann_ivf_sq.cuh +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -31,7 +31,6 @@ struct AnnIvfSqInputs { IdxT nprobe; IdxT nlist; cuvs::distance::DistanceType metric; - bool adaptive_centers; }; template @@ -40,7 +39,7 @@ template os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " << p.nprobe << ", " << p.nlist << ", " << cuvs::neighbors::print_metric{static_cast((int)p.metric)} - << ", " << p.adaptive_centers << '}' << std::endl; + << '}' << std::endl; return os; } @@ -91,10 +90,9 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { { cuvs::neighbors::ivf_sq::index_params index_params; cuvs::neighbors::ivf_sq::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - index_params.adaptive_centers = ps.adaptive_centers; - search_params.n_probes = ps.nprobe; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + search_params.n_probes = ps.nprobe; index_params.add_data_on_build = true; index_params.kmeans_trainset_fraction = 0.5; @@ -108,7 +106,6 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { cuvs::neighbors::ivf_sq::index_params index_params_no_add; index_params_no_add.n_lists = ps.nlist; index_params_no_add.metric = ps.metric; - index_params_no_add.adaptive_centers = ps.adaptive_centers; index_params_no_add.add_data_on_build = false; index_params_no_add.kmeans_trainset_fraction = 0.5; @@ -214,10 +211,9 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { { cuvs::neighbors::ivf_sq::index_params index_params; cuvs::neighbors::ivf_sq::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - index_params.adaptive_centers = ps.adaptive_centers; - search_params.n_probes = ps.nprobe; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + search_params.n_probes = ps.nprobe; index_params.add_data_on_build = true; index_params.kmeans_trainset_fraction = 0.5; @@ -299,156 +295,141 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { }; const std::vector> inputs = { - // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers + // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric // ===== Dimension edge cases (all four metrics) ===== // dim=1 (CosineExpanded excluded: requires dim > 1) - {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 1, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, // dim=2,3,4,5 (unaligned) - {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, // dim=7,8 (around veclen=16 boundary, not a multiple of veclen) - {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, // dim=15,16,17 (around veclen=16 boundary) - {1000, 10000, 15, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 15, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 10000, 17, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 17, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 15, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 15, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 17, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 17, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, // dim=31,32,33 (around 2*veclen boundary) - {1000, 10000, 31, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 31, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 33, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 33, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 31, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 31, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 32, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 33, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 33, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, // medium dims - {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, // large dims (may exceed shared memory limits) - {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::CosineExpanded}, // ===== k edge cases ===== - {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 16, 2, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 5, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 20, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 20, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 16, 50, 100, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 2, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 5, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 20, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 20, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 50, 100, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::InnerProduct}, // ===== nprobe / nlist edge cases ===== // nprobe == nlist (exhaustive probe) - {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded}, // nprobe == 1 (minimal probe) - {1000, 10000, 16, 10, 1, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 1, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 1, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 1, 1024, cuvs::distance::DistanceType::CosineExpanded}, // nprobe > nlist (clamped to nlist) - {1000, 10000, 16, 10, 2048, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 2048, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 2048, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 2048, 1024, cuvs::distance::DistanceType::CosineExpanded}, // various nprobe - {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, // very small nlist - {100, 10000, 16, 10, 8, 8, cuvs::distance::DistanceType::L2Expanded, false}, - {100, 10000, 16, 10, 8, 8, cuvs::distance::DistanceType::CosineExpanded, false}, + {100, 10000, 16, 10, 8, 8, cuvs::distance::DistanceType::L2Expanded}, + {100, 10000, 16, 10, 8, 8, cuvs::distance::DistanceType::CosineExpanded}, // smaller nlist - {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false}, - {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct, false}, - {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false}, - {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2SqrtExpanded}, // ===== Dataset size edge cases ===== // single query - {1, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {1, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, // very few queries - {2, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {5, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {2, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {5, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, // very few db vectors (nlist reduced to fit) - {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::L2Expanded, false}, - {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::CosineExpanded, false}, + {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::L2Expanded}, + {100, 500, 16, 10, 40, 256, cuvs::distance::DistanceType::CosineExpanded}, // larger datasets - {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, - {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, // ===== Large query batches (gridDim.x > 65535) ===== - {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded, false}, - {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, - {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, - {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::L2SqrtExpanded, false}, - {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, - {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::CosineExpanded, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::L2SqrtExpanded}, + {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded}, + {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::CosineExpanded}, // just above the old 65535 limit - {65536, 1024, 16, 10, 32, 64, cuvs::distance::DistanceType::L2Expanded, false}, - {65536, 1024, 16, 10, 32, 64, cuvs::distance::DistanceType::CosineExpanded, false}, - - // ===== Adaptive centers (all four metrics, multiple dims) ===== - {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 8, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, - {1000, 10000, 32, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 32, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 32, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {65536, 1024, 16, 10, 32, 64, cuvs::distance::DistanceType::L2Expanded}, + {65536, 1024, 16, 10, 32, 64, cuvs::distance::DistanceType::CosineExpanded}, // ===== Recall-stability: same data, different query counts ===== - {20000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, - {50000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, + {20000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded}, + {50000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded}, }; } // namespace cuvs::neighbors::ivf_sq From e34bdd8178000a92aa4ee6083edf71b1e36e74a9 Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 16 Mar 2026 11:19:03 +0100 Subject: [PATCH 13/31] Add IVF-SQ FAISS benchmark --- cpp/bench/ann/CMakeLists.txt | 15 ++++++++++ .../cuvs_bench/config/algorithms.yaml | 5 +++- .../config/algos/constraints/__init__.py | 12 ++++++++ .../cuvs_bench/config/algos/cuvs_ivf_sq.yaml | 14 +++++----- .../config/algos/faiss_cpu_ivf_sq.yaml | 28 +++++++++++++++++++ .../config/algos/faiss_gpu_ivf_sq.yaml | 28 +++++++++++++++++++ 6 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 python/cuvs_bench/cuvs_bench/config/algos/faiss_cpu_ivf_sq.yaml create mode 100644 python/cuvs_bench/cuvs_bench/config/algos/faiss_gpu_ivf_sq.yaml diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index c377b64ec6..0755a15a2f 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -13,6 +13,7 @@ list(APPEND CMAKE_MODULE_PATH "${CUVS_SOURCE_DIR}") option(CUVS_ANN_BENCH_USE_FAISS_GPU_FLAT "Include faiss' brute-force knn algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT "Include faiss' ivf flat algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_FAISS_GPU_IVF_PQ "Include faiss' ivf pq algorithm in benchmark" ON) +option(CUVS_ANN_BENCH_USE_FAISS_GPU_IVF_SQ "Include faiss' ivf sq algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_FAISS_GPU_CAGRA "Include faiss' cagra algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_FAISS_GPU_CAGRA_HNSW "Include faiss' cagra algorithm for build and hnsw for search in benchmark" ON @@ -22,6 +23,7 @@ option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algori ON ) option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm in benchmark" ON) +option(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_SQ "Include faiss' cpu ivf sq algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_FAISS_CPU_HNSW_FLAT "Include faiss' hnsw algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT "Include cuVS ivf flat algorithm in benchmark" ON) option(CUVS_ANN_BENCH_USE_CUVS_IVF_SQ "Include cuVS ivf sq algorithm in benchmark" ON) @@ -318,6 +320,12 @@ if(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_PQ) ) endif() +if(CUVS_ANN_BENCH_USE_FAISS_CPU_IVF_SQ) + ConfigureAnnBench( + NAME FAISS_CPU_IVF_SQ PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${CUVS_FAISS_TARGETS} + ) +endif() + if(CUVS_ANN_BENCH_USE_FAISS_CPU_HNSW_FLAT) ConfigureAnnBench( NAME FAISS_CPU_HNSW_FLAT PATH src/faiss/faiss_cpu_benchmark.cpp LINKS ${CUVS_FAISS_TARGETS} @@ -338,6 +346,13 @@ if(CUVS_ANN_BENCH_USE_FAISS_GPU_IVF_PQ AND CUVS_FAISS_ENABLE_GPU) ) endif() +if(CUVS_ANN_BENCH_USE_FAISS_GPU_IVF_SQ AND CUVS_FAISS_ENABLE_GPU) + ConfigureAnnBench( + NAME FAISS_GPU_IVF_SQ PATH src/faiss/faiss_gpu_benchmark.cu LINKS ${CUVS_FAISS_TARGETS} + raft::raft + ) +endif() + if(CUVS_ANN_BENCH_USE_FAISS_GPU_FLAT AND CUVS_FAISS_ENABLE_GPU) ConfigureAnnBench( NAME FAISS_GPU_FLAT PATH src/faiss/faiss_gpu_benchmark.cu LINKS ${CUVS_FAISS_TARGETS} diff --git a/python/cuvs_bench/cuvs_bench/config/algorithms.yaml b/python/cuvs_bench/cuvs_bench/config/algorithms.yaml index 3a787f65ab..f181e549eb 100644 --- a/python/cuvs_bench/cuvs_bench/config/algorithms.yaml +++ b/python/cuvs_bench/cuvs_bench/config/algorithms.yaml @@ -8,7 +8,7 @@ faiss_gpu_ivf_pq: executable: FAISS_GPU_IVF_PQ_ANN_BENCH requires_gpu: true faiss_gpu_ivf_sq: - executable: FAISS_GPU_IVF_PQ_ANN_BENCH + executable: FAISS_GPU_IVF_SQ_ANN_BENCH requires_gpu: true faiss_gpu_cagra: executable: FAISS_GPU_CAGRA_ANN_BENCH @@ -25,6 +25,9 @@ faiss_cpu_ivf_flat: faiss_cpu_ivf_pq: executable: FAISS_CPU_IVF_PQ_ANN_BENCH requires_gpu: false +faiss_cpu_ivf_sq: + executable: FAISS_CPU_IVF_SQ_ANN_BENCH + requires_gpu: false faiss_cpu_hnsw_flat: executable: FAISS_CPU_HNSW_FLAT_ANN_BENCH requires_gpu: false diff --git a/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py b/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py index ea2afe351e..f22852f0ae 100644 --- a/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py +++ b/python/cuvs_bench/cuvs_bench/config/algos/constraints/__init__.py @@ -61,6 +61,18 @@ def cuvs_ivf_sq_search(params, build_params, k, batch_size): ############################################################################### +def faiss_gpu_ivf_sq_search(params, build_params, k, batch_size): + if "nlist" in build_params and "nprobe" in params: + return build_params["nlist"] >= params["nprobe"] + return True + + +def faiss_cpu_ivf_sq_search(params, build_params, k, batch_size): + if "nlist" in build_params and "nprobe" in params: + return build_params["nlist"] >= params["nprobe"] + return True + + def faiss_gpu_ivf_pq_build(params, dims): ret = True # M must be defined diff --git a/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml b/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml index adaad54e04..af59493eef 100644 --- a/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml +++ b/python/cuvs_bench/cuvs_bench/config/algos/cuvs_ivf_sq.yaml @@ -4,22 +4,22 @@ constraints: groups: base: build: - nlist: [1024, 2048, 4096, 8192] - ratio: [1, 2] + nlist: [1024, 2048, 4096] + ratio: [2, 4] niter: [25] search: - nprobe: [1, 5, 10, 20, 50, 100, 200, 500] + nprobe: [1, 5, 10, 20, 50, 100, 200] large: build: - nlist: [8192, 16384, 32000, 64000] - ratio: [2, 4] + nlist: [8192, 16384, 32768] + ratio: [4] niter: [20] search: - nprobe: [10, 20, 50, 100, 200, 500, 1000, 2000] + nprobe: [10, 20, 50, 100, 200, 500, 1000] test: build: nlist: [1024] - ratio: [1] + ratio: [2] niter: [20] search: nprobe: [1, 5] diff --git a/python/cuvs_bench/cuvs_bench/config/algos/faiss_cpu_ivf_sq.yaml b/python/cuvs_bench/cuvs_bench/config/algos/faiss_cpu_ivf_sq.yaml new file mode 100644 index 0000000000..ce237f280d --- /dev/null +++ b/python/cuvs_bench/cuvs_bench/config/algos/faiss_cpu_ivf_sq.yaml @@ -0,0 +1,28 @@ +name: faiss_cpu_ivf_sq +constraints: + search: cuvs_bench.config.algos.constraints.faiss_cpu_ivf_sq_search +groups: + base: + build: + nlist: [1024, 2048, 4096] + ratio: [2, 4] + quantizer_type: [int8] + search: + nprobe: [1, 5, 10, 20, 50, 100, 200] + refine_ratio: [1] + large: + build: + nlist: [8192, 16384, 32768] + ratio: [4] + quantizer_type: [int8] + search: + nprobe: [10, 20, 50, 100, 200, 500, 1000] + refine_ratio: [1, 2] + test: + build: + nlist: [1024] + ratio: [2] + quantizer_type: [int8] + search: + nprobe: [1, 5] + refine_ratio: [1] diff --git a/python/cuvs_bench/cuvs_bench/config/algos/faiss_gpu_ivf_sq.yaml b/python/cuvs_bench/cuvs_bench/config/algos/faiss_gpu_ivf_sq.yaml new file mode 100644 index 0000000000..d49df8116d --- /dev/null +++ b/python/cuvs_bench/cuvs_bench/config/algos/faiss_gpu_ivf_sq.yaml @@ -0,0 +1,28 @@ +name: faiss_gpu_ivf_sq +constraints: + search: cuvs_bench.config.algos.constraints.faiss_gpu_ivf_sq_search +groups: + base: + build: + nlist: [1024, 2048, 4096] + ratio: [2, 4] + quantizer_type: [int8] + search: + nprobe: [1, 5, 10, 20, 50, 100, 200] + refine_ratio: [1] + large: + build: + nlist: [8192, 16384, 32768] + ratio: [4] + quantizer_type: [int8] + search: + nprobe: [10, 20, 50, 100, 200, 500, 1000] + refine_ratio: [1, 2] + test: + build: + nlist: [1024] + ratio: [2] + quantizer_type: [int8] + search: + nprobe: [1, 5] + refine_ratio: [1] From 9bd7bc05aa49e7049a120c3764e81bce6709ef71 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 19 Mar 2026 14:36:02 +0100 Subject: [PATCH 14/31] Adressing review --- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 7 +++++++ cpp/src/neighbors/ivf_sq_index.cpp | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index a17992ff19..88c1b71970 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -459,6 +459,13 @@ inline void search_with_filtering(raft::resources const& handle, RAFT_EXPECTS(params.n_probes > 0, "n_probes (number of clusters to probe in the search) must be positive."); auto n_probes = std::min(params.n_probes, index.n_lists()); + if (n_probes < params.n_probes) { + RAFT_LOG_WARN( + "n_probes (%u) is larger than the number of lists in the index (%u), clamping to %u.", + params.n_probes, + index.n_lists(), + n_probes); + } uint32_t max_samples = std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); diff --git a/cpp/src/neighbors/ivf_sq_index.cpp b/cpp/src/neighbors/ivf_sq_index.cpp index 8b4de55f54..91eb86704f 100644 --- a/cpp/src/neighbors/ivf_sq_index.cpp +++ b/cpp/src/neighbors/ivf_sq_index.cpp @@ -43,6 +43,8 @@ index::index(raft::resources const& res, inds_ptrs_{raft::make_device_vector(res, n_lists)}, accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} { + RAFT_EXPECTS(n_lists > 0, "n_lists must be positive."); + RAFT_EXPECTS(dim > 0, "dim must be positive."); check_consistency(); auto stream = raft::resource::get_cuda_stream(res); std::memset(accum_sorted_sizes_.data_handle(), 0, accum_sorted_sizes_.size() * sizeof(int64_t)); @@ -228,6 +230,14 @@ void index::check_consistency() RAFT_EXPECTS((centers_.extent(0) == list_sizes_.extent(0)) && (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), "inconsistent number of lists (clusters)"); + RAFT_EXPECTS(sq_vmin_.extent(0) == centers_.extent(1), + "sq_vmin size (%u) does not match dim (%u)", + static_cast(sq_vmin_.extent(0)), + static_cast(centers_.extent(1))); + RAFT_EXPECTS(sq_delta_.extent(0) == centers_.extent(1), + "sq_delta size (%u) does not match dim (%u)", + static_cast(sq_delta_.extent(0)), + static_cast(centers_.extent(1))); } template struct index; From 0ce16416ff1364bba28440abd1da7ce4e5c339e5 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 20 Mar 2026 13:55:58 +0100 Subject: [PATCH 15/31] Addressing review --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 48 ++++++++--------------- cpp/src/neighbors/ivf_sq_index.cpp | 2 - 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index e4a373a80b..5c2f6808f1 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -204,26 +204,6 @@ __launch_bounds__(BlockSize) RAFT_KERNEL encode_and_fill_kernel(const uint32_t* } } -/** - * Compute residuals: residual[i] = cast(x_i) - centers[labels[i]] - */ -template -RAFT_KERNEL compute_residuals_kernel(const T* dataset, - const float* centers, - const uint32_t* labels, - float* residuals, - int64_t n_rows, - uint32_t dim) -{ - int64_t i = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; - uint32_t j = blockIdx.y * blockDim.y + threadIdx.y; - if (i >= n_rows || j >= dim) return; - - float val = utils::mapping{}(dataset[i * dim + j]); - uint32_t c = labels[i]; - residuals[i * dim + j] = val - centers[c * dim + j]; -} - /** In-place variant: dataset[i] = cast(cast(dataset[i]) - centers[labels[i]]) */ template RAFT_KERNEL compute_residuals_inplace_kernel( @@ -348,16 +328,22 @@ void extend(raft::resources const& handle, int64_t bs = batch.size(); { - dim3 threads(32, 8); - dim3 blocks(raft::ceildiv(bs, threads.x), raft::ceildiv(dim, threads.y)); - compute_residuals_kernel - <<>>(batch.data(), - index->centers().data_handle(), - new_labels.data_handle() + batch.offset(), - residuals_buf.data_handle(), - bs, - dim); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + auto batch_view = raft::make_device_matrix_view(batch.data(), bs, dim); + auto residuals_view = + raft::make_device_matrix_view(residuals_buf.data_handle(), bs, dim); + + const float* centers_ptr = index->centers().data_handle(); + const uint32_t* labels_ptr = new_labels.data_handle() + batch.offset(); + + raft::linalg::map_offset( + handle, + residuals_view, + [centers_ptr, labels_ptr, dim] __device__(auto idx, T x) { + auto i = idx / dim; + auto j = idx % dim; + return utils::mapping{}(x)-centers_ptr[labels_ptr[i] * dim + j]; + }, + batch_view); } { @@ -464,7 +450,6 @@ inline auto build(raft::resources const& handle, kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = idx.metric(); cuvs::cluster::kmeans::fit(handle, kmeans_params, trainset_const_view, centers_view); - raft::resource::sync_stream(handle); // Train SQ: predict labels for the training subset, compute residuals in-place, // and derive per-dimension vmin/delta from them. @@ -476,7 +461,6 @@ inline auto build(raft::resources const& handle, idx.centers().data_handle(), idx.n_lists(), dim); cuvs::cluster::kmeans::predict( handle, pred_params, trainset_const_view, centers_const_view, train_labels.view()); - raft::resource::sync_stream(handle); constexpr int kResidualBlockSize = 256; compute_residuals_inplace_kernel diff --git a/cpp/src/neighbors/ivf_sq_index.cpp b/cpp/src/neighbors/ivf_sq_index.cpp index 91eb86704f..bf5b6df288 100644 --- a/cpp/src/neighbors/ivf_sq_index.cpp +++ b/cpp/src/neighbors/ivf_sq_index.cpp @@ -43,8 +43,6 @@ index::index(raft::resources const& res, inds_ptrs_{raft::make_device_vector(res, n_lists)}, accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} { - RAFT_EXPECTS(n_lists > 0, "n_lists must be positive."); - RAFT_EXPECTS(dim > 0, "dim must be positive."); check_consistency(); auto stream = raft::resource::get_cuda_stream(res); std::memset(accum_sorted_sizes_.data_handle(), 0, accum_sorted_sizes_.size() * sizeof(int64_t)); From 77c4a7996d2a7dd9c7e9fbfef0c4cea0a72259db Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 2 Apr 2026 14:33:49 +0200 Subject: [PATCH 16/31] Fix issue with host data + half testing --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 44 +++---- cpp/tests/neighbors/ann_ivf_sq.cuh | 116 ++++++++++++++---- .../ann_ivf_sq/test_float_uint8_t.cu | 9 ++ 3 files changed, 118 insertions(+), 51 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 5c2f6808f1..8fbaf30240 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -412,14 +413,16 @@ auto extend(raft::resources const& handle, return ext_index; } -template -inline auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - int64_t n_rows, - uint32_t dim) -> index +template +inline auto build( + raft::resources const& handle, + const index_params& params, + raft::mdspan, raft::row_major, accessor> dataset) + -> index { - auto stream = raft::resource::get_cuda_stream(handle); + int64_t n_rows = dataset.extent(0); + uint32_t dim = dataset.extent(1); + auto stream = raft::resource::get_cuda_stream(handle); cuvs::common::nvtx::range fun_scope( "ivf_sq::build(%zu, %u)", size_t(n_rows), dim); static_assert(std::is_same_v || std::is_same_v, "unsupported data type"); @@ -441,8 +444,7 @@ inline auto build(raft::resources const& handle, raft::make_device_mdarray(handle, raft::resource::get_large_workspace_resource(handle), raft::make_extents(n_rows_train, idx.dim())); - auto dataset_view = raft::make_device_matrix_view(dataset, n_rows, idx.dim()); - raft::matrix::sample_rows(handle, random_state, dataset_view, trainset.view()); + raft::matrix::sample_rows(handle, random_state, dataset, trainset.view()); auto trainset_const_view = raft::make_const_mdspan(trainset.view()); auto centers_view = raft::make_device_matrix_view( idx.centers().data_handle(), idx.n_lists(), idx.dim()); @@ -499,31 +501,13 @@ inline auto build(raft::resources const& handle, } } - if (params.add_data_on_build) { detail::extend(handle, &idx, dataset, nullptr, n_rows); } + if (params.add_data_on_build) { + detail::extend(handle, &idx, dataset.data_handle(), nullptr, n_rows); + } return idx; } -template -auto build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -> index -{ - int64_t n_rows = dataset.extent(0); - uint32_t dim = dataset.extent(1); - return build(handle, params, dataset.data_handle(), n_rows, dim); -} - -template -auto build(raft::resources const& handle, - const index_params& params, - raft::host_matrix_view dataset) -> index -{ - int64_t n_rows = dataset.extent(0); - uint32_t dim = dataset.extent(1); - return build(handle, params, dataset.data_handle(), n_rows, dim); -} - template void build(raft::resources const& handle, const index_params& params, diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh index d90ec66959..c4f3c9de74 100644 --- a/cpp/tests/neighbors/ann_ivf_sq.cuh +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -16,6 +16,8 @@ #include #include +#include + namespace cuvs::neighbors::ivf_sq { struct test_ivf_sample_filter { @@ -31,6 +33,7 @@ struct AnnIvfSqInputs { IdxT nprobe; IdxT nlist; cuvs::distance::DistanceType metric; + bool host_dataset = false; }; template @@ -39,7 +42,7 @@ template os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " << p.nprobe << ", " << p.nlist << ", " << cuvs::neighbors::print_metric{static_cast((int)p.metric)} - << '}' << std::endl; + << ", " << (p.host_dataset ? "host" : "device") << '}' << std::endl; return os; } @@ -97,32 +100,60 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { index_params.add_data_on_build = true; index_params.kmeans_trainset_fraction = 0.5; - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - - auto idx = cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); - - // Test extend: build without data, then extend cuvs::neighbors::ivf_sq::index_params index_params_no_add; index_params_no_add.n_lists = ps.nlist; index_params_no_add.metric = ps.metric; index_params_no_add.add_data_on_build = false; index_params_no_add.kmeans_trainset_fraction = 0.5; - auto idx_empty = - cuvs::neighbors::ivf_sq::build(handle_, index_params_no_add, database_view); - - auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); - raft::linalg::map_offset(handle_, vector_indices.view(), raft::identity_op{}); - raft::resource::sync_stream(handle_); - - auto indices_view = raft::make_device_vector_view( - vector_indices.data_handle(), ps.num_db_vecs); - cuvs::neighbors::ivf_sq::extend( - handle_, - database_view, - std::make_optional>(indices_view), - &idx_empty); + cuvs::neighbors::ivf_sq::index idx(handle_); + cuvs::neighbors::ivf_sq::index idx_empty(handle_); + + if (!ps.host_dataset) { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + + idx = cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); + + idx_empty = cuvs::neighbors::ivf_sq::build(handle_, index_params_no_add, database_view); + + auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); + raft::linalg::map_offset(handle_, vector_indices.view(), raft::identity_op{}); + raft::resource::sync_stream(handle_); + + auto indices_view = raft::make_device_vector_view( + vector_indices.data_handle(), ps.num_db_vecs); + cuvs::neighbors::ivf_sq::extend( + handle_, + database_view, + std::make_optional>(indices_view), + &idx_empty); + } else { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + raft::resource::sync_stream(handle_); + + idx = cuvs::neighbors::ivf_sq::build( + handle_, index_params, raft::make_const_mdspan(host_database.view())); + + idx_empty = cuvs::neighbors::ivf_sq::build( + handle_, index_params_no_add, raft::make_const_mdspan(host_database.view())); + + auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); + std::iota( + vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, IdxT(0)); + + auto indices_view = raft::make_host_vector_view( + vector_indices.data_handle(), ps.num_db_vecs); + auto host_database_view = raft::make_host_matrix_view( + host_database.data_handle(), ps.num_db_vecs, ps.dim); + cuvs::neighbors::ivf_sq::extend( + handle_, + host_database_view, + std::make_optional>(indices_view), + &idx_empty); + } // Serialize / deserialize round-trip tmp_index_file index_file; @@ -430,6 +461,49 @@ const std::vector> inputs = { // ===== Recall-stability: same data, different query counts ===== {20000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded}, {50000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded}, + + // ===== Host dataset: build + extend from host_matrix_view ===== + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 3, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {100, 10000, 64, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct, true}, +}; + +const std::vector> inputs_half = { + // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, host_dataset + + // All four metrics at a standard dimension + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded}, + + // Unaligned and small dimensions + {1000, 10000, 3, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 7, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + + // Medium / larger dimensions + {1000, 10000, 64, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded}, + {1000, 10000, 256, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + + // k edge cases + {1000, 10000, 16, 1, 40, 1024, cuvs::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 50, 100, 1024, cuvs::distance::DistanceType::L2Expanded}, + + // nprobe / nlist edge cases + {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded}, + + // Host dataset + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 128, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, }; } // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu b/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu index 02ec8a7dfc..f136b6f41c 100644 --- a/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu +++ b/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu @@ -18,4 +18,13 @@ TEST_P(AnnIVFSQTestF_float, AnnIVFSQ) INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_float, ::testing::ValuesIn(inputs)); +typedef AnnIVFSQTest AnnIVFSQTestF_half; +TEST_P(AnnIVFSQTestF_half, AnnIVFSQ) +{ + this->testIVFSQ(); + this->testFilter(); +} + +INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_half, ::testing::ValuesIn(inputs_half)); + } // namespace cuvs::neighbors::ivf_sq From b46ea79f377360cdd1e9d3bc7f505e910dcccf63 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 2 Apr 2026 16:39:14 +0200 Subject: [PATCH 17/31] Update metric in doc --- cpp/include/cuvs/neighbors/ivf_sq.hpp | 10 +++++----- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/cuvs/neighbors/ivf_sq.hpp b/cpp/include/cuvs/neighbors/ivf_sq.hpp index 10d9a4c856..bb29b3fef1 100644 --- a/cpp/include/cuvs/neighbors/ivf_sq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_sq.hpp @@ -236,7 +236,7 @@ struct index : cuvs::neighbors::index { * * NB: Currently, the following distance metrics are supported: * - L2Expanded - * - L2Unexpanded + * - L2SqrtExpanded * - InnerProduct * - CosineExpanded * @@ -265,7 +265,7 @@ auto build(raft::resources const& handle, * * NB: Currently, the following distance metrics are supported: * - L2Expanded - * - L2Unexpanded + * - L2SqrtExpanded * - InnerProduct * - CosineExpanded * @@ -318,7 +318,7 @@ auto build(raft::resources const& handle, * * NB: Currently, the following distance metrics are supported: * - L2Expanded - * - L2Unexpanded + * - L2SqrtExpanded * - InnerProduct * - CosineExpanded * @@ -371,7 +371,7 @@ auto build(raft::resources const& handle, * * NB: Currently, the following distance metrics are supported: * - L2Expanded - * - L2Unexpanded + * - L2SqrtExpanded * - InnerProduct * - CosineExpanded * @@ -424,7 +424,7 @@ auto build(raft::resources const& handle, * * NB: Currently, the following distance metrics are supported: * - L2Expanded - * - L2Unexpanded + * - L2SqrtExpanded * - InnerProduct * - CosineExpanded * diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 8fbaf30240..7934460fb5 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -62,7 +62,7 @@ struct ColMinMaxOp { * rows and feed CUB BlockReduce with a combined min/max pair. * * Row-loop is manually 4x-unrolled so the compiler can overlap four - * independent __ldg requests in the memory pipeline. + * independent read-only loads in the memory pipeline. */ template __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __restrict__ data, From 44c5f0a3826408dd764fc8f048acbb5bf052be53 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 2 Apr 2026 16:42:40 +0200 Subject: [PATCH 18/31] Fix manage_local_topk / Capacity mismatch in IVF-SQ search --- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 705 +++++++++++++++------ 1 file changed, 516 insertions(+), 189 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index 88c1b71970..c79e043e20 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include @@ -34,143 +35,369 @@ using namespace cuvs::spatial::knn::detail; // NOLINT enum class SqScanMetric { kL2, kIP, kCosine }; -/** - * Per-probe scan kernel for IVF-SQ search. - * - * Grid: (n_queries, n_probes). Each block handles one (query, probe) pair. - * Within a block, each warp processes one interleaved group of kIndexGroupSize - * (=32) vectors at a time, with each lane responsible for one vector. - * Dimension blocks of veclen=16 bytes are loaded as coalesced uint4 reads - * across the warp (32 lanes x 16 bytes = 512 bytes = 4 cache lines), giving - * full memory-bandwidth utilisation. - * - * Per-dimension constants that are invariant across rows are precomputed into - * shared memory so the hot loop only reads from smem + one uint4 per dim-block: - * - * L2 / L2Sqrt: - * s_query_term[d] = query[d] - centroid[d] - sq_vmin[d] - * dist += (s_query_term[d] - code * s_sq_scale[d])^2 - * - * InnerProduct / Cosine: - * s_query_term[d] = query[d] - * s_recon_base[d] = centroid[d] + sq_vmin[d] - * v_d = s_recon_base[d] + code * s_sq_scale[d] - * dist += s_query_term[d] * v_d - * - * Shared-memory layout adapts to the metric to avoid waste: - * L2 / L2Sqrt : [s_query_term | s_sq_scale] (2 * dim floats) - * InnerProduct/Cosine: [s_query_term | s_recon_base | s_sq_scale] (3 * dim floats) - */ -template -__launch_bounds__(BlockDim) RAFT_KERNEL ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, - const uint32_t* list_sizes, - const uint32_t* coarse_indices, - const float* queries_float, - const float* centers, - const float* sq_vmin, - const float* sq_delta, - const float* query_norms, - uint32_t n_probes, - uint32_t dim, - uint32_t max_samples, - const uint32_t* chunk_indices, - float* out_distances, - uint32_t* out_indices, - IvfSampleFilterT sample_filter) +static constexpr int kSqScanThreads = 128; + +// Maximum fused top-k capacity we instantiate for the scan kernel. +// Must match the highest Capacity case in ivf_sq_scan's switch. +static constexpr int kMaxSqScanCapacity = 256; + +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool +{ + return k <= kMaxSqScanCapacity && + k <= raft::matrix::detail::select::warpsort::kMaxCapacity; +} + +// --------------------------------------------------------------------------- +// block_sort type selection (fused top-k vs dummy for Capacity == 0) +// --------------------------------------------------------------------------- +template +struct sq_block_sort { + using type = raft::matrix::detail::select::warpsort::block_sort< + raft::matrix::detail::select::warpsort::warp_sort_filtered, + Capacity, + Ascending, + float, + uint32_t>; +}; + +template +struct sq_block_sort<0, Ascending> { + using type = ivf::detail::dummy_block_sort_t; +}; + +template +using sq_block_sort_t = typename sq_block_sort::type; + +// --------------------------------------------------------------------------- +// configure_grid_dim_x: choose grid.x to saturate the GPU +// --------------------------------------------------------------------------- +inline uint32_t configure_grid_dim_x(uint32_t n_queries, + uint32_t n_probes, + int smem_size, + int block_size, + const void* kernel_ptr) +{ + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int num_sms; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel_ptr, block_size, smem_size)); + + size_t min_grid_size = size_t(num_sms) * num_blocks_per_sm; + size_t min_grid_x = raft::ceildiv(min_grid_size, n_queries); + return std::min(n_probes, static_cast(min_grid_x)); +} + +// --------------------------------------------------------------------------- +// IVF-SQ scan kernel with fused in-kernel top-k +// +// Grid layout: +// kManageLocalTopK (Capacity > 0): +// grid (grid_dim_x, n_queries) — each block loops over probes +// otherwise (Capacity == 0): +// grid (n_probes, n_queries) — one block per (query, probe) +// +// Shared-memory layout: [s_sq_scale(dim) | s_query_term(dim) | s_recon_base(dim)?] +// s_sq_scale is loaded once (invariant across probes). +// For L2, s_query_term is reloaded per probe. +// For IP/Cosine, s_query_term is loaded once and s_recon_base is reloaded per probe. +// After all probes are scanned, the smem is reused for block_sort merge. +// --------------------------------------------------------------------------- +template +__launch_bounds__(BlockDim) RAFT_KERNEL + ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, + const uint32_t* list_sizes, + const uint32_t* coarse_indices, + const float* queries_float, + const float* centers, + const float* sq_vmin, + const float* sq_delta, + const float* query_norms, + uint32_t n_probes, + uint32_t dim, + uint32_t k, + uint32_t max_samples, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + IvfSampleFilterT sample_filter) { static_assert(kIndexGroupSize == raft::WarpSize, "Warp-coalesced scan requires kIndexGroupSize == WarpSize"); - extern __shared__ float smem[]; + constexpr bool kManageLocalTopK = (Capacity > 0); + constexpr bool kIsL2 = (Metric == SqScanMetric::kL2); + constexpr bool kIsCosine = (Metric == SqScanMetric::kCosine); + constexpr bool kAscending = (Metric != SqScanMetric::kIP); - constexpr bool kIsL2 = (Metric == SqScanMetric::kL2); - constexpr bool kIsCosine = (Metric == SqScanMetric::kCosine); + extern __shared__ __align__(256) uint8_t smem_buf[]; + float* smem = reinterpret_cast(smem_buf); - float* s_query_term = smem; - float* s_recon_base = smem + dim; - float* s_sq_scale = kIsL2 ? (smem + dim) : (smem + 2 * dim); + // smem layout: [s_sq_scale | s_query_term | (s_recon_base for IP/Cosine)] + float* s_sq_scale = smem; + float* s_query_term = smem + dim; + float* s_recon_base = kIsL2 ? nullptr : (smem + 2 * dim); - const uint32_t query_ix = blockIdx.x; - const uint32_t probe_ix = blockIdx.y; + const uint32_t query_ix = blockIdx.y; + const float* query = queries_float + query_ix * dim; - const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; - const uint32_t cluster_id = my_coarse[probe_ix]; - const uint32_t cluster_sz = list_sizes[cluster_id]; - if (cluster_sz == 0) return; - - const uint8_t* codes = data_ptrs[cluster_id]; - const float* query = queries_float + query_ix * dim; - const float* centroid = centers + cluster_id * dim; + // Point output to this block's slice when using fused top-k + if constexpr (kManageLocalTopK) { + out_distances += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; + out_indices += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; + } + // --- Phase 1: load shared memory that is invariant across probes --- for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { - float vmin_d = sq_vmin[d]; s_sq_scale[d] = sq_delta[d]; - if constexpr (kIsL2) { - s_query_term[d] = query[d] - centroid[d] - vmin_d; - } else { - s_query_term[d] = query[d]; - s_recon_base[d] = centroid[d] + vmin_d; - } + if constexpr (!kIsL2) { s_query_term[d] = query[d]; } } - __syncthreads(); - const uint32_t* my_chunk = chunk_indices + query_ix * n_probes; - uint32_t out_base = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; + using local_topk_t = sq_block_sort_t; + local_topk_t queue(k); + + const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; + const uint32_t* my_chunk = chunk_indices + query_ix * n_probes; constexpr uint32_t veclen = 16; constexpr uint32_t kWarpsPerBlock = BlockDim / raft::WarpSize; const uint32_t warp_id = threadIdx.x / raft::WarpSize; const uint32_t lane_id = threadIdx.x % raft::WarpSize; - uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; - uint32_t n_dim_blocks = padded_dim / veclen; + // --- Phase 2: loop over probes --- + for (uint32_t probe_ix = blockIdx.x; probe_ix < n_probes; + probe_ix += (kManageLocalTopK ? gridDim.x : uint32_t{1})) { + const uint32_t cluster_id = my_coarse[probe_ix]; + const uint32_t cluster_sz = list_sizes[cluster_id]; + + // Load centroid-dependent shared memory terms + { + const float* centroid = centers + cluster_id * dim; + for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { + if constexpr (kIsL2) { + s_query_term[d] = query[d] - centroid[d] - sq_vmin[d]; + } else { + s_recon_base[d] = centroid[d] + sq_vmin[d]; + } + } + } + __syncthreads(); + + if (cluster_sz == 0) { + if constexpr (!kManageLocalTopK) break; + continue; + } + + const uint8_t* codes = data_ptrs[cluster_id]; + uint32_t sample_offset = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; + uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; + uint32_t n_dim_blocks = padded_dim / veclen; - for (uint32_t group = warp_id * kIndexGroupSize; group < cluster_sz; - group += kWarpsPerBlock * kIndexGroupSize) { - const uint32_t row = group + lane_id; - const bool valid = (row < cluster_sz) && sample_filter(query_ix, cluster_id, row); + for (uint32_t group = warp_id * kIndexGroupSize; group < cluster_sz; + group += kWarpsPerBlock * kIndexGroupSize) { + const uint32_t row = group + lane_id; + const bool valid = (row < cluster_sz) && sample_filter(query_ix, cluster_id, row); - float dist = 0.0f; - float v_norm_sq = 0.0f; + float dist = 0.0f; + float v_norm_sq = 0.0f; - const uint8_t* group_data = codes + size_t(group) * padded_dim; + const uint8_t* group_data = codes + size_t(group) * padded_dim; - for (uint32_t bl = 0; bl < n_dim_blocks; bl++) { - uint8_t codes_local[veclen]; - *reinterpret_cast(codes_local) = *reinterpret_cast( - group_data + bl * (veclen * kIndexGroupSize) + lane_id * veclen); + for (uint32_t bl = 0; bl < n_dim_blocks; bl++) { + uint8_t codes_local[veclen]; + *reinterpret_cast(codes_local) = *reinterpret_cast( + group_data + bl * (veclen * kIndexGroupSize) + lane_id * veclen); - const uint32_t l = bl * veclen; + const uint32_t l = bl * veclen; #pragma unroll - for (uint32_t j = 0; j < veclen; j++) { - if (l + j < dim) { - float recon = float(codes_local[j]) * s_sq_scale[l + j]; - - if constexpr (kIsL2) { - float diff = s_query_term[l + j] - recon; - dist += diff * diff; - } else { - float v_d = s_recon_base[l + j] + recon; - dist += s_query_term[l + j] * v_d; - if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } + for (uint32_t j = 0; j < veclen; j++) { + if (l + j < dim) { + float recon = float(codes_local[j]) * s_sq_scale[l + j]; + + if constexpr (kIsL2) { + float diff = s_query_term[l + j] - recon; + dist += diff * diff; + } else { + float v_d = s_recon_base[l + j] + recon; + dist += s_query_term[l + j] * v_d; + if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } + } } } } + + if constexpr (kIsCosine) { + float denom = query_norms[query_ix] * sqrtf(v_norm_sq); + dist = (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; + } + + if constexpr (kManageLocalTopK) { + float val = valid ? dist : local_topk_t::queue_t::kDummy; + queue.add(val, sample_offset + row); + } else { + if (valid) { + uint32_t out_idx = query_ix * max_samples + sample_offset + row; + out_distances[out_idx] = dist; + out_indices[out_idx] = sample_offset + row; + } + } } - if constexpr (kIsCosine) { - float denom = query_norms[query_ix] * sqrtf(v_norm_sq); - dist = (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; + __syncthreads(); + if constexpr (!kManageLocalTopK) break; + } + + if constexpr (kManageLocalTopK) { + __syncthreads(); + queue.done(smem_buf); + queue.store(out_distances, out_indices); + + // block_sort initializes unused slots with (kDummy, idx=0). When the + // probed clusters have fewer than k total valid vectors, those slots + // survive into the output and share idx=0 with the real first vector, + // causing duplicates. Mark them with an invalid index so + // postprocess_neighbors treats them as out-of-bounds. + // store() is a warp-0-only operation, restrict the fixup to the same warp. + if (threadIdx.x < raft::WarpSize) { + constexpr auto kDummyVal = local_topk_t::queue_t::kDummy; + for (uint32_t i = threadIdx.x; i < k; i += raft::WarpSize) { + if (out_distances[i] == kDummyVal) { out_indices[i] = uint32_t(0xFFFFFFFF); } + } } + } +} + +// --------------------------------------------------------------------------- +// Compute shared-memory size for a given kernel configuration +// --------------------------------------------------------------------------- +inline size_t sq_scan_smem_size(uint32_t dim, SqScanMetric metric) +{ + return (metric == SqScanMetric::kL2 ? 2 : 3) * dim * sizeof(float); +} + +template +size_t sq_scan_total_smem(uint32_t dim, uint32_t k, SqScanMetric metric) +{ + size_t scan_smem = sq_scan_smem_size(dim, metric); + if constexpr (Capacity > 0) { + constexpr int kSubwarpSize = std::min(Capacity, raft::WarpSize); + int num_subwarps = kSqScanThreads / kSubwarpSize; + size_t merge_smem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + num_subwarps, k); + return std::max(scan_smem, merge_smem); + } + return scan_smem; +} + +// --------------------------------------------------------------------------- +// Launch helper: dispatches on Metric, handles grid_dim_x query vs launch +// --------------------------------------------------------------------------- +template +void ivf_sq_scan_launch(const index& idx, + const float* queries_float, + const float* query_norms, + uint32_t n_queries, + uint32_t n_probes, + uint32_t k, + uint32_t max_samples, + const uint32_t* coarse_indices, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + IvfSampleFilterT sample_filter, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + constexpr bool kManageLocalTopK = (Capacity > 0); + constexpr int kThreads = kSqScanThreads; + uint32_t dim = idx.dim(); + + constexpr uint32_t kMaxGridY = 32768; - if (valid) { - uint32_t out_idx = query_ix * max_samples + out_base + row; - out_distances[out_idx] = dist; - out_indices[out_idx] = out_base + row; + auto do_launch = [&](auto kernel_ptr, SqScanMetric metric_val) { + size_t smem = sq_scan_total_smem(dim, k, metric_val); + + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem)); + + // If grid_dim_x == 0, compute the optimal value and return + if constexpr (kManageLocalTopK) { + if (grid_dim_x == 0) { + grid_dim_x = configure_grid_dim_x( + std::min(kMaxGridY, n_queries), n_probes, smem, kThreads, + reinterpret_cast(kernel_ptr)); + return; + } } + + dim3 block(kThreads); + + // Batch over queries to respect the gridDim.y limit (65535) + for (uint32_t query_offset = 0; query_offset < n_queries; query_offset += kMaxGridY) { + uint32_t batch = std::min(kMaxGridY, n_queries - query_offset); + dim3 grid = kManageLocalTopK ? dim3(grid_dim_x, batch) : dim3(n_probes, batch); + + auto q_ptr = queries_float + uint64_t(query_offset) * dim; + auto qn_ptr = query_norms ? query_norms + query_offset : query_norms; + auto ci = coarse_indices + uint64_t(query_offset) * n_probes; + auto ch = chunk_indices + uint64_t(query_offset) * n_probes; + auto od = out_distances; + auto oi = out_indices; + if constexpr (kManageLocalTopK) { + od += uint64_t(query_offset) * grid_dim_x * k; + oi += uint64_t(query_offset) * grid_dim_x * k; + } else { + od += uint64_t(query_offset) * max_samples; + oi += uint64_t(query_offset) * max_samples; + } + + kernel_ptr<<>>(idx.data_ptrs().data_handle(), + idx.list_sizes().data_handle(), + ci, + q_ptr, + idx.centers().data_handle(), + idx.sq_vmin().data_handle(), + idx.sq_delta().data_handle(), + qn_ptr, + n_probes, + dim, + k, + max_samples, + ch, + od, + oi, + sample_filter); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + }; + + switch (idx.metric()) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: + do_launch( + ivf_sq_scan_kernel, + SqScanMetric::kL2); + break; + case cuvs::distance::DistanceType::InnerProduct: + do_launch( + ivf_sq_scan_kernel, + SqScanMetric::kIP); + break; + case cuvs::distance::DistanceType::CosineExpanded: + do_launch( + ivf_sq_scan_kernel, + SqScanMetric::kCosine); + break; + default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); } } +// --------------------------------------------------------------------------- +// ivf_sq_scan: top-level scan dispatch with Capacity selection +// --------------------------------------------------------------------------- template void ivf_sq_scan(raft::resources const& handle, const index& idx, @@ -178,58 +405,57 @@ void ivf_sq_scan(raft::resources const& handle, const float* query_norms, uint32_t n_queries, uint32_t n_probes, + uint32_t k, uint32_t max_samples, const uint32_t* coarse_indices, const uint32_t* chunk_indices, float* out_distances, uint32_t* out_indices, IvfSampleFilterT sample_filter, + uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { - constexpr int kThreads = 256; - dim3 grid(n_queries, n_probes); - dim3 block(kThreads); - uint32_t dim = idx.dim(); + // Determine the fused top-k capacity (0 = disabled / fallback to materialization) + int capacity = is_local_topk_feasible(k) ? raft::bound_by_power_of_two(int(k)) : 0; + + // Clamp to supported compile-time Capacity values. + // Using a limited set to avoid excessive template instantiations. + if (capacity > 0 && capacity <= 32) { + capacity = 32; + } else if (capacity > 32 && capacity <= 256) { + capacity = 256; + } else if (capacity > 256) { + capacity = 0; + } - auto do_launch = [&](auto kernel_ptr, size_t smem) { - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem)); - kernel_ptr<<>>(idx.data_ptrs().data_handle(), - idx.list_sizes().data_handle(), - coarse_indices, - queries_float, - idx.centers().data_handle(), - idx.sq_vmin().data_handle(), - idx.sq_delta().data_handle(), - query_norms, - n_probes, - dim, - max_samples, - chunk_indices, - out_distances, - out_indices, - sample_filter); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + auto fwd = [&](auto cap_tag) { + ivf_sq_scan_launch(idx, + queries_float, + query_norms, + n_queries, + n_probes, + k, + max_samples, + coarse_indices, + chunk_indices, + out_distances, + out_indices, + sample_filter, + grid_dim_x, + stream); }; - switch (idx.metric()) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: - do_launch(ivf_sq_scan_kernel, - 2 * dim * sizeof(float)); - break; - case cuvs::distance::DistanceType::InnerProduct: - do_launch(ivf_sq_scan_kernel, - 3 * dim * sizeof(float)); - break; - case cuvs::distance::DistanceType::CosineExpanded: - do_launch(ivf_sq_scan_kernel, - 3 * dim * sizeof(float)); - break; - default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); + switch (capacity) { + case 0: fwd(std::integral_constant{}); break; + case 32: fwd(std::integral_constant{}); break; + case 256: fwd(std::integral_constant{}); break; + default: RAFT_FAIL("Unexpected capacity value %d", capacity); } } +// --------------------------------------------------------------------------- +// search_impl — host-side search logic +// --------------------------------------------------------------------------- template void search_impl(raft::resources const& handle, const index& index, @@ -366,41 +592,37 @@ void search_impl(raft::resources const& handle, num_samples.data(), stream); - uint32_t max_samples = - std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); - - rmm::device_uvector all_distances(std::size_t(n_queries) * max_samples, stream, search_mr); - rmm::device_uvector all_indices( - std::size_t(n_queries) * max_samples, stream, search_mr); - - float init_val = - select_min ? std::numeric_limits::max() : std::numeric_limits::lowest(); - thrust::fill_n(raft::resource::get_thrust_policy(handle), - all_distances.data(), - std::size_t(n_queries) * max_samples, - init_val); - thrust::fill_n(raft::resource::get_thrust_policy(handle), - all_indices.data(), - std::size_t(n_queries) * max_samples, - uint32_t(0xFFFFFFFF)); - auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( index.inds_ptrs().data_handle(), sample_filter); - ivf_sq_scan(handle, - index, - converted_queries_ptr, - query_norm_dev.data(), - n_queries, - n_probes, - max_samples, - coarse_indices_dev.data(), - chunk_index.data(), - all_distances.data(), - all_indices.data(), - filter_adapter, - stream); + bool manage_local_topk = is_local_topk_feasible(k); + + // Determine grid_dim_x for the fused path + uint32_t grid_dim_x = 0; + if (manage_local_topk) { + if (n_probes > 1) { + // Query the occupancy to compute optimal grid_dim_x (does not launch) + ivf_sq_scan(handle, + index, + converted_queries_ptr, + query_norm_dev.data(), + n_queries, + n_probes, + k, + 0, + coarse_indices_dev.data(), + chunk_index.data(), + nullptr, + nullptr, + filter_adapter, + grid_dim_x, + stream); + } else { + grid_dim_x = 1; + } + } + // Prepare uint32 neighbors buffer for postprocessing rmm::device_uvector neighbors_uint32(0, stream, search_mr); uint32_t* neighbors_uint32_ptr = nullptr; if constexpr (sizeof(int64_t) == sizeof(uint32_t)) { @@ -410,21 +632,110 @@ void search_impl(raft::resources const& handle, neighbors_uint32_ptr = neighbors_uint32.data(); } - auto num_samples_view = - raft::make_device_vector_view(num_samples.data(), n_queries); + if (manage_local_topk) { + // --- Fused top-k path --- + auto target_size = std::size_t(n_queries) * grid_dim_x * k; + rmm::device_uvector distances_tmp(0, stream, search_mr); + rmm::device_uvector indices_tmp(0, stream, search_mr); - cuvs::selection::select_k( - handle, - raft::make_device_matrix_view( - all_distances.data(), n_queries, max_samples), - raft::make_device_matrix_view( - all_indices.data(), n_queries, max_samples), - raft::make_device_matrix_view(distances, n_queries, k), - raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), - select_min, - false, - cuvs::selection::SelectAlgo::kAuto, - num_samples_view); + float* dist_out_ptr = nullptr; + uint32_t* idx_out_ptr = nullptr; + + if (grid_dim_x > 1) { + distances_tmp.resize(target_size, stream); + indices_tmp.resize(target_size, stream); + dist_out_ptr = distances_tmp.data(); + idx_out_ptr = indices_tmp.data(); + } else { + dist_out_ptr = distances; + idx_out_ptr = neighbors_uint32_ptr; + } + + ivf_sq_scan(handle, + index, + converted_queries_ptr, + query_norm_dev.data(), + n_queries, + n_probes, + k, + 0, + coarse_indices_dev.data(), + chunk_index.data(), + dist_out_ptr, + idx_out_ptr, + filter_adapter, + grid_dim_x, + stream); + + // Merge across blocks if needed + if (grid_dim_x > 1) { + auto cols = uint32_t(grid_dim_x) * k; + cuvs::selection::select_k( + handle, + raft::make_device_matrix_view( + distances_tmp.data(), n_queries, cols), + raft::make_device_matrix_view( + indices_tmp.data(), n_queries, cols), + raft::make_device_matrix_view(distances, n_queries, k), + raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), + select_min); + } + } else { + // --- Fallback: materialize all distances --- + uint32_t max_samples = + std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + + rmm::device_uvector all_distances( + std::size_t(n_queries) * max_samples, stream, search_mr); + rmm::device_uvector all_indices( + std::size_t(n_queries) * max_samples, stream, search_mr); + + float init_val = + select_min ? std::numeric_limits::max() : std::numeric_limits::lowest(); + thrust::fill_n(raft::resource::get_thrust_policy(handle), + all_distances.data(), + std::size_t(n_queries) * max_samples, + init_val); + thrust::fill_n(raft::resource::get_thrust_policy(handle), + all_indices.data(), + std::size_t(n_queries) * max_samples, + uint32_t(0xFFFFFFFF)); + + // grid_dim_x is unused for the non-fused path; set to n_probes so each + // block in the (n_probes, n_queries) grid processes exactly one probe + uint32_t gdx = n_probes; + ivf_sq_scan(handle, + index, + converted_queries_ptr, + query_norm_dev.data(), + n_queries, + n_probes, + k, + max_samples, + coarse_indices_dev.data(), + chunk_index.data(), + all_distances.data(), + all_indices.data(), + filter_adapter, + gdx, + stream); + + auto num_samples_view = + raft::make_device_vector_view(num_samples.data(), n_queries); + + cuvs::selection::select_k( + handle, + raft::make_device_matrix_view( + all_distances.data(), n_queries, max_samples), + raft::make_device_matrix_view( + all_indices.data(), n_queries, max_samples), + raft::make_device_matrix_view(distances, n_queries, k), + raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), + select_min, + false, + cuvs::selection::SelectAlgo::kAuto, + num_samples_view); + } ivf::detail::postprocess_distances( handle, distances, distances, index.metric(), n_queries, k, 1.0, false); @@ -467,17 +778,33 @@ inline void search_with_filtering(raft::resources const& handle, n_probes); } - uint32_t max_samples = - std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + bool manage_local_topk = is_local_topk_feasible(k); + + uint32_t max_samples = 0; + if (!manage_local_topk) { + max_samples = + std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + } constexpr uint64_t kExpectedWsSize = 1024ull * 1024 * 1024; uint64_t max_ws_size = std::min(raft::resource::get_workspace_free_bytes(handle), kExpectedWsSize); uint64_t converted_query_floats = std::is_same_v ? 0 : index.dim(); - uint64_t ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + max_samples + - converted_query_floats) + - sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + max_samples + k); + uint64_t ws_per_query; + if (manage_local_topk) { + // Fused path: only small per-query buffers for coarse search + chunk indices + // (The scan output is at most grid_dim_x * k per query, which is small) + // Conservatively assume grid_dim_x <= n_probes for the workspace estimate + uint64_t fused_out = uint64_t(n_probes) * k; + ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + fused_out + + converted_query_floats) + + sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + fused_out + k); + } else { + ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + max_samples + + converted_query_floats) + + sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + max_samples + k); + } const uint32_t max_queries = std::min(n_queries, std::max(1, max_ws_size / ws_per_query)); From ef957f7373256be6a15721404387681028fad06a Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 2 Apr 2026 16:44:51 +0200 Subject: [PATCH 19/31] Add large-k tests for IVF-SQ materialized fallback path --- cpp/tests/neighbors/ann_ivf_sq.cuh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh index c4f3c9de74..5cc4e55075 100644 --- a/cpp/tests/neighbors/ann_ivf_sq.cuh +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -398,6 +398,15 @@ const std::vector> inputs = { {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::L2Expanded}, {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::InnerProduct}, + // ===== Large k (beyond fused top-k kMaxSqScanCapacity=256, exercises materialized fallback) ===== + // k=257: smallest k that forces the materialized path (Capacity clamped to 0) + {100, 10000, 32, 257, 100, 64, cuvs::distance::DistanceType::L2Expanded}, + {100, 10000, 32, 257, 100, 64, cuvs::distance::DistanceType::InnerProduct}, + {100, 10000, 32, 257, 100, 64, cuvs::distance::DistanceType::CosineExpanded}, + // k=300: comfortably above the fused top-k threshold + {100, 10000, 32, 300, 64, 64, cuvs::distance::DistanceType::L2Expanded}, + {100, 10000, 32, 300, 64, 64, cuvs::distance::DistanceType::InnerProduct}, + // ===== nprobe / nlist edge cases ===== // nprobe == nlist (exhaustive probe) {1000, 10000, 16, 10, 64, 64, cuvs::distance::DistanceType::L2Expanded}, From 56ebfc9508caf20b0b592b1d8572577d21b965c7 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 2 Apr 2026 16:53:06 +0200 Subject: [PATCH 20/31] Improve shared memory synchronization in IVF-SQ scan kernel --- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index c79e043e20..c7b9d65e20 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -154,6 +154,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL s_sq_scale[d] = sq_delta[d]; if constexpr (!kIsL2) { s_query_term[d] = query[d]; } } + __syncthreads(); using local_topk_t = sq_block_sort_t; local_topk_t queue(k); @@ -167,6 +168,15 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const uint32_t lane_id = threadIdx.x % raft::WarpSize; // --- Phase 2: loop over probes --- + // Synchronization protocol: + // (a) __syncthreads after Phase 1 (above) ensures s_sq_scale / s_query_term + // are visible before any probe iteration overwrites s_query_term or + // s_recon_base. + // (b) __syncthreads after per-probe smem writes (below) ensures + // probe-specific values are visible before the distance computation. + // (c) __syncthreads at the end of each iteration ensures all distance + // computation reads are complete before the next iteration overwrites + // the same smem regions. for (uint32_t probe_ix = blockIdx.x; probe_ix < n_probes; probe_ix += (kManageLocalTopK ? gridDim.x : uint32_t{1})) { const uint32_t cluster_id = my_coarse[probe_ix]; @@ -183,9 +193,11 @@ __launch_bounds__(BlockDim) RAFT_KERNEL } } } - __syncthreads(); + __syncthreads(); // (b) if (cluster_sz == 0) { + // No distance computation reads happened, so no end-of-iteration + // barrier is needed; the next iteration's barrier (b) is sufficient. if constexpr (!kManageLocalTopK) break; continue; } @@ -245,11 +257,15 @@ __launch_bounds__(BlockDim) RAFT_KERNEL } } - __syncthreads(); + __syncthreads(); // (c) if constexpr (!kManageLocalTopK) break; } if constexpr (kManageLocalTopK) { + // All probe iterations are done; smem_buf is reused for block_sort merge. + // The loop's last (b) or (c) barrier ensures all prior smem accesses have + // completed, so this additional barrier is only needed to synchronize any + // register-level state across warps before the merge. __syncthreads(); queue.done(smem_buf); queue.store(out_distances, out_indices); From 15b2f158378251104d2598577ef033e300ee50ca Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 2 Apr 2026 18:41:05 +0200 Subject: [PATCH 21/31] IVF-SQ scan: reduce L2 global reads and refine fused top-k capacity selection --- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 240 ++++++++++++--------- 1 file changed, 138 insertions(+), 102 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index c7b9d65e20..17e6c4f9eb 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -40,11 +40,13 @@ static constexpr int kSqScanThreads = 128; // Maximum fused top-k capacity we instantiate for the scan kernel. // Must match the highest Capacity case in ivf_sq_scan's switch. static constexpr int kMaxSqScanCapacity = 256; +static_assert(kMaxSqScanCapacity <= raft::matrix::detail::select::warpsort::kMaxCapacity, + "kMaxSqScanCapacity must not exceed the warpsort library's maximum supported " + "capacity; reduce kMaxSqScanCapacity or update the warpsort dependency."); auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool { - return k <= kMaxSqScanCapacity && - k <= raft::matrix::detail::select::warpsort::kMaxCapacity; + return k <= kMaxSqScanCapacity; } // --------------------------------------------------------------------------- @@ -71,11 +73,8 @@ using sq_block_sort_t = typename sq_block_sort::type; // --------------------------------------------------------------------------- // configure_grid_dim_x: choose grid.x to saturate the GPU // --------------------------------------------------------------------------- -inline uint32_t configure_grid_dim_x(uint32_t n_queries, - uint32_t n_probes, - int smem_size, - int block_size, - const void* kernel_ptr) +inline uint32_t configure_grid_dim_x( + uint32_t n_queries, uint32_t n_probes, int smem_size, int block_size, const void* kernel_ptr) { int dev_id; RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); @@ -99,30 +98,42 @@ inline uint32_t configure_grid_dim_x(uint32_t n_queries, // otherwise (Capacity == 0): // grid (n_probes, n_queries) — one block per (query, probe) // -// Shared-memory layout: [s_sq_scale(dim) | s_query_term(dim) | s_recon_base(dim)?] -// s_sq_scale is loaded once (invariant across probes). -// For L2, s_query_term is reloaded per probe. -// For IP/Cosine, s_query_term is loaded once and s_recon_base is reloaded per probe. +// Shared-memory layout (always 3 × dim floats): +// [s_sq_scale(dim) | s_query_term(dim) | s_aux(dim)] +// +// s_sq_scale = delta[d] — SQ dequantization scale, invariant (Phase 1). +// +// L2 path: +// Phase 1: s_aux[d] = query[d] - vmin[d] (invariant) +// Phase 2: s_query_term[d] = s_aux[d] - centroid[d] (per-probe) +// The full SQ reconstruction is centroid + vmin + code*delta, so +// query - reconstructed = (query - vmin - centroid) - code*delta +// = s_query_term - code*s_sq_scale. +// +// IP/Cosine path: +// Phase 1: s_query_term[d] = query[d] (invariant) +// Phase 2: s_aux[d] = centroid[d] + vmin[d] (per-probe) +// Reconstructed vector component: s_aux[d] + code*s_sq_scale[d]. +// // After all probes are scanned, the smem is reused for block_sort merge. // --------------------------------------------------------------------------- template -__launch_bounds__(BlockDim) RAFT_KERNEL - ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, - const uint32_t* list_sizes, - const uint32_t* coarse_indices, - const float* queries_float, - const float* centers, - const float* sq_vmin, - const float* sq_delta, - const float* query_norms, - uint32_t n_probes, - uint32_t dim, - uint32_t k, - uint32_t max_samples, - const uint32_t* chunk_indices, - float* out_distances, - uint32_t* out_indices, - IvfSampleFilterT sample_filter) +__launch_bounds__(BlockDim) RAFT_KERNEL ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, + const uint32_t* list_sizes, + const uint32_t* coarse_indices, + const float* queries_float, + const float* centers, + const float* sq_vmin, + const float* sq_delta, + const float* query_norms, + uint32_t n_probes, + uint32_t dim, + uint32_t k, + uint32_t max_samples, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + IvfSampleFilterT sample_filter) { static_assert(kIndexGroupSize == raft::WarpSize, "Warp-coalesced scan requires kIndexGroupSize == WarpSize"); @@ -135,10 +146,9 @@ __launch_bounds__(BlockDim) RAFT_KERNEL extern __shared__ __align__(256) uint8_t smem_buf[]; float* smem = reinterpret_cast(smem_buf); - // smem layout: [s_sq_scale | s_query_term | (s_recon_base for IP/Cosine)] float* s_sq_scale = smem; float* s_query_term = smem + dim; - float* s_recon_base = kIsL2 ? nullptr : (smem + 2 * dim); + float* s_aux = smem + 2 * dim; const uint32_t query_ix = blockIdx.y; const float* query = queries_float + query_ix * dim; @@ -152,7 +162,11 @@ __launch_bounds__(BlockDim) RAFT_KERNEL // --- Phase 1: load shared memory that is invariant across probes --- for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { s_sq_scale[d] = sq_delta[d]; - if constexpr (!kIsL2) { s_query_term[d] = query[d]; } + if constexpr (kIsL2) { + s_aux[d] = query[d] - sq_vmin[d]; + } else { + s_query_term[d] = query[d]; + } } __syncthreads(); @@ -169,14 +183,18 @@ __launch_bounds__(BlockDim) RAFT_KERNEL // --- Phase 2: loop over probes --- // Synchronization protocol: - // (a) __syncthreads after Phase 1 (above) ensures s_sq_scale / s_query_term - // are visible before any probe iteration overwrites s_query_term or - // s_recon_base. - // (b) __syncthreads after per-probe smem writes (below) ensures - // probe-specific values are visible before the distance computation. + // (a) __syncthreads after Phase 1 (above) ensures invariant smem arrays + // (s_sq_scale, and L2: s_aux / IP-Cosine: s_query_term) are visible + // before Phase 2 overwrites the per-probe array. + // (b) __syncthreads after per-probe smem writes (L2: s_query_term / + // IP-Cosine: s_aux) ensures probe-specific values are visible before + // the distance computation. // (c) __syncthreads at the end of each iteration ensures all distance // computation reads are complete before the next iteration overwrites - // the same smem regions. + // the per-probe smem region. + // When cluster_sz == 0, barrier (c) is skipped because no distance reads + // occurred; all threads converge on the same branch uniformly, and the + // next iteration's barrier (b) provides the needed ordering. for (uint32_t probe_ix = blockIdx.x; probe_ix < n_probes; probe_ix += (kManageLocalTopK ? gridDim.x : uint32_t{1})) { const uint32_t cluster_id = my_coarse[probe_ix]; @@ -187,9 +205,9 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const float* centroid = centers + cluster_id * dim; for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { if constexpr (kIsL2) { - s_query_term[d] = query[d] - centroid[d] - sq_vmin[d]; + s_query_term[d] = s_aux[d] - centroid[d]; } else { - s_recon_base[d] = centroid[d] + sq_vmin[d]; + s_aux[d] = centroid[d] + sq_vmin[d]; } } } @@ -202,10 +220,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL continue; } - const uint8_t* codes = data_ptrs[cluster_id]; - uint32_t sample_offset = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; - uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; - uint32_t n_dim_blocks = padded_dim / veclen; + const uint8_t* codes = data_ptrs[cluster_id]; + uint32_t sample_offset = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; + uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; + uint32_t n_dim_blocks = padded_dim / veclen; for (uint32_t group = warp_id * kIndexGroupSize; group < cluster_sz; group += kWarpsPerBlock * kIndexGroupSize) { @@ -232,7 +250,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL float diff = s_query_term[l + j] - recon; dist += diff * diff; } else { - float v_d = s_recon_base[l + j] + recon; + float v_d = s_aux[l + j] + recon; dist += s_query_term[l + j] * v_d; if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } } @@ -288,15 +306,12 @@ __launch_bounds__(BlockDim) RAFT_KERNEL // --------------------------------------------------------------------------- // Compute shared-memory size for a given kernel configuration // --------------------------------------------------------------------------- -inline size_t sq_scan_smem_size(uint32_t dim, SqScanMetric metric) -{ - return (metric == SqScanMetric::kL2 ? 2 : 3) * dim * sizeof(float); -} +inline size_t sq_scan_smem_size(uint32_t dim) { return 3 * dim * sizeof(float); } template -size_t sq_scan_total_smem(uint32_t dim, uint32_t k, SqScanMetric metric) +size_t sq_scan_total_smem(uint32_t dim, uint32_t k) { - size_t scan_smem = sq_scan_smem_size(dim, metric); + size_t scan_smem = sq_scan_smem_size(dim); if constexpr (Capacity > 0) { constexpr int kSubwarpSize = std::min(Capacity, raft::WarpSize); int num_subwarps = kSqScanThreads / kSubwarpSize; @@ -331,10 +346,25 @@ void ivf_sq_scan_launch(const index& idx, constexpr int kThreads = kSqScanThreads; uint32_t dim = idx.dim(); - constexpr uint32_t kMaxGridY = 32768; + constexpr uint32_t kMaxGridY = 65535; + + auto do_launch = [&](auto kernel_ptr) { + size_t smem = sq_scan_total_smem(dim, k); - auto do_launch = [&](auto kernel_ptr, SqScanMetric metric_val) { - size_t smem = sq_scan_total_smem(dim, k, metric_val); + { + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int max_smem; + RAFT_CUDA_TRY( + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id)); + RAFT_EXPECTS(smem <= size_t(max_smem), + "IVF-SQ scan kernel requires %zu bytes of shared memory (dim=%u, k=%u), " + "but the device supports at most %d bytes per block.", + smem, + dim, + k, + max_smem); + } RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem)); @@ -342,9 +372,11 @@ void ivf_sq_scan_launch(const index& idx, // If grid_dim_x == 0, compute the optimal value and return if constexpr (kManageLocalTopK) { if (grid_dim_x == 0) { - grid_dim_x = configure_grid_dim_x( - std::min(kMaxGridY, n_queries), n_probes, smem, kThreads, - reinterpret_cast(kernel_ptr)); + grid_dim_x = configure_grid_dim_x(std::min(kMaxGridY, n_queries), + n_probes, + smem, + kThreads, + reinterpret_cast(kernel_ptr)); return; } } @@ -393,19 +425,14 @@ void ivf_sq_scan_launch(const index& idx, switch (idx.metric()) { case cuvs::distance::DistanceType::L2Expanded: case cuvs::distance::DistanceType::L2SqrtExpanded: - do_launch( - ivf_sq_scan_kernel, - SqScanMetric::kL2); + do_launch(ivf_sq_scan_kernel); break; case cuvs::distance::DistanceType::InnerProduct: - do_launch( - ivf_sq_scan_kernel, - SqScanMetric::kIP); + do_launch(ivf_sq_scan_kernel); break; case cuvs::distance::DistanceType::CosineExpanded: do_launch( - ivf_sq_scan_kernel, - SqScanMetric::kCosine); + ivf_sq_scan_kernel); break; default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); } @@ -434,13 +461,13 @@ void ivf_sq_scan(raft::resources const& handle, // Determine the fused top-k capacity (0 = disabled / fallback to materialization) int capacity = is_local_topk_feasible(k) ? raft::bound_by_power_of_two(int(k)) : 0; - // Clamp to supported compile-time Capacity values. - // Using a limited set to avoid excessive template instantiations. - if (capacity > 0 && capacity <= 32) { + // Snap to the nearest supported compile-time Capacity value (must be a + // power of two). Values up to 32 share one instantiation; 64, 128 and 256 + // each get their own. Beyond kMaxSqScanCapacity we fall back to the + // non-fused path (Capacity == 0). + if (capacity > 0 && capacity < 32) { capacity = 32; - } else if (capacity > 32 && capacity <= 256) { - capacity = 256; - } else if (capacity > 256) { + } else if (capacity > kMaxSqScanCapacity) { capacity = 0; } @@ -464,6 +491,8 @@ void ivf_sq_scan(raft::resources const& handle, switch (capacity) { case 0: fwd(std::integral_constant{}); break; case 32: fwd(std::integral_constant{}); break; + case 64: fwd(std::integral_constant{}); break; + case 128: fwd(std::integral_constant{}); break; case 256: fwd(std::integral_constant{}); break; default: RAFT_FAIL("Unexpected capacity value %d", capacity); } @@ -490,7 +519,8 @@ void search_impl(raft::resources const& handle, std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); - rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); + bool needs_query_norms = index.metric() != cuvs::distance::DistanceType::InnerProduct; + rmm::device_uvector query_norm_dev(needs_query_norms ? n_queries : 0, stream, search_mr); rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); rmm::device_uvector coarse_distances_dev(n_queries_probes, stream, search_mr); rmm::device_uvector coarse_indices_dev(n_queries_probes, stream, search_mr); @@ -581,7 +611,7 @@ void search_impl(raft::resources const& handle, raft::linalg::map_offset( handle, distance_buffer_dev_view, - [=] __device__(const uint32_t idx, const float dist) { + [=] __device__(const int64_t idx, const float dist) { const auto query = idx / n_lists_local; const auto cluster = idx % n_lists_local; float denom = q_norm_ptr[query] * center_norm_ptr[cluster]; @@ -616,25 +646,29 @@ void search_impl(raft::resources const& handle, // Determine grid_dim_x for the fused path uint32_t grid_dim_x = 0; if (manage_local_topk) { - if (n_probes > 1) { - // Query the occupancy to compute optimal grid_dim_x (does not launch) - ivf_sq_scan(handle, - index, - converted_queries_ptr, - query_norm_dev.data(), - n_queries, - n_probes, - k, - 0, - coarse_indices_dev.data(), - chunk_index.data(), - nullptr, - nullptr, - filter_adapter, - grid_dim_x, - stream); - } else { - grid_dim_x = 1; + // Query the occupancy to compute optimal grid_dim_x (does not launch) + ivf_sq_scan(handle, + index, + converted_queries_ptr, + query_norm_dev.data(), + n_queries, + n_probes, + k, + 0, + coarse_indices_dev.data(), + chunk_index.data(), + nullptr, + nullptr, + filter_adapter, + grid_dim_x, + stream); + if (grid_dim_x == 0) { + manage_local_topk = false; + RAFT_LOG_WARN( + "IVF-SQ fused top-k kernel has zero occupancy (dim=%u, k=%u); " + "falling back to the non-fused scan path.", + index.dim(), + k); } } @@ -654,8 +688,8 @@ void search_impl(raft::resources const& handle, rmm::device_uvector distances_tmp(0, stream, search_mr); rmm::device_uvector indices_tmp(0, stream, search_mr); - float* dist_out_ptr = nullptr; - uint32_t* idx_out_ptr = nullptr; + float* dist_out_ptr = nullptr; + uint32_t* idx_out_ptr = nullptr; if (grid_dim_x > 1) { distances_tmp.resize(target_size, stream); @@ -688,18 +722,18 @@ void search_impl(raft::resources const& handle, auto cols = uint32_t(grid_dim_x) * k; cuvs::selection::select_k( handle, - raft::make_device_matrix_view( - distances_tmp.data(), n_queries, cols), - raft::make_device_matrix_view( - indices_tmp.data(), n_queries, cols), + raft::make_device_matrix_view(distances_tmp.data(), n_queries, cols), + raft::make_device_matrix_view(indices_tmp.data(), n_queries, cols), raft::make_device_matrix_view(distances, n_queries, k), raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), select_min); } } else { // --- Fallback: materialize all distances --- - uint32_t max_samples = - std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + int64_t ms = std::max(index.accum_sorted_sizes()(n_probes), k); + RAFT_EXPECTS(ms <= int64_t(std::numeric_limits::max()), + "The maximum sample size is too big."); + uint32_t max_samples = static_cast(ms); rmm::device_uvector all_distances( std::size_t(n_queries) * max_samples, stream, search_mr); @@ -798,8 +832,10 @@ inline void search_with_filtering(raft::resources const& handle, uint32_t max_samples = 0; if (!manage_local_topk) { - max_samples = - std::max(static_cast(index.accum_sorted_sizes()(n_probes)), k); + int64_t ms = std::max(index.accum_sorted_sizes()(n_probes), k); + RAFT_EXPECTS(ms <= int64_t(std::numeric_limits::max()), + "The maximum sample size is too big."); + max_samples = static_cast(ms); } constexpr uint64_t kExpectedWsSize = 1024ull * 1024 * 1024; @@ -813,7 +849,7 @@ inline void search_with_filtering(raft::resources const& handle, // (The scan output is at most grid_dim_x * k per query, which is small) // Conservatively assume grid_dim_x <= n_probes for the workspace estimate uint64_t fused_out = uint64_t(n_probes) * k; - ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + fused_out + + ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + fused_out + converted_query_floats) + sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + fused_out + k); } else { From 3a3427f5ca5de84086c920172f8f244c36702c7f Mon Sep 17 00:00:00 2001 From: vic Date: Tue, 7 Apr 2026 14:06:49 +0200 Subject: [PATCH 22/31] Addressing review (tests updates) --- cpp/tests/CMakeLists.txt | 2 +- cpp/tests/neighbors/ann_ivf_sq.cuh | 317 ++++++++++-------- .../ann_ivf_sq/test_float_int64_t.cu | 20 ++ ..._float_uint8_t.cu => test_half_int64_t.cu} | 18 +- 4 files changed, 207 insertions(+), 150 deletions(-) create mode 100644 cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu rename cpp/tests/neighbors/ann_ivf_sq/{test_float_uint8_t.cu => test_half_int64_t.cu} (55%) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 59a5f1fb70..eea4d7f775 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -135,7 +135,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_IVF_SQ_TEST - PATH neighbors/ann_ivf_sq/test_float_uint8_t.cu + PATH neighbors/ann_ivf_sq/test_float_int64_t.cu neighbors/ann_ivf_sq/test_half_int64_t.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh index 5cc4e55075..f7311b75e4 100644 --- a/cpp/tests/neighbors/ann_ivf_sq.cuh +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -57,143 +57,67 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { { } - void testIVFSQ() + void testSearch() { - size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivfsq(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_ivfsq(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - cuvs::neighbors::naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric); - raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - raft::resource::sync_stream(handle_); - } - - { - double min_recall = - std::min(1.0, static_cast(ps.nprobe) / static_cast(ps.nlist)); - - rmm::device_uvector distances_ivfsq_dev(queries_size, stream_); - rmm::device_uvector indices_ivfsq_dev(queries_size, stream_); - - { - cuvs::neighbors::ivf_sq::index_params index_params; - cuvs::neighbors::ivf_sq::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - search_params.n_probes = ps.nprobe; - - index_params.add_data_on_build = true; - index_params.kmeans_trainset_fraction = 0.5; - - cuvs::neighbors::ivf_sq::index_params index_params_no_add; - index_params_no_add.n_lists = ps.nlist; - index_params_no_add.metric = ps.metric; - index_params_no_add.add_data_on_build = false; - index_params_no_add.kmeans_trainset_fraction = 0.5; - - cuvs::neighbors::ivf_sq::index idx(handle_); - cuvs::neighbors::ivf_sq::index idx_empty(handle_); - - if (!ps.host_dataset) { - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - - idx = cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); - - idx_empty = cuvs::neighbors::ivf_sq::build(handle_, index_params_no_add, database_view); - - auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); - raft::linalg::map_offset(handle_, vector_indices.view(), raft::identity_op{}); - raft::resource::sync_stream(handle_); - - auto indices_view = raft::make_device_vector_view( - vector_indices.data_handle(), ps.num_db_vecs); - cuvs::neighbors::ivf_sq::extend( - handle_, - database_view, - std::make_optional>(indices_view), - &idx_empty); - } else { - auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); - raft::copy( - host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); - raft::resource::sync_stream(handle_); - - idx = cuvs::neighbors::ivf_sq::build( - handle_, index_params, raft::make_const_mdspan(host_database.view())); - - idx_empty = cuvs::neighbors::ivf_sq::build( - handle_, index_params_no_add, raft::make_const_mdspan(host_database.view())); - - auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); - std::iota( - vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, IdxT(0)); - - auto indices_view = raft::make_host_vector_view( - vector_indices.data_handle(), ps.num_db_vecs); - auto host_database_view = raft::make_host_matrix_view( - host_database.data_handle(), ps.num_db_vecs, ps.dim); - cuvs::neighbors::ivf_sq::extend( - handle_, - host_database_view, - std::make_optional>(indices_view), - &idx_empty); - } - - // Serialize / deserialize round-trip - tmp_index_file index_file; - cuvs::neighbors::ivf_sq::serialize(handle_, index_file.filename, idx); - cuvs::neighbors::ivf_sq::index index_loaded(handle_); - cuvs::neighbors::ivf_sq::deserialize(handle_, index_file.filename, &index_loaded); - ASSERT_EQ(idx.size(), index_loaded.size()); - ASSERT_EQ(idx.dim(), index_loaded.dim()); - ASSERT_EQ(idx.n_lists(), index_loaded.n_lists()); - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - auto indices_out_view = - raft::make_device_matrix_view(indices_ivfsq_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = - raft::make_device_matrix_view(distances_ivfsq_dev.data(), ps.num_queries, ps.k); + auto naive = compute_naive_knn(); + auto idx = build_index(true); + auto results = search_index(idx); + + float eps = 0.1; + ASSERT_TRUE(eval_neighbours(naive.indices, + results.indices, + naive.distances, + results.distances, + ps.num_queries, + ps.k, + eps, + min_recall_threshold())); + } - cuvs::neighbors::ivf_sq::search(handle_, - search_params, - index_loaded, - search_queries_view, - indices_out_view, - dists_out_view); + void testSerialize() + { + auto idx = build_index(true); + + tmp_index_file index_file; + cuvs::neighbors::ivf_sq::serialize(handle_, index_file.filename, idx); + cuvs::neighbors::ivf_sq::index index_loaded(handle_); + cuvs::neighbors::ivf_sq::deserialize(handle_, index_file.filename, &index_loaded); + + ASSERT_EQ(idx.size(), index_loaded.size()); + ASSERT_EQ(idx.dim(), index_loaded.dim()); + ASSERT_EQ(idx.n_lists(), index_loaded.n_lists()); + + auto results_orig = search_index(idx); + auto results_loaded = search_index(index_loaded); + + float eps = 0.001; + ASSERT_TRUE(eval_neighbours(results_orig.indices, + results_loaded.indices, + results_orig.distances, + results_loaded.distances, + ps.num_queries, + ps.k, + eps, + 1.0)); + } - raft::update_host( - distances_ivfsq.data(), distances_ivfsq_dev.data(), queries_size, stream_); - raft::update_host(indices_ivfsq.data(), indices_ivfsq_dev.data(), queries_size, stream_); - raft::resource::sync_stream(handle_); - } - // SQ introduces quantization error, so we relax the distance epsilon - float eps = 0.1; - ASSERT_TRUE(eval_neighbours(indices_naive, - indices_ivfsq, - distances_naive, - distances_ivfsq, - ps.num_queries, - ps.k, - eps, - min_recall)); - } + void testExtend() + { + auto naive = compute_naive_knn(); + auto idx_empty = build_index(false); + extend_index(&idx_empty); + + auto results = search_index(idx_empty); + + float eps = 0.1; + ASSERT_TRUE(eval_neighbours(naive.indices, + results.indices, + naive.distances, + results.distances, + ps.num_queries, + ps.k, + eps, + min_recall_threshold())); } void testFilter() @@ -318,6 +242,128 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { } private: + struct SearchResults { + std::vector indices; + std::vector distances; + }; + + double min_recall_threshold() + { + return std::min(1.0, static_cast(ps.nprobe) / static_cast(ps.nlist)); + } + + SearchResults compute_naive_knn() + { + size_t queries_size = ps.num_queries * ps.k; + SearchResults results; + results.indices.resize(queries_size); + results.distances.resize(queries_size); + + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + cuvs::neighbors::naive_knn(handle_, + distances_dev.data(), + indices_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + raft::update_host(results.distances.data(), distances_dev.data(), queries_size, stream_); + raft::update_host(results.indices.data(), indices_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + return results; + } + + cuvs::neighbors::ivf_sq::index build_index(bool add_data_on_build) + { + cuvs::neighbors::ivf_sq::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.add_data_on_build = add_data_on_build; + index_params.kmeans_trainset_fraction = 0.5; + + if (!ps.host_dataset) { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + return cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); + } else { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy(host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + raft::resource::sync_stream(handle_); + return cuvs::neighbors::ivf_sq::build( + handle_, index_params, raft::make_const_mdspan(host_database.view())); + } + } + + void extend_index(cuvs::neighbors::ivf_sq::index* idx) + { + if (!ps.host_dataset) { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); + raft::linalg::map_offset(handle_, vector_indices.view(), raft::identity_op{}); + raft::resource::sync_stream(handle_); + + auto indices_view = raft::make_device_vector_view( + vector_indices.data_handle(), ps.num_db_vecs); + cuvs::neighbors::ivf_sq::extend( + handle_, + database_view, + std::make_optional>(indices_view), + idx); + } else { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy(host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + raft::resource::sync_stream(handle_); + + auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); + std::iota( + vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, IdxT(0)); + + auto indices_view = + raft::make_host_vector_view(vector_indices.data_handle(), ps.num_db_vecs); + auto host_database_view = raft::make_host_matrix_view( + host_database.data_handle(), ps.num_db_vecs, ps.dim); + cuvs::neighbors::ivf_sq::extend( + handle_, + host_database_view, + std::make_optional>(indices_view), + idx); + } + } + + SearchResults search_index(const cuvs::neighbors::ivf_sq::index& idx) + { + size_t queries_size = ps.num_queries * ps.k; + SearchResults results; + results.indices.resize(queries_size); + results.distances.resize(queries_size); + + cuvs::neighbors::ivf_sq::search_params search_params; + search_params.n_probes = ps.nprobe; + + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = + raft::make_device_matrix_view(distances_dev.data(), ps.num_queries, ps.k); + + cuvs::neighbors::ivf_sq::search( + handle_, search_params, idx, search_queries_view, indices_out_view, dists_out_view); + + raft::update_host(results.distances.data(), distances_dev.data(), queries_size, stream_); + raft::update_host(results.indices.data(), indices_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + return results; + } + raft::resources handle_; rmm::cuda_stream_view stream_; AnnIvfSqInputs ps; @@ -398,7 +444,8 @@ const std::vector> inputs = { {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::L2Expanded}, {1000, 10000, 16, 100, 200, 1024, cuvs::distance::DistanceType::InnerProduct}, - // ===== Large k (beyond fused top-k kMaxSqScanCapacity=256, exercises materialized fallback) ===== + // ===== Large k (beyond fused top-k kMaxSqScanCapacity=256, exercises materialized fallback) + // ===== // k=257: smallest k that forces the materialized path (Capacity clamped to 0) {100, 10000, 32, 257, 100, 64, cuvs::distance::DistanceType::L2Expanded}, {100, 10000, 32, 257, 100, 64, cuvs::distance::DistanceType::InnerProduct}, diff --git a/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu b/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu new file mode 100644 index 0000000000..734f736b09 --- /dev/null +++ b/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu @@ -0,0 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "../ann_ivf_sq.cuh" + +namespace cuvs::neighbors::ivf_sq { + +typedef AnnIVFSQTest AnnIVFSQTestF_float; +TEST_P(AnnIVFSQTestF_float, AnnIVFSQSearch) { this->testSearch(); } +TEST_P(AnnIVFSQTestF_float, AnnIVFSQSerialize) { this->testSerialize(); } +TEST_P(AnnIVFSQTestF_float, AnnIVFSQExtend) { this->testExtend(); } +TEST_P(AnnIVFSQTestF_float, AnnIVFSQFilter) { this->testFilter(); } + +INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_float, ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu b/cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu similarity index 55% rename from cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu rename to cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu index f136b6f41c..e6f5e44dd3 100644 --- a/cpp/tests/neighbors/ann_ivf_sq/test_float_uint8_t.cu +++ b/cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu @@ -9,21 +9,11 @@ namespace cuvs::neighbors::ivf_sq { -typedef AnnIVFSQTest AnnIVFSQTestF_float; -TEST_P(AnnIVFSQTestF_float, AnnIVFSQ) -{ - this->testIVFSQ(); - this->testFilter(); -} - -INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_float, ::testing::ValuesIn(inputs)); - typedef AnnIVFSQTest AnnIVFSQTestF_half; -TEST_P(AnnIVFSQTestF_half, AnnIVFSQ) -{ - this->testIVFSQ(); - this->testFilter(); -} +TEST_P(AnnIVFSQTestF_half, AnnIVFSQSearch) { this->testSearch(); } +TEST_P(AnnIVFSQTestF_half, AnnIVFSQSerialize) { this->testSerialize(); } +TEST_P(AnnIVFSQTestF_half, AnnIVFSQExtend) { this->testExtend(); } +TEST_P(AnnIVFSQTestF_half, AnnIVFSQFilter) { this->testFilter(); } INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_half, ::testing::ValuesIn(inputs_half)); From 1b182d70487eae2b1e0b5609fc775506a90167de Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 20 Apr 2026 13:34:01 +0200 Subject: [PATCH 23/31] Swap IdxT for CodeT --- cpp/include/cuvs/neighbors/ivf_sq.hpp | 36 ++--- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 76 +++++----- ...f_sq_build_extend_float_uint8_t_int64_t.cu | 142 +++++++++--------- ...vf_sq_build_extend_half_uint8_t_int64_t.cu | 142 +++++++++--------- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 86 ++++++----- .../ivf_sq_search_float_uint8_t_int64_t.cu | 4 +- .../ivf_sq_search_half_uint8_t_int64_t.cu | 4 +- cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh | 54 +++---- cpp/src/neighbors/ivf_sq_index.cpp | 124 +++++++-------- 9 files changed, 336 insertions(+), 332 deletions(-) diff --git a/cpp/include/cuvs/neighbors/ivf_sq.hpp b/cpp/include/cuvs/neighbors/ivf_sq.hpp index bb29b3fef1..ba4bb39437 100644 --- a/cpp/include/cuvs/neighbors/ivf_sq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_sq.hpp @@ -78,13 +78,13 @@ static_assert(std::is_aggregate_v); * @{ */ -template +template struct list_spec { - static_assert(std::is_same_v, "IVF-SQ code type IdxT must be uint8_t"); + static_assert(std::is_same_v, "IVF-SQ code type CodeT must be uint8_t"); - using value_type = IdxT; + using value_type = CodeT; using list_extents = raft::matrix_extent; - using index_type = ExtT; + using index_type = IdxT; SizeT align_max; SizeT align_min; @@ -98,7 +98,7 @@ struct list_spec { } template - constexpr explicit list_spec(const list_spec& other_spec) + constexpr explicit list_spec(const list_spec& other_spec) : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} { } @@ -112,8 +112,8 @@ struct list_spec { } }; -template -using list_data = ivf::list; +template +using list_data = ivf::list; /** * @} @@ -142,18 +142,18 @@ using list_data = ivf::list; * search speed, and recall compared to flat (uncompressed) and product-quantized (PQ) * representations. * - * @tparam IdxT SQ code type. Only uint8_t (8-bit, codes in [0,255]) for now. + * @tparam CodeT SQ code type. Only uint8_t (8-bit, codes in [0,255]) for now. * */ -template +template struct index : cuvs::neighbors::index { - static_assert(std::is_same_v, "IVF-SQ code type IdxT must be uint8_t for now."); + static_assert(std::is_same_v, "IVF-SQ code type CodeT must be uint8_t for now."); using index_params_type = ivf_sq::index_params; using search_params_type = ivf_sq::search_params; - using code_type = IdxT; + using code_type = CodeT; - static constexpr uint32_t sq_bits = sizeof(IdxT) * 8; + static constexpr uint32_t sq_bits = sizeof(CodeT) * 8; public: index(const index&) = delete; @@ -195,14 +195,14 @@ struct index : cuvs::neighbors::index { raft::host_vector_view accum_sorted_sizes() noexcept; [[nodiscard]] raft::host_vector_view accum_sorted_sizes() const noexcept; - raft::device_vector_view data_ptrs() noexcept; - raft::device_vector_view data_ptrs() const noexcept; + raft::device_vector_view data_ptrs() noexcept; + raft::device_vector_view data_ptrs() const noexcept; raft::device_vector_view inds_ptrs() noexcept; raft::device_vector_view inds_ptrs() const noexcept; - std::vector>>& lists() noexcept; - const std::vector>>& lists() const noexcept; + std::vector>>& lists() noexcept; + const std::vector>>& lists() const noexcept; void check_consistency(); @@ -210,14 +210,14 @@ struct index : cuvs::neighbors::index { cuvs::distance::DistanceType metric_; bool conservative_memory_allocation_; - std::vector>> lists_; + std::vector>> lists_; raft::device_vector list_sizes_; raft::device_matrix centers_; std::optional> center_norms_; raft::device_vector sq_vmin_; raft::device_vector sq_delta_; - raft::device_vector data_ptrs_; + raft::device_vector data_ptrs_; raft::device_vector inds_ptrs_; raft::host_vector accum_sorted_sizes_; }; diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 7934460fb5..7940d0d7fb 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -104,12 +104,12 @@ __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __r } } -template -auto clone(const raft::resources& res, const index& source) -> index +template +auto clone(const raft::resources& res, const index& source) -> index { auto stream = raft::resource::get_cuda_stream(res); - index target( + index target( res, source.metric(), source.n_lists(), source.dim(), source.conservative_memory_allocation()); raft::copy(target.list_sizes().data_handle(), @@ -219,9 +219,9 @@ RAFT_KERNEL compute_residuals_inplace_kernel( } } -template +template void extend(raft::resources const& handle, - index* index, + index* index, const T* new_vectors, const int64_t* new_indices, int64_t n_rows) @@ -233,8 +233,8 @@ void extend(raft::resources const& handle, auto stream = raft::resource::get_cuda_stream(handle); auto n_lists = index->n_lists(); auto dim = index->dim(); - list_spec list_device_spec{index->dim(), - index->conservative_memory_allocation()}; + list_spec list_device_spec{index->dim(), + index->conservative_memory_allocation()}; cuvs::common::nvtx::range fun_scope( "ivf_sq::extend(%zu, %u)", size_t(n_rows), dim); @@ -401,24 +401,24 @@ void extend(raft::resources const& handle, } } -template +template auto extend(raft::resources const& handle, - const index& orig_index, + const index& orig_index, const T* new_vectors, const int64_t* new_indices, - int64_t n_rows) -> index + int64_t n_rows) -> index { auto ext_index = clone(handle, orig_index); detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); return ext_index; } -template +template inline auto build( raft::resources const& handle, const index_params& params, raft::mdspan, raft::row_major, accessor> dataset) - -> index + -> index { int64_t n_rows = dataset.extent(0); uint32_t dim = dataset.extent(1); @@ -431,7 +431,7 @@ inline auto build( RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::CosineExpanded || dim > 1, "Cosine metric requires more than one dim"); - index idx(handle, params, dim); + index idx(handle, params, dim); // Train k-means centroids and SQ parameters on the same training subset. // This mirrors IVF-PQ, which also trains its codebook on a subset of the data. @@ -502,35 +502,35 @@ inline auto build( } if (params.add_data_on_build) { - detail::extend(handle, &idx, dataset.data_handle(), nullptr, n_rows); + detail::extend(handle, &idx, dataset.data_handle(), nullptr, n_rows); } return idx; } -template +template void build(raft::resources const& handle, const index_params& params, raft::device_matrix_view dataset, - index& idx) + index& idx) { - idx = build(handle, params, dataset); + idx = build(handle, params, dataset); } -template +template void build(raft::resources const& handle, const index_params& params, raft::host_matrix_view dataset, - index& idx) + index& idx) { - idx = build(handle, params, dataset); + idx = build(handle, params, dataset); } -template +template auto extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, - const index& orig_index) -> index + const index& orig_index) -> index { RAFT_EXPECTS(new_vectors.extent(1) == orig_index.dim(), "new_vectors should have the same dimension as the index"); @@ -539,18 +539,18 @@ auto extend(raft::resources const& handle, "new_vectors and new_indices have different number of rows"); } int64_t n_rows = new_vectors.extent(0); - return extend(handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } -template +template auto extend(raft::resources const& handle, raft::host_matrix_view new_vectors, std::optional> new_indices, - const index& orig_index) -> index + const index& orig_index) -> index { RAFT_EXPECTS(new_vectors.extent(1) == orig_index.dim(), "new_vectors should have the same dimension as the index"); @@ -559,18 +559,18 @@ auto extend(raft::resources const& handle, "new_vectors and new_indices have different number of rows"); } int64_t n_rows = new_vectors.extent(0); - return extend(handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } -template +template void extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, - index* idx) + index* idx) { RAFT_EXPECTS(new_vectors.extent(1) == idx->dim(), "new_vectors should have the same dimension as the index"); @@ -585,11 +585,11 @@ void extend(raft::resources const& handle, new_vectors.extent(0)); } -template +template void extend(raft::resources const& handle, raft::host_matrix_view new_vectors, std::optional> new_indices, - index* idx) + index* idx) { RAFT_EXPECTS(new_vectors.extent(1) == idx->dim(), "new_vectors should have the same dimension as the index"); diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu index a97aebb11c..f22154e50f 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu @@ -9,77 +9,77 @@ namespace cuvs::neighbors::ivf_sq { -#define CUVS_INST_IVF_SQ_BUILD_EXTEND(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::device_matrix_view dataset) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_sq::index& idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ - } \ - \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::host_matrix_view dataset) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_sq::index& idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ - } \ - \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_sq::index& orig_index) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::extend( \ - handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_sq::index* idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_sq::index& orig_index) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::extend( \ - handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_sq::index* idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ +#define CUVS_INST_IVF_SQ_BUILD_EXTEND(T, CodeT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ } CUVS_INST_IVF_SQ_BUILD_EXTEND(float, uint8_t); diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu index 9148e5c328..e6900af4a0 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu @@ -9,77 +9,77 @@ namespace cuvs::neighbors::ivf_sq { -#define CUVS_INST_IVF_SQ_BUILD_EXTEND(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::device_matrix_view dataset) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_sq::index& idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ - } \ - \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::host_matrix_view dataset) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_sq::index& idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ - } \ - \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_sq::index& orig_index) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::extend( \ - handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_sq::index* idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_sq::index& orig_index) \ - -> cuvs::neighbors::ivf_sq::index \ - { \ - return cuvs::neighbors::ivf_sq::index( \ - std::move(cuvs::neighbors::ivf_sq::detail::extend( \ - handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_sq::index* idx) \ - { \ - cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ +#define CUVS_INST_IVF_SQ_BUILD_EXTEND(T, CodeT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_sq::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::ivf_sq::index& idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::build(handle, params, dataset, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ + } \ + \ + auto extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_sq::index& orig_index) \ + -> cuvs::neighbors::ivf_sq::index \ + { \ + return cuvs::neighbors::ivf_sq::index( \ + std::move(cuvs::neighbors::ivf_sq::detail::extend( \ + handle, new_vectors, new_indices, orig_index))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_sq::index* idx) \ + { \ + cuvs::neighbors::ivf_sq::detail::extend(handle, new_vectors, new_indices, idx); \ } CUVS_INST_IVF_SQ_BUILD_EXTEND(half, uint8_t); diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index 17e6c4f9eb..9728b90796 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -117,7 +117,11 @@ inline uint32_t configure_grid_dim_x( // // After all probes are scanned, the smem is reused for block_sort merge. // --------------------------------------------------------------------------- -template +template __launch_bounds__(BlockDim) RAFT_KERNEL ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, const uint32_t* list_sizes, const uint32_t* coarse_indices, @@ -326,8 +330,8 @@ size_t sq_scan_total_smem(uint32_t dim, uint32_t k) // --------------------------------------------------------------------------- // Launch helper: dispatches on Metric, handles grid_dim_x query vs launch // --------------------------------------------------------------------------- -template -void ivf_sq_scan_launch(const index& idx, +template +void ivf_sq_scan_launch(const index& idx, const float* queries_float, const float* query_norms, uint32_t n_queries, @@ -425,14 +429,14 @@ void ivf_sq_scan_launch(const index& idx, switch (idx.metric()) { case cuvs::distance::DistanceType::L2Expanded: case cuvs::distance::DistanceType::L2SqrtExpanded: - do_launch(ivf_sq_scan_kernel); + do_launch(ivf_sq_scan_kernel); break; case cuvs::distance::DistanceType::InnerProduct: - do_launch(ivf_sq_scan_kernel); + do_launch(ivf_sq_scan_kernel); break; case cuvs::distance::DistanceType::CosineExpanded: do_launch( - ivf_sq_scan_kernel); + ivf_sq_scan_kernel); break; default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); } @@ -441,9 +445,9 @@ void ivf_sq_scan_launch(const index& idx, // --------------------------------------------------------------------------- // ivf_sq_scan: top-level scan dispatch with Capacity selection // --------------------------------------------------------------------------- -template +template void ivf_sq_scan(raft::resources const& handle, - const index& idx, + const index& idx, const float* queries_float, const float* query_norms, uint32_t n_queries, @@ -472,20 +476,20 @@ void ivf_sq_scan(raft::resources const& handle, } auto fwd = [&](auto cap_tag) { - ivf_sq_scan_launch(idx, - queries_float, - query_norms, - n_queries, - n_probes, - k, - max_samples, - coarse_indices, - chunk_indices, - out_distances, - out_indices, - sample_filter, - grid_dim_x, - stream); + ivf_sq_scan_launch(idx, + queries_float, + query_norms, + n_queries, + n_probes, + k, + max_samples, + coarse_indices, + chunk_indices, + out_distances, + out_indices, + sample_filter, + grid_dim_x, + stream); }; switch (capacity) { @@ -501,9 +505,9 @@ void ivf_sq_scan(raft::resources const& handle, // --------------------------------------------------------------------------- // search_impl — host-side search logic // --------------------------------------------------------------------------- -template +template void search_impl(raft::resources const& handle, - const index& index, + const index& index, const T* queries, uint32_t n_queries, uint32_t k, @@ -802,11 +806,11 @@ void search_impl(raft::resources const& handle, } template inline void search_with_filtering(raft::resources const& handle, const search_params& params, - const index& index, + const index& index, const T* queries, uint32_t n_queries, uint32_t k, @@ -864,24 +868,24 @@ inline void search_with_filtering(raft::resources const& handle, for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = std::min(max_queries, n_queries - offset_q); - search_impl(handle, - index, - queries + std::size_t(offset_q) * index.dim(), - queries_batch, - k, - n_probes, - cuvs::distance::is_min_close(index.metric()), - neighbors + std::size_t(offset_q) * k, - distances + std::size_t(offset_q) * k, - raft::resource::get_workspace_resource(handle), - sample_filter); + search_impl(handle, + index, + queries + std::size_t(offset_q) * index.dim(), + queries_batch, + k, + n_probes, + cuvs::distance::is_min_close(index.metric()), + neighbors + std::size_t(offset_q) * k, + distances + std::size_t(offset_q) * k, + raft::resource::get_workspace_resource(handle), + sample_filter); } } -template +template void search_with_filtering(raft::resources const& handle, const search_params& params, - const index& index, + const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, @@ -906,10 +910,10 @@ void search_with_filtering(raft::resources const& handle, sample_filter); } -template +template void search(raft::resources const& handle, const search_params& params, - const index& idx, + const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu index 60d95a153f..de185de8ec 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu @@ -9,10 +9,10 @@ namespace cuvs::neighbors::ivf_sq { -#define CUVS_INST_IVF_SQ_SEARCH(T, IdxT) \ +#define CUVS_INST_IVF_SQ_SEARCH(T, CodeT) \ void search(raft::resources const& handle, \ const cuvs::neighbors::ivf_sq::search_params& params, \ - const cuvs::neighbors::ivf_sq::index& index, \ + const cuvs::neighbors::ivf_sq::index& index, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances, \ diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu index fbed3fd432..40029119b2 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu @@ -9,10 +9,10 @@ namespace cuvs::neighbors::ivf_sq { -#define CUVS_INST_IVF_SQ_SEARCH(T, IdxT) \ +#define CUVS_INST_IVF_SQ_SEARCH(T, CodeT) \ void search(raft::resources const& handle, \ const cuvs::neighbors::ivf_sq::search_params& params, \ - const cuvs::neighbors::ivf_sq::index& index, \ + const cuvs::neighbors::ivf_sq::index& index, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances, \ diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh index 8aa1f12e04..b201cceee7 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_serialize.cuh @@ -22,13 +22,13 @@ namespace cuvs::neighbors::ivf_sq::detail { constexpr int serialization_version = 1; -template -void serialize(raft::resources const& handle, std::ostream& os, const index& index_) +template +void serialize(raft::resources const& handle, std::ostream& os, const index& index_) { RAFT_LOG_DEBUG( "Saving IVF-SQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); dtype_string.resize(4); os << dtype_string; @@ -60,7 +60,7 @@ void serialize(raft::resources const& handle, std::ostream& os, const index list_store_spec{index_.dim(), true}; + list_spec list_store_spec{index_.dim(), true}; for (uint32_t label = 0; label < index_.n_lists(); label++) { ivf::serialize_list(handle, os, @@ -71,10 +71,10 @@ void serialize(raft::resources const& handle, std::ostream& os, const index +template void serialize(raft::resources const& handle, const std::string& filename, - const index& index_) + const index& index_) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } @@ -83,8 +83,8 @@ void serialize(raft::resources const& handle, if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } -template -auto deserialize(raft::resources const& handle, std::istream& is) -> index +template +auto deserialize(raft::resources const& handle, std::istream& is) -> index { char dtype_string[4]; is.read(dtype_string, 4); @@ -99,7 +99,7 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index auto metric = raft::deserialize_scalar(handle, is); bool cma = raft::deserialize_scalar(handle, is); - index index_ = index(handle, metric, n_lists, dim, cma); + index index_ = index(handle, metric, n_lists, dim, cma); deserialize_mdspan(handle, is, index_.centers()); @@ -119,8 +119,8 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index deserialize_mdspan(handle, is, index_.list_sizes()); - list_spec list_device_spec{index_.dim(), cma}; - list_spec list_store_spec{index_.dim(), true}; + list_spec list_device_spec{index_.dim(), cma}; + list_spec list_store_spec{index_.dim(), true}; for (uint32_t label = 0; label < index_.n_lists(); label++) { ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); } @@ -131,29 +131,29 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index return index_; } -template -auto deserialize(raft::resources const& handle, const std::string& filename) -> index +template +auto deserialize(raft::resources const& handle, const std::string& filename) -> index { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - auto index = detail::deserialize(handle, is); + auto index = detail::deserialize(handle, is); is.close(); return index; } } // namespace cuvs::neighbors::ivf_sq::detail -#define CUVS_INST_IVF_SQ_SERIALIZE(IdxT) \ - void serialize(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::ivf_sq::index& index) \ - { \ - cuvs::neighbors::ivf_sq::detail::serialize(handle, filename, index); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::ivf_sq::index* index) \ - { \ - *index = cuvs::neighbors::ivf_sq::detail::deserialize(handle, filename); \ +#define CUVS_INST_IVF_SQ_SERIALIZE(CodeT) \ + void serialize(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_sq::index& index) \ + { \ + cuvs::neighbors::ivf_sq::detail::serialize(handle, filename, index); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_sq::index* index) \ + { \ + *index = cuvs::neighbors::ivf_sq::detail::deserialize(handle, filename); \ } diff --git a/cpp/src/neighbors/ivf_sq_index.cpp b/cpp/src/neighbors/ivf_sq_index.cpp index bf5b6df288..5a110a476c 100644 --- a/cpp/src/neighbors/ivf_sq_index.cpp +++ b/cpp/src/neighbors/ivf_sq_index.cpp @@ -12,24 +12,24 @@ namespace cuvs::neighbors::ivf_sq { -template -index::index(raft::resources const& res) +template +index::index(raft::resources const& res) : index(res, cuvs::distance::DistanceType::L2Expanded, 0, 0, false) { } -template -index::index(raft::resources const& res, const index_params& params, uint32_t dim) +template +index::index(raft::resources const& res, const index_params& params, uint32_t dim) : index(res, params.metric, params.n_lists, dim, params.conservative_memory_allocation) { } -template -index::index(raft::resources const& res, - cuvs::distance::DistanceType metric, - uint32_t n_lists, - uint32_t dim, - bool conservative_memory_allocation) +template +index::index(raft::resources const& res, + cuvs::distance::DistanceType metric, + uint32_t n_lists, + uint32_t dim, + bool conservative_memory_allocation) : cuvs::neighbors::index(), metric_(metric), conservative_memory_allocation_(conservative_memory_allocation), @@ -39,7 +39,7 @@ index::index(raft::resources const& res, center_norms_(std::nullopt), sq_vmin_{raft::make_device_vector(res, dim)}, sq_delta_{raft::make_device_vector(res, dim)}, - data_ptrs_{raft::make_device_vector(res, n_lists)}, + data_ptrs_{raft::make_device_vector(res, n_lists)}, inds_ptrs_{raft::make_device_vector(res, n_lists)}, accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} { @@ -49,68 +49,68 @@ index::index(raft::resources const& res, RAFT_CUDA_TRY( cudaMemsetAsync(list_sizes_.data_handle(), 0, list_sizes_.size() * sizeof(uint32_t), stream)); RAFT_CUDA_TRY( - cudaMemsetAsync(data_ptrs_.data_handle(), 0, data_ptrs_.size() * sizeof(IdxT*), stream)); + cudaMemsetAsync(data_ptrs_.data_handle(), 0, data_ptrs_.size() * sizeof(CodeT*), stream)); RAFT_CUDA_TRY( cudaMemsetAsync(inds_ptrs_.data_handle(), 0, inds_ptrs_.size() * sizeof(int64_t*), stream)); } -template -cuvs::distance::DistanceType index::metric() const noexcept +template +cuvs::distance::DistanceType index::metric() const noexcept { return metric_; } -template -int64_t index::size() const noexcept +template +int64_t index::size() const noexcept { return accum_sorted_sizes()(n_lists()); } -template -uint32_t index::dim() const noexcept +template +uint32_t index::dim() const noexcept { return centers_.extent(1); } -template -uint32_t index::n_lists() const noexcept +template +uint32_t index::n_lists() const noexcept { return lists_.size(); } -template -bool index::conservative_memory_allocation() const noexcept +template +bool index::conservative_memory_allocation() const noexcept { return conservative_memory_allocation_; } -template -raft::device_vector_view index::list_sizes() noexcept +template +raft::device_vector_view index::list_sizes() noexcept { return list_sizes_.view(); } -template -raft::device_vector_view index::list_sizes() const noexcept +template +raft::device_vector_view index::list_sizes() const noexcept { return list_sizes_.view(); } -template -raft::device_matrix_view index::centers() noexcept +template +raft::device_matrix_view index::centers() noexcept { return centers_.view(); } -template -raft::device_matrix_view index::centers() +template +raft::device_matrix_view index::centers() const noexcept { return centers_.view(); } -template -std::optional> index::center_norms() noexcept +template +std::optional> index::center_norms() noexcept { if (center_norms_.has_value()) { return std::make_optional>(center_norms_->view()); @@ -119,8 +119,8 @@ std::optional> index::center_nor } } -template -std::optional> index::center_norms() +template +std::optional> index::center_norms() const noexcept { if (center_norms_.has_value()) { @@ -131,8 +131,8 @@ std::optional> index::cent } } -template -void index::allocate_center_norms(raft::resources const& res) +template +void index::allocate_center_norms(raft::resources const& res) { switch (metric_) { case cuvs::distance::DistanceType::L2Expanded: @@ -146,80 +146,80 @@ void index::allocate_center_norms(raft::resources const& res) } } -template -raft::device_vector_view index::sq_vmin() noexcept +template +raft::device_vector_view index::sq_vmin() noexcept { return sq_vmin_.view(); } -template -raft::device_vector_view index::sq_vmin() const noexcept +template +raft::device_vector_view index::sq_vmin() const noexcept { return sq_vmin_.view(); } -template -raft::device_vector_view index::sq_delta() noexcept +template +raft::device_vector_view index::sq_delta() noexcept { return sq_delta_.view(); } -template -raft::device_vector_view index::sq_delta() const noexcept +template +raft::device_vector_view index::sq_delta() const noexcept { return sq_delta_.view(); } -template -raft::host_vector_view index::accum_sorted_sizes() noexcept +template +raft::host_vector_view index::accum_sorted_sizes() noexcept { return accum_sorted_sizes_.view(); } -template -raft::host_vector_view index::accum_sorted_sizes() const noexcept +template +raft::host_vector_view index::accum_sorted_sizes() const noexcept { return accum_sorted_sizes_.view(); } -template -raft::device_vector_view index::data_ptrs() noexcept +template +raft::device_vector_view index::data_ptrs() noexcept { return data_ptrs_.view(); } -template -raft::device_vector_view index::data_ptrs() const noexcept +template +raft::device_vector_view index::data_ptrs() const noexcept { return data_ptrs_.view(); } -template -raft::device_vector_view index::inds_ptrs() noexcept +template +raft::device_vector_view index::inds_ptrs() noexcept { return inds_ptrs_.view(); } -template -raft::device_vector_view index::inds_ptrs() const noexcept +template +raft::device_vector_view index::inds_ptrs() const noexcept { return inds_ptrs_.view(); } -template -std::vector>>& index::lists() noexcept +template +std::vector>>& index::lists() noexcept { return lists_; } -template -const std::vector>>& index::lists() const noexcept +template +const std::vector>>& index::lists() const noexcept { return lists_; } -template -void index::check_consistency() +template +void index::check_consistency() { auto n_lists = lists_.size(); RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); From 8c44557997c77a0d41c865e8c046fc7a01ba155a Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 20 Apr 2026 16:12:20 +0200 Subject: [PATCH 24/31] addressing review --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 229 +++++++++++++----- cpp/tests/neighbors/ann_ivf_sq.cuh | 60 +++-- .../ann_ivf_sq/test_float_int64_t.cu | 5 +- .../neighbors/ann_ivf_sq/test_half_int64_t.cu | 5 +- 4 files changed, 207 insertions(+), 92 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 7940d0d7fb..570ccbb872 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -56,15 +56,99 @@ struct ColMinMaxOp { } }; +/** + * Vectorized load helper: reads VecCols contiguous elements of type T as + * a single aligned wide load and unpacks them into floats. + * + * The primary benefit over scalar loads is halving (VecCols=2) or + * quartering (VecCols=4) the number of LDG instructions issued per warp, + * which is the dominant cost in the column-strided access pattern of + * fused_column_minmax_kernel. VecCols=1 is provided as the degenerate + * scalar fallback for odd `dim`. + * + * Requires `p` to be aligned to sizeof(T) * VecCols. + */ +template +struct vec_loader; + +template <> +struct vec_loader { + __device__ __forceinline__ static void load(const float* p, float (&out)[1]) { out[0] = *p; } +}; + +template <> +struct vec_loader { + __device__ __forceinline__ static void load(const half* p, float (&out)[1]) + { + out[0] = float(*p); + } +}; + +template <> +struct vec_loader { + __device__ __forceinline__ static void load(const float* p, float (&out)[4]) + { + float4 v = *reinterpret_cast(p); + out[0] = v.x; + out[1] = v.y; + out[2] = v.z; + out[3] = v.w; + } +}; + +template <> +struct vec_loader { + __device__ __forceinline__ static void load(const float* p, float (&out)[2]) + { + float2 v = *reinterpret_cast(p); + out[0] = v.x; + out[1] = v.y; + } +}; + +template <> +struct vec_loader { + __device__ __forceinline__ static void load(const half* p, float (&out)[4]) + { + // Single 8-byte load covering 4 halves; memcpy avoids aliasing issues + // and is compiled to a register move in device code. + uint2 raw = *reinterpret_cast(p); + half h[4]; + static_assert(sizeof(h) == sizeof(raw), "unexpected half packing"); + memcpy(&h[0], &raw, sizeof(raw)); +#pragma unroll + for (int k = 0; k < 4; ++k) + out[k] = float(h[k]); + } +}; + +template <> +struct vec_loader { + __device__ __forceinline__ static void load(const half* p, float (&out)[2]) + { + uint32_t raw = *reinterpret_cast(p); + half h[2]; + static_assert(sizeof(h) == sizeof(raw), "unexpected half packing"); + memcpy(&h[0], &raw, sizeof(raw)); + out[0] = float(h[0]); + out[1] = float(h[1]); + } +}; + /** * Fused per-column min+max in a single pass (2x less DRAM traffic than two - * separate reductions). One thread block per column; threads stride over - * rows and feed CUB BlockReduce with a combined min/max pair. + * separate reductions). Each block owns VecCols contiguous columns and + * threads read them with a single aligned wide load (float4/float2 for + * float, uint2/uint32 for half) per row instead of VecCols scalar loads. + * Requires dim to be a multiple of VecCols so that + * `data + row*dim + col_base` is `sizeof(T)*VecCols` aligned for every + * row; the host-side dispatcher picks the widest VecCols that satisfies + * this, down to VecCols=1 (pure scalar fallback) for arbitrary dim. * * Row-loop is manually 4x-unrolled so the compiler can overlap four * independent read-only loads in the memory pipeline. */ -template +template __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __restrict__ data, float* __restrict__ col_min, float* __restrict__ col_max, @@ -74,33 +158,76 @@ __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __r using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - const uint32_t col = blockIdx.x; - if (col >= dim) return; + const uint32_t col_base = blockIdx.x * VecCols; + // When launched with gridDim.x = dim / VecCols and dim % VecCols == 0 + // (enforced by the host-side dispatch), col_base + VecCols <= dim always. - ColMinMaxPair agg = {std::numeric_limits::max(), std::numeric_limits::lowest()}; + ColMinMaxPair agg[VecCols]; +#pragma unroll + for (int k = 0; k < VecCols; ++k) { + agg[k] = {std::numeric_limits::max(), std::numeric_limits::lowest()}; + } const int64_t stride = static_cast(BlockSize); int64_t row = static_cast(threadIdx.x); + // 4x row-unrolled loop with vectorized loads: 4 * VecCols values per iter + // are pulled into registers via 4 wide LDGs, exposing ILP across both + // the column and row axes. for (; row + 3 * stride < n_rows; row += 4 * stride) { - float v0 = float(data[row * dim + col]); - float v1 = float(data[(row + stride) * dim + col]); - float v2 = float(data[(row + 2 * stride) * dim + col]); - float v3 = float(data[(row + 3 * stride) * dim + col]); - agg.min_val = fminf(agg.min_val, fminf(fminf(v0, v1), fminf(v2, v3))); - agg.max_val = fmaxf(agg.max_val, fmaxf(fmaxf(v0, v1), fmaxf(v2, v3))); + float r0[VecCols], r1[VecCols], r2[VecCols], r3[VecCols]; + vec_loader::load(data + row * dim + col_base, r0); + vec_loader::load(data + (row + stride) * dim + col_base, r1); + vec_loader::load(data + (row + 2 * stride) * dim + col_base, r2); + vec_loader::load(data + (row + 3 * stride) * dim + col_base, r3); +#pragma unroll + for (int k = 0; k < VecCols; ++k) { + float mn = fminf(fminf(r0[k], r1[k]), fminf(r2[k], r3[k])); + float mx = fmaxf(fmaxf(r0[k], r1[k]), fmaxf(r2[k], r3[k])); + agg[k].min_val = fminf(agg[k].min_val, mn); + agg[k].max_val = fmaxf(agg[k].max_val, mx); + } } for (; row < n_rows; row += stride) { - float val = float(data[row * dim + col]); - agg.min_val = fminf(agg.min_val, val); - agg.max_val = fmaxf(agg.max_val, val); + float r[VecCols]; + vec_loader::load(data + row * dim + col_base, r); +#pragma unroll + for (int k = 0; k < VecCols; ++k) { + agg[k].min_val = fminf(agg[k].min_val, r[k]); + agg[k].max_val = fmaxf(agg[k].max_val, r[k]); + } } - agg = BlockReduce(temp_storage).Reduce(agg, ColMinMaxOp()); + // One block-reduce per owned column. CUB requires a __syncthreads() + // between reuses of the shared temp_storage. + for (int k = 0; k < VecCols; ++k) { + if (k > 0) __syncthreads(); + auto r = BlockReduce(temp_storage).Reduce(agg[k], ColMinMaxOp()); + if (threadIdx.x == 0) { + col_min[col_base + k] = r.min_val; + col_max[col_base + k] = r.max_val; + } + } +} - if (threadIdx.x == 0) { - col_min[col] = agg.min_val; - col_max[col] = agg.max_val; +/** + * Host-side dispatch that selects the widest VecCols compatible with the + * given dim alignment (VecCols in {4, 2, 1}), and launches + * fused_column_minmax_kernel with the corresponding grid shape. + */ +template +inline void launch_fused_column_minmax( + const T* data, float* col_min, float* col_max, int64_t n_rows, uint32_t dim, cudaStream_t stream) +{ + if (dim % 4 == 0) { + fused_column_minmax_kernel + <<>>(data, col_min, col_max, n_rows, dim); + } else if (dim % 2 == 0) { + fused_column_minmax_kernel + <<>>(data, col_min, col_max, n_rows, dim); + } else { + fused_column_minmax_kernel + <<>>(data, col_min, col_max, n_rows, dim); } } @@ -141,15 +268,21 @@ auto clone(const raft::resources& res, const index& source) -> index(new_vectors[i,d]) - centers[list_id,d]) is computed + * in registers and consumed immediately, avoiding a full HBM round trip through + * an intermediate residuals buffer. */ -template +template __launch_bounds__(BlockSize) RAFT_KERNEL encode_and_fill_kernel(const uint32_t* labels, - const float* residuals, + const T* new_vectors, + const float* centers, const int64_t* source_ixs, uint8_t** list_data_ptrs, int64_t** list_index_ptrs, @@ -186,12 +319,13 @@ __launch_bounds__(BlockSize) RAFT_KERNEL encode_and_fill_kernel(const uint32_t* constexpr uint32_t veclen = list_spec::kVecLen; uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; auto* list_dat = list_data_ptrs[list_id] + static_cast(group_offset) * padded_dim; - const float* src = residuals + row_id * dim; + const T* src = new_vectors + row_id * dim; + const float* ctr = centers + static_cast(list_id) * dim; for (uint32_t d = lane_id; d < padded_dim; d += kWarpSize) { uint8_t out; if (d < dim) { - float val = src[d]; + float val = utils::mapping{}(src[d]) - ctr[d]; float dv = delta[d]; float v = vmin[d]; float code = (dv > 0.0f) ? roundf((val - v) / dv) : 0.0f; @@ -271,6 +405,8 @@ void extend(raft::resources const& handle, enable_prefetch); vec_batches.prefetch_next_batch(); + const bool needs_prefetch_sync = enable_prefetch && vec_batches.does_copy(); + for (const auto& batch : vec_batches) { auto batch_data_view = raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); @@ -279,7 +415,7 @@ void extend(raft::resources const& handle, cuvs::cluster::kmeans::predict( handle, kmeans_params, batch_data_view, orig_centroids_view, batch_labels_view); vec_batches.prefetch_next_batch(); - raft::resource::sync_stream(handle); + if (needs_prefetch_sync) { raft::resource::sync_stream(handle); } } auto* list_sizes_ptr = index->list_sizes().data_handle(); @@ -320,41 +456,21 @@ void extend(raft::resources const& handle, vec_batches.prefetch_next_batch(); utils::batch_load_iterator idx_batch = vec_indices.begin(); - auto residuals_buf = raft::make_device_vector(handle, max_batch_size * dim); - size_t next_report_offset = 0; size_t d_report_offset = n_rows * 5 / 100; for (const auto& batch : vec_batches) { int64_t bs = batch.size(); - { - auto batch_view = raft::make_device_matrix_view(batch.data(), bs, dim); - auto residuals_view = - raft::make_device_matrix_view(residuals_buf.data_handle(), bs, dim); - - const float* centers_ptr = index->centers().data_handle(); - const uint32_t* labels_ptr = new_labels.data_handle() + batch.offset(); - - raft::linalg::map_offset( - handle, - residuals_view, - [centers_ptr, labels_ptr, dim] __device__(auto idx, T x) { - auto i = idx / dim; - auto j = idx % dim; - return utils::mapping{}(x)-centers_ptr[labels_ptr[i] * dim + j]; - }, - batch_view); - } - { constexpr int kEncodeBlockSize = 256; constexpr int kEncodeWarpsPerBlk = kEncodeBlockSize / kIndexGroupSize; const dim3 block_dim(kEncodeBlockSize); const dim3 grid_dim(raft::ceildiv(bs, int64_t(kEncodeWarpsPerBlk))); - encode_and_fill_kernel + encode_and_fill_kernel <<>>(new_labels.data_handle() + batch.offset(), - residuals_buf.data_handle(), + batch.data(), + index->centers().data_handle(), idx_batch->data(), index->data_ptrs().data_handle(), index->inds_ptrs().data_handle(), @@ -368,8 +484,7 @@ void extend(raft::resources const& handle, } vec_batches.prefetch_next_batch(); - raft::resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (needs_prefetch_sync) { raft::resource::sync_stream(handle); } if (batch.offset() > next_report_offset) { float progress = batch.offset() * 100.0f / n_rows; @@ -456,13 +571,11 @@ inline auto build( // Train SQ: predict labels for the training subset, compute residuals in-place, // and derive per-dimension vmin/delta from them. { - auto train_labels = raft::make_device_vector(handle, n_rows_train); - cuvs::cluster::kmeans::balanced_params pred_params; - pred_params.metric = idx.metric(); + auto train_labels = raft::make_device_vector(handle, n_rows_train); auto centers_const_view = raft::make_device_matrix_view( idx.centers().data_handle(), idx.n_lists(), dim); cuvs::cluster::kmeans::predict( - handle, pred_params, trainset_const_view, centers_const_view, train_labels.view()); + handle, kmeans_params, trainset_const_view, centers_const_view, train_labels.view()); constexpr int kResidualBlockSize = 256; compute_residuals_inplace_kernel @@ -483,8 +596,8 @@ inline auto build( auto* vmax_ptr = vmax_buf.data_handle(); constexpr int kMinMaxBlockSize = 256; - fused_column_minmax_kernel<<>>( - residuals.data_handle(), vmin_ptr, vmax_ptr, n_rows_train, dim); + launch_fused_column_minmax( + residuals.data_handle(), vmin_ptr, vmax_ptr, n_rows_train, dim, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Expand the observed range by a small margin to reduce clipping on unseen data, diff --git a/cpp/tests/neighbors/ann_ivf_sq.cuh b/cpp/tests/neighbors/ann_ivf_sq.cuh index f7311b75e4..b62892d9b1 100644 --- a/cpp/tests/neighbors/ann_ivf_sq.cuh +++ b/cpp/tests/neighbors/ann_ivf_sq.cuh @@ -57,10 +57,37 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { { } - void testSearch() + void testAll() + { + auto naive = compute_naive_knn(); + auto idx = build_index(true); + + { + SCOPED_TRACE("Search"); + checkSearch(idx, naive); + } + { + SCOPED_TRACE("Serialize"); + checkSerialize(idx); + } + { + SCOPED_TRACE("Filter"); + checkFilter(idx); + } + { + SCOPED_TRACE("Extend"); + checkExtend(naive); + } + } + + protected: + struct SearchResults { + std::vector indices; + std::vector distances; + }; + + void checkSearch(const cuvs::neighbors::ivf_sq::index& idx, const SearchResults& naive) { - auto naive = compute_naive_knn(); - auto idx = build_index(true); auto results = search_index(idx); float eps = 0.1; @@ -74,10 +101,8 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { min_recall_threshold())); } - void testSerialize() + void checkSerialize(const cuvs::neighbors::ivf_sq::index& idx) { - auto idx = build_index(true); - tmp_index_file index_file; cuvs::neighbors::ivf_sq::serialize(handle_, index_file.filename, idx); cuvs::neighbors::ivf_sq::index index_loaded(handle_); @@ -101,9 +126,8 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { 1.0)); } - void testExtend() + void checkExtend(const SearchResults& naive) { - auto naive = compute_naive_knn(); auto idx_empty = build_index(false); extend_index(&idx_empty); @@ -120,7 +144,7 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { min_recall_threshold())); } - void testFilter() + void checkFilter(const cuvs::neighbors::ivf_sq::index& idx) { if (ps.num_db_vecs <= static_cast(test_ivf_sample_filter::offset)) { GTEST_SKIP() << "Skipping filter test: num_db_vecs <= filter offset"; @@ -164,19 +188,9 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { rmm::device_uvector indices_ivfsq_dev(queries_size, stream_); { - cuvs::neighbors::ivf_sq::index_params index_params; cuvs::neighbors::ivf_sq::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; search_params.n_probes = ps.nprobe; - index_params.add_data_on_build = true; - index_params.kmeans_trainset_fraction = 0.5; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = cuvs::neighbors::ivf_sq::build(handle_, index_params, database_view); - auto removed_indices = raft::make_device_vector(handle_, test_ivf_sample_filter::offset); raft::linalg::map_offset(handle_, removed_indices.view(), raft::identity_op{}); @@ -196,7 +210,7 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { cuvs::neighbors::ivf_sq::search(handle_, search_params, - index, + idx, search_queries_view, indices_out_view, dists_out_view, @@ -241,12 +255,6 @@ class AnnIVFSQTest : public ::testing::TestWithParam> { search_queries.resize(0, stream_); } - private: - struct SearchResults { - std::vector indices; - std::vector distances; - }; - double min_recall_threshold() { return std::min(1.0, static_cast(ps.nprobe) / static_cast(ps.nlist)); diff --git a/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu b/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu index 734f736b09..8831ae720a 100644 --- a/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu +++ b/cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu @@ -10,10 +10,7 @@ namespace cuvs::neighbors::ivf_sq { typedef AnnIVFSQTest AnnIVFSQTestF_float; -TEST_P(AnnIVFSQTestF_float, AnnIVFSQSearch) { this->testSearch(); } -TEST_P(AnnIVFSQTestF_float, AnnIVFSQSerialize) { this->testSerialize(); } -TEST_P(AnnIVFSQTestF_float, AnnIVFSQExtend) { this->testExtend(); } -TEST_P(AnnIVFSQTestF_float, AnnIVFSQFilter) { this->testFilter(); } +TEST_P(AnnIVFSQTestF_float, AnnIVFSQ) { this->testAll(); } INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_float, ::testing::ValuesIn(inputs)); diff --git a/cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu b/cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu index e6f5e44dd3..fc17e246dd 100644 --- a/cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu +++ b/cpp/tests/neighbors/ann_ivf_sq/test_half_int64_t.cu @@ -10,10 +10,7 @@ namespace cuvs::neighbors::ivf_sq { typedef AnnIVFSQTest AnnIVFSQTestF_half; -TEST_P(AnnIVFSQTestF_half, AnnIVFSQSearch) { this->testSearch(); } -TEST_P(AnnIVFSQTestF_half, AnnIVFSQSerialize) { this->testSerialize(); } -TEST_P(AnnIVFSQTestF_half, AnnIVFSQExtend) { this->testExtend(); } -TEST_P(AnnIVFSQTestF_half, AnnIVFSQFilter) { this->testFilter(); } +TEST_P(AnnIVFSQTestF_half, AnnIVFSQ) { this->testAll(); } INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_half, ::testing::ValuesIn(inputs_half)); From 80a55fd0406fb212345b74230145c722dbed2f60 Mon Sep 17 00:00:00 2001 From: vic Date: Wed, 22 Apr 2026 18:56:44 +0200 Subject: [PATCH 25/31] account for RAFT update --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 13 +++++++++---- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 570ccbb872..a4a2063ca5 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -377,7 +377,7 @@ void extend(raft::resources const& handle, auto new_labels = raft::make_device_mdarray(handle, - raft::resource::get_large_workspace_resource(handle), + raft::resource::get_large_workspace_resource_ref(handle), raft::make_extents(n_rows)); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); @@ -401,7 +401,7 @@ void extend(raft::resources const& handle, index->dim(), max_batch_size, copy_stream, - raft::resource::get_workspace_resource(handle), + raft::resource::get_workspace_resource_ref(handle), enable_prefetch); vec_batches.prefetch_next_batch(); @@ -451,7 +451,12 @@ void extend(raft::resources const& handle, raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); utils::batch_load_iterator vec_indices( - new_indices, n_rows, 1, max_batch_size, stream, raft::resource::get_workspace_resource(handle)); + new_indices, + n_rows, + 1, + max_batch_size, + stream, + raft::resource::get_workspace_resource_ref(handle)); vec_batches.reset(); vec_batches.prefetch_next_batch(); utils::batch_load_iterator idx_batch = vec_indices.begin(); @@ -557,7 +562,7 @@ inline auto build( auto n_rows_train = n_rows / trainset_ratio; auto trainset = raft::make_device_mdarray(handle, - raft::resource::get_large_workspace_resource(handle), + raft::resource::get_large_workspace_resource_ref(handle), raft::make_extents(n_rows_train, idx.dim())); raft::matrix::sample_rows(handle, random_state, dataset, trainset.view()); auto trainset_const_view = raft::make_const_mdspan(trainset.view()); diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index 9728b90796..6201cdd4f4 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -877,7 +877,7 @@ inline void search_with_filtering(raft::resources const& handle, cuvs::distance::is_min_close(index.metric()), neighbors + std::size_t(offset_q) * k, distances + std::size_t(offset_q) * k, - raft::resource::get_workspace_resource(handle), + raft::resource::get_workspace_resource_ref(handle), sample_filter); } } From ac8ea4e8836dbdda4822d23b4f53b4dbf850814a Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 27 Apr 2026 16:27:16 +0200 Subject: [PATCH 26/31] IVF-SQ JIT-LTO --- cpp/CMakeLists.txt | 32 +- .../detail/jit_lto/ivf_sq/scan_fragments.hpp | 20 + .../jit_lto_kernels/device_functions.cuh | 25 + .../jit_lto_kernels/filter_kernel.cu.in | 31 + .../detail/jit_lto_kernels/filter_matrix.json | 3 + .../detail/jit_lto_kernels/kernel_def.hpp | 38 ++ .../detail/jit_lto_kernels/scan_impl.cuh | 260 ++++++++ .../detail/jit_lto_kernels/scan_kernel.cu.in | 62 ++ .../detail/jit_lto_kernels/scan_matrix.json | 8 + .../detail/jit_lto_kernels/scan_planner.hpp | 38 ++ cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 567 ++++++------------ .../ivf_sq_search_half_uint8_t_int64_t.cu | 29 - ..._t.cu => ivf_sq_search_uint8_t_int64_t.cu} | 1 + 13 files changed, 711 insertions(+), 403 deletions(-) create mode 100644 cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_kernel.cu.in create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_matrix.json create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/kernel_def.hpp create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp delete mode 100644 cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu rename cpp/src/neighbors/ivf_sq/{ivf_sq_search_float_uint8_t_int64_t.cu => ivf_sq_search_uint8_t_int64_t.cu} (97%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e3f366661e..f957f8b702 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -487,6 +487,35 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/post_process" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + set(ivf_sq_ns "cuvs::neighbors::ivf_sq::detail") + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "ivf_sq_scan_capacity_@capacity@_metric_@metric_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${ivf_sq_ns}::fragment_tag_ivf_sq_scan<${ivf_sq_ns}::tag_metric_@metric_name@, @capacity@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/scan" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "ivf_sq_filter_@filter_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${ivf_sq_ns}::fragment_tag_ivf_sq_filter<${neighbors_ns}::tag_filter_@filter_name@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/filter" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) set(ivf_pq_ns "cuvs::neighbors::ivf_pq::detail") generate_jit_lto_kernels( jit_lto_files @@ -900,8 +929,7 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/ivf_sq_index.cpp src/neighbors/ivf_sq/ivf_sq_build_extend_float_uint8_t_int64_t.cu src/neighbors/ivf_sq/ivf_sq_build_extend_half_uint8_t_int64_t.cu - src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu - src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu + src/neighbors/ivf_sq/ivf_sq_search_uint8_t_int64_t.cu src/neighbors/ivf_sq/ivf_sq_serialize_uint8_t.cu src/neighbors/knn_merge_parts.cu src/neighbors/nn_descent.cu diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp new file mode 100644 index 0000000000..b20684b11f --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp @@ -0,0 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::ivf_sq::detail { + +struct tag_metric_l2 {}; +struct tag_metric_ip {}; +struct tag_metric_cosine {}; + +template +struct fragment_tag_ivf_sq_scan {}; + +template +struct fragment_tag_ivf_sq_filter {}; + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh new file mode 100644 index 0000000000..188933d803 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh @@ -0,0 +1,25 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +// Forward declaration of the sample filter device function. +// The concrete implementation is provided by a JIT-LTO filter-adapter fragment +// (see filter_kernel.cu.in) that delegates to the shared +// cuvs::neighbors::detail::sample_filter_ fragment. +template +__device__ bool sample_filter(const IndexT* const* const inds_ptrs, + const uint32_t query_ix, + const uint32_t cluster_ix, + const uint32_t sample_ix, + uint32_t* bitset_ptr, + IndexT bitset_len, + IndexT original_nbits); + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_kernel.cu.in new file mode 100644 index 0000000000..c93038de84 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_kernel.cu.in @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr auto sample_filter_impl = + cuvs::neighbors::detail::sample_filter_@filter_name@; + +} // namespace + +namespace cuvs::neighbors::ivf_sq::detail { + +template <> +__device__ bool sample_filter(const int64_t* const* const inds_ptrs, + const uint32_t query_ix, + const uint32_t cluster_ix, + const uint32_t sample_ix, + uint32_t* bitset_ptr, + int64_t bitset_len, + int64_t original_nbits) +{ + return sample_filter_impl( + inds_ptrs, query_ix, cluster_ix, sample_ix, bitset_ptr, bitset_len, original_nbits); +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_matrix.json new file mode 100644 index 0000000000..2a01bc4583 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/filter_matrix.json @@ -0,0 +1,3 @@ +{ + "filter_name": ["none", "bitset"] +} diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/kernel_def.hpp new file mode 100644 index 0000000000..5df97deabc --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/kernel_def.hpp @@ -0,0 +1,38 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +static constexpr int kSqScanThreads = 128; + +// Function-pointer signature for the JIT-LTO scan entrypoint. +// Must exactly match the extern "C" __global__ ivf_sq_scan(...) signature +// produced by scan_kernel.cu.in. +template +using ivf_sq_scan_func_t = void(const uint8_t* const* data_ptrs, + const uint32_t* list_sizes, + const uint32_t* coarse_indices, + const float* queries_float, + const float* centers, + const float* sq_vmin, + const float* sq_delta, + const float* query_norms, + uint32_t n_probes, + uint32_t dim, + uint32_t k, + uint32_t max_samples, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + IdxT* const* inds_ptrs, + uint32_t* bitset_ptr, + IdxT bitset_len, + IdxT original_nbits); + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh new file mode 100644 index 0000000000..bc8f7587ee --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh @@ -0,0 +1,260 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "device_functions.cuh" +#include "kernel_def.hpp" +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +// block_sort type selection: dispatch the dummy block sort when Capacity == 0 +// so the same impl body works for both the fused top-k path (Capacity > 0) and +// the materialize-all path (Capacity == 0). +template +struct sq_block_sort { + using type = raft::matrix::detail::select::warpsort::block_sort< + raft::matrix::detail::select::warpsort::warp_sort_filtered, + Capacity, + Ascending, + float, + uint32_t>; +}; + +template +struct sq_block_sort<0, Ascending> { + using type = ivf::detail::dummy_block_sort_t; +}; + +template +using sq_block_sort_t = typename sq_block_sort::type; + +// IVF-SQ scan kernel body with fused in-kernel top-k. +// +// Grid layout: +// kManageLocalTopK (Capacity > 0): +// grid (grid_dim_x, n_queries) - each block loops over probes +// otherwise (Capacity == 0): +// grid (n_probes, n_queries) - one block per (query, probe) +// +// Shared-memory layout (always 3 x dim floats): +// [s_sq_scale(dim) | s_query_term(dim) | s_aux(dim)] +// +// s_sq_scale = delta[d] - SQ dequantization scale, invariant (Phase 1). +// +// L2 path: +// Phase 1: s_aux[d] = query[d] - vmin[d] (invariant) +// Phase 2: s_query_term[d] = s_aux[d] - centroid[d] (per-probe) +// The full SQ reconstruction is centroid + vmin + code*delta, so +// query - reconstructed = (query - vmin - centroid) - code*delta +// = s_query_term - code*s_sq_scale. +// +// IP/Cosine path: +// Phase 1: s_query_term[d] = query[d] (invariant) +// Phase 2: s_aux[d] = centroid[d] + vmin[d] (per-probe) +// Reconstructed vector component: s_aux[d] + code*s_sq_scale[d]. +// +// After all probes are scanned, the smem is reused for block_sort merge. +template +__device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs, + const uint32_t* list_sizes, + const uint32_t* coarse_indices, + const float* queries_float, + const float* centers, + const float* sq_vmin, + const float* sq_delta, + const float* query_norms, + uint32_t n_probes, + uint32_t dim, + uint32_t k, + uint32_t max_samples, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + const int64_t* const* inds_ptrs, + uint32_t* bitset_ptr, + int64_t bitset_len, + int64_t original_nbits) +{ + static_assert(kIndexGroupSize == raft::WarpSize, + "Warp-coalesced scan requires kIndexGroupSize == WarpSize"); + + constexpr int BlockDim = kSqScanThreads; + constexpr bool kManageLocalTopK = (Capacity > 0); + constexpr bool kIsL2 = std::is_same_v; + constexpr bool kIsCosine = std::is_same_v; + constexpr bool kIsIP = std::is_same_v; + constexpr bool kAscending = !kIsIP; + + extern __shared__ __align__(256) uint8_t smem_buf[]; + float* smem = reinterpret_cast(smem_buf); + + float* s_sq_scale = smem; + float* s_query_term = smem + dim; + float* s_aux = smem + 2 * dim; + + const uint32_t query_ix = blockIdx.y; + const float* query = queries_float + query_ix * dim; + + if constexpr (kManageLocalTopK) { + out_distances += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; + out_indices += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; + } + + // Phase 1: load shared memory that is invariant across probes. + for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { + s_sq_scale[d] = sq_delta[d]; + if constexpr (kIsL2) { + s_aux[d] = query[d] - sq_vmin[d]; + } else { + s_query_term[d] = query[d]; + } + } + __syncthreads(); + + using local_topk_t = sq_block_sort_t; + local_topk_t queue(k); + + const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; + const uint32_t* my_chunk = chunk_indices + query_ix * n_probes; + + constexpr uint32_t veclen = 16; + constexpr uint32_t kWarpsPerBlock = BlockDim / raft::WarpSize; + const uint32_t warp_id = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + // Phase 2: loop over probes. + // Synchronization protocol: + // (a) __syncthreads after Phase 1 (above) ensures invariant smem arrays + // (s_sq_scale, and L2: s_aux / IP-Cosine: s_query_term) are visible + // before Phase 2 overwrites the per-probe array. + // (b) __syncthreads after per-probe smem writes (L2: s_query_term / + // IP-Cosine: s_aux) ensures probe-specific values are visible before + // the distance computation. + // (c) __syncthreads at the end of each iteration ensures all distance + // computation reads are complete before the next iteration overwrites + // the per-probe smem region. + // When cluster_sz == 0, barrier (c) is skipped because no distance reads + // occurred; all threads converge on the same branch uniformly, and the + // next iteration's barrier (b) provides the needed ordering. + for (uint32_t probe_ix = blockIdx.x; probe_ix < n_probes; + probe_ix += (kManageLocalTopK ? gridDim.x : uint32_t{1})) { + const uint32_t cluster_id = my_coarse[probe_ix]; + const uint32_t cluster_sz = list_sizes[cluster_id]; + + { + const float* centroid = centers + cluster_id * dim; + for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { + if constexpr (kIsL2) { + s_query_term[d] = s_aux[d] - centroid[d]; + } else { + s_aux[d] = centroid[d] + sq_vmin[d]; + } + } + } + __syncthreads(); // (b) + + if (cluster_sz == 0) { + if constexpr (!kManageLocalTopK) break; + continue; + } + + const uint8_t* codes = data_ptrs[cluster_id]; + uint32_t sample_offset = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; + uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; + uint32_t n_dim_blocks = padded_dim / veclen; + + for (uint32_t group = warp_id * kIndexGroupSize; group < cluster_sz; + group += kWarpsPerBlock * kIndexGroupSize) { + const uint32_t row = group + lane_id; + const bool valid = + (row < cluster_sz) && + sample_filter( + inds_ptrs, query_ix, cluster_id, row, bitset_ptr, bitset_len, original_nbits); + + float dist = 0.0f; + float v_norm_sq = 0.0f; + + const uint8_t* group_data = codes + size_t(group) * padded_dim; + + for (uint32_t bl = 0; bl < n_dim_blocks; bl++) { + uint8_t codes_local[veclen]; + *reinterpret_cast(codes_local) = *reinterpret_cast( + group_data + bl * (veclen * kIndexGroupSize) + lane_id * veclen); + + const uint32_t l = bl * veclen; +#pragma unroll + for (uint32_t j = 0; j < veclen; j++) { + if (l + j < dim) { + float recon = float(codes_local[j]) * s_sq_scale[l + j]; + + if constexpr (kIsL2) { + float diff = s_query_term[l + j] - recon; + dist += diff * diff; + } else { + float v_d = s_aux[l + j] + recon; + dist += s_query_term[l + j] * v_d; + if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } + } + } + } + } + + if constexpr (kIsCosine) { + float denom = query_norms[query_ix] * sqrtf(v_norm_sq); + dist = (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; + } + + if constexpr (kManageLocalTopK) { + float val = valid ? dist : local_topk_t::queue_t::kDummy; + queue.add(val, sample_offset + row); + } else { + if (valid) { + uint32_t out_idx = query_ix * max_samples + sample_offset + row; + out_distances[out_idx] = dist; + out_indices[out_idx] = sample_offset + row; + } + } + } + + __syncthreads(); // (c) + if constexpr (!kManageLocalTopK) break; + } + + if constexpr (kManageLocalTopK) { + // All probe iterations are done; smem_buf is reused for block_sort merge. + // The loop's last (b) or (c) barrier ensures all prior smem accesses have + // completed, so this additional barrier is only needed to synchronize any + // register-level state across warps before the merge. + __syncthreads(); + queue.done(smem_buf); + queue.store(out_distances, out_indices); + + // block_sort initializes unused slots with (kDummy, idx=0). When the + // probed clusters have fewer than k total valid vectors, those slots + // survive into the output and share idx=0 with the real first vector, + // causing duplicates. Mark them with an invalid index so + // postprocess_neighbors treats them as out-of-bounds. + // store() is a warp-0-only operation, restrict the fixup to the same warp. + if (threadIdx.x < raft::WarpSize) { + constexpr auto kDummyVal = local_topk_t::queue_t::kDummy; + for (uint32_t i = threadIdx.x; i < k; i += raft::WarpSize) { + if (out_distances[i] == kDummyVal) { out_indices[i] = uint32_t(0xFFFFFFFF); } + } + } + } +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in new file mode 100644 index 0000000000..3c59f91764 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in @@ -0,0 +1,62 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using metric_tag = cuvs::neighbors::ivf_sq::detail::tag_metric_@metric_name@; +constexpr int capacity = @capacity@; + +} // namespace + +namespace cuvs::neighbors::ivf_sq::detail { + +extern "C" __global__ __launch_bounds__(kSqScanThreads) void ivf_sq_scan( + const uint8_t* const* data_ptrs, + const uint32_t* list_sizes, + const uint32_t* coarse_indices, + const float* queries_float, + const float* centers, + const float* sq_vmin, + const float* sq_delta, + const float* query_norms, + uint32_t n_probes, + uint32_t dim, + uint32_t k, + uint32_t max_samples, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + int64_t* const* inds_ptrs, + uint32_t* bitset_ptr, + int64_t bitset_len, + int64_t original_nbits) +{ + ivf_sq_scan_impl(data_ptrs, + list_sizes, + coarse_indices, + queries_float, + centers, + sq_vmin, + sq_delta, + query_norms, + n_probes, + dim, + k, + max_samples, + chunk_indices, + out_distances, + out_indices, + inds_ptrs, + bitset_ptr, + bitset_len, + original_nbits); +} + +static_assert(std::is_same_v>); + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json new file mode 100644 index 0000000000..fac4154c18 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json @@ -0,0 +1,8 @@ +{ + "capacity": [ + "0", "32", "64", "128", "256" + ], + "metric_name": [ + "l2", "ip", "cosine" + ] +} diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp new file mode 100644 index 0000000000..ae157e3994 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp @@ -0,0 +1,38 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +struct IvfSqScanPlanner : AlgorithmPlanner { + inline static LauncherJitCache launcher_jit_cache{}; + + IvfSqScanPlanner() : AlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} + + template + void add_entrypoint() + { + this->add_static_fragment>(); + } + + template + void add_filter_device_function() + { + this->add_static_fragment>(); + this->add_static_fragment< + cuvs::neighbors::detail::fragment_tag_sample_filter>(); + } +}; + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index 6201cdd4f4..bbe6f05df1 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -9,6 +9,10 @@ #include "../detail/ann_utils.cuh" #include "../ivf_common.cuh" #include "../sample_filter.cuh" +#include "detail/jit_lto_kernels/kernel_def.hpp" +#include "detail/jit_lto_kernels/scan_planner.hpp" +#include +#include #include #include @@ -23,6 +27,7 @@ #include #include #include +#include #include #include @@ -33,10 +38,6 @@ namespace cuvs::neighbors::ivf_sq::detail { using namespace cuvs::spatial::knn::detail; // NOLINT -enum class SqScanMetric { kL2, kIP, kCosine }; - -static constexpr int kSqScanThreads = 128; - // Maximum fused top-k capacity we instantiate for the scan kernel. // Must match the highest Capacity case in ivf_sq_scan's switch. static constexpr int kMaxSqScanCapacity = 256; @@ -49,32 +50,57 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool return k <= kMaxSqScanCapacity; } -// --------------------------------------------------------------------------- -// block_sort type selection (fused top-k vs dummy for Capacity == 0) -// --------------------------------------------------------------------------- -template -struct sq_block_sort { - using type = raft::matrix::detail::select::warpsort::block_sort< - raft::matrix::detail::select::warpsort::warp_sort_filtered, - Capacity, - Ascending, - float, - uint32_t>; -}; - -template -struct sq_block_sort<0, Ascending> { - using type = ivf::detail::dummy_block_sort_t; -}; - -template -using sq_block_sort_t = typename sq_block_sort::type; +// Compute shared-memory size for the scan kernel (3 x dim floats), plus the +// block_sort merge buffer when fused top-k is enabled. Must match the smem +// layout used by ivf_sq_scan_impl. +inline size_t sq_scan_smem_size(uint32_t dim) { return 3 * dim * sizeof(float); } -// --------------------------------------------------------------------------- -// configure_grid_dim_x: choose grid.x to saturate the GPU -// --------------------------------------------------------------------------- +template +size_t sq_scan_total_smem(uint32_t dim, uint32_t k) +{ + size_t scan_smem = sq_scan_smem_size(dim); + if constexpr (Capacity > 0) { + constexpr int kSubwarpSize = std::min(Capacity, raft::WarpSize); + int num_subwarps = kSqScanThreads / kSubwarpSize; + size_t merge_smem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + num_subwarps, k); + return std::max(scan_smem, merge_smem); + } + return scan_smem; +} + +// Map the runtime distance metric onto the corresponding compile-time tag and +// invoke `f` (a generic lambda) with the tag type. Used by the JIT-LTO launch +// dispatch to forward the metric to the planner without bloating the static +// library with the kernel body. +template +void dispatch_metric_tag(cuvs::distance::DistanceType metric, F&& f) +{ + switch (metric) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: f(tag_metric_l2{}); return; + case cuvs::distance::DistanceType::InnerProduct: f(tag_metric_ip{}); return; + case cuvs::distance::DistanceType::CosineExpanded: f(tag_metric_cosine{}); return; + default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); + } +} + +template +constexpr auto get_filter_type_tag() +{ + using namespace cuvs::neighbors::filtering; + if constexpr (std::is_same_v) { + return cuvs::neighbors::detail::tag_filter_none{}; + } else if constexpr (std::is_same_v>) { + return cuvs::neighbors::detail::tag_filter_bitset{}; + } +} + +// Choose grid.x to saturate the GPU. Mirrors the previous static +// configure_grid_dim_x but takes the JIT-linked cudaKernel_t. inline uint32_t configure_grid_dim_x( - uint32_t n_queries, uint32_t n_probes, int smem_size, int block_size, const void* kernel_ptr) + uint32_t n_queries, uint32_t n_probes, int smem_size, int block_size, cudaKernel_t kernel) { int dev_id; RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); @@ -82,7 +108,7 @@ inline uint32_t configure_grid_dim_x( RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); int num_blocks_per_sm = 0; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel_ptr, block_size, smem_size)); + &num_blocks_per_sm, kernel, block_size, smem_size)); size_t min_grid_size = size_t(num_sms) * num_blocks_per_sm; size_t min_grid_x = raft::ceildiv(min_grid_size, n_queries); @@ -90,355 +116,149 @@ inline uint32_t configure_grid_dim_x( } // --------------------------------------------------------------------------- -// IVF-SQ scan kernel with fused in-kernel top-k +// JIT-LTO scan launcher +// +// The scan kernel `ivf_sq_scan_impl` lives as a fatbin fragment under +// `detail/jit_lto_kernels/`. At runtime we ask the planner for the linked +// kernel matching (Capacity, MetricTag, FilterTag), set the dynamic shared +// memory attribute, compute grid_dim_x via cudaOccupancy*, and dispatch. // -// Grid layout: +// Grid layout matches the original static kernel: // kManageLocalTopK (Capacity > 0): -// grid (grid_dim_x, n_queries) — each block loops over probes +// grid (grid_dim_x, n_queries) - each block loops over probes // otherwise (Capacity == 0): -// grid (n_probes, n_queries) — one block per (query, probe) -// -// Shared-memory layout (always 3 × dim floats): -// [s_sq_scale(dim) | s_query_term(dim) | s_aux(dim)] -// -// s_sq_scale = delta[d] — SQ dequantization scale, invariant (Phase 1). -// -// L2 path: -// Phase 1: s_aux[d] = query[d] - vmin[d] (invariant) -// Phase 2: s_query_term[d] = s_aux[d] - centroid[d] (per-probe) -// The full SQ reconstruction is centroid + vmin + code*delta, so -// query - reconstructed = (query - vmin - centroid) - code*delta -// = s_query_term - code*s_sq_scale. -// -// IP/Cosine path: -// Phase 1: s_query_term[d] = query[d] (invariant) -// Phase 2: s_aux[d] = centroid[d] + vmin[d] (per-probe) -// Reconstructed vector component: s_aux[d] + code*s_sq_scale[d]. -// -// After all probes are scanned, the smem is reused for block_sort merge. +// grid (n_probes, n_queries) - one block per (query, probe) // --------------------------------------------------------------------------- -template -__launch_bounds__(BlockDim) RAFT_KERNEL ivf_sq_scan_kernel(const uint8_t* const* data_ptrs, - const uint32_t* list_sizes, - const uint32_t* coarse_indices, - const float* queries_float, - const float* centers, - const float* sq_vmin, - const float* sq_delta, - const float* query_norms, - uint32_t n_probes, - uint32_t dim, - uint32_t k, - uint32_t max_samples, - const uint32_t* chunk_indices, - float* out_distances, - uint32_t* out_indices, - IvfSampleFilterT sample_filter) +void launch_kernel(const index& idx, + const float* queries_float, + const float* query_norms, + uint32_t n_queries, + uint32_t n_probes, + uint32_t k, + uint32_t max_samples, + const uint32_t* coarse_indices, + const uint32_t* chunk_indices, + float* out_distances, + uint32_t* out_indices, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream, + IvfSampleFilterT sample_filter) { - static_assert(kIndexGroupSize == raft::WarpSize, - "Warp-coalesced scan requires kIndexGroupSize == WarpSize"); + static_assert(std::is_same_v, "IVF-SQ JIT-LTO scan only supports CodeT=uint8_t"); constexpr bool kManageLocalTopK = (Capacity > 0); - constexpr bool kIsL2 = (Metric == SqScanMetric::kL2); - constexpr bool kIsCosine = (Metric == SqScanMetric::kCosine); - constexpr bool kAscending = (Metric != SqScanMetric::kIP); - - extern __shared__ __align__(256) uint8_t smem_buf[]; - float* smem = reinterpret_cast(smem_buf); + constexpr int kThreads = kSqScanThreads; + uint32_t dim = idx.dim(); - float* s_sq_scale = smem; - float* s_query_term = smem + dim; - float* s_aux = smem + 2 * dim; + IvfSqScanPlanner kernel_planner; + kernel_planner.add_entrypoint(); + kernel_planner.add_filter_device_function(); + auto kernel_launcher = kernel_planner.get_launcher(); + + // Extract bitset arguments from the user's filter object so they can be + // passed to the JIT-linked kernel as plain pointers/scalars; the JIT'd + // sample_filter fragment reconstructs the bitset_view inside the kernel. + int64_t* const* inds_ptrs_dev = idx.inds_ptrs().data_handle(); + uint32_t* bitset_ptr = nullptr; + int64_t bitset_len = 0; + int64_t original_nbits = 0; + if constexpr (std::is_same_v>) { + bitset_ptr = sample_filter.view().data(); + bitset_len = sample_filter.view().size(); + original_nbits = sample_filter.view().get_original_nbits(); + } - const uint32_t query_ix = blockIdx.y; - const float* query = queries_float + query_ix * dim; + size_t smem = sq_scan_total_smem(dim, k); - // Point output to this block's slice when using fused top-k - if constexpr (kManageLocalTopK) { - out_distances += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; - out_indices += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; + { + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int max_smem; + RAFT_CUDA_TRY( + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id)); + RAFT_EXPECTS(smem <= size_t(max_smem), + "IVF-SQ scan kernel requires %zu bytes of shared memory (dim=%u, k=%u), " + "but the device supports at most %d bytes per block.", + smem, + dim, + k, + max_smem); } - // --- Phase 1: load shared memory that is invariant across probes --- - for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { - s_sq_scale[d] = sq_delta[d]; - if constexpr (kIsL2) { - s_aux[d] = query[d] - sq_vmin[d]; - } else { - s_query_term[d] = query[d]; - } + { + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + RAFT_CUDA_TRY(cudaKernelSetAttributeForDevice(kernel_launcher->get_kernel(), + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(smem), + dev_id)); } - __syncthreads(); - - using local_topk_t = sq_block_sort_t; - local_topk_t queue(k); - - const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; - const uint32_t* my_chunk = chunk_indices + query_ix * n_probes; - - constexpr uint32_t veclen = 16; - constexpr uint32_t kWarpsPerBlock = BlockDim / raft::WarpSize; - const uint32_t warp_id = threadIdx.x / raft::WarpSize; - const uint32_t lane_id = threadIdx.x % raft::WarpSize; - - // --- Phase 2: loop over probes --- - // Synchronization protocol: - // (a) __syncthreads after Phase 1 (above) ensures invariant smem arrays - // (s_sq_scale, and L2: s_aux / IP-Cosine: s_query_term) are visible - // before Phase 2 overwrites the per-probe array. - // (b) __syncthreads after per-probe smem writes (L2: s_query_term / - // IP-Cosine: s_aux) ensures probe-specific values are visible before - // the distance computation. - // (c) __syncthreads at the end of each iteration ensures all distance - // computation reads are complete before the next iteration overwrites - // the per-probe smem region. - // When cluster_sz == 0, barrier (c) is skipped because no distance reads - // occurred; all threads converge on the same branch uniformly, and the - // next iteration's barrier (b) provides the needed ordering. - for (uint32_t probe_ix = blockIdx.x; probe_ix < n_probes; - probe_ix += (kManageLocalTopK ? gridDim.x : uint32_t{1})) { - const uint32_t cluster_id = my_coarse[probe_ix]; - const uint32_t cluster_sz = list_sizes[cluster_id]; - - // Load centroid-dependent shared memory terms - { - const float* centroid = centers + cluster_id * dim; - for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { - if constexpr (kIsL2) { - s_query_term[d] = s_aux[d] - centroid[d]; - } else { - s_aux[d] = centroid[d] + sq_vmin[d]; - } - } - } - __syncthreads(); // (b) - - if (cluster_sz == 0) { - // No distance computation reads happened, so no end-of-iteration - // barrier is needed; the next iteration's barrier (b) is sufficient. - if constexpr (!kManageLocalTopK) break; - continue; - } - - const uint8_t* codes = data_ptrs[cluster_id]; - uint32_t sample_offset = (probe_ix > 0) ? my_chunk[probe_ix - 1] : 0; - uint32_t padded_dim = ((dim + veclen - 1) / veclen) * veclen; - uint32_t n_dim_blocks = padded_dim / veclen; - - for (uint32_t group = warp_id * kIndexGroupSize; group < cluster_sz; - group += kWarpsPerBlock * kIndexGroupSize) { - const uint32_t row = group + lane_id; - const bool valid = (row < cluster_sz) && sample_filter(query_ix, cluster_id, row); - - float dist = 0.0f; - float v_norm_sq = 0.0f; - - const uint8_t* group_data = codes + size_t(group) * padded_dim; - - for (uint32_t bl = 0; bl < n_dim_blocks; bl++) { - uint8_t codes_local[veclen]; - *reinterpret_cast(codes_local) = *reinterpret_cast( - group_data + bl * (veclen * kIndexGroupSize) + lane_id * veclen); - - const uint32_t l = bl * veclen; -#pragma unroll - for (uint32_t j = 0; j < veclen; j++) { - if (l + j < dim) { - float recon = float(codes_local[j]) * s_sq_scale[l + j]; - - if constexpr (kIsL2) { - float diff = s_query_term[l + j] - recon; - dist += diff * diff; - } else { - float v_d = s_aux[l + j] + recon; - dist += s_query_term[l + j] * v_d; - if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } - } - } - } - } - - if constexpr (kIsCosine) { - float denom = query_norms[query_ix] * sqrtf(v_norm_sq); - dist = (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; - } - - if constexpr (kManageLocalTopK) { - float val = valid ? dist : local_topk_t::queue_t::kDummy; - queue.add(val, sample_offset + row); - } else { - if (valid) { - uint32_t out_idx = query_ix * max_samples + sample_offset + row; - out_distances[out_idx] = dist; - out_indices[out_idx] = sample_offset + row; - } - } - } - __syncthreads(); // (c) - if constexpr (!kManageLocalTopK) break; - } + constexpr uint32_t kMaxGridY = 65535; if constexpr (kManageLocalTopK) { - // All probe iterations are done; smem_buf is reused for block_sort merge. - // The loop's last (b) or (c) barrier ensures all prior smem accesses have - // completed, so this additional barrier is only needed to synchronize any - // register-level state across warps before the merge. - __syncthreads(); - queue.done(smem_buf); - queue.store(out_distances, out_indices); - - // block_sort initializes unused slots with (kDummy, idx=0). When the - // probed clusters have fewer than k total valid vectors, those slots - // survive into the output and share idx=0 with the real first vector, - // causing duplicates. Mark them with an invalid index so - // postprocess_neighbors treats them as out-of-bounds. - // store() is a warp-0-only operation, restrict the fixup to the same warp. - if (threadIdx.x < raft::WarpSize) { - constexpr auto kDummyVal = local_topk_t::queue_t::kDummy; - for (uint32_t i = threadIdx.x; i < k; i += raft::WarpSize) { - if (out_distances[i] == kDummyVal) { out_indices[i] = uint32_t(0xFFFFFFFF); } - } + if (grid_dim_x == 0) { + grid_dim_x = configure_grid_dim_x( + std::min(kMaxGridY, n_queries), n_probes, smem, kThreads, kernel_launcher->get_kernel()); + return; } } -} - -// --------------------------------------------------------------------------- -// Compute shared-memory size for a given kernel configuration -// --------------------------------------------------------------------------- -inline size_t sq_scan_smem_size(uint32_t dim) { return 3 * dim * sizeof(float); } - -template -size_t sq_scan_total_smem(uint32_t dim, uint32_t k) -{ - size_t scan_smem = sq_scan_smem_size(dim); - if constexpr (Capacity > 0) { - constexpr int kSubwarpSize = std::min(Capacity, raft::WarpSize); - int num_subwarps = kSqScanThreads / kSubwarpSize; - size_t merge_smem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - num_subwarps, k); - return std::max(scan_smem, merge_smem); - } - return scan_smem; -} - -// --------------------------------------------------------------------------- -// Launch helper: dispatches on Metric, handles grid_dim_x query vs launch -// --------------------------------------------------------------------------- -template -void ivf_sq_scan_launch(const index& idx, - const float* queries_float, - const float* query_norms, - uint32_t n_queries, - uint32_t n_probes, - uint32_t k, - uint32_t max_samples, - const uint32_t* coarse_indices, - const uint32_t* chunk_indices, - float* out_distances, - uint32_t* out_indices, - IvfSampleFilterT sample_filter, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - constexpr bool kManageLocalTopK = (Capacity > 0); - constexpr int kThreads = kSqScanThreads; - uint32_t dim = idx.dim(); - constexpr uint32_t kMaxGridY = 65535; - - auto do_launch = [&](auto kernel_ptr) { - size_t smem = sq_scan_total_smem(dim, k); - - { - int dev_id; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - int max_smem; - RAFT_CUDA_TRY( - cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id)); - RAFT_EXPECTS(smem <= size_t(max_smem), - "IVF-SQ scan kernel requires %zu bytes of shared memory (dim=%u, k=%u), " - "but the device supports at most %d bytes per block.", - smem, - dim, - k, - max_smem); - } + dim3 block(kThreads); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem)); + for (uint32_t query_offset = 0; query_offset < n_queries; query_offset += kMaxGridY) { + uint32_t batch = std::min(kMaxGridY, n_queries - query_offset); + dim3 grid = kManageLocalTopK ? dim3(grid_dim_x, batch) : dim3(n_probes, batch); - // If grid_dim_x == 0, compute the optimal value and return + const float* q_ptr = queries_float + uint64_t(query_offset) * dim; + const float* qn_ptr = query_norms ? query_norms + query_offset : query_norms; + const uint32_t* ci = coarse_indices + uint64_t(query_offset) * n_probes; + const uint32_t* ch = chunk_indices + uint64_t(query_offset) * n_probes; + float* od = out_distances; + uint32_t* oi = out_indices; if constexpr (kManageLocalTopK) { - if (grid_dim_x == 0) { - grid_dim_x = configure_grid_dim_x(std::min(kMaxGridY, n_queries), - n_probes, - smem, - kThreads, - reinterpret_cast(kernel_ptr)); - return; - } - } - - dim3 block(kThreads); - - // Batch over queries to respect the gridDim.y limit (65535) - for (uint32_t query_offset = 0; query_offset < n_queries; query_offset += kMaxGridY) { - uint32_t batch = std::min(kMaxGridY, n_queries - query_offset); - dim3 grid = kManageLocalTopK ? dim3(grid_dim_x, batch) : dim3(n_probes, batch); - - auto q_ptr = queries_float + uint64_t(query_offset) * dim; - auto qn_ptr = query_norms ? query_norms + query_offset : query_norms; - auto ci = coarse_indices + uint64_t(query_offset) * n_probes; - auto ch = chunk_indices + uint64_t(query_offset) * n_probes; - auto od = out_distances; - auto oi = out_indices; - if constexpr (kManageLocalTopK) { - od += uint64_t(query_offset) * grid_dim_x * k; - oi += uint64_t(query_offset) * grid_dim_x * k; - } else { - od += uint64_t(query_offset) * max_samples; - oi += uint64_t(query_offset) * max_samples; - } - - kernel_ptr<<>>(idx.data_ptrs().data_handle(), - idx.list_sizes().data_handle(), - ci, - q_ptr, - idx.centers().data_handle(), - idx.sq_vmin().data_handle(), - idx.sq_delta().data_handle(), - qn_ptr, - n_probes, - dim, - k, - max_samples, - ch, - od, - oi, - sample_filter); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + od += uint64_t(query_offset) * grid_dim_x * k; + oi += uint64_t(query_offset) * grid_dim_x * k; + } else { + od += uint64_t(query_offset) * max_samples; + oi += uint64_t(query_offset) * max_samples; } - }; - switch (idx.metric()) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: - do_launch(ivf_sq_scan_kernel); - break; - case cuvs::distance::DistanceType::InnerProduct: - do_launch(ivf_sq_scan_kernel); - break; - case cuvs::distance::DistanceType::CosineExpanded: - do_launch( - ivf_sq_scan_kernel); - break; - default: RAFT_FAIL("Unsupported metric type for IVF-SQ scan."); + const uint8_t* const* data_ptrs = idx.data_ptrs().data_handle(); + const uint32_t* list_sizes = idx.list_sizes().data_handle(); + const float* centers = idx.centers().data_handle(); + const float* sq_vmin = idx.sq_vmin().data_handle(); + const float* sq_delta = idx.sq_delta().data_handle(); + + kernel_launcher->dispatch>(stream, + grid, + block, + smem, + data_ptrs, + list_sizes, + ci, + q_ptr, + centers, + sq_vmin, + sq_delta, + qn_ptr, + n_probes, + dim, + k, + max_samples, + ch, + od, + oi, + inds_ptrs_dev, + bitset_ptr, + bitset_len, + original_nbits); } } @@ -475,21 +295,27 @@ void ivf_sq_scan(raft::resources const& handle, capacity = 0; } + using IvfSampleFilterTag = decltype(get_filter_type_tag()); + auto fwd = [&](auto cap_tag) { - ivf_sq_scan_launch(idx, - queries_float, - query_norms, - n_queries, - n_probes, - k, - max_samples, - coarse_indices, - chunk_indices, - out_distances, - out_indices, - sample_filter, - grid_dim_x, - stream); + constexpr int kCap = decltype(cap_tag)::value; + dispatch_metric_tag(idx.metric(), [&](auto metric_tag) { + using MetricTag = decltype(metric_tag); + launch_kernel(idx, + queries_float, + query_norms, + n_queries, + n_probes, + k, + max_samples, + coarse_indices, + chunk_indices, + out_distances, + out_indices, + grid_dim_x, + stream, + sample_filter); + }); }; switch (capacity) { @@ -642,9 +468,6 @@ void search_impl(raft::resources const& handle, num_samples.data(), stream); - auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( - index.inds_ptrs().data_handle(), sample_filter); - bool manage_local_topk = is_local_topk_feasible(k); // Determine grid_dim_x for the fused path @@ -663,7 +486,7 @@ void search_impl(raft::resources const& handle, chunk_index.data(), nullptr, nullptr, - filter_adapter, + sample_filter, grid_dim_x, stream); if (grid_dim_x == 0) { @@ -717,7 +540,7 @@ void search_impl(raft::resources const& handle, chunk_index.data(), dist_out_ptr, idx_out_ptr, - filter_adapter, + sample_filter, grid_dim_x, stream); @@ -770,7 +593,7 @@ void search_impl(raft::resources const& handle, chunk_index.data(), all_distances.data(), all_indices.data(), - filter_adapter, + sample_filter, gdx, stream); diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu deleted file mode 100644 index 40029119b2..0000000000 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search_half_uint8_t_int64_t.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include - -#include "ivf_sq_search.cuh" - -namespace cuvs::neighbors::ivf_sq { - -#define CUVS_INST_IVF_SQ_SEARCH(T, CodeT) \ - void search(raft::resources const& handle, \ - const cuvs::neighbors::ivf_sq::search_params& params, \ - const cuvs::neighbors::ivf_sq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter) \ - { \ - cuvs::neighbors::ivf_sq::detail::search( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ - } - -CUVS_INST_IVF_SQ_SEARCH(half, uint8_t); - -#undef CUVS_INST_IVF_SQ_SEARCH - -} // namespace cuvs::neighbors::ivf_sq diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_sq/ivf_sq_search_uint8_t_int64_t.cu similarity index 97% rename from cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu rename to cpp/src/neighbors/ivf_sq/ivf_sq_search_uint8_t_int64_t.cu index de185de8ec..48247e207b 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search_float_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search_uint8_t_int64_t.cu @@ -23,6 +23,7 @@ namespace cuvs::neighbors::ivf_sq { } CUVS_INST_IVF_SQ_SEARCH(float, uint8_t); +CUVS_INST_IVF_SQ_SEARCH(half, uint8_t); #undef CUVS_INST_IVF_SQ_SEARCH From 55a91bf794dc23f0795b8b2c8d0d7d00008f451c Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 30 Apr 2026 15:24:20 +0200 Subject: [PATCH 27/31] doc fix + build assert addition --- cpp/include/cuvs/neighbors/ivf_sq.hpp | 14 +++++++++----- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 2 ++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cpp/include/cuvs/neighbors/ivf_sq.hpp b/cpp/include/cuvs/neighbors/ivf_sq.hpp index ba4bb39437..b342d0ce50 100644 --- a/cpp/include/cuvs/neighbors/ivf_sq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_sq.hpp @@ -130,12 +130,16 @@ using list_data = ivf::list; * In the IVF-SQ index, a database vector is first assigned to the nearest cluster center * using an inverted file (IVF) structure, and then compressed using scalar quantization (SQ). * - * Scalar quantization independently maps each dimension of the vector to a fixed-width integer - * code. For 8-bit quantization (uint8_t), each floating-point component is linearly mapped to - * an integer in [0, 255] using learned per-dimension minimum (`sq_vmin`) and range (`sq_delta`) - * values: + * Scalar quantization independently maps each dimension of the per-cluster residual (the + * input vector minus its assigned centroid) to a fixed-width integer code. For 8-bit + * quantization (uint8_t), each residual component r_i = x_i - centroid_i is linearly + * mapped to an integer in [0, 255] using learned per-dimension minimum (`sq_vmin`) and + * step-size (`sq_delta`) values: * - * code_i = round((x_i - vmin_i) / delta_i * 255) + * code_i = clamp(round((r_i - vmin_i) / delta_i), 0, 255) + * + * where delta_i is the per-level step size (range divided by 255), so the corresponding + * reconstruction is x_i ≈ centroid_i + vmin_i + code_i * delta_i. * * This provides a compact representation (1 byte per dimension) while preserving the relative * distances between vectors with high fidelity, offering a good trade-off between index size, diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index a4a2063ca5..7574c3d892 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -548,6 +548,8 @@ inline auto build( static_assert(std::is_same_v || std::is_same_v, "unsupported data type"); RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); + RAFT_EXPECTS(params.kmeans_trainset_fraction > 0.0 && params.kmeans_trainset_fraction <= 1.0, + "kmeans_trainset_fraction must be in (0, 1]"); RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::CosineExpanded || dim > 1, "Cosine metric requires more than one dim"); From 6d5ec7247af304e37d6736052a1883861f54f462 Mon Sep 17 00:00:00 2001 From: vic Date: Mon, 4 May 2026 10:45:51 +0200 Subject: [PATCH 28/31] Switching to raft::TxN_t --- cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh | 91 +++++------------------ 1 file changed, 19 insertions(+), 72 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh index 7574c3d892..52be653985 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh @@ -30,6 +30,7 @@ #include #include #include +#include #include @@ -58,82 +59,28 @@ struct ColMinMaxOp { /** * Vectorized load helper: reads VecCols contiguous elements of type T as - * a single aligned wide load and unpacks them into floats. + * a single aligned wide load (via raft::TxN_t -> __ldg on the promoted + * io_t) and unpacks them into floats. * * The primary benefit over scalar loads is halving (VecCols=2) or * quartering (VecCols=4) the number of LDG instructions issued per warp, * which is the dominant cost in the column-strided access pattern of - * fused_column_minmax_kernel. VecCols=1 is provided as the degenerate - * scalar fallback for odd `dim`. + * fused_column_minmax_kernel. VecCols=1 is the degenerate scalar + * fallback for odd `dim`. * - * Requires `p` to be aligned to sizeof(T) * VecCols. + * Requires `p` to be aligned to sizeof(raft::IOType::Type), + * i.e. sizeof(T) * VecCols. */ template -struct vec_loader; - -template <> -struct vec_loader { - __device__ __forceinline__ static void load(const float* p, float (&out)[1]) { out[0] = *p; } -}; - -template <> -struct vec_loader { - __device__ __forceinline__ static void load(const half* p, float (&out)[1]) - { - out[0] = float(*p); - } -}; - -template <> -struct vec_loader { - __device__ __forceinline__ static void load(const float* p, float (&out)[4]) - { - float4 v = *reinterpret_cast(p); - out[0] = v.x; - out[1] = v.y; - out[2] = v.z; - out[3] = v.w; - } -}; - -template <> -struct vec_loader { - __device__ __forceinline__ static void load(const float* p, float (&out)[2]) - { - float2 v = *reinterpret_cast(p); - out[0] = v.x; - out[1] = v.y; - } -}; - -template <> -struct vec_loader { - __device__ __forceinline__ static void load(const half* p, float (&out)[4]) - { - // Single 8-byte load covering 4 halves; memcpy avoids aliasing issues - // and is compiled to a register move in device code. - uint2 raw = *reinterpret_cast(p); - half h[4]; - static_assert(sizeof(h) == sizeof(raw), "unexpected half packing"); - memcpy(&h[0], &raw, sizeof(raw)); +__device__ __forceinline__ void load_cols_as_float(const T* p, float (&out)[VecCols]) +{ + raft::TxN_t v; + v.load(p, 0); #pragma unroll - for (int k = 0; k < 4; ++k) - out[k] = float(h[k]); - } -}; - -template <> -struct vec_loader { - __device__ __forceinline__ static void load(const half* p, float (&out)[2]) - { - uint32_t raw = *reinterpret_cast(p); - half h[2]; - static_assert(sizeof(h) == sizeof(raw), "unexpected half packing"); - memcpy(&h[0], &raw, sizeof(raw)); - out[0] = float(h[0]); - out[1] = float(h[1]); + for (int k = 0; k < VecCols; ++k) { + out[k] = static_cast(v.val.data[k]); } -}; +} /** * Fused per-column min+max in a single pass (2x less DRAM traffic than two @@ -176,10 +123,10 @@ __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __r // the column and row axes. for (; row + 3 * stride < n_rows; row += 4 * stride) { float r0[VecCols], r1[VecCols], r2[VecCols], r3[VecCols]; - vec_loader::load(data + row * dim + col_base, r0); - vec_loader::load(data + (row + stride) * dim + col_base, r1); - vec_loader::load(data + (row + 2 * stride) * dim + col_base, r2); - vec_loader::load(data + (row + 3 * stride) * dim + col_base, r3); + load_cols_as_float(data + row * dim + col_base, r0); + load_cols_as_float(data + (row + stride) * dim + col_base, r1); + load_cols_as_float(data + (row + 2 * stride) * dim + col_base, r2); + load_cols_as_float(data + (row + 3 * stride) * dim + col_base, r3); #pragma unroll for (int k = 0; k < VecCols; ++k) { float mn = fminf(fminf(r0[k], r1[k]), fminf(r2[k], r3[k])); @@ -190,7 +137,7 @@ __launch_bounds__(BlockSize) RAFT_KERNEL fused_column_minmax_kernel(const T* __r } for (; row < n_rows; row += stride) { float r[VecCols]; - vec_loader::load(data + row * dim + col_base, r); + load_cols_as_float(data + row * dim + col_base, r); #pragma unroll for (int k = 0; k < VecCols; ++k) { agg[k].min_val = fminf(agg[k].min_val, r[k]); From 688962424f93ee27587264a04f902a0232af562b Mon Sep 17 00:00:00 2001 From: vic Date: Tue, 5 May 2026 11:39:27 +0200 Subject: [PATCH 29/31] Dropping the MetricTag template parameter --- cpp/CMakeLists.txt | 57 ++++++++++++++++- .../detail/jit_lto/ivf_sq/scan_fragments.hpp | 17 ++++- .../accumulate_distance_impl.cuh | 46 ++++++++++++++ .../accumulate_distance_kernel.cu.in | 24 +++++++ .../accumulate_distance_matrix.json | 3 + .../jit_lto_kernels/device_functions.cuh | 28 +++++++++ .../finalize_distance_impl.cuh | 40 ++++++++++++ .../finalize_distance_kernel.cu.in | 23 +++++++ .../finalize_distance_matrix.json | 3 + .../detail/jit_lto_kernels/scan_impl.cuh | 63 +++++++++---------- .../detail/jit_lto_kernels/scan_kernel.cu.in | 42 ++++++------- .../detail/jit_lto_kernels/scan_matrix.json | 11 +++- .../detail/jit_lto_kernels/scan_planner.hpp | 28 ++++++++- .../setup_invariant_smem_impl.cuh | 50 +++++++++++++++ .../setup_invariant_smem_kernel.cu.in | 27 ++++++++ .../setup_invariant_smem_matrix.json | 3 + .../setup_per_probe_smem_impl.cuh | 55 ++++++++++++++++ .../setup_per_probe_smem_kernel.cu.in | 27 ++++++++ .../setup_per_probe_smem_matrix.json | 3 + cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 12 +++- 20 files changed, 497 insertions(+), 65 deletions(-) create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_impl.cuh create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_kernel.cu.in create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_matrix.json create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_kernel.cu.in create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_matrix.json create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_impl.cuh create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_kernel.cu.in create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_matrix.json create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_impl.cuh create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_kernel.cu.in create mode 100644 cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_matrix.json diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f957f8b702..65cec8837a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -490,18 +490,69 @@ if(NOT BUILD_CPU_ONLY) set(ivf_sq_ns "cuvs::neighbors::ivf_sq::detail") generate_jit_lto_kernels( jit_lto_files - NAME_FORMAT "ivf_sq_scan_capacity_@capacity@_metric_@metric_name@" + NAME_FORMAT "ivf_sq_scan_capacity_@capacity@_@ascending_descending@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in" - FRAGMENT_TAG_FORMAT - "${ivf_sq_ns}::fragment_tag_ivf_sq_scan<${ivf_sq_ns}::tag_metric_@metric_name@, @capacity@>" + FRAGMENT_TAG_FORMAT "${ivf_sq_ns}::fragment_tag_ivf_sq_scan<@capacity@, @ascending_value@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/scan" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "ivf_sq_setup_invariant_smem_metric_@metric_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${ivf_sq_ns}::fragment_tag_setup_invariant_smem<${ivf_sq_ns}::tag_metric_@metric_name@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/setup_invariant_smem" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "ivf_sq_setup_per_probe_smem_metric_@metric_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${ivf_sq_ns}::fragment_tag_setup_per_probe_smem<${ivf_sq_ns}::tag_metric_@metric_name@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/setup_per_probe_smem" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "ivf_sq_accumulate_distance_metric_@metric_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${ivf_sq_ns}::fragment_tag_accumulate_distance<${ivf_sq_ns}::tag_metric_@metric_name@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/accumulate_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "ivf_sq_finalize_distance_metric_@metric_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${ivf_sq_ns}::fragment_tag_finalize_distance<${ivf_sq_ns}::tag_metric_@metric_name@>" + FRAGMENT_TAG_HEADER_FILES "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/finalize_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT "ivf_sq_filter_@filter_name@" diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp index b20684b11f..c505ed3f8f 100644 --- a/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp @@ -11,10 +11,25 @@ struct tag_metric_l2 {}; struct tag_metric_ip {}; struct tag_metric_cosine {}; -template +// Scan entrypoint fragment. Templated only on (Capacity, Ascending). The +// metric specialization lives in the four device-function fragments below. +template struct fragment_tag_ivf_sq_scan {}; template struct fragment_tag_ivf_sq_filter {}; +// Metric-specific device-function fragments composed in at JIT-link time. +template +struct fragment_tag_setup_invariant_smem {}; + +template +struct fragment_tag_setup_per_probe_smem {}; + +template +struct fragment_tag_accumulate_distance {}; + +template +struct fragment_tag_finalize_distance {}; + } // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_impl.cuh new file mode 100644 index 0000000000..5368e101cd --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_impl.cuh @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +// Per-element distance accumulator. Called inside the unrolled inner loop of +// ivf_sq_scan_impl. After JIT-LTO inlining, these bodies fold directly into +// the unrolled loop and the v_norm_sq plumbing is dead-code-eliminated for +// non-cosine metrics. +// +// L2: diff = qt - code*scale; dist += diff*diff +// IP: v = aux + code*scale; dist += qt*v +// Cosine: as IP, plus v_norm_sq += v*v + +__device__ void accumulate_distance_l2_impl( + float qt, float /* aux */, float scale, uint8_t code, float& dist, float& /* v_norm_sq */) +{ + float recon = float(code) * scale; + float diff = qt - recon; + dist += diff * diff; +} + +__device__ void accumulate_distance_ip_impl( + float qt, float aux, float scale, uint8_t code, float& dist, float& /* v_norm_sq */) +{ + float recon = float(code) * scale; + float v_d = aux + recon; + dist += qt * v_d; +} + +__device__ void accumulate_distance_cosine_impl( + float qt, float aux, float scale, uint8_t code, float& dist, float& v_norm_sq) +{ + float recon = float(code) * scale; + float v_d = aux + recon; + dist += qt * v_d; + v_norm_sq += v_d * v_d; +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_kernel.cu.in new file mode 100644 index 0000000000..7d7e77c85a --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_kernel.cu.in @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr auto accumulate_distance_impl = + cuvs::neighbors::ivf_sq::detail::accumulate_distance_@metric_name@_impl; + +} // namespace + +namespace cuvs::neighbors::ivf_sq::detail { + +__device__ void accumulate_distance( + float qt, float aux, float scale, uint8_t code, float& dist, float& v_norm_sq) +{ + accumulate_distance_impl(qt, aux, scale, code, dist, v_norm_sq); +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_matrix.json new file mode 100644 index 0000000000..915ea1284a --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/accumulate_distance_matrix.json @@ -0,0 +1,3 @@ +{ + "metric_name": ["l2", "ip", "cosine"] +} diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh index 188933d803..ca22c21168 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/device_functions.cuh @@ -22,4 +22,32 @@ __device__ bool sample_filter(const IndexT* const* const inds_ptrs, IndexT bitset_len, IndexT original_nbits); +// Forward declarations of the metric-specific scan device functions. The +// concrete implementations are provided by JIT-LTO fragments generated from +// setup_invariant_smem_kernel.cu.in, setup_per_probe_smem_kernel.cu.in, +// accumulate_distance_kernel.cu.in and finalize_distance_kernel.cu.in. After +// nvJitLink LTO inlines these into ivf_sq_scan_impl, the codegen matches the +// pre-refactor `if constexpr (kIsL2 / kIsCosine)` form. + +// Phase 1: load the metric-invariant smem array. Called once per query. +__device__ void setup_invariant_smem(uint32_t dim, + const float* __restrict__ query, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ s_query_term); + +// Phase 2: load the per-probe smem array. Called once per (query, probe). +__device__ void setup_per_probe_smem(uint32_t dim, + const float* __restrict__ centroid, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ s_query_term); + +// Per-element distance accumulator. Called inside the unrolled inner loop. +__device__ void accumulate_distance( + float qt, float aux, float scale, uint8_t code, float& dist, float& v_norm_sq); + +// Per-row distance finalize. Called once per scanned row. +__device__ float finalize_distance(float dist, float v_norm_sq, float query_norm); + } // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh new file mode 100644 index 0000000000..2d31ff863a --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh @@ -0,0 +1,40 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +// Per-row distance finalize. Called once per scanned row. +// +// L2/IP: return dist +// Cosine: denom = query_norm * sqrtf(v_norm_sq); +// return denom > 0 ? 1 - dist/denom : 0 + +__device__ float finalize_distance_l2_impl(float dist, + float /* v_norm_sq */, + float /* query_norm */) +{ + return dist; +} + +__device__ float finalize_distance_ip_impl(float dist, + float /* v_norm_sq */, + float /* query_norm */) +{ + return dist; +} + +__device__ float finalize_distance_cosine_impl(float dist, float v_norm_sq, float query_norm) +{ + float denom = query_norm * sqrtf(v_norm_sq); + return (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_kernel.cu.in new file mode 100644 index 0000000000..642e4f1d5d --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_kernel.cu.in @@ -0,0 +1,23 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr auto finalize_distance_impl = + cuvs::neighbors::ivf_sq::detail::finalize_distance_@metric_name@_impl; + +} // namespace + +namespace cuvs::neighbors::ivf_sq::detail { + +__device__ float finalize_distance(float dist, float v_norm_sq, float query_norm) +{ + return finalize_distance_impl(dist, v_norm_sq, query_norm); +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_matrix.json new file mode 100644 index 0000000000..915ea1284a --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_matrix.json @@ -0,0 +1,3 @@ +{ + "metric_name": ["l2", "ip", "cosine"] +} diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh index bc8f7587ee..38bea46eef 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh @@ -43,6 +43,18 @@ using sq_block_sort_t = typename sq_block_sort::type; // IVF-SQ scan kernel body with fused in-kernel top-k. // +// The kernel is metric-agnostic: the four metric-specific pieces are +// linked in at runtime via JIT-LTO as fragments declared in +// device_functions.cuh: +// - setup_invariant_smem (Phase 1, once per query) +// - setup_per_probe_smem (Phase 2, once per probe) +// - accumulate_distance (per-element, inner unrolled loop) +// - finalize_distance (per-row, after the dim accumulation) +// +// The host launcher (ivf_sq_search.cuh) derives Ascending from the metric +// (Ascending = !is_ip) and registers the matching metric variant of each +// fragment with the planner. +// // Grid layout: // kManageLocalTopK (Capacity > 0): // grid (grid_dim_x, n_queries) - each block loops over probes @@ -67,7 +79,7 @@ using sq_block_sort_t = typename sq_block_sort::type; // Reconstructed vector component: s_aux[d] + code*s_sq_scale[d]. // // After all probes are scanned, the smem is reused for block_sort merge. -template +template __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs, const uint32_t* list_sizes, const uint32_t* coarse_indices, @@ -93,10 +105,6 @@ __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs constexpr int BlockDim = kSqScanThreads; constexpr bool kManageLocalTopK = (Capacity > 0); - constexpr bool kIsL2 = std::is_same_v; - constexpr bool kIsCosine = std::is_same_v; - constexpr bool kIsIP = std::is_same_v; - constexpr bool kAscending = !kIsIP; extern __shared__ __align__(256) uint8_t smem_buf[]; float* smem = reinterpret_cast(smem_buf); @@ -108,6 +116,11 @@ __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs const uint32_t query_ix = blockIdx.y; const float* query = queries_float + query_ix * dim; + // Hoist the per-query scalar load to a uniform value. The cosine fragment of + // finalize_distance is the only consumer; for L2/IP, the unused argument is + // dead-code-eliminated after JIT-LTO inlining and this load disappears. + const float q_norm = (query_norms != nullptr) ? query_norms[query_ix] : 0.0f; + if constexpr (kManageLocalTopK) { out_distances += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; out_indices += uint64_t(query_ix) * k * gridDim.x + blockIdx.x * k; @@ -116,15 +129,11 @@ __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs // Phase 1: load shared memory that is invariant across probes. for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { s_sq_scale[d] = sq_delta[d]; - if constexpr (kIsL2) { - s_aux[d] = query[d] - sq_vmin[d]; - } else { - s_query_term[d] = query[d]; - } } + setup_invariant_smem(dim, query, sq_vmin, s_aux, s_query_term); __syncthreads(); - using local_topk_t = sq_block_sort_t; + using local_topk_t = sq_block_sort_t; local_topk_t queue(k); const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; @@ -154,16 +163,7 @@ __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs const uint32_t cluster_id = my_coarse[probe_ix]; const uint32_t cluster_sz = list_sizes[cluster_id]; - { - const float* centroid = centers + cluster_id * dim; - for (uint32_t d = threadIdx.x; d < dim; d += BlockDim) { - if constexpr (kIsL2) { - s_query_term[d] = s_aux[d] - centroid[d]; - } else { - s_aux[d] = centroid[d] + sq_vmin[d]; - } - } - } + setup_per_probe_smem(dim, centers + cluster_id * dim, sq_vmin, s_aux, s_query_term); __syncthreads(); // (b) if (cluster_sz == 0) { @@ -198,24 +198,17 @@ __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs #pragma unroll for (uint32_t j = 0; j < veclen; j++) { if (l + j < dim) { - float recon = float(codes_local[j]) * s_sq_scale[l + j]; - - if constexpr (kIsL2) { - float diff = s_query_term[l + j] - recon; - dist += diff * diff; - } else { - float v_d = s_aux[l + j] + recon; - dist += s_query_term[l + j] * v_d; - if constexpr (kIsCosine) { v_norm_sq += v_d * v_d; } - } + accumulate_distance(s_query_term[l + j], + s_aux[l + j], + s_sq_scale[l + j], + codes_local[j], + dist, + v_norm_sq); } } } - if constexpr (kIsCosine) { - float denom = query_norms[query_ix] * sqrtf(v_norm_sq); - dist = (denom > 0.0f) ? 1.0f - dist / denom : 0.0f; - } + dist = finalize_distance(dist, v_norm_sq, q_norm); if constexpr (kManageLocalTopK) { float val = valid ? dist : local_topk_t::queue_t::kDummy; diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in index 3c59f91764..252ad98ed3 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in @@ -8,8 +8,8 @@ namespace { -using metric_tag = cuvs::neighbors::ivf_sq::detail::tag_metric_@metric_name@; -constexpr int capacity = @capacity@; +constexpr int capacity = @capacity@; +constexpr bool ascending = @ascending_value@; } // namespace @@ -36,25 +36,25 @@ extern "C" __global__ __launch_bounds__(kSqScanThreads) void ivf_sq_scan( int64_t bitset_len, int64_t original_nbits) { - ivf_sq_scan_impl(data_ptrs, - list_sizes, - coarse_indices, - queries_float, - centers, - sq_vmin, - sq_delta, - query_norms, - n_probes, - dim, - k, - max_samples, - chunk_indices, - out_distances, - out_indices, - inds_ptrs, - bitset_ptr, - bitset_len, - original_nbits); + ivf_sq_scan_impl(data_ptrs, + list_sizes, + coarse_indices, + queries_float, + centers, + sq_vmin, + sq_delta, + query_norms, + n_probes, + dim, + k, + max_samples, + chunk_indices, + out_distances, + out_indices, + inds_ptrs, + bitset_ptr, + bitset_len, + original_nbits); } static_assert(std::is_same_v>); diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json index fac4154c18..79250fea1b 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json @@ -2,7 +2,14 @@ "capacity": [ "0", "32", "64", "128", "256" ], - "metric_name": [ - "l2", "ip", "cosine" + "_ascending": [ + { + "ascending_descending": "ascending", + "ascending_value": "true" + }, + { + "ascending_descending": "descending", + "ascending_value": "false" + } ] } diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp index ae157e3994..da0fc1ce23 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp @@ -18,10 +18,10 @@ struct IvfSqScanPlanner : AlgorithmPlanner { IvfSqScanPlanner() : AlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} - template + template void add_entrypoint() { - this->add_static_fragment>(); + this->add_static_fragment>(); } template @@ -33,6 +33,30 @@ struct IvfSqScanPlanner : AlgorithmPlanner { cuvs::neighbors::detail::tag_index_i64, FilterTag>>(); } + + template + void add_setup_invariant_smem_function() + { + this->add_static_fragment>(); + } + + template + void add_setup_per_probe_smem_function() + { + this->add_static_fragment>(); + } + + template + void add_accumulate_distance_function() + { + this->add_static_fragment>(); + } + + template + void add_finalize_distance_function() + { + this->add_static_fragment>(); + } }; } // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_impl.cuh new file mode 100644 index 0000000000..b0e2d5ea9e --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_impl.cuh @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +// Phase 1 invariant smem setup. Called once per query. +// +// Loads the metric-invariant smem array. The metric-variant array is left +// untouched and will be filled per-probe by setup_per_probe_smem. + +__device__ void setup_invariant_smem_l2_impl(uint32_t dim, + const float* __restrict__ query, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ /* s_query_term */) +{ + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + s_aux[d] = query[d] - sq_vmin[d]; + } +} + +__device__ void setup_invariant_smem_ip_impl(uint32_t dim, + const float* __restrict__ query, + const float* __restrict__ /* sq_vmin */, + float* __restrict__ /* s_aux */, + float* __restrict__ s_query_term) +{ + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + s_query_term[d] = query[d]; + } +} + +__device__ void setup_invariant_smem_cosine_impl(uint32_t dim, + const float* __restrict__ query, + const float* __restrict__ /* sq_vmin */, + float* __restrict__ /* s_aux */, + float* __restrict__ s_query_term) +{ + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + s_query_term[d] = query[d]; + } +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_kernel.cu.in new file mode 100644 index 0000000000..43a31f7fc4 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_kernel.cu.in @@ -0,0 +1,27 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr auto setup_invariant_smem_impl = + cuvs::neighbors::ivf_sq::detail::setup_invariant_smem_@metric_name@_impl; + +} // namespace + +namespace cuvs::neighbors::ivf_sq::detail { + +__device__ void setup_invariant_smem(uint32_t dim, + const float* __restrict__ query, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ s_query_term) +{ + setup_invariant_smem_impl(dim, query, sq_vmin, s_aux, s_query_term); +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_matrix.json new file mode 100644 index 0000000000..915ea1284a --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_invariant_smem_matrix.json @@ -0,0 +1,3 @@ +{ + "metric_name": ["l2", "ip", "cosine"] +} diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_impl.cuh new file mode 100644 index 0000000000..6e0b19a1f9 --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_impl.cuh @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_sq::detail { + +// Phase 2 per-probe smem setup. Called once per (query, probe). +// +// Loads the per-probe smem array using the invariant array filled in Phase 1. +// +// L2: s_query_term[d] = s_aux[d] - centroid[d] (where s_aux holds query - sq_vmin) +// IP/Cosine: s_aux[d] = centroid[d] + sq_vmin[d] +// +// The "in" and "out" smem arrays do not alias — they are distinct regions of +// the kernel's smem layout — so __restrict__ is safe. + +__device__ void setup_per_probe_smem_l2_impl(uint32_t dim, + const float* __restrict__ centroid, + const float* __restrict__ /* sq_vmin */, + float* __restrict__ s_aux, + float* __restrict__ s_query_term) +{ + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + s_query_term[d] = s_aux[d] - centroid[d]; + } +} + +__device__ void setup_per_probe_smem_ip_impl(uint32_t dim, + const float* __restrict__ centroid, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ /* s_query_term */) +{ + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + s_aux[d] = centroid[d] + sq_vmin[d]; + } +} + +__device__ void setup_per_probe_smem_cosine_impl(uint32_t dim, + const float* __restrict__ centroid, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ /* s_query_term */) +{ + for (uint32_t d = threadIdx.x; d < dim; d += blockDim.x) { + s_aux[d] = centroid[d] + sq_vmin[d]; + } +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_kernel.cu.in new file mode 100644 index 0000000000..81d206f78d --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_kernel.cu.in @@ -0,0 +1,27 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr auto setup_per_probe_smem_impl = + cuvs::neighbors::ivf_sq::detail::setup_per_probe_smem_@metric_name@_impl; + +} // namespace + +namespace cuvs::neighbors::ivf_sq::detail { + +__device__ void setup_per_probe_smem(uint32_t dim, + const float* __restrict__ centroid, + const float* __restrict__ sq_vmin, + float* __restrict__ s_aux, + float* __restrict__ s_query_term) +{ + setup_per_probe_smem_impl(dim, centroid, sq_vmin, s_aux, s_query_term); +} + +} // namespace cuvs::neighbors::ivf_sq::detail diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_matrix.json new file mode 100644 index 0000000000..915ea1284a --- /dev/null +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/setup_per_probe_smem_matrix.json @@ -0,0 +1,3 @@ +{ + "metric_name": ["l2", "ip", "cosine"] +} diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index bbe6f05df1..d4171a0fcc 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -155,8 +155,18 @@ void launch_kernel(const index& idx, constexpr int kThreads = kSqScanThreads; uint32_t dim = idx.dim(); + // The scan entrypoint is templated only on (Capacity, Ascending). Metric + // specialization is composed in at link time via four device-function + // fragments (setup_invariant_smem, setup_per_probe_smem, + // accumulate_distance, finalize_distance). + constexpr bool kAscending = !std::is_same_v; + IvfSqScanPlanner kernel_planner; - kernel_planner.add_entrypoint(); + kernel_planner.add_entrypoint(); + kernel_planner.add_setup_invariant_smem_function(); + kernel_planner.add_setup_per_probe_smem_function(); + kernel_planner.add_accumulate_distance_function(); + kernel_planner.add_finalize_distance_function(); kernel_planner.add_filter_device_function(); auto kernel_launcher = kernel_planner.get_launcher(); From df55c51b3cf74cb460d0e0d954c68c00269caebd Mon Sep 17 00:00:00 2001 From: vic Date: Tue, 5 May 2026 12:01:26 +0200 Subject: [PATCH 30/31] Inner product trick --- cpp/CMakeLists.txt | 4 +- .../detail/jit_lto/ivf_sq/scan_fragments.hpp | 8 ++-- .../finalize_distance_impl.cuh | 13 ++++-- .../detail/jit_lto_kernels/scan_impl.cuh | 33 +++++++++------ .../detail/jit_lto_kernels/scan_kernel.cu.in | 41 +++++++++---------- .../detail/jit_lto_kernels/scan_matrix.json | 10 ----- .../detail/jit_lto_kernels/scan_planner.hpp | 4 +- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 38 +++++++++-------- 8 files changed, 80 insertions(+), 71 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 65cec8837a..f501f5a39d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -490,12 +490,12 @@ if(NOT BUILD_CPU_ONLY) set(ivf_sq_ns "cuvs::neighbors::ivf_sq::detail") generate_jit_lto_kernels( jit_lto_files - NAME_FORMAT "ivf_sq_scan_capacity_@capacity@_@ascending_descending@" + NAME_FORMAT "ivf_sq_scan_capacity_@capacity@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in" - FRAGMENT_TAG_FORMAT "${ivf_sq_ns}::fragment_tag_ivf_sq_scan<@capacity@, @ascending_value@>" + FRAGMENT_TAG_FORMAT "${ivf_sq_ns}::fragment_tag_ivf_sq_scan<@capacity@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_sq/scan" diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp index c505ed3f8f..d16d47d653 100644 --- a/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/ivf_sq/scan_fragments.hpp @@ -11,9 +11,11 @@ struct tag_metric_l2 {}; struct tag_metric_ip {}; struct tag_metric_cosine {}; -// Scan entrypoint fragment. Templated only on (Capacity, Ascending). The -// metric specialization lives in the four device-function fragments below. -template +// Scan entrypoint fragment. Templated only on Capacity. The warpsort runs +// ascending for all metrics; metric specialization lives in the four +// device-function fragments below (finalize_distance is responsible for +// returning a min-close value). +template struct fragment_tag_ivf_sq_scan {}; template diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh index 2d31ff863a..4276f80dd7 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/finalize_distance_impl.cuh @@ -13,9 +13,14 @@ namespace cuvs::neighbors::ivf_sq::detail { // Per-row distance finalize. Called once per scanned row. // -// L2/IP: return dist -// Cosine: denom = query_norm * sqrtf(v_norm_sq); -// return denom > 0 ? 1 - dist/denom : 0 +// All metrics return a min-close value so the warpsort epilogue and the +// host-side select_k can be hardcoded to ascending / select-min: +// L2: return dist (squared L2 is min-close) +// IP: return -dist (negate so min-close; +// postprocess_distances +// undoes the sign) +// Cosine: denom = query_norm * sqrtf(v_norm_sq); +// return denom > 0 ? 1 - dist/denom : 0 (cosine distance is min-close) __device__ float finalize_distance_l2_impl(float dist, float /* v_norm_sq */, @@ -28,7 +33,7 @@ __device__ float finalize_distance_ip_impl(float dist, float /* v_norm_sq */, float /* query_norm */) { - return dist; + return -dist; } __device__ float finalize_distance_cosine_impl(float dist, float v_norm_sq, float query_norm) diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh index 38bea46eef..85cb9605be 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_impl.cuh @@ -23,23 +23,27 @@ namespace cuvs::neighbors::ivf_sq::detail { // block_sort type selection: dispatch the dummy block sort when Capacity == 0 // so the same impl body works for both the fused top-k path (Capacity > 0) and // the materialize-all path (Capacity == 0). -template +// +// All metrics are min-close after finalize_distance (IP is negated; cosine +// returns 1 - cos_sim; L2 is squared L2), so the warpsort is hardcoded to +// ascending. +template struct sq_block_sort { using type = raft::matrix::detail::select::warpsort::block_sort< raft::matrix::detail::select::warpsort::warp_sort_filtered, Capacity, - Ascending, + /*Ascending=*/true, float, uint32_t>; }; -template -struct sq_block_sort<0, Ascending> { - using type = ivf::detail::dummy_block_sort_t; +template <> +struct sq_block_sort<0> { + using type = ivf::detail::dummy_block_sort_t; }; -template -using sq_block_sort_t = typename sq_block_sort::type; +template +using sq_block_sort_t = typename sq_block_sort::type; // IVF-SQ scan kernel body with fused in-kernel top-k. // @@ -49,11 +53,14 @@ using sq_block_sort_t = typename sq_block_sort::type; // - setup_invariant_smem (Phase 1, once per query) // - setup_per_probe_smem (Phase 2, once per probe) // - accumulate_distance (per-element, inner unrolled loop) -// - finalize_distance (per-row, after the dim accumulation) +// - finalize_distance (per-row, after the dim accumulation; ensures +// all metrics are min-close so the warpsort +// epilogue and the host select_k can be +// hardcoded to ascending / select-min) // -// The host launcher (ivf_sq_search.cuh) derives Ascending from the metric -// (Ascending = !is_ip) and registers the matching metric variant of each -// fragment with the planner. +// The host launcher (ivf_sq_search.cuh) registers the matching metric +// variant of each fragment with the planner. IP scores are negated by +// finalize_distance and undone by postprocess_distances. // // Grid layout: // kManageLocalTopK (Capacity > 0): @@ -79,7 +86,7 @@ using sq_block_sort_t = typename sq_block_sort::type; // Reconstructed vector component: s_aux[d] + code*s_sq_scale[d]. // // After all probes are scanned, the smem is reused for block_sort merge. -template +template __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs, const uint32_t* list_sizes, const uint32_t* coarse_indices, @@ -133,7 +140,7 @@ __device__ __forceinline__ void ivf_sq_scan_impl(const uint8_t* const* data_ptrs setup_invariant_smem(dim, query, sq_vmin, s_aux, s_query_term); __syncthreads(); - using local_topk_t = sq_block_sort_t; + using local_topk_t = sq_block_sort_t; local_topk_t queue(k); const uint32_t* my_coarse = coarse_indices + query_ix * n_probes; diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in index 252ad98ed3..d48c4d259f 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_kernel.cu.in @@ -8,8 +8,7 @@ namespace { -constexpr int capacity = @capacity@; -constexpr bool ascending = @ascending_value@; +constexpr int capacity = @capacity@; } // namespace @@ -36,25 +35,25 @@ extern "C" __global__ __launch_bounds__(kSqScanThreads) void ivf_sq_scan( int64_t bitset_len, int64_t original_nbits) { - ivf_sq_scan_impl(data_ptrs, - list_sizes, - coarse_indices, - queries_float, - centers, - sq_vmin, - sq_delta, - query_norms, - n_probes, - dim, - k, - max_samples, - chunk_indices, - out_distances, - out_indices, - inds_ptrs, - bitset_ptr, - bitset_len, - original_nbits); + ivf_sq_scan_impl(data_ptrs, + list_sizes, + coarse_indices, + queries_float, + centers, + sq_vmin, + sq_delta, + query_norms, + n_probes, + dim, + k, + max_samples, + chunk_indices, + out_distances, + out_indices, + inds_ptrs, + bitset_ptr, + bitset_len, + original_nbits); } static_assert(std::is_same_v>); diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json index 79250fea1b..4af912b122 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_matrix.json @@ -1,15 +1,5 @@ { "capacity": [ "0", "32", "64", "128", "256" - ], - "_ascending": [ - { - "ascending_descending": "ascending", - "ascending_value": "true" - }, - { - "ascending_descending": "descending", - "ascending_value": "false" - } ] } diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp index da0fc1ce23..05ea34532e 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp @@ -18,10 +18,10 @@ struct IvfSqScanPlanner : AlgorithmPlanner { IvfSqScanPlanner() : AlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} - template + template void add_entrypoint() { - this->add_static_fragment>(); + this->add_static_fragment>(); } template diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index d4171a0fcc..86a3ee8a1f 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -155,14 +155,12 @@ void launch_kernel(const index& idx, constexpr int kThreads = kSqScanThreads; uint32_t dim = idx.dim(); - // The scan entrypoint is templated only on (Capacity, Ascending). Metric - // specialization is composed in at link time via four device-function - // fragments (setup_invariant_smem, setup_per_probe_smem, - // accumulate_distance, finalize_distance). - constexpr bool kAscending = !std::is_same_v; - + // The scan entrypoint is templated only on Capacity. The warpsort runs + // ascending for all metrics; metric specialization is composed in at link + // time via four device-function fragments (setup_invariant_smem, + // setup_per_probe_smem, accumulate_distance, finalize_distance). IvfSqScanPlanner kernel_planner; - kernel_planner.add_entrypoint(); + kernel_planner.add_entrypoint(); kernel_planner.add_setup_invariant_smem_function(); kernel_planner.add_setup_per_probe_smem_function(); kernel_planner.add_accumulate_distance_function(); @@ -348,7 +346,6 @@ void search_impl(raft::resources const& handle, uint32_t n_queries, uint32_t k, uint32_t n_probes, - bool select_min, int64_t* neighbors, float* distances, rmm::device_async_resource_ref search_mr, @@ -357,6 +354,13 @@ void search_impl(raft::resources const& handle, auto stream = raft::resource::get_cuda_stream(handle); auto dim = index.dim(); + // The scan kernel emits min-close distances for every metric (IP scores are + // negated by finalize_distance and undone by postprocess_distances below), + // so all scan-related top-k selections use select_min = true. Only the + // coarse search, which runs on the raw GEMM output, follows the metric's + // native direction. + const bool coarse_select_min = cuvs::distance::is_min_close(index.metric()); + std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); bool needs_query_norms = index.metric() != cuvs::distance::DistanceType::InnerProduct; @@ -467,7 +471,7 @@ void search_impl(raft::resources const& handle, raft::make_device_matrix_view(coarse_distances_dev.data(), n_queries, n_probes), raft::make_device_matrix_view( coarse_indices_dev.data(), n_queries, n_probes), - select_min); + coarse_select_min); rmm::device_uvector num_samples(n_queries, stream, search_mr); rmm::device_uvector chunk_index(n_queries_probes, stream, search_mr); @@ -563,7 +567,7 @@ void search_impl(raft::resources const& handle, raft::make_device_matrix_view(indices_tmp.data(), n_queries, cols), raft::make_device_matrix_view(distances, n_queries, k), raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), - select_min); + /*select_min=*/true); } } else { // --- Fallback: materialize all distances --- @@ -577,12 +581,10 @@ void search_impl(raft::resources const& handle, rmm::device_uvector all_indices( std::size_t(n_queries) * max_samples, stream, search_mr); - float init_val = - select_min ? std::numeric_limits::max() : std::numeric_limits::lowest(); thrust::fill_n(raft::resource::get_thrust_policy(handle), all_distances.data(), std::size_t(n_queries) * max_samples, - init_val); + std::numeric_limits::max()); thrust::fill_n(raft::resource::get_thrust_policy(handle), all_indices.data(), std::size_t(n_queries) * max_samples, @@ -618,14 +620,19 @@ void search_impl(raft::resources const& handle, all_indices.data(), n_queries, max_samples), raft::make_device_matrix_view(distances, n_queries, k), raft::make_device_matrix_view(neighbors_uint32_ptr, n_queries, k), - select_min, + /*select_min=*/true, false, cuvs::selection::SelectAlgo::kAuto, num_samples_view); } + // For IP, the kernel returned negated scores so the warpsort could run + // ascending; postprocess_distances multiplies by -1 to restore the public + // contract (positive inner products). Cosine distance is already min-close + // so it must not be flipped (account_for_max_close = false for cosine). + const bool account_for_max_close = index.metric() != cuvs::distance::DistanceType::CosineExpanded; ivf::detail::postprocess_distances( - handle, distances, distances, index.metric(), n_queries, k, 1.0, false); + handle, distances, distances, index.metric(), n_queries, k, 1.0, account_for_max_close); ivf::detail::postprocess_neighbors(neighbors, neighbors_uint32_ptr, @@ -707,7 +714,6 @@ inline void search_with_filtering(raft::resources const& handle, queries_batch, k, n_probes, - cuvs::distance::is_min_close(index.metric()), neighbors + std::size_t(offset_q) * k, distances + std::size_t(offset_q) * k, raft::resource::get_workspace_resource_ref(handle), From c5948a2755d5b0b88c64549ce47547b3bd2db518 Mon Sep 17 00:00:00 2001 From: vic Date: Tue, 5 May 2026 12:59:36 +0200 Subject: [PATCH 31/31] Fix + minor cleanups --- cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh | 64 ++++++++++++---------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh index 86a3ee8a1f..6bc64c1569 100644 --- a/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh +++ b/cpp/src/neighbors/ivf_sq/ivf_sq_search.cuh @@ -365,6 +365,9 @@ void search_impl(raft::resources const& handle, bool needs_query_norms = index.metric() != cuvs::distance::DistanceType::InnerProduct; rmm::device_uvector query_norm_dev(needs_query_norms ? n_queries : 0, stream, search_mr); + // Pass nullptr to the kernel for IP, where the empty uvector's data() pointer + // is implementation-defined and the kernel's `query_norms` argument is unused. + const float* qn_ptr = needs_query_norms ? query_norm_dev.data() : nullptr; rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); rmm::device_uvector coarse_distances_dev(n_queries_probes, stream, search_mr); rmm::device_uvector coarse_indices_dev(n_queries_probes, stream, search_mr); @@ -491,7 +494,7 @@ void search_impl(raft::resources const& handle, ivf_sq_scan(handle, index, converted_queries_ptr, - query_norm_dev.data(), + qn_ptr, n_queries, n_probes, k, @@ -513,15 +516,10 @@ void search_impl(raft::resources const& handle, } } - // Prepare uint32 neighbors buffer for postprocessing - rmm::device_uvector neighbors_uint32(0, stream, search_mr); - uint32_t* neighbors_uint32_ptr = nullptr; - if constexpr (sizeof(int64_t) == sizeof(uint32_t)) { - neighbors_uint32_ptr = reinterpret_cast(neighbors); - } else { - neighbors_uint32.resize(std::size_t(n_queries) * k, stream); - neighbors_uint32_ptr = neighbors_uint32.data(); - } + // The scan kernel writes sample-local uint32 indices; postprocess_neighbors + // lifts them to int64_t database indices via inds_ptrs. + rmm::device_uvector neighbors_uint32(std::size_t(n_queries) * k, stream, search_mr); + uint32_t* neighbors_uint32_ptr = neighbors_uint32.data(); if (manage_local_topk) { // --- Fused top-k path --- @@ -545,7 +543,7 @@ void search_impl(raft::resources const& handle, ivf_sq_scan(handle, index, converted_queries_ptr, - query_norm_dev.data(), + qn_ptr, n_queries, n_probes, k, @@ -596,7 +594,7 @@ void search_impl(raft::resources const& handle, ivf_sq_scan(handle, index, converted_queries_ptr, - query_norm_dev.data(), + qn_ptr, n_queries, n_probes, k, @@ -674,32 +672,42 @@ inline void search_with_filtering(raft::resources const& handle, bool manage_local_topk = is_local_topk_feasible(k); - uint32_t max_samples = 0; - if (!manage_local_topk) { - int64_t ms = std::max(index.accum_sorted_sizes()(n_probes), k); - RAFT_EXPECTS(ms <= int64_t(std::numeric_limits::max()), - "The maximum sample size is too big."); - max_samples = static_cast(ms); - } + // Always compute max_samples: even on the fused path, search_impl may fall + // back to the materialize-all branch when the scan kernel reports zero + // occupancy, in which case the workspace must already be sized for the + // (much larger) materialize-path requirement. + int64_t ms = std::max(index.accum_sorted_sizes()(n_probes), k); + RAFT_EXPECTS(ms <= int64_t(std::numeric_limits::max()), + "The maximum sample size is too big."); + uint32_t max_samples = static_cast(ms); constexpr uint64_t kExpectedWsSize = 1024ull * 1024 * 1024; uint64_t max_ws_size = std::min(raft::resource::get_workspace_free_bytes(handle), kExpectedWsSize); uint64_t converted_query_floats = std::is_same_v ? 0 : index.dim(); + + // Materialize-path estimate: used directly for the non-fused path, and as + // a conservative upper bound for the fused path so the fallback branch in + // search_impl never overshoots the workspace. + uint64_t ws_materialize = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + + max_samples + converted_query_floats) + + sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + max_samples + k); + uint64_t ws_per_query; if (manage_local_topk) { - // Fused path: only small per-query buffers for coarse search + chunk indices - // (The scan output is at most grid_dim_x * k per query, which is small) - // Conservatively assume grid_dim_x <= n_probes for the workspace estimate + // Fused path: only small per-query buffers for coarse search + chunk indices. + // The scan output is at most grid_dim_x * k per query (we conservatively + // assume grid_dim_x <= n_probes for the workspace estimate). uint64_t fused_out = uint64_t(n_probes) * k; - ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + fused_out + - converted_query_floats) + - sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + fused_out + k); + uint64_t ws_fused = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + fused_out + + converted_query_floats) + + sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + fused_out + k); + // Take the conservative max so search_impl can safely fall back to the + // materialize path if the fused-path occupancy probe returns zero. + ws_per_query = std::max(ws_fused, ws_materialize); } else { - ws_per_query = sizeof(float) * (uint64_t(index.n_lists()) + n_probes + 1 + max_samples + - converted_query_floats) + - sizeof(uint32_t) * (uint64_t(n_probes) * 2 + 1 + max_samples + k); + ws_per_query = ws_materialize; } const uint32_t max_queries =