From 2accbbadee0178c0e498baccad999e48c22d81ff Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 03:55:53 -0800 Subject: [PATCH 01/36] first commit --- cpp/include/cuvs/neighbors/common.hpp | 52 +++++++ .../cuvs/preprocessing/quantize/pq.hpp | 96 ++++++++++++- cpp/src/preprocessing/quantize/detail/pq.cuh | 135 ++++++++++++++++++ cpp/src/preprocessing/quantize/pq.cu | 40 ++++++ .../preprocessing/product_quantization.cu | 135 ++++++++++++++++++ 5 files changed, 457 insertions(+), 1 deletion(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index b2da48aaa1..92c0f38e1a 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -483,6 +483,58 @@ struct vpq_dataset : public dataset { } }; +/** + * @brief View-type VPQ codebooks (non-owning). + * + * This structure stores views of pre-computed VQ and PQ codebooks without copying. + * The caller is responsible for ensuring the lifetime of the underlying data + * exceeds the lifetime of this view. + * + * @tparam MathT the type of elements in the codebooks + * + */ +template +struct vpq_codebooks_view { + using math_type = MathT; + /** View of Vector Quantization codebook - "coarse cluster centers" [vq_n_centers, dim]. */ + raft::device_matrix_view vq_code_book; + /** View of Product Quantization codebook - "fine cluster centers". + * - For use_subspaces=true: [pq_dim * pq_n_centers, pq_len] + * - For use_subspaces=false: [pq_n_centers, pq_len] + */ + raft::device_matrix_view pq_code_book; + + vpq_codebooks_view( + raft::device_matrix_view vq_code_book, + raft::device_matrix_view pq_code_book) + : vq_code_book{vq_code_book}, pq_code_book{pq_code_book} + { + } + + [[nodiscard]] auto is_owning() const noexcept -> bool { return false; } + + /** Dimensionality of the original vectors. */ + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t + { + return vq_code_book.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** Dimensionality of a subspace, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + template struct is_vpq_dataset : std::false_type {}; diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index a8b5e6a0b9..56ff922569 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -77,7 +77,7 @@ struct params { }; /** - * @brief Defines and stores VPQ codebooks upon training + * @brief Defines and stores VPQ codebooks upon training (owning type) * * @tparam T data element type * @@ -88,6 +88,22 @@ struct quantizer { cuvs::neighbors::vpq_dataset vpq_codebooks; }; +/** + * @brief View-type quantizer that references pre-computed codebooks without copying. + * + * This structure stores views of pre-computed VQ and PQ codebooks for inference. + * The caller is responsible for ensuring the lifetime of the underlying data + * exceeds the lifetime of this view. + * + * @tparam T data element type + * + */ +template +struct quantizer_view { + params params_quantizer; + cuvs::neighbors::vpq_codebooks_view vpq_codebooks; +}; + /** * @brief Initializes a product quantizer to be used later for quantizing the dataset. * @@ -117,6 +133,45 @@ quantizer build(raft::resources const& res, const params params, raft::host_matrix_view dataset); +/** + * @brief Creates a view-type product quantizer from pre-computed codebooks. + * + * This function creates a non-owning quantizer that references the provided device data. + * The caller is responsible for ensuring the lifetime of the input data exceeds + * the lifetime of the returned quantizer_view. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * // Assume pq_centers and vq_centers are pre-computed on device + * cuvs::preprocessing::quantize::pq::params params; + * params.pq_bits = 8; + * params.pq_dim = 32; + * params.use_vq = true; + * params.use_subspaces = true; + * auto quant_view = cuvs::preprocessing::quantize::pq::build(handle, params, + * pq_centers_view, vq_centers_view); + * // Use quant_view for transform/inverse_transform operations + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure product quantizer parameters. Must be fully specified + * (pq_bits, pq_dim must be set; use_subspaces and use_vq must match the codebook shapes). + * @param[in] pq_centers PQ codebook on device memory: + * - For use_subspaces=true: [pq_dim * pq_n_centers, pq_len] + * - For use_subspaces=false: [pq_n_centers, pq_len] + * where pq_n_centers = (1 << pq_bits), pq_len = dim / pq_dim + * @param[in] vq_centers VQ codebook on device memory [vq_n_centers, dim]. + * Pass an empty view if use_vq=false. + * + * @return A view-type quantizer that references the provided data + */ +quantizer_view build( + raft::resources const& res, + const params params, + raft::device_matrix_view pq_centers, + raft::device_matrix_view vq_centers); + /** * @brief Applies quantization transform to given dataset * @@ -153,6 +208,28 @@ void transform(raft::resources const& res, raft::device_matrix_view codes_out, std::optional> vq_labels = std::nullopt); +/** + * @brief Applies quantization transform using a view-type quantizer + * + * @param[in] res raft resource + * @param[in] quant a view-type product quantizer + * @param[in] dataset a row-major matrix view on device + * @param[out] codes_out a row-major matrix view on device containing the PQ codes + * @param[out] vq_labels a vector view on device containing the VQ labels when VQ is used, optional + */ +void transform(raft::resources const& res, + const quantizer_view& quant, + raft::device_matrix_view dataset, + raft::device_matrix_view codes_out, + std::optional> vq_labels = std::nullopt); + +/** @copydoc transform */ +void transform(raft::resources const& res, + const quantizer_view& quant, + raft::host_matrix_view dataset, + raft::device_matrix_view codes_out, + std::optional> vq_labels = std::nullopt); + /** * @brief Get the dimension of the quantized dataset (in bytes) * @@ -181,6 +258,23 @@ void inverse_transform( raft::device_matrix_view out, std::optional> vq_labels = std::nullopt); +/** + * @brief Applies inverse quantization transform using a view-type quantizer + * + * @param[in] res raft resource + * @param[in] quant a view-type product quantizer + * @param[in] pq_codes a row-major matrix view on device containing the PQ codes + * @param[out] out a row-major matrix view on device + * @param[in] vq_labels a vector view on device containing the VQ labels when VQ is used, optional + * + */ +void inverse_transform( + raft::resources const& res, + const quantizer_view& quant, + raft::device_matrix_view pq_codes, + raft::device_matrix_view out, + std::optional> vq_labels = std::nullopt); + /** @} */ // end of group product } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index db7cc8d5c1..d7ffac608c 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -328,4 +328,139 @@ void inverse_transform( out, quant.params_quantizer.use_subspaces); } + +// ============= View-type quantizer functions ============= + +/** + * @brief Creates a view-type quantizer from pre-computed codebooks. + */ +template +quantizer_view build_view( + raft::resources const& res, + const params& params, + raft::device_matrix_view pq_centers, + raft::device_matrix_view vq_centers) +{ + // Validate parameters + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, + "PQ bits must be within [4, 16], got %u", + params.pq_bits); + RAFT_EXPECTS(params.pq_dim > 0, "pq_dim must be specified for view-type quantizer"); + + const uint32_t pq_n_centers = 1u << params.pq_bits; + + // Validate PQ centers shape + if (params.use_subspaces) { + RAFT_EXPECTS(pq_centers.extent(0) == params.pq_dim * pq_n_centers, + "For use_subspaces=true, pq_centers must have shape [pq_dim * pq_n_centers, " + "pq_len], got [%u, %u]", + pq_centers.extent(0), + pq_centers.extent(1)); + } else { + RAFT_EXPECTS(pq_centers.extent(0) == pq_n_centers, + "For use_subspaces=false, pq_centers must have shape [pq_n_centers, pq_len], got " + "[%u, %u]", + pq_centers.extent(0), + pq_centers.extent(1)); + } + + // Validate VQ centers + if (params.use_vq) { + RAFT_EXPECTS(!vq_centers.empty(), + "vq_centers must be provided when use_vq=true"); + RAFT_EXPECTS(vq_centers.extent(0) == params.vq_n_centers, + "vq_centers must have vq_n_centers rows, got %u", + vq_centers.extent(0)); + } + + return quantizer_view{ + params, cuvs::neighbors::vpq_codebooks_view{vq_centers, pq_centers}}; +} + +/** + * @brief Applies quantization transform using a view-type quantizer. + */ +template +void transform( + raft::resources const& res, + const quantizer_view& quantizer, + raft::mdspan, raft::row_major, AccessorType> dataset, + raft::device_matrix_view pq_codes_out, + std::optional> vq_labels = std::nullopt) +{ + raft::common::nvtx::range fun_scope( + "preprocessing::quantize::pq::transform_view(%zu, %zu, %zu)", + size_t(dataset.extent(0)), + size_t(dataset.extent(1)), + size_t(pq_codes_out.extent(1))); + RAFT_EXPECTS(pq_codes_out.extent(0) == dataset.extent(0), + "Output matrix must have the same number of rows as the input dataset"); + RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quantizer.params_quantizer), + "Output matrix doesn't have the correct number of columns"); + RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, + "PQ bits must be within [4, 16]"); + + // Use the views directly + auto vq_centers = quantizer.vpq_codebooks.vq_code_book; + auto pq_centers = quantizer.vpq_codebooks.pq_code_book; + auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); + if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } + + if (quantizer.params_quantizer.use_subspaces) { + cuvs::neighbors::detail::process_and_fill_codes_subspaces( + res, + to_vpq_params(quantizer.params_quantizer), + dataset, + pq_centers, + vq_centers, + vq_labels_view, + pq_codes_out); + } else { + cuvs::neighbors::detail::process_and_fill_codes( + res, + to_vpq_params(quantizer.params_quantizer), + dataset, + pq_centers, + vq_centers, + vq_labels_view, + pq_codes_out); + } +} + +/** + * @brief Applies inverse quantization transform using a view-type quantizer. + */ +template +void inverse_transform( + raft::resources const& res, + const quantizer_view& quant, + raft::device_matrix_view codes, + raft::device_matrix_view out, + std::optional> vq_labels = std::nullopt) +{ + raft::common::nvtx::range fun_scope( + "preprocessing::quantize::pq::inverse_transform_view(%zu, %zu, %zu)", + size_t(codes.extent(0)), + size_t(codes.extent(1)), + size_t(out.extent(1))); + using label_t = uint32_t; + using idx_t = int64_t; + RAFT_EXPECTS(out.extent(0) == codes.extent(0), + "Output matrix must have the same number of rows as the input codes"); + RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quant.params_quantizer), + "Codes matrix doesn't have the correct number of columns"); + RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, + "PQ bits must be within [4, 16]"); + + // Use the views directly + reconstruct_vectors(res, + quant.params_quantizer, + codes, + quant.vpq_codebooks.pq_code_book, + quant.vpq_codebooks.vq_code_book, + vq_labels, + out, + quant.params_quantizer.use_subspaces); +} + } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index 4a381c59ca..7496ea186c 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -52,4 +52,44 @@ CUVS_INST_QUANTIZATION(float, uint8_t); #undef CUVS_INST_QUANTIZATION +// View-type quantizer instantiations +#define CUVS_INST_QUANTIZATION_VIEW(T, QuantI) \ + auto build(raft::resources const& res, \ + const params params, \ + raft::device_matrix_view pq_centers, \ + raft::device_matrix_view vq_centers) \ + ->quantizer_view \ + { \ + return detail::build_view(res, params, pq_centers, vq_centers); \ + } \ + void transform(raft::resources const& res, \ + const quantizer_view& quantizer, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view codes_out, \ + std::optional> vq_labels) \ + { \ + detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ + } \ + void transform(raft::resources const& res, \ + const quantizer_view& quantizer, \ + raft::host_matrix_view dataset, \ + raft::device_matrix_view codes_out, \ + std::optional> vq_labels) \ + { \ + detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ + } \ + void inverse_transform( \ + raft::resources const& res, \ + const quantizer_view& quantizer, \ + raft::device_matrix_view pq_codes, \ + raft::device_matrix_view out, \ + std::optional> vq_labels) \ + { \ + detail::inverse_transform(res, quantizer, pq_codes, out, vq_labels); \ + } + +CUVS_INST_QUANTIZATION_VIEW(float, uint8_t); + +#undef CUVS_INST_QUANTIZATION_VIEW + } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index d8a83e747a..1ce7437527 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -372,4 +372,139 @@ INSTANTIATE_TEST_CASE_P(ProductQuantizationTests, ProductQuantizationTestF, ::testing::ValuesIn(inputs)); +// Test for view-type quantizer +class ProductQuantizationViewTest : public ::testing::Test { + public: + ProductQuantizationViewTest() + : handle{}, + stream{raft::resource::get_cuda_stream(handle)}, + n_samples_{1000}, + n_features_{128}, + pq_bits_{8}, + pq_dim_{32}, + dataset_{raft::make_device_matrix(handle, n_samples_, n_features_)} + { + } + + protected: + void SetUp() override + { + // Generate random dataset + auto labels = raft::make_device_vector(handle, n_samples_); + raft::random::make_blobs(handle, + dataset_.view(), + labels.view(), + 5, + std::nullopt, + std::nullopt, + 1.0f, + true, + -10.0f, + 10.0f, + 42ULL); + raft::resource::sync_stream(handle); + } + + void testViewQuantizerProducesSameResultsAsOwning() + { + // Build owning quantizer + params config{pq_bits_, pq_dim_, true /* use_subspaces */, false /* use_vq */}; + auto owning_quant = build(handle, config, raft::make_const_mdspan(dataset_.view())); + + // Extract codebook views from owning quantizer + auto pq_centers_view = raft::make_device_matrix_view( + owning_quant.vpq_codebooks.pq_code_book.data_handle(), + owning_quant.vpq_codebooks.pq_code_book.extent(0), + owning_quant.vpq_codebooks.pq_code_book.extent(1)); + auto vq_centers_view = raft::make_device_matrix_view( + owning_quant.vpq_codebooks.vq_code_book.data_handle(), + owning_quant.vpq_codebooks.vq_code_book.extent(0), + owning_quant.vpq_codebooks.vq_code_book.extent(1)); + + // Create view-type quantizer from the same codebooks + auto view_quant = build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); + + // Transform using owning quantizer + auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); + auto codes_owning = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + transform(handle, + owning_quant, + raft::make_const_mdspan(dataset_.view()), + codes_owning.view(), + std::nullopt); + + // Transform using view-type quantizer + auto codes_view = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + transform(handle, + view_quant, + raft::make_const_mdspan(dataset_.view()), + codes_view.view(), + std::nullopt); + + raft::resource::sync_stream(handle); + + // Compare results - should be identical + auto h_codes_owning = raft::make_host_matrix(n_samples_, n_encoded_cols); + auto h_codes_view = raft::make_host_matrix(n_samples_, n_encoded_cols); + raft::copy(h_codes_owning.data_handle(), + codes_owning.data_handle(), + n_samples_ * n_encoded_cols, + stream); + raft::copy( + h_codes_view.data_handle(), codes_view.data_handle(), n_samples_ * n_encoded_cols, stream); + raft::resource::sync_stream(handle); + + for (int64_t i = 0; i < n_samples_ * n_encoded_cols; i++) { + ASSERT_EQ(h_codes_owning.data_handle()[i], h_codes_view.data_handle()[i]) + << "Mismatch at index " << i; + } + + // Test inverse_transform + auto reconstructed_owning = raft::make_device_matrix(handle, n_samples_, n_features_); + auto reconstructed_view = raft::make_device_matrix(handle, n_samples_, n_features_); + + inverse_transform(handle, + owning_quant, + raft::make_const_mdspan(codes_owning.view()), + reconstructed_owning.view(), + std::nullopt); + inverse_transform(handle, + view_quant, + raft::make_const_mdspan(codes_view.view()), + reconstructed_view.view(), + std::nullopt); + + raft::resource::sync_stream(handle); + + // Compare reconstructions + auto h_rec_owning = raft::make_host_matrix(n_samples_, n_features_); + auto h_rec_view = raft::make_host_matrix(n_samples_, n_features_); + raft::copy(h_rec_owning.data_handle(), + reconstructed_owning.data_handle(), + n_samples_ * n_features_, + stream); + raft::copy( + h_rec_view.data_handle(), reconstructed_view.data_handle(), n_samples_ * n_features_, stream); + raft::resource::sync_stream(handle); + + for (int64_t i = 0; i < n_samples_ * n_features_; i++) { + ASSERT_FLOAT_EQ(h_rec_owning.data_handle()[i], h_rec_view.data_handle()[i]) + << "Reconstruction mismatch at index " << i; + } + } + + raft::resources handle; + cudaStream_t stream; + int64_t n_samples_; + int64_t n_features_; + uint32_t pq_bits_; + uint32_t pq_dim_; + raft::device_matrix dataset_; +}; + +TEST_F(ProductQuantizationViewTest, ViewQuantizerProducesSameResults) +{ + testViewQuantizerProducesSameResultsAsOwning(); +} + } // namespace cuvs::preprocessing::quantize::pq From 4c6182cb181bbc2647861b4c9a1f25d25ff6849a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 04:34:55 -0800 Subject: [PATCH 02/36] update vpq_dataset --- cpp/include/cuvs/neighbors/common.hpp | 146 ++++++------ .../cuvs/preprocessing/quantize/pq.hpp | 69 +----- cpp/src/preprocessing/quantize/detail/pq.cuh | 210 ++++++------------ cpp/src/preprocessing/quantize/pq.cu | 48 +--- .../preprocessing/product_quantization.cu | 33 +-- 5 files changed, 178 insertions(+), 328 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 92c0f38e1a..3dbfa1b9df 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -405,6 +405,10 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b * 1. Vector Quantization * 2. Product Quantization of residuals * + * This struct can be either owning or non-owning (view-type): + * - Owning: owns the codebooks and data matrices (created via move constructor) + * - View: stores views to externally-owned data (created via view constructor) + * * @tparam MathT the type of elements in the codebooks * @tparam IdxT type of the vector indices (represent dataset.extent(0)) * @@ -413,35 +417,86 @@ template struct vpq_dataset : public dataset { using index_type = IdxT; using math_type = MathT; - /** Vector Quantization codebook - "coarse cluster centers". */ - raft::device_matrix vq_code_book; - /** Product Quantization codebook - "fine cluster centers". */ - raft::device_matrix pq_code_book; - /** Compressed dataset. */ - raft::device_matrix data; + /** + * @brief Construct an owning vpq_dataset by moving in the codebooks and data. + */ vpq_dataset(raft::device_matrix&& vq_code_book, raft::device_matrix&& pq_code_book, raft::device_matrix&& data) - : vq_code_book{std::move(vq_code_book)}, - pq_code_book{std::move(pq_code_book)}, - data{std::move(data)} + : vq_code_book_owned_{std::move(vq_code_book)}, + pq_code_book_owned_{std::move(pq_code_book)}, + data_owned_{std::move(data)}, + vq_code_book_view_{vq_code_book_owned_.view()}, + pq_code_book_view_{pq_code_book_owned_.view()}, + data_view_{data_owned_.view()}, + is_owning_{true} { } - [[nodiscard]] auto n_rows() const noexcept -> index_type final { return data.extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + /** + * @brief Construct a view-type vpq_dataset from external codebook views. + * + * The caller must ensure the lifetime of the underlying data exceeds the lifetime of this object. + * + * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] + * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, + * pq_len] + * @param data_view View of compressed data (can be empty for quantizer-only use) + */ + vpq_dataset(raft::device_matrix_view vq_code_book_view, + raft::device_matrix_view pq_code_book_view, + raft::device_matrix_view data_view = + raft::device_matrix_view{}) + : vq_code_book_owned_{}, + pq_code_book_owned_{}, + data_owned_{}, + vq_code_book_view_{vq_code_book_view}, + pq_code_book_view_{pq_code_book_view}, + data_view_{data_view}, + is_owning_{false} + { + } + + vpq_dataset(const vpq_dataset&) = delete; + vpq_dataset& operator=(const vpq_dataset&) = delete; + vpq_dataset(vpq_dataset&&) = default; + vpq_dataset& operator=(vpq_dataset&&) = default; + + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return data_view_.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book_view_.extent(1); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return is_owning_; } + + /** Get view of VQ codebook. */ + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view + { + return vq_code_book_view_; + } + + /** Get view of PQ codebook. */ + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return pq_code_book_view_; + } + + /** Get view of compressed data. */ + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view + { + return data_view_; + } /** Row length of the encoded data in bytes. */ [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t { - return data.extent(1); + return data_view_.extent(1); } /** The number of "coarse cluster centers" */ [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t { - return vq_code_book.extent(0); + return vq_code_book_view_.extent(0); } /** The bit length of an encoded vector element after compression by PQ. */ [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t @@ -474,65 +529,26 @@ struct vpq_dataset : public dataset { /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t { - return pq_code_book.extent(1); + return pq_code_book_view_.extent(1); } /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t { - return pq_code_book.extent(0); + return pq_code_book_view_.extent(0); } -}; -/** - * @brief View-type VPQ codebooks (non-owning). - * - * This structure stores views of pre-computed VQ and PQ codebooks without copying. - * The caller is responsible for ensuring the lifetime of the underlying data - * exceeds the lifetime of this view. - * - * @tparam MathT the type of elements in the codebooks - * - */ -template -struct vpq_codebooks_view { - using math_type = MathT; - /** View of Vector Quantization codebook - "coarse cluster centers" [vq_n_centers, dim]. */ - raft::device_matrix_view vq_code_book; - /** View of Product Quantization codebook - "fine cluster centers". - * - For use_subspaces=true: [pq_dim * pq_n_centers, pq_len] - * - For use_subspaces=false: [pq_n_centers, pq_len] - */ - raft::device_matrix_view pq_code_book; + private: + // Owning storage (empty when is_owning_ == false) + raft::device_matrix vq_code_book_owned_; + raft::device_matrix pq_code_book_owned_; + raft::device_matrix data_owned_; - vpq_codebooks_view( - raft::device_matrix_view vq_code_book, - raft::device_matrix_view pq_code_book) - : vq_code_book{vq_code_book}, pq_code_book{pq_code_book} - { - } - - [[nodiscard]] auto is_owning() const noexcept -> bool { return false; } + // Views (always valid - either point to owned data or external data) + raft::device_matrix_view vq_code_book_view_; + raft::device_matrix_view pq_code_book_view_; + raft::device_matrix_view data_view_; - /** Dimensionality of the original vectors. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return vq_code_book.extent(1); - } - /** The number of "coarse cluster centers" */ - [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t - { - return vq_code_book.extent(0); - } - /** Dimensionality of a subspace, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t - { - return pq_code_book.extent(1); - } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t - { - return pq_code_book.extent(0); - } + bool is_owning_; }; template diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 56ff922569..29ecdd0db7 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -77,7 +77,9 @@ struct params { }; /** - * @brief Defines and stores VPQ codebooks upon training (owning type) + * @brief Defines and stores VPQ codebooks upon training. + * + * Can be either owning (trained from data) or non-owning (view of external codebooks). * * @tparam T data element type * @@ -89,23 +91,7 @@ struct quantizer { }; /** - * @brief View-type quantizer that references pre-computed codebooks without copying. - * - * This structure stores views of pre-computed VQ and PQ codebooks for inference. - * The caller is responsible for ensuring the lifetime of the underlying data - * exceeds the lifetime of this view. - * - * @tparam T data element type - * - */ -template -struct quantizer_view { - params params_quantizer; - cuvs::neighbors::vpq_codebooks_view vpq_codebooks; -}; - -/** - * @brief Initializes a product quantizer to be used later for quantizing the dataset. + * @brief Initializes a product quantizer by training on the dataset (owning). * * The use of a pool memory resource is recommended for more consistent training performance. * @@ -119,10 +105,10 @@ struct quantizer_view { * @endcode * * @param[in] res raft resource - * @param[in] params configure product quantizer, e.g. quantile + * @param[in] params configure product quantizer, e.g. pq_bits, pq_dim * @param[in] dataset a row-major matrix view on device or host * - * @return quantizer + * @return quantizer (owning) */ quantizer build(raft::resources const& res, const params params, @@ -138,7 +124,7 @@ quantizer build(raft::resources const& res, * * This function creates a non-owning quantizer that references the provided device data. * The caller is responsible for ensuring the lifetime of the input data exceeds - * the lifetime of the returned quantizer_view. + * the lifetime of the returned quantizer. * * Usage example: * @code{.cpp} @@ -166,7 +152,7 @@ quantizer build(raft::resources const& res, * * @return A view-type quantizer that references the provided data */ -quantizer_view build( +quantizer build( raft::resources const& res, const params params, raft::device_matrix_view pq_centers, @@ -208,28 +194,6 @@ void transform(raft::resources const& res, raft::device_matrix_view codes_out, std::optional> vq_labels = std::nullopt); -/** - * @brief Applies quantization transform using a view-type quantizer - * - * @param[in] res raft resource - * @param[in] quant a view-type product quantizer - * @param[in] dataset a row-major matrix view on device - * @param[out] codes_out a row-major matrix view on device containing the PQ codes - * @param[out] vq_labels a vector view on device containing the VQ labels when VQ is used, optional - */ -void transform(raft::resources const& res, - const quantizer_view& quant, - raft::device_matrix_view dataset, - raft::device_matrix_view codes_out, - std::optional> vq_labels = std::nullopt); - -/** @copydoc transform */ -void transform(raft::resources const& res, - const quantizer_view& quant, - raft::host_matrix_view dataset, - raft::device_matrix_view codes_out, - std::optional> vq_labels = std::nullopt); - /** * @brief Get the dimension of the quantized dataset (in bytes) * @@ -258,23 +222,6 @@ void inverse_transform( raft::device_matrix_view out, std::optional> vq_labels = std::nullopt); -/** - * @brief Applies inverse quantization transform using a view-type quantizer - * - * @param[in] res raft resource - * @param[in] quant a view-type product quantizer - * @param[in] pq_codes a row-major matrix view on device containing the PQ codes - * @param[out] out a row-major matrix view on device - * @param[in] vq_labels a vector view on device containing the VQ labels when VQ is used, optional - * - */ -void inverse_transform( - raft::resources const& res, - const quantizer_view& quant, - raft::device_matrix_view pq_codes, - raft::device_matrix_view out, - std::optional> vq_labels = std::nullopt); - /** @} */ // end of group product } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index d7ffac608c..867a6fce26 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -120,6 +120,9 @@ auto train_pq_subspaces( return pq_centers; } +/** + * @brief Build an owning quantizer by training on a dataset. + */ template quantizer build( raft::resources const& res, @@ -159,10 +162,56 @@ quantizer build( std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes)}}; } +/** + * @brief Build a view-type quantizer from pre-computed codebooks. + */ +template +quantizer build_view( + raft::resources const& res, + const params& params, + raft::device_matrix_view pq_centers, + raft::device_matrix_view vq_centers) +{ + // Validate parameters + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, + "PQ bits must be within [4, 16], got %u", + params.pq_bits); + RAFT_EXPECTS(params.pq_dim > 0, "pq_dim must be specified for view-type quantizer"); + + const uint32_t pq_n_centers = 1u << params.pq_bits; + + // Validate PQ centers shape + if (params.use_subspaces) { + RAFT_EXPECTS(pq_centers.extent(0) == params.pq_dim * pq_n_centers, + "For use_subspaces=true, pq_centers must have shape [pq_dim * pq_n_centers, " + "pq_len], got [%u, %u]", + pq_centers.extent(0), + pq_centers.extent(1)); + } else { + RAFT_EXPECTS(pq_centers.extent(0) == pq_n_centers, + "For use_subspaces=false, pq_centers must have shape [pq_n_centers, pq_len], got " + "[%u, %u]", + pq_centers.extent(0), + pq_centers.extent(1)); + } + + // Validate VQ centers + if (params.use_vq) { + RAFT_EXPECTS(!vq_centers.empty(), "vq_centers must be provided when use_vq=true"); + RAFT_EXPECTS(vq_centers.extent(0) == params.vq_n_centers, + "vq_centers must have vq_n_centers rows, got %u", + vq_centers.extent(0)); + } + + // Create view-type vpq_dataset + auto empty_data = raft::device_matrix_view{}; + return {params, cuvs::neighbors::vpq_dataset{vq_centers, pq_centers, empty_data}}; +} + template void transform( raft::resources const& res, - const quantizer& quantizer, + const quantizer& quant, raft::mdspan, raft::row_major, AccessorType> dataset, raft::device_matrix_view pq_codes_out, std::optional> vq_labels = std::nullopt) @@ -174,30 +223,32 @@ void transform( size_t(pq_codes_out.extent(1))); RAFT_EXPECTS(pq_codes_out.extent(0) == dataset.extent(0), "Output matrix must have the same number of rows as the input dataset"); - RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quantizer.params_quantizer), + RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quant.params_quantizer), "Output matrix doesn't have the correct number of columns"); - RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, + RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - // Encode dataset - auto vq_centers = raft::make_const_mdspan(quantizer.vpq_codebooks.vq_code_book.view()); + + // Use view accessors from vpq_dataset + auto vq_centers = quant.vpq_codebooks.vq_code_book(); + auto pq_centers = quant.vpq_codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } - if (quantizer.params_quantizer.use_subspaces) { + if (quant.params_quantizer.use_subspaces) { cuvs::neighbors::detail::process_and_fill_codes_subspaces( res, - to_vpq_params(quantizer.params_quantizer), + to_vpq_params(quant.params_quantizer), dataset, - raft::make_const_mdspan(quantizer.vpq_codebooks.pq_code_book.view()), + pq_centers, vq_centers, vq_labels_view, pq_codes_out); } else { cuvs::neighbors::detail::process_and_fill_codes( res, - to_vpq_params(quantizer.params_quantizer), + to_vpq_params(quant.params_quantizer), dataset, - raft::make_const_mdspan(quantizer.vpq_codebooks.pq_code_book.view()), + pq_centers, vq_centers, vq_labels_view, pq_codes_out); @@ -312,139 +363,6 @@ void inverse_transform( size_t(out.extent(1))); using label_t = uint32_t; using idx_t = int64_t; - RAFT_EXPECTS(out.extent(0) == codes.extent(0), - "Output matrix must have the same number of rows as the input codes"); - RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quant.params_quantizer), - "Codes matrix doesn't have the correct number of columns"); - RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, - "PQ bits must be within [4, 16]"); - reconstruct_vectors( - res, - quant.params_quantizer, - codes, - raft::make_const_mdspan(quant.vpq_codebooks.pq_code_book.view()), - raft::make_const_mdspan(quant.vpq_codebooks.vq_code_book.view()), - vq_labels, - out, - quant.params_quantizer.use_subspaces); -} - -// ============= View-type quantizer functions ============= - -/** - * @brief Creates a view-type quantizer from pre-computed codebooks. - */ -template -quantizer_view build_view( - raft::resources const& res, - const params& params, - raft::device_matrix_view pq_centers, - raft::device_matrix_view vq_centers) -{ - // Validate parameters - RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, - "PQ bits must be within [4, 16], got %u", - params.pq_bits); - RAFT_EXPECTS(params.pq_dim > 0, "pq_dim must be specified for view-type quantizer"); - - const uint32_t pq_n_centers = 1u << params.pq_bits; - - // Validate PQ centers shape - if (params.use_subspaces) { - RAFT_EXPECTS(pq_centers.extent(0) == params.pq_dim * pq_n_centers, - "For use_subspaces=true, pq_centers must have shape [pq_dim * pq_n_centers, " - "pq_len], got [%u, %u]", - pq_centers.extent(0), - pq_centers.extent(1)); - } else { - RAFT_EXPECTS(pq_centers.extent(0) == pq_n_centers, - "For use_subspaces=false, pq_centers must have shape [pq_n_centers, pq_len], got " - "[%u, %u]", - pq_centers.extent(0), - pq_centers.extent(1)); - } - - // Validate VQ centers - if (params.use_vq) { - RAFT_EXPECTS(!vq_centers.empty(), - "vq_centers must be provided when use_vq=true"); - RAFT_EXPECTS(vq_centers.extent(0) == params.vq_n_centers, - "vq_centers must have vq_n_centers rows, got %u", - vq_centers.extent(0)); - } - - return quantizer_view{ - params, cuvs::neighbors::vpq_codebooks_view{vq_centers, pq_centers}}; -} - -/** - * @brief Applies quantization transform using a view-type quantizer. - */ -template -void transform( - raft::resources const& res, - const quantizer_view& quantizer, - raft::mdspan, raft::row_major, AccessorType> dataset, - raft::device_matrix_view pq_codes_out, - std::optional> vq_labels = std::nullopt) -{ - raft::common::nvtx::range fun_scope( - "preprocessing::quantize::pq::transform_view(%zu, %zu, %zu)", - size_t(dataset.extent(0)), - size_t(dataset.extent(1)), - size_t(pq_codes_out.extent(1))); - RAFT_EXPECTS(pq_codes_out.extent(0) == dataset.extent(0), - "Output matrix must have the same number of rows as the input dataset"); - RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quantizer.params_quantizer), - "Output matrix doesn't have the correct number of columns"); - RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, - "PQ bits must be within [4, 16]"); - - // Use the views directly - auto vq_centers = quantizer.vpq_codebooks.vq_code_book; - auto pq_centers = quantizer.vpq_codebooks.pq_code_book; - auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); - if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } - - if (quantizer.params_quantizer.use_subspaces) { - cuvs::neighbors::detail::process_and_fill_codes_subspaces( - res, - to_vpq_params(quantizer.params_quantizer), - dataset, - pq_centers, - vq_centers, - vq_labels_view, - pq_codes_out); - } else { - cuvs::neighbors::detail::process_and_fill_codes( - res, - to_vpq_params(quantizer.params_quantizer), - dataset, - pq_centers, - vq_centers, - vq_labels_view, - pq_codes_out); - } -} - -/** - * @brief Applies inverse quantization transform using a view-type quantizer. - */ -template -void inverse_transform( - raft::resources const& res, - const quantizer_view& quant, - raft::device_matrix_view codes, - raft::device_matrix_view out, - std::optional> vq_labels = std::nullopt) -{ - raft::common::nvtx::range fun_scope( - "preprocessing::quantize::pq::inverse_transform_view(%zu, %zu, %zu)", - size_t(codes.extent(0)), - size_t(codes.extent(1)), - size_t(out.extent(1))); - using label_t = uint32_t; - using idx_t = int64_t; RAFT_EXPECTS(out.extent(0) == codes.extent(0), "Output matrix must have the same number of rows as the input codes"); RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quant.params_quantizer), @@ -452,12 +370,12 @@ void inverse_transform( RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - // Use the views directly + // Use view accessors from vpq_dataset reconstruct_vectors(res, quant.params_quantizer, codes, - quant.vpq_codebooks.pq_code_book, - quant.vpq_codebooks.vq_code_book, + quant.vpq_codebooks.pq_code_book(), + quant.vpq_codebooks.vq_code_book(), vq_labels, out, quant.params_quantizer.use_subspaces); diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index 7496ea186c..c01e832b4a 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -22,6 +22,14 @@ namespace cuvs::preprocessing::quantize::pq { { \ return detail::build(res, params, dataset); \ } \ + auto build(raft::resources const& res, \ + const params params, \ + raft::device_matrix_view pq_centers, \ + raft::device_matrix_view vq_centers) \ + ->quantizer \ + { \ + return detail::build_view(res, params, pq_centers, vq_centers); \ + } \ void transform(raft::resources const& res, \ const quantizer& quantizer, \ raft::device_matrix_view dataset, \ @@ -52,44 +60,4 @@ CUVS_INST_QUANTIZATION(float, uint8_t); #undef CUVS_INST_QUANTIZATION -// View-type quantizer instantiations -#define CUVS_INST_QUANTIZATION_VIEW(T, QuantI) \ - auto build(raft::resources const& res, \ - const params params, \ - raft::device_matrix_view pq_centers, \ - raft::device_matrix_view vq_centers) \ - ->quantizer_view \ - { \ - return detail::build_view(res, params, pq_centers, vq_centers); \ - } \ - void transform(raft::resources const& res, \ - const quantizer_view& quantizer, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view codes_out, \ - std::optional> vq_labels) \ - { \ - detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ - } \ - void transform(raft::resources const& res, \ - const quantizer_view& quantizer, \ - raft::host_matrix_view dataset, \ - raft::device_matrix_view codes_out, \ - std::optional> vq_labels) \ - { \ - detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ - } \ - void inverse_transform( \ - raft::resources const& res, \ - const quantizer_view& quantizer, \ - raft::device_matrix_view pq_codes, \ - raft::device_matrix_view out, \ - std::optional> vq_labels) \ - { \ - detail::inverse_transform(res, quantizer, pq_codes, out, vq_labels); \ - } - -CUVS_INST_QUANTIZATION_VIEW(float, uint8_t); - -#undef CUVS_INST_QUANTIZATION_VIEW - } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 1ce7437527..aa6f257e19 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -131,7 +131,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam& quantizer, + void check_reconstruction(const cuvs::preprocessing::quantize::pq::quantizer& quant, raft::device_matrix_view codes, std::optional> vq_labels, double compression_ratio, @@ -149,7 +149,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam> vq_labels_view = std::nullopt; if (vq_labels) { vq_labels_view = raft::make_const_mdspan(vq_labels.value()); } inverse_transform(handle, - quantizer, + quant, raft::device_matrix_view( codes.data_handle(), n_take, codes.extent(1)), rec_data.view(), @@ -161,7 +161,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam( - owning_quant.vpq_codebooks.pq_code_book.data_handle(), - owning_quant.vpq_codebooks.pq_code_book.extent(0), - owning_quant.vpq_codebooks.pq_code_book.extent(1)); - auto vq_centers_view = raft::make_device_matrix_view( - owning_quant.vpq_codebooks.vq_code_book.data_handle(), - owning_quant.vpq_codebooks.vq_code_book.extent(0), - owning_quant.vpq_codebooks.vq_code_book.extent(1)); + // Extract codebook views from owning quantizer using new accessor methods + auto pq_centers_view = owning_quant.vpq_codebooks.pq_code_book(); + auto vq_centers_view = owning_quant.vpq_codebooks.vq_code_book(); // Create view-type quantizer from the same codebooks auto view_quant = build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); // Transform using owning quantizer auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); - auto codes_owning = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + auto codes_owning = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); transform(handle, owning_quant, raft::make_const_mdspan(dataset_.view()), @@ -460,8 +454,10 @@ class ProductQuantizationViewTest : public ::testing::Test { } // Test inverse_transform - auto reconstructed_owning = raft::make_device_matrix(handle, n_samples_, n_features_); - auto reconstructed_view = raft::make_device_matrix(handle, n_samples_, n_features_); + auto reconstructed_owning = + raft::make_device_matrix(handle, n_samples_, n_features_); + auto reconstructed_view = + raft::make_device_matrix(handle, n_samples_, n_features_); inverse_transform(handle, owning_quant, @@ -491,6 +487,11 @@ class ProductQuantizationViewTest : public ::testing::Test { ASSERT_FLOAT_EQ(h_rec_owning.data_handle()[i], h_rec_view.data_handle()[i]) << "Reconstruction mismatch at index " << i; } + + // Verify that the owning quantizer reports is_owning() = true + // and the view quantizer reports is_owning() = false + ASSERT_TRUE(owning_quant.vpq_codebooks.is_owning()); + ASSERT_FALSE(view_quant.vpq_codebooks.is_owning()); } raft::resources handle; From f18e00c3f05507210ba2f0e2080a4e11ab89e4f4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 04:52:05 -0800 Subject: [PATCH 03/36] clean pimpl separation --- cpp/include/cuvs/neighbors/common.hpp | 260 ++++++++++++------ .../cuvs/preprocessing/quantize/pq.hpp | 21 +- .../detail/cagra/compute_distance_vpq.hpp | 6 +- cpp/src/neighbors/detail/cagra/factory.cuh | 4 +- .../neighbors/detail/dataset_serialize.hpp | 8 +- .../neighbors/detail/vamana/vamana_build.cuh | 4 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 31 ++- cpp/src/neighbors/vpq_dataset.cuh | 7 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 18 +- .../preprocessing/product_quantization.cu | 13 +- 10 files changed, 242 insertions(+), 130 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 3dbfa1b9df..1bbd7751fb 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -398,108 +398,54 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b return make_strided_dataset(res, std::forward(src), required_stride); } /** - * @brief VPQ compressed dataset. + * @brief VPQ compressed dataset - base interface. * * The dataset is compressed using two level quantization * * 1. Vector Quantization * 2. Product Quantization of residuals * - * This struct can be either owning or non-owning (view-type): - * - Owning: owns the codebooks and data matrices (created via move constructor) - * - View: stores views to externally-owned data (created via view constructor) + * This is the abstract base class. Use vpq_dataset_owning for owning data + * or vpq_dataset_view for non-owning views. * * @tparam MathT the type of elements in the codebooks * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - * */ template struct vpq_dataset : public dataset { using index_type = IdxT; using math_type = MathT; - /** - * @brief Construct an owning vpq_dataset by moving in the codebooks and data. - */ - vpq_dataset(raft::device_matrix&& vq_code_book, - raft::device_matrix&& pq_code_book, - raft::device_matrix&& data) - : vq_code_book_owned_{std::move(vq_code_book)}, - pq_code_book_owned_{std::move(pq_code_book)}, - data_owned_{std::move(data)}, - vq_code_book_view_{vq_code_book_owned_.view()}, - pq_code_book_view_{pq_code_book_owned_.view()}, - data_view_{data_owned_.view()}, - is_owning_{true} - { - } - - /** - * @brief Construct a view-type vpq_dataset from external codebook views. - * - * The caller must ensure the lifetime of the underlying data exceeds the lifetime of this object. - * - * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] - * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, - * pq_len] - * @param data_view View of compressed data (can be empty for quantizer-only use) - */ - vpq_dataset(raft::device_matrix_view vq_code_book_view, - raft::device_matrix_view pq_code_book_view, - raft::device_matrix_view data_view = - raft::device_matrix_view{}) - : vq_code_book_owned_{}, - pq_code_book_owned_{}, - data_owned_{}, - vq_code_book_view_{vq_code_book_view}, - pq_code_book_view_{pq_code_book_view}, - data_view_{data_view}, - is_owning_{false} - { - } - - vpq_dataset(const vpq_dataset&) = delete; - vpq_dataset& operator=(const vpq_dataset&) = delete; - vpq_dataset(vpq_dataset&&) = default; - vpq_dataset& operator=(vpq_dataset&&) = default; - - [[nodiscard]] auto n_rows() const noexcept -> index_type final { return data_view_.extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book_view_.extent(1); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return is_owning_; } + ~vpq_dataset() override = default; /** Get view of VQ codebook. */ - [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view - { - return vq_code_book_view_; - } + [[nodiscard]] virtual auto vq_code_book() const noexcept + -> raft::device_matrix_view = 0; /** Get view of PQ codebook. */ - [[nodiscard]] auto pq_code_book() const noexcept - -> raft::device_matrix_view - { - return pq_code_book_view_; - } + [[nodiscard]] virtual auto pq_code_book() const noexcept + -> raft::device_matrix_view = 0; /** Get view of compressed data. */ - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view - { - return data_view_; - } + [[nodiscard]] virtual auto data() const noexcept + -> raft::device_matrix_view = 0; + + // Derived properties with default implementations + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data().extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return vq_code_book().extent(1); } /** Row length of the encoded data in bytes. */ - [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t + [[nodiscard]] inline auto encoded_row_length() const noexcept -> uint32_t { - return data_view_.extent(1); + return data().extent(1); } /** The number of "coarse cluster centers" */ - [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + [[nodiscard]] inline auto vq_n_centers() const noexcept -> uint32_t { - return vq_code_book_view_.extent(0); + return vq_code_book().extent(0); } /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + [[nodiscard]] inline auto pq_bits() const noexcept -> uint32_t { /* NOTE: pq_bits and the book size @@ -513,42 +459,172 @@ struct vpq_dataset : public dataset { #ifdef __cpp_lib_bitops return std::countr_zero(pq_width); #else - uint32_t pq_bits = 0; + uint32_t bits = 0; while (pq_width > 1) { - pq_bits++; + bits++; pq_width >>= 1; } - return pq_bits; + return bits; #endif } /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + [[nodiscard]] inline auto pq_dim() const noexcept -> uint32_t { return raft::div_rounding_up_unsafe(dim(), pq_len()); } /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + [[nodiscard]] inline auto pq_len() const noexcept -> uint32_t { return pq_code_book().extent(1); } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] inline auto pq_n_centers() const noexcept -> uint32_t { - return pq_code_book_view_.extent(1); + return pq_code_book().extent(0); } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t +}; + +/** + * @brief Owning VPQ dataset - owns the codebooks and data. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices + */ +template +struct vpq_dataset_owning : public vpq_dataset { + using index_type = IdxT; + using math_type = MathT; + + /** + * @brief Construct an owning vpq_dataset by moving in the codebooks and data. + */ + vpq_dataset_owning(raft::device_matrix&& vq_code_book, + raft::device_matrix&& pq_code_book, + raft::device_matrix&& data) + : vq_code_book_{std::move(vq_code_book)}, + pq_code_book_{std::move(pq_code_book)}, + data_{std::move(data)} + { + } + + vpq_dataset_owning(const vpq_dataset_owning&) = delete; + vpq_dataset_owning& operator=(const vpq_dataset_owning&) = delete; + vpq_dataset_owning(vpq_dataset_owning&&) = default; + vpq_dataset_owning& operator=(vpq_dataset_owning&&) = default; + ~vpq_dataset_owning() override = default; + + [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } + + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view override + { + return vq_code_book_.view(); + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view override + { + return pq_code_book_.view(); + } + + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view override + { + return data_.view(); + } + + // Non-const accessors for building/modifying + [[nodiscard]] auto vq_code_book_mut() noexcept + -> raft::device_matrix_view + { + return vq_code_book_.view(); + } + + [[nodiscard]] auto pq_code_book_mut() noexcept + -> raft::device_matrix_view + { + return pq_code_book_.view(); + } + + [[nodiscard]] auto data_mut() noexcept + -> raft::device_matrix_view + { + return data_.view(); + } + + /** Release ownership of the data matrix (for type conversion operations). */ + [[nodiscard]] auto release_data() noexcept + -> raft::device_matrix { - return pq_code_book_view_.extent(0); + return std::move(data_); } private: - // Owning storage (empty when is_owning_ == false) - raft::device_matrix vq_code_book_owned_; - raft::device_matrix pq_code_book_owned_; - raft::device_matrix data_owned_; + raft::device_matrix vq_code_book_; + raft::device_matrix pq_code_book_; + raft::device_matrix data_; +}; + +/** + * @brief View-type VPQ dataset - non-owning views to external data. + * + * The caller must ensure the lifetime of the underlying data exceeds + * the lifetime of this object. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices + */ +template +struct vpq_dataset_view : public vpq_dataset { + using index_type = IdxT; + using math_type = MathT; + + /** + * @brief Construct a view-type vpq_dataset from external codebook views. + * + * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] + * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, + * pq_len] + * @param data_view View of compressed data (can be empty for quantizer-only use) + */ + vpq_dataset_view( + raft::device_matrix_view vq_code_book_view, + raft::device_matrix_view pq_code_book_view, + raft::device_matrix_view data_view = + raft::device_matrix_view{}) + : vq_code_book_view_{vq_code_book_view}, + pq_code_book_view_{pq_code_book_view}, + data_view_{data_view} + { + } + + vpq_dataset_view(const vpq_dataset_view&) = default; + vpq_dataset_view& operator=(const vpq_dataset_view&) = default; + vpq_dataset_view(vpq_dataset_view&&) = default; + vpq_dataset_view& operator=(vpq_dataset_view&&) = default; + ~vpq_dataset_view() override = default; + + [[nodiscard]] auto is_owning() const noexcept -> bool override { return false; } + + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view override + { + return vq_code_book_view_; + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view override + { + return pq_code_book_view_; + } - // Views (always valid - either point to owned data or external data) + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view override + { + return data_view_; + } + + private: raft::device_matrix_view vq_code_book_view_; raft::device_matrix_view pq_code_book_view_; raft::device_matrix_view data_view_; - - bool is_owning_; }; template @@ -557,6 +633,12 @@ struct is_vpq_dataset : std::false_type {}; template struct is_vpq_dataset> : std::true_type {}; +template +struct is_vpq_dataset> : std::true_type {}; + +template +struct is_vpq_dataset> : std::true_type {}; + template inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 29ecdd0db7..078f9464b4 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -10,6 +10,8 @@ #include #include +#include + namespace cuvs::preprocessing::quantize::pq { /** @@ -79,15 +81,28 @@ struct params { /** * @brief Defines and stores VPQ codebooks upon training. * - * Can be either owning (trained from data) or non-owning (view of external codebooks). + * The quantizer holds a pointer to a vpq_dataset, which can be either + * owning (vpq_dataset_owning, trained from data) or non-owning + * (vpq_dataset_view, referencing external codebooks). * * @tparam T data element type - * */ template struct quantizer { params params_quantizer; - cuvs::neighbors::vpq_dataset vpq_codebooks; + std::unique_ptr> vpq_codebooks; + + quantizer() = default; + + quantizer(params p, std::unique_ptr> codebooks) + : params_quantizer(p), vpq_codebooks(std::move(codebooks)) + { + } + + quantizer(const quantizer&) = delete; + quantizer& operator=(const quantizer&) = delete; + quantizer(quantizer&&) = default; + quantizer& operator=(quantizer&&) = default; }; /** diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 2b69a1cef4..995265d232 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -50,10 +50,10 @@ struct vpq_descriptor_spec : public instance_spec { const DistanceT* dataset_norms = nullptr) -> host_type { return init_(params, - dataset.data.data_handle(), + dataset.data().data_handle(), dataset.encoded_row_length(), - dataset.vq_code_book.data_handle(), - dataset.pq_code_book.data_handle(), + dataset.vq_code_book().data_handle(), + dataset.pq_code_book().data_handle(), IndexT(dataset.n_rows()), dataset.dim()); } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index a767d16530..910f067a36 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -106,10 +106,10 @@ auto make_key(const cagra::search_params& params, cuvs::distance::DistanceType metric) -> std::enable_if_t, key> { - return key{reinterpret_cast(dataset.data.data_handle()), + return key{reinterpret_cast(dataset.data().data_handle()), uint64_t(dataset.n_rows()), dataset.dim(), - uint32_t(reinterpret_cast(dataset.pq_code_book.data_handle()) >> 6), + uint32_t(reinterpret_cast(dataset.pq_code_book().data_handle()) >> 6), uint32_t(params.team_size), uint32_t(metric)}; } diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 7da60ff906..b48d47106b 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -67,9 +67,9 @@ void serialize(const raft::resources& res, raft::serialize_scalar(res, os, dataset.pq_n_centers()); raft::serialize_scalar(res, os, dataset.pq_len()); raft::serialize_scalar(res, os, dataset.encoded_row_length()); - raft::serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); - raft::serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); - raft::serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); + raft::serialize_mdspan(res, os, dataset.vq_code_book()); + raft::serialize_mdspan(res, os, dataset.pq_code_book()); + raft::serialize_mdspan(res, os, dataset.data()); } template @@ -154,7 +154,7 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, pq_code_book.view()); raft::deserialize_mdspan(res, is, data.view()); - return std::make_unique>( + return std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(data)); } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 0f22a1a153..f810cc112b 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -641,10 +641,10 @@ index build( auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( pq_params, - cuvs::neighbors::vpq_dataset{ + std::make_unique>( raft::make_device_matrix(res, 0, 0), std::move(pq_codebook), - raft::make_device_matrix(res, 0, 0)}); + raft::make_device_matrix(res, 0, 0))); const int64_t codes_rowlen = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); quantized_vectors = raft::make_device_matrix(res, n_rows, codes_rowlen); diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index e609100a76..e3d7e9e89f 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -504,22 +504,31 @@ void process_and_fill_codes( } template -auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) - -> vpq_dataset +auto vpq_convert_math_type(const raft::resources& res, + std::unique_ptr>&& src) + -> std::unique_ptr> { - auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); - auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents()); + auto vq_code_book = raft::make_device_mdarray(res, src->vq_code_book().extents()); + auto pq_code_book = raft::make_device_mdarray(res, src->pq_code_book().extents()); raft::linalg::map(res, vq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); + src->vq_code_book()); raft::linalg::map(res, pq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; + src->pq_code_book()); + + // Get the data from the old dataset - need to cast to owning to move the data + auto* owning_src = dynamic_cast*>(src.get()); + RAFT_EXPECTS(owning_src != nullptr, "Cannot convert non-owning vpq_dataset"); + + // Move the data matrix from the source (data type is uint8_t, independent of MathT) + auto data = owning_src->release_data(); + + return std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data)); } // Helper for operations using vectorized loads of raft::TxN_t @@ -902,7 +911,7 @@ void process_and_fill_codes_subspaces( template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset + -> std::unique_ptr> { using label_t = uint32_t; // Use a heuristic to impute missing parameters. @@ -928,8 +937,8 @@ auto vpq_build(const raft::resources& res, const vpq_params& params, const Datas codes.view(), true); - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; + return std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)); } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/vpq_dataset.cuh b/cpp/src/neighbors/vpq_dataset.cuh index 34b011bb0b..3d22eb9eca 100644 --- a/cpp/src/neighbors/vpq_dataset.cuh +++ b/cpp/src/neighbors/vpq_dataset.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -22,13 +22,14 @@ namespace cuvs::neighbors { * @param[in] res * @param[in] params VQ and PQ parameters for compressing the data * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. + * + * @return a unique_ptr to the vpq_dataset */ template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset - + -> std::unique_ptr> { if constexpr (std::is_same_v) { return detail::vpq_convert_math_type( diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 867a6fce26..839be6a698 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -158,8 +158,8 @@ quantizer build( res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } return {filled_params, - cuvs::neighbors::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes)}}; + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}; } /** @@ -205,7 +205,9 @@ quantizer build_view( // Create view-type vpq_dataset auto empty_data = raft::device_matrix_view{}; - return {params, cuvs::neighbors::vpq_dataset{vq_centers, pq_centers, empty_data}}; + return {params, + std::make_unique>( + vq_centers, pq_centers, empty_data)}; } template @@ -227,10 +229,11 @@ void transform( "Output matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); + RAFT_EXPECTS(quant.vpq_codebooks != nullptr, "Quantizer codebooks must be initialized"); // Use view accessors from vpq_dataset - auto vq_centers = quant.vpq_codebooks.vq_code_book(); - auto pq_centers = quant.vpq_codebooks.pq_code_book(); + auto vq_centers = quant.vpq_codebooks->vq_code_book(); + auto pq_centers = quant.vpq_codebooks->pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } @@ -369,13 +372,14 @@ void inverse_transform( "Codes matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); + RAFT_EXPECTS(quant.vpq_codebooks != nullptr, "Quantizer codebooks must be initialized"); // Use view accessors from vpq_dataset reconstruct_vectors(res, quant.params_quantizer, codes, - quant.vpq_codebooks.pq_code_book(), - quant.vpq_codebooks.vq_code_book(), + quant.vpq_codebooks->pq_code_book(), + quant.vpq_codebooks->vq_code_book(), vq_labels, out, quant.params_quantizer.use_subspaces); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index aa6f257e19..274ea26448 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -412,15 +412,16 @@ class ProductQuantizationViewTest : public ::testing::Test { auto owning_quant = build(handle, config, raft::make_const_mdspan(dataset_.view())); // Extract codebook views from owning quantizer using new accessor methods - auto pq_centers_view = owning_quant.vpq_codebooks.pq_code_book(); - auto vq_centers_view = owning_quant.vpq_codebooks.vq_code_book(); + auto pq_centers_view = owning_quant.vpq_codebooks->pq_code_book(); + auto vq_centers_view = owning_quant.vpq_codebooks->vq_code_book(); // Create view-type quantizer from the same codebooks auto view_quant = build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); // Transform using owning quantizer - auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); - auto codes_owning = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); + auto codes_owning = + raft::make_device_matrix(handle, n_samples_, n_encoded_cols); transform(handle, owning_quant, raft::make_const_mdspan(dataset_.view()), @@ -490,8 +491,8 @@ class ProductQuantizationViewTest : public ::testing::Test { // Verify that the owning quantizer reports is_owning() = true // and the view quantizer reports is_owning() = false - ASSERT_TRUE(owning_quant.vpq_codebooks.is_owning()); - ASSERT_FALSE(view_quant.vpq_codebooks.is_owning()); + ASSERT_TRUE(owning_quant.vpq_codebooks->is_owning()); + ASSERT_FALSE(view_quant.vpq_codebooks->is_owning()); } raft::resources handle; From fa70a018cd608bfa74ec61ca463bf33d7a3d0b33 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 05:05:19 -0800 Subject: [PATCH 04/36] fix vpq_build --- cpp/include/cuvs/neighbors/common.hpp | 109 +++++++++++++++--- .../cuvs/preprocessing/quantize/pq.hpp | 11 +- .../neighbors/detail/dataset_serialize.hpp | 5 +- .../neighbors/detail/vamana/vamana_build.cuh | 9 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 41 +++---- cpp/src/neighbors/vpq_dataset.cuh | 4 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 20 ++-- .../preprocessing/product_quantization.cu | 10 +- 8 files changed, 145 insertions(+), 64 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 1bbd7751fb..9886338d9e 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -398,25 +398,20 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b return make_strided_dataset(res, std::forward(src), required_stride); } /** - * @brief VPQ compressed dataset - base interface. + * @brief VPQ compressed dataset - internal interface. * - * The dataset is compressed using two level quantization - * - * 1. Vector Quantization - * 2. Product Quantization of residuals - * - * This is the abstract base class. Use vpq_dataset_owning for owning data - * or vpq_dataset_view for non-owning views. + * This is the abstract base class for the internal implementation. + * Users should use vpq_dataset which wraps this via PIMPL. * * @tparam MathT the type of elements in the codebooks * @tparam IdxT type of the vector indices (represent dataset.extent(0)) */ template -struct vpq_dataset : public dataset { +struct vpq_dataset_iface : public dataset { using index_type = IdxT; using math_type = MathT; - ~vpq_dataset() override = default; + ~vpq_dataset_iface() override = default; /** Get view of VQ codebook. */ [[nodiscard]] virtual auto vq_code_book() const noexcept @@ -482,13 +477,13 @@ struct vpq_dataset : public dataset { }; /** - * @brief Owning VPQ dataset - owns the codebooks and data. + * @brief Owning VPQ dataset implementation - owns the codebooks and data. * * @tparam MathT the type of elements in the codebooks * @tparam IdxT type of the vector indices */ template -struct vpq_dataset_owning : public vpq_dataset { +struct vpq_dataset_owning : public vpq_dataset_iface { using index_type = IdxT; using math_type = MathT; @@ -563,7 +558,7 @@ struct vpq_dataset_owning : public vpq_dataset { }; /** - * @brief View-type VPQ dataset - non-owning views to external data. + * @brief View-type VPQ dataset implementation - non-owning views to external data. * * The caller must ensure the lifetime of the underlying data exceeds * the lifetime of this object. @@ -572,7 +567,7 @@ struct vpq_dataset_owning : public vpq_dataset { * @tparam IdxT type of the vector indices */ template -struct vpq_dataset_view : public vpq_dataset { +struct vpq_dataset_view : public vpq_dataset_iface { using index_type = IdxT; using math_type = MathT; @@ -627,12 +622,98 @@ struct vpq_dataset_view : public vpq_dataset { raft::device_matrix_view data_view_; }; +/** + * @brief VPQ compressed dataset (PIMPL wrapper). + * + * The dataset is compressed using two level quantization: + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * This class wraps the internal implementation (vpq_dataset_owning or vpq_dataset_view) + * and provides a stable API. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + */ +template +struct vpq_dataset : public dataset { + using index_type = IdxT; + using math_type = MathT; + + vpq_dataset() = default; + + /** Construct from an implementation. */ + explicit vpq_dataset(std::unique_ptr> impl) + : impl_{std::move(impl)} + { + } + + vpq_dataset(const vpq_dataset&) = delete; + vpq_dataset& operator=(const vpq_dataset&) = delete; + vpq_dataset(vpq_dataset&&) = default; + vpq_dataset& operator=(vpq_dataset&&) = default; + ~vpq_dataset() override = default; + + // Delegation methods + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } + [[nodiscard]] auto is_owning() const noexcept -> bool override { return impl_->is_owning(); } + + /** Get view of VQ codebook. */ + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->vq_code_book(); + } + + /** Get view of PQ codebook. */ + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->pq_code_book(); + } + + /** Get view of compressed data. */ + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view + { + return impl_->data(); + } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t + { + return impl_->encoded_row_length(); + } + + /** The number of "coarse cluster centers" */ + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } + + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } + + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } + + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } + + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } + + private: + std::unique_ptr> impl_; +}; + template struct is_vpq_dataset : std::false_type {}; template struct is_vpq_dataset> : std::true_type {}; +template +struct is_vpq_dataset> : std::true_type {}; + template struct is_vpq_dataset> : std::true_type {}; diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 078f9464b4..5c4cf13ace 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -10,8 +10,6 @@ #include #include -#include - namespace cuvs::preprocessing::quantize::pq { /** @@ -81,20 +79,19 @@ struct params { /** * @brief Defines and stores VPQ codebooks upon training. * - * The quantizer holds a pointer to a vpq_dataset, which can be either - * owning (vpq_dataset_owning, trained from data) or non-owning - * (vpq_dataset_view, referencing external codebooks). + * The quantizer holds a vpq_dataset, which can be either owning (trained from data) + * or non-owning (referencing external codebooks). * * @tparam T data element type */ template struct quantizer { params params_quantizer; - std::unique_ptr> vpq_codebooks; + cuvs::neighbors::vpq_dataset vpq_codebooks; quantizer() = default; - quantizer(params p, std::unique_ptr> codebooks) + quantizer(params p, cuvs::neighbors::vpq_dataset&& codebooks) : params_quantizer(p), vpq_codebooks(std::move(codebooks)) { } diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index b48d47106b..becd84632f 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -154,8 +154,9 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, pq_code_book.view()); raft::deserialize_mdspan(res, is, data.view()); - return std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(data)); + return std::make_unique>( + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data))); } template diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index f810cc112b..0ae2f41c9f 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -641,10 +641,11 @@ index build( auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( pq_params, - std::make_unique>( - raft::make_device_matrix(res, 0, 0), - std::move(pq_codebook), - raft::make_device_matrix(res, 0, 0))); + cuvs::neighbors::vpq_dataset{ + std::make_unique>( + raft::make_device_matrix(res, 0, 0), + std::move(pq_codebook), + raft::make_device_matrix(res, 0, 0))}); const int64_t codes_rowlen = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); quantized_vectors = raft::make_device_matrix(res, n_rows, codes_rowlen); diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index e3d7e9e89f..ce1a1d31aa 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -504,31 +504,32 @@ void process_and_fill_codes( } template -auto vpq_convert_math_type(const raft::resources& res, - std::unique_ptr>&& src) - -> std::unique_ptr> +auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) + -> vpq_dataset { - auto vq_code_book = raft::make_device_mdarray(res, src->vq_code_book().extents()); - auto pq_code_book = raft::make_device_mdarray(res, src->pq_code_book().extents()); + auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book().extents()); + auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book().extents()); raft::linalg::map(res, vq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, - src->vq_code_book()); + src.vq_code_book()); raft::linalg::map(res, pq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, - src->pq_code_book()); - - // Get the data from the old dataset - need to cast to owning to move the data - auto* owning_src = dynamic_cast*>(src.get()); - RAFT_EXPECTS(owning_src != nullptr, "Cannot convert non-owning vpq_dataset"); - - // Move the data matrix from the source (data type is uint8_t, independent of MathT) - auto data = owning_src->release_data(); - - return std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(data)); + src.pq_code_book()); + + // Copy the data from the source (data type is uint8_t, independent of MathT) + auto data_view = src.data(); + auto data = raft::make_device_matrix( + res, data_view.extent(0), data_view.extent(1)); + raft::copy(data.data_handle(), + data_view.data_handle(), + data_view.size(), + raft::resource::get_cuda_stream(res)); + + return vpq_dataset{std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data))}; } // Helper for operations using vectorized loads of raft::TxN_t @@ -911,7 +912,7 @@ void process_and_fill_codes_subspaces( template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> std::unique_ptr> + -> vpq_dataset { using label_t = uint32_t; // Use a heuristic to impute missing parameters. @@ -937,8 +938,8 @@ auto vpq_build(const raft::resources& res, const vpq_params& params, const Datas codes.view(), true); - return std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)); + return vpq_dataset{std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(codes))}; } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/vpq_dataset.cuh b/cpp/src/neighbors/vpq_dataset.cuh index 3d22eb9eca..cfcef97b4f 100644 --- a/cpp/src/neighbors/vpq_dataset.cuh +++ b/cpp/src/neighbors/vpq_dataset.cuh @@ -23,13 +23,13 @@ namespace cuvs::neighbors { * @param[in] params VQ and PQ parameters for compressing the data * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. * - * @return a unique_ptr to the vpq_dataset + * @return the vpq_dataset */ template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> std::unique_ptr> + -> vpq_dataset { if constexpr (std::is_same_v) { return detail::vpq_convert_math_type( diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 839be6a698..601e3d2c18 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -158,8 +158,9 @@ quantizer build( res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } return {filled_params, - std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}; + cuvs::neighbors::vpq_dataset{ + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}}; } /** @@ -206,8 +207,9 @@ quantizer build_view( // Create view-type vpq_dataset auto empty_data = raft::device_matrix_view{}; return {params, - std::make_unique>( - vq_centers, pq_centers, empty_data)}; + cuvs::neighbors::vpq_dataset{ + std::make_unique>( + vq_centers, pq_centers, empty_data)}}; } template @@ -229,11 +231,10 @@ void transform( "Output matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - RAFT_EXPECTS(quant.vpq_codebooks != nullptr, "Quantizer codebooks must be initialized"); // Use view accessors from vpq_dataset - auto vq_centers = quant.vpq_codebooks->vq_code_book(); - auto pq_centers = quant.vpq_codebooks->pq_code_book(); + auto vq_centers = quant.vpq_codebooks.vq_code_book(); + auto pq_centers = quant.vpq_codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } @@ -372,14 +373,13 @@ void inverse_transform( "Codes matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - RAFT_EXPECTS(quant.vpq_codebooks != nullptr, "Quantizer codebooks must be initialized"); // Use view accessors from vpq_dataset reconstruct_vectors(res, quant.params_quantizer, codes, - quant.vpq_codebooks->pq_code_book(), - quant.vpq_codebooks->vq_code_book(), + quant.vpq_codebooks.pq_code_book(), + quant.vpq_codebooks.vq_code_book(), vq_labels, out, quant.params_quantizer.use_subspaces); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 274ea26448..723435c820 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -411,9 +411,9 @@ class ProductQuantizationViewTest : public ::testing::Test { params config{pq_bits_, pq_dim_, true /* use_subspaces */, false /* use_vq */}; auto owning_quant = build(handle, config, raft::make_const_mdspan(dataset_.view())); - // Extract codebook views from owning quantizer using new accessor methods - auto pq_centers_view = owning_quant.vpq_codebooks->pq_code_book(); - auto vq_centers_view = owning_quant.vpq_codebooks->vq_code_book(); + // Extract codebook views from owning quantizer using accessor methods + auto pq_centers_view = owning_quant.vpq_codebooks.pq_code_book(); + auto vq_centers_view = owning_quant.vpq_codebooks.vq_code_book(); // Create view-type quantizer from the same codebooks auto view_quant = build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); @@ -491,8 +491,8 @@ class ProductQuantizationViewTest : public ::testing::Test { // Verify that the owning quantizer reports is_owning() = true // and the view quantizer reports is_owning() = false - ASSERT_TRUE(owning_quant.vpq_codebooks->is_owning()); - ASSERT_FALSE(view_quant.vpq_codebooks->is_owning()); + ASSERT_TRUE(owning_quant.vpq_codebooks.is_owning()); + ASSERT_FALSE(view_quant.vpq_codebooks.is_owning()); } raft::resources handle; From bf763e30d23f7c72ef783ca70bbc53b598c043ee Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 05:39:45 -0800 Subject: [PATCH 05/36] revert changes to quantizer struct --- cpp/include/cuvs/neighbors/common.hpp | 12 ++------ .../cuvs/preprocessing/quantize/pq.hpp | 18 ++--------- cpp/src/preprocessing/quantize/detail/pq.cuh | 30 +++++++++---------- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 9886338d9e..ab11f5ba14 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -407,12 +407,13 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b * @tparam IdxT type of the vector indices (represent dataset.extent(0)) */ template -struct vpq_dataset_iface : public dataset { +class vpq_dataset_iface : public dataset { using index_type = IdxT; using math_type = MathT; ~vpq_dataset_iface() override = default; + public: /** Get view of VQ codebook. */ [[nodiscard]] virtual auto vq_code_book() const noexcept -> raft::device_matrix_view = 0; @@ -711,15 +712,6 @@ struct is_vpq_dataset : std::false_type {}; template struct is_vpq_dataset> : std::true_type {}; -template -struct is_vpq_dataset> : std::true_type {}; - -template -struct is_vpq_dataset> : std::true_type {}; - -template -struct is_vpq_dataset> : std::true_type {}; - template inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 5c4cf13ace..66fd6a9633 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -77,7 +77,7 @@ struct params { }; /** - * @brief Defines and stores VPQ codebooks upon training. + * @brief Defines and stores VPQ codebooks upon training * * The quantizer holds a vpq_dataset, which can be either owning (trained from data) * or non-owning (referencing external codebooks). @@ -88,18 +88,6 @@ template struct quantizer { params params_quantizer; cuvs::neighbors::vpq_dataset vpq_codebooks; - - quantizer() = default; - - quantizer(params p, cuvs::neighbors::vpq_dataset&& codebooks) - : params_quantizer(p), vpq_codebooks(std::move(codebooks)) - { - } - - quantizer(const quantizer&) = delete; - quantizer& operator=(const quantizer&) = delete; - quantizer(quantizer&&) = default; - quantizer& operator=(quantizer&&) = default; }; /** @@ -120,7 +108,7 @@ struct quantizer { * @param[in] params configure product quantizer, e.g. pq_bits, pq_dim * @param[in] dataset a row-major matrix view on device or host * - * @return quantizer (owning) + * @return quantizer */ quantizer build(raft::resources const& res, const params params, @@ -135,8 +123,6 @@ quantizer build(raft::resources const& res, * @brief Creates a view-type product quantizer from pre-computed codebooks. * * This function creates a non-owning quantizer that references the provided device data. - * The caller is responsible for ensuring the lifetime of the input data exceeds - * the lifetime of the returned quantizer. * * Usage example: * @code{.cpp} diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 601e3d2c18..9e184365aa 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -215,7 +215,7 @@ quantizer build_view( template void transform( raft::resources const& res, - const quantizer& quant, + const quantizer& quantizer, raft::mdspan, raft::row_major, AccessorType> dataset, raft::device_matrix_view pq_codes_out, std::optional> vq_labels = std::nullopt) @@ -227,21 +227,21 @@ void transform( size_t(pq_codes_out.extent(1))); RAFT_EXPECTS(pq_codes_out.extent(0) == dataset.extent(0), "Output matrix must have the same number of rows as the input dataset"); - RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quant.params_quantizer), + RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quantizer.params_quantizer), "Output matrix doesn't have the correct number of columns"); - RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, + RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); // Use view accessors from vpq_dataset - auto vq_centers = quant.vpq_codebooks.vq_code_book(); - auto pq_centers = quant.vpq_codebooks.pq_code_book(); + auto vq_centers = quantizer.vpq_codebooks.vq_code_book(); + auto pq_centers = quantizer.vpq_codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } - if (quant.params_quantizer.use_subspaces) { + if (quantizer.params_quantizer.use_subspaces) { cuvs::neighbors::detail::process_and_fill_codes_subspaces( res, - to_vpq_params(quant.params_quantizer), + to_vpq_params(quantizer.params_quantizer), dataset, pq_centers, vq_centers, @@ -250,7 +250,7 @@ void transform( } else { cuvs::neighbors::detail::process_and_fill_codes( res, - to_vpq_params(quant.params_quantizer), + to_vpq_params(quantizer.params_quantizer), dataset, pq_centers, vq_centers, @@ -355,7 +355,7 @@ auto reconstruct_vectors( template void inverse_transform( raft::resources const& res, - const quantizer& quant, + const quantizer& quantizer, raft::device_matrix_view codes, raft::device_matrix_view out, std::optional> vq_labels = std::nullopt) @@ -369,20 +369,20 @@ void inverse_transform( using idx_t = int64_t; RAFT_EXPECTS(out.extent(0) == codes.extent(0), "Output matrix must have the same number of rows as the input codes"); - RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quant.params_quantizer), + RAFT_EXPECTS(codes.extent(1) == get_quantized_dim(quantizer.params_quantizer), "Codes matrix doesn't have the correct number of columns"); - RAFT_EXPECTS(quant.params_quantizer.pq_bits >= 4 && quant.params_quantizer.pq_bits <= 16, + RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); // Use view accessors from vpq_dataset reconstruct_vectors(res, - quant.params_quantizer, + quantizer.params_quantizer, codes, - quant.vpq_codebooks.pq_code_book(), - quant.vpq_codebooks.vq_code_book(), + quantizer.vpq_codebooks.pq_code_book(), + quantizer.vpq_codebooks.vq_code_book(), vq_labels, out, - quant.params_quantizer.use_subspaces); + quantizer.params_quantizer.use_subspaces); } } // namespace cuvs::preprocessing::quantize::pq::detail From b0aaa05663c1dfa0629cfbb48a8dba7c46f566a7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Feb 2026 06:00:03 -0800 Subject: [PATCH 06/36] make user class pure pimpl --- c/src/preprocessing/quantize/pq.cpp | 4 +- cpp/include/cuvs/neighbors/common.hpp | 211 ++---------------- .../cuvs/preprocessing/quantize/pq.hpp | 9 +- .../neighbors/detail/dataset_serialize.hpp | 3 +- .../neighbors/detail/vamana/vamana_build.cuh | 1 + cpp/src/neighbors/detail/vpq_dataset.cuh | 1 + cpp/src/neighbors/vpq_dataset_impl.hpp | 206 +++++++++++++++++ cpp/src/preprocessing/quantize/pq.cu | 2 +- .../preprocessing/product_quantization.cu | 8 +- 9 files changed, 235 insertions(+), 210 deletions(-) create mode 100644 cpp/src/neighbors/vpq_dataset_impl.hpp diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index 35a8f5ed72..c792f1dcce 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -252,7 +252,7 @@ extern "C" cuvsError_t cuvsProductQuantizerGetPqCodebook(cuvsProductQuantizer_t if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { auto pq_mdspan = (reinterpret_cast*>(quant_addr)) - ->vpq_codebooks.pq_code_book.view(); + ->vpq_codebooks.pq_code_book(); cuvs::core::to_dlpack(pq_mdspan, pq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", @@ -274,7 +274,7 @@ extern "C" cuvsError_t cuvsProductQuantizerGetVqCodebook(cuvsProductQuantizer_t if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { auto pq_mdspan = (reinterpret_cast*>(quant_addr)) - ->vpq_codebooks.vq_code_book.view(); + ->vpq_codebooks.vq_code_book(); cuvs::core::to_dlpack(pq_mdspan, vq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index ab11f5ba14..76a06f1def 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -408,12 +408,12 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b */ template class vpq_dataset_iface : public dataset { + public: using index_type = IdxT; using math_type = MathT; ~vpq_dataset_iface() override = default; - public: /** Get view of VQ codebook. */ [[nodiscard]] virtual auto vq_code_book() const noexcept -> raft::device_matrix_view = 0; @@ -426,201 +426,15 @@ class vpq_dataset_iface : public dataset { [[nodiscard]] virtual auto data() const noexcept -> raft::device_matrix_view = 0; - // Derived properties with default implementations - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data().extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return vq_code_book().extent(1); } - - /** Row length of the encoded data in bytes. */ - [[nodiscard]] inline auto encoded_row_length() const noexcept -> uint32_t - { - return data().extent(1); - } - /** The number of "coarse cluster centers" */ - [[nodiscard]] inline auto vq_n_centers() const noexcept -> uint32_t - { - return vq_code_book().extent(0); - } - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] inline auto pq_bits() const noexcept -> uint32_t - { - /* - NOTE: pq_bits and the book size - - Normally, we'd store `pq_bits` as a part of the index. - However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is - the same as the number of possible code values. Hence, we don't store the pq_bits and derive it - from the array dimensions instead. - */ - auto pq_width = pq_n_centers(); -#ifdef __cpp_lib_bitops - return std::countr_zero(pq_width); -#else - uint32_t bits = 0; - while (pq_width > 1) { - bits++; - pq_width >>= 1; - } - return bits; -#endif - } - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] inline auto pq_dim() const noexcept -> uint32_t - { - return raft::div_rounding_up_unsafe(dim(), pq_len()); - } - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] inline auto pq_len() const noexcept -> uint32_t { return pq_code_book().extent(1); } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] inline auto pq_n_centers() const noexcept -> uint32_t - { - return pq_code_book().extent(0); - } -}; - -/** - * @brief Owning VPQ dataset implementation - owns the codebooks and data. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices - */ -template -struct vpq_dataset_owning : public vpq_dataset_iface { - using index_type = IdxT; - using math_type = MathT; - - /** - * @brief Construct an owning vpq_dataset by moving in the codebooks and data. - */ - vpq_dataset_owning(raft::device_matrix&& vq_code_book, - raft::device_matrix&& pq_code_book, - raft::device_matrix&& data) - : vq_code_book_{std::move(vq_code_book)}, - pq_code_book_{std::move(pq_code_book)}, - data_{std::move(data)} - { - } - - vpq_dataset_owning(const vpq_dataset_owning&) = delete; - vpq_dataset_owning& operator=(const vpq_dataset_owning&) = delete; - vpq_dataset_owning(vpq_dataset_owning&&) = default; - vpq_dataset_owning& operator=(vpq_dataset_owning&&) = default; - ~vpq_dataset_owning() override = default; - - [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } - - [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view override - { - return vq_code_book_.view(); - } - - [[nodiscard]] auto pq_code_book() const noexcept - -> raft::device_matrix_view override - { - return pq_code_book_.view(); - } - - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view override - { - return data_.view(); - } - - // Non-const accessors for building/modifying - [[nodiscard]] auto vq_code_book_mut() noexcept - -> raft::device_matrix_view - { - return vq_code_book_.view(); - } - - [[nodiscard]] auto pq_code_book_mut() noexcept - -> raft::device_matrix_view - { - return pq_code_book_.view(); - } - - [[nodiscard]] auto data_mut() noexcept - -> raft::device_matrix_view - { - return data_.view(); - } - - /** Release ownership of the data matrix (for type conversion operations). */ - [[nodiscard]] auto release_data() noexcept - -> raft::device_matrix - { - return std::move(data_); - } - - private: - raft::device_matrix vq_code_book_; - raft::device_matrix pq_code_book_; - raft::device_matrix data_; -}; - -/** - * @brief View-type VPQ dataset implementation - non-owning views to external data. - * - * The caller must ensure the lifetime of the underlying data exceeds - * the lifetime of this object. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices - */ -template -struct vpq_dataset_view : public vpq_dataset_iface { - using index_type = IdxT; - using math_type = MathT; - - /** - * @brief Construct a view-type vpq_dataset from external codebook views. - * - * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] - * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, - * pq_len] - * @param data_view View of compressed data (can be empty for quantizer-only use) - */ - vpq_dataset_view( - raft::device_matrix_view vq_code_book_view, - raft::device_matrix_view pq_code_book_view, - raft::device_matrix_view data_view = - raft::device_matrix_view{}) - : vq_code_book_view_{vq_code_book_view}, - pq_code_book_view_{pq_code_book_view}, - data_view_{data_view} - { - } - - vpq_dataset_view(const vpq_dataset_view&) = default; - vpq_dataset_view& operator=(const vpq_dataset_view&) = default; - vpq_dataset_view(vpq_dataset_view&&) = default; - vpq_dataset_view& operator=(vpq_dataset_view&&) = default; - ~vpq_dataset_view() override = default; - - [[nodiscard]] auto is_owning() const noexcept -> bool override { return false; } - - [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view override - { - return vq_code_book_view_; - } - - [[nodiscard]] auto pq_code_book() const noexcept - -> raft::device_matrix_view override - { - return pq_code_book_view_; - } - - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view override - { - return data_view_; - } - - private: - raft::device_matrix_view vq_code_book_view_; - raft::device_matrix_view pq_code_book_view_; - raft::device_matrix_view data_view_; + // Derived properties - pure virtual + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto encoded_row_length() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; }; /** @@ -637,7 +451,8 @@ struct vpq_dataset_view : public vpq_dataset_iface { * @tparam IdxT type of the vector indices (represent dataset.extent(0)) */ template -struct vpq_dataset : public dataset { +class vpq_dataset : public dataset { + public: using index_type = IdxT; using math_type = MathT; @@ -658,7 +473,7 @@ struct vpq_dataset : public dataset { // Delegation methods [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } - [[nodiscard]] auto is_owning() const noexcept -> bool override { return impl_->is_owning(); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } /** Get view of VQ codebook. */ [[nodiscard]] auto vq_code_book() const noexcept diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 66fd6a9633..6d4e8082c4 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -150,11 +150,10 @@ quantizer build(raft::resources const& res, * * @return A view-type quantizer that references the provided data */ -quantizer build( - raft::resources const& res, - const params params, - raft::device_matrix_view pq_centers, - raft::device_matrix_view vq_centers); +quantizer build(raft::resources const& res, + const params params, + raft::device_matrix_view pq_centers, + raft::device_matrix_view vq_centers); /** * @brief Applies quantization transform to given dataset diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index becd84632f..758a1e9898 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -1,9 +1,10 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include "../../neighbors/vpq_dataset_impl.hpp" #include #include diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index af8a866ae6..f30d1e5051 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -5,6 +5,7 @@ #pragma once +#include "../../../neighbors/vpq_dataset_impl.hpp" #include "../../../sparse/neighbors/cross_component_nn.cuh" #include "../../detail/ann_utils.cuh" #include "greedy_search.cuh" diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index ce1a1d31aa..badeadeeab 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -4,6 +4,7 @@ */ #pragma once +#include "../../neighbors/vpq_dataset_impl.hpp" #include #include "../../cluster/kmeans_balanced.cuh" diff --git a/cpp/src/neighbors/vpq_dataset_impl.hpp b/cpp/src/neighbors/vpq_dataset_impl.hpp new file mode 100644 index 0000000000..3bd4b0b8a3 --- /dev/null +++ b/cpp/src/neighbors/vpq_dataset_impl.hpp @@ -0,0 +1,206 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include + +#ifdef __cpp_lib_bitops +#include +#endif + +namespace cuvs::neighbors { + +/** + * @brief Common VPQ dataset implementation - provides shared implementations. + * + * This class contains the common implementations for derived properties + * that are shared between owning and view implementations. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices + */ +template +class vpq_dataset_impl : public vpq_dataset_iface { + public: + using index_type = IdxT; + using math_type = MathT; + + // Derived properties with default implementations + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return this->data().extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return this->vq_code_book().extent(1); } + [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] inline auto encoded_row_length() const noexcept -> uint32_t override + { + return this->data().extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] inline auto vq_n_centers() const noexcept -> uint32_t override + { + return this->vq_code_book().extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] inline auto pq_bits() const noexcept -> uint32_t override + { + /* + NOTE: pq_bits and the book size + + Normally, we'd store `pq_bits` as a part of the index. + However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is + the same as the number of possible code values. Hence, we don't store the pq_bits and derive it + from the array dimensions instead. + */ + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] inline auto pq_dim() const noexcept -> uint32_t override + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] inline auto pq_len() const noexcept -> uint32_t override + { + return this->pq_code_book().extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] inline auto pq_n_centers() const noexcept -> uint32_t override + { + return this->pq_code_book().extent(0); + } +}; + +/** + * @brief Owning VPQ dataset implementation - owns the codebooks and data. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices + */ +template +class vpq_dataset_owning : public vpq_dataset_impl { + public: + using index_type = IdxT; + using math_type = MathT; + + /** + * @brief Construct an owning vpq_dataset by moving in the codebooks and data. + */ + vpq_dataset_owning(raft::device_matrix&& vq_code_book, + raft::device_matrix&& pq_code_book, + raft::device_matrix&& data) + : vq_code_book_{std::move(vq_code_book)}, + pq_code_book_{std::move(pq_code_book)}, + data_{std::move(data)} + { + } + + vpq_dataset_owning(const vpq_dataset_owning&) = delete; + vpq_dataset_owning& operator=(const vpq_dataset_owning&) = delete; + vpq_dataset_owning(vpq_dataset_owning&&) = default; + vpq_dataset_owning& operator=(vpq_dataset_owning&&) = default; + ~vpq_dataset_owning() override = default; + + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view override + { + return vq_code_book_.view(); + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view override + { + return pq_code_book_.view(); + } + + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view override + { + return data_.view(); + } + + private: + raft::device_matrix vq_code_book_; + raft::device_matrix pq_code_book_; + raft::device_matrix data_; +}; + +/** + * @brief View-type VPQ dataset implementation - non-owning views to external data. + * + * The caller must ensure the lifetime of the underlying data exceeds + * the lifetime of this object. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices + */ +template +class vpq_dataset_view : public vpq_dataset_impl { + public: + using index_type = IdxT; + using math_type = MathT; + + /** + * @brief Construct a view-type vpq_dataset from external codebook views. + * + * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] + * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, + * pq_len] + * @param data_view View of compressed data (can be empty for quantizer-only use) + */ + vpq_dataset_view( + raft::device_matrix_view vq_code_book_view, + raft::device_matrix_view pq_code_book_view, + raft::device_matrix_view data_view = + raft::device_matrix_view{}) + : vq_code_book_view_{vq_code_book_view}, + pq_code_book_view_{pq_code_book_view}, + data_view_{data_view} + { + } + + vpq_dataset_view(const vpq_dataset_view&) = default; + vpq_dataset_view& operator=(const vpq_dataset_view&) = default; + vpq_dataset_view(vpq_dataset_view&&) = default; + vpq_dataset_view& operator=(vpq_dataset_view&&) = default; + ~vpq_dataset_view() override = default; + + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view override + { + return vq_code_book_view_; + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view override + { + return pq_code_book_view_; + } + + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view override + { + return data_view_; + } + + private: + raft::device_matrix_view vq_code_book_view_; + raft::device_matrix_view pq_code_book_view_; + raft::device_matrix_view data_view_; +}; + +} // namespace cuvs::neighbors diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index c01e832b4a..22939b96d6 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -26,7 +26,7 @@ namespace cuvs::preprocessing::quantize::pq { const params params, \ raft::device_matrix_view pq_centers, \ raft::device_matrix_view vq_centers) \ - ->quantizer \ + -> quantizer \ { \ return detail::build_view(res, params, pq_centers, vq_centers); \ } \ diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 723435c820..499ab571de 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -416,10 +416,11 @@ class ProductQuantizationViewTest : public ::testing::Test { auto vq_centers_view = owning_quant.vpq_codebooks.vq_code_book(); // Create view-type quantizer from the same codebooks - auto view_quant = build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); + auto view_quant = + build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); // Transform using owning quantizer - auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); + auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); auto codes_owning = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); transform(handle, @@ -429,7 +430,8 @@ class ProductQuantizationViewTest : public ::testing::Test { std::nullopt); // Transform using view-type quantizer - auto codes_view = raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + auto codes_view = + raft::make_device_matrix(handle, n_samples_, n_encoded_cols); transform(handle, view_quant, raft::make_const_mdspan(dataset_.view()), From 04be0a0d1d216f905bdd8401666831283e5885d7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 13 Mar 2026 16:20:22 -0700 Subject: [PATCH 07/36] fixes --- cpp/include/cuvs/neighbors/common.hpp | 1 - .../cuvs/preprocessing/quantize/pq.hpp | 4 ++-- cpp/src/neighbors/vpq_dataset_impl.hpp | 23 +++++++++---------- cpp/src/preprocessing/quantize/detail/pq.cuh | 4 ++-- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index f1ad45d386..4f0356866f 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -466,7 +466,6 @@ class vpq_dataset : public dataset { vpq_dataset& operator=(vpq_dataset&&) = default; ~vpq_dataset() override = default; - // Delegation methods [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 6d4e8082c4..e4eb371cbe 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -79,7 +79,7 @@ struct params { /** * @brief Defines and stores VPQ codebooks upon training * - * The quantizer holds a vpq_dataset, which can be either owning (trained from data) + * The quantizer holds a vpq_dataset, which can either own the codebooks * or non-owning (referencing external codebooks). * * @tparam T data element type @@ -122,7 +122,7 @@ quantizer build(raft::resources const& res, /** * @brief Creates a view-type product quantizer from pre-computed codebooks. * - * This function creates a non-owning quantizer that references the provided device data. + * This function creates a non-owning quantizer that references the provided codebooks. * * Usage example: * @code{.cpp} diff --git a/cpp/src/neighbors/vpq_dataset_impl.hpp b/cpp/src/neighbors/vpq_dataset_impl.hpp index 3bd4b0b8a3..cf381aaaab 100644 --- a/cpp/src/neighbors/vpq_dataset_impl.hpp +++ b/cpp/src/neighbors/vpq_dataset_impl.hpp @@ -140,10 +140,10 @@ class vpq_dataset_owning : public vpq_dataset_impl { }; /** - * @brief View-type VPQ dataset implementation - non-owning views to external data. + * @brief View-type VPQ dataset implementation - owns the dataset but not the codebooks. * - * The caller must ensure the lifetime of the underlying data exceeds - * the lifetime of this object. + * The caller must ensure the lifetime of the codebook data exceeds + * the lifetime of this object. The dataset is owned by this object. * * @tparam MathT the type of elements in the codebooks * @tparam IdxT type of the vector indices @@ -155,21 +155,20 @@ class vpq_dataset_view : public vpq_dataset_impl { using math_type = MathT; /** - * @brief Construct a view-type vpq_dataset from external codebook views. + * @brief Construct a vpq_dataset that owns the dataset but references the codebooks. * - * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] + * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] (non-owning) * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, - * pq_len] - * @param data_view View of compressed data (can be empty for quantizer-only use) + * pq_len] (non-owning) + * @param data Compressed data matrix (moved, owned by this object) */ vpq_dataset_view( raft::device_matrix_view vq_code_book_view, raft::device_matrix_view pq_code_book_view, - raft::device_matrix_view data_view = - raft::device_matrix_view{}) + raft::device_matrix&& data) : vq_code_book_view_{vq_code_book_view}, pq_code_book_view_{pq_code_book_view}, - data_view_{data_view} + data_{std::move(data)} { } @@ -194,13 +193,13 @@ class vpq_dataset_view : public vpq_dataset_impl { [[nodiscard]] auto data() const noexcept -> raft::device_matrix_view override { - return data_view_; + return data_.view(); } private: raft::device_matrix_view vq_code_book_view_; raft::device_matrix_view pq_code_book_view_; - raft::device_matrix_view data_view_; + raft::device_matrix data_; }; } // namespace cuvs::neighbors diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 9e184365aa..6f08e01655 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -205,11 +205,11 @@ quantizer build_view( } // Create view-type vpq_dataset - auto empty_data = raft::device_matrix_view{}; + auto empty_data = raft::make_device_matrix(res, 0, 0); return {params, cuvs::neighbors::vpq_dataset{ std::make_unique>( - vq_centers, pq_centers, empty_data)}}; + vq_centers, pq_centers, std::move(empty_data))}}; } template From f80280e5401906e360257e8990f08263a47f0e77 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 16 Mar 2026 10:40:47 -0700 Subject: [PATCH 08/36] style --- .../neighbors/detail/cagra/compute_distance_vpq.hpp | 2 +- cpp/src/neighbors/detail/cagra/factory.cuh | 2 +- cpp/src/neighbors/vpq_dataset_impl.hpp | 12 +++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 995265d232..09afbad6ed 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 910f067a36..a831327be7 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/neighbors/vpq_dataset_impl.hpp b/cpp/src/neighbors/vpq_dataset_impl.hpp index cf381aaaab..3f79d27368 100644 --- a/cpp/src/neighbors/vpq_dataset_impl.hpp +++ b/cpp/src/neighbors/vpq_dataset_impl.hpp @@ -31,8 +31,14 @@ class vpq_dataset_impl : public vpq_dataset_iface { using math_type = MathT; // Derived properties with default implementations - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return this->data().extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return this->vq_code_book().extent(1); } + [[nodiscard]] auto n_rows() const noexcept -> index_type override + { + return this->data().extent(0); + } + [[nodiscard]] auto dim() const noexcept -> uint32_t override + { + return this->vq_code_book().extent(1); + } [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } /** Row length of the encoded data in bytes. */ @@ -150,7 +156,7 @@ class vpq_dataset_owning : public vpq_dataset_impl { */ template class vpq_dataset_view : public vpq_dataset_impl { - public: + public: using index_type = IdxT; using math_type = MathT; From 51e8209a28ed1057ab0905822421cb182ec955ff Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 16 Mar 2026 11:54:09 -0700 Subject: [PATCH 09/36] fix tests --- cpp/src/neighbors/vpq_dataset_impl.hpp | 6 +- .../preprocessing/product_quantization.cu | 202 ++++++------------ 2 files changed, 65 insertions(+), 143 deletions(-) diff --git a/cpp/src/neighbors/vpq_dataset_impl.hpp b/cpp/src/neighbors/vpq_dataset_impl.hpp index 3f79d27368..e6bf96a522 100644 --- a/cpp/src/neighbors/vpq_dataset_impl.hpp +++ b/cpp/src/neighbors/vpq_dataset_impl.hpp @@ -99,9 +99,9 @@ class vpq_dataset_impl : public vpq_dataset_iface { */ template class vpq_dataset_owning : public vpq_dataset_impl { - public: - using index_type = IdxT; - using math_type = MathT; + public: + using index_type = IdxT; + using math_type = MathT; /** * @brief Construct an owning vpq_dataset by moving in the codebooks and data. diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 499ab571de..d0eaff4963 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -238,6 +238,67 @@ class ProductQuantizationTest : public ::testing::TestWithParam(n_samples_) < params_.n_vq_centers)) { + EXPECT_THROW(build(handle, config, raft::make_const_mdspan(dataset_.view())), + raft::logic_error); + return; + } + auto owning_quant = build(handle, config, raft::make_const_mdspan(dataset_.view())); + + auto pq_centers_view = owning_quant.vpq_codebooks.pq_code_book(); + auto vq_centers_view = owning_quant.vpq_codebooks.vq_code_book(); + + auto view_quant = + build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); + + auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); + auto codes_owning = + raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + transform(handle, + owning_quant, + raft::make_const_mdspan(dataset_.view()), + codes_owning.view(), + std::nullopt); + + auto codes_view = + raft::make_device_matrix(handle, n_samples_, n_encoded_cols); + transform(handle, + view_quant, + raft::make_const_mdspan(dataset_.view()), + codes_view.view(), + std::nullopt); + + cuvs::devArrMatch(codes_owning.data_handle(), + codes_view.data_handle(), + n_samples_ * n_encoded_cols, + cuvs::Compare()); + + auto reconstructed_owning = + raft::make_device_matrix(handle, n_samples_, n_features_); + auto reconstructed_view = + raft::make_device_matrix(handle, n_samples_, n_features_); + + inverse_transform(handle, + owning_quant, + raft::make_const_mdspan(codes_owning.view()), + reconstructed_owning.view(), + std::nullopt); + inverse_transform(handle, + view_quant, + raft::make_const_mdspan(codes_view.view()), + reconstructed_view.view(), + std::nullopt); + cuvs::devArrMatch(reconstructed_owning.data_handle(), + reconstructed_view.data_handle(), + n_samples_ * n_features_, + cuvs::Compare()); + } + private: raft::resources handle; cudaStream_t stream; @@ -367,148 +428,9 @@ const std::vector> inputs = { typedef ProductQuantizationTest ProductQuantizationTestF; TEST_P(ProductQuantizationTestF, Result) { this->testProductQuantizationFromDataset(); } - +TEST_P(ProductQuantizationTestF, ViewQuantizer) { this->testViewQuantizer(); } INSTANTIATE_TEST_CASE_P(ProductQuantizationTests, ProductQuantizationTestF, ::testing::ValuesIn(inputs)); -// Test for view-type quantizer -class ProductQuantizationViewTest : public ::testing::Test { - public: - ProductQuantizationViewTest() - : handle{}, - stream{raft::resource::get_cuda_stream(handle)}, - n_samples_{1000}, - n_features_{128}, - pq_bits_{8}, - pq_dim_{32}, - dataset_{raft::make_device_matrix(handle, n_samples_, n_features_)} - { - } - - protected: - void SetUp() override - { - // Generate random dataset - auto labels = raft::make_device_vector(handle, n_samples_); - raft::random::make_blobs(handle, - dataset_.view(), - labels.view(), - 5, - std::nullopt, - std::nullopt, - 1.0f, - true, - -10.0f, - 10.0f, - 42ULL); - raft::resource::sync_stream(handle); - } - - void testViewQuantizerProducesSameResultsAsOwning() - { - // Build owning quantizer - params config{pq_bits_, pq_dim_, true /* use_subspaces */, false /* use_vq */}; - auto owning_quant = build(handle, config, raft::make_const_mdspan(dataset_.view())); - - // Extract codebook views from owning quantizer using accessor methods - auto pq_centers_view = owning_quant.vpq_codebooks.pq_code_book(); - auto vq_centers_view = owning_quant.vpq_codebooks.vq_code_book(); - - // Create view-type quantizer from the same codebooks - auto view_quant = - build(handle, owning_quant.params_quantizer, pq_centers_view, vq_centers_view); - - // Transform using owning quantizer - auto n_encoded_cols = get_quantized_dim(owning_quant.params_quantizer); - auto codes_owning = - raft::make_device_matrix(handle, n_samples_, n_encoded_cols); - transform(handle, - owning_quant, - raft::make_const_mdspan(dataset_.view()), - codes_owning.view(), - std::nullopt); - - // Transform using view-type quantizer - auto codes_view = - raft::make_device_matrix(handle, n_samples_, n_encoded_cols); - transform(handle, - view_quant, - raft::make_const_mdspan(dataset_.view()), - codes_view.view(), - std::nullopt); - - raft::resource::sync_stream(handle); - - // Compare results - should be identical - auto h_codes_owning = raft::make_host_matrix(n_samples_, n_encoded_cols); - auto h_codes_view = raft::make_host_matrix(n_samples_, n_encoded_cols); - raft::copy(h_codes_owning.data_handle(), - codes_owning.data_handle(), - n_samples_ * n_encoded_cols, - stream); - raft::copy( - h_codes_view.data_handle(), codes_view.data_handle(), n_samples_ * n_encoded_cols, stream); - raft::resource::sync_stream(handle); - - for (int64_t i = 0; i < n_samples_ * n_encoded_cols; i++) { - ASSERT_EQ(h_codes_owning.data_handle()[i], h_codes_view.data_handle()[i]) - << "Mismatch at index " << i; - } - - // Test inverse_transform - auto reconstructed_owning = - raft::make_device_matrix(handle, n_samples_, n_features_); - auto reconstructed_view = - raft::make_device_matrix(handle, n_samples_, n_features_); - - inverse_transform(handle, - owning_quant, - raft::make_const_mdspan(codes_owning.view()), - reconstructed_owning.view(), - std::nullopt); - inverse_transform(handle, - view_quant, - raft::make_const_mdspan(codes_view.view()), - reconstructed_view.view(), - std::nullopt); - - raft::resource::sync_stream(handle); - - // Compare reconstructions - auto h_rec_owning = raft::make_host_matrix(n_samples_, n_features_); - auto h_rec_view = raft::make_host_matrix(n_samples_, n_features_); - raft::copy(h_rec_owning.data_handle(), - reconstructed_owning.data_handle(), - n_samples_ * n_features_, - stream); - raft::copy( - h_rec_view.data_handle(), reconstructed_view.data_handle(), n_samples_ * n_features_, stream); - raft::resource::sync_stream(handle); - - for (int64_t i = 0; i < n_samples_ * n_features_; i++) { - ASSERT_FLOAT_EQ(h_rec_owning.data_handle()[i], h_rec_view.data_handle()[i]) - << "Reconstruction mismatch at index " << i; - } - - // Verify that the owning quantizer reports is_owning() = true - // and the view quantizer reports is_owning() = false - ASSERT_TRUE(owning_quant.vpq_codebooks.is_owning()); - ASSERT_FALSE(view_quant.vpq_codebooks.is_owning()); - } - - raft::resources handle; - cudaStream_t stream; - int64_t n_samples_; - int64_t n_features_; - uint32_t pq_bits_; - uint32_t pq_dim_; - raft::device_matrix dataset_; -}; - -TEST_F(ProductQuantizationViewTest, ViewQuantizerProducesSameResults) -{ - testViewQuantizerProducesSameResultsAsOwning(); -} - } // namespace cuvs::preprocessing::quantize::pq From ebfc7d2e1b2c6253afcff827edcb8e9df778ca24 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 19 Mar 2026 15:40:15 -0700 Subject: [PATCH 10/36] move vpq_dataset class --- cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h | 14 +- cpp/include/cuvs/neighbors/common.hpp | 131 ---------------- .../cuvs/preprocessing/quantize/pq.hpp | 3 +- .../preprocessing/quantize/vpq_dataset.hpp | 144 ++++++++++++++++++ .../neighbors/detail/cagra/cagra_search.cuh | 9 +- .../detail/cagra/compute_distance_vpq.hpp | 5 +- cpp/src/neighbors/detail/cagra/factory.cuh | 2 +- .../neighbors/detail/dataset_serialize.hpp | 15 +- .../neighbors/detail/vamana/vamana_build.cuh | 2 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 21 +-- .../neighbors/scann/detail/scann_quantize.cuh | 2 +- cpp/src/neighbors/vpq_dataset.cuh | 2 +- cpp/src/neighbors/vpq_dataset_impl.hpp | 8 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 6 +- 14 files changed, 196 insertions(+), 168 deletions(-) create mode 100644 cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index 34fea2f82a..b95ce0d7a9 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include #include @@ -357,8 +359,10 @@ void cuvs_cagra::set_search_dataset(const T* dataset, size_t nrow) } else { using ds_idx_type = decltype(index_->data().n_rows()); bool is_vpq = - dynamic_cast*>(&index_->data()) || - dynamic_cast*>(&index_->data()); + dynamic_cast*>( + &index_->data()) || + dynamic_cast*>( + &index_->data()); // It can happen that we are re-using a previous algo object which already has // the dataset set. Check if we need update. if (static_cast(input_dataset_v_->extent(0)) != nrow || @@ -385,8 +389,10 @@ void cuvs_cagra::save(const std::string& file) const } else { using ds_idx_type = decltype(index_->data().n_rows()); bool is_vpq = - dynamic_cast*>(&index_->data()) || - dynamic_cast*>(&index_->data()); + dynamic_cast*>( + &index_->data()) || + dynamic_cast*>( + &index_->data()); cuvs::neighbors::cagra::serialize(handle_, file, *index_, is_vpq); } } diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 4f0356866f..93b8d15bc3 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -393,137 +393,6 @@ auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_b raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; return make_strided_dataset(res, std::forward(src), required_stride); } -/** - * @brief VPQ compressed dataset - internal interface. - * - * This is the abstract base class for the internal implementation. - * Users should use vpq_dataset which wraps this via PIMPL. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - */ -template -class vpq_dataset_iface : public dataset { - public: - using index_type = IdxT; - using math_type = MathT; - - ~vpq_dataset_iface() override = default; - - /** Get view of VQ codebook. */ - [[nodiscard]] virtual auto vq_code_book() const noexcept - -> raft::device_matrix_view = 0; - - /** Get view of PQ codebook. */ - [[nodiscard]] virtual auto pq_code_book() const noexcept - -> raft::device_matrix_view = 0; - - /** Get view of compressed data. */ - [[nodiscard]] virtual auto data() const noexcept - -> raft::device_matrix_view = 0; - - // Derived properties - pure virtual - [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto encoded_row_length() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; -}; - -/** - * @brief VPQ compressed dataset (PIMPL wrapper). - * - * The dataset is compressed using two level quantization: - * 1. Vector Quantization - * 2. Product Quantization of residuals - * - * This class wraps the internal implementation (vpq_dataset_owning or vpq_dataset_view) - * and provides a stable API. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - */ -template -class vpq_dataset : public dataset { - public: - using index_type = IdxT; - using math_type = MathT; - - vpq_dataset() = default; - - /** Construct from an implementation. */ - explicit vpq_dataset(std::unique_ptr> impl) - : impl_{std::move(impl)} - { - } - - vpq_dataset(const vpq_dataset&) = delete; - vpq_dataset& operator=(const vpq_dataset&) = delete; - vpq_dataset(vpq_dataset&&) = default; - vpq_dataset& operator=(vpq_dataset&&) = default; - ~vpq_dataset() override = default; - - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } - - /** Get view of VQ codebook. */ - [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view - { - return impl_->vq_code_book(); - } - - /** Get view of PQ codebook. */ - [[nodiscard]] auto pq_code_book() const noexcept - -> raft::device_matrix_view - { - return impl_->pq_code_book(); - } - - /** Get view of compressed data. */ - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view - { - return impl_->data(); - } - - /** Row length of the encoded data in bytes. */ - [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t - { - return impl_->encoded_row_length(); - } - - /** The number of "coarse cluster centers" */ - [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } - - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } - - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } - - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } - - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } - - private: - std::unique_ptr> impl_; -}; - -template -struct is_vpq_dataset : std::false_type {}; - -template -struct is_vpq_dataset> : std::true_type {}; - -template -inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; namespace filtering { diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index e4eb371cbe..884261c08d 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -87,7 +88,7 @@ struct params { template struct quantizer { params params_quantizer; - cuvs::neighbors::vpq_dataset vpq_codebooks; + cuvs::preprocessing::quantize::pq::vpq_dataset vpq_codebooks; }; /** diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp new file mode 100644 index 0000000000..4ce90418b9 --- /dev/null +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -0,0 +1,144 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + + #pragma once + + #include + +namespace cuvs::preprocessing::quantize::pq { + +/** + * @brief VPQ compressed dataset - internal interface. + * + * This is the abstract base class for the internal implementation. + * Users should use vpq_dataset which wraps this via PIMPL. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + */ + template + class vpq_dataset_iface : public cuvs::neighbors::dataset { + public: + using index_type = IdxT; + using math_type = MathT; + + ~vpq_dataset_iface() override = default; + + /** Get view of VQ codebook. */ + [[nodiscard]] virtual auto vq_code_book() const noexcept + -> raft::device_matrix_view = 0; + + /** Get view of PQ codebook. */ + [[nodiscard]] virtual auto pq_code_book() const noexcept + -> raft::device_matrix_view = 0; + + /** Get view of compressed data. */ + [[nodiscard]] virtual auto data() const noexcept + -> raft::device_matrix_view = 0; + + // Derived properties - pure virtual + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto encoded_row_length() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; + }; + + /** + * @brief VPQ compressed dataset (PIMPL wrapper). + * + * The dataset is compressed using two level quantization: + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * This class wraps the internal implementation (vpq_dataset_owning or vpq_dataset_view) + * and provides a stable API. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + */ + template + class vpq_dataset : public cuvs::neighbors::dataset { + public: + using index_type = IdxT; + using math_type = MathT; + + vpq_dataset() = default; + + /** Construct from an implementation. */ + explicit vpq_dataset(std::unique_ptr> impl) + : impl_{std::move(impl)} + { + } + + vpq_dataset(const vpq_dataset&) = delete; + vpq_dataset& operator=(const vpq_dataset&) = delete; + vpq_dataset(vpq_dataset&&) = default; + vpq_dataset& operator=(vpq_dataset&&) = default; + ~vpq_dataset() override = default; + + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + + /** Get view of VQ codebook. */ + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->vq_code_book(); + } + + /** Get view of PQ codebook. */ + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->pq_code_book(); + } + + /** Get view of compressed data. */ + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view + { + return impl_->data(); + } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t + { + return impl_->encoded_row_length(); + } + + /** The number of "coarse cluster centers" */ + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } + + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } + + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } + + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } + + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } + + private: + std::unique_ptr> impl_; + }; + + template + struct is_vpq_dataset : std::false_type {}; + + template + struct is_vpq_dataset> : std::true_type {}; + + template + inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; + +} // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index f1650980e0..9db2afbde0 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -24,6 +24,7 @@ #include "../../ivf_common.cuh" #include "../../ivf_pq/ivf_pq_search.cuh" #include +#include // TODO: This shouldn't be calling spatial/knn apis #include "../ann_utils.cuh" @@ -173,11 +174,15 @@ void search_main(raft::resources const& res, neighbors, distances, sample_filter); - } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); + } else if (auto* vpq_dset = dynamic_cast< + const cuvs::preprocessing::quantize::pq::vpq_dataset*>( + &index.data()); vpq_dset != nullptr) { // Search using a compressed dataset RAFT_FAIL("FP32 VPQ dataset support is coming soon"); - } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); + } else if (auto* vpq_dset = dynamic_cast< + const cuvs::preprocessing::quantize::pq::vpq_dataset*>( + &index.data()); vpq_dset != nullptr) { auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric(), nullptr); diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 09afbad6ed..2d039d699e 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -8,6 +8,7 @@ #include "compute_distance.hpp" #include +#include #include @@ -31,14 +32,14 @@ struct vpq_descriptor_spec : public instance_spec { template constexpr static inline auto accepts_dataset() - -> std::enable_if_t, bool> + -> std::enable_if_t, bool> { return std::is_same_v; } template constexpr static inline auto accepts_dataset() - -> std::enable_if_t, bool> + -> std::enable_if_t, bool> { return false; } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index a831327be7..0c971b3d63 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -104,7 +104,7 @@ template auto make_key(const cagra::search_params& params, const DatasetT& dataset, cuvs::distance::DistanceType metric) - -> std::enable_if_t, key> + -> std::enable_if_t, key> { return key{reinterpret_cast(dataset.data().data_handle()), uint64_t(dataset.n_rows()), diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 758a1e9898..7b6b2cc084 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -6,6 +6,7 @@ #include "../../neighbors/vpq_dataset_impl.hpp" #include +#include #include #include @@ -60,7 +61,7 @@ void serialize(const raft::resources& res, template void serialize(const raft::resources& res, std::ostream& os, - const vpq_dataset& dataset) + const cuvs::preprocessing::quantize::pq::vpq_dataset& dataset) { raft::serialize_scalar(res, os, dataset.n_rows()); raft::serialize_scalar(res, os, dataset.dim()); @@ -100,12 +101,16 @@ void serialize(const raft::resources& res, std::ostream& os, const dataset raft::serialize_scalar(res, os, CUDA_R_8U); return serialize(res, os, *x); } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + if (auto x = + dynamic_cast*>(&dataset); + x != nullptr) { raft::serialize_scalar(res, os, kSerializeVPQDataset); raft::serialize_scalar(res, os, CUDA_R_32F); return serialize(res, os, *x); } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + if (auto x = + dynamic_cast*>(&dataset); + x != nullptr) { raft::serialize_scalar(res, os, kSerializeVPQDataset); raft::serialize_scalar(res, os, CUDA_R_16F); return serialize(res, os, *x); @@ -135,7 +140,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is) template auto deserialize_vpq(raft::resources const& res, std::istream& is) - -> std::unique_ptr> + -> std::unique_ptr> { auto n_rows = raft::deserialize_scalar(res, is); auto dim = raft::deserialize_scalar(res, is); @@ -155,7 +160,7 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, pq_code_book.view()); raft::deserialize_mdspan(res, is, data.view()); - return std::make_unique>( + return std::make_unique>( std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(data))); } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 0b768dc02e..e28c1db1fc 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -647,7 +647,7 @@ index build( auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( pq_params, - cuvs::neighbors::vpq_dataset{ + cuvs::preprocessing::quantize::pq::vpq_dataset{ std::make_unique>( raft::make_device_matrix(res, 0, 0), std::move(pq_codebook), diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index badeadeeab..a73b58d383 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -412,7 +412,7 @@ void process_and_fill_codes( bool inline_vq_labels = false) { using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; + using cdataset_t = cuvs::preprocessing::quantize::pq::vpq_dataset; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); @@ -505,8 +505,9 @@ void process_and_fill_codes( } template -auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) - -> vpq_dataset +auto vpq_convert_math_type(const raft::resources& res, + cuvs::preprocessing::quantize::pq::vpq_dataset&& src) + -> cuvs::preprocessing::quantize::pq::vpq_dataset { auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book().extents()); auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book().extents()); @@ -529,8 +530,9 @@ auto vpq_convert_math_type(const raft::resources& res, vpq_dataset{std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(data))}; + return cuvs::preprocessing::quantize::pq::vpq_dataset{ + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data))}; } // Helper for operations using vectorized loads of raft::TxN_t @@ -808,7 +810,7 @@ void process_and_fill_codes_subspaces( raft::device_matrix_view codes) { using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; + using cdataset_t = cuvs::preprocessing::quantize::pq::vpq_dataset; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); @@ -913,7 +915,7 @@ void process_and_fill_codes_subspaces( template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset + -> cuvs::preprocessing::quantize::pq::vpq_dataset { using label_t = uint32_t; // Use a heuristic to impute missing parameters. @@ -939,8 +941,9 @@ auto vpq_build(const raft::resources& res, const vpq_params& params, const Datas codes.view(), true); - return vpq_dataset{std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(codes))}; + return cuvs::preprocessing::quantize::pq::vpq_dataset{ + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(codes))}; } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/scann/detail/scann_quantize.cuh b/cpp/src/neighbors/scann/detail/scann_quantize.cuh index 16ef1f4295..b8b128679a 100644 --- a/cpp/src/neighbors/scann/detail/scann_quantize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_quantize.cuh @@ -75,7 +75,7 @@ auto process_and_fill_codes_subspaces( -> raft::device_matrix { using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; + using cdataset_t = cuvs::preprocessing::quantize::pq::vpq_dataset; using label_t = uint32_t; const ix_t n_rows = dataset.extent(0); diff --git a/cpp/src/neighbors/vpq_dataset.cuh b/cpp/src/neighbors/vpq_dataset.cuh index cfcef97b4f..d3c1312dfc 100644 --- a/cpp/src/neighbors/vpq_dataset.cuh +++ b/cpp/src/neighbors/vpq_dataset.cuh @@ -29,7 +29,7 @@ template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset + -> cuvs::preprocessing::quantize::pq::vpq_dataset { if constexpr (std::is_same_v) { return detail::vpq_convert_math_type( diff --git a/cpp/src/neighbors/vpq_dataset_impl.hpp b/cpp/src/neighbors/vpq_dataset_impl.hpp index e6bf96a522..e6b9cac47e 100644 --- a/cpp/src/neighbors/vpq_dataset_impl.hpp +++ b/cpp/src/neighbors/vpq_dataset_impl.hpp @@ -5,14 +5,10 @@ #pragma once -#include +#include #include -#ifdef __cpp_lib_bitops -#include -#endif - namespace cuvs::neighbors { /** @@ -25,7 +21,7 @@ namespace cuvs::neighbors { * @tparam IdxT type of the vector indices */ template -class vpq_dataset_impl : public vpq_dataset_iface { +class vpq_dataset_impl : public cuvs::preprocessing::quantize::pq::vpq_dataset_iface { public: using index_type = IdxT; using math_type = MathT; diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 6f08e01655..bed2952717 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -158,7 +158,7 @@ quantizer build( res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } return {filled_params, - cuvs::neighbors::vpq_dataset{ + cuvs::preprocessing::quantize::pq::vpq_dataset{ std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}}; } @@ -207,7 +207,7 @@ quantizer build_view( // Create view-type vpq_dataset auto empty_data = raft::make_device_matrix(res, 0, 0); return {params, - cuvs::neighbors::vpq_dataset{ + cuvs::preprocessing::quantize::pq::vpq_dataset{ std::make_unique>( vq_centers, pq_centers, std::move(empty_data))}}; } @@ -232,7 +232,6 @@ void transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - // Use view accessors from vpq_dataset auto vq_centers = quantizer.vpq_codebooks.vq_code_book(); auto pq_centers = quantizer.vpq_codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); @@ -374,7 +373,6 @@ void inverse_transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - // Use view accessors from vpq_dataset reconstruct_vectors(res, quantizer.params_quantizer, codes, From a8b3ce43476827f8cdbf651dd2a70c206511c34b Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 19 Mar 2026 15:46:32 -0700 Subject: [PATCH 11/36] fix style --- .../preprocessing/quantize/vpq_dataset.hpp | 248 +++++++++--------- 1 file changed, 124 insertions(+), 124 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index 4ce90418b9..bd033c1b74 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -3,9 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ - #pragma once +#pragma once - #include +#include namespace cuvs::preprocessing::quantize::pq { @@ -18,127 +18,127 @@ namespace cuvs::preprocessing::quantize::pq { * @tparam MathT the type of elements in the codebooks * @tparam IdxT type of the vector indices (represent dataset.extent(0)) */ - template - class vpq_dataset_iface : public cuvs::neighbors::dataset { - public: - using index_type = IdxT; - using math_type = MathT; - - ~vpq_dataset_iface() override = default; - - /** Get view of VQ codebook. */ - [[nodiscard]] virtual auto vq_code_book() const noexcept - -> raft::device_matrix_view = 0; - - /** Get view of PQ codebook. */ - [[nodiscard]] virtual auto pq_code_book() const noexcept - -> raft::device_matrix_view = 0; - - /** Get view of compressed data. */ - [[nodiscard]] virtual auto data() const noexcept - -> raft::device_matrix_view = 0; - - // Derived properties - pure virtual - [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto encoded_row_length() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; - }; - - /** - * @brief VPQ compressed dataset (PIMPL wrapper). - * - * The dataset is compressed using two level quantization: - * 1. Vector Quantization - * 2. Product Quantization of residuals - * - * This class wraps the internal implementation (vpq_dataset_owning or vpq_dataset_view) - * and provides a stable API. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - */ - template - class vpq_dataset : public cuvs::neighbors::dataset { - public: - using index_type = IdxT; - using math_type = MathT; - - vpq_dataset() = default; - - /** Construct from an implementation. */ - explicit vpq_dataset(std::unique_ptr> impl) - : impl_{std::move(impl)} - { - } - - vpq_dataset(const vpq_dataset&) = delete; - vpq_dataset& operator=(const vpq_dataset&) = delete; - vpq_dataset(vpq_dataset&&) = default; - vpq_dataset& operator=(vpq_dataset&&) = default; - ~vpq_dataset() override = default; - - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } - - /** Get view of VQ codebook. */ - [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view - { - return impl_->vq_code_book(); - } - - /** Get view of PQ codebook. */ - [[nodiscard]] auto pq_code_book() const noexcept - -> raft::device_matrix_view - { - return impl_->pq_code_book(); - } - - /** Get view of compressed data. */ - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view - { - return impl_->data(); - } - - /** Row length of the encoded data in bytes. */ - [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t - { - return impl_->encoded_row_length(); - } - - /** The number of "coarse cluster centers" */ - [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } - - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } - - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } - - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } - - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } - - private: - std::unique_ptr> impl_; - }; - - template - struct is_vpq_dataset : std::false_type {}; - - template - struct is_vpq_dataset> : std::true_type {}; - - template - inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; +template +class vpq_dataset_iface : public cuvs::neighbors::dataset { + public: + using index_type = IdxT; + using math_type = MathT; + + ~vpq_dataset_iface() override = default; + + /** Get view of VQ codebook. */ + [[nodiscard]] virtual auto vq_code_book() const noexcept + -> raft::device_matrix_view = 0; + + /** Get view of PQ codebook. */ + [[nodiscard]] virtual auto pq_code_book() const noexcept + -> raft::device_matrix_view = 0; + + /** Get view of compressed data. */ + [[nodiscard]] virtual auto data() const noexcept + -> raft::device_matrix_view = 0; + + // Derived properties - pure virtual + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto encoded_row_length() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; +}; + +/** + * @brief VPQ compressed dataset (PIMPL wrapper). + * + * The dataset is compressed using two level quantization: + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * This class wraps the internal implementation (vpq_dataset_owning or vpq_dataset_view) + * and provides a stable API. + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + */ +template +class vpq_dataset : public cuvs::neighbors::dataset { + public: + using index_type = IdxT; + using math_type = MathT; + + vpq_dataset() = default; + + /** Construct from an implementation. */ + explicit vpq_dataset(std::unique_ptr> impl) + : impl_{std::move(impl)} + { + } + + vpq_dataset(const vpq_dataset&) = delete; + vpq_dataset& operator=(const vpq_dataset&) = delete; + vpq_dataset(vpq_dataset&&) = default; + vpq_dataset& operator=(vpq_dataset&&) = default; + ~vpq_dataset() override = default; + + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + + /** Get view of VQ codebook. */ + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->vq_code_book(); + } + + /** Get view of PQ codebook. */ + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->pq_code_book(); + } + + /** Get view of compressed data. */ + [[nodiscard]] auto data() const noexcept + -> raft::device_matrix_view + { + return impl_->data(); + } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t + { + return impl_->encoded_row_length(); + } + + /** The number of "coarse cluster centers" */ + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } + + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } + + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } + + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } + + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } + + private: + std::unique_ptr> impl_; +}; + +template +struct is_vpq_dataset : std::false_type {}; + +template +struct is_vpq_dataset> : std::true_type {}; + +template +inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; } // namespace cuvs::preprocessing::quantize::pq From f8432b5df16f985f4352a3e374ecb1189856dd69 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 20 Mar 2026 11:55:40 -0700 Subject: [PATCH 12/36] fix the signature --- .../cuvs/preprocessing/quantize/pq.hpp | 22 +++-- cpp/src/preprocessing/quantize/detail/pq.cuh | 15 ++- cpp/src/preprocessing/quantize/pq.cu | 91 ++++++++++--------- 3 files changed, 70 insertions(+), 58 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 884261c08d..ef490c905c 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -134,9 +134,13 @@ quantizer build(raft::resources const& res, * params.pq_dim = 32; * params.use_vq = true; * params.use_subspaces = true; + * // With VQ centers: * auto quant_view = cuvs::preprocessing::quantize::pq::build(handle, params, - * pq_centers_view, vq_centers_view); - * // Use quant_view for transform/inverse_transform operations + * pq_centers_view, + * std::make_optional>(vq_centers_view)); + * // Without VQ (PQ only): + * auto quant_pq_only = cuvs::preprocessing::quantize::pq::build(handle, params, pq_centers_view); * @endcode * * @param[in] res raft resource @@ -146,15 +150,17 @@ quantizer build(raft::resources const& res, * - For use_subspaces=true: [pq_dim * pq_n_centers, pq_len] * - For use_subspaces=false: [pq_n_centers, pq_len] * where pq_n_centers = (1 << pq_bits), pq_len = dim / pq_dim - * @param[in] vq_centers VQ codebook on device memory [vq_n_centers, dim]. - * Pass an empty view if use_vq=false. + * @param[in] vq_centers Optional VQ codebook on device memory [vq_n_centers, dim]. + * Required when use_vq=true. Defaults to std::nullopt (no VQ). * * @return A view-type quantizer that references the provided data */ -quantizer build(raft::resources const& res, - const params params, - raft::device_matrix_view pq_centers, - raft::device_matrix_view vq_centers); +quantizer build( + raft::resources const& res, + const params params, + raft::device_matrix_view pq_centers, + std::optional> vq_centers = + std::nullopt); /** * @brief Applies quantization transform to given dataset diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index bed2952717..2c01e5746f 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -171,7 +171,8 @@ quantizer build_view( raft::resources const& res, const params& params, raft::device_matrix_view pq_centers, - raft::device_matrix_view vq_centers) + std::optional> vq_centers = + std::nullopt) { // Validate parameters RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, @@ -198,18 +199,22 @@ quantizer build_view( // Validate VQ centers if (params.use_vq) { - RAFT_EXPECTS(!vq_centers.empty(), "vq_centers must be provided when use_vq=true"); - RAFT_EXPECTS(vq_centers.extent(0) == params.vq_n_centers, + RAFT_EXPECTS(vq_centers.has_value(), "vq_centers must be provided when use_vq=true"); + RAFT_EXPECTS(vq_centers.value().extent(0) == params.vq_n_centers, "vq_centers must have vq_n_centers rows, got %u", - vq_centers.extent(0)); + vq_centers.value().extent(0)); } // Create view-type vpq_dataset auto empty_data = raft::make_device_matrix(res, 0, 0); + auto vq_view = + vq_centers.has_value() + ? vq_centers.value() + : raft::make_device_matrix_view(nullptr, 0, 0); return {params, cuvs::preprocessing::quantize::pq::vpq_dataset{ std::make_unique>( - vq_centers, pq_centers, std::move(empty_data))}}; + vq_view, pq_centers, std::move(empty_data))}}; } template diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index 22939b96d6..19c6897528 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -9,51 +9,52 @@ namespace cuvs::preprocessing::quantize::pq { -#define CUVS_INST_QUANTIZATION(T, QuantI) \ - auto build(raft::resources const& res, \ - const params params, \ - raft::device_matrix_view dataset) -> quantizer \ - { \ - return detail::build(res, params, dataset); \ - } \ - auto build(raft::resources const& res, \ - const params params, \ - raft::host_matrix_view dataset) -> quantizer \ - { \ - return detail::build(res, params, dataset); \ - } \ - auto build(raft::resources const& res, \ - const params params, \ - raft::device_matrix_view pq_centers, \ - raft::device_matrix_view vq_centers) \ - -> quantizer \ - { \ - return detail::build_view(res, params, pq_centers, vq_centers); \ - } \ - void transform(raft::resources const& res, \ - const quantizer& quantizer, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view codes_out, \ - std::optional> vq_labels) \ - { \ - detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ - } \ - void transform(raft::resources const& res, \ - const quantizer& quantizer, \ - raft::host_matrix_view dataset, \ - raft::device_matrix_view codes_out, \ - std::optional> vq_labels) \ - { \ - detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ - } \ - void inverse_transform( \ - raft::resources const& res, \ - const quantizer& quantizer, \ - raft::device_matrix_view pq_codes, \ - raft::device_matrix_view out, \ - std::optional> vq_labels) \ - { \ - detail::inverse_transform(res, quantizer, pq_codes, out, vq_labels); \ +#define CUVS_INST_QUANTIZATION(T, QuantI) \ + auto build(raft::resources const& res, \ + const params params, \ + raft::device_matrix_view dataset) -> quantizer \ + { \ + return detail::build(res, params, dataset); \ + } \ + auto build(raft::resources const& res, \ + const params params, \ + raft::host_matrix_view dataset) -> quantizer \ + { \ + return detail::build(res, params, dataset); \ + } \ + auto build( \ + raft::resources const& res, \ + const params params, \ + raft::device_matrix_view pq_centers, \ + std::optional> vq_centers) \ + -> quantizer \ + { \ + return detail::build_view(res, params, pq_centers, vq_centers); \ + } \ + void transform(raft::resources const& res, \ + const quantizer& quantizer, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view codes_out, \ + std::optional> vq_labels) \ + { \ + detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ + } \ + void transform(raft::resources const& res, \ + const quantizer& quantizer, \ + raft::host_matrix_view dataset, \ + raft::device_matrix_view codes_out, \ + std::optional> vq_labels) \ + { \ + detail::transform(res, quantizer, dataset, codes_out, vq_labels); \ + } \ + void inverse_transform( \ + raft::resources const& res, \ + const quantizer& quantizer, \ + raft::device_matrix_view pq_codes, \ + raft::device_matrix_view out, \ + std::optional> vq_labels) \ + { \ + detail::inverse_transform(res, quantizer, pq_codes, out, vq_labels); \ } CUVS_INST_QUANTIZATION(float, uint8_t); From a15e0541d5bc0f683c231f109829e82dde6490b6 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 20 Mar 2026 12:07:49 -0700 Subject: [PATCH 13/36] addtogroup --- cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index bd033c1b74..d95c2dde10 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -49,6 +49,11 @@ class vpq_dataset_iface : public cuvs::neighbors::dataset { [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; }; +/** + * @addtogroup pq + * @{ + */ + /** * @brief VPQ compressed dataset (PIMPL wrapper). * @@ -141,4 +146,6 @@ struct is_vpq_dataset> : std::true_type {}; template inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; +/** @} */ // end of group pq + } // namespace cuvs::preprocessing::quantize::pq From 34ce8ffed7e538967aa03ce33d0ab14f9de5f73f Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 25 Mar 2026 11:09:53 -0700 Subject: [PATCH 14/36] sync stream after vamana build --- cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h | 2 ++ cpp/bench/ann/src/diskann/diskann_benchmark.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h index 7d7292bd50..6897416a3c 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h @@ -13,6 +13,7 @@ #include #include #include +#include namespace cuvs::bench { @@ -82,6 +83,7 @@ void cuvs_vamana::build(const T* dataset, size_t nrow) dataset_is_on_host ? cuvs::neighbors::vamana::build(handle_, vamana_index_params_, dataset_view_host) : cuvs::neighbors::vamana::build(handle_, vamana_index_params_, dataset_view_device))); + raft::resource::sync_stream(handle_); } template diff --git a/cpp/bench/ann/src/diskann/diskann_benchmark.cpp b/cpp/bench/ann/src/diskann/diskann_benchmark.cpp index 4129f3cdab..cc437389cb 100644 --- a/cpp/bench/ann/src/diskann/diskann_benchmark.cpp +++ b/cpp/bench/ann/src/diskann/diskann_benchmark.cpp @@ -26,7 +26,7 @@ void parse_build_param(const nlohmann::json& conf, { param.R = conf.at("R"); param.L_build = conf.at("L_build"); - if (conf.contains("alpha")) { param.num_threads = conf.at("alpha"); } + if (conf.contains("alpha")) { param.alpha = conf.at("alpha"); } if (conf.contains("num_threads")) { param.num_threads = conf.at("num_threads"); } } From 2c1aa7138db3844746e5ed2e6fe2b33ce4c3dcea Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:30:34 -0700 Subject: [PATCH 15/36] Update cpp/include/cuvs/preprocessing/quantize/pq.hpp Co-authored-by: Tamas Bela Feher --- cpp/include/cuvs/preprocessing/quantize/pq.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index ef490c905c..f355a7f29a 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -121,7 +121,7 @@ quantizer build(raft::resources const& res, raft::host_matrix_view dataset); /** - * @brief Creates a view-type product quantizer from pre-computed codebooks. + * @brief Creates a product quantizer from pre-computed codebooks. * * This function creates a non-owning quantizer that references the provided codebooks. * From d6a836457667e060cc065b636a86bc676b218802 Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:30:54 -0700 Subject: [PATCH 16/36] Update cpp/src/preprocessing/quantize/detail/pq.cuh Co-authored-by: Tamas Bela Feher --- cpp/src/preprocessing/quantize/detail/pq.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 2c01e5746f..0d01a53012 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -164,7 +164,10 @@ quantizer build( } /** - * @brief Build a view-type quantizer from pre-computed codebooks. + * @brief Build a quantizer from pre-computed codebooks. + * + * The quantizer does not own the codebook arrays + */ template quantizer build_view( From 68f016dc47f5b1986081e610cfb614427f08ade7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 2 Apr 2026 18:45:07 -0700 Subject: [PATCH 17/36] update namespace --- .../cuvs/preprocessing/quantize/pq.hpp | 2 +- .../neighbors/detail/vamana/vamana_build.cuh | 2 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 28 +++++++++---------- .../preprocessing/quantize/vpq_build-ext.cuh | 4 +-- cpp/tests/neighbors/ann_scann.cuh | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index 567a55eec6..e5ffd8931d 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -90,7 +90,7 @@ struct quantizer { /** Parameters used to build this quantizer. */ params params_quantizer; /** VPQ codebooks produced during training. */ - cuvs::neighbors::vpq_dataset vpq_codebooks; + cuvs::preprocessing::quantize::pq::vpq_dataset vpq_codebooks; }; /** diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 030860f1cd..0fbe395b04 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -648,7 +648,7 @@ index build( auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( pq_params, cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( + std::make_unique>( raft::make_device_matrix(res, 0, 0), std::move(pq_codebook), raft::make_device_matrix(res, 0, 0))}); diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 4e40b36d17..e207bea5f2 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -159,7 +159,7 @@ quantizer build( } return {filled_params, cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( + std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}}; } @@ -216,7 +216,7 @@ quantizer build_view( : raft::make_device_matrix_view(nullptr, 0, 0); return {params, cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( + std::make_unique>( vq_view, pq_centers, std::move(empty_data))}}; } @@ -393,8 +393,8 @@ void inverse_transform( template void vpq_convert_math_type(const raft::resources& res, - const cuvs::neighbors::vpq_dataset& src, - cuvs::neighbors::vpq_dataset& dst) + const cuvs::preprocessing::quantize::pq::vpq_dataset& src, + cuvs::preprocessing::quantize::pq::vpq_dataset& dst) { raft::linalg::map(res, dst.vq_code_book.view(), @@ -409,7 +409,7 @@ void vpq_convert_math_type(const raft::resources& res, template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset + const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset { using label_t = uint32_t; // Use a heuristic to impute missing parameters. @@ -436,17 +436,17 @@ auto vpq_build(const raft::resources& res, codes.view(), true); - return cuvs::neighbors::vpq_dataset{ + return cuvs::preprocessing::quantize::pq::vpq_dataset{ std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; } template auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset + const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset { auto old_type = vpq_build(res, params, dataset); - auto new_type = cuvs::neighbors::vpq_dataset{ + auto new_type = cuvs::preprocessing::quantize::pq::vpq_dataset{ raft::make_device_mdarray(res, old_type.vq_code_book.extents()), raft::make_device_mdarray(res, old_type.pq_code_book.extents()), std::move(old_type.data)}; @@ -456,8 +456,8 @@ auto vpq_build_half(const raft::resources& res, template void vpq_convert_math_type(const raft::resources& res, - const cuvs::neighbors::vpq_dataset& src, - cuvs::neighbors::vpq_dataset& dst) + const cuvs::preprocessing::quantize::pq::vpq_dataset& src, + cuvs::preprocessing::quantize::pq::vpq_dataset& dst) { raft::linalg::map(res, dst.vq_code_book.view(), @@ -472,7 +472,7 @@ void vpq_convert_math_type(const raft::resources& res, template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset + const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset { using label_t = uint32_t; // Use a heuristic to impute missing parameters. @@ -499,17 +499,17 @@ auto vpq_build(const raft::resources& res, codes.view(), true); - return cuvs::neighbors::vpq_dataset{ + return cuvs::preprocessing::quantize::pq::vpq_dataset{ std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; } template auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset + const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset { auto old_type = vpq_build(res, params, dataset); - auto new_type = cuvs::neighbors::vpq_dataset{ + auto new_type = cuvs::preprocessing::quantize::pq::vpq_dataset{ raft::make_device_mdarray(res, old_type.vq_code_book.extents()), raft::make_device_mdarray(res, old_type.pq_code_book.extents()), std::move(old_type.data)}; diff --git a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh index 1745e53a33..3b80062947 100644 --- a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh +++ b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh @@ -10,11 +10,11 @@ namespace cuvs::preprocessing::quantize::pq { #define CUVS_INST_VPQ_BUILD(T) \ - cuvs::neighbors::vpq_dataset vpq_build( \ + cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::host_matrix_view& dataset); \ - cuvs::neighbors::vpq_dataset vpq_build( \ + cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::device_matrix_view& dataset); diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index eafddec9d2..32d45ed7b6 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -186,7 +186,7 @@ class scann_test : public ::testing::TestWithParam { cuvs::preprocessing::quantize::pq::quantizer quantizer{ pq_params, - cuvs::neighbors::vpq_dataset{ + cuvs::preprocessing::quantize::pq::vpq_dataset{ std::move(vq_codebook), std::move(pq_codebook_copy), std::move(empty_data)}}; auto quantized_residuals_device = From 0070cdad876b588fd6ff509e63f06ddc5a9960e4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 2 Apr 2026 19:02:30 -0700 Subject: [PATCH 18/36] fix compilation --- .../neighbors/detail/dataset_serialize.hpp | 4 +- .../neighbors/detail/vamana/vamana_build.cuh | 2 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 2 +- cpp/src/neighbors/vpq_dataset.cuh | 42 ------ cpp/src/preprocessing/quantize/detail/pq.cuh | 140 ++++++------------ .../quantize/detail}/vpq_dataset_impl.hpp | 4 +- 6 files changed, 54 insertions(+), 140 deletions(-) delete mode 100644 cpp/src/neighbors/vpq_dataset.cuh rename cpp/src/{neighbors => preprocessing/quantize/detail}/vpq_dataset_impl.hpp (98%) diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 51252dac4e..a3c17bf3e1 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -4,7 +4,7 @@ */ #pragma once -#include "../../neighbors/vpq_dataset_impl.hpp" +#include "../../preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include #include @@ -161,7 +161,7 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, data.view()); return std::make_unique>( - std::make_unique>( + std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(data))); } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 0fbe395b04..b37c91b3bb 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -5,7 +5,7 @@ #pragma once -#include "../../../neighbors/vpq_dataset_impl.hpp" +#include "../../../preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include "../../../sparse/neighbors/cross_component_nn.cuh" #include "../../detail/ann_utils.cuh" #include "greedy_search.cuh" diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index c409c8321c..4893a71190 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -4,7 +4,7 @@ */ #pragma once -#include "../../neighbors/vpq_dataset_impl.hpp" +#include "../../preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include #include "../../cluster/kmeans_balanced.cuh" diff --git a/cpp/src/neighbors/vpq_dataset.cuh b/cpp/src/neighbors/vpq_dataset.cuh deleted file mode 100644 index d3c1312dfc..0000000000 --- a/cpp/src/neighbors/vpq_dataset.cuh +++ /dev/null @@ -1,42 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "detail/vpq_dataset.cuh" -#include - -#include - -namespace cuvs::neighbors { - -/** - * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. - * - * @tparam DatasetT a row-major mdspan or mdarray (device or host). - * @tparam MathT a type of the codebook elements and internal math ops. - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res - * @param[in] params VQ and PQ parameters for compressing the data - * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. - * - * @return the vpq_dataset - */ -template -auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> cuvs::preprocessing::quantize::pq::vpq_dataset -{ - if constexpr (std::is_same_v) { - return detail::vpq_convert_math_type( - res, detail::vpq_build(res, params, dataset)); - } else { - return detail::vpq_build(res, params, dataset); - } -} - -} // namespace cuvs::neighbors diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index e207bea5f2..93952de1ee 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -159,15 +159,14 @@ quantizer build( } return {filled_params, cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( + std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}}; } /** * @brief Build a quantizer from pre-computed codebooks. - * - * The quantizer does not own the codebook arrays - + * + * The quantizer does not own the codebook arrays. */ template quantizer build_view( @@ -216,7 +215,7 @@ quantizer build_view( : raft::make_device_matrix_view(nullptr, 0, 0); return {params, cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( + std::make_unique>( vq_view, pq_centers, std::move(empty_data))}}; } @@ -381,98 +380,57 @@ void inverse_transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - reconstruct_vectors( - res, - quant.params_quantizer, - codes, - raft::make_const_mdspan(quant.vpq_codebooks.pq_code_book.view()), - raft::make_const_mdspan(quant.vpq_codebooks.vq_code_book.view()), - vq_labels, - out, - quant.params_quantizer.use_subspaces); - -template -void vpq_convert_math_type(const raft::resources& res, - const cuvs::preprocessing::quantize::pq::vpq_dataset& src, - cuvs::preprocessing::quantize::pq::vpq_dataset& dst) -{ - raft::linalg::map(res, - dst.vq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); - raft::linalg::map(res, - dst.pq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); + reconstruct_vectors(res, + quantizer.params_quantizer, + codes, + quantizer.vpq_codebooks.pq_code_book(), + quantizer.vpq_codebooks.vq_code_book(), + vq_labels, + out, + quantizer.params_quantizer.use_subspaces); } -template -auto vpq_build(const raft::resources& res, - const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset +template +auto vpq_convert_math_type(const raft::resources& res, + cuvs::preprocessing::quantize::pq::vpq_dataset&& src) + -> cuvs::preprocessing::quantize::pq::vpq_dataset { - using label_t = uint32_t; - // Use a heuristic to impute missing parameters. - auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); - - // Train codes - auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); - auto pq_code_book = cuvs::neighbors::detail::train_pq( - res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); - - // Encode dataset - const IdxT n_rows = dataset.extent(0); - const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( - ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); + auto vq_src = src.vq_code_book(); + auto pq_src = src.pq_code_book(); - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - cuvs::neighbors::detail::process_and_fill_codes( - res, - ps, - dataset, - raft::make_const_mdspan(pq_code_book.view()), - raft::make_const_mdspan(vq_code_book.view()), - raft::make_device_vector_view(nullptr, 0), - codes.view(), - true); + auto vq_new = raft::make_device_matrix( + res, vq_src.extent(0), vq_src.extent(1)); + auto pq_new = raft::make_device_matrix( + res, pq_src.extent(0), pq_src.extent(1)); - return cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; -} - -template -auto vpq_build_half(const raft::resources& res, - const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset -{ - auto old_type = vpq_build(res, params, dataset); - auto new_type = cuvs::preprocessing::quantize::pq::vpq_dataset{ - raft::make_device_mdarray(res, old_type.vq_code_book.extents()), - raft::make_device_mdarray(res, old_type.pq_code_book.extents()), - std::move(old_type.data)}; - vpq_convert_math_type(res, old_type, new_type); - return new_type; -} - -template -void vpq_convert_math_type(const raft::resources& res, - const cuvs::preprocessing::quantize::pq::vpq_dataset& src, - cuvs::preprocessing::quantize::pq::vpq_dataset& dst) -{ raft::linalg::map(res, - dst.vq_code_book.view(), + vq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); + vq_src); raft::linalg::map(res, - dst.pq_code_book.view(), + pq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); + pq_src); + + // Data (encoded bytes) can be moved directly — independent of MathT. + auto data_src = src.data(); + auto data_new = raft::make_device_matrix( + res, data_src.extent(0), data_src.extent(1)); + raft::copy(data_new.data_handle(), + data_src.data_handle(), + data_src.size(), + raft::resource::get_cuda_stream(res)); + + return cuvs::preprocessing::quantize::pq::vpq_dataset{ + std::make_unique>( + std::move(vq_new), std::move(pq_new), std::move(data_new))}; } template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset + const DatasetT& dataset) + -> cuvs::preprocessing::quantize::pq::vpq_dataset { using label_t = uint32_t; // Use a heuristic to impute missing parameters. @@ -500,20 +458,18 @@ auto vpq_build(const raft::resources& res, true); return cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(codes))}; } template auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) -> cuvs::preprocessing::quantize::pq::vpq_dataset + const DatasetT& dataset) + -> cuvs::preprocessing::quantize::pq::vpq_dataset { - auto old_type = vpq_build(res, params, dataset); - auto new_type = cuvs::preprocessing::quantize::pq::vpq_dataset{ - raft::make_device_mdarray(res, old_type.vq_code_book.extents()), - raft::make_device_mdarray(res, old_type.pq_code_book.extents()), - std::move(old_type.data)}; - vpq_convert_math_type(res, old_type, new_type); - return new_type; + auto float_type = vpq_build(res, params, dataset); + return vpq_convert_math_type(res, std::move(float_type)); } + } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/neighbors/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp similarity index 98% rename from cpp/src/neighbors/vpq_dataset_impl.hpp rename to cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp index e6b9cac47e..2e75a234ce 100644 --- a/cpp/src/neighbors/vpq_dataset_impl.hpp +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -9,7 +9,7 @@ #include -namespace cuvs::neighbors { +namespace cuvs::preprocessing::quantize::pq { /** * @brief Common VPQ dataset implementation - provides shared implementations. @@ -204,4 +204,4 @@ class vpq_dataset_view : public vpq_dataset_impl { raft::device_matrix data_; }; -} // namespace cuvs::neighbors +} // namespace cuvs::preprocessing::quantize::pq From 819eef881645af35e3e69f9ba27a2c5ec2d7a7d0 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 2 Apr 2026 19:38:24 -0700 Subject: [PATCH 19/36] create vpq_codebooks --- cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h | 2 +- .../cuvs/preprocessing/quantize/pq.hpp | 4 +- .../preprocessing/quantize/vpq_dataset.hpp | 167 ++++++++++------- .../neighbors/detail/dataset_serialize.hpp | 6 +- .../neighbors/detail/vamana/vamana_build.cuh | 7 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 2 +- .../neighbors/scann/detail/scann_build.cuh | 2 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 151 +++++++-------- .../quantize/detail/vpq_dataset_impl.hpp | 172 +++++------------- .../preprocessing/quantize/vpq_build-ext.cuh | 4 +- .../preprocessing/product_quantization.cu | 4 +- 11 files changed, 246 insertions(+), 275 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h index 6897416a3c..99bc1725ec 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once diff --git a/cpp/include/cuvs/preprocessing/quantize/pq.hpp b/cpp/include/cuvs/preprocessing/quantize/pq.hpp index e5ffd8931d..40f4719240 100644 --- a/cpp/include/cuvs/preprocessing/quantize/pq.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/pq.hpp @@ -89,8 +89,8 @@ template struct quantizer { /** Parameters used to build this quantizer. */ params params_quantizer; - /** VPQ codebooks produced during training. */ - cuvs::preprocessing::quantize::pq::vpq_dataset vpq_codebooks; + /** VPQ codebooks (owning or view). */ + cuvs::preprocessing::quantize::pq::vpq_codebooks codebooks; }; /** diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index d95c2dde10..6ae8a9b090 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -9,44 +9,36 @@ namespace cuvs::preprocessing::quantize::pq { +// ============================================================================ +// vpq_codebooks — VQ + PQ codebook storage (owning or view) +// ============================================================================ + /** - * @brief VPQ compressed dataset - internal interface. - * - * This is the abstract base class for the internal implementation. - * Users should use vpq_dataset which wraps this via PIMPL. + * @brief Abstract interface for VPQ codebook access. * * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) */ -template -class vpq_dataset_iface : public cuvs::neighbors::dataset { +template +class vpq_codebooks_iface { public: - using index_type = IdxT; - using math_type = MathT; + using math_type = MathT; - ~vpq_dataset_iface() override = default; + virtual ~vpq_codebooks_iface() = default; - /** Get view of VQ codebook. */ + /** Get view of VQ codebook [vq_n_centers, dim]. */ [[nodiscard]] virtual auto vq_code_book() const noexcept -> raft::device_matrix_view = 0; - /** Get view of PQ codebook. */ + /** Get view of PQ codebook [pq_n_centers (× pq_dim for subspaces), pq_len]. */ [[nodiscard]] virtual auto pq_code_book() const noexcept -> raft::device_matrix_view = 0; - /** Get view of compressed data. */ - [[nodiscard]] virtual auto data() const noexcept - -> raft::device_matrix_view = 0; - - // Derived properties - pure virtual - [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto encoded_row_length() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; + [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; }; /** @@ -55,17 +47,69 @@ class vpq_dataset_iface : public cuvs::neighbors::dataset { */ /** - * @brief VPQ compressed dataset (PIMPL wrapper). + * @brief PIMPL wrapper for VQ + PQ codebooks. + * + * Internally delegates to either an owning implementation (holds device + * matrices) or a view implementation (references external device memory). * - * The dataset is compressed using two level quantization: - * 1. Vector Quantization - * 2. Product Quantization of residuals + * @tparam MathT the type of elements in the codebooks + */ +template +class vpq_codebooks { + public: + using math_type = MathT; + + vpq_codebooks() = default; + + /** Construct from an implementation. */ + explicit vpq_codebooks(std::unique_ptr> impl) : impl_{std::move(impl)} + { + } + + vpq_codebooks(const vpq_codebooks&) = delete; + vpq_codebooks& operator=(const vpq_codebooks&) = delete; + vpq_codebooks(vpq_codebooks&&) = default; + vpq_codebooks& operator=(vpq_codebooks&&) = default; + ~vpq_codebooks() = default; + + [[nodiscard]] auto vq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->vq_code_book(); + } + + [[nodiscard]] auto pq_code_book() const noexcept + -> raft::device_matrix_view + { + return impl_->pq_code_book(); + } + + [[nodiscard]] auto dim() const noexcept -> uint32_t { return impl_->dim(); } + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } + + /** Check whether this object has been initialised. */ + [[nodiscard]] explicit operator bool() const noexcept { return impl_ != nullptr; } + + private: + std::unique_ptr> impl_; +}; + +// ============================================================================ +// vpq_dataset — codebooks + encoded data (always owning) +// ============================================================================ + +/** + * @brief VPQ compressed dataset. * - * This class wraps the internal implementation (vpq_dataset_owning or vpq_dataset_view) - * and provides a stable API. + * Holds a set of VQ + PQ codebooks together with the encoded dataset. + * Both the codebooks and the encoded data are always owned by this object. * * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) */ template class vpq_dataset : public cuvs::neighbors::dataset { @@ -75,9 +119,9 @@ class vpq_dataset : public cuvs::neighbors::dataset { vpq_dataset() = default; - /** Construct from an implementation. */ - explicit vpq_dataset(std::unique_ptr> impl) - : impl_{std::move(impl)} + vpq_dataset(vpq_codebooks&& codebooks, + raft::device_matrix&& data) + : codebooks_{std::move(codebooks)}, data_{std::move(data)} { } @@ -87,56 +131,51 @@ class vpq_dataset : public cuvs::neighbors::dataset { vpq_dataset& operator=(vpq_dataset&&) = default; ~vpq_dataset() override = default; - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return impl_->n_rows(); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return impl_->dim(); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + // ── dataset interface ────────────────────────────────────────────── + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data_.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return codebooks_.dim(); } + [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } - /** Get view of VQ codebook. */ + // ── Codebook access (convenience forwards) ─────────────────────────────── [[nodiscard]] auto vq_code_book() const noexcept -> raft::device_matrix_view { - return impl_->vq_code_book(); + return codebooks_.vq_code_book(); } - - /** Get view of PQ codebook. */ [[nodiscard]] auto pq_code_book() const noexcept -> raft::device_matrix_view { - return impl_->pq_code_book(); + return codebooks_.pq_code_book(); } - /** Get view of compressed data. */ + /** Get view of the encoded (compressed) data. */ [[nodiscard]] auto data() const noexcept -> raft::device_matrix_view { - return impl_->data(); + return data_.view(); } - /** Row length of the encoded data in bytes. */ - [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t + // ── Derived properties ─────────────────────────────────────────────────── + [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t { return data_.extent(1); } + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return codebooks_.vq_n_centers(); } + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return codebooks_.pq_bits(); } + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return codebooks_.pq_dim(); } + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return codebooks_.pq_len(); } + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return codebooks_.pq_n_centers(); } + + /** Direct access to the codebooks object. */ + [[nodiscard]] auto codebooks() const noexcept -> const vpq_codebooks& { - return impl_->encoded_row_length(); + return codebooks_; } - /** The number of "coarse cluster centers" */ - [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } - - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } - - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } - - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } - - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } - private: - std::unique_ptr> impl_; + vpq_codebooks codebooks_; + raft::device_matrix data_; }; +// ── Type trait ───────────────────────────────────────────────────────────── + template struct is_vpq_dataset : std::false_type {}; diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index a3c17bf3e1..d9112b9f9d 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -161,8 +161,10 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, data.view()); return std::make_unique>( - std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(data))); + cuvs::preprocessing::quantize::pq::vpq_codebooks{ + std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book))}, + std::move(data)); } template diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index b37c91b3bb..304a2b38ec 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -647,11 +647,10 @@ index build( auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( pq_params, - cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( + cuvs::preprocessing::quantize::pq::vpq_codebooks{ + std::make_unique>( raft::make_device_matrix(res, 0, 0), - std::move(pq_codebook), - raft::make_device_matrix(res, 0, 0))}); + std::move(pq_codebook))}); const int64_t codes_rowlen = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); quantized_vectors = raft::make_device_matrix(res, n_rows, codes_rowlen); diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 4893a71190..5b76e72e33 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -4,8 +4,8 @@ */ #pragma once -#include "../../preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include +#include #include "../../cluster/kmeans_balanced.cuh" #include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 37f74f29bb..72c2a1ca11 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -286,7 +286,7 @@ index build( // Codebooks from VPQ have the shape [subspace idx, subspace dim, code] // This converts the codebook into matrix format for easy interoperability // with open-source ScaNN search - auto full_codebook_view = pq_quantizer.vpq_codebooks.pq_code_book.view(); + auto full_codebook_view = pq_quantizer.codebooks.pq_code_book(); raft::linalg::map_offset( res, diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 93952de1ee..51cfc4d173 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -8,6 +8,7 @@ #include "../../../core/nvtx.hpp" #include "../../../neighbors/detail/vpq_dataset.cuh" #include "pq_codepacking.cuh" // pq_bits-bitfield +#include "vpq_dataset_impl.hpp" #include #include @@ -44,6 +45,27 @@ inline auto to_vpq_params(const cuvs::preprocessing::quantize::pq::params& param .max_train_points_per_vq_cluster = params.max_train_points_per_vq_cluster}; } +// ── Helper: build owning vpq_codebooks from device matrices ──────────────── + +template +auto make_owning_codebooks(raft::device_matrix&& vq, + raft::device_matrix&& pq) + -> vpq_codebooks +{ + return vpq_codebooks{ + std::make_unique>(std::move(vq), std::move(pq))}; +} + +template +auto make_view_codebooks(raft::device_matrix_view vq, + raft::device_matrix_view pq) + -> vpq_codebooks +{ + return vpq_codebooks{std::make_unique>(vq, pq)}; +} + +// ── PQ subspace training ─────────────────────────────────────────────────── + template auto train_pq_subspaces( const raft::resources& res, @@ -120,9 +142,8 @@ auto train_pq_subspaces( return pq_centers; } -/** - * @brief Build an owning quantizer by training on a dataset. - */ +// ── Build (owning, from dataset) ─────────────────────────────────────────── + template quantizer build( raft::resources const& res, @@ -148,7 +169,6 @@ quantizer build( if (filled_params.use_vq) { vq_code_book = cuvs::neighbors::detail::train_vq(res, vpq_params, dataset); } - auto empty_codes = raft::make_device_matrix(res, 0, 0); auto pq_code_book = raft::make_device_matrix(res, 0, 0); if (filled_params.use_subspaces) { pq_code_book = train_pq_subspaces( @@ -157,17 +177,11 @@ quantizer build( pq_code_book = cuvs::neighbors::detail::train_pq( res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } - return {filled_params, - cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(empty_codes))}}; + return {filled_params, make_owning_codebooks(std::move(vq_code_book), std::move(pq_code_book))}; } -/** - * @brief Build a quantizer from pre-computed codebooks. - * - * The quantizer does not own the codebook arrays. - */ +// ── Build (view, from pre-computed codebooks) ────────────────────────────── + template quantizer build_view( raft::resources const& res, @@ -176,7 +190,6 @@ quantizer build_view( std::optional> vq_centers = std::nullopt) { - // Validate parameters RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, "PQ bits must be within [4, 16], got %u", params.pq_bits); @@ -184,7 +197,6 @@ quantizer build_view( const uint32_t pq_n_centers = 1u << params.pq_bits; - // Validate PQ centers shape if (params.use_subspaces) { RAFT_EXPECTS(pq_centers.extent(0) == params.pq_dim * pq_n_centers, "For use_subspaces=true, pq_centers must have shape [pq_dim * pq_n_centers, " @@ -199,7 +211,6 @@ quantizer build_view( pq_centers.extent(1)); } - // Validate VQ centers if (params.use_vq) { RAFT_EXPECTS(vq_centers.has_value(), "vq_centers must be provided when use_vq=true"); RAFT_EXPECTS(vq_centers.value().extent(0) == params.vq_n_centers, @@ -207,18 +218,15 @@ quantizer build_view( vq_centers.value().extent(0)); } - // Create view-type vpq_dataset - auto empty_data = raft::make_device_matrix(res, 0, 0); auto vq_view = vq_centers.has_value() ? vq_centers.value() : raft::make_device_matrix_view(nullptr, 0, 0); - return {params, - cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( - vq_view, pq_centers, std::move(empty_data))}}; + return {params, make_view_codebooks(vq_view, pq_centers)}; } +// ── Transform ────────────────────────────────────────────────────────────── + template void transform( raft::resources const& res, @@ -239,8 +247,8 @@ void transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - auto vq_centers = quantizer.vpq_codebooks.vq_code_book(); - auto pq_centers = quantizer.vpq_codebooks.pq_code_book(); + auto vq_centers = quantizer.codebooks.vq_code_book(); + auto pq_centers = quantizer.codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } @@ -265,6 +273,8 @@ void transform( } } +// ── Inverse transform ────────────────────────────────────────────────────── + template >>( codes, out_vectors, pq_centers, vq_centers, vq_labels, pq_bits, use_subspaces); RAFT_CUDA_TRY(cudaPeekAtLastError()); - - return codes; } template @@ -381,19 +387,20 @@ void inverse_transform( "PQ bits must be within [4, 16]"); reconstruct_vectors(res, - quantizer.params_quantizer, - codes, - quantizer.vpq_codebooks.pq_code_book(), - quantizer.vpq_codebooks.vq_code_book(), - vq_labels, - out, - quantizer.params_quantizer.use_subspaces); + quantizer.params_quantizer, + codes, + quantizer.codebooks.pq_code_book(), + quantizer.codebooks.vq_code_book(), + vq_labels, + out, + quantizer.params_quantizer.use_subspaces); } -template -auto vpq_convert_math_type(const raft::resources& res, - cuvs::preprocessing::quantize::pq::vpq_dataset&& src) - -> cuvs::preprocessing::quantize::pq::vpq_dataset +// ── vpq_convert_math_type (codebooks only — no data copy) ────────────────── + +template +auto vpq_convert_math_type(const raft::resources& res, vpq_codebooks&& src) + -> vpq_codebooks { auto vq_src = src.vq_code_book(); auto pq_src = src.pq_code_book(); @@ -403,45 +410,28 @@ auto vpq_convert_math_type(const raft::resources& res, auto pq_new = raft::make_device_matrix( res, pq_src.extent(0), pq_src.extent(1)); - raft::linalg::map(res, - vq_new.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - vq_src); - raft::linalg::map(res, - pq_new.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - pq_src); - - // Data (encoded bytes) can be moved directly — independent of MathT. - auto data_src = src.data(); - auto data_new = raft::make_device_matrix( - res, data_src.extent(0), data_src.extent(1)); - raft::copy(data_new.data_handle(), - data_src.data_handle(), - data_src.size(), - raft::resource::get_cuda_stream(res)); - - return cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( - std::move(vq_new), std::move(pq_new), std::move(data_new))}; + raft::linalg::map( + res, vq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, vq_src); + raft::linalg::map( + res, pq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, pq_src); + + return make_owning_codebooks(std::move(vq_new), std::move(pq_new)); } +// ── vpq_build (trains codebooks + encodes dataset → vpq_dataset) ─────────── + template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) - -> cuvs::preprocessing::quantize::pq::vpq_dataset + const DatasetT& dataset) -> vpq_dataset { using label_t = uint32_t; - // Use a heuristic to impute missing parameters. - auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); + auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); - // Train codes auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); auto pq_code_book = cuvs::neighbors::detail::train_pq( res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); - // Encode dataset const IdxT n_rows = dataset.extent(0); const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); @@ -457,19 +447,34 @@ auto vpq_build(const raft::resources& res, codes.view(), true); - return cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(codes))}; + return vpq_dataset{ + make_owning_codebooks(std::move(vq_code_book), std::move(pq_code_book)), std::move(codes)}; } template auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, - const DatasetT& dataset) - -> cuvs::preprocessing::quantize::pq::vpq_dataset + const DatasetT& dataset) -> vpq_dataset { - auto float_type = vpq_build(res, params, dataset); - return vpq_convert_math_type(res, std::move(float_type)); + // Build in float first + auto float_ds = vpq_build(res, params, dataset); + + // Convert only the codebooks to half; encoded data is uint8 and stays as-is + auto half_codebooks = vpq_convert_math_type(res, std::move(float_ds.codebooks())); + + // Transfer data out — vpq_dataset owns it, so we need to reconstruct. + // The data view is still valid because float_ds hasn't been destroyed yet. + auto data_view = float_ds.data(); + auto data_copy = raft::make_device_matrix( + res, data_view.extent(0), data_view.extent(1)); + if (data_view.size() > 0) { + raft::copy(data_copy.data_handle(), + data_view.data_handle(), + data_view.size(), + raft::resource::get_cuda_stream(res)); + } + + return vpq_dataset{std::move(half_codebooks), std::move(data_copy)}; } } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp index 2e75a234ce..65f1edf8d3 100644 --- a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -12,196 +12,122 @@ namespace cuvs::preprocessing::quantize::pq { /** - * @brief Common VPQ dataset implementation - provides shared implementations. - * - * This class contains the common implementations for derived properties - * that are shared between owning and view implementations. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices + * @brief Common derived-property logic shared by owning and view codebook implementations. */ -template -class vpq_dataset_impl : public cuvs::preprocessing::quantize::pq::vpq_dataset_iface { +template +class vpq_codebooks_impl : public vpq_codebooks_iface { public: - using index_type = IdxT; - using math_type = MathT; + using math_type = MathT; - // Derived properties with default implementations - [[nodiscard]] auto n_rows() const noexcept -> index_type override - { - return this->data().extent(0); - } [[nodiscard]] auto dim() const noexcept -> uint32_t override { return this->vq_code_book().extent(1); } - [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } - - /** Row length of the encoded data in bytes. */ - [[nodiscard]] inline auto encoded_row_length() const noexcept -> uint32_t override + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t override + { + return this->vq_code_book().extent(0); + } + [[nodiscard]] auto pq_len() const noexcept -> uint32_t override { - return this->data().extent(1); + return this->pq_code_book().extent(1); } - /** The number of "coarse cluster centers" */ - [[nodiscard]] inline auto vq_n_centers() const noexcept -> uint32_t override + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t override { - return this->vq_code_book().extent(0); + return this->pq_code_book().extent(0); } - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] inline auto pq_bits() const noexcept -> uint32_t override + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t override { - /* - NOTE: pq_bits and the book size - - Normally, we'd store `pq_bits` as a part of the index. - However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is - the same as the number of possible code values. Hence, we don't store the pq_bits and derive it - from the array dimensions instead. - */ auto pq_width = pq_n_centers(); #ifdef __cpp_lib_bitops return std::countr_zero(pq_width); #else - uint32_t pq_bits = 0; + uint32_t bits = 0; while (pq_width > 1) { - pq_bits++; + bits++; pq_width >>= 1; } - return pq_bits; + return bits; #endif } - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] inline auto pq_dim() const noexcept -> uint32_t override + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t override { return raft::div_rounding_up_unsafe(dim(), pq_len()); } - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] inline auto pq_len() const noexcept -> uint32_t override - { - return this->pq_code_book().extent(1); - } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] inline auto pq_n_centers() const noexcept -> uint32_t override - { - return this->pq_code_book().extent(0); - } }; -/** - * @brief Owning VPQ dataset implementation - owns the codebooks and data. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices - */ -template -class vpq_dataset_owning : public vpq_dataset_impl { +// ============================================================================ +// Owning variant — holds device_matrix objects +// ============================================================================ + +template +class vpq_codebooks_owning : public vpq_codebooks_impl { public: - using index_type = IdxT; - using math_type = MathT; - - /** - * @brief Construct an owning vpq_dataset by moving in the codebooks and data. - */ - vpq_dataset_owning(raft::device_matrix&& vq_code_book, - raft::device_matrix&& pq_code_book, - raft::device_matrix&& data) - : vq_code_book_{std::move(vq_code_book)}, - pq_code_book_{std::move(pq_code_book)}, - data_{std::move(data)} + using math_type = MathT; + + vpq_codebooks_owning(raft::device_matrix&& vq_code_book, + raft::device_matrix&& pq_code_book) + : vq_code_book_{std::move(vq_code_book)}, pq_code_book_{std::move(pq_code_book)} { } - vpq_dataset_owning(const vpq_dataset_owning&) = delete; - vpq_dataset_owning& operator=(const vpq_dataset_owning&) = delete; - vpq_dataset_owning(vpq_dataset_owning&&) = default; - vpq_dataset_owning& operator=(vpq_dataset_owning&&) = default; - ~vpq_dataset_owning() override = default; + vpq_codebooks_owning(const vpq_codebooks_owning&) = delete; + vpq_codebooks_owning& operator=(const vpq_codebooks_owning&) = delete; + vpq_codebooks_owning(vpq_codebooks_owning&&) = default; + vpq_codebooks_owning& operator=(vpq_codebooks_owning&&) = default; + ~vpq_codebooks_owning() override = default; [[nodiscard]] auto vq_code_book() const noexcept -> raft::device_matrix_view override { return vq_code_book_.view(); } - [[nodiscard]] auto pq_code_book() const noexcept -> raft::device_matrix_view override { return pq_code_book_.view(); } - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view override - { - return data_.view(); - } - private: raft::device_matrix vq_code_book_; raft::device_matrix pq_code_book_; - raft::device_matrix data_; }; -/** - * @brief View-type VPQ dataset implementation - owns the dataset but not the codebooks. - * - * The caller must ensure the lifetime of the codebook data exceeds - * the lifetime of this object. The dataset is owned by this object. - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices - */ -template -class vpq_dataset_view : public vpq_dataset_impl { +// ============================================================================ +// View variant — references external device memory +// ============================================================================ + +template +class vpq_codebooks_view : public vpq_codebooks_impl { public: - using index_type = IdxT; - using math_type = MathT; - - /** - * @brief Construct a vpq_dataset that owns the dataset but references the codebooks. - * - * @param vq_code_book_view View of VQ codebook [vq_n_centers, dim] (non-owning) - * @param pq_code_book_view View of PQ codebook [pq_dim * pq_n_centers, pq_len] or [pq_n_centers, - * pq_len] (non-owning) - * @param data Compressed data matrix (moved, owned by this object) - */ - vpq_dataset_view( + using math_type = MathT; + + vpq_codebooks_view( raft::device_matrix_view vq_code_book_view, - raft::device_matrix_view pq_code_book_view, - raft::device_matrix&& data) - : vq_code_book_view_{vq_code_book_view}, - pq_code_book_view_{pq_code_book_view}, - data_{std::move(data)} + raft::device_matrix_view pq_code_book_view) + : vq_code_book_view_{vq_code_book_view}, pq_code_book_view_{pq_code_book_view} { } - vpq_dataset_view(const vpq_dataset_view&) = default; - vpq_dataset_view& operator=(const vpq_dataset_view&) = default; - vpq_dataset_view(vpq_dataset_view&&) = default; - vpq_dataset_view& operator=(vpq_dataset_view&&) = default; - ~vpq_dataset_view() override = default; + vpq_codebooks_view(const vpq_codebooks_view&) = default; + vpq_codebooks_view& operator=(const vpq_codebooks_view&) = default; + vpq_codebooks_view(vpq_codebooks_view&&) = default; + vpq_codebooks_view& operator=(vpq_codebooks_view&&) = default; + ~vpq_codebooks_view() override = default; [[nodiscard]] auto vq_code_book() const noexcept -> raft::device_matrix_view override { return vq_code_book_view_; } - [[nodiscard]] auto pq_code_book() const noexcept -> raft::device_matrix_view override { return pq_code_book_view_; } - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view override - { - return data_.view(); - } - private: raft::device_matrix_view vq_code_book_view_; raft::device_matrix_view pq_code_book_view_; - raft::device_matrix data_; }; } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh index 3b80062947..e8f31a11aa 100644 --- a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh +++ b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh @@ -10,11 +10,11 @@ namespace cuvs::preprocessing::quantize::pq { #define CUVS_INST_VPQ_BUILD(T) \ - cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ + cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::host_matrix_view& dataset); \ - cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ + cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::device_matrix_view& dataset); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index e1dbad2ae8..6702c8c857 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -245,8 +245,8 @@ class ProductQuantizationTest : public ::testing::TestWithParam Date: Fri, 3 Apr 2026 11:46:14 -0700 Subject: [PATCH 20/36] reduce diff --- .../preprocessing/quantize/vpq_dataset.hpp | 53 ++++++++++------- .../quantize/detail/vpq_dataset_impl.hpp | 58 +------------------ .../preprocessing/product_quantization.cu | 16 ++--- 3 files changed, 41 insertions(+), 86 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index 6ae8a9b090..4b9b3c34dc 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -9,10 +9,6 @@ namespace cuvs::preprocessing::quantize::pq { -// ============================================================================ -// vpq_codebooks — VQ + PQ codebook storage (owning or view) -// ============================================================================ - /** * @brief Abstract interface for VPQ codebook access. * @@ -25,20 +21,43 @@ class vpq_codebooks_iface { virtual ~vpq_codebooks_iface() = default; - /** Get view of VQ codebook [vq_n_centers, dim]. */ + /** VQ codebook [vq_n_centers, dim]. */ [[nodiscard]] virtual auto vq_code_book() const noexcept -> raft::device_matrix_view = 0; - /** Get view of PQ codebook [pq_n_centers (× pq_dim for subspaces), pq_len]. */ + /** PQ codebook [pq_n_centers (× pq_dim for subspaces), pq_len]. */ [[nodiscard]] virtual auto pq_code_book() const noexcept -> raft::device_matrix_view = 0; - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t = 0; - [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t = 0; + // ── Derived properties (default implementations) ───────────────────────── + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t { return vq_code_book().extent(1); } + [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book().extent(0); + } + [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t + { + return pq_code_book().extent(1); + } + [[nodiscard]] virtual auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book().extent(0); + } + [[nodiscard]] virtual auto pq_bits() const noexcept -> uint32_t + { + auto w = pq_n_centers(); + uint32_t bits = 0; + while (w > 1) { + bits++; + w >>= 1; + } + return bits; + } + [[nodiscard]] virtual auto pq_dim() const noexcept -> uint32_t + { + auto l = pq_len(); + return l > 0 ? (dim() + l - 1) / l : 0; + } }; /** @@ -61,7 +80,6 @@ class vpq_codebooks { vpq_codebooks() = default; - /** Construct from an implementation. */ explicit vpq_codebooks(std::unique_ptr> impl) : impl_{std::move(impl)} { } @@ -91,17 +109,10 @@ class vpq_codebooks { [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return impl_->pq_len(); } [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return impl_->pq_n_centers(); } - /** Check whether this object has been initialised. */ - [[nodiscard]] explicit operator bool() const noexcept { return impl_ != nullptr; } - private: std::unique_ptr> impl_; }; -// ============================================================================ -// vpq_dataset — codebooks + encoded data (always owning) -// ============================================================================ - /** * @brief VPQ compressed dataset. * @@ -174,8 +185,6 @@ class vpq_dataset : public cuvs::neighbors::dataset { raft::device_matrix data_; }; -// ── Type trait ───────────────────────────────────────────────────────────── - template struct is_vpq_dataset : std::false_type {}; diff --git a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp index 65f1edf8d3..718f95d709 100644 --- a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -7,60 +7,10 @@ #include -#include - namespace cuvs::preprocessing::quantize::pq { -/** - * @brief Common derived-property logic shared by owning and view codebook implementations. - */ -template -class vpq_codebooks_impl : public vpq_codebooks_iface { - public: - using math_type = MathT; - - [[nodiscard]] auto dim() const noexcept -> uint32_t override - { - return this->vq_code_book().extent(1); - } - [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t override - { - return this->vq_code_book().extent(0); - } - [[nodiscard]] auto pq_len() const noexcept -> uint32_t override - { - return this->pq_code_book().extent(1); - } - [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t override - { - return this->pq_code_book().extent(0); - } - [[nodiscard]] auto pq_bits() const noexcept -> uint32_t override - { - auto pq_width = pq_n_centers(); -#ifdef __cpp_lib_bitops - return std::countr_zero(pq_width); -#else - uint32_t bits = 0; - while (pq_width > 1) { - bits++; - pq_width >>= 1; - } - return bits; -#endif - } - [[nodiscard]] auto pq_dim() const noexcept -> uint32_t override - { - return raft::div_rounding_up_unsafe(dim(), pq_len()); - } -}; - -// ============================================================================ -// Owning variant — holds device_matrix objects -// ============================================================================ - template -class vpq_codebooks_owning : public vpq_codebooks_impl { +class vpq_codebooks_owning : public vpq_codebooks_iface { public: using math_type = MathT; @@ -92,12 +42,8 @@ class vpq_codebooks_owning : public vpq_codebooks_impl { raft::device_matrix pq_code_book_; }; -// ============================================================================ -// View variant — references external device memory -// ============================================================================ - template -class vpq_codebooks_view : public vpq_codebooks_impl { +class vpq_codebooks_view : public vpq_codebooks_iface { public: using math_type = MathT; diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 6702c8c857..0b0c67432c 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -131,7 +131,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam& quant, + void check_reconstruction(const cuvs::preprocessing::quantize::pq::quantizer& quantizer, raft::device_matrix_view codes, std::optional> vq_labels, double compression_ratio, @@ -149,7 +149,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam> vq_labels_view = std::nullopt; if (vq_labels) { vq_labels_view = raft::make_const_mdspan(vq_labels.value()); } inverse_transform(handle, - quant, + quantizer, raft::device_matrix_view( codes.data_handle(), n_take, codes.extent(1)), rec_data.view(), @@ -161,7 +161,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam(handle, n_samples_, n_encoded_cols); transform(handle, - owning_quant, + owning_quantizer, raft::make_const_mdspan(dataset_.view()), codes_owning.view(), std::nullopt); @@ -263,7 +263,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam(handle, n_samples_, n_encoded_cols); transform(handle, - view_quant, + view_quantizer, raft::make_const_mdspan(dataset_.view()), codes_view.view(), std::nullopt); @@ -279,12 +279,12 @@ class ProductQuantizationTest : public ::testing::TestWithParam(handle, n_samples_, n_features_); inverse_transform(handle, - owning_quant, + owning_quantizer, raft::make_const_mdspan(codes_owning.view()), reconstructed_owning.view(), std::nullopt); inverse_transform(handle, - view_quant, + view_quantizer, raft::make_const_mdspan(codes_view.view()), reconstructed_view.view(), std::nullopt); From 77bd557fb58cc2d69d304b0ecb07973e24fc5ada Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 3 Apr 2026 12:18:47 -0700 Subject: [PATCH 21/36] fix compilation --- c/src/preprocessing/quantize/pq.cpp | 4 +- .../preprocessing/quantize/vpq_dataset.hpp | 51 +++++--------- .../detail/cagra/compute_distance_vpq.hpp | 2 +- cpp/src/neighbors/detail/cagra/factory.cuh | 2 +- .../neighbors/detail/dataset_serialize.hpp | 2 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 70 ++++--------------- .../preprocessing/quantize/vpq_build-ext.cuh | 5 +- cpp/tests/neighbors/ann_scann.cuh | 8 +-- .../preprocessing/product_quantization.cu | 4 +- 9 files changed, 47 insertions(+), 101 deletions(-) diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index c792f1dcce..4b33fbf12e 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -252,7 +252,7 @@ extern "C" cuvsError_t cuvsProductQuantizerGetPqCodebook(cuvsProductQuantizer_t if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { auto pq_mdspan = (reinterpret_cast*>(quant_addr)) - ->vpq_codebooks.pq_code_book(); + ->codebooks.pq_code_book(); cuvs::core::to_dlpack(pq_mdspan, pq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", @@ -274,7 +274,7 @@ extern "C" cuvsError_t cuvsProductQuantizerGetVqCodebook(cuvsProductQuantizer_t if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { auto pq_mdspan = (reinterpret_cast*>(quant_addr)) - ->vpq_codebooks.vq_code_book(); + ->codebooks.vq_code_book(); cuvs::core::to_dlpack(pq_mdspan, vq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index 4b9b3c34dc..db1741aa6c 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -128,11 +128,16 @@ class vpq_dataset : public cuvs::neighbors::dataset { using index_type = IdxT; using math_type = MathT; + /** VQ + PQ codebooks (owning or view). */ + vpq_codebooks codebooks; + /** Encoded (compressed) data [n_rows, encoded_row_length]. */ + raft::device_matrix data; + vpq_dataset() = default; - vpq_dataset(vpq_codebooks&& codebooks, - raft::device_matrix&& data) - : codebooks_{std::move(codebooks)}, data_{std::move(data)} + vpq_dataset(vpq_codebooks&& codebooks_in, + raft::device_matrix&& data_in) + : codebooks{std::move(codebooks_in)}, data{std::move(data_in)} { } @@ -143,46 +148,26 @@ class vpq_dataset : public cuvs::neighbors::dataset { ~vpq_dataset() override = default; // ── dataset interface ────────────────────────────────────────────── - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data_.extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return codebooks_.dim(); } + [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t override { return codebooks.dim(); } [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } - // ── Codebook access (convenience forwards) ─────────────────────────────── [[nodiscard]] auto vq_code_book() const noexcept -> raft::device_matrix_view { - return codebooks_.vq_code_book(); + return codebooks.vq_code_book(); } [[nodiscard]] auto pq_code_book() const noexcept -> raft::device_matrix_view { - return codebooks_.pq_code_book(); - } - - /** Get view of the encoded (compressed) data. */ - [[nodiscard]] auto data() const noexcept - -> raft::device_matrix_view - { - return data_.view(); - } - - // ── Derived properties ─────────────────────────────────────────────────── - [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t { return data_.extent(1); } - [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return codebooks_.vq_n_centers(); } - [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return codebooks_.pq_bits(); } - [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return codebooks_.pq_dim(); } - [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return codebooks_.pq_len(); } - [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return codebooks_.pq_n_centers(); } - - /** Direct access to the codebooks object. */ - [[nodiscard]] auto codebooks() const noexcept -> const vpq_codebooks& - { - return codebooks_; + return codebooks.pq_code_book(); } - - private: - vpq_codebooks codebooks_; - raft::device_matrix data_; + [[nodiscard]] auto encoded_row_length() const noexcept -> uint32_t { return data.extent(1); } + [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return codebooks.vq_n_centers(); } + [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return codebooks.pq_bits(); } + [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return codebooks.pq_dim(); } + [[nodiscard]] auto pq_len() const noexcept -> uint32_t { return codebooks.pq_len(); } + [[nodiscard]] auto pq_n_centers() const noexcept -> uint32_t { return codebooks.pq_n_centers(); } }; template diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 2d039d699e..59f1cafff5 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -51,7 +51,7 @@ struct vpq_descriptor_spec : public instance_spec { const DistanceT* dataset_norms = nullptr) -> host_type { return init_(params, - dataset.data().data_handle(), + dataset.data.data_handle(), dataset.encoded_row_length(), dataset.vq_code_book().data_handle(), dataset.pq_code_book().data_handle(), diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index dcdf4e4b00..7b23946d1c 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -109,7 +109,7 @@ auto make_key(const cagra::search_params& params, cuvs::distance::DistanceType metric) -> std::enable_if_t, key> { - return key{reinterpret_cast(dataset.data().data_handle()), + return key{reinterpret_cast(dataset.data.data_handle()), uint64_t(dataset.n_rows()), dataset.dim(), uint32_t(reinterpret_cast(dataset.pq_code_book().data_handle()) >> 6), diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index d9112b9f9d..90676f524d 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -71,7 +71,7 @@ void serialize(const raft::resources& res, raft::serialize_scalar(res, os, dataset.encoded_row_length()); raft::serialize_mdspan(res, os, dataset.vq_code_book()); raft::serialize_mdspan(res, os, dataset.pq_code_book()); - raft::serialize_mdspan(res, os, dataset.data()); + raft::serialize_mdspan(res, os, dataset.data.view()); } template diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 51cfc4d173..552a38dae9 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -45,27 +45,6 @@ inline auto to_vpq_params(const cuvs::preprocessing::quantize::pq::params& param .max_train_points_per_vq_cluster = params.max_train_points_per_vq_cluster}; } -// ── Helper: build owning vpq_codebooks from device matrices ──────────────── - -template -auto make_owning_codebooks(raft::device_matrix&& vq, - raft::device_matrix&& pq) - -> vpq_codebooks -{ - return vpq_codebooks{ - std::make_unique>(std::move(vq), std::move(pq))}; -} - -template -auto make_view_codebooks(raft::device_matrix_view vq, - raft::device_matrix_view pq) - -> vpq_codebooks -{ - return vpq_codebooks{std::make_unique>(vq, pq)}; -} - -// ── PQ subspace training ─────────────────────────────────────────────────── - template auto train_pq_subspaces( const raft::resources& res, @@ -142,8 +121,6 @@ auto train_pq_subspaces( return pq_centers; } -// ── Build (owning, from dataset) ─────────────────────────────────────────── - template quantizer build( raft::resources const& res, @@ -177,7 +154,9 @@ quantizer build( pq_code_book = cuvs::neighbors::detail::train_pq( res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } - return {filled_params, make_owning_codebooks(std::move(vq_code_book), std::move(pq_code_book))}; + return {filled_params, + vpq_codebooks{std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book))}}; } // ── Build (view, from pre-computed codebooks) ────────────────────────────── @@ -222,11 +201,10 @@ quantizer build_view( vq_centers.has_value() ? vq_centers.value() : raft::make_device_matrix_view(nullptr, 0, 0); - return {params, make_view_codebooks(vq_view, pq_centers)}; + return {params, + vpq_codebooks{std::make_unique>(vq_view, pq_centers)}}; } -// ── Transform ────────────────────────────────────────────────────────────── - template void transform( raft::resources const& res, @@ -273,8 +251,6 @@ void transform( } } -// ── Inverse transform ────────────────────────────────────────────────────── - template -auto vpq_convert_math_type(const raft::resources& res, vpq_codebooks&& src) +auto vpq_convert_math_type(const raft::resources& res, const vpq_codebooks& src) -> vpq_codebooks { auto vq_src = src.vq_code_book(); @@ -415,11 +389,10 @@ auto vpq_convert_math_type(const raft::resources& res, vpq_codebooks&& raft::linalg::map( res, pq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, pq_src); - return make_owning_codebooks(std::move(vq_new), std::move(pq_new)); + return vpq_codebooks{ + std::make_unique>(std::move(vq_new), std::move(pq_new))}; } -// ── vpq_build (trains codebooks + encodes dataset → vpq_dataset) ─────────── - template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, @@ -448,7 +421,9 @@ auto vpq_build(const raft::resources& res, true); return vpq_dataset{ - make_owning_codebooks(std::move(vq_code_book), std::move(pq_code_book)), std::move(codes)}; + vpq_codebooks{std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book))}, + std::move(codes)}; } template @@ -456,25 +431,10 @@ auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, const DatasetT& dataset) -> vpq_dataset { - // Build in float first - auto float_ds = vpq_build(res, params, dataset); - - // Convert only the codebooks to half; encoded data is uint8 and stays as-is - auto half_codebooks = vpq_convert_math_type(res, std::move(float_ds.codebooks())); - - // Transfer data out — vpq_dataset owns it, so we need to reconstruct. - // The data view is still valid because float_ds hasn't been destroyed yet. - auto data_view = float_ds.data(); - auto data_copy = raft::make_device_matrix( - res, data_view.extent(0), data_view.extent(1)); - if (data_view.size() > 0) { - raft::copy(data_copy.data_handle(), - data_view.data_handle(), - data_view.size(), - raft::resource::get_cuda_stream(res)); - } - - return vpq_dataset{std::move(half_codebooks), std::move(data_copy)}; + // Build in float, then convert codebooks to half; data (uint8 codes) is moved, not copied. + auto float_ds = vpq_build(res, params, dataset); + auto half_codebooks = vpq_convert_math_type(res, float_ds.codebooks); + return vpq_dataset{std::move(half_codebooks), std::move(float_ds.data)}; } } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh index e8f31a11aa..7d4c8083b1 100644 --- a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh +++ b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh @@ -5,16 +5,17 @@ #pragma once #include +#include #include namespace cuvs::preprocessing::quantize::pq { #define CUVS_INST_VPQ_BUILD(T) \ - cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ + vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::host_matrix_view& dataset); \ - cuvs::preprocessing::quantize::pq::vpq_dataset vpq_build( \ + vpq_dataset vpq_build( \ const raft::resources& res, \ const cuvs::neighbors::vpq_params& params, \ const raft::device_matrix_view& dataset); diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index 32d45ed7b6..a26ecc9a1f 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -9,6 +9,7 @@ #include #include #include +#include "../../src/preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include #include @@ -182,12 +183,11 @@ class scann_test : public ::testing::TestWithParam { handle_, idx.centers().extent(0), idx.centers().extent(1)); raft::copy( vq_codebook.data_handle(), idx.centers().data_handle(), idx.centers().size(), stream_); - auto empty_data = raft::make_device_matrix(handle_, 0, 0); - cuvs::preprocessing::quantize::pq::quantizer quantizer{ pq_params, - cuvs::preprocessing::quantize::pq::vpq_dataset{ - std::move(vq_codebook), std::move(pq_codebook_copy), std::move(empty_data)}}; + cuvs::preprocessing::quantize::pq::vpq_codebooks{ + std::make_unique>( + std::move(vq_codebook), std::move(pq_codebook_copy))}}; auto quantized_residuals_device = raft::make_device_matrix(handle_, ps.num_db_vecs, num_subspaces); diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 0b0c67432c..df858bfa9d 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -255,7 +255,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam(handle, n_samples_, n_encoded_cols); transform(handle, - owning_quantizer, + owning_quant, raft::make_const_mdspan(dataset_.view()), codes_owning.view(), std::nullopt); @@ -263,7 +263,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam(handle, n_samples_, n_encoded_cols); transform(handle, - view_quantizer, + view_quant, raft::make_const_mdspan(dataset_.view()), codes_view.view(), std::nullopt); From 1c85f16c5149c9ea3917b7b9a04be517721dbc98 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 3 Apr 2026 12:22:15 -0700 Subject: [PATCH 22/36] pre-commit --- cpp/src/preprocessing/quantize/detail/pq.cuh | 4 ++-- cpp/tests/neighbors/ann_scann.cuh | 2 +- cpp/tests/preprocessing/product_quantization.cu | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 552a38dae9..63e741f20d 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -421,8 +421,8 @@ auto vpq_build(const raft::resources& res, true); return vpq_dataset{ - vpq_codebooks{std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book))}, + vpq_codebooks{std::make_unique>(std::move(vq_code_book), + std::move(pq_code_book))}, std::move(codes)}; } diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index a26ecc9a1f..aa739da301 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -4,12 +4,12 @@ */ #pragma once +#include "../../src/preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include "../test_utils.cuh" #include "ann_utils.cuh" #include #include #include -#include "../../src/preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include #include diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index df858bfa9d..c18f09cb8f 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -279,12 +279,12 @@ class ProductQuantizationTest : public ::testing::TestWithParam(handle, n_samples_, n_features_); inverse_transform(handle, - owning_quantizer, + owning_quant, raft::make_const_mdspan(codes_owning.view()), reconstructed_owning.view(), std::nullopt); inverse_transform(handle, - view_quantizer, + view_quant, raft::make_const_mdspan(codes_view.view()), reconstructed_view.view(), std::nullopt); From 47084344f3889be427228567a8efe169a35bb44c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 3 Apr 2026 16:15:35 -0700 Subject: [PATCH 23/36] revert bm change --- cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h index 99bc1725ec..7d7292bd50 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_vamana_wrapper.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -13,7 +13,6 @@ #include #include #include -#include namespace cuvs::bench { @@ -83,7 +82,6 @@ void cuvs_vamana::build(const T* dataset, size_t nrow) dataset_is_on_host ? cuvs::neighbors::vamana::build(handle_, vamana_index_params_, dataset_view_host) : cuvs::neighbors::vamana::build(handle_, vamana_index_params_, dataset_view_device))); - raft::resource::sync_stream(handle_); } template From 55db0a45ad56dff185fe54a1e2c0abfa1d8c6292 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 3 Apr 2026 16:19:55 -0700 Subject: [PATCH 24/36] rm unnecessary commits --- cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp | 2 -- cpp/src/preprocessing/quantize/detail/pq.cuh | 2 -- 2 files changed, 4 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index db1741aa6c..8eafbbbe6d 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -29,7 +29,6 @@ class vpq_codebooks_iface { [[nodiscard]] virtual auto pq_code_book() const noexcept -> raft::device_matrix_view = 0; - // ── Derived properties (default implementations) ───────────────────────── [[nodiscard]] virtual auto dim() const noexcept -> uint32_t { return vq_code_book().extent(1); } [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t { @@ -147,7 +146,6 @@ class vpq_dataset : public cuvs::neighbors::dataset { vpq_dataset& operator=(vpq_dataset&&) = default; ~vpq_dataset() override = default; - // ── dataset interface ────────────────────────────────────────────── [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data.extent(0); } [[nodiscard]] auto dim() const noexcept -> uint32_t override { return codebooks.dim(); } [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 63e741f20d..74a44c80e5 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -159,8 +159,6 @@ quantizer build( std::move(vq_code_book), std::move(pq_code_book))}}; } -// ── Build (view, from pre-computed codebooks) ────────────────────────────── - template quantizer build_view( raft::resources const& res, From 6408cbb0a268addadbccbd0511918d67261991a5 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 3 Apr 2026 16:28:19 -0700 Subject: [PATCH 25/36] fix error message and copyright --- cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp | 2 +- cpp/src/preprocessing/quantize/detail/pq.cuh | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index 8eafbbbe6d..faac7b1787 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 74a44c80e5..110f238f15 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -135,7 +135,7 @@ quantizer build( dim, params.pq_bits, params.pq_dim); - RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 8, "PQ bits must be within [4, 16], got %u", params.pq_bits); @@ -167,8 +167,8 @@ quantizer build_view( std::optional> vq_centers = std::nullopt) { - RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, - "PQ bits must be within [4, 16], got %u", + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 8, + "PQ bits must be within [4, 8], got %u", params.pq_bits); RAFT_EXPECTS(params.pq_dim > 0, "pq_dim must be specified for view-type quantizer"); @@ -221,7 +221,7 @@ void transform( RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quantizer.params_quantizer), "Output matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, - "PQ bits must be within [4, 16]"); + "PQ bits must be within [4, 8]"); auto vq_centers = quantizer.codebooks.vq_code_book(); auto pq_centers = quantizer.codebooks.pq_code_book(); From c4250f5f7006705adcb1830d04663599e60fd326 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 6 Apr 2026 09:44:13 -0700 Subject: [PATCH 26/36] fix condition check --- cpp/src/preprocessing/quantize/detail/pq.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 110f238f15..74a44c80e5 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -135,7 +135,7 @@ quantizer build( dim, params.pq_bits, params.pq_dim); - RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 8, + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, "PQ bits must be within [4, 16], got %u", params.pq_bits); @@ -167,8 +167,8 @@ quantizer build_view( std::optional> vq_centers = std::nullopt) { - RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 8, - "PQ bits must be within [4, 8], got %u", + RAFT_EXPECTS(params.pq_bits >= 4 && params.pq_bits <= 16, + "PQ bits must be within [4, 16], got %u", params.pq_bits); RAFT_EXPECTS(params.pq_dim > 0, "pq_dim must be specified for view-type quantizer"); @@ -221,7 +221,7 @@ void transform( RAFT_EXPECTS(pq_codes_out.extent(1) == get_quantized_dim(quantizer.params_quantizer), "Output matrix doesn't have the correct number of columns"); RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, - "PQ bits must be within [4, 8]"); + "PQ bits must be within [4, 16]"); auto vq_centers = quantizer.codebooks.vq_code_book(); auto pq_centers = quantizer.codebooks.pq_code_book(); From 8c1f7921ec45730ed3810ebb004e466168a13a78 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 6 Apr 2026 15:09:23 -0700 Subject: [PATCH 27/36] change trailing return type --- cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index faac7b1787..77d43b8649 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -146,9 +146,9 @@ class vpq_dataset : public cuvs::neighbors::dataset { vpq_dataset& operator=(vpq_dataset&&) = default; ~vpq_dataset() override = default; - [[nodiscard]] auto n_rows() const noexcept -> index_type override { return data.extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t override { return codebooks.dim(); } - [[nodiscard]] auto is_owning() const noexcept -> bool override { return true; } + [[nodiscard]] index_type n_rows() const noexcept override { return data.extent(0); } + [[nodiscard]] uint32_t dim() const noexcept override { return codebooks.dim(); } + [[nodiscard]] bool is_owning() const noexcept override { return true; } [[nodiscard]] auto vq_code_book() const noexcept -> raft::device_matrix_view From 4b7015f720e8ecc8de89c9a9ff442454026d14f1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 7 Apr 2026 18:27:48 -0700 Subject: [PATCH 28/36] add non const getters --- .../quantize/detail/vpq_dataset_impl.hpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp index 718f95d709..dc827242e0 100644 --- a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -31,12 +31,25 @@ class vpq_codebooks_owning : public vpq_codebooks_iface { { return vq_code_book_.view(); } + + [[nodiscard]] auto vq_code_book() noexcept + -> raft::device_matrix_view + { + return vq_code_book_.view(); + } + [[nodiscard]] auto pq_code_book() const noexcept -> raft::device_matrix_view override { return pq_code_book_.view(); } + [[nodiscard]] auto pq_code_book() noexcept + -> raft::device_matrix_view + { + return pq_code_book_.view(); + } + private: raft::device_matrix vq_code_book_; raft::device_matrix pq_code_book_; From 724768883d003c47d0f7e9d3b86aa2be1751c48e Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 11:44:39 -0700 Subject: [PATCH 29/36] update to use the view api --- cpp/src/neighbors/detail/vamana/vamana_build.cuh | 9 ++------- cpp/tests/neighbors/ann_scann.cuh | 11 +++++------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 3ab5c53814..72c059f649 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -5,7 +5,6 @@ #pragma once -#include "../../../preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include "../../../sparse/neighbors/cross_component_nn.cuh" #include "../../detail/ann_utils.cuh" #include "greedy_search.cuh" @@ -645,12 +644,8 @@ index build( // process in batches const uint32_t n_rows = dataset.extent(0); - auto quantizer = cuvs::preprocessing::quantize::pq::quantizer( - pq_params, - cuvs::preprocessing::quantize::pq::vpq_codebooks{ - std::make_unique>( - raft::make_device_matrix(res, 0, 0), - std::move(pq_codebook))}); + auto quantizer = cuvs::preprocessing::quantize::pq::build( + res, pq_params, raft::make_const_mdspan(pq_codebook.view())); const int64_t codes_rowlen = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); quantized_vectors = raft::make_device_matrix(res, n_rows, codes_rowlen); diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index aa739da301..9df3bcadad 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -4,7 +4,6 @@ */ #pragma once -#include "../../src/preprocessing/quantize/detail/vpq_dataset_impl.hpp" #include "../test_utils.cuh" #include "ann_utils.cuh" #include @@ -183,11 +182,11 @@ class scann_test : public ::testing::TestWithParam { handle_, idx.centers().extent(0), idx.centers().extent(1)); raft::copy( vq_codebook.data_handle(), idx.centers().data_handle(), idx.centers().size(), stream_); - cuvs::preprocessing::quantize::pq::quantizer quantizer{ - pq_params, - cuvs::preprocessing::quantize::pq::vpq_codebooks{ - std::make_unique>( - std::move(vq_codebook), std::move(pq_codebook_copy))}}; + auto quantizer = + cuvs::preprocessing::quantize::pq::build(handle_, + pq_params, + raft::make_const_mdspan(pq_codebook_copy.view()), + raft::make_const_mdspan(vq_codebook.view())); auto quantized_residuals_device = raft::make_device_matrix(handle_, ps.num_db_vecs, num_subspaces); From be4cb4a5c368480835c260aa2de4f33e7c57cc46 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 11:54:29 -0700 Subject: [PATCH 30/36] make vq codebook optional --- c/src/preprocessing/quantize/pq.cpp | 6 ++- .../preprocessing/quantize/vpq_dataset.hpp | 31 +++++++++--- .../detail/cagra/compute_distance_vpq.hpp | 3 +- .../neighbors/detail/dataset_serialize.hpp | 8 ++- cpp/src/preprocessing/quantize/detail/pq.cuh | 47 ++++++++++------- .../quantize/detail/vpq_dataset_impl.hpp | 50 +++++++++---------- 6 files changed, 90 insertions(+), 55 deletions(-) diff --git a/c/src/preprocessing/quantize/pq.cpp b/c/src/preprocessing/quantize/pq.cpp index 4b33fbf12e..273cff5b2d 100644 --- a/c/src/preprocessing/quantize/pq.cpp +++ b/c/src/preprocessing/quantize/pq.cpp @@ -272,10 +272,12 @@ extern "C" cuvsError_t cuvsProductQuantizerGetVqCodebook(cuvsProductQuantizer_t if (quantizer != nullptr) { auto quant_addr = quantizer->addr; if (quantizer->dtype.code == kDLFloat && quantizer->dtype.bits == 32) { - auto pq_mdspan = + auto vq_opt = (reinterpret_cast*>(quant_addr)) ->codebooks.vq_code_book(); - cuvs::core::to_dlpack(pq_mdspan, vq_codebook); + RAFT_EXPECTS(vq_opt.has_value(), + "quantizer has no VQ codebook (build with use_vq=true to enable)"); + cuvs::core::to_dlpack(vq_opt.value(), vq_codebook); } else { RAFT_FAIL("Unsupported quantizer dtype: %d and bits: %d", quantizer->dtype.code, diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index 77d43b8649..c19aa9f1b9 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -7,6 +7,8 @@ #include +#include + namespace cuvs::preprocessing::quantize::pq { /** @@ -21,18 +23,29 @@ class vpq_codebooks_iface { virtual ~vpq_codebooks_iface() = default; - /** VQ codebook [vq_n_centers, dim]. */ + /** + * VQ codebook [vq_n_centers, dim]. + * + * Returns std::nullopt when no VQ codebook is configured (i.e. PQ-only, + * use_vq=false). Callers that need to forward a `device_matrix_view` + * downstream should materialize an empty 0x0 view themselves on nullopt. + */ [[nodiscard]] virtual auto vq_code_book() const noexcept - -> raft::device_matrix_view = 0; + -> std::optional> = 0; /** PQ codebook [pq_n_centers (× pq_dim for subspaces), pq_len]. */ [[nodiscard]] virtual auto pq_code_book() const noexcept -> raft::device_matrix_view = 0; - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t { return vq_code_book().extent(1); } + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t + { + auto vq = vq_code_book(); + return vq.has_value() ? vq->extent(1) : 0; + } [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t { - return vq_code_book().extent(0); + auto vq = vq_code_book(); + return vq.has_value() ? vq->extent(0) : 0; } [[nodiscard]] virtual auto pq_len() const noexcept -> uint32_t { @@ -89,8 +102,11 @@ class vpq_codebooks { vpq_codebooks& operator=(vpq_codebooks&&) = default; ~vpq_codebooks() = default; + /** + * VQ codebook view, or std::nullopt when no VQ codebook is configured. + */ [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view + -> std::optional> { return impl_->vq_code_book(); } @@ -150,8 +166,11 @@ class vpq_dataset : public cuvs::neighbors::dataset { [[nodiscard]] uint32_t dim() const noexcept override { return codebooks.dim(); } [[nodiscard]] bool is_owning() const noexcept override { return true; } + /** + * VQ codebook view, or std::nullopt when no VQ codebook is configured. + */ [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view + -> std::optional> { return codebooks.vq_code_book(); } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 59f1cafff5..6daefa501f 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -50,10 +50,11 @@ struct vpq_descriptor_spec : public instance_spec { cuvs::distance::DistanceType metric, const DistanceT* dataset_norms = nullptr) -> host_type { + auto vq_view = dataset.vq_code_book(); return init_(params, dataset.data.data_handle(), dataset.encoded_row_length(), - dataset.vq_code_book().data_handle(), + vq_view.has_value() ? vq_view->data_handle() : nullptr, dataset.pq_code_book().data_handle(), IndexT(dataset.n_rows()), dataset.dim()); diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 90676f524d..d2cc0432ee 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -69,7 +69,9 @@ void serialize(const raft::resources& res, raft::serialize_scalar(res, os, dataset.pq_n_centers()); raft::serialize_scalar(res, os, dataset.pq_len()); raft::serialize_scalar(res, os, dataset.encoded_row_length()); - raft::serialize_mdspan(res, os, dataset.vq_code_book()); + auto vq_view = dataset.vq_code_book().value_or( + raft::make_device_matrix_view(nullptr, 0, 0)); + raft::serialize_mdspan(res, os, vq_view); raft::serialize_mdspan(res, os, dataset.pq_code_book()); raft::serialize_mdspan(res, os, dataset.data.view()); } @@ -160,10 +162,12 @@ auto deserialize_vpq(raft::resources const& res, std::istream& is) raft::deserialize_mdspan(res, is, pq_code_book.view()); raft::deserialize_mdspan(res, is, data.view()); + std::optional> vq_code_book_opt; + if (vq_n_centers > 0) { vq_code_book_opt = std::move(vq_code_book); } return std::make_unique>( cuvs::preprocessing::quantize::pq::vpq_codebooks{ std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book))}, + std::move(pq_code_book), std::move(vq_code_book_opt))}, std::move(data)); } diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 87daf15e57..9f17cfa426 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -154,9 +154,11 @@ quantizer build( pq_code_book = cuvs::neighbors::detail::train_pq( res, vpq_params, dataset, raft::make_const_mdspan(vq_code_book.view())); } + std::optional> vq_code_book_opt; + if (filled_params.use_vq) { vq_code_book_opt = std::move(vq_code_book); } return {filled_params, vpq_codebooks{std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book))}}; + std::move(pq_code_book), std::move(vq_code_book_opt))}}; } template @@ -195,12 +197,9 @@ quantizer build_view( vq_centers.value().extent(0)); } - auto vq_view = - vq_centers.has_value() - ? vq_centers.value() - : raft::make_device_matrix_view(nullptr, 0, 0); - return {params, - vpq_codebooks{std::make_unique>(vq_view, pq_centers)}}; + return { + params, + vpq_codebooks{std::make_unique>(pq_centers, vq_centers)}}; } template @@ -223,7 +222,9 @@ void transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); - auto vq_centers = quantizer.codebooks.vq_code_book(); + auto vq_centers_opt = quantizer.codebooks.vq_code_book(); + auto vq_centers = vq_centers_opt.value_or( + raft::make_device_matrix_view(nullptr, 0, 0)); auto pq_centers = quantizer.codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } @@ -360,11 +361,14 @@ void inverse_transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); + auto vq_centers_opt = quantizer.codebooks.vq_code_book(); + auto vq_centers = vq_centers_opt.value_or( + raft::make_device_matrix_view(nullptr, 0, 0)); reconstruct_vectors(res, quantizer.params_quantizer, codes, quantizer.codebooks.pq_code_book(), - quantizer.codebooks.vq_code_book(), + vq_centers, vq_labels, out, quantizer.params_quantizer.use_subspaces); @@ -374,21 +378,26 @@ template auto vpq_convert_math_type(const raft::resources& res, const vpq_codebooks& src) -> vpq_codebooks { - auto vq_src = src.vq_code_book(); - auto pq_src = src.pq_code_book(); + auto vq_src_opt = src.vq_code_book(); + auto pq_src = src.pq_code_book(); - auto vq_new = raft::make_device_matrix( - res, vq_src.extent(0), vq_src.extent(1)); auto pq_new = raft::make_device_matrix( res, pq_src.extent(0), pq_src.extent(1)); - - raft::linalg::map( - res, vq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, vq_src); raft::linalg::map( res, pq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, pq_src); + std::optional> vq_new_opt; + if (vq_src_opt.has_value()) { + auto vq_src = vq_src_opt.value(); + auto vq_new = raft::make_device_matrix( + res, vq_src.extent(0), vq_src.extent(1)); + raft::linalg::map( + res, vq_new.view(), cuvs::spatial::knn::detail::utils::mapping{}, vq_src); + vq_new_opt = std::move(vq_new); + } + return vpq_codebooks{ - std::make_unique>(std::move(vq_new), std::move(pq_new))}; + std::make_unique>(std::move(pq_new), std::move(vq_new_opt))}; } template @@ -419,8 +428,8 @@ auto vpq_build(const raft::resources& res, true); return vpq_dataset{ - vpq_codebooks{std::make_unique>(std::move(vq_code_book), - std::move(pq_code_book))}, + vpq_codebooks{std::make_unique>(std::move(pq_code_book), + std::move(vq_code_book))}, std::move(codes)}; } diff --git a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp index dc827242e0..ac879a9d4c 100644 --- a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -7,16 +7,21 @@ #include +#include + namespace cuvs::preprocessing::quantize::pq { template class vpq_codebooks_owning : public vpq_codebooks_iface { public: - using math_type = MathT; - - vpq_codebooks_owning(raft::device_matrix&& vq_code_book, - raft::device_matrix&& pq_code_book) - : vq_code_book_{std::move(vq_code_book)}, pq_code_book_{std::move(pq_code_book)} + using math_type = MathT; + using matrix_type = raft::device_matrix; + + // PQ codebook is required; VQ codebook is optional and defaults to absent + // When VQ is not provided, vq_code_book() returns std::nullopt. + explicit vpq_codebooks_owning(matrix_type&& pq_code_book, + std::optional&& vq_code_book = std::nullopt) + : pq_code_book_{std::move(pq_code_book)}, vq_code_book_{std::move(vq_code_book)} { } @@ -27,15 +32,10 @@ class vpq_codebooks_owning : public vpq_codebooks_iface { ~vpq_codebooks_owning() override = default; [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view override - { - return vq_code_book_.view(); - } - - [[nodiscard]] auto vq_code_book() noexcept - -> raft::device_matrix_view + -> std::optional> override { - return vq_code_book_.view(); + if (!vq_code_book_.has_value()) { return std::nullopt; } + return vq_code_book_.value().view(); } [[nodiscard]] auto pq_code_book() const noexcept @@ -51,19 +51,21 @@ class vpq_codebooks_owning : public vpq_codebooks_iface { } private: - raft::device_matrix vq_code_book_; - raft::device_matrix pq_code_book_; + matrix_type pq_code_book_; + std::optional vq_code_book_; }; template class vpq_codebooks_view : public vpq_codebooks_iface { public: using math_type = MathT; + using view_type = raft::device_matrix_view; - vpq_codebooks_view( - raft::device_matrix_view vq_code_book_view, - raft::device_matrix_view pq_code_book_view) - : vq_code_book_view_{vq_code_book_view}, pq_code_book_view_{pq_code_book_view} + // PQ codebook view is required; VQ codebook view is optional and defaults to absent. + // When VQ is not provided, vq_code_book() returns std::nullopt. + explicit vpq_codebooks_view(view_type pq_code_book_view, + std::optional vq_code_book_view = std::nullopt) + : pq_code_book_view_{pq_code_book_view}, vq_code_book_view_{vq_code_book_view} { } @@ -73,20 +75,18 @@ class vpq_codebooks_view : public vpq_codebooks_iface { vpq_codebooks_view& operator=(vpq_codebooks_view&&) = default; ~vpq_codebooks_view() override = default; - [[nodiscard]] auto vq_code_book() const noexcept - -> raft::device_matrix_view override + [[nodiscard]] auto vq_code_book() const noexcept -> std::optional override { return vq_code_book_view_; } - [[nodiscard]] auto pq_code_book() const noexcept - -> raft::device_matrix_view override + [[nodiscard]] auto pq_code_book() const noexcept -> view_type override { return pq_code_book_view_; } private: - raft::device_matrix_view vq_code_book_view_; - raft::device_matrix_view pq_code_book_view_; + view_type pq_code_book_view_; + std::optional vq_code_book_view_; }; } // namespace cuvs::preprocessing::quantize::pq From ee0170e0c8736c405dafcd2b6385486955f42da7 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 12:03:28 -0700 Subject: [PATCH 31/36] style --- cpp/src/neighbors/detail/vpq_dataset.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index fb7e149775..3d5e7f2247 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -5,8 +5,8 @@ #pragma once #include -#include #include +#include #include "../../cluster/kmeans_balanced.cuh" #include "../../preprocessing/quantize/detail/pq_codepacking.cuh" // pq_bits-bitfield From 04888f901167a163e12d258001da45b867463611 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 12:49:19 -0700 Subject: [PATCH 32/36] input validation for vq codebooks --- cpp/src/preprocessing/quantize/detail/pq.cuh | 44 ++++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 8958801eef..932637569a 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -230,14 +230,25 @@ quantizer build_view( if (params.use_vq) { RAFT_EXPECTS(vq_centers.has_value(), "vq_centers must be provided when use_vq=true"); + RAFT_EXPECTS(params.vq_n_centers > 0, + "params.vq_n_centers must be > 0 when use_vq=true (got %u)", + params.vq_n_centers); + RAFT_EXPECTS(vq_centers.value().data_handle() != nullptr, + "vq_centers data pointer must be non-null when use_vq=true"); RAFT_EXPECTS(vq_centers.value().extent(0) == params.vq_n_centers, - "vq_centers must have vq_n_centers rows, got %u", + "vq_centers must have shape [vq_n_centers, dim] (vq_n_centers=%u), got " + "extent(0)=%u", + params.vq_n_centers, vq_centers.value().extent(0)); + return { + params, + vpq_codebooks{std::make_unique>(pq_centers, vq_centers)}}; + } else { + if (vq_centers.has_value()) { + RAFT_LOG_WARN("vq_centers will be ignored since params.use_vq=false"); + } + return {params, vpq_codebooks{std::make_unique>(pq_centers)}}; } - - return { - params, - vpq_codebooks{std::make_unique>(pq_centers, vq_centers)}}; } template @@ -260,9 +271,19 @@ void transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); + // Honor params.use_vq as the source of truth: when it is false, pass an + // empty view to the kernel regardless of what the codebooks contain + // (the kernel gates VQ subtraction on vq_centers.empty(), so an empty view + // guarantees no residual VQ is applied even if a misconfigured quantizer + // somehow carries non-empty centers). Conversely, when it is true the + // codebook must be present. auto vq_centers_opt = quantizer.codebooks.vq_code_book(); - auto vq_centers = vq_centers_opt.value_or( - raft::make_device_matrix_view(nullptr, 0, 0)); + RAFT_EXPECTS(!quantizer.params_quantizer.use_vq || vq_centers_opt.has_value(), + "Quantizer has params.use_vq=true but no VQ codebook"); + auto vq_centers = + quantizer.params_quantizer.use_vq + ? vq_centers_opt.value() + : raft::make_device_matrix_view(nullptr, 0, 0); auto pq_centers = quantizer.codebooks.pq_code_book(); auto vq_labels_view = raft::make_device_vector_view(nullptr, 0); if (vq_labels.has_value()) { vq_labels_view = vq_labels.value(); } @@ -399,9 +420,14 @@ void inverse_transform( RAFT_EXPECTS(quantizer.params_quantizer.pq_bits >= 4 && quantizer.params_quantizer.pq_bits <= 16, "PQ bits must be within [4, 16]"); + // Honor params.use_vq strictly (see the matching block in transform()). auto vq_centers_opt = quantizer.codebooks.vq_code_book(); - auto vq_centers = vq_centers_opt.value_or( - raft::make_device_matrix_view(nullptr, 0, 0)); + RAFT_EXPECTS(!quantizer.params_quantizer.use_vq || vq_centers_opt.has_value(), + "Quantizer has params.use_vq=true but no VQ codebook"); + auto vq_centers = + quantizer.params_quantizer.use_vq + ? vq_centers_opt.value() + : raft::make_device_matrix_view(nullptr, 0, 0); reconstruct_vectors(res, quantizer.params_quantizer, codes, From fd1eb38128aa114ec5cf5de3b1b496998505d449 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 15:43:50 -0700 Subject: [PATCH 33/36] add warning --- .../cuvs/preprocessing/quantize/vpq_dataset.hpp | 11 ++++++++++- .../quantize/detail/vpq_dataset_impl.hpp | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index c19aa9f1b9..aadf98c4e0 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include @@ -117,7 +118,15 @@ class vpq_codebooks { return impl_->pq_code_book(); } - [[nodiscard]] auto dim() const noexcept -> uint32_t { return impl_->dim(); } + [[nodiscard]] auto dim() const noexcept -> uint32_t + { + if (!impl_->vq_code_book().has_value()) { + RAFT_LOG_WARN( + "vpq_codebooks::dim() returns 0 when no VQ codebook is present; " + "the original vector dimension cannot be recovered from PQ codebooks alone."); + } + return impl_->dim(); + } [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } diff --git a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp index ac879a9d4c..9d50044ff6 100644 --- a/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp +++ b/cpp/src/preprocessing/quantize/detail/vpq_dataset_impl.hpp @@ -17,7 +17,7 @@ class vpq_codebooks_owning : public vpq_codebooks_iface { using math_type = MathT; using matrix_type = raft::device_matrix; - // PQ codebook is required; VQ codebook is optional and defaults to absent + // PQ codebook is required; VQ codebook is optional and defaults to absent. // When VQ is not provided, vq_code_book() returns std::nullopt. explicit vpq_codebooks_owning(matrix_type&& pq_code_book, std::optional&& vq_code_book = std::nullopt) @@ -61,8 +61,8 @@ class vpq_codebooks_view : public vpq_codebooks_iface { using math_type = MathT; using view_type = raft::device_matrix_view; - // PQ codebook view is required; VQ codebook view is optional and defaults to absent. - // When VQ is not provided, vq_code_book() returns std::nullopt. + // PQ codebook view is required; VQ codebook view is optional and defaults to + // absent. When VQ is not provided, vq_code_book() returns std::nullopt. explicit vpq_codebooks_view(view_type pq_code_book_view, std::optional vq_code_book_view = std::nullopt) : pq_code_book_view_{pq_code_book_view}, vq_code_book_view_{vq_code_book_view} From 075c1bae37c13855e25e93fc2bac7b18b5c30a32 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 16:04:54 -0700 Subject: [PATCH 34/36] check vq_labels --- cpp/src/preprocessing/quantize/detail/pq.cuh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 932637569a..956bd2e4ac 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -428,6 +428,20 @@ void inverse_transform( quantizer.params_quantizer.use_vq ? vq_centers_opt.value() : raft::make_device_matrix_view(nullptr, 0, 0); + + // VQ-label preflight: when use_vq is true the kernel reads vq_labels(row) per + // row and falls back to label 0 when vq_labels is absent. Without this check every vector can be + // reconstructed with vq_label 0. + if (quantizer.params_quantizer.use_vq) { + RAFT_EXPECTS(vq_labels.has_value(), + "When params.use_vq is true, vq_labels must be provided to inverse_transform()"); + RAFT_EXPECTS(vq_labels.value().extent(0) == codes.extent(0), + "When params.use_vq is true, vq_labels must have the same number of rows as " + "codes (got %zu vs %zu)", + size_t(vq_labels.value().extent(0)), + size_t(codes.extent(0))); + } + reconstruct_vectors(res, quantizer.params_quantizer, codes, From 740b78eda123243555825860c748232178abd8d0 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 17:05:01 -0700 Subject: [PATCH 35/36] fix compilation errors --- cpp/src/preprocessing/quantize/detail/pq.cuh | 24 ------------------- .../preprocessing/product_quantization.cu | 11 +++++++-- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index 956bd2e4ac..87771474b8 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -502,30 +502,6 @@ inline auto make_pq_params_from_vpq(const cuvs::neighbors::vpq_params& in_params max_train_points_per_vq_cluster}; } -inline auto make_pq_params_from_vpq(const cuvs::neighbors::vpq_params& in_params, - const uint64_t n_rows) - -> cuvs::preprocessing::quantize::pq::params -{ - const uint32_t pq_n_centers = 1 << in_params.pq_bits; - uint32_t max_train_points_per_vq_cluster = in_params.max_train_points_per_vq_cluster; - if (in_params.vq_n_centers > 0) { - max_train_points_per_vq_cluster = - std::min(max_train_points_per_vq_cluster, - n_rows * in_params.vq_kmeans_trainset_fraction / in_params.vq_n_centers); - } - return cuvs::preprocessing::quantize::pq::params{ - in_params.pq_bits, - in_params.pq_dim, - true, - true, - in_params.vq_n_centers, - in_params.kmeans_n_iters, - in_params.pq_kmeans_type, - std::min(in_params.max_train_points_per_pq_code, - n_rows * in_params.pq_kmeans_trainset_fraction / pq_n_centers), - max_train_points_per_vq_cluster}; -} - template auto vpq_build(const raft::resources& res, const cuvs::neighbors::vpq_params& params, diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index 37e03f2e1f..a3daaea3a8 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -147,7 +147,10 @@ class ProductQuantizationTest : public ::testing::TestWithParam(dataset_.data_handle(), n_take, dim); std::optional> vq_labels_view = std::nullopt; - if (vq_labels) { vq_labels_view = raft::make_const_mdspan(vq_labels.value()); } + if (vq_labels) { + vq_labels_view = raft::make_device_vector_view( + vq_labels.value().data_handle(), static_cast(n_take)); + } inverse_transform(handle, quantizer, raft::device_matrix_view( @@ -242,7 +245,11 @@ class ProductQuantizationTest : public ::testing::TestWithParam(n_samples_) < params_.n_vq_centers)) { From c5e231c85435539fbc67fc7755983f7aae2e46c7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 4 May 2026 15:27:58 -0700 Subject: [PATCH 36/36] add instantiations --- cpp/bench/ann/CMakeLists.txt | 2 +- .../preprocessing/quantize/vpq_dataset.hpp | 18 ++++++++---------- cpp/src/preprocessing/quantize/pq.cu | 5 +++++ 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 4e4527267c..6b5d2ebe9c 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -157,7 +157,7 @@ function(ConfigureAnnBench) add_dependencies(${BENCH_NAME} ANN_BENCH) else() add_executable(${BENCH_NAME} ${ConfigureAnnBench_PATH}) - target_compile_definitions(${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN>) + target_compile_definitions(${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN) target_link_libraries( ${BENCH_NAME} PRIVATE benchmark::benchmark $<$:CUDA::nvtx3> ) diff --git a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp index aadf98c4e0..52f1c7d958 100644 --- a/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp +++ b/cpp/include/cuvs/preprocessing/quantize/vpq_dataset.hpp @@ -41,7 +41,13 @@ class vpq_codebooks_iface { [[nodiscard]] virtual auto dim() const noexcept -> uint32_t { auto vq = vq_code_book(); - return vq.has_value() ? vq->extent(1) : 0; + if (!vq.has_value()) { + RAFT_LOG_WARN( + "vpq_codebooks_iface::dim() returns 0 when no VQ codebook is present; " + "the original vector dimension cannot be recovered from PQ codebooks alone."); + return 0; + } + return vq->extent(1); } [[nodiscard]] virtual auto vq_n_centers() const noexcept -> uint32_t { @@ -118,15 +124,7 @@ class vpq_codebooks { return impl_->pq_code_book(); } - [[nodiscard]] auto dim() const noexcept -> uint32_t - { - if (!impl_->vq_code_book().has_value()) { - RAFT_LOG_WARN( - "vpq_codebooks::dim() returns 0 when no VQ codebook is present; " - "the original vector dimension cannot be recovered from PQ codebooks alone."); - } - return impl_->dim(); - } + [[nodiscard]] auto dim() const noexcept -> uint32_t { return impl_->dim(); } [[nodiscard]] auto vq_n_centers() const noexcept -> uint32_t { return impl_->vq_n_centers(); } [[nodiscard]] auto pq_bits() const noexcept -> uint32_t { return impl_->pq_bits(); } [[nodiscard]] auto pq_dim() const noexcept -> uint32_t { return impl_->pq_dim(); } diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index c9e203c8bf..6eacb92224 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -82,4 +82,9 @@ CUVS_INST_VPQ_BUILD(uint8_t); #undef CUVS_INST_VPQ_BUILD +template class vpq_codebooks; +template class vpq_codebooks; +template class vpq_dataset; +template class vpq_dataset; + } // namespace cuvs::preprocessing::quantize::pq