diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index 1e3a48694a..c049d7e3b5 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -244,7 +244,7 @@ extern "C" cuvsError_t cuvsProductQuantizerGetPqCodebook(cuvsProductQuantizer_t if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { auto pq_mdspan = (reinterpret_cast*>(quant_addr)) - ->vpq_codebooks.pq_code_book.view(); + ->codebooks.pq_code_book(); cuvs::core::to_dlpack(pq_mdspan, pq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", @@ -264,10 +264,12 @@ extern "C" cuvsError_t cuvsProductQuantizerGetVqCodebook(cuvsProductQuantizer_t if (quantizer != nullptr) { auto quant_addr = quantizer->addr; if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { - auto pq_mdspan = + auto vq_opt = (reinterpret_cast*>(quant_addr)) - ->vpq_codebooks.vq_code_book.view(); - cuvs::core::to_dlpack(pq_mdspan, vq_codebook); + ->codebooks.vq_code_book(); + RAFT_EXPECTS(vq_opt.has_value(), + "quantizer has no VQ codebook (build with use_vq=true to enable)"); + cuvs::core::to_dlpack(vq_opt.value(), vq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", quantizer->dtype.code, diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 4e4527267c..6b5d2ebe9c 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -157,7 +157,7 @@ function(ConfigureAnnBench) add_dependencies(${BENCH_NAME} ANN_BENCH) else() add_executable(${BENCH_NAME} ${ConfigureAnnBench_PATH}) - target_compile_definitions(${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN>) + target_compile_definitions(${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN) target_link_libraries( ${BENCH_NAME} PRIVATE benchmark::benchmark $<$:CUDA::nvtx3> ) diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index 98dd94c2e1..dbb5c840d8 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include #include @@ -357,8 +359,10 @@ void cuvs_cagra::set_search_dataset(const T* dataset, size_t nrow) } else { using ds_idx_type = decltype(index_->data().n_rows()); bool is_vpq = - dynamic_cast*>(&index_->data()) || - dynamic_cast*>(&index_->data()); + dynamic_cast*>( + &index_->data()) || + dynamic_cast*>( + &index_->data()); // It can happen that we are re-using a previous algo object which already has // the dataset set. Check if we need update. if (static_cast(input_dataset_v_->extent(0)) != nrow || @@ -385,8 +389,10 @@ void cuvs_cagra::save(const std::string& file) const } else { using ds_idx_type = decltype(index_->data().n_rows()); bool is_vpq = - dynamic_cast*>(&index_->data()) || - dynamic_cast*>(&index_->data()); + dynamic_cast*>( + &index_->data()) || + dynamic_cast*>( + &index_->data()); cuvs::neighbors::cagra::serialize(handle_, file, *index_, is_vpq); } } diff --git a/cpp/bench/ann/src/diskann/diskann_benchmark.cpp b/cpp/bench/ann/src/diskann/diskann_benchmark.cpp index 4129f3cdab..cc437389cb 100644 --- a/cpp/bench/ann/src/diskann/diskann_benchmark.cpp +++ b/cpp/bench/ann/src/diskann/diskann_benchmark.cpp @@ -26,7 +26,7 @@ void parse_build_param(const nlohmann::json& conf, { param.R = conf.at("R"); param.L_build = conf.at("L_build"); - if (conf.contains("alpha")) { param.num_threads = conf.at("alpha"); } + if (conf.contains("alpha")) { param.alpha = conf.at("alpha"); } if (conf.contains("num_threads")) { param.num_threads = conf.at("num_threads"); } } diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index c7111aaf4a..da81570218 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -391,100 +391,6 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; return make_strided_dataset(res, std::forward(src), required_stride); } -/** - * @brief VPQ compressed dataset. - * - * The dataset is compressed using two level quantization - * - * 1. Vector Quantization - * 2. Product Quantization of residuals - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - * - */ -template -struct vpq_dataset : public dataset { - using index_type = IdxT; - using math_type = MathT; - /** Vector Quantization codebook - "coarse cluster centers". */ - raft::device_matrix vq_code_book; - /** Product Quantization codebook - "fine cluster centers". */ - raft::device_matrix pq_code_book; - /** Compressed dataset. */ - raft::device_matrix data; - - vpq_dataset(raft::device_matrix&& vq_code_book, - raft::device_matrix&& pq_code_book, - raft::device_matrix&& data) - : vq_code_book{std::move(vq_code_book)}, - pq_code_book{std::move(pq_code_book)}, - data{std::move(data)} - { - } - - [[nodiscard]] auto n_rows() const noexcept -> index_type final { return data.extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } - - /** Row length of the encoded data in bytes. */ - [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t - { - return data.extent(1); - } - /** The number of "coarse cluster centers" */ - [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t - { - return vq_code_book.extent(0); - } - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t - { - /* - NOTE: pq_bits and the book size - - Normally, we'd store `pq_bits` as a part of the index. - However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is - the same as the number of possible code values. Hence, we don't store the pq_bits and derive it - from the array dimensions instead. - */ - auto pq_width = pq_n_centers(); -#ifdef __cpp_lib_bitops - return std::countr_zero(pq_width); -#else - uint32_t pq_bits = 0; - while (pq_width > 1) { - pq_bits++; - pq_width >>= 1; - } - return pq_bits; -#endif - } - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t - { - return raft::div_rounding_up_unsafe(dim(), pq_len()); - } - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t - { - return pq_code_book.extent(1); - } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t - { - return pq_code_book.extent(0); - } -}; - -template -struct is_vpq_dataset : std::false_type {}; - -template -struct is_vpq_dataset> : std::true_type {}; - -template -inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; namespace filtering { diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 104e43dc72..4fe4c9bd14 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -135,19 +136,21 @@ struct params { /** * @brief Defines and stores VPQ codebooks upon training * - * @tparam T data element type + * The quantizer holds a vpq_dataset, which can either own the codebooks + * or non-owning (referencing external codebooks). * + * @tparam T data element type */ template struct quantizer { /** Parameters used to build this quantizer. */ params params_quantizer; - /** VPQ codebooks produced during training. */ - cuvs::neighbors::vpq_dataset vpq_codebooks; + /** VPQ codebooks (owning or view). */ + cuvs::preprocessing::quantize::pq::vpq_codebooks codebooks; }; /** - * @brief Initializes a product quantizer to be used later for quantizing the dataset. + * @brief Initializes a product quantizer by training on the dataset (owning). * * The use of a pool memory resource is recommended for more consistent training performance. * @@ -161,7 +164,7 @@ struct quantizer { * @endcode * * @param[in] res raft resource - * @param[in] params configure product quantizer, e.g. quantile + * @param[in] params configure product quantizer, e.g. pq_bits, pq_dim * @param[in] dataset a row-major matrix view on device or host * * @return quantizer @@ -175,6 +178,48 @@ quantizer build(raft::resources const& res, const params params, raft::host_matrix_view dataset); +/** + * @brief Creates a product quantizer from pre-computed codebooks. + * + * This function creates a non-owning quantizer that references the provided codebooks. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * // Assume pq_centers and vq_centers are pre-computed on device + * cuvs::preprocessing::quantize::pq::params params; + * params.pq_bits = 8; + * params.pq_dim = 32; + * params.use_vq = true; + * params.use_subspaces = true; + * // With VQ centers: + * auto quant_view = cuvs::preprocessing::quantize::pq::build(handle, params, + * pq_centers_view, + * std::make_optional>(vq_centers_view)); + * // Without VQ (PQ only): + * auto quant_pq_only = cuvs::preprocessing::quantize::pq::build(handle, params, pq_centers_view); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure product quantizer parameters. Must be fully specified + * (pq_bits, pq_dim must be set; use_subspaces and use_vq must match the codebook shapes). + * @param[in] pq_centers PQ codebook on device memory: + * - For use_subspaces=true: [pq_dim * pq_n_centers, pq_len] + * - For use_subspaces=false: [pq_n_centers, pq_len] + * where pq_n_centers = (1 << pq_bits), pq_len = dim / pq_dim + * @param[in] vq_centers Optional VQ codebook on device memory [vq_n_centers, dim]. + * Required when use_vq=true. Defaults to std::nullopt (no VQ). + * + * @return A view-type quantizer that references the provided data + */ +quantizer build( + raft::resources const& res, + const params params, + raft::device_matrix_view pq_centers, + std::optional> vq_centers = + std::nullopt); + /** * @brief Applies quantization transform to given dataset * diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp new file mode 100644 index 0000000000..52f1c7d958 --- /dev/null +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -0,0 +1,208 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include + +namespace cuvs::preprocessing::quantize::pq { + +/** + * @brief Abstract interface for VPQ codebook access. + * + * @tparam MathT the type of elements in the codebooks + */ +template +class vpq_codebooks_iface { + public: + using math_type = MathT; + + virtual ~vpq_codebooks_iface() = default; + + /** + * VQ codebook [vq_n_centers, dim]. + * + * Returns std::nullopt when no VQ codebook is configured (i.e. PQ-only, + * use_vq=false). Callers that need to forward a `device_matrix_view` + * downstream should materialize an empty 0x0 view themselves on nullopt. + */ + [[nodiscard]] virtual auto vq_code_book() const noexcept + -> std::optional> = 0; + + /** PQ codebook [pq_n_centers (× pq_dim for subspaces), pq_len]. */ + [[nodiscard]] virtual auto pq_code_book() const noexcept + -> raft::device_matrix_view = 0; + + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t + { + auto vq = vq_code_book(); + if (!vq.has_value()) { + RAFT_LOG_WARN( + "vpq_codebooks_iface::dim() returns 0 when no VQ codebook is present; " + "the original vector dimension cannot be recovered from PQ codebooks alone."); + return 0; + } + return vq->extent(1); + } + [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t + { + auto vq = vq_code_book(); + return vq.has_value() ? vq->extent(0) : 0; + } + [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t + { + return pq_code_book().extent(1); + } + [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book().extent(0); + } + [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t + { + auto w = pq_n_centers(); + uint32_t bits = 0; + while (w > 1) { + bits++; + w >>= 1; + } + return bits; + } + [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t + { + auto l = pq_len(); + return l > 0 ? (dim() + l - 1) / l : 0; + } +}; + +/** + * @addtogroup pq + * @{ + */ + +/** + * @brief PIMPL wrapper for VQ + PQ codebooks. + * + * Internally delegates to either an owning implementation (holds device + * matrices) or a view implementation (references external device memory). + * + * @tparam MathT the type of elements in the codebooks + */ +template +class vpq_codebooks { + public: + using math_type = MathT; + + vpq_codebooks() = default; + + explicit vpq_codebooks(std::unique_ptr> impl) : impl_{std::move(impl)} + { + } + + vpq_codebooks(const vpq_codebooks&) = delete; + vpq_codebooks& operator=(const vpq_codebooks&) = delete; + vpq_codebooks(vpq_codebooks&&) = default; + vpq_codebooks& operator=(vpq_codebooks&&) = default; + ~vpq_codebooks() = default; + + /** + * VQ codebook view, or std::nullopt when no VQ codebook is configured. + */ + [[nodiscard]] auto vq_code_book() const noexcept + -> std::optional> + { + return impl_->vq_code_book(); + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->pq_code_book(); + } + + [[nodiscard]] auto dim() const noexcept -> uint32_t { return impl_->dim(); } + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } + + private: + std::unique_ptr> impl_; +}; + +/** + * @brief VPQ compressed dataset. + * + * Holds a set of VQ + PQ codebooks together with the encoded dataset. + * Both the codebooks and the encoded data are always owned by this object. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + */ +template +class vpq_dataset : public cuvs::neighbors::dataset { + public: + using index_type = IdxT; + using math_type = MathT; + + /** VQ + PQ codebooks (owning or view). */ + vpq_codebooks codebooks; + /** Encoded (compressed) data [n_rows, encoded_row_length]. */ + raft::device_matrix data; + + vpq_dataset() = default; + + vpq_dataset(vpq_codebooks&& codebooks_in, + raft::device_matrix&& data_in) + : codebooks{std::move(codebooks_in)}, data{std::move(data_in)} + { + } + + vpq_dataset(const vpq_dataset&) = delete; + vpq_dataset& operator=(const vpq_dataset&) = delete; + vpq_dataset(vpq_dataset&&) = default; + vpq_dataset& operator=(vpq_dataset&&) = default; + ~vpq_dataset() override = default; + + [[nodiscard]] index_type n_rows() const noexcept override { return data.extent(0); } + [[nodiscard]] uint32_t dim() const noexcept override { return codebooks.dim(); } + [[nodiscard]] bool is_owning() const noexcept override { return true; } + + /** + * VQ codebook view, or std::nullopt when no VQ codebook is configured. + */ + [[nodiscard]] auto vq_code_book() const noexcept + -> std::optional> + { + return codebooks.vq_code_book(); + } + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return codebooks.pq_code_book(); + } + [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t { return data.extent(1); } + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return codebooks.vq_n_centers(); } + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return codebooks.pq_bits(); } + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return codebooks.pq_dim(); } + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return codebooks.pq_len(); } + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return codebooks.pq_n_centers(); } +}; + +template +struct is_vpq_dataset : std::false_type {}; + +template +struct is_vpq_dataset> : std::true_type {}; + +template +inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; + +/** @} */ // end of group pq + +} // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index f1650980e0..9db2afbde0 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -24,6 +24,7 @@ #include "../../ivf_common.cuh" #include "../../ivf_pq/ivf_pq_search.cuh" #include +#include // TODO: This shouldn't be calling spatial/knn apis #include "../ann_utils.cuh" @@ -173,11 +174,15 @@ void search_main(raft::resources const& res, neighbors, distances, sample_filter); - } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); + } else if (auto* vpq_dset = dynamic_cast< + const cuvs::preprocessing::quantize::pq::vpq_dataset*>( + &index.data()); vpq_dset != nullptr) { // Search using a compressed dataset RAFT_FAIL("FP32 VPQ dataset support is coming soon"); - } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); + } else if (auto* vpq_dset = dynamic_cast< + const cuvs::preprocessing::quantize::pq::vpq_dataset*>( + &index.data()); vpq_dset != nullptr) { auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric(), nullptr); diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 2b69a1cef4..6daefa501f 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include "compute_distance.hpp" #include +#include #include @@ -31,14 +32,14 @@ struct vpq_descriptor_spec : public instance_spec { template constexpr static inline auto accepts_dataset() - -> std::enable_if_t, bool> + -> std::enable_if_t, bool> { return std::is_same_v; } template constexpr static inline auto accepts_dataset() - -> std::enable_if_t, bool> + -> std::enable_if_t, bool> { return false; } @@ -49,11 +50,12 @@ struct vpq_descriptor_spec : public instance_spec { cuvs::distance::DistanceType metric, const DistanceT* dataset_norms = nullptr) -> host_type { + auto vq_view = dataset.vq_code_book(); return init_(params, dataset.data.data_handle(), dataset.encoded_row_length(), - dataset.vq_code_book.data_handle(), - dataset.pq_code_book.data_handle(), + vq_view.has_value() ? vq_view->data_handle() : nullptr, + dataset.pq_code_book().data_handle(), IndexT(dataset.n_rows()), dataset.dim()); } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 26cd13bab8..e29ca05579 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -107,12 +107,12 @@ template auto make_key(const cagra::search_params& params, const DatasetT& dataset, cuvs::distance::DistanceType metric) - -> std::enable_if_t, key> + -> std::enable_if_t, key> { return key{reinterpret_cast(dataset.data.data_handle()), uint64_t(dataset.n_rows()), dataset.dim(), - uint32_t(reinterpret_cast(dataset.pq_code_book.data_handle()) >> 6), + uint32_t(reinterpret_cast(dataset.pq_code_book().data_handle()) >> 6), uint32_t(params.team_size), uint32_t(metric)}; } diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index be11f2da53..d2cc0432ee 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -4,7 +4,9 @@ */ #pragma once +#include "../../preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include +#include #include #include @@ -59,7 +61,7 @@ void serialize(const raft::resources& res, template void serialize(const raft::resources& res, std::ostream& os, - const vpq_dataset& dataset) + const cuvs::preprocessing::quantize::pq::vpq_dataset& dataset) { raft::serialize_scalar(res, os, dataset.n_rows()); raft::serialize_scalar(res, os, dataset.dim()); @@ -67,9 +69,11 @@ void serialize(const raft::resources& res, raft::serialize_scalar(res, os, dataset.pq_n_centers()); raft::serialize_scalar(res, os, dataset.pq_len()); raft::serialize_scalar(res, os, dataset.encoded_row_length()); - raft::serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); - raft::serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); - raft::serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); + auto vq_view = dataset.vq_code_book().value_or( + raft::make_device_matrix_view(nullptr, 0, 0)); + raft::serialize_mdspan(res, os, vq_view); + raft::serialize_mdspan(res, os, dataset.pq_code_book()); + raft::serialize_mdspan(res, os, dataset.data.view()); } template @@ -99,12 +103,16 @@ void serialize(const raft::resources& res, std::ostream& os, const dataset raft::serialize_scalar(res, os, CUDA_R_8U); return serialize(res, os, *x); } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + if (auto x = + dynamic_cast*>(&dataset); + x != nullptr) { raft::serialize_scalar(res, os, kSerializeVPQDataset); raft::serialize_scalar(res, os, CUDA_R_32F); return serialize(res, os, *x); } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + if (auto x = + dynamic_cast*>(&dataset); + x != nullptr) { raft::serialize_scalar(res, os, kSerializeVPQDataset); raft::serialize_scalar(res, os, CUDA_R_16F); return serialize(res, os, *x); @@ -134,7 +142,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is) template auto deserialize_vpq(raft::resources const& res, std::istream& is) - -> std::unique_ptr> + -> std::unique_ptr> { auto n_rows = raft::deserialize_scalar(res, is); auto dim = raft::deserialize_scalar(res, is); @@ -154,8 +162,13 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, pq_code_book.view()); raft::deserialize_mdspan(res, is, data.view()); - return std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(data)); + std::optional> vq_code_book_opt; + if (vq_n_centers > 0) { vq_code_book_opt = std::move(vq_code_book); } + return std::make_unique>( + cuvs::preprocessing::quantize::pq::vpq_codebooks{ + std::make_unique>( + std::move(pq_code_book), std::move(vq_code_book_opt))}, + std::move(data)); } template diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 107bc39ee5..72c059f649 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -644,12 +644,8 @@ index build( // process in batches const uint32_t n_rows = dataset.extent(0); - auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( - pq_params, - cuvs::neighbors::vpq_dataset{ - raft::make_device_matrix(res, 0, 0), - std::move(pq_codebook), - raft::make_device_matrix(res, 0, 0)}); + auto quantizer = cuvs::preprocessing::quantize::pq::build( + res, pq_params, raft::make_const_mdspan(pq_codebook.view())); const int64_t codes_rowlen = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); quantized_vectors = raft::make_device_matrix(res, n_rows, codes_rowlen); diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 75785da5d2..3d5e7f2247 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -6,6 +6,7 @@ #include #include +#include #include "../../cluster/kmeans_balanced.cuh" #include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield @@ -415,7 +416,7 @@ void process_and_fill_codes( bool inline_vq_labels = false) { using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; + using cdataset_t = cuvs::preprocessing::quantize::pq::vpq_dataset; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); @@ -803,7 +804,7 @@ void process_and_fill_codes_subspaces( raft::device_matrix_view codes) { using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; + using cdataset_t = cuvs::preprocessing::quantize::pq::vpq_dataset; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 66ebd994b2..47d542bcef 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -288,7 +288,7 @@ index build( // Codebooks from VPQ have the shape [subspace idx, subspace dim, code] // This converts the codebook into matrix format for easy interoperability // with open-source ScaNN search - auto full_codebook_view = pq_quantizer.vpq_codebooks.pq_code_book.view(); + auto full_codebook_view = pq_quantizer.codebooks.pq_code_book(); raft::linalg::map_offset( res, diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 5d77e2dd44..87771474b8 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -8,6 +8,7 @@ #include "../../../core/nvtx.hpp" #include "../../../neighbors/detail/vpq_dataset.cuh" #include "pq_codepacking.cuh" // pq_bits-bitfield +#include "vpq_dataset_impl.hpp" #include #include @@ -183,7 +184,6 @@ quantizer build( if (filled_params.use_vq) { vq_code_book = cuvs::neighbors::detail::train_vq(res, vpq_params, dataset); } - auto empty_codes = raft::make_device_matrix(res, 0, 0); auto pq_code_book = raft::make_device_matrix(res, 0, 0); if (filled_params.use_subspaces) { pq_code_book = train_pq_subspaces( @@ -192,9 +192,63 @@ quantizer build( pq_code_book = cuvs::neighbors::detail::train_pq( res, filled_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } + std::optional> vq_code_book_opt; + if (filled_params.use_vq) { vq_code_book_opt = std::move(vq_code_book); } return {filled_params, - cuvs::neighbors::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes)}}; + vpq_codebooks{std::make_unique>( + std::move(pq_code_book), std::move(vq_code_book_opt))}}; +} + +template +quantizer build_view( + raft::resources const& res, + const params& params, + raft::device_matrix_view pq_centers, + std::optional> vq_centers = + std::nullopt) +{ + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, + "PQ bits must be within [4, 16], got %u", + params.pq_bits); + RAFT_EXPECTS(params.pq_dim > 0, "pq_dim must be specified for view-type quantizer"); + + const uint32_t pq_n_centers = 1u << params.pq_bits; + + if (params.use_subspaces) { + RAFT_EXPECTS(pq_centers.extent(0) == params.pq_dim * pq_n_centers, + "For use_subspaces=true, pq_centers must have shape [pq_dim * pq_n_centers, " + "pq_len], got [%u, %u]", + pq_centers.extent(0), + pq_centers.extent(1)); + } else { + RAFT_EXPECTS(pq_centers.extent(0) == pq_n_centers, + "For use_subspaces=false, pq_centers must have shape [pq_n_centers, pq_len], got " + "[%u, %u]", + pq_centers.extent(0), + pq_centers.extent(1)); + } + + if (params.use_vq) { + RAFT_EXPECTS(vq_centers.has_value(), "vq_centers must be provided when use_vq=true"); + RAFT_EXPECTS(params.vq_n_centers > 0, + "params.vq_n_centers must be > 0 when use_vq=true (got %u)", + params.vq_n_centers); + RAFT_EXPECTS(vq_centers.value().data_handle() != nullptr, + "vq_centers data pointer must be non-null when use_vq=true"); + RAFT_EXPECTS(vq_centers.value().extent(0) == params.vq_n_centers, + "vq_centers must have shape [vq_n_centers, dim] (vq_n_centers=%u), got " + "extent(0)=%u", + params.vq_n_centers, + vq_centers.value().extent(0)); + return { + params, + vpq_codebooks{std::make_unique>(pq_centers, vq_centers)}}; + } else { + if (vq_centers.has_value()) { + RAFT_LOG_WARN("vq_centers will be ignored since params.use_vq=false"); + } + return {params, vpq_codebooks{std::make_unique>(pq_centers)}}; + } } template @@ -216,8 +270,21 @@ void transform( "Output matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - // Encode dataset - auto vq_centers = raft::make_const_mdspan(quantizer.vpq_codebooks.vq_code_book.view()); + + // Honor params.use_vq as the source of truth: when it is false, pass an + // empty view to the kernel regardless of what the codebooks contain + // (the kernel gates VQ subtraction on vq_centers.empty(), so an empty view + // guarantees no residual VQ is applied even if a misconfigured quantizer + // somehow carries non-empty centers). Conversely, when it is true the + // codebook must be present. + auto vq_centers_opt = quantizer.codebooks.vq_code_book(); + RAFT_EXPECTS(!quantizer.params_quantizer.use_vq || vq_centers_opt.has_value(), + "Quantizer has params.use_vq=true but no VQ codebook"); + auto vq_centers = + quantizer.params_quantizer.use_vq + ? vq_centers_opt.value() + : raft::make_device_matrix_view(nullptr, 0, 0); + auto pq_centers = quantizer.codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } @@ -226,7 +293,7 @@ void transform( res, to_vpq_params(quantizer.params_quantizer), dataset, - raft::make_const_mdspan(quantizer.vpq_codebooks.pq_code_book.view()), + pq_centers, vq_centers, vq_labels_view, pq_codes_out); @@ -235,7 +302,7 @@ void transform( res, to_vpq_params(quantizer.params_quantizer), dataset, - raft::make_const_mdspan(quantizer.vpq_codebooks.pq_code_book.view()), + pq_centers, vq_centers, vq_labels_view, pq_codes_out); @@ -294,8 +361,6 @@ auto reconstruct_vectors( bool use_subspaces) { const IdxT n_rows = out_vectors.extent(0); - const IdxT dim = out_vectors.extent(1); - const IdxT pq_dim = params.pq_dim; const IdxT pq_bits = params.pq_bits; const IdxT pq_n_centers = IdxT{1} << pq_bits; @@ -331,14 +396,12 @@ auto reconstruct_vectors( kernel<<>>( codes, out_vectors, pq_centers, vq_centers, vq_labels, pq_bits, use_subspaces); RAFT_CUDA_TRY(cudaPeekAtLastError()); - - return codes; } template void inverse_transform( raft::resources const& res, - const quantizer& quant, + const quantizer& quantizer, raft::device_matrix_view codes, raft::device_matrix_view out, std::optional> vq_labels = std::nullopt) @@ -352,34 +415,67 @@ void inverse_transform( using idx_t = int64_t; RAFT_EXPECTS(out.extent(0) == codes.extent(0), "Output matrix must have the same number of rows as the input codes"); - RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quant.params_quantizer), + RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quantizer.params_quantizer), "Codes matrix doesn't have the correct number of columns"); - RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, + RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - reconstruct_vectors( - res, - quant.params_quantizer, - codes, - raft::make_const_mdspan(quant.vpq_codebooks.pq_code_book.view()), - raft::make_const_mdspan(quant.vpq_codebooks.vq_code_book.view()), - vq_labels, - out, - quant.params_quantizer.use_subspaces); + + // Honor params.use_vq strictly (see the matching block in transform()). + auto vq_centers_opt = quantizer.codebooks.vq_code_book(); + RAFT_EXPECTS(!quantizer.params_quantizer.use_vq || vq_centers_opt.has_value(), + "Quantizer has params.use_vq=true but no VQ codebook"); + auto vq_centers = + quantizer.params_quantizer.use_vq + ? vq_centers_opt.value() + : raft::make_device_matrix_view(nullptr, 0, 0); + + // VQ-label preflight: when use_vq is true the kernel reads vq_labels(row) per + // row and falls back to label 0 when vq_labels is absent. Without this check every vector can be + // reconstructed with vq_label 0. + if (quantizer.params_quantizer.use_vq) { + RAFT_EXPECTS(vq_labels.has_value(), + "When params.use_vq is true, vq_labels must be provided to inverse_transform()"); + RAFT_EXPECTS(vq_labels.value().extent(0) == codes.extent(0), + "When params.use_vq is true, vq_labels must have the same number of rows as " + "codes (got %zu vs %zu)", + size_t(vq_labels.value().extent(0)), + size_t(codes.extent(0))); + } + + reconstruct_vectors(res, + quantizer.params_quantizer, + codes, + quantizer.codebooks.pq_code_book(), + vq_centers, + vq_labels, + out, + quantizer.params_quantizer.use_subspaces); } -template -void vpq_convert_math_type(const raft::resources& res, - const cuvs::neighbors::vpq_dataset& src, - cuvs::neighbors::vpq_dataset& dst) +template +auto vpq_convert_math_type(const raft::resources& res, const vpq_codebooks& src) + -> vpq_codebooks { - raft::linalg::map(res, - dst.vq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); - raft::linalg::map(res, - dst.pq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); + auto vq_src_opt = src.vq_code_book(); + auto pq_src = src.pq_code_book(); + + auto pq_new = raft::make_device_matrix( + res, pq_src.extent(0), pq_src.extent(1)); + raft::linalg::map( + res, pq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, pq_src); + + std::optional> vq_new_opt; + if (vq_src_opt.has_value()) { + auto vq_src = vq_src_opt.value(); + auto vq_new = raft::make_device_matrix( + res, vq_src.extent(0), vq_src.extent(1)); + raft::linalg::map( + res, vq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, vq_src); + vq_new_opt = std::move(vq_new); + } + + return vpq_codebooks{ + std::make_unique>(std::move(pq_new), std::move(vq_new_opt))}; } inline auto make_pq_params_from_vpq(const cuvs::neighbors::vpq_params& in_params, @@ -409,11 +505,10 @@ inline auto make_pq_params_from_vpq(const cuvs::neighbors::vpq_params& in_params template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset + const DatasetT& dataset) -> vpq_dataset { using label_t = uint32_t; - // Use a heuristic to impute missing parameters. - auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); + auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); auto pq_params = make_pq_params_from_vpq(ps, dataset.extent(0)); // Train codes @@ -421,7 +516,6 @@ auto vpq_build(const raft::resources& res, auto pq_code_book = cuvs::neighbors::detail::train_pq( res, pq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); - // Encode dataset const IdxT n_rows = dataset.extent(0); const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); @@ -437,21 +531,21 @@ auto vpq_build(const raft::resources& res, codes.view(), true); - return cuvs::neighbors::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; + return vpq_dataset{ + vpq_codebooks{std::make_unique>(std::move(pq_code_book), + std::move(vq_code_book))}, + std::move(codes)}; } template auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset + const DatasetT& dataset) -> vpq_dataset { - auto old_type = vpq_build(res, params, dataset); - auto new_type = cuvs::neighbors::vpq_dataset{ - raft::make_device_mdarray(res, old_type.vq_code_book.extents()), - raft::make_device_mdarray(res, old_type.pq_code_book.extents()), - std::move(old_type.data)}; - vpq_convert_math_type(res, old_type, new_type); - return new_type; + // Build in float, then convert codebooks to half; data (uint8 codes) is moved, not copied. + auto float_ds = vpq_build(res, params, dataset); + auto half_codebooks = vpq_convert_math_type(res, float_ds.codebooks); + return vpq_dataset{std::move(half_codebooks), std::move(float_ds.data)}; } + } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp new file mode 100644 index 0000000000..9d50044ff6 --- /dev/null +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -0,0 +1,92 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include + +namespace cuvs::preprocessing::quantize::pq { + +template +class vpq_codebooks_owning : public vpq_codebooks_iface { + public: + using math_type = MathT; + using matrix_type = raft::device_matrix; + + // PQ codebook is required; VQ codebook is optional and defaults to absent. + // When VQ is not provided, vq_code_book() returns std::nullopt. + explicit vpq_codebooks_owning(matrix_type&& pq_code_book, + std::optional&& vq_code_book = std::nullopt) + : pq_code_book_{std::move(pq_code_book)}, vq_code_book_{std::move(vq_code_book)} + { + } + + vpq_codebooks_owning(const vpq_codebooks_owning&) = delete; + vpq_codebooks_owning& operator=(const vpq_codebooks_owning&) = delete; + vpq_codebooks_owning(vpq_codebooks_owning&&) = default; + vpq_codebooks_owning& operator=(vpq_codebooks_owning&&) = default; + ~vpq_codebooks_owning() override = default; + + [[nodiscard]] auto vq_code_book() const noexcept + -> std::optional> override + { + if (!vq_code_book_.has_value()) { return std::nullopt; } + return vq_code_book_.value().view(); + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view override + { + return pq_code_book_.view(); + } + + [[nodiscard]] auto pq_code_book() noexcept + -> raft::device_matrix_view + { + return pq_code_book_.view(); + } + + private: + matrix_type pq_code_book_; + std::optional vq_code_book_; +}; + +template +class vpq_codebooks_view : public vpq_codebooks_iface { + public: + using math_type = MathT; + using view_type = raft::device_matrix_view; + + // PQ codebook view is required; VQ codebook view is optional and defaults to + // absent. When VQ is not provided, vq_code_book() returns std::nullopt. + explicit vpq_codebooks_view(view_type pq_code_book_view, + std::optional vq_code_book_view = std::nullopt) + : pq_code_book_view_{pq_code_book_view}, vq_code_book_view_{vq_code_book_view} + { + } + + vpq_codebooks_view(const vpq_codebooks_view&) = default; + vpq_codebooks_view& operator=(const vpq_codebooks_view&) = default; + vpq_codebooks_view(vpq_codebooks_view&&) = default; + vpq_codebooks_view& operator=(vpq_codebooks_view&&) = default; + ~vpq_codebooks_view() override = default; + + [[nodiscard]] auto vq_code_book() const noexcept -> std::optional override + { + return vq_code_book_view_; + } + [[nodiscard]] auto pq_code_book() const noexcept -> view_type override + { + return pq_code_book_view_; + } + + private: + view_type pq_code_book_view_; + std::optional vq_code_book_view_; +}; + +} // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index 761474bdf8..6eacb92224 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -9,43 +9,52 @@ namespace cuvs::preprocessing::quantize::pq { -#define CUVS_INST_QUANTIZATION(T, QuantI) \ - auto build(raft::resources const& res, \ - const params params, \ - raft::device_matrix_view dataset) -> quantizer \ - { \ - return detail::build(res, params, dataset); \ - } \ - auto build(raft::resources const& res, \ - const params params, \ - raft::host_matrix_view dataset) -> quantizer \ - { \ - return detail::build(res, params, dataset); \ - } \ - void transform(raft::resources const& res, \ - const quantizer& quantizer, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view codes_out, \ - std::optional> vq_labels) \ - { \ - detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ - } \ - void transform(raft::resources const& res, \ - const quantizer& quantizer, \ - raft::host_matrix_view dataset, \ - raft::device_matrix_view codes_out, \ - std::optional> vq_labels) \ - { \ - detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ - } \ - void inverse_transform( \ - raft::resources const& res, \ - const quantizer& quantizer, \ - raft::device_matrix_view pq_codes, \ - raft::device_matrix_view out, \ - std::optional> vq_labels) \ - { \ - detail::inverse_transform(res, quantizer, pq_codes, out, vq_labels); \ +#define CUVS_INST_QUANTIZATION(T, QuantI) \ + auto build(raft::resources const& res, \ + const params params, \ + raft::device_matrix_view dataset) -> quantizer \ + { \ + return detail::build(res, params, dataset); \ + } \ + auto build(raft::resources const& res, \ + const params params, \ + raft::host_matrix_view dataset) -> quantizer \ + { \ + return detail::build(res, params, dataset); \ + } \ + auto build( \ + raft::resources const& res, \ + const params params, \ + raft::device_matrix_view pq_centers, \ + std::optional> vq_centers) \ + -> quantizer \ + { \ + return detail::build_view(res, params, pq_centers, vq_centers); \ + } \ + void transform(raft::resources const& res, \ + const quantizer& quantizer, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view codes_out, \ + std::optional> vq_labels) \ + { \ + detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ + } \ + void transform(raft::resources const& res, \ + const quantizer& quantizer, \ + raft::host_matrix_view dataset, \ + raft::device_matrix_view codes_out, \ + std::optional> vq_labels) \ + { \ + detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ + } \ + void inverse_transform( \ + raft::resources const& res, \ + const quantizer& quantizer, \ + raft::device_matrix_view pq_codes, \ + raft::device_matrix_view out, \ + std::optional> vq_labels) \ + { \ + detail::inverse_transform(res, quantizer, pq_codes, out, vq_labels); \ } CUVS_INST_QUANTIZATION(float, uint8_t); @@ -73,4 +82,9 @@ CUVS_INST_VPQ_BUILD(uint8_t); #undef CUVS_INST_VPQ_BUILD +template class vpq_codebooks; +template class vpq_codebooks; +template class vpq_dataset; +template class vpq_dataset; + } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh index 1745e53a33..7d4c8083b1 100644 --- a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh +++ b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh @@ -5,16 +5,17 @@ #pragma once #include +#include #include namespace cuvs::preprocessing::quantize::pq { #define CUVS_INST_VPQ_BUILD(T) \ - cuvs::neighbors::vpq_dataset vpq_build( \ + vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::host_matrix_view& dataset); \ - cuvs::neighbors::vpq_dataset vpq_build( \ + vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::device_matrix_view& dataset); diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index eafddec9d2..9df3bcadad 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -182,12 +182,11 @@ class scann_test : public ::testing::TestWithParam { handle_, idx.centers().extent(0), idx.centers().extent(1)); raft::copy( vq_codebook.data_handle(), idx.centers().data_handle(), idx.centers().size(), stream_); - auto empty_data = raft::make_device_matrix(handle_, 0, 0); - - cuvs::preprocessing::quantize::pq::quantizer quantizer{ - pq_params, - cuvs::neighbors::vpq_dataset{ - std::move(vq_codebook), std::move(pq_codebook_copy), std::move(empty_data)}}; + auto quantizer = + cuvs::preprocessing::quantize::pq::build(handle_, + pq_params, + raft::make_const_mdspan(pq_codebook_copy.view()), + raft::make_const_mdspan(vq_codebook.view())); auto quantized_residuals_device = raft::make_device_matrix(handle_, ps.num_db_vecs, num_subspaces); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 6ea881cd0b..a3daaea3a8 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -147,7 +147,10 @@ class ProductQuantizationTest : public ::testing::TestWithParam(dataset_.data_handle(), n_take, dim); std::optional> vq_labels_view = std::nullopt; - if (vq_labels) { vq_labels_view = raft::make_const_mdspan(vq_labels.value()); } + if (vq_labels) { + vq_labels_view = raft::make_device_vector_view( + vq_labels.value().data_handle(), static_cast(n_take)); + } inverse_transform(handle, quantizer, raft::device_matrix_view( @@ -240,6 +243,71 @@ class ProductQuantizationTest : public ::testing::TestWithParam(n_samples_) < params_.n_vq_centers)) { + EXPECT_THROW(build(handle, config, raft::make_const_mdspan(dataset_.view())), + raft::logic_error); + return; + } + auto owning_quant = build(handle, config, raft::make_const_mdspan(dataset_.view())); + + auto pq_centers_view = owning_quant.codebooks.pq_code_book(); + auto vq_centers_view = owning_quant.codebooks.vq_code_book(); + + auto view_quant = + build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); + + auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); + auto codes_owning = + raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + transform(handle, + owning_quant, + raft::make_const_mdspan(dataset_.view()), + codes_owning.view(), + std::nullopt); + + auto codes_view = + raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + transform(handle, + view_quant, + raft::make_const_mdspan(dataset_.view()), + codes_view.view(), + std::nullopt); + + cuvs::devArrMatch(codes_owning.data_handle(), + codes_view.data_handle(), + n_samples_ * n_encoded_cols, + cuvs::Compare()); + + auto reconstructed_owning = + raft::make_device_matrix(handle, n_samples_, n_features_); + auto reconstructed_view = + raft::make_device_matrix(handle, n_samples_, n_features_); + + inverse_transform(handle, + owning_quant, + raft::make_const_mdspan(codes_owning.view()), + reconstructed_owning.view(), + std::nullopt); + inverse_transform(handle, + view_quant, + raft::make_const_mdspan(codes_view.view()), + reconstructed_view.view(), + std::nullopt); + cuvs::devArrMatch(reconstructed_owning.data_handle(), + reconstructed_view.data_handle(), + n_samples_ * n_features_, + cuvs::Compare()); + } + private: raft::resources handle; cudaStream_t stream; @@ -428,7 +496,7 @@ const std::vector> inputs = { typedef ProductQuantizationTest ProductQuantizationTestF; TEST_P(ProductQuantizationTestF, Result) { this->testProductQuantizationFromDataset(); } - +TEST_P(ProductQuantizationTestF, ViewQuantizer) { this->testViewQuantizer(); } INSTANTIATE_TEST_CASE_P(ProductQuantizationTests, ProductQuantizationTestF, ::testing::ValuesIn(inputs));