From 66d7fd3d76929d4f45fb46b24b992dc6697df62b Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 10 Apr 2026 15:54:51 -0700 Subject: [PATCH 01/50] combine impls --- cpp/src/cluster/detail/kmeans.cuh | 567 ++++++++++++---------- cpp/src/cluster/detail/kmeans_batched.cuh | 510 ------------------- cpp/src/cluster/detail/kmeans_common.cuh | 178 +++++++ cpp/src/cluster/kmeans_fit_double.cu | 3 +- cpp/src/cluster/kmeans_fit_float.cu | 3 +- cpp/src/cluster/kmeans_impl.cuh | 24 +- 6 files changed, 501 insertions(+), 784 deletions(-) delete mode 100644 cpp/src/cluster/detail/kmeans_batched.cuh diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5a35f203b3..1c019a07d8 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -5,6 +5,7 @@ #pragma once #include "../../core/nvtx.hpp" +#include "../../neighbors/detail/ann_utils.cuh" #include "kmeans_common.cuh" #include @@ -31,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +46,7 @@ #include #include #include +#include #include #include @@ -303,150 +306,17 @@ void update_centroids(raft::resources const& handle, new_centroids); } -// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector -template -void kmeans_fit_main(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - raft::common::nvtx::range fun_scope("kmeans_fit_main"); - raft::default_logger().set_level(params.verbosity); - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - rmm::device_scalar clusterCostD(stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm(handle, X, L2NormX.view()); - } - - RAFT_LOG_DEBUG( - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG( - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter[0]); - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - update_centroids( - handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - cuda::transform_iterator(minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}), - wtInCluster.view(), - newCentroids.view(), - workspace); - - // Compute how much centroids shifted - DataT sqrdNormError = compute_centroid_shift( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(newCentroids.view())); - - raft::copy(handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()), - raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - DataT curClusteringCost = clusterCostD.value(stream); - - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - cuvs::cluster::kmeans::cluster_cost(handle, - X, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - inertia, - std::make_optional(weight)); - - RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", - n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], - inertia[0]); -} +template +void kmeans_fit( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& pams, + raft::mdspan, raft::row_major, Accessor> X, + std::optional< + raft::mdspan, raft::layout_right, Accessor>> + sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /* * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. @@ -651,17 +521,20 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, auto inertia = raft::make_host_scalar(0); auto n_iter = raft::make_host_scalar(0); - cuvs::cluster::kmeans::params default_params; - default_params.n_clusters = params.n_clusters; - - cuvs::cluster::kmeans::detail::kmeans_fit_main(handle, - default_params, - potentialCentroids, - weight.view(), - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); + cuvs::cluster::kmeans::params recluster_params; + recluster_params.n_clusters = params.n_clusters; + recluster_params.init = cuvs::cluster::kmeans::params::InitMethod::Array; + recluster_params.n_init = 1; + + auto weight_opt = std::make_optional(raft::make_const_mdspan(weight.view())); + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, + recluster_params, + raft::make_const_mdspan(potentialCentroids), + weight_opt, + centroidsRawData, + inertia.view(), + n_iter.view()); } else if ((int)potentialCentroids.extent(0) < n_clusters) { // supplement with random @@ -697,90 +570,118 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, } /** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. It must be noted - * that the data must be in row-major format and stored in device accessible - * location. - * @param[in] n_samples Number of samples in the input X. - * @param[in] n_features Number of features or the dimensions of each - * sample. - * @param[in] sample_weight Optional weights for each observation in X. - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] Otherwise, generated centroids from the - * kmeans algorithm is stored at the address pointed by 'centroids'. - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. + * @brief Unified k-means fit (works with host or device data). + * + * Handles centroid initialization, the n_init best-of loop, and the batched + * Lloyd iteration. All reusable work buffers are allocated once before the + * n_init loop and shared across iterations. + * + * Data and weights are batched via batch_load_iterator (transparent H2D copy + * for host memory, zero-copy offset for device). Only batch-sized device + * buffers are allocated — no O(n_samples) device memory. + * + * @tparam DataT Data / weight type + * @tparam IndexT Index type + * @tparam Accessor Accessor policy (host or device); deduced from X */ -template -void kmeans_fit(raft::resources const& handle, - const cuvs::cluster::kmeans::params& pams, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +template +void kmeans_fit( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& pams, + raft::mdspan, raft::row_major, Accessor> X, + std::optional< + raft::mdspan, raft::layout_right, Accessor>> + sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { raft::common::nvtx::range fun_scope("kmeans_fit"); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = pams.n_clusters; + auto metric = pams.metric; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - // Check that parameters are valid + if (sample_weight.has_value()) RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, "invalid parameter (sample_weight!=n_samples)"); RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); RAFT_EXPECTS(pams.tol > 0, "invalid parameter (tol<=0)"); RAFT_EXPECTS(pams.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == pams.n_clusters, + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "invalid parameter (centroids.extent(0) != n_clusters)"); RAFT_EXPECTS(centroids.extent(1) == n_features, "invalid parameter (centroids.extent(1) != n_features)"); - // Display a message if the batch size is smaller than n_samples but will be ignored - if (pams.batch_samples < (int)n_samples && - (pams.metric == cuvs::distance::DistanceType::L2Expanded || - pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_samples=%d was passed, but batch_samples=%d will be used (reason: " - "batch_samples has no impact on the memory footprint when FusedL2NN can be used)", - pams.batch_samples, - (int)n_samples); - } - // Display a message if batch_centroids is set and a fusedL2NN-compatible metric is used - if (pams.batch_centroids != 0 && pams.batch_centroids != pams.n_clusters && - (pams.metric == cuvs::distance::DistanceType::L2Expanded || - pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " - "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", - pams.batch_centroids, - pams.n_clusters); + raft::default_logger().set_level(pams.verbosity); + + IndexT streaming_batch_size = static_cast(pams.streaming_batch_size); + if (streaming_batch_size <= 0 || streaming_batch_size > static_cast(n_samples)) { + streaming_batch_size = static_cast(n_samples); } - raft::default_logger().set_level(pams.verbosity); + const DataT* weight_ptr = + sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; + DataT weight_scale = compute_weight_scale(weight_ptr, n_samples, stream); - // Allocate memory rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(handle, weight.view(), sample_weight.value()); - else - raft::matrix::fill(handle, weight.view(), DataT(1)); - // check if weights sum up to n_samples - checkWeight(handle, weight.view(), workspace); + constexpr bool data_on_device = !raft::is_host_mdspan_v; - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); + auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, + raft::device_matrix_view centroidsRawData) { + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + raft::copy( + handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), + raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); + return; + } + + if constexpr (data_on_device) { + auto X_dev = + raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); + + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + initRandom(handle, iter_params, X_dev, centroidsRawData); + } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + if (iter_params.oversampling_factor == 0) + kmeansPlusPlus(handle, iter_params, X_dev, centroidsRawData, workspace); + else + initScalableKMeansPlusPlus( + handle, iter_params, X_dev, centroidsRawData, workspace); + } else { + THROW("unknown initialization method to select initial centers"); + } + } else { + raft::random::RngState random_state(iter_params.rng_state.seed); + + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + raft::matrix::sample_rows(handle, random_state, X, centroidsRawData); + } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IndexT init_sample_size = + std::min(static_cast(3 * streaming_batch_size), n_samples); + if (init_sample_size < n_clusters) { + init_sample_size = std::min(static_cast(3 * n_clusters), n_samples); + } + + auto init_sample = + raft::make_device_matrix(handle, init_sample_size, n_features); + raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); + + auto init_sample_const = raft::make_const_mdspan(init_sample.view()); + if (iter_params.oversampling_factor == 0) + kmeansPlusPlus( + handle, iter_params, init_sample_const, centroidsRawData, workspace); + else + initScalableKMeansPlusPlus( + handle, iter_params, init_sample_const, centroidsRawData, workspace); + } else { + THROW("unknown initialization method to select initial centers"); + } + } + }; auto n_init = pams.n_init; if (pams.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { @@ -791,61 +692,184 @@ void kmeans_fit(raft::resources const& handle, n_init = 1; } + auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); + + auto minClusterAndDistance = raft::make_device_vector, IndexT>( + handle, streaming_batch_size); + auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); + auto batch_weights_buf = raft::make_device_vector(handle, streaming_batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); + + cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, streaming_batch_size, stream); + cuvs::spatial::knn::detail::utils::batch_load_iterator weight_batches( + weight_ptr, n_samples, 1, streaming_batch_size, stream); + + if (weight_ptr == nullptr) { raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); } + + auto prepare_batch_weights = [&](const auto& wt_batch, IndexT cur_batch_size) { + if (weight_ptr != nullptr) { + raft::copy(batch_weights_buf.data_handle(), wt_batch.data(), cur_batch_size, stream); + if (weight_scale != DataT{1}) { + auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + raft::linalg::map( + handle, bw, raft::mul_const_op{weight_scale}, raft::make_const_mdspan(bw)); + } + } + return raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + }; + + RAFT_LOG_DEBUG( + "KMeans.fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(streaming_batch_size)); + std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = pams; iter_params.rng_state.seed = gen(); - DataT iter_inertia = std::numeric_limits::max(); - IndexT n_current_iter = 0; - if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - // initializing with random samples from input dataset - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, iter_params, X, centroidsRawData.view()); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (iter_params.oversampling_factor == 0) - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - else - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init argument.", - seed_iter + 1, - n_init); + RAFT_LOG_DEBUG("KMeans.fit: n_init iteration %d/%d (seed=%llu)", + seed_iter + 1, + n_init, + (unsigned long long)iter_params.rng_state.seed); + + init_centroids(iter_params, centroidsRawData.view()); + + DataT iter_inertia = std::numeric_limits::max(); + IndexT n_current_iter = 0; + DataT priorClusteringCost = 0; + + for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) { + RAFT_LOG_DEBUG("KMeans.fit: Iteration-%d", n_current_iter); + + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, weight_per_cluster.view(), DataT{0}); + raft::linalg::map(handle, + raft::make_device_scalar_view(clustering_cost.data_handle()), + raft::const_op{DataT{0}}); + + auto centroids_const = raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features); + + data_batches.reset(); + weight_batches.reset(); + auto wt_it = weight_batches.begin(); + for (const auto& data_batch : data_batches) { + IndexT cur_batch_size = static_cast(data_batch.size()); + const auto& wt_batch = *wt_it; + ++wt_it; + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), cur_batch_size, n_features); + auto batch_weights_view = prepare_batch_weights(wt_batch, cur_batch_size); + + auto minCAD_view = raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), cur_batch_size); + auto l2_view = + raft::make_device_vector_view(L2NormBatch.data_handle(), cur_batch_size); + + process_batch(handle, + batch_data_view, + batch_weights_view, + centroids_const, + metric, + iter_params.batch_samples, + iter_params.batch_centroids, + minCAD_view, + l2_view, + L2NormBuf_OR_DistBuf, + workspace, + centroid_sums.view(), + weight_per_cluster.view(), + batch_sums.view(), + batch_counts.view(), + clustering_cost.view()); + } + + finalize_centroids(handle, + raft::make_const_mdspan(centroid_sums.view()), + raft::make_const_mdspan(weight_per_cluster.view()), + centroids_const, + new_centroids.view()); + + DataT sqrdNormError = + compute_centroid_shift(handle, + raft::make_const_mdspan(centroids_const), + raft::make_const_mdspan(new_centroids.view())); + raft::copy( handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); - } else { - THROW("unknown initialization method to select initial centers"); + raft::make_device_vector_view(centroidsRawData.data_handle(), new_centroids.size()), + raft::make_device_vector_view(new_centroids.data_handle(), new_centroids.size())); + + bool done = false; + + DataT curClusteringCost = DataT{0}; + raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + + ASSERT(curClusteringCost != DataT{0}, + "Too few points and centroids being found is getting 0 cost from centers"); + + if (n_current_iter > 1) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - iter_params.tol) done = true; + } + priorClusteringCost = curClusteringCost; + + if (sqrdNormError < iter_params.tol) done = true; + + if (done) { + RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", + n_current_iter); + break; + } + } + + { + auto centroids_const = raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features); + + iter_inertia = DataT{0}; + data_batches.reset(); + weight_batches.reset(); + auto wt_it = weight_batches.begin(); + for (const auto& data_batch : data_batches) { + IndexT cur_batch_size = static_cast(data_batch.size()); + const auto& wt_batch = *wt_it; + ++wt_it; + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), cur_batch_size, n_features); + + std::optional> batch_sw = std::nullopt; + if (weight_ptr != nullptr) { batch_sw = prepare_batch_weights(wt_batch, cur_batch_size); } + + DataT batch_cost = DataT{0}; + cuvs::cluster::kmeans::cluster_cost(handle, + batch_data_view, + centroids_const, + raft::make_host_scalar_view(&batch_cost), + batch_sw); + + iter_inertia += batch_cost; + } } - cuvs::cluster::kmeans::detail::kmeans_fit_main( - handle, - iter_params, - X, - weight.view(), - centroidsRawData.view(), - raft::make_host_scalar_view(&iter_inertia), - raft::make_host_scalar_view(&n_current_iter), - workspace); if (iter_inertia < inertia[0]) { inertia[0] = iter_inertia; n_iter[0] = n_current_iter; @@ -854,13 +878,13 @@ void kmeans_fit(raft::resources const& handle, raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); } - RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", + RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter - %d", seed_iter + 1, n_init, inertia[0], n_iter[0]); } - RAFT_LOG_DEBUG("KMeans.fit: async call returned (fit could still be running on the device)"); + RAFT_LOG_DEBUG("KMeans.fit: completed."); } template @@ -877,15 +901,26 @@ void kmeans_fit(raft::resources const& handle, auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = raft::make_device_matrix_view(centroids, pams.n_clusters, n_features); - std::optional> sample_weightView = std::nullopt; + std::optional> sample_weightView = std::nullopt; if (sample_weight) sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); auto inertiaView = raft::make_host_scalar_view(&inertia); auto n_iterView = raft::make_host_scalar_view(&n_iter); - cuvs::cluster::kmeans::detail::kmeans_fit( - handle, pams, XView, sample_weightView, centroidsView, inertiaView, n_iterView); + kmeans_fit(handle, pams, XView, sample_weightView, centroidsView, inertiaView, n_iterView); +} + +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } template diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh deleted file mode 100644 index e2fc8d334f..0000000000 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ /dev/null @@ -1,510 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "kmeans.cuh" -#include "kmeans_common.cuh" - -#include "../../neighbors/detail/ann_utils.cuh" -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include -#include - -namespace cuvs::cluster::kmeans::detail { - -/** - * @brief Initialize centroids from host data - * - * @tparam T Input data type - * @tparam IdxT Index type - */ -template -void init_centroids_from_host_sample(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - IdxT streaming_batch_size, - raft::host_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - - if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - // this is a heuristic to speed up the initialization - IdxT init_sample_size = 3 * streaming_batch_size; - if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } - init_sample_size = std::min(init_sample_size, n_samples); - - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); - raft::random::RngState random_state(params.rng_state.seed); - raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - - if (params.oversampling_factor == 0) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, params, raft::make_const_mdspan(init_sample.view()), centroids, workspace); - } else { - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( - handle, params, raft::make_const_mdspan(init_sample.view()), centroids, workspace); - } - } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - raft::random::RngState random_state(params.rng_state.seed); - raft::matrix::sample_rows(handle, random_state, X, centroids); - } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - // already provided - } else { - RAFT_FAIL("Unknown initialization method"); - } -} - -/** - * @brief Compute the weight normalization scale factor for host sample weights. Weights are - * normalized to sum to n_samples. Returns the scale factor to apply to each weight. - * - * @param[in] sample_weight Optional host vector of sample weights - * @param[in] n_samples Number of samples - * @return Scale factor (1.0 if no weights or already normalized) - */ -template -T compute_host_weight_scale( - const std::optional>& sample_weight, IdxT n_samples) -{ - if (!sample_weight.has_value()) { return T{1}; } - - T wt_sum = T{0}; - const T* sw_ptr = sample_weight->data_handle(); - for (IdxT i = 0; i < n_samples; ++i) { - wt_sum += sw_ptr[i]; - } - if (wt_sum == static_cast(n_samples)) { return T{1}; } - - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %zu samples (scale=%f)", - static_cast(n_samples), - static_cast(static_cast(n_samples) / wt_sum)); - return static_cast(n_samples) / wt_sum; -} - -/** - * @brief Copy host sample weights to device and apply normalization scale. - * - * When sample_weight is provided, copies the batch slice to the device buffer - * and applies the normalization scale factor. When not provided, the device - * buffer is assumed to already be filled with 1.0. - * - * @param[in] handle RAFT resources handle - * @param[in] sample_weight Optional host weights - * @param[in] batch_offset Offset into the host weights for this batch - * @param[in] batch_size Number of elements in this batch - * @param[in] weight_scale Scale factor from compute_host_weight_scale - * @param[inout] batch_weights Device buffer to write normalized weights into - */ -template -void copy_and_scale_batch_weights( - raft::resources const& handle, - const std::optional>& sample_weight, - size_t batch_offset, - IdxT batch_size, - T weight_scale, - raft::device_vector& batch_weights) -{ - if (!sample_weight.has_value()) { return; } - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - raft::copy( - batch_weights.data_handle(), sample_weight->data_handle() + batch_offset, batch_size, stream); - - if (weight_scale != T{1}) { - auto batch_weights_view = - raft::make_device_vector_view(batch_weights.data_handle(), batch_size); - raft::linalg::map(handle, - batch_weights_view, - raft::mul_const_op{weight_scale}, - raft::make_const_mdspan(batch_weights_view)); - } -} - -/** - * @brief Accumulate partial centroid sums and counts from a batch - * - * This function adds the partial sums from a batch to the running accumulators. - * It does NOT divide - that happens once at the end of all batches. - */ -template -void accumulate_batch_centroids( - raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view, IdxT> minClusterAndDistance, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view batch_sums, - raft::device_vector_view batch_counts) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto workspace = rmm::device_uvector( - batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle)); - - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> - labels_itr(minClusterAndDistance.data_handle(), conversion_op); - - cuvs::cluster::kmeans::detail::compute_centroid_adjustments( - handle, - batch_data, - sample_weights, - labels_itr, - static_cast(centroid_sums.extent(0)), - batch_sums, - batch_counts, - workspace); - - raft::linalg::add(centroid_sums.data_handle(), - centroid_sums.data_handle(), - batch_sums.data_handle(), - centroid_sums.size(), - stream); - - raft::linalg::add(cluster_counts.data_handle(), - cluster_counts.data_handle(), - batch_counts.data_handle(), - cluster_counts.size(), - stream); -} - -/** - * @brief Main fit function for batched k-means with host data (full-batch / Lloyd's algorithm). - * - * Processes host data in GPU-sized batches per iteration, accumulating partial centroid - * sums and counts, then finalizes centroids at the end of each iteration. - * - * @tparam T Input data type (float, double) - * @tparam IdxT Index type (int, int64_t) - * - * @param[in] handle RAFT resources handle - * @param[in] params K-means parameters - * @param[in] X Input data on HOST [n_samples x n_features] - * @param[in] sample_weight Optional weights per sample (on host) - * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid - * @param[out] n_iter Number of iterations run - */ -template -void fit(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - IdxT streaming_batch_size = static_cast(params.streaming_batch_size); - - if (params.streaming_batch_size == 0) { - streaming_batch_size = static_cast(n_samples); - } else if (params.streaming_batch_size < 0 || params.streaming_batch_size > n_samples) { - RAFT_LOG_WARN("streaming_batch_size must be >= 0 and <= n_samples, using n_samples=%zu", - static_cast(n_samples)); - streaming_batch_size = static_cast(n_samples); - } - - RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, - "centroids.extent(0) must equal n_clusters"); - RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); - - RAFT_LOG_DEBUG( - "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", - static_cast(n_samples), - static_cast(n_features), - n_clusters, - static_cast(streaming_batch_size)); - - rmm::device_uvector workspace(0, stream); - - auto n_init = params.n_init; - if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { - RAFT_LOG_DEBUG( - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - auto best_centroids = n_init > 1 - ? raft::make_device_matrix(handle, n_clusters, n_features) - : raft::make_device_matrix(handle, 0, 0); - T best_inertia = std::numeric_limits::max(); - IdxT best_n_iter = 0; - - std::mt19937 gen(params.rng_state.seed); - - // ----- Allocate reusable work buffers (shared across n_init iterations) ----- - auto batch_weights = raft::make_device_vector(handle, streaming_batch_size); - auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, streaming_batch_size); - auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - auto clustering_cost = raft::make_device_vector(handle, 1); - auto batch_clustering_cost = raft::make_device_vector(handle, 1); - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); - - // Compute weight normalization (matches checkWeight in regular kmeans) - T weight_scale = compute_host_weight_scale(sample_weight, n_samples); - - // ---- Main n_init loop ---- - for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { - cuvs::cluster::kmeans::params iter_params = params; - iter_params.rng_state.seed = gen(); - - RAFT_LOG_DEBUG("KMeans batched fit: n_init iteration %d/%d (seed=%llu)", - seed_iter + 1, - n_init, - (unsigned long long)iter_params.rng_state.seed); - - if (iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample( - handle, iter_params, streaming_batch_size, X, centroids, workspace); - } - - if (!sample_weight.has_value()) { raft::matrix::fill(handle, batch_weights.view(), T{1}); } - - // Reset per-iteration state - T prior_cluster_cost = 0; - - cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, streaming_batch_size, stream); - - for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, weight_per_cluster.view(), T{0}); - - raft::matrix::fill(handle, clustering_cost.view(), T{0}); - - auto centroids_const = raft::make_const_mdspan(centroids); - - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - raft::matrix::fill(handle, batch_clustering_cost.view(), T{0}); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - copy_and_scale_batch_weights(handle, - sample_weight, - data_batch.offset(), - current_batch_size, - weight_scale, - batch_weights); - - auto batch_weights_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); - - auto L2NormBatch_view = - raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features), - L2NormBatch_view); - } - - auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch_view); - - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - batch_data_view, - centroids_const, - minClusterAndDistance_view, - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); - - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - weight_per_cluster.view(), - batch_sums.view(), - batch_counts.view()); - - if (params.inertia_check) { - raft::linalg::map( - handle, - minClusterAndDistance_view, - [=] __device__(const raft::KeyValuePair kvp, T wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance_view), - batch_weights_view); - - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(batch_clustering_cost.data_handle()), - raft::value_op{}, - raft::add_op{}); - raft::linalg::add(handle, - raft::make_const_mdspan(clustering_cost.view()), - raft::make_const_mdspan(batch_clustering_cost.view()), - clustering_cost.view()); - } - } - - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto weight_per_cluster_const = - raft::make_device_vector_view(weight_per_cluster.data_handle(), n_clusters); - - finalize_centroids(handle, - centroid_sums_const, - weight_per_cluster_const, - centroids_const, - new_centroids.view()); - - T sqrdNormError = compute_centroid_shift( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(new_centroids.view())); - - raft::copy(handle, centroids, new_centroids.view()); - - bool done = false; - if (params.inertia_check) { - raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); - raft::resource::sync_stream(handle); - ASSERT(inertia[0] != (T)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - if (n_iter[0] > 1) { - T delta = inertia[0] / prior_cluster_cost; - if (delta > 1 - params.tol) done = true; - } - prior_cluster_cost = inertia[0]; - } - - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - // Recompute final weighted inertia with the converged centroids. - { - auto centroids_const_view = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - - inertia[0] = T{0}; - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - std::optional> batch_sw = std::nullopt; - if (sample_weight.has_value()) { - copy_and_scale_batch_weights(handle, - sample_weight, - data_batch.offset(), - current_batch_size, - weight_scale, - batch_weights); - batch_sw = raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size); - } - - T batch_cost = T{0}; - cuvs::cluster::kmeans::cluster_cost(handle, - batch_data_view, - centroids_const_view, - raft::make_host_scalar_view(&batch_cost), - batch_sw); - - inertia[0] += batch_cost; - } - } - - RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", - seed_iter + 1, - n_init, - static_cast(inertia[0])); - - if (n_init > 1 && inertia[0] < best_inertia) { - best_inertia = inertia[0]; - best_n_iter = n_iter[0]; - raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - } - } - if (n_init > 1) { - inertia[0] = best_inertia; - n_iter[0] = best_n_iter; - raft::copy(handle, centroids, best_centroids.view()); - RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", - n_init, - static_cast(best_inertia), - best_n_iter); - } -} - -} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 250563dd12..c5ca78941e 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -43,6 +43,9 @@ #include #include #include +#include + +#include #include #include @@ -598,4 +601,179 @@ DataT compute_centroid_shift(raft::resources const& handle, return result; } +/** + * @brief Compute the weight normalization scale factor for sample weights that may + * reside on host memory. Weights are normalized to sum to n_samples. + * + * Works on any contiguous pointer (host or device) by copying to host for the sum. + * + * @tparam DataT Weight type + * @tparam IndexT Index type + * + * @param[in] weight_ptr Pointer to sample weights (host or device), may be nullptr + * @param[in] n_samples Number of samples + * @param[in] stream CUDA stream (used when pointer is device memory) + * @return Scale factor (1.0 if weights already sum to n_samples or nullptr) + */ +template +DataT compute_weight_scale(const DataT* weight_ptr, IndexT n_samples, cudaStream_t stream) +{ + if (weight_ptr == nullptr) { return DataT{1}; } + + bool is_host = true; + cudaPointerAttributes attr; + auto err = cudaPointerGetAttributes(&attr, weight_ptr); + if (err == cudaSuccess && attr.type == cudaMemoryTypeDevice) { is_host = false; } + cudaGetLastError(); // clear any error + + DataT wt_sum = DataT{0}; + if (is_host) { + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += weight_ptr[i]; + } + } else { + std::vector h_weights(n_samples); + raft::copy(h_weights.data(), weight_ptr, n_samples, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += h_weights[i]; + } + } + + if (wt_sum == static_cast(n_samples)) { return DataT{1}; } + + RAFT_LOG_DEBUG( + "[Warning!] KMeans: normalizing the user provided sample weight to " + "sum up to %zu samples", + static_cast(n_samples)); + return static_cast(n_samples) / wt_sum; +} + +/** + * @brief Process a single batch of data in the Lloyd iteration. + * + * This is the shared per-batch helper used by both the device (single-batch) and + * host-streaming (multi-batch) k-means paths. It operates entirely on device + * buffers: given one batch of data + weights + current centroids it + * 1. computes L2 norms (if needed), + * 2. finds the nearest centroid for every sample, + * 3. accumulates weighted centroid sums and counts into the running accumulators, + * 4. accumulates the weighted clustering cost (inertia). + * + * @tparam DataT Data / weight type (float, double) + * @tparam IndexT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] batch_data Device batch data [batch_size x n_features] + * @param[in] batch_weights Device batch weights [batch_size] + * @param[in] centroids Current centroids [n_clusters x n_features] + * @param[in] metric Distance metric + * @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute + * @param[in] batch_centroids_param Batch-centroids param forwarded to + * minClusterAndDistanceCompute + * @param[inout] minClusterAndDistance Work buffer [batch_size] + * @param[inout] L2NormBatch Work buffer for L2 norms [batch_size] + * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch + * @param[inout] workspace Resizable scratch + * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) + * @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into) + * @param[inout] batch_sums Scratch for this batch [n_clusters x n_features] + * @param[inout] batch_counts Scratch for this batch [n_clusters] + * @param[inout] clustering_cost Running cost scalar (device) (added into) + */ +template +void process_batch( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view batch_weights, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric, + int batch_samples_param, + int batch_centroids_param, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view L2NormBatch, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + raft::device_matrix_view batch_sums, + raft::device_vector_view batch_counts, + raft::device_scalar_view clustering_cost) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + IndexT current_batch_sz = batch_data.extent(0); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_sz, batch_data.extent(1)), + L2NormBatch); + } + + auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch); + + minClusterAndDistanceCompute(handle, + batch_data, + centroids, + minClusterAndDistance, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + batch_samples_param, + batch_centroids_param, + workspace); + + KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + auto batch_workspace = rmm::device_uvector( + current_batch_sz, stream, raft::resource::get_workspace_resource(handle)); + + compute_centroid_adjustments(handle, + batch_data, + batch_weights, + labels_itr, + static_cast(centroid_sums.extent(0)), + batch_sums, + batch_counts, + batch_workspace); + + raft::linalg::add(centroid_sums.data_handle(), + centroid_sums.data_handle(), + batch_sums.data_handle(), + centroid_sums.size(), + stream); + + raft::linalg::add(weight_per_cluster.data_handle(), + weight_per_cluster.data_handle(), + batch_counts.data_handle(), + weight_per_cluster.size(), + stream); + + raft::linalg::map( + handle, + minClusterAndDistance, + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance), + batch_weights); + + auto batch_cost = raft::make_device_scalar(handle, DataT{0}); + computeClusterCost( + handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{}); + raft::linalg::add(clustering_cost.data_handle(), + clustering_cost.data_handle(), + batch_cost.data_handle(), + 1, + stream); +} + } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index d7e4748e33..51cd21cb51 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -72,7 +71,7 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( + cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index f86fabcfbd..000774b9c6 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -72,7 +71,7 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( + cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index 437aa16c76..f521edd07f 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -18,8 +18,12 @@ void fit_main(raft::resources const& handle, raft::host_scalar_view n_iter, rmm::device_uvector& workspace) { - cuvs::cluster::kmeans::detail::kmeans_fit_main( - handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); + cuvs::cluster::kmeans::params p = params; + p.init = kmeans::params::InitMethod::Array; + p.n_init = 1; + auto sw = std::make_optional( + raft::make_device_vector_view(sample_weights.data_handle(), X.extent(0))); + cuvs::cluster::kmeans::detail::kmeans_fit(handle, p, X, sw, centroids, inertia, n_iter); } template @@ -31,7 +35,6 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise if (raft::resource::comms_initialized(handle)) { cuvs::cluster::kmeans::mg::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } else { @@ -54,4 +57,17 @@ void predict(raft::resources const& handle, handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } +template +void fit(raft::resources const& handle, + const kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + } // namespace cuvs::cluster::kmeans From 0a09e6f417a3c3d2ed4d3835d05a80c939f522e2 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 11:50:48 -0700 Subject: [PATCH 02/50] rm inertia_check --- c/include/cuvs/cluster/kmeans.h | 3 --- c/src/cluster/kmeans.cpp | 2 -- cpp/include/cuvs/cluster/kmeans.hpp | 5 ----- cpp/src/cluster/detail/kmeans_mg.cuh | 2 +- cpp/tests/cluster/kmeans.cu | 6 ++---- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 1 - python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 9 --------- 7 files changed, 3 insertions(+), 25 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 8f55edb925..ccf1144bb1 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -91,9 +91,6 @@ struct cuvsKMeansParams { */ int batch_centroids; - /** Check inertia during iterations for early convergence. */ - bool inertia_check; - /** * Whether to use hierarchical (balanced) kmeans or not */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index a84cd50259..2f9156ecab 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -28,7 +28,6 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.oversampling_factor = params.oversampling_factor; kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; - kmeans_params.inertia_check = params.inertia_check; kmeans_params.streaming_batch_size = params.streaming_batch_size; return kmeans_params; } @@ -237,7 +236,6 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .oversampling_factor = cpp_params.oversampling_factor, .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, - .inertia_check = cpp_params.inertia_check, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), .streaming_batch_size = cpp_params.streaming_batch_size}; diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index d299d9f483..1122c19f22 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -112,11 +112,6 @@ struct params : base_params { */ int batch_centroids = 0; - /** - * If true, check inertia during iterations for early convergence. - */ - bool inertia_check = false; - /** * Number of samples to process per GPU batch when fitting with host data. * When set to 0, defaults to n_samples (process all at once). diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 4c8d7f8b2a..4506aad36e 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -701,7 +701,7 @@ void fit(const raft::resources& handle, raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); bool done = false; - if (params.inertia_check) { + { rmm::device_scalar> clusterCostD(stream); // calculate cluster cost phi_x(C) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 1ef8d07623..ee6aac097b 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -433,9 +433,8 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(d_centroids_ref.data(), params.n_clusters, n_features); - params.init = cuvs::cluster::kmeans::params::Array; - params.inertia_check = true; - params.max_iter = 20; + params.init = cuvs::cluster::kmeans::params::Array; + params.max_iter = 20; T ref_inertia = 0; int ref_n_iter = 0; @@ -448,7 +447,6 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(&ref_n_iter)); cuvs::cluster::kmeans::params batched_params = params; - batched_params.inertia_check = true; batched_params.streaming_batch_size = testparams.streaming_batch_size; std::optional> h_sw = std::nullopt; diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 6d0c878660..a99ac50464 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -32,7 +32,6 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, - bool inertia_check, int64_t streaming_batch_size, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b267c908c9..656da2f978 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -76,8 +76,6 @@ cdef class KMeansParams: [batch_samples x n_clusters]. batch_centroids : int Number of centroids to process in each batch. If 0, uses n_clusters. - inertia_check : bool - If True, check inertia during iterations for early convergence. streaming_batch_size : int Number of samples to process per GPU batch when fitting with host (numpy) data. When set to 0, defaults to n_samples (process all @@ -111,7 +109,6 @@ cdef class KMeansParams: oversampling_factor=None, batch_samples=None, batch_centroids=None, - inertia_check=None, streaming_batch_size=None, hierarchical=None, hierarchical_n_iters=None): @@ -134,8 +131,6 @@ cdef class KMeansParams: self.params.batch_samples = batch_samples if batch_centroids is not None: self.params.batch_centroids = batch_centroids - if inertia_check is not None: - self.params.inertia_check = inertia_check if streaming_batch_size is not None: self.params.streaming_batch_size = streaming_batch_size if hierarchical is not None: @@ -182,10 +177,6 @@ cdef class KMeansParams: def batch_centroids(self): return self.params.batch_centroids - @property - def inertia_check(self): - return self.params.inertia_check - @property def streaming_batch_size(self): return self.params.streaming_batch_size From 99a5730fc2f69268d54e99dd1deffc9bcf5c0063 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 12:27:38 -0700 Subject: [PATCH 03/50] change to warning --- cpp/src/cluster/detail/kmeans.cuh | 7 +++---- cpp/src/cluster/detail/kmeans_mg.cuh | 8 +++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 1c019a07d8..45b9eef2aa 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -822,10 +822,9 @@ void kmeans_fit( raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream); raft::resource::sync_stream(handle, stream); - ASSERT(curClusteringCost != DataT{0}, - "Too few points and centroids being found is getting 0 cost from centers"); - - if (n_current_iter > 1) { + if (curClusteringCost == DataT{0}) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_current_iter > 1) { DataT delta = curClusteringCost / priorClusteringCost; if (delta > 1 - iter_params.tol) done = true; } diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 4506aad36e..47997b2535 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -734,11 +734,9 @@ void fit(const raft::resources& handle, ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " "from a failed rank"); - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers\n"); - - if (n_iter[0] > 0) { + if (curClusteringCost == (DataT)0.0) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_iter[0] > 0) { DataT delta = curClusteringCost / priorClusteringCost; if (delta > 1 - params.tol) done = true; } From a07740659fa00c67396f7dd3f5a21d11b0e542f4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 12:29:18 -0700 Subject: [PATCH 04/50] style --- cpp/tests/cluster/kmeans.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index ee6aac097b..5d48ef099e 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -433,7 +433,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(d_centroids_ref.data(), params.n_clusters, n_features); - params.init = cuvs::cluster::kmeans::params::Array; + params.init = cuvs::cluster::kmeans::params::Array; params.max_iter = 20; T ref_inertia = 0; From d6598755b0a6a3f6e8112280b5760d9ed0724060 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 12:43:52 -0700 Subject: [PATCH 05/50] add init_size param --- c/include/cuvs/cluster/kmeans.h | 6 +++ c/src/cluster/kmeans.cpp | 2 + cpp/include/cuvs/cluster/kmeans.hpp | 10 ++++ cpp/src/cluster/detail/kmeans.cuh | 63 ++++++++-------------- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 1 + python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 11 ++++ 6 files changed, 53 insertions(+), 40 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index ccf1144bb1..4118986043 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -101,6 +101,12 @@ struct cuvsKMeansParams { */ int hierarchical_n_iters; + /** + * Number of samples to draw for KMeansPlusPlus initialization with host data. + * When set to 0, uses heuristic min(3 * n_clusters, n_samples). + */ + int64_t init_size; + /** * Number of samples to process per GPU batch for the batched (host-data) API. * When set to 0, defaults to n_samples (process all at once). diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 2f9156ecab..dd1436c1f0 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -28,6 +28,7 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.oversampling_factor = params.oversampling_factor; kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; + kmeans_params.init_size = params.init_size; kmeans_params.streaming_batch_size = params.streaming_batch_size; return kmeans_params; } @@ -238,6 +239,7 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .batch_centroids = cpp_params.batch_centroids, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), + .init_size = cpp_params.init_size, .streaming_batch_size = cpp_params.streaming_batch_size}; }); } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 1122c19f22..ff7d056f7d 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -112,6 +112,16 @@ struct params : base_params { */ int batch_centroids = 0; + /** + * Number of samples to randomly draw for the KMeansPlusPlus initialization + * step. A random subset of this size is used for centroid seeding. + * When set to 0 the default depends on the data location: + * - Device data: n_samples (use the full dataset). + * - Host data: min(3 * n_clusters, n_samples). + * Default: 0. + */ + int64_t init_size = 0; + /** * Number of samples to process per GPU batch when fitting with host data. * When set to 0, defaults to n_samples (process all at once). diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 45b9eef2aa..5add300985 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -639,47 +639,30 @@ void kmeans_fit( return; } - if constexpr (data_on_device) { - auto X_dev = - raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); - - if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - initRandom(handle, iter_params, X_dev, centroidsRawData); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - if (iter_params.oversampling_factor == 0) - kmeansPlusPlus(handle, iter_params, X_dev, centroidsRawData, workspace); - else - initScalableKMeansPlusPlus( - handle, iter_params, X_dev, centroidsRawData, workspace); - } else { - THROW("unknown initialization method to select initial centers"); - } + raft::random::RngState random_state(iter_params.rng_state.seed); + + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + raft::matrix::sample_rows(handle, random_state, X, centroidsRawData); + } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IndexT default_init_size = + data_on_device ? n_samples : std::min(static_cast(3 * n_clusters), n_samples); + IndexT init_sample_size = iter_params.init_size > 0 + ? std::min(static_cast(iter_params.init_size), n_samples) + : default_init_size; + + auto init_sample = + raft::make_device_matrix(handle, init_sample_size, n_features); + raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); + + auto init_sample_const = raft::make_const_mdspan(init_sample.view()); + if (iter_params.oversampling_factor == 0) + kmeansPlusPlus( + handle, iter_params, init_sample_const, centroidsRawData, workspace); + else + initScalableKMeansPlusPlus( + handle, iter_params, init_sample_const, centroidsRawData, workspace); } else { - raft::random::RngState random_state(iter_params.rng_state.seed); - - if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - raft::matrix::sample_rows(handle, random_state, X, centroidsRawData); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - IndexT init_sample_size = - std::min(static_cast(3 * streaming_batch_size), n_samples); - if (init_sample_size < n_clusters) { - init_sample_size = std::min(static_cast(3 * n_clusters), n_samples); - } - - auto init_sample = - raft::make_device_matrix(handle, init_sample_size, n_features); - raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - - auto init_sample_const = raft::make_const_mdspan(init_sample.view()); - if (iter_params.oversampling_factor == 0) - kmeansPlusPlus( - handle, iter_params, init_sample_const, centroidsRawData, workspace); - else - initScalableKMeansPlusPlus( - handle, iter_params, init_sample_const, centroidsRawData, workspace); - } else { - THROW("unknown initialization method to select initial centers"); - } + THROW("unknown initialization method to select initial centers"); } }; diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index a99ac50464..0a793d81a6 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -32,6 +32,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, + int64_t init_size, int64_t streaming_batch_size, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 656da2f978..246ac4138c 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -76,6 +76,10 @@ cdef class KMeansParams: [batch_samples x n_clusters]. batch_centroids : int Number of centroids to process in each batch. If 0, uses n_clusters. + init_size : int + Number of samples to draw for KMeansPlusPlus initialization with + host (out-of-core) data. When set to 0, uses the heuristic + min(3 * n_clusters, n_samples). Default: 0. streaming_batch_size : int Number of samples to process per GPU batch when fitting with host (numpy) data. When set to 0, defaults to n_samples (process all @@ -109,6 +113,7 @@ cdef class KMeansParams: oversampling_factor=None, batch_samples=None, batch_centroids=None, + init_size=None, streaming_batch_size=None, hierarchical=None, hierarchical_n_iters=None): @@ -131,6 +136,8 @@ cdef class KMeansParams: self.params.batch_samples = batch_samples if batch_centroids is not None: self.params.batch_centroids = batch_centroids + if init_size is not None: + self.params.init_size = init_size if streaming_batch_size is not None: self.params.streaming_batch_size = streaming_batch_size if hierarchical is not None: @@ -177,6 +184,10 @@ cdef class KMeansParams: def batch_centroids(self): return self.params.batch_centroids + @property + def init_size(self): + return self.params.init_size + @property def streaming_batch_size(self): return self.params.streaming_batch_size From 03a64736812c702fdcae319593f522e3e6cea281 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 12:56:22 -0700 Subject: [PATCH 06/50] docs --- cpp/src/cluster/detail/kmeans.cuh | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5add300985..cd69068256 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -583,6 +583,22 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, * @tparam DataT Data / weight type * @tparam IndexT Index type * @tparam Accessor Accessor policy (host or device); deduced from X + * + * @param[in] handle The raft handle. + * @param[in] pams Parameters for the KMeans model. + * @param[in] X Training instances to cluster (host or device). + * Row-major, [n_samples x n_features]. + * @param[in] sample_weight Optional weights for each observation in X. + * [n_samples]. When std::nullopt, uniform weights + * are used. + * @param[inout] centroids [in] When init is InitMethod::Array, used as + * the initial cluster centers. + * [out] The final centroids produced by the + * algorithm. [n_clusters x n_features]. + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run for the best + * initialization. */ template void kmeans_fit( From 86af2fa3b688e971c14142e922ab0dec076e2f2e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 16:54:27 -0700 Subject: [PATCH 07/50] rm direct cuda api calls --- cpp/src/cluster/detail/kmeans.cuh | 10 +-- cpp/src/cluster/detail/kmeans_common.cuh | 12 ++-- cpp/src/cluster/detail/kmeans_mg.cuh | 78 ++++++++++++------------ 3 files changed, 44 insertions(+), 56 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index cd69068256..5893ba12ef 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -572,14 +572,6 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, /** * @brief Unified k-means fit (works with host or device data). * - * Handles centroid initialization, the n_init best-of loop, and the batched - * Lloyd iteration. All reusable work buffers are allocated once before the - * n_init loop and shared across iterations. - * - * Data and weights are batched via batch_load_iterator (transparent H2D copy - * for host memory, zero-copy offset for device). Only batch-sized device - * buffers are allocated — no O(n_samples) device memory. - * * @tparam DataT Data / weight type * @tparam IndexT Index type * @tparam Accessor Accessor policy (host or device); deduced from X @@ -882,7 +874,7 @@ void kmeans_fit( inertia[0], n_iter[0]); } - RAFT_LOG_DEBUG("KMeans.fit: completed."); + RAFT_LOG_DEBUG("KMeans.fit: async call returned (fit could still be running on the device)"); } template diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index c5ca78941e..682bf1244c 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -620,21 +621,18 @@ DataT compute_weight_scale(const DataT* weight_ptr, IndexT n_samples, cudaStream { if (weight_ptr == nullptr) { return DataT{1}; } - bool is_host = true; - cudaPointerAttributes attr; - auto err = cudaPointerGetAttributes(&attr, weight_ptr); - if (err == cudaSuccess && attr.type == cudaMemoryTypeDevice) { is_host = false; } - cudaGetLastError(); // clear any error + bool is_device_accessible = + raft::is_device_accessible(raft::memory_type_from_pointer(weight_ptr)); DataT wt_sum = DataT{0}; - if (is_host) { + if (!is_device_accessible) { for (IndexT i = 0; i < n_samples; ++i) { wt_sum += weight_ptr[i]; } } else { std::vector h_weights(n_samples); raft::copy(h_weights.data(), weight_ptr, n_samples, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::resource::sync_stream(handle); for (IndexT i = 0; i < n_samples; ++i) { wt_sum += h_weights[i]; } diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 47997b2535..9dd2eaa060 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -701,47 +701,45 @@ void fit(const raft::resources& handle, raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); bool done = false; - { - rmm::device_scalar> clusterCostD(stream); - - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::cluster_cost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); - - // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); - - DataT curClusteringCost = 0; - raft::copy(handle, - raft::make_host_scalar_view(&curClusteringCost), - raft::make_device_scalar_view(&(clusterCostD.data()->value))); - - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - if (curClusteringCost == (DataT)0.0) { - RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); - } else if (n_iter[0] > 0) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; + rmm::device_scalar> clusterCostD(stream); + + // calculate cluster cost phi_x(C) + cuvs::cluster::kmeans::cluster_cost( + handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + cuda::proclaim_return_type>( + [] __device__(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) { + raft::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + })); + + // Cluster cost phi_x(C) from all ranks + comm.allreduce(&(clusterCostD.data()->value), + &(clusterCostD.data()->value), + 1, + raft::comms::op_t::SUM, + stream); + + DataT curClusteringCost = 0; + raft::copy(handle, + raft::make_host_scalar_view(&curClusteringCost), + raft::make_device_scalar_view(&(clusterCostD.data()->value))); + + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + if (curClusteringCost == (DataT)0.0) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_iter[0] > 0) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; } + priorClusteringCost = curClusteringCost; raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; From d4e4e2cff099bcb7888d8f5285a75acce8805e4a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 17:04:19 -0700 Subject: [PATCH 08/50] std::swap instead of raft::copy --- cpp/src/cluster/detail/kmeans.cuh | 33 +++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5893ba12ef..39fbdcdca0 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -683,7 +683,11 @@ void kmeans_fit( n_init = 1; } - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); + IndexT centroid_buf_size = n_clusters * n_features; + rmm::device_uvector centroid_buf_A(centroid_buf_size, stream); + rmm::device_uvector centroid_buf_B(centroid_buf_size, stream); + DataT* cur_centroids_ptr = centroid_buf_A.data(); + DataT* new_centroids_ptr = centroid_buf_B.data(); auto minClusterAndDistance = raft::make_device_vector, IndexT>( handle, streaming_batch_size); @@ -693,7 +697,6 @@ void kmeans_fit( auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); @@ -738,7 +741,11 @@ void kmeans_fit( n_init, (unsigned long long)iter_params.rng_state.seed); - init_centroids(iter_params, centroidsRawData.view()); + cur_centroids_ptr = centroid_buf_A.data(); + new_centroids_ptr = centroid_buf_B.data(); + init_centroids( + iter_params, + raft::make_device_matrix_view(cur_centroids_ptr, n_clusters, n_features)); DataT iter_inertia = std::numeric_limits::max(); IndexT n_current_iter = 0; @@ -754,7 +761,9 @@ void kmeans_fit( raft::const_op{DataT{0}}); auto centroids_const = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); + cur_centroids_ptr, n_clusters, n_features); + auto new_centroids_view = + raft::make_device_matrix_view(new_centroids_ptr, n_clusters, n_features); data_batches.reset(); weight_batches.reset(); @@ -795,17 +804,14 @@ void kmeans_fit( raft::make_const_mdspan(centroid_sums.view()), raft::make_const_mdspan(weight_per_cluster.view()), centroids_const, - new_centroids.view()); + new_centroids_view); DataT sqrdNormError = compute_centroid_shift(handle, raft::make_const_mdspan(centroids_const), - raft::make_const_mdspan(new_centroids.view())); + raft::make_const_mdspan(new_centroids_view)); - raft::copy( - handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), new_centroids.size()), - raft::make_device_vector_view(new_centroids.data_handle(), new_centroids.size())); + std::swap(cur_centroids_ptr, new_centroids_ptr); bool done = false; @@ -832,7 +838,7 @@ void kmeans_fit( { auto centroids_const = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); + cur_centroids_ptr, n_clusters, n_features); iter_inertia = DataT{0}; data_batches.reset(); @@ -863,10 +869,7 @@ void kmeans_fit( if (iter_inertia < inertia[0]) { inertia[0] = iter_inertia; n_iter[0] = n_current_iter; - raft::copy( - handle, - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); + raft::copy(centroids.data_handle(), cur_centroids_ptr, centroid_buf_size, stream); } RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter - %d", seed_iter + 1, From 0819af5e21ee2d1ad1bbceb89e18130c4bb9a5e4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 17:20:36 -0700 Subject: [PATCH 09/50] cache batch norms --- cpp/src/cluster/detail/kmeans.cuh | 55 ++++++++++++++++++++++-- cpp/src/cluster/detail/kmeans_common.cuh | 29 +++++-------- 2 files changed, 61 insertions(+), 23 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 39fbdcdca0..f13f86e7a9 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -631,7 +631,7 @@ void kmeans_fit( const DataT* weight_ptr = sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; - DataT weight_scale = compute_weight_scale(weight_ptr, n_samples, stream); + DataT weight_scale = compute_weight_scale(handle, weight_ptr, n_samples); rmm::device_uvector workspace(0, stream); @@ -729,6 +729,33 @@ void kmeans_fit( n_clusters, static_cast(streaming_batch_size)); + bool compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + bool use_norm_cache = compute_norms && !data_on_device; + std::vector h_norm_cache; + if (use_norm_cache) { h_norm_cache.resize(n_samples); } + bool norms_cached = false; + + auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { + auto batch_view = + raft::make_device_matrix_view(batch_ptr, batch_size, n_features); + auto norm_view = + raft::make_device_vector_view(L2NormBatch.data_handle(), batch_size); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, batch_view, norm_view, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, batch_view, norm_view); + } + }; + + if (compute_norms && data_on_device) { + compute_batch_norms(X.data_handle(), n_samples); + norms_cached = true; + } + std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); @@ -779,8 +806,24 @@ void kmeans_fit( auto minCAD_view = raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle(), cur_batch_size); - auto l2_view = - raft::make_device_vector_view(L2NormBatch.data_handle(), cur_batch_size); + + if (compute_norms && !norms_cached) { + compute_batch_norms(data_batch.data(), cur_batch_size); + if (use_norm_cache) { + raft::copy(h_norm_cache.data() + data_batch.offset(), + L2NormBatch.data_handle(), + cur_batch_size, + stream); + } + } else if (use_norm_cache) { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data() + data_batch.offset(), + cur_batch_size, + stream); + } + + auto l2_const_view = raft::make_device_vector_view( + L2NormBatch.data_handle(), cur_batch_size); process_batch(handle, batch_data_view, @@ -790,7 +833,7 @@ void kmeans_fit( iter_params.batch_samples, iter_params.batch_centroids, minCAD_view, - l2_view, + l2_const_view, L2NormBuf_OR_DistBuf, workspace, centroid_sums.view(), @@ -799,6 +842,10 @@ void kmeans_fit( batch_counts.view(), clustering_cost.view()); } + if (!norms_cached && use_norm_cache) { + raft::resource::sync_stream(handle, stream); + norms_cached = true; + } finalize_centroids(handle, raft::make_const_mdspan(centroid_sums.view()), diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 682bf1244c..511aed9566 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -158,7 +158,7 @@ void checkWeight(raft::resources const& handle, raft::copy(handle, raft::make_host_scalar_view(&wt_sum), raft::make_device_scalar_view(wt_aggr.data_handle())); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); if (wt_sum != n_samples) { RAFT_LOG_DEBUG( @@ -266,7 +266,7 @@ void sampleCentroids(raft::resources const& handle, raft::copy(handle, raft::make_host_scalar_view(&nPtsSampledInRank), raft::make_device_scalar_view(nSelected.data_handle())); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); thrust::for_each_n(raft::resource::get_thrust_policy(handle), @@ -598,7 +598,7 @@ DataT compute_centroid_shift(raft::resources const& handle, new_centroids.data_handle()); DataT result = 0; raft::copy(&result, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); return result; } @@ -611,13 +611,13 @@ DataT compute_centroid_shift(raft::resources const& handle, * @tparam DataT Weight type * @tparam IndexT Index type * + * @param[in] handle RAFT resources handle * @param[in] weight_ptr Pointer to sample weights (host or device), may be nullptr * @param[in] n_samples Number of samples - * @param[in] stream CUDA stream (used when pointer is device memory) * @return Scale factor (1.0 if weights already sum to n_samples or nullptr) */ template -DataT compute_weight_scale(const DataT* weight_ptr, IndexT n_samples, cudaStream_t stream) +DataT compute_weight_scale(raft::resources const& handle, const DataT* weight_ptr, IndexT n_samples) { if (weight_ptr == nullptr) { return DataT{1}; } @@ -631,7 +631,9 @@ DataT compute_weight_scale(const DataT* weight_ptr, IndexT n_samples, cudaStream } } else { std::vector h_weights(n_samples); - raft::copy(h_weights.data(), weight_ptr, n_samples, stream); + auto d_view = raft::make_device_vector_view(weight_ptr, n_samples); + auto h_view = raft::make_host_vector_view(h_weights.data(), n_samples); + raft::copy(handle, h_view, d_view); raft::resource::sync_stream(handle); for (IndexT i = 0; i < n_samples; ++i) { wt_sum += h_weights[i]; @@ -689,7 +691,7 @@ void process_batch( int batch_samples_param, int batch_centroids_param, raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormBatch, + raft::device_vector_view L2NormBatch, rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace, raft::device_matrix_view centroid_sums, @@ -701,22 +703,11 @@ void process_batch( cudaStream_t stream = raft::resource::get_cuda_stream(handle); IndexT current_batch_sz = batch_data.extent(0); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - batch_data.data_handle(), current_batch_sz, batch_data.extent(1)), - L2NormBatch); - } - - auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch); - minClusterAndDistanceCompute(handle, batch_data, centroids, minClusterAndDistance, - L2NormBatch_const, + L2NormBatch, L2NormBuf_OR_DistBuf, metric, batch_samples_param, From e0f079c0cbd0978fffdc0993cd8916523d03d0e1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 17:53:33 -0700 Subject: [PATCH 10/50] centroid norms can also be cached per iteration --- cpp/src/cluster/detail/kmeans.cuh | 25 ++++-- cpp/src/cluster/detail/kmeans_common.cuh | 37 +++++--- .../detail/minClusterDistanceCompute.cu | 87 ++++++++++--------- 3 files changed, 87 insertions(+), 62 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index f13f86e7a9..f9e7fe7505 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -697,6 +697,7 @@ void kmeans_fit( auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); + auto centroid_norms_buf = raft::make_device_vector(handle, n_clusters); auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); @@ -729,10 +730,10 @@ void kmeans_fit( n_clusters, static_cast(streaming_batch_size)); - bool compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; - bool use_norm_cache = compute_norms && !data_on_device; + bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + bool use_norm_cache = need_compute_norms && !data_on_device; std::vector h_norm_cache; if (use_norm_cache) { h_norm_cache.resize(n_samples); } bool norms_cached = false; @@ -751,7 +752,7 @@ void kmeans_fit( } }; - if (compute_norms && data_on_device) { + if (need_compute_norms && data_on_device) { compute_batch_norms(X.data_handle(), n_samples); norms_cached = true; } @@ -792,6 +793,15 @@ void kmeans_fit( auto new_centroids_view = raft::make_device_matrix_view(new_centroids_ptr, n_clusters, n_features); + std::optional> centroid_norms_opt = + std::nullopt; + if (need_compute_norms) { + raft::linalg::norm( + handle, centroids_const, centroid_norms_buf.view()); + centroid_norms_opt = raft::make_device_vector_view( + centroid_norms_buf.data_handle(), n_clusters); + } + data_batches.reset(); weight_batches.reset(); auto wt_it = weight_batches.begin(); @@ -807,7 +817,7 @@ void kmeans_fit( auto minCAD_view = raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle(), cur_batch_size); - if (compute_norms && !norms_cached) { + if (need_compute_norms && !norms_cached) { compute_batch_norms(data_batch.data(), cur_batch_size); if (use_norm_cache) { raft::copy(h_norm_cache.data() + data_batch.offset(), @@ -840,7 +850,8 @@ void kmeans_fit( weight_per_cluster.view(), batch_sums.view(), batch_counts.view(), - clustering_cost.view()); + clustering_cost.view(), + centroid_norms_opt); } if (!norms_cached && use_norm_cache) { raft::resource::sync_stream(handle, stream); diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 511aed9566..bc2de15726 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -371,7 +371,9 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms = + std::nullopt); #define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ extern template void minClusterAndDistanceCompute( \ @@ -384,7 +386,8 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int) @@ -403,7 +406,9 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, + std::optional> + precomputed_centroid_norms = std::nullopt); #define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \ extern template void minClusterDistanceCompute( \ @@ -416,7 +421,8 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t) @@ -652,13 +658,12 @@ DataT compute_weight_scale(raft::resources const& handle, const DataT* weight_pt /** * @brief Process a single batch of data in the Lloyd iteration. * - * This is the shared per-batch helper used by both the device (single-batch) and - * host-streaming (multi-batch) k-means paths. It operates entirely on device - * buffers: given one batch of data + weights + current centroids it - * 1. computes L2 norms (if needed), - * 2. finds the nearest centroid for every sample, - * 3. accumulates weighted centroid sums and counts into the running accumulators, - * 4. accumulates the weighted clustering cost (inertia). + * Given one batch of data + precomputed norms + weights + current centroids it + * 1. finds the nearest centroid for every sample, + * 2. accumulates weighted centroid sums and counts into the running accumulators, + * 3. accumulates the weighted clustering cost (inertia). + * + * Data norms must be precomputed by the caller and passed in via L2NormBatch. * * @tparam DataT Data / weight type (float, double) * @tparam IndexT Index type (int, int64_t) @@ -672,7 +677,7 @@ DataT compute_weight_scale(raft::resources const& handle, const DataT* weight_pt * @param[in] batch_centroids_param Batch-centroids param forwarded to * minClusterAndDistanceCompute * @param[inout] minClusterAndDistance Work buffer [batch_size] - * @param[inout] L2NormBatch Work buffer for L2 norms [batch_size] + * @param[in] L2NormBatch Precomputed data norms [batch_size] * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch * @param[inout] workspace Resizable scratch * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) @@ -680,6 +685,8 @@ DataT compute_weight_scale(raft::resources const& handle, const DataT* weight_pt * @param[inout] batch_sums Scratch for this batch [n_clusters x n_features] * @param[inout] batch_counts Scratch for this batch [n_clusters] * @param[inout] clustering_cost Running cost scalar (device) (added into) + * @param[in] centroid_norms Optional precomputed centroid norms [n_clusters]. + * When provided, skips internal centroid norm computation. */ template void process_batch( @@ -698,7 +705,8 @@ void process_batch( raft::device_vector_view weight_per_cluster, raft::device_matrix_view batch_sums, raft::device_vector_view batch_counts, - raft::device_scalar_view clustering_cost) + raft::device_scalar_view clustering_cost, + std::optional> centroid_norms = std::nullopt) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); IndexT current_batch_sz = batch_data.extent(0); @@ -712,7 +720,8 @@ void process_batch( metric, batch_samples_param, batch_centroids_param, - workspace); + workspace, + centroid_norms); KeyValueIndexOp conversion_op; thrust::transform_iterator, diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 8370ff922f..bcfc381753 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -7,6 +7,8 @@ #include +#include + namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an @@ -23,36 +25,34 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded; auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (!precomputed_centroid_norms.has_value()) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::norm( + handle, + centroids, + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto centroidsNorm_view = + precomputed_centroid_norms.has_value() + ? precomputed_centroid_norms.value() + : raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); @@ -87,7 +87,7 @@ void minClusterAndDistanceCompute( datasetView.data_handle(), centroids.data_handle(), L2NormXView.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_view.data_handle(), ns, n_clusters, n_features, @@ -154,7 +154,8 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int64_t) @@ -164,16 +165,18 @@ INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int) #undef INSTANTIATE_MIN_CLUSTER_AND_DISTANCE template -void minClusterDistanceCompute(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void minClusterDistanceCompute( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view minClusterDistance, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -186,21 +189,22 @@ void minClusterDistanceCompute(raft::resources const& handle, auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (!precomputed_centroid_norms.has_value()) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } } else { L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto centroidsNorm_view = + precomputed_centroid_norms.has_value() + ? precomputed_centroid_norms.value() + : raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); @@ -232,7 +236,7 @@ void minClusterDistanceCompute(raft::resources const& handle, datasetView.data_handle(), centroids.data_handle(), L2NormXView.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_view.data_handle(), ns, n_clusters, n_features, @@ -290,7 +294,8 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); INSTANTIATE_MIN_CLUSTER_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_DISTANCE(double, int64_t) From c2f739003ff8131a3809c794d7edf117687d9f48 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 18:11:28 -0700 Subject: [PATCH 11/50] mg n_iter --- cpp/src/cluster/detail/kmeans_mg.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 9dd2eaa060..cbc75c822c 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -735,7 +735,7 @@ void fit(const raft::resources& handle, "from a failed rank"); if (curClusteringCost == (DataT)0.0) { RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); - } else if (n_iter[0] > 0) { + } else if (n_iter[0] > 1) { DataT delta = curClusteringCost / priorClusteringCost; if (delta > 1 - params.tol) done = true; } From b9c310289f86039b80fb07aac73f70cb870794a7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 18:19:28 -0700 Subject: [PATCH 12/50] pre-commit --- cpp/src/cluster/detail/kmeans.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index f9e7fe7505..a35e557ba5 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -731,9 +731,9 @@ void kmeans_fit( static_cast(streaming_batch_size)); bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; - bool use_norm_cache = need_compute_norms && !data_on_device; + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + bool use_norm_cache = need_compute_norms && !data_on_device; std::vector h_norm_cache; if (use_norm_cache) { h_norm_cache.resize(n_samples); } bool norms_cached = false; From e3956c1323b03cfca58c414a6d07142ca3a7a041 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Apr 2026 18:59:22 -0700 Subject: [PATCH 13/50] do not break c abi --- c/include/cuvs/cluster/kmeans.h | 16 ++++++++++------ c/src/cluster/kmeans.cpp | 5 +++-- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 7 ++++--- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 4118986043..fbc6877a00 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -91,6 +91,9 @@ struct cuvsKMeansParams { */ int batch_centroids; + /** Deprecated, ignored. Kept for ABI compatibility. */ + bool inertia_check; + /** * Whether to use hierarchical (balanced) kmeans or not */ @@ -102,16 +105,17 @@ struct cuvsKMeansParams { int hierarchical_n_iters; /** - * Number of samples to draw for KMeansPlusPlus initialization with host data. - * When set to 0, uses heuristic min(3 * n_clusters, n_samples). + * Number of samples to process per GPU batch for the batched (host-data) API. + * When set to 0, defaults to n_samples (process all at once). */ - int64_t init_size; + int64_t streaming_batch_size; /** - * Number of samples to process per GPU batch for the batched (host-data) API. - * When set to 0, defaults to n_samples (process all at once). + * Number of samples to draw for KMeansPlusPlus initialization. + * When set to 0, uses heuristic min(3 * n_clusters, n_samples) for host data, + * or n_samples for device data. */ - int64_t streaming_batch_size; + int64_t init_size; }; typedef struct cuvsKMeansParams* cuvsKMeansParams_t; diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index dd1436c1f0..495a83f8d5 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -237,10 +237,11 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .oversampling_factor = cpp_params.oversampling_factor, .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, + .inertia_check = false, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), - .init_size = cpp_params.init_size, - .streaming_batch_size = cpp_params.streaming_batch_size}; + .streaming_batch_size = cpp_params.streaming_batch_size, + .init_size = cpp_params.init_size}; }); } diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 0a793d81a6..ccacb7042b 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -32,10 +32,11 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, - int64_t init_size, - int64_t streaming_batch_size, + bool inertia_check, bool hierarchical, - int hierarchical_n_iters + int hierarchical_n_iters, + int64_t streaming_batch_size, + int64_t init_size ctypedef cuvsKMeansParams* cuvsKMeansParams_t From 384d054af87048dcec03e810d580d4589e05b43a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 21 Apr 2026 15:03:30 -0700 Subject: [PATCH 14/50] fix checkWeight --- cpp/src/cluster/detail/kmeans.cuh | 76 +++++++++++----- cpp/src/cluster/detail/kmeans_common.cuh | 109 +++++++---------------- cpp/src/cluster/detail/kmeans_mg.cuh | 20 ++--- 3 files changed, 91 insertions(+), 114 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index a35e557ba5..7ac43f43b2 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -631,11 +631,21 @@ void kmeans_fit( const DataT* weight_ptr = sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; - DataT weight_scale = compute_weight_scale(handle, weight_ptr, n_samples); + DataT wt_sum = sample_weight.has_value() ? weightSum(handle, sample_weight.value()) + : static_cast(n_samples); rmm::device_uvector workspace(0, stream); - constexpr bool data_on_device = !raft::is_host_mdspan_v; + constexpr bool data_on_device = raft::is_device_mdspan_v; + + if (data_on_device && streaming_batch_size != static_cast(n_samples)) { + RAFT_LOG_WARN( + "KMeans: streaming_batch_size (%zu) ignored when data resides on device; using n_samples " + "(%zu)", + static_cast(streaming_batch_size), + static_cast(n_samples)); + streaming_batch_size = static_cast(n_samples); + } auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, raft::device_matrix_view centroidsRawData) { @@ -652,23 +662,30 @@ void kmeans_fit( if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { raft::matrix::sample_rows(handle, random_state, X, centroidsRawData); } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - IndexT default_init_size = - data_on_device ? n_samples : std::min(static_cast(3 * n_clusters), n_samples); - IndexT init_sample_size = iter_params.init_size > 0 - ? std::min(static_cast(iter_params.init_size), n_samples) - : default_init_size; - - auto init_sample = - raft::make_device_matrix(handle, init_sample_size, n_features); - raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - - auto init_sample_const = raft::make_const_mdspan(init_sample.view()); - if (iter_params.oversampling_factor == 0) - kmeansPlusPlus( - handle, iter_params, init_sample_const, centroidsRawData, workspace); - else - initScalableKMeansPlusPlus( - handle, iter_params, init_sample_const, centroidsRawData, workspace); + auto run_kmeanspp = [&](raft::device_matrix_view init_data) { + if (iter_params.oversampling_factor == 0) + kmeansPlusPlus( + handle, iter_params, init_data, centroidsRawData, workspace); + else + initScalableKMeansPlusPlus( + handle, iter_params, init_data, centroidsRawData, workspace); + }; + + if constexpr (data_on_device) { + run_kmeanspp(X); + } else { + IndexT default_init_size = + std::min(static_cast(std::int64_t{3} * n_clusters), n_samples); + IndexT init_sample_size = + iter_params.init_size > 0 + ? std::min(static_cast(iter_params.init_size), n_samples) + : default_init_size; + + auto init_sample = + raft::make_device_matrix(handle, init_sample_size, n_features); + raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); + run_kmeanspp(raft::make_const_mdspan(init_sample.view())); + } } else { THROW("unknown initialization method to select initial centers"); } @@ -712,11 +729,14 @@ void kmeans_fit( auto prepare_batch_weights = [&](const auto& wt_batch, IndexT cur_batch_size) { if (weight_ptr != nullptr) { raft::copy(batch_weights_buf.data_handle(), wt_batch.data(), cur_batch_size, stream); - if (weight_scale != DataT{1}) { + if (wt_sum != static_cast(n_samples)) { auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), cur_batch_size); - raft::linalg::map( - handle, bw, raft::mul_const_op{weight_scale}, raft::make_const_mdspan(bw)); + raft::linalg::map(handle, + bw, + raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, + raft::div_const_op{wt_sum}), + raft::make_const_mdspan(bw)); } } return raft::make_device_vector_view(batch_weights_buf.data_handle(), @@ -1012,8 +1032,16 @@ void kmeans_predict(raft::resources const& handle, else raft::matrix::fill(handle, weight.view(), DataT(1)); - // check if weights sum up to n_samples - if (normalize_weight) checkWeight(handle, weight.view(), workspace); + if (normalize_weight) { + DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); + if (wt_sum != static_cast(n_samples)) { + raft::linalg::map(handle, + weight.view(), + raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, + raft::div_const_op{wt_sum}), + raft::make_const_mdspan(weight.view())); + } + } auto minClusterAndDistance = raft::make_device_vector, IndexT>(handle, n_samples); diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index bc2de15726..97e8360d19 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -133,43 +133,43 @@ void countLabels(raft::resources const& handle, stream)); } -template -void checkWeight(raft::resources const& handle, - raft::device_vector_view weight, - rmm::device_uvector& workspace) +/** + * @brief Compute the sum of sample weights. + * + * Device-accessible mdspans are reduced on device via mapThenSumReduce; + * host mdspans are summed on the host. + * + * @return Sum of weights. + */ +template +DataT weightSum( + raft::resources const& handle, + raft::mdspan, raft::layout_right, Accessor> weight) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto wt_aggr = raft::make_device_scalar(handle, 0); - auto n_samples = weight.extent(0); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); + auto n_samples = weight.extent(0); + auto ns = static_cast(n_samples); - RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(), - temp_storage_bytes, - weight.data_handle(), - wt_aggr.data_handle(), - n_samples, - stream)); - DataT wt_sum = 0; - raft::copy(handle, - raft::make_host_scalar_view(&wt_sum), - raft::make_device_scalar_view(wt_aggr.data_handle())); - raft::resource::sync_stream(handle); + DataT wt_sum = DataT{0}; + if constexpr (raft::is_device_mdspan_v) { + auto stream = raft::resource::get_cuda_stream(handle); + auto d_wt_sum = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce( + d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle()); + raft::copy(&wt_sum, d_wt_sum.data_handle(), 1, stream); + raft::resource::sync_stream(handle); + } else { + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += weight(i); + } + } - if (wt_sum != n_samples) { + if (wt_sum != ns) { RAFT_LOG_DEBUG( "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %d samples", - n_samples); - - auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::map( - handle, weight, raft::mul_const_op{scale}, raft::make_const_mdspan(weight)); + "sum up to %zu samples", + static_cast(n_samples)); } + return wt_sum; } template @@ -608,53 +608,6 @@ DataT compute_centroid_shift(raft::resources const& handle, return result; } -/** - * @brief Compute the weight normalization scale factor for sample weights that may - * reside on host memory. Weights are normalized to sum to n_samples. - * - * Works on any contiguous pointer (host or device) by copying to host for the sum. - * - * @tparam DataT Weight type - * @tparam IndexT Index type - * - * @param[in] handle RAFT resources handle - * @param[in] weight_ptr Pointer to sample weights (host or device), may be nullptr - * @param[in] n_samples Number of samples - * @return Scale factor (1.0 if weights already sum to n_samples or nullptr) - */ -template -DataT compute_weight_scale(raft::resources const& handle, const DataT* weight_ptr, IndexT n_samples) -{ - if (weight_ptr == nullptr) { return DataT{1}; } - - bool is_device_accessible = - raft::is_device_accessible(raft::memory_type_from_pointer(weight_ptr)); - - DataT wt_sum = DataT{0}; - if (!is_device_accessible) { - for (IndexT i = 0; i < n_samples; ++i) { - wt_sum += weight_ptr[i]; - } - } else { - std::vector h_weights(n_samples); - auto d_view = raft::make_device_vector_view(weight_ptr, n_samples); - auto h_view = raft::make_host_vector_view(h_weights.data(), n_samples); - raft::copy(handle, h_view, d_view); - raft::resource::sync_stream(handle); - for (IndexT i = 0; i < n_samples; ++i) { - wt_sum += h_weights[i]; - } - } - - if (wt_sum == static_cast(n_samples)) { return DataT{1}; } - - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %zu samples", - static_cast(n_samples)); - return static_cast(n_samples) / wt_sum; -} - /** * @brief Process a single batch of data in the Lloyd iteration. * diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index cbc75c822c..2c8bf19018 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -467,15 +467,9 @@ void checkWeights(const raft::resources& handle, const auto& comm = raft::resource::get_comms(handle); - auto n_samples = weight.extent(0); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - workspace.data(), temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); + auto n_samples = weight.extent(0); + raft::linalg::mapThenSumReduce( + wt_aggr.data(), n_samples, raft::identity_op{}, stream, weight.data_handle()); comm.allreduce(wt_aggr.data(), // sendbuff wt_aggr.data(), // recvbuff @@ -491,9 +485,11 @@ void checkWeights(const raft::resources& handle, "sum up to %d samples", n_samples); - DataT scale = n_samples / wt_sum; - raft::linalg::map( - handle, weight, raft::mul_const_op(scale), raft::make_const_mdspan(weight)); + raft::linalg::map(handle, + weight, + raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, + raft::div_const_op{wt_sum}), + raft::make_const_mdspan(weight)); } } From 6ba759c4e37ed27c7de8220df883f36311931b36 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 21 Apr 2026 15:35:05 -0700 Subject: [PATCH 15/50] fix compilation --- cpp/src/cluster/detail/kmeans.cuh | 7 ++++--- cpp/src/cluster/detail/kmeans_common.cuh | 7 ++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 7ac43f43b2..109e4cbdfb 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -719,6 +719,8 @@ void kmeans_fit( auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); + rmm::device_uvector batch_workspace(streaming_batch_size, stream); + cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( X.data_handle(), n_samples, n_features, streaming_batch_size, stream); cuvs::spatial::knn::detail::utils::batch_load_iterator weight_batches( @@ -804,9 +806,7 @@ void kmeans_fit( raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); raft::matrix::fill(handle, weight_per_cluster.view(), DataT{0}); - raft::linalg::map(handle, - raft::make_device_scalar_view(clustering_cost.data_handle()), - raft::const_op{DataT{0}}); + raft::matrix::fill(handle, clustering_cost.view(), DataT{0}); auto centroids_const = raft::make_device_matrix_view( cur_centroids_ptr, n_clusters, n_features); @@ -871,6 +871,7 @@ void kmeans_fit( batch_sums.view(), batch_counts.view(), clustering_cost.view(), + batch_workspace, centroid_norms_opt); } if (!norms_cached && use_norm_cache) { diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 97e8360d19..186a9c5882 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -659,10 +659,10 @@ void process_batch( raft::device_matrix_view batch_sums, raft::device_vector_view batch_counts, raft::device_scalar_view clustering_cost, + rmm::device_uvector& batch_workspace, std::optional> centroid_norms = std::nullopt) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - IndexT current_batch_sz = batch_data.extent(0); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); minClusterAndDistanceCompute(handle, batch_data, @@ -681,9 +681,6 @@ void process_batch( const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); - auto batch_workspace = rmm::device_uvector( - current_batch_sz, stream, raft::resource::get_workspace_resource(handle)); - compute_centroid_adjustments(handle, batch_data, batch_weights, From e76eaac63b0138ee55948c12502fa72fd843a443 Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Wed, 22 Apr 2026 08:01:54 -0700 Subject: [PATCH 16/50] rel_tol Co-authored-by: Victor Lafargue --- cpp/src/cluster/detail/kmeans.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 109e4cbdfb..7e60c7422b 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -1035,7 +1035,8 @@ void kmeans_predict(raft::resources const& handle, if (normalize_weight) { DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); - if (wt_sum != static_cast(n_samples)) { + const DataT rel_tol = n_samples * std::numeric_limits::epsilon(); + if (std::abs(wt_sum - n_samples) <= rel_tol) { raft::linalg::map(handle, weight.view(), raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, From afbefdf5e666b59fff2c00f312eaa08974a33347 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 09:13:51 -0700 Subject: [PATCH 17/50] pass workspace --- cpp/src/cluster/detail/kmeans.cuh | 19 +++++++++++-------- cpp/src/cluster/kmeans_impl.cuh | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 109e4cbdfb..5740cccc2d 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -316,7 +316,8 @@ void kmeans_fit( sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional>> workspace = std::nullopt); /* * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. @@ -534,7 +535,8 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, weight_opt, centroidsRawData, inertia.view(), - n_iter.view()); + n_iter.view(), + std::ref(workspace)); } else if ((int)potentialCentroids.extent(0) < n_clusters) { // supplement with random @@ -602,7 +604,8 @@ void kmeans_fit( sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional>> workspace) { raft::common::nvtx::range fun_scope("kmeans_fit"); auto n_samples = X.extent(0); @@ -634,7 +637,8 @@ void kmeans_fit( DataT wt_sum = sample_weight.has_value() ? weightSum(handle, sample_weight.value()) : static_cast(n_samples); - rmm::device_uvector workspace(0, stream); + rmm::device_uvector local_workspace(0, stream); + rmm::device_uvector& ws = workspace.has_value() ? workspace->get() : local_workspace; constexpr bool data_on_device = raft::is_device_mdspan_v; @@ -664,11 +668,10 @@ void kmeans_fit( } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { auto run_kmeanspp = [&](raft::device_matrix_view init_data) { if (iter_params.oversampling_factor == 0) - kmeansPlusPlus( - handle, iter_params, init_data, centroidsRawData, workspace); + kmeansPlusPlus(handle, iter_params, init_data, centroidsRawData, ws); else initScalableKMeansPlusPlus( - handle, iter_params, init_data, centroidsRawData, workspace); + handle, iter_params, init_data, centroidsRawData, ws); }; if constexpr (data_on_device) { @@ -865,7 +868,7 @@ void kmeans_fit( minCAD_view, l2_const_view, L2NormBuf_OR_DistBuf, - workspace, + ws, centroid_sums.view(), weight_per_cluster.view(), batch_sums.view(), diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index f521edd07f..9be6aac929 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -20,10 +20,10 @@ void fit_main(raft::resources const& handle, { cuvs::cluster::kmeans::params p = params; p.init = kmeans::params::InitMethod::Array; - p.n_init = 1; auto sw = std::make_optional( raft::make_device_vector_view(sample_weights.data_handle(), X.extent(0))); - cuvs::cluster::kmeans::detail::kmeans_fit(handle, p, X, sw, centroids, inertia, n_iter); + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, p, X, sw, centroids, inertia, n_iter, std::ref(workspace)); } template From e4f08bf9b980f19929cb3e64a758f9437a797bfc Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 09:15:10 -0700 Subject: [PATCH 18/50] style --- cpp/src/cluster/detail/kmeans.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index add329f3a7..6b7ec54f35 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -1037,7 +1037,7 @@ void kmeans_predict(raft::resources const& handle, raft::matrix::fill(handle, weight.view(), DataT(1)); if (normalize_weight) { - DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); + DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); const DataT rel_tol = n_samples * std::numeric_limits::epsilon(); if (std::abs(wt_sum - n_samples) <= rel_tol) { raft::linalg::map(handle, From 4a8a85c7130cc058caf3e5664db331baf1cc8f62 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 09:38:52 -0700 Subject: [PATCH 19/50] do not use batch scratch space; rm update_centroids --- cpp/src/cluster/detail/kmeans.cuh | 48 ------------------------ cpp/src/cluster/detail/kmeans_common.cuh | 32 +++++----------- cpp/src/cluster/kmeans.cuh | 33 ---------------- 3 files changed, 10 insertions(+), 103 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 6b7ec54f35..4032b4cb14 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -262,50 +262,6 @@ void kmeansPlusPlus(raft::resources const& handle, } /// <<<< Step-5 >>> } -/** - * - * @tparam DataT - * @tparam IndexT - * @param handle - * @param[in] X input matrix (size n_samples, n_features) - * @param[in] weight number of samples currently assigned to each centroid - * @param[in] cur_centroids matrix of current centroids (size n_clusters, n_features) - * @param[in] l2norm_x - * @param[out] min_cluster_and_dist - * @param[out] new_centroids - * @param[out] new_weight - * @param[inout] workspace - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - - // TODO: Figure out how to best wrap iterator types in mdspan - LabelsIterator cluster_labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids, - rmm::device_uvector& workspace) -{ - auto n_clusters = centroids.extent(0); - - cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, - X, - sample_weights, - cluster_labels, - static_cast(n_clusters), - new_centroids, - weight_per_cluster, - workspace); - - cuvs::cluster::kmeans::detail::finalize_centroids(handle, - raft::make_const_mdspan(new_centroids), - raft::make_const_mdspan(weight_per_cluster), - centroids, - new_centroids); -} - template void kmeans_fit( raft::resources const& handle, @@ -719,8 +675,6 @@ void kmeans_fit( auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); auto centroid_norms_buf = raft::make_device_vector(handle, n_clusters); auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); rmm::device_uvector batch_workspace(streaming_batch_size, stream); @@ -871,8 +825,6 @@ void kmeans_fit( ws, centroid_sums.view(), weight_per_cluster.view(), - batch_sums.view(), - batch_counts.view(), clustering_cost.view(), batch_workspace, centroid_norms_opt); diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 186a9c5882..0d9eac1916 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -512,7 +512,8 @@ void compute_centroid_adjustments( IndexT n_clusters, raft::device_matrix_view centroid_sums, raft::device_vector_view weight_per_cluster, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + bool reset_sums = true) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -528,7 +529,8 @@ void compute_centroid_adjustments( X.extent(1), n_clusters, centroid_sums.data_handle(), - stream); + stream, + reset_sums); raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), cluster_labels, @@ -536,7 +538,8 @@ void compute_centroid_adjustments( static_cast(1), static_cast(n_samples), n_clusters, - stream); + stream, + reset_sums); } /** * @brief Finalize centroids by dividing accumulated sums by counts. @@ -635,8 +638,6 @@ DataT compute_centroid_shift(raft::resources const& handle, * @param[inout] workspace Resizable scratch * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) * @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into) - * @param[inout] batch_sums Scratch for this batch [n_clusters x n_features] - * @param[inout] batch_counts Scratch for this batch [n_clusters] * @param[inout] clustering_cost Running cost scalar (device) (added into) * @param[in] centroid_norms Optional precomputed centroid norms [n_clusters]. * When provided, skips internal centroid norm computation. @@ -656,8 +657,6 @@ void process_batch( rmm::device_uvector& workspace, raft::device_matrix_view centroid_sums, raft::device_vector_view weight_per_cluster, - raft::device_matrix_view batch_sums, - raft::device_vector_view batch_counts, raft::device_scalar_view clustering_cost, rmm::device_uvector& batch_workspace, std::optional> centroid_norms = std::nullopt) @@ -686,21 +685,10 @@ void process_batch( batch_weights, labels_itr, static_cast(centroid_sums.extent(0)), - batch_sums, - batch_counts, - batch_workspace); - - raft::linalg::add(centroid_sums.data_handle(), - centroid_sums.data_handle(), - batch_sums.data_handle(), - centroid_sums.size(), - stream); - - raft::linalg::add(weight_per_cluster.data_handle(), - weight_per_cluster.data_handle(), - batch_counts.data_handle(), - weight_per_cluster.size(), - stream); + centroid_sums, + weight_per_cluster, + batch_workspace, + /*reset_sums=*/false); raft::linalg::map( handle, diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index e4f9821990..12d3cc2c4b 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -324,39 +324,6 @@ void cluster_cost(raft::resources const& handle, handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); } -/** - * @brief Update centroids given current centroids and number of points assigned to each centroid. - * This function also produces a vector of RAFT key/value pairs containing the cluster assignment - * for each point and its distance. - * - * @tparam DataT - * @tparam IndexT - * @param[in] handle: Raft handle to use for managing library resources - * @param[in] X: input matrix (size n_samples, n_features) - * @param[in] sample_weights: number of samples currently assigned to each centroid (size n_samples) - * @param[in] centroids: matrix of current centroids (size n_clusters, n_features) - * @param[in] labels: Iterator of labels (can also be a raw pointer) - * @param[out] weight_per_cluster: sum of sample weights per cluster (size n_clusters) - * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - LabelsIterator labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids) -{ - // TODO: Passing these into the algorithm doesn't really present much of a benefit - // because they are being resized anyways. - // ref https://github.com/rapidsai/raft/issues/930 - rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); - - cuvs::cluster::kmeans::detail::update_centroids( - handle, X, sample_weights, centroids, labels, weight_per_cluster, new_centroids, workspace); -} - /** * @brief Compute distance for every sample to it's nearest centroid * From bbf2a9fc38de8b02f80493fae086950b8a297016 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 09:50:20 -0700 Subject: [PATCH 20/50] move the debug log --- cpp/src/cluster/detail/kmeans.cuh | 10 +++++++++- cpp/src/cluster/detail/kmeans_common.cuh | 7 ------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 4032b4cb14..f2ccbafd33 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -592,6 +592,11 @@ void kmeans_fit( sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; DataT wt_sum = sample_weight.has_value() ? weightSum(handle, sample_weight.value()) : static_cast(n_samples); + if (sample_weight.has_value() && wt_sum != static_cast(n_samples)) { + RAFT_LOG_DEBUG( + "[Warning!] KMeans: normalizing the user provided sample weight to sum up to %zu samples", + static_cast(n_samples)); + } rmm::device_uvector local_workspace(0, stream); rmm::device_uvector& ws = workspace.has_value() ? workspace->get() : local_workspace; @@ -991,7 +996,10 @@ void kmeans_predict(raft::resources const& handle, if (normalize_weight) { DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); const DataT rel_tol = n_samples * std::numeric_limits::epsilon(); - if (std::abs(wt_sum - n_samples) <= rel_tol) { + if (std::abs(wt_sum - n_samples) > rel_tol) { + RAFT_LOG_DEBUG( + "[Warning!] KMeans: normalizing the user provided sample weight to sum up to %zu samples", + static_cast(n_samples)); raft::linalg::map(handle, weight.view(), raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 0d9eac1916..01506df227 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -162,13 +162,6 @@ DataT weightSum( wt_sum += weight(i); } } - - if (wt_sum != ns) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %zu samples", - static_cast(n_samples)); - } return wt_sum; } From 410092c9b573c5b96ceb71122317200a2d90dc56 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 11:57:05 -0700 Subject: [PATCH 21/50] add new suffixed param struct --- c/include/cuvs/cluster/kmeans.h | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index fbc6877a00..3f3f487590 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -39,6 +39,8 @@ typedef enum { /** * @brief Hyper-parameters for the kmeans algorithm + * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v1. + * CalVer for the replacement: 26.08 */ struct cuvsKMeansParams { cuvsDistanceType metric; @@ -118,6 +120,84 @@ struct cuvsKMeansParams { int64_t init_size; }; +/** + * @brief Hyper-parameters for the kmeans algorithm + */ + struct cuvsKMeansParams_v1 { + cuvsDistanceType metric; + + /** + * The number of clusters to form as well as the number of centroids to generate (default:8). + */ + int n_clusters; + + /** + * Method for initialization, defaults to k-means++: + * - cuvsKMeansInitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm + * to select the initial cluster centers. + * - cuvsKMeansInitMethod::Random (random): Choose 'n_clusters' observations (rows) at + * random from the input data for the initial centroids. + * - cuvsKMeansInitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + */ + cuvsKMeansInitMethod init; + + /** + * Maximum number of iterations of the k-means algorithm for a single run. + */ + int max_iter; + + /** + * Relative tolerance with regards to inertia to declare convergence. + */ + double tol; + + /** + * Number of instance k-means algorithm will be run with different seeds. + */ + int n_init; + + /** + * Oversampling factor for use in the k-means|| algorithm + */ + double oversampling_factor; + + /** + * batch_samples and batch_centroids are used to tile 1NN computation which is + * useful to optimize/control the memory footprint + * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + * then don't tile the centroids + */ + int batch_samples; + + /** + * if 0 then batch_centroids = n_clusters + */ + int batch_centroids; + + /** + * Whether to use hierarchical (balanced) kmeans or not + */ + bool hierarchical; + + /** + * For hierarchical k-means , defines the number of training iterations + */ + int hierarchical_n_iters; + + /** + * Number of samples to process per GPU batch for the batched (host-data) API. + * When set to 0, defaults to n_samples (process all at once). + */ + int64_t streaming_batch_size; + + /** + * Number of samples to draw for KMeansPlusPlus initialization. + * When set to 0, uses heuristic min(3 * n_clusters, n_samples) for host data, + * or n_samples for device data. + */ + int64_t init_size; +}; + typedef struct cuvsKMeansParams* cuvsKMeansParams_t; /** From c515c1ed6b382d6aa509846863c646b99c344199 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 14:36:25 -0700 Subject: [PATCH 22/50] address pr reviews --- cpp/include/cuvs/cluster/kmeans.hpp | 10 +++-- cpp/src/cluster/detail/kmeans.cuh | 50 +++++++++++++++------- cpp/src/cluster/detail/kmeans_common.cuh | 22 ++++++---- cpp/src/cluster/detail/kmeans_mg.cuh | 1 + python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 14 ++++++ 5 files changed, 70 insertions(+), 27 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index ff7d056f7d..36e6f46c51 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -115,9 +115,13 @@ struct params : base_params { /** * Number of samples to randomly draw for the KMeansPlusPlus initialization * step. A random subset of this size is used for centroid seeding. - * When set to 0 the default depends on the data location: - * - Device data: n_samples (use the full dataset). - * - Host data: min(3 * n_clusters, n_samples). + * + * Only applies when dataset is on host; for device data the full dataset + * is always used for seeding and this parameter is ignored. + * + * When set to 0 (default) with host data uses `min(3 * n_clusters, n_samples)` + * as a default. + * * Default: 0. */ int64_t init_size = 0; diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index f2ccbafd33..a1b94adf5c 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -590,9 +590,12 @@ void kmeans_fit( const DataT* weight_ptr = sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; - DataT wt_sum = sample_weight.has_value() ? weightSum(handle, sample_weight.value()) - : static_cast(n_samples); - if (sample_weight.has_value() && wt_sum != static_cast(n_samples)) { + DataT wt_sum = sample_weight.has_value() ? weightSum(handle, sample_weight.value()) + : static_cast(n_samples); + const DataT wt_rel_tol = n_samples * std::numeric_limits::epsilon(); + const bool needs_wt_rescale = + sample_weight.has_value() && std::abs(wt_sum - static_cast(n_samples)) > wt_rel_tol; + if (needs_wt_rescale) { RAFT_LOG_DEBUG( "[Warning!] KMeans: normalizing the user provided sample weight to sum up to %zu samples", static_cast(n_samples)); @@ -612,6 +615,29 @@ void kmeans_fit( streaming_batch_size = static_cast(n_samples); } + // Preallocate the host-side KMeans++ init sample buffer. + std::optional> init_sample; + if constexpr (!data_on_device) { + if (pams.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IndexT default_init_size = + std::min(static_cast(std::int64_t{3} * n_clusters), n_samples); + IndexT init_sample_size = pams.init_size > 0 + ? std::min(static_cast(pams.init_size), n_samples) + : default_init_size; + + if (pams.init_size <= 0 && init_sample_size < n_samples) { + RAFT_LOG_WARN( + "KMeans.fit: KMeans++ initialization is using a random subsample of %zu/%zu host rows " + "(params.init_size=0 defaults to min(3 * n_clusters, n_samples) for host data). " + "Set params.init_size to n_samples to use the full dataset for seeding.", + static_cast(init_sample_size), + static_cast(n_samples)); + } + + init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + } + } + auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, raft::device_matrix_view centroidsRawData) { if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { @@ -638,17 +664,8 @@ void kmeans_fit( if constexpr (data_on_device) { run_kmeanspp(X); } else { - IndexT default_init_size = - std::min(static_cast(std::int64_t{3} * n_clusters), n_samples); - IndexT init_sample_size = - iter_params.init_size > 0 - ? std::min(static_cast(iter_params.init_size), n_samples) - : default_init_size; - - auto init_sample = - raft::make_device_matrix(handle, init_sample_size, n_features); - raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - run_kmeanspp(raft::make_const_mdspan(init_sample.view())); + raft::matrix::sample_rows(handle, random_state, X, init_sample->view()); + run_kmeanspp(raft::make_const_mdspan(init_sample->view())); } } else { THROW("unknown initialization method to select initial centers"); @@ -693,7 +710,7 @@ void kmeans_fit( auto prepare_batch_weights = [&](const auto& wt_batch, IndexT cur_batch_size) { if (weight_ptr != nullptr) { raft::copy(batch_weights_buf.data_handle(), wt_batch.data(), cur_batch_size, stream); - if (wt_sum != static_cast(n_samples)) { + if (needs_wt_rescale) { auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), cur_batch_size); raft::linalg::map(handle, @@ -994,7 +1011,8 @@ void kmeans_predict(raft::resources const& handle, raft::matrix::fill(handle, weight.view(), DataT(1)); if (normalize_weight) { - DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); + DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); const DataT rel_tol = n_samples * std::numeric_limits::epsilon(); if (std::abs(wt_sum - n_samples) > rel_tol) { RAFT_LOG_DEBUG( diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 01506df227..4dd75fb0f2 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -487,14 +487,20 @@ void countSamplesInCluster(raft::resources const& handle, * @tparam IndexT Index type * @tparam LabelsIterator Iterator type for cluster labels * - * @param[in] handle RAFT resources handle - * @param[in] X Input samples [n_samples x n_features] - * @param[in] sample_weights Weights for each sample [n_samples] - * @param[in] cluster_labels Cluster assignment for each sample (iterator) - * @param[in] n_clusters Number of clusters - * @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features] - * @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters] - * @param[inout] workspace Workspace buffer for intermediate operations + * @param[in] handle RAFT resources handle + * @param[in] X Input samples [n_samples x n_features] + * @param[in] sample_weights Weights for each sample [n_samples] + * @param[in] cluster_labels Cluster assignment for each sample (iterator) + * @param[in] n_clusters Number of clusters + * @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features]. When + * @p reset_sums is true this is overwritten; otherwise the + * contribution from @p X is added to the existing contents. + * @param[inout] weight_per_cluster Sum of weights per cluster [n_clusters]. Follows the same + * overwrite-vs-accumulate semantics as @p centroid_sums. + * @param[inout] workspace Workspace buffer for intermediate operations + * @param[in] reset_sums If true (default), outputs are reset to zero before reducing; + * if false, this call's contribution is accumulated into the + * existing values (used by `process_batch` for streaming fits). */ template void compute_centroid_adjustments( diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 2c8bf19018..6f8a1054af 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -478,6 +478,7 @@ void checkWeights(const raft::resources& handle, stream); DataT wt_sum = wt_aggr.value(stream); raft::resource::sync_stream(handle, stream); + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); if (wt_sum != n_samples) { CUVS_LOG_KMEANS(handle, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 246ac4138c..b192d71058 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -4,6 +4,7 @@ # # cython: language_level=3 +import warnings from collections import namedtuple import numpy as np @@ -76,6 +77,10 @@ cdef class KMeansParams: [batch_samples x n_clusters]. batch_centroids : int Number of centroids to process in each batch. If 0, uses n_clusters. + inertia_check : bool + Deprecated and ignored. Will be + removed in a future release. Inertia-based convergence checking + always runs. init_size : int Number of samples to draw for KMeansPlusPlus initialization with host (out-of-core) data. When set to 0, uses the heuristic @@ -113,6 +118,7 @@ cdef class KMeansParams: oversampling_factor=None, batch_samples=None, batch_centroids=None, + inertia_check=None, init_size=None, streaming_batch_size=None, hierarchical=None, @@ -136,6 +142,14 @@ cdef class KMeansParams: self.params.batch_samples = batch_samples if batch_centroids is not None: self.params.batch_centroids = batch_centroids + if inertia_check is not None: + warnings.warn( + "KMeansParams `inertia_check` is deprecated and ignored; " + "inertia-based convergence checking always runs." + DeprecationWarning, + stacklevel=2, + ) + self.params.inertia_check = inertia_check if init_size is not None: self.params.init_size = init_size if streaming_batch_size is not None: From e8e63ab5f05d8cc5be46f39a0e7ca447ea5b904c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 14:42:34 -0700 Subject: [PATCH 23/50] fix docstring --- cpp/src/cluster/detail/kmeans_common.cuh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 4dd75fb0f2..cfd500d036 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -492,15 +492,13 @@ void countSamplesInCluster(raft::resources const& handle, * @param[in] sample_weights Weights for each sample [n_samples] * @param[in] cluster_labels Cluster assignment for each sample (iterator) * @param[in] n_clusters Number of clusters - * @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features]. When - * @p reset_sums is true this is overwritten; otherwise the - * contribution from @p X is added to the existing contents. + * @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features] * @param[inout] weight_per_cluster Sum of weights per cluster [n_clusters]. Follows the same - * overwrite-vs-accumulate semantics as @p centroid_sums. + * overwrite-vs-accumulate semantics as `centroid_sums` * @param[inout] workspace Workspace buffer for intermediate operations * @param[in] reset_sums If true (default), outputs are reset to zero before reducing; * if false, this call's contribution is accumulated into the - * existing values (used by `process_batch` for streaming fits). + * existing `centroid_sums` */ template void compute_centroid_adjustments( From 30c457cc589048d8aa83fe2ba126cf9c2ff3c215 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 14:53:27 -0700 Subject: [PATCH 24/50] fix wt_sum warning --- cpp/src/cluster/detail/kmeans.cuh | 3 +-- cpp/src/cluster/detail/kmeans_common.cuh | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index a1b94adf5c..540b3fd4c2 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -1011,8 +1011,7 @@ void kmeans_predict(raft::resources const& handle, raft::matrix::fill(handle, weight.view(), DataT(1)); if (normalize_weight) { - DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); - RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); + DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); const DataT rel_tol = n_samples * std::numeric_limits::epsilon(); if (std::abs(wt_sum - n_samples) > rel_tol) { RAFT_LOG_DEBUG( diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index cfd500d036..8af6114b4b 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -147,7 +147,6 @@ DataT weightSum( raft::mdspan, raft::layout_right, Accessor> weight) { auto n_samples = weight.extent(0); - auto ns = static_cast(n_samples); DataT wt_sum = DataT{0}; if constexpr (raft::is_device_mdspan_v) { @@ -162,6 +161,7 @@ DataT weightSum( wt_sum += weight(i); } } + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); return wt_sum; } From ab966239611889d286e69e026fea5e71c93a9c2f Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 15:00:25 -0700 Subject: [PATCH 25/50] rm deprecationwarning and instead add FutureWarning:= --- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b192d71058..2e9046b4b2 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -145,11 +145,9 @@ cdef class KMeansParams: if inertia_check is not None: warnings.warn( "KMeansParams `inertia_check` is deprecated and ignored; " - "inertia-based convergence checking always runs." - DeprecationWarning, - stacklevel=2, + "inertia-based convergence checking always runs.", + FutureWarning ) - self.params.inertia_check = inertia_check if init_size is not None: self.params.init_size = init_size if streaming_batch_size is not None: From 269f23cde51cc8bc5929b313377e56b7f953a8f4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 15:06:44 -0700 Subject: [PATCH 26/50] unweighted to never materialize batch weights --- cpp/src/cluster/detail/kmeans.cuh | 55 +++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 540b3fd4c2..4485ffd3e6 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -702,14 +702,21 @@ void kmeans_fit( cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( X.data_handle(), n_samples, n_features, streaming_batch_size, stream); - cuvs::spatial::knn::detail::utils::batch_load_iterator weight_batches( - weight_ptr, n_samples, 1, streaming_batch_size, stream); - - if (weight_ptr == nullptr) { raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); } + // Only materialize weight batches when weights are provided; otherwise we + // never touch this iterator (and never dereference a null-pointer batch). + std::optional> weight_batches; + if (weight_ptr != nullptr) { + weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream); + } else { + raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); + } - auto prepare_batch_weights = [&](const auto& wt_batch, IndexT cur_batch_size) { - if (weight_ptr != nullptr) { - raft::copy(batch_weights_buf.data_handle(), wt_batch.data(), cur_batch_size, stream); + // Copies and optionally rescales `wt_data` into `batch_weights_buf`. When + // `wt_data == nullptr`, the buffer is assumed to have been pre-filled with 1 + // for the unweighted case. + auto prepare_batch_weights = [&](const DataT* wt_data, IndexT cur_batch_size) { + if (wt_data != nullptr) { + raft::copy(batch_weights_buf.data_handle(), wt_data, cur_batch_size, stream); if (needs_wt_rescale) { auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), cur_batch_size); @@ -802,16 +809,23 @@ void kmeans_fit( } data_batches.reset(); - weight_batches.reset(); - auto wt_it = weight_batches.begin(); + using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; + std::optional wt_it; + if (weight_batches.has_value()) { + weight_batches->reset(); + wt_it = weight_batches->begin(); + } for (const auto& data_batch : data_batches) { IndexT cur_batch_size = static_cast(data_batch.size()); - const auto& wt_batch = *wt_it; - ++wt_it; + const DataT* wt_data = nullptr; + if (wt_it.has_value()) { + wt_data = (**wt_it).data(); + ++(*wt_it); + } auto batch_data_view = raft::make_device_matrix_view( data_batch.data(), cur_batch_size, n_features); - auto batch_weights_view = prepare_batch_weights(wt_batch, cur_batch_size); + auto batch_weights_view = prepare_batch_weights(wt_data, cur_batch_size); auto minCAD_view = raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle(), cur_batch_size); @@ -898,18 +912,25 @@ void kmeans_fit( iter_inertia = DataT{0}; data_batches.reset(); - weight_batches.reset(); - auto wt_it = weight_batches.begin(); + using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; + std::optional wt_it; + if (weight_batches.has_value()) { + weight_batches->reset(); + wt_it = weight_batches->begin(); + } for (const auto& data_batch : data_batches) { IndexT cur_batch_size = static_cast(data_batch.size()); - const auto& wt_batch = *wt_it; - ++wt_it; + const DataT* wt_data = nullptr; + if (wt_it.has_value()) { + wt_data = (**wt_it).data(); + ++(*wt_it); + } auto batch_data_view = raft::make_device_matrix_view( data_batch.data(), cur_batch_size, n_features); std::optional> batch_sw = std::nullopt; - if (weight_ptr != nullptr) { batch_sw = prepare_batch_weights(wt_batch, cur_batch_size); } + if (weight_ptr != nullptr) { batch_sw = prepare_batch_weights(wt_data, cur_batch_size); } DataT batch_cost = DataT{0}; cuvs::cluster::kmeans::cluster_cost(handle, From 80a22ca572546c6d0fca93a9595e2e6cb6a6b14f Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 19:11:33 -0700 Subject: [PATCH 27/50] add cpp tests --- cpp/tests/cluster/kmeans.cu | 245 ++++++++++++++++++++++++++++++------ 1 file changed, 207 insertions(+), 38 deletions(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 5d48ef099e..1213d2a18f 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -365,30 +365,27 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam>::GetParam(); + testparams = ::testing::TestWithParam>::GetParam(); + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + int n_clusters = testparams.n_clusters; + auto stream = raft::resource::get_cuda_stream(handle); - int n_samples = testparams.n_row; - int n_features = testparams.n_col; - params.n_clusters = testparams.n_clusters; - params.tol = testparams.tol; - params.rng_state.seed = 1; - params.oversampling_factor = 0; - - auto stream = raft::resource::get_cuda_stream(handle); - auto X = raft::make_device_matrix(handle, n_samples, n_features); - auto labels = raft::make_device_vector(handle, n_samples); - - raft::random::make_blobs(X.data_handle(), - labels.data_handle(), + d_X.resize(static_cast(n_samples) * n_features, stream); + d_labels_ref.resize(n_samples, stream); + raft::random::make_blobs(d_X.data(), + d_labels_ref.data(), n_samples, n_features, - params.n_clusters, + n_clusters, stream, true, nullptr, @@ -399,36 +396,63 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam h_X(n_samples * n_features); - raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + h_X.resize(static_cast(n_samples) * n_features); + raft::update_host(h_X.data(), d_X.data(), h_X.size(), stream); + + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + raft::matrix::fill( + handle, raft::make_device_vector_view(d_sample_weight.data(), n_samples), T(1)); + } else { + d_sample_weight.resize(0, stream); + } + raft::resource::sync_stream(handle, stream); + } + + raft::device_matrix_view X_dview() const + { + return raft::make_device_matrix_view( + d_X.data(), testparams.n_row, testparams.n_col); + } + + raft::host_matrix_view h_X_view() const + { + return raft::make_host_matrix_view( + h_X.data(), testparams.n_row, testparams.n_col); + } + + std::optional> d_sw_view() const + { + if (!testparams.weighted) return std::nullopt; + return std::make_optional( + raft::make_device_vector_view(d_sample_weight.data(), testparams.n_row)); + } + + void fitBatchedTest() + { + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + auto stream = raft::resource::get_cuda_stream(handle); d_labels.resize(n_samples, stream); - d_labels_ref.resize(n_samples, stream); d_centroids.resize(params.n_clusters * n_features, stream); d_centroids_ref.resize(params.n_clusters * n_features, stream); - raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); raft::random::RngState rng(params.rng_state.seed); raft::random::uniform( handle, rng, d_centroids.data(), params.n_clusters * n_features, T(-1), T(1)); raft::copy(d_centroids_ref.data(), d_centroids.data(), params.n_clusters * n_features, stream); - auto h_X_view = - raft::make_host_matrix_view(h_X.data(), n_samples, n_features); auto d_centroids_view = raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); - std::optional> d_sw = std::nullopt; - rmm::device_uvector d_sample_weight(0, stream); - if (testparams.weighted) { - d_sample_weight.resize(n_samples, stream); - d_sw = std::make_optional( - raft::make_device_vector_view(d_sample_weight.data(), n_samples)); - raft::matrix::fill( - handle, raft::make_device_vector_view(d_sample_weight.data(), n_samples), T(1)); - } + auto d_sw = d_sw_view(); auto d_centroids_ref_view = raft::make_device_matrix_view(d_centroids_ref.data(), params.n_clusters, n_features); @@ -440,7 +464,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(&ref_inertia), @@ -462,7 +486,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(&inertia), @@ -481,7 +505,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam( d_centroids_ref.data(), params.n_clusters, n_features), @@ -493,7 +517,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam( d_centroids.data(), params.n_clusters, n_features), @@ -523,7 +547,138 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(handle, n_clusters, n_features); + T inertia = 0; + int64_t n_iter = 0; + cuvs::cluster::kmeans::fit( + handle, + p, + h_X_view(), + std::optional>{std::nullopt}, + d_centroids_buf.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + raft::resource::sync_stream(handle, stream); + return inertia; + } + + void runInitSizeCompare() + { + T inertia_default = fitHostWithInitSize(0); + T inertia_full = fitHostWithInitSize(testparams.n_row); + + ASSERT_TRUE(std::isfinite(inertia_default)); + ASSERT_TRUE(std::isfinite(inertia_full)); + ASSERT_GT(inertia_default, T(0)); + ASSERT_GT(inertia_full, T(0)); + + // Full-dataset seeding has at least as much information as the subsample + // default, so the converged inertia should not be worse (modulo reduction + // -order noise from the nondeterministic atomics used in the fused NN + // kernel, worst observed ~8.6e-7 in single precision). + const T rel = T(1e-5); + ASSERT_LE(inertia_full, inertia_default * (T(1) + rel)); + } + + T fitKMeansPlusPlus(int n_init_value) + { + int n_features = testparams.n_col; + int n_clusters = testparams.n_clusters; + auto stream = raft::resource::get_cuda_stream(handle); + + cuvs::cluster::kmeans::params p; + p.n_clusters = n_clusters; + p.tol = testparams.tol; + p.n_init = n_init_value; + p.init = cuvs::cluster::kmeans::params::KMeansPlusPlus; + p.max_iter = 10; + p.rng_state.seed = 7; + p.oversampling_factor = 0; + + auto d_centroids_buf = raft::make_device_matrix(handle, n_clusters, n_features); + T inertia = 0; + int n_iter = 0; + cuvs::cluster::kmeans::fit(handle, + p, + X_dview(), + std::optional>{std::nullopt}, + d_centroids_buf.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + raft::resource::sync_stream(handle, stream); + return inertia; + } + + void runMultiSeedCheck() + { + T inertia1 = fitKMeansPlusPlus(1); + T inertia3 = fitKMeansPlusPlus(3); + ASSERT_GT(inertia1, T(0)); + ASSERT_GT(inertia3, T(0)); + // n_init > 1 keeps the best trial, so it should not be worse than + // n_init == 1 (modulo reduction-order noise from the nondeterministic + // atomics used in the fused NN kernel). + const T rel = T(1e-5); + ASSERT_LE(inertia3, inertia1 * (T(1) + rel)); + } + + void runZeroCost() + { + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + int n_clusters = n_samples; + auto stream = raft::resource::get_cuda_stream(handle); + + auto d_centroids_buf = raft::make_device_matrix(handle, n_clusters, n_features); + raft::copy(d_centroids_buf.data_handle(), + d_X.data(), + static_cast(n_samples) * n_features, + stream); + + cuvs::cluster::kmeans::params p; + p.n_clusters = n_clusters; + p.tol = testparams.tol; + p.n_init = 1; + p.init = cuvs::cluster::kmeans::params::Array; + p.max_iter = 5; + p.rng_state.seed = 1; + p.oversampling_factor = 0; + + T inertia = 0; + int n_iter = 0; + ASSERT_NO_THROW(cuvs::cluster::kmeans::fit( + handle, + p, + X_dview(), + std::optional>{std::nullopt}, + d_centroids_buf.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter))); + raft::resource::sync_stream(handle, stream); + + // The expanded L2 formula ||x||^2 - 2 x.c + ||c||^2 does not cancel to + // exactly 0 even when x == c due to float roundoff. The largest residual + // we observe across our parameterized blob shapes is ~2 in single + // precision; use an absolute upper bound of 10 for headroom. + ASSERT_LE(inertia, T(10)); + } protected: raft::resources handle; @@ -532,6 +687,9 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam d_labels_ref; rmm::device_uvector d_centroids; rmm::device_uvector d_centroids_ref; + rmm::device_uvector d_X; + rmm::device_uvector d_sample_weight; + std::vector h_X; double score; testing::AssertionResult centroids_match = testing::AssertionSuccess(); bool inertia_match = false; @@ -568,16 +726,26 @@ typedef KmeansFitBatchedTest KmeansFitBatchedTestD; TEST_P(KmeansFitBatchedTestF, Result) { + prepareBlobInputs(); + fitBatchedTest(); ASSERT_TRUE(centroids_match); ASSERT_TRUE(score >= 0.99); ASSERT_TRUE(inertia_match); + runInitSizeCompare(); + runMultiSeedCheck(); + runZeroCost(); } TEST_P(KmeansFitBatchedTestD, Result) { + prepareBlobInputs(); + fitBatchedTest(); ASSERT_TRUE(centroids_match); ASSERT_TRUE(score >= 0.99); ASSERT_TRUE(inertia_match); + runInitSizeCompare(); + runMultiSeedCheck(); + runZeroCost(); } INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, @@ -586,4 +754,5 @@ INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, KmeansFitBatchedTestD, ::testing::ValuesIn(batched_inputsd2)); + } // namespace cuvs From ac06b0570b9726d2aa48c1d635bfd7ba81daceef Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Apr 2026 19:22:40 -0700 Subject: [PATCH 28/50] update cpp tests --- cpp/tests/cluster/kmeans.cu | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 1213d2a18f..b35b9d38d9 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -581,19 +581,29 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam 1 keeps the best trial, so it should not be worse than - // n_init == 1 (modulo reduction-order noise from the nondeterministic - // atomics used in the fused NN kernel). + // n_init == 1. const T rel = T(1e-5); ASSERT_LE(inertia3, inertia1 * (T(1) + rel)); } From 056934087673595585f5dbf5dddce962c93ac6e1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 23 Apr 2026 16:52:31 -0700 Subject: [PATCH 29/50] revert batch norms cache --- cpp/src/cluster/detail/kmeans.cuh | 36 ++++++------------------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 4485ffd3e6..c6eecd1ae3 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -741,10 +741,6 @@ void kmeans_fit( bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; - bool use_norm_cache = need_compute_norms && !data_on_device; - std::vector h_norm_cache; - if (use_norm_cache) { h_norm_cache.resize(n_samples); } - bool norms_cached = false; auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { auto batch_view = @@ -760,11 +756,6 @@ void kmeans_fit( } }; - if (need_compute_norms && data_on_device) { - compute_batch_norms(X.data_handle(), n_samples); - norms_cached = true; - } - std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); @@ -785,7 +776,7 @@ void kmeans_fit( DataT iter_inertia = std::numeric_limits::max(); IndexT n_current_iter = 0; - DataT priorClusteringCost = 0; + // DataT priorClusteringCost = 0; for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) { RAFT_LOG_DEBUG("KMeans.fit: Iteration-%d", n_current_iter); @@ -830,19 +821,8 @@ void kmeans_fit( auto minCAD_view = raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle(), cur_batch_size); - if (need_compute_norms && !norms_cached) { + if (need_compute_norms) { compute_batch_norms(data_batch.data(), cur_batch_size); - if (use_norm_cache) { - raft::copy(h_norm_cache.data() + data_batch.offset(), - L2NormBatch.data_handle(), - cur_batch_size, - stream); - } - } else if (use_norm_cache) { - raft::copy(L2NormBatch.data_handle(), - h_norm_cache.data() + data_batch.offset(), - cur_batch_size, - stream); } auto l2_const_view = raft::make_device_vector_view( @@ -865,10 +845,6 @@ void kmeans_fit( batch_workspace, centroid_norms_opt); } - if (!norms_cached && use_norm_cache) { - raft::resource::sync_stream(handle, stream); - norms_cached = true; - } finalize_centroids(handle, raft::make_const_mdspan(centroid_sums.view()), @@ -887,15 +863,15 @@ void kmeans_fit( DataT curClusteringCost = DataT{0}; raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); + // raft::resource::sync_stream(handle, stream); if (curClusteringCost == DataT{0}) { RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); } else if (n_current_iter > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - iter_params.tol) done = true; + // DataT delta = curClusteringCost / priorClusteringCost; + // if (delta > 1 - iter_params.tol) done = true; } - priorClusteringCost = curClusteringCost; + // priorClusteringCost = curClusteringCost; if (sqrdNormError < iter_params.tol) done = true; From 8cac63a8a553b20d77ec44557eb2fc61bd1e0229 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 24 Apr 2026 13:03:13 -0700 Subject: [PATCH 30/50] increase zero cost threshold --- cpp/tests/cluster/kmeans.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index b35b9d38d9..6261da7afe 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -684,9 +684,9 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam Date: Fri, 24 Apr 2026 12:50:53 -0700 Subject: [PATCH 31/50] apply cuda event plus re-add h_norm_cache --- cpp/src/cluster/detail/kmeans.cuh | 87 +++++++++++++++++------- cpp/src/cluster/detail/kmeans_common.cuh | 47 ++++++++++--- 2 files changed, 97 insertions(+), 37 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index c6eecd1ae3..6650491b1b 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -741,6 +743,10 @@ void kmeans_fit( bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; + bool use_norm_cache = need_compute_norms && !data_on_device; + auto h_norm_cache = + raft::make_pinned_vector(handle, use_norm_cache ? n_samples : 0); + bool norms_cached = false; auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { auto batch_view = @@ -774,11 +780,28 @@ void kmeans_fit( iter_params, raft::make_device_matrix_view(cur_centroids_ptr, n_clusters, n_features)); - DataT iter_inertia = std::numeric_limits::max(); - IndexT n_current_iter = 0; - // DataT priorClusteringCost = 0; + DataT iter_inertia = std::numeric_limits::max(); + IndexT n_current_iter = 0; + auto sqrdNormError = raft::make_device_scalar(handle, DataT{0}); + + auto d_prior_cost = raft::make_device_scalar(handle, DataT{0}); + auto d_done_flag = raft::make_device_scalar(handle, 0); + auto h_done_flag = raft::make_pinned_scalar(handle, 0); + + cudaEvent_t convergence_event; + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&convergence_event, cudaEventDisableTiming)); for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) { + if (n_current_iter > 1) { + RAFT_CUDA_TRY(cudaEventSynchronize(convergence_event)); + if (*h_done_flag.data_handle()) { + n_current_iter--; + RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", + n_current_iter); + break; + } + } + RAFT_LOG_DEBUG("KMeans.fit: Iteration-%d", n_current_iter); raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); @@ -823,6 +846,17 @@ void kmeans_fit( if (need_compute_norms) { compute_batch_norms(data_batch.data(), cur_batch_size); + if (use_norm_cache) { + raft::copy(h_norm_cache.data_handle() + data_batch.offset(), + L2NormBatch.data_handle(), + cur_batch_size, + stream); + } + } else if (use_norm_cache) { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data_handle() + data_batch.offset(), + cur_batch_size, + stream); } auto l2_const_view = raft::make_device_vector_view( @@ -845,6 +879,7 @@ void kmeans_fit( batch_workspace, centroid_norms_opt); } + if (!norms_cached && use_norm_cache) { norms_cached = true; } finalize_centroids(handle, raft::make_const_mdspan(centroid_sums.view()), @@ -852,36 +887,36 @@ void kmeans_fit( centroids_const, new_centroids_view); - DataT sqrdNormError = - compute_centroid_shift(handle, - raft::make_const_mdspan(centroids_const), - raft::make_const_mdspan(new_centroids_view)); + compute_centroid_shift(handle, + raft::make_const_mdspan(centroids_const), + raft::make_const_mdspan(new_centroids_view), + sqrdNormError.view()); std::swap(cur_centroids_ptr, new_centroids_ptr); - bool done = false; - - DataT curClusteringCost = DataT{0}; - raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream); - // raft::resource::sync_stream(handle, stream); + auto d_cost_view = raft::make_device_scalar_view(clustering_cost.data_handle()); + auto d_prior_view = d_prior_cost.view(); + auto d_norm_view = raft::make_device_scalar_view(sqrdNormError.data_handle()); + auto d_done_view = d_done_flag.view(); + DataT tol = iter_params.tol; + int iter = n_current_iter; - if (curClusteringCost == DataT{0}) { - RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); - } else if (n_current_iter > 1) { - // DataT delta = curClusteringCost / priorClusteringCost; - // if (delta > 1 - iter_params.tol) done = true; - } - // priorClusteringCost = curClusteringCost; - - if (sqrdNormError < iter_params.tol) done = true; + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(d_done_flag.data_handle(), 1), + [=] __device__(int) { + check_convergence(d_cost_view, d_prior_view, d_norm_view, tol, iter, d_done_view); + return *d_done_view.data_handle(); + }); - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", - n_current_iter); - break; - } + raft::copy(handle, + raft::make_pinned_scalar_view(h_done_flag.data_handle()), + raft::make_device_scalar_view(d_done_flag.data_handle())); + RAFT_CUDA_TRY(cudaEventRecord(convergence_event, stream)); } + RAFT_CUDA_TRY(cudaEventDestroy(convergence_event)); + { auto centroids_const = raft::make_device_matrix_view( cur_centroids_ptr, n_clusters, n_features); diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 8af6114b4b..ec46b97b4e 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -586,26 +586,51 @@ void finalize_centroids(raft::resources const& handle, /** * @brief Compute the squared norm difference between two centroid sets. * - * Returns sum((old_centroids - new_centroids)^2). - * Used for convergence checking. + * Writes sum((old_centroids - new_centroids)^2) into @p sqrd_norm_out. + * Used for convergence checking. Fully asynchronous — no stream sync. */ template -DataT compute_centroid_shift(raft::resources const& handle, - raft::device_matrix_view old_centroids, - raft::device_matrix_view new_centroids) +void compute_centroid_shift(raft::resources const& handle, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids, + raft::device_scalar_view sqrd_norm_out) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + raft::linalg::mapThenSumReduce(sqrd_norm_out.data_handle(), old_centroids.size(), raft::sqdiff_op{}, stream, old_centroids.data_handle(), new_centroids.data_handle()); - DataT result = 0; - raft::copy(&result, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle); - return result; +} + +/** + * @brief Evaluate convergence criteria entirely on device. + * + * Checks the cost-ratio and centroid-shift stopping conditions and writes + * a boolean result (0 or 1) into @p done_flag. Also advances + * @p prior_clustering_cost to the current cost for the next iteration. + */ +template +__device__ void check_convergence(raft::device_scalar_view clustering_cost, + raft::device_scalar_view prior_clustering_cost, + raft::device_scalar_view sqrd_norm_error, + DataT tol, + int n_iter, + raft::device_scalar_view done_flag) +{ + DataT cur_cost = *clustering_cost.data_handle(); + DataT norm_err = *sqrd_norm_error.data_handle(); + int done = 0; + + if (cur_cost != DataT{0} && n_iter > 1) { + DataT delta = cur_cost / *prior_clustering_cost.data_handle(); + if (delta > DataT{1} - tol) done = 1; + } + if (norm_err < tol) done = 1; + + *prior_clustering_cost.data_handle() = cur_cost; + *done_flag.data_handle() = done; } /** From 9fc74b1a1a1de80f5abd99e8defa43bd0a9aaa60 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 24 Apr 2026 13:37:56 -0700 Subject: [PATCH 32/50] rm cosine expanded stuff --- cpp/src/cluster/detail/kmeans.cuh | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 6650491b1b..4a0195d1de 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -741,8 +741,7 @@ void kmeans_fit( static_cast(streaming_batch_size)); bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; + metric == cuvs::distance::DistanceType::L2SqrtExpanded; bool use_norm_cache = need_compute_norms && !data_on_device; auto h_norm_cache = raft::make_pinned_vector(handle, use_norm_cache ? n_samples : 0); @@ -753,13 +752,8 @@ void kmeans_fit( raft::make_device_matrix_view(batch_ptr, batch_size, n_features); auto norm_view = raft::make_device_vector_view(L2NormBatch.data_handle(), batch_size); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, batch_view, norm_view, raft::sqrt_op{}); - } else { - raft::linalg::norm( - handle, batch_view, norm_view); - } + raft::linalg::norm( + handle, batch_view, norm_view); }; std::mt19937 gen(pams.rng_state.seed); From 0d030a2e801e42e1dd2cb86ee67f3093572c6124 Mon Sep 17 00:00:00 2001 From: tarangj Date: Tue, 28 Apr 2026 09:36:28 -0700 Subject: [PATCH 33/50] change suffix of the params struct --- c/include/cuvs/cluster/kmeans.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 3f3f487590..f31d8f4c80 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -39,7 +39,7 @@ typedef enum { /** * @brief Hyper-parameters for the kmeans algorithm - * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v1. + * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v06. * CalVer for the replacement: 26.08 */ struct cuvsKMeansParams { @@ -123,7 +123,7 @@ struct cuvsKMeansParams { /** * @brief Hyper-parameters for the kmeans algorithm */ - struct cuvsKMeansParams_v1 { + struct cuvsKMeansParams_v06 { cuvsDistanceType metric; /** From b1c034ef1a42e3c5104acf0f5a7871c9f692ca1e Mon Sep 17 00:00:00 2001 From: tarangj Date: Tue, 28 Apr 2026 09:57:00 -0700 Subject: [PATCH 34/50] replace 06 by 08, add todo and note --- c/include/cuvs/cluster/kmeans.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index f31d8f4c80..9eb7999043 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -39,8 +39,8 @@ typedef enum { /** * @brief Hyper-parameters for the kmeans algorithm - * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v06. - * CalVer for the replacement: 26.08 + * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v08. + * TODO: CalVer for the replacement: 26.08 */ struct cuvsKMeansParams { cuvsDistanceType metric; @@ -122,8 +122,9 @@ struct cuvsKMeansParams { /** * @brief Hyper-parameters for the kmeans algorithm + * TODO: Remove this after cuvsKMeansParams is replaced in 26.08 */ - struct cuvsKMeansParams_v06 { + struct cuvsKMeansParams_v08 { cuvsDistanceType metric; /** From a482495d4305e17d570ac9bbdb26212d777166ac Mon Sep 17 00:00:00 2001 From: tarangj Date: Tue, 28 Apr 2026 11:00:40 -0700 Subject: [PATCH 35/50] update to v2 --- c/include/cuvs/cluster/kmeans.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 9eb7999043..79ceaf842c 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -39,7 +39,7 @@ typedef enum { /** * @brief Hyper-parameters for the kmeans algorithm - * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v08. + * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v2. * TODO: CalVer for the replacement: 26.08 */ struct cuvsKMeansParams { @@ -122,9 +122,9 @@ struct cuvsKMeansParams { /** * @brief Hyper-parameters for the kmeans algorithm - * TODO: Remove this after cuvsKMeansParams is replaced in 26.08 + * TODO: Remove this after cuvsKMeansParams is replaced in ABI 2.0 */ - struct cuvsKMeansParams_v08 { + struct cuvsKMeansParams_v2 { cuvsDistanceType metric; /** From 8ecfdc1e2d014ca100bc6a51d469544bab9d1527 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 29 Apr 2026 09:10:52 -0700 Subject: [PATCH 36/50] avoid stream sync inside weight sum --- cpp/src/cluster/detail/kmeans.cuh | 65 ++++++++++-------------- cpp/src/cluster/detail/kmeans_common.cuh | 23 ++++----- 2 files changed, 37 insertions(+), 51 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 4a0195d1de..c339470491 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -590,24 +590,17 @@ void kmeans_fit( streaming_batch_size = static_cast(n_samples); } + constexpr bool data_on_device = raft::is_device_mdspan_v; + const DataT* weight_ptr = sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; - DataT wt_sum = sample_weight.has_value() ? weightSum(handle, sample_weight.value()) - : static_cast(n_samples); - const DataT wt_rel_tol = n_samples * std::numeric_limits::epsilon(); - const bool needs_wt_rescale = - sample_weight.has_value() && std::abs(wt_sum - static_cast(n_samples)) > wt_rel_tol; - if (needs_wt_rescale) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to sum up to %zu samples", - static_cast(n_samples)); - } + + auto d_wt_sum = raft::make_device_scalar(handle, static_cast(n_samples)); + if (sample_weight.has_value()) { weightSum(handle, sample_weight.value(), d_wt_sum.view()); } rmm::device_uvector local_workspace(0, stream); rmm::device_uvector& ws = workspace.has_value() ? workspace->get() : local_workspace; - constexpr bool data_on_device = raft::is_device_mdspan_v; - if (data_on_device && streaming_batch_size != static_cast(n_samples)) { RAFT_LOG_WARN( "KMeans: streaming_batch_size (%zu) ignored when data resides on device; using n_samples " @@ -713,21 +706,21 @@ void kmeans_fit( raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); } - // Copies and optionally rescales `wt_data` into `batch_weights_buf`. When - // `wt_data == nullptr`, the buffer is assumed to have been pre-filled with 1 - // for the unweighted case. + // Copies and rescales `wt_data` into `batch_weights_buf` so that weights + // are normalized to sum to n_samples. + const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); auto prepare_batch_weights = [&](const DataT* wt_data, IndexT cur_batch_size) { if (wt_data != nullptr) { raft::copy(batch_weights_buf.data_handle(), wt_data, cur_batch_size, stream); - if (needs_wt_rescale) { - auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), - cur_batch_size); - raft::linalg::map(handle, - bw, - raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, - raft::div_const_op{wt_sum}), - raft::make_const_mdspan(bw)); - } + auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + raft::linalg::map( + handle, + bw, + [n_samples, d_wt_sum_ptr] __device__(DataT w) { + return w * static_cast(n_samples) / *d_wt_sum_ptr; + }, + raft::make_const_mdspan(bw)); } return raft::make_device_vector_view(batch_weights_buf.data_handle(), cur_batch_size); @@ -1036,19 +1029,17 @@ void kmeans_predict(raft::resources const& handle, else raft::matrix::fill(handle, weight.view(), DataT(1)); - if (normalize_weight) { - DataT wt_sum = weightSum(handle, raft::make_const_mdspan(weight.view())); - const DataT rel_tol = n_samples * std::numeric_limits::epsilon(); - if (std::abs(wt_sum - n_samples) > rel_tol) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to sum up to %zu samples", - static_cast(n_samples)); - raft::linalg::map(handle, - weight.view(), - raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, - raft::div_const_op{wt_sum}), - raft::make_const_mdspan(weight.view())); - } + if (normalize_weight && sample_weight.has_value()) { + auto d_wt_sum = raft::make_device_scalar(handle, DataT{0}); + weightSum(handle, raft::make_const_mdspan(weight.view()), d_wt_sum.view()); + const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); + raft::linalg::map( + handle, + weight.view(), + [n_samples, d_wt_sum_ptr] __device__(DataT w) { + return w * static_cast(n_samples) / *d_wt_sum_ptr; + }, + raft::make_const_mdspan(weight.view())); } auto minClusterAndDistance = diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index ec46b97b4e..d53a5943ff 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -134,35 +134,30 @@ void countLabels(raft::resources const& handle, } /** - * @brief Compute the sum of sample weights. + * @brief Compute the sum of sample weights into a device scalar. * - * Device-accessible mdspans are reduced on device via mapThenSumReduce; - * host mdspans are summed on the host. - * - * @return Sum of weights. + * Device-accessible mdspans are reduced on device. Host mdspans are summed on the host. */ template -DataT weightSum( +void weightSum( raft::resources const& handle, - raft::mdspan, raft::layout_right, Accessor> weight) + raft::mdspan, raft::layout_right, Accessor> weight, + raft::device_scalar_view d_wt_sum) { auto n_samples = weight.extent(0); + auto stream = raft::resource::get_cuda_stream(handle); - DataT wt_sum = DataT{0}; if constexpr (raft::is_device_mdspan_v) { - auto stream = raft::resource::get_cuda_stream(handle); - auto d_wt_sum = raft::make_device_scalar(handle, DataT{0}); raft::linalg::mapThenSumReduce( d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle()); - raft::copy(&wt_sum, d_wt_sum.data_handle(), 1, stream); - raft::resource::sync_stream(handle); } else { + DataT wt_sum = DataT{0}; for (IndexT i = 0; i < n_samples; ++i) { wt_sum += weight(i); } + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); + raft::copy(d_wt_sum.data_handle(), &wt_sum, 1, stream); } - RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); - return wt_sum; } template From ec22e079396a476bb42330f7445faa5fc1de060c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 29 Apr 2026 09:38:16 -0700 Subject: [PATCH 37/50] empty From d2e410df188d3965ac8e81f8cb5afb1a5ba60faa Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 29 Apr 2026 09:38:54 -0700 Subject: [PATCH 38/50] empty From a05a006db6a471818e77378c1bbe1cef2de9ca34 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 29 Apr 2026 10:16:38 -0700 Subject: [PATCH 39/50] new signatures with new struct --- c/include/cuvs/cluster/kmeans.h | 96 ++++++++++++++++++++++ c/src/cluster/kmeans.cpp | 88 ++++++++++++++++++-- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 39 +++++---- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 10 +-- rust/cuvs/src/cluster/kmeans/mod.rs | 6 +- rust/cuvs/src/cluster/kmeans/params.rs | 18 ++-- 6 files changed, 216 insertions(+), 41 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 79ceaf842c..0f5f4554c1 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -200,10 +200,14 @@ struct cuvsKMeansParams { }; typedef struct cuvsKMeansParams* cuvsKMeansParams_t; +typedef struct cuvsKMeansParams_v2* cuvsKMeansParams_v2_t; /** * @brief Allocate KMeans params, and populate with default values * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansParamsCreate_v2. + * * @param[in] params cuvsKMeansParams_t to allocate * @return cuvsError_t */ @@ -212,11 +216,33 @@ cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params); /** * @brief De-allocate KMeans params * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansParamsDestroy_v2. + * * @param[in] params * @return cuvsError_t */ cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params); +/** + * @brief Allocate KMeans params + * + * Mirrors cuvsKMeansParamsCreate but operates on cuvsKMeansParams_v2. + * Will become the unsuffixed cuvsKMeansParamsCreate in cuVS 26.08. + * + * @param[in] params cuvsKMeansParams_v2_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsCreate_v2(cuvsKMeansParams_v2_t* params); + +/** + * @brief De-allocate KMeans params allocated by cuvsKMeansParamsCreate_v2. + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsDestroy_v2(cuvsKMeansParams_v2_t params); + /** * @brief Type of k-means algorithm. */ @@ -242,6 +268,9 @@ typedef enum { CUVS_KMEANS_TYPE_KMEANS = 0, CUVS_KMEANS_TYPE_KMEANS_BALANCED = 1 * When X is on the host the data is streamed to the GPU in * batches controlled by params->streaming_batch_size. * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansFit_v2. + * * @param[in] res opaque C handle * @param[in] params Parameters for KMeans model. * @param[in] X Training instances to cluster. The data must @@ -269,9 +298,45 @@ cuvsError_t cuvsKMeansFit(cuvsResources_t res, double* inertia, int* n_iter); +/** + * @brief Find clusters with k-means algorithm (v2 params layout). + * + * Mirrors cuvsKMeansFit but takes cuvsKMeansParams_v2_t. Will become the + * unsuffixed cuvsKMeansFit in cuVS 26.08. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model (v2 layout). + * @param[in] X Training instances to cluster. The data must + * be in row-major format. May be on host or + * device memory. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * Must be on the same memory space as X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. Must be on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +cuvsError_t cuvsKMeansFit_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter); + /** * @brief Predict the closest cluster each sample in X belongs to. * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansPredict_v2. + * * @param[in] res opaque C handle * @param[in] params Parameters for KMeans model. * @param[in] X New data to predict. @@ -297,6 +362,37 @@ cuvsError_t cuvsKMeansPredict(cuvsResources_t res, bool normalize_weight, double* inertia); +/** + * @brief Predict the closest cluster each sample in X belongs to (v2 params layout). + * + * Mirrors cuvsKMeansPredict but takes cuvsKMeansParams_v2_t. Will become the + * unsuffixed cuvsKMeansPredict in cuVS 26.08. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model (v2 layout). + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +cuvsError_t cuvsKMeansPredict_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia); + /** * @brief Compute cluster cost * diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 495a83f8d5..8e46764ce4 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -16,7 +16,9 @@ namespace { -cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) +// The conversions are templated on the C struct type and reused by both API surfaces. +template +cuvs::cluster::kmeans::params convert_params(const ParamsT& params) { auto kmeans_params = cuvs::cluster::kmeans::params(); kmeans_params.metric = static_cast(params.metric); @@ -33,7 +35,8 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) return kmeans_params; } -cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansParams& params) +template +cuvs::cluster::kmeans::balanced_params convert_balanced_params(const ParamsT& params) { auto kmeans_params = cuvs::cluster::kmeans::balanced_params(); kmeans_params.metric = static_cast(params.metric); @@ -41,9 +44,9 @@ cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansP return kmeans_params; } -template +template void _fit(cuvsResources_t res, - const cuvsKMeansParams& params, + const ParamsT& params, DLManagedTensor* X_tensor, DLManagedTensor* sample_weight_tensor, DLManagedTensor* centroids_tensor, @@ -140,9 +143,9 @@ void _fit(cuvsResources_t res, } } -template +template void _predict(cuvsResources_t res, - const cuvsKMeansParams& params, + const ParamsT& params, DLManagedTensor* X_tensor, DLManagedTensor* sample_weight_tensor, DLManagedTensor* centroids_tensor, @@ -295,6 +298,79 @@ extern "C" cuvsError_t cuvsKMeansPredict(cuvsResources_t res, }); } +extern "C" cuvsError_t cuvsKMeansParamsCreate_v2(cuvsKMeansParams_v2_t* params) +{ + return cuvs::core::translate_exceptions([=] { + cuvs::cluster::kmeans::params cpp_params; + cuvs::cluster::kmeans::balanced_params cpp_balanced_params; + *params = new cuvsKMeansParams_v2{ + .metric = static_cast(cpp_params.metric), + .n_clusters = cpp_params.n_clusters, + .init = static_cast(cpp_params.init), + .max_iter = cpp_params.max_iter, + .tol = cpp_params.tol, + .n_init = cpp_params.n_init, + .oversampling_factor = cpp_params.oversampling_factor, + .batch_samples = cpp_params.batch_samples, + .batch_centroids = cpp_params.batch_centroids, + .hierarchical = false, + .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), + .streaming_batch_size = cpp_params.streaming_batch_size, + .init_size = cpp_params.init_size}; + }); +} + +extern "C" cuvsError_t cuvsKMeansParamsDestroy_v2(cuvsKMeansParams_v2_t params) +{ + return cuvs::core::translate_exceptions([=] { delete params; }); +} + +extern "C" cuvsError_t cuvsKMeansFit_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsKMeansPredict_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _predict(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _predict( + res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* X, DLManagedTensor* centroids, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index ccacb7042b..38f0b14a91 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -22,7 +22,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED - ctypedef struct cuvsKMeansParams: + ctypedef struct cuvsKMeansParams_v2: cuvsDistanceType metric, int n_clusters, cuvsKMeansInitMethod init, @@ -32,34 +32,33 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, - bool inertia_check, bool hierarchical, int hierarchical_n_iters, int64_t streaming_batch_size, int64_t init_size - ctypedef cuvsKMeansParams* cuvsKMeansParams_t + ctypedef cuvsKMeansParams_v2* cuvsKMeansParams_v2_t - cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* index) + cuvsError_t cuvsKMeansParamsCreate_v2(cuvsKMeansParams_v2_t* index) - cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t index) + cuvsError_t cuvsKMeansParamsDestroy_v2(cuvsKMeansParams_v2_t index) - cuvsError_t cuvsKMeansFit(cuvsResources_t res, - cuvsKMeansParams_t params, - DLManagedTensor* X, - DLManagedTensor* sample_weight, - DLManagedTensor * centroids, - double * inertia, - int * n_iter) except + + cuvsError_t cuvsKMeansFit_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor * centroids, + double * inertia, + int * n_iter) except + - cuvsError_t cuvsKMeansPredict(cuvsResources_t res, - cuvsKMeansParams_t params, - DLManagedTensor* X, - DLManagedTensor* sample_weight, - DLManagedTensor * centroids, - DLManagedTensor * labels, - bool normalize_weight, - double * inertia) + cuvsError_t cuvsKMeansPredict_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor * centroids, + DLManagedTensor * labels, + bool normalize_weight, + double * inertia) cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* X, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 2e9046b4b2..2efc969698 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -100,13 +100,13 @@ cdef class KMeansParams: For hierarchical k-means , defines the number of training iterations """ - cdef cuvsKMeansParams* params + cdef cuvsKMeansParams_v2* params def __cinit__(self): - cuvsKMeansParamsCreate(&self.params) + cuvsKMeansParamsCreate_v2(&self.params) def __dealloc__(self): - check_cuvs(cuvsKMeansParamsDestroy(self.params)) + check_cuvs(cuvsKMeansParamsDestroy_v2(self.params)) def __init__(self, *, metric=None, @@ -338,7 +338,7 @@ def fit( cydlpack.dlpack_c(wrap_array(sample_weights)) with cuda_interruptible(): - check_cuvs(cuvsKMeansFit( + check_cuvs(cuvsKMeansFit_v2( res, params.params, x_dlpack, @@ -431,7 +431,7 @@ def predict( cdef double inertia = 0 with cuda_interruptible(): - check_cuvs(cuvsKMeansPredict( + check_cuvs(cuvsKMeansPredict_v2( res, params.params, x_dlpack, diff --git a/rust/cuvs/src/cluster/kmeans/mod.rs b/rust/cuvs/src/cluster/kmeans/mod.rs index 6fb0848e3d..74ac7732a7 100644 --- a/rust/cuvs/src/cluster/kmeans/mod.rs +++ b/rust/cuvs/src/cluster/kmeans/mod.rs @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -68,7 +68,7 @@ pub fn fit( Some(tensor) => tensor.as_ptr(), None => std::ptr::null_mut(), }; - check_cuvs(ffi::cuvsKMeansFit( + check_cuvs(ffi::cuvsKMeansFit_v2( res.0, params.0, x.as_ptr(), @@ -108,7 +108,7 @@ pub fn predict( Some(tensor) => tensor.as_ptr(), None => std::ptr::null_mut(), }; - check_cuvs(ffi::cuvsKMeansPredict( + check_cuvs(ffi::cuvsKMeansPredict_v2( res.0, params.0, x.as_ptr(), diff --git a/rust/cuvs/src/cluster/kmeans/params.rs b/rust/cuvs/src/cluster/kmeans/params.rs index 46e4957a32..b241cddc23 100644 --- a/rust/cuvs/src/cluster/kmeans/params.rs +++ b/rust/cuvs/src/cluster/kmeans/params.rs @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,14 +8,14 @@ use crate::error::{check_cuvs, Result}; use std::fmt; use std::io::{stderr, Write}; -pub struct Params(pub ffi::cuvsKMeansParams_t); +pub struct Params(pub ffi::cuvsKMeansParams_v2_t); impl Params { /// Returns a new Params pub fn new() -> Result { unsafe { - let mut params = std::mem::MaybeUninit::::uninit(); - check_cuvs(ffi::cuvsKMeansParamsCreate(params.as_mut_ptr()))?; + let mut params = std::mem::MaybeUninit::::uninit(); + check_cuvs(ffi::cuvsKMeansParamsCreate_v2(params.as_mut_ptr()))?; Ok(Params(params.assume_init())) } } @@ -115,9 +115,13 @@ impl fmt::Debug for Params { impl Drop for Params { fn drop(&mut self) { - if let Err(e) = check_cuvs(unsafe { ffi::cuvsKMeansParamsDestroy(self.0) }) { - write!(stderr(), "failed to call cuvsKMeansParamsDestroy {:?}", e) - .expect("failed to write to stderr"); + if let Err(e) = check_cuvs(unsafe { ffi::cuvsKMeansParamsDestroy_v2(self.0) }) { + write!( + stderr(), + "failed to call cuvsKMeansParamsDestroy_v2 {:?}", + e + ) + .expect("failed to write to stderr"); } } } From e2035ec5dd9d3b509f6b300a23e7ddcf26a053d3 Mon Sep 17 00:00:00 2001 From: tarangj Date: Thu, 30 Apr 2026 10:41:44 -0700 Subject: [PATCH 40/50] revert change to calls in py and rust; add c tests --- c/tests/CMakeLists.txt | 1 + c/tests/cluster/kmeans_c.cu | 211 +++++++++++++++++++++ c/tests/cluster/run_kmeans_c.c | 200 +++++++++++++++++++ python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 44 +++-- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 10 +- rust/cuvs/src/cluster/kmeans/mod.rs | 6 +- rust/cuvs/src/cluster/kmeans/params.rs | 18 +- 7 files changed, 452 insertions(+), 38 deletions(-) create mode 100644 c/tests/cluster/kmeans_c.cu create mode 100644 c/tests/cluster/run_kmeans_c.c diff --git a/c/tests/CMakeLists.txt b/c/tests/CMakeLists.txt index a80f0b518a..351edc90b9 100644 --- a/c/tests/CMakeLists.txt +++ b/c/tests/CMakeLists.txt @@ -74,6 +74,7 @@ ConfigureTest(NAME INTEROP_TEST PATH core/interop.cu) ConfigureTest( NAME DISTANCE_C_TEST PATH distance/run_pairwise_distance_c.c distance/pairwise_distance_c.cu ) +ConfigureTest(NAME KMEANS_C_TEST PATH cluster/run_kmeans_c.c cluster/kmeans_c.cu) ConfigureTest(NAME BRUTEFORCE_C_TEST PATH neighbors/run_brute_force_c.c neighbors/brute_force_c.cu) ConfigureTest(NAME IVF_FLAT_C_TEST PATH neighbors/run_ivf_flat_c.c neighbors/ann_ivf_flat_c.cu) ConfigureTest(NAME IVF_PQ_C_TEST PATH neighbors/run_ivf_pq_c.c neighbors/ann_ivf_pq_c.cu) diff --git a/c/tests/cluster/kmeans_c.cu b/c/tests/cluster/kmeans_c.cu new file mode 100644 index 0000000000..fbf0054a4a --- /dev/null +++ b/c/tests/cluster/kmeans_c.cu @@ -0,0 +1,211 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + + #include "test_utils.cuh" + + #include + #include + + #include + #include + #include + + #include + #include + + #include + #include + + extern "C" void run_kmeans(int64_t n_samples, + int64_t n_features, + int n_clusters, + int max_iter, + double tol, + cuvsKMeansInitMethod init, + bool dataset_on_host, + int64_t streaming_batch_size, + void* dataset_data, + float* centroids_data, + int32_t* labels_data, + double* inertia_out, + int* n_iter_out, + double* predict_inertia_out, + double* cluster_cost_out); + + // TODO(cuVS 26.08): remove run_kmeans_v2 declaration once the `_v2` ABI is +// promoted to the unsuffixed names. +extern "C" void run_kmeans_v2(int64_t n_samples, + int64_t n_features, + int n_clusters, + int max_iter, + double tol, + cuvsKMeansInitMethod init, + bool dataset_on_host, + int64_t streaming_batch_size, + void* dataset_data, + float* centroids_data, + int32_t* labels_data, + double* inertia_out, + int* n_iter_out, + double* predict_inertia_out, + double* cluster_cost_out); + + namespace { + + constexpr int64_t kNSamples = 8; + constexpr int64_t kNFeatures = 2; + constexpr int kNClusters = 2; + + float kDataset[kNSamples][kNFeatures] = { + {1.0f, 1.0f}, + {1.0f, 2.0f}, + {2.0f, 1.0f}, + {2.0f, 2.0f}, + {10.0f, 10.0f}, + {10.0f, 11.0f}, + {11.0f, 10.0f}, + {11.0f, 11.0f}, + }; + +float kInitCentroids[kNClusters][kNFeatures] = { + {0.0f, 0.0f}, + {12.0f, 12.0f}, +}; + + float kExpectedCentroids[kNClusters * kNFeatures] = {1.5f, 1.5f, 10.5f, 10.5f}; + int32_t kExpectedLabels[kNSamples] = {0, 0, 0, 0, 1, 1, 1, 1}; + + // 8 points, each at squared distance 0.5 from its cluster mean -> 4.0. + constexpr double kExpectedInertia = 4.0; + + template + void test_fit_predict(RunFn run_fn) + { + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector dataset_d(kNSamples * kNFeatures, stream); + rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); + rmm::device_uvector labels_d(kNSamples, stream); + + raft::copy(dataset_d.data(), + reinterpret_cast(kDataset), + kNSamples * kNFeatures, + stream); + raft::copy(centroids_d.data(), + reinterpret_cast(kInitCentroids), + kNClusters * kNFeatures, + stream); + + double inertia = -1.0; + int n_iter = -1; + double predict_inertia = -1.0; + double cluster_cost = -1.0; + + run_fn(kNSamples, + kNFeatures, + kNClusters, + 100, + 1e-6, + Array, + false, + 0, + dataset_d.data(), + centroids_d.data(), + labels_d.data(), + &inertia, + &n_iter, + &predict_inertia, + &cluster_cost); + + ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, + centroids_d.data(), + kNClusters * kNFeatures, + cuvs::CompareApprox(1e-4f))); + ASSERT_TRUE(cuvs::devArrMatchHost( + kExpectedLabels, labels_d.data(), kNSamples, cuvs::Compare())); + + EXPECT_GT(n_iter, 0); + EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); + EXPECT_NEAR(predict_inertia, kExpectedInertia, 1e-4); + EXPECT_NEAR(cluster_cost, kExpectedInertia, 1e-4); + } + + template + void test_fit_host(RunFn run_fn) + { + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); + raft::copy(centroids_d.data(), + reinterpret_cast(kInitCentroids), + kNClusters * kNFeatures, + stream); + + double inertia = -1.0; + int n_iter = -1; + double unused_predict = 0.0; + double unused_cost = 0.0; + + run_fn(kNSamples, + kNFeatures, + kNClusters, + 100, + 1e-6, + Array, + true, + 4, // force at least 2 streamed batches + reinterpret_cast(kDataset), + centroids_d.data(), + nullptr, + &inertia, + &n_iter, + &unused_predict, + &unused_cost); + + ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, + centroids_d.data(), + kNClusters * kNFeatures, + cuvs::CompareApprox(1e-4f))); + + EXPECT_GT(n_iter, 0); + EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); + } + + } // namespace + +TEST(KMeansC, FitPredict) { test_fit_predict(run_kmeans); } +// TODO(cuVS 26.08): remove FitPredictV2 once `_v2` is promoted to the +// unsuffixed ABI -- it will be redundant with FitPredict at that point. +TEST(KMeansC, FitPredictV2) { test_fit_predict(run_kmeans_v2); } + +TEST(KMeansC, FitHost) { test_fit_host(run_kmeans); } +// TODO(cuVS 26.08): remove FitHostV2 once `_v2` is promoted to the +// unsuffixed ABI. +TEST(KMeansC, FitHostV2) { test_fit_host(run_kmeans_v2); } + + TEST(KMeansC, ParamsCreateDestroy) + { + cuvsKMeansParams_t params = nullptr; + ASSERT_EQ(cuvsKMeansParamsCreate(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_clusters, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsKMeansParamsDestroy(params), CUVS_SUCCESS); + } + +// TODO(cuVS 26.08): remove ParamsCreateDestroyV2 once cuvsKMeansParamsCreate_v2 +// / cuvsKMeansParamsDestroy_v2 are promoted to the unsuffixed entry points and +// the `_v2` symbols are deleted from the public header. +TEST(KMeansC, ParamsCreateDestroyV2) +{ + cuvsKMeansParams_v2_t params = nullptr; + ASSERT_EQ(cuvsKMeansParamsCreate_v2(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_clusters, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsKMeansParamsDestroy_v2(params), CUVS_SUCCESS); + } diff --git a/c/tests/cluster/run_kmeans_c.c b/c/tests/cluster/run_kmeans_c.c new file mode 100644 index 0000000000..c2b140db9d --- /dev/null +++ b/c/tests/cluster/run_kmeans_c.c @@ -0,0 +1,200 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + + #include + #include + #include + #include + #include + #include + + static void fill_matrix_tensor(DLManagedTensor* t, + void* data, + int64_t* shape, + DLDeviceType device_type, + uint8_t code, + uint8_t bits) + { + t->dl_tensor.data = data; + t->dl_tensor.device.device_type = device_type; + t->dl_tensor.device.device_id = 0; + t->dl_tensor.ndim = 2; + t->dl_tensor.dtype.code = code; + t->dl_tensor.dtype.bits = bits; + t->dl_tensor.dtype.lanes = 1; + t->dl_tensor.shape = shape; + t->dl_tensor.strides = NULL; + t->dl_tensor.byte_offset = 0; + t->manager_ctx = NULL; + t->deleter = NULL; + } + + static void fill_vector_tensor(DLManagedTensor* t, + void* data, + int64_t* shape, + DLDeviceType device_type, + uint8_t code, + uint8_t bits) + { + t->dl_tensor.data = data; + t->dl_tensor.device.device_type = device_type; + t->dl_tensor.device.device_id = 0; + t->dl_tensor.ndim = 1; + t->dl_tensor.dtype.code = code; + t->dl_tensor.dtype.bits = bits; + t->dl_tensor.dtype.lanes = 1; + t->dl_tensor.shape = shape; + t->dl_tensor.strides = NULL; + t->dl_tensor.byte_offset = 0; + t->manager_ctx = NULL; + t->deleter = NULL; + } + + /** + * Run KMeans fit + (optional) predict + cluster_cost using the C API. + * + * If `dataset_on_host` is true, `dataset_data` is a host pointer, otherwise it is a + * device pointer. `centroids_data` and `labels_data` are always device pointers. + * + * `predict_inertia_out`/`labels_data`/`cluster_cost_out` are only used when + * `dataset_on_host` is false (predict + cluster_cost require device data). + */ + void run_kmeans(int64_t n_samples, + int64_t n_features, + int n_clusters, + int max_iter, + double tol, + cuvsKMeansInitMethod init, + bool dataset_on_host, + int64_t streaming_batch_size, + void* dataset_data, + float* centroids_data, + int32_t* labels_data, + double* inertia_out, + int* n_iter_out, + double* predict_inertia_out, + double* cluster_cost_out) + { + cuvsResources_t res; + cuvsResourcesCreate(&res); + + cuvsKMeansParams_t params; + cuvsKMeansParamsCreate(¶ms); + params->n_clusters = n_clusters; + params->max_iter = max_iter; + params->tol = tol; + params->init = init; + params->streaming_batch_size = streaming_batch_size; + + DLManagedTensor dataset_tensor; + int64_t dataset_shape[2] = {n_samples, n_features}; + fill_matrix_tensor(&dataset_tensor, + dataset_data, + dataset_shape, + dataset_on_host ? kDLCPU : kDLCUDA, + kDLFloat, + 32); + + DLManagedTensor centroids_tensor; + int64_t centroids_shape[2] = {n_clusters, n_features}; + fill_matrix_tensor( + ¢roids_tensor, centroids_data, centroids_shape, kDLCUDA, kDLFloat, 32); + + cuvsKMeansFit( + res, params, &dataset_tensor, NULL, ¢roids_tensor, inertia_out, n_iter_out); + + if (!dataset_on_host) { + DLManagedTensor labels_tensor; + int64_t labels_shape[1] = {n_samples}; + fill_vector_tensor(&labels_tensor, labels_data, labels_shape, kDLCUDA, kDLInt, 32); + + cuvsKMeansPredict(res, + params, + &dataset_tensor, + NULL, + ¢roids_tensor, + &labels_tensor, + false, + predict_inertia_out); + + cuvsKMeansClusterCost(res, &dataset_tensor, ¢roids_tensor, cluster_cost_out); + } + + cuvsKMeansParamsDestroy(params); + cuvsResourcesDestroy(res); + } + +/** + * Run KMeans fit + (optional) predict + cluster_cost. + * + * TODO(cuVS 26.08): delete run_kmeans_v2 once the `_v2` entry points + * (cuvsKMeansFit_v2 / cuvsKMeansPredict_v2 / cuvsKMeansParamsCreate_v2 / + * cuvsKMeansParamsDestroy_v2) are promoted to the unsuffixed names in the + * public header. + */ +void run_kmeans_v2(int64_t n_samples, + int64_t n_features, + int n_clusters, + int max_iter, + double tol, + cuvsKMeansInitMethod init, + bool dataset_on_host, + int64_t streaming_batch_size, + void* dataset_data, + float* centroids_data, + int32_t* labels_data, + double* inertia_out, + int* n_iter_out, + double* predict_inertia_out, + double* cluster_cost_out) + { + cuvsResources_t res; + cuvsResourcesCreate(&res); + + cuvsKMeansParams_v2_t params; + cuvsKMeansParamsCreate_v2(¶ms); + params->n_clusters = n_clusters; + params->max_iter = max_iter; + params->tol = tol; + params->init = init; + params->streaming_batch_size = streaming_batch_size; + + DLManagedTensor dataset_tensor; + int64_t dataset_shape[2] = {n_samples, n_features}; + fill_matrix_tensor(&dataset_tensor, + dataset_data, + dataset_shape, + dataset_on_host ? kDLCPU : kDLCUDA, + kDLFloat, + 32); + + DLManagedTensor centroids_tensor; + int64_t centroids_shape[2] = {n_clusters, n_features}; + fill_matrix_tensor( + ¢roids_tensor, centroids_data, centroids_shape, kDLCUDA, kDLFloat, 32); + + cuvsKMeansFit_v2( + res, params, &dataset_tensor, NULL, ¢roids_tensor, inertia_out, n_iter_out); + + if (!dataset_on_host) { + DLManagedTensor labels_tensor; + int64_t labels_shape[1] = {n_samples}; + fill_vector_tensor(&labels_tensor, labels_data, labels_shape, kDLCUDA, kDLInt, 32); + + cuvsKMeansPredict_v2(res, + params, + &dataset_tensor, + NULL, + ¢roids_tensor, + &labels_tensor, + false, + predict_inertia_out); + + cuvsKMeansClusterCost(res, &dataset_tensor, ¢roids_tensor, cluster_cost_out); + } + + cuvsKMeansParamsDestroy_v2(params); + cuvsResourcesDestroy(res); + } diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 38f0b14a91..975ef386df 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -22,7 +22,12 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED - ctypedef struct cuvsKMeansParams_v2: + # NOTE: The Python binding currently targets the unsuffixed cuvsKMeansParams + # ABI (which still carries the deprecated `inertia_check` field). In cuVS + # 26.08 this struct/entry-point set will be replaced by the contents of + # cuvsKMeansParams_v2 -- once that lands, the `inertia_check` field below + # should be deleted. + ctypedef struct cuvsKMeansParams: cuvsDistanceType metric, int n_clusters, cuvsKMeansInitMethod init, @@ -32,33 +37,34 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, + bool inertia_check, bool hierarchical, int hierarchical_n_iters, int64_t streaming_batch_size, int64_t init_size - ctypedef cuvsKMeansParams_v2* cuvsKMeansParams_v2_t + ctypedef cuvsKMeansParams* cuvsKMeansParams_t - cuvsError_t cuvsKMeansParamsCreate_v2(cuvsKMeansParams_v2_t* index) + cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* index) - cuvsError_t cuvsKMeansParamsDestroy_v2(cuvsKMeansParams_v2_t index) + cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t index) - cuvsError_t cuvsKMeansFit_v2(cuvsResources_t res, - cuvsKMeansParams_v2_t params, - DLManagedTensor* X, - DLManagedTensor* sample_weight, - DLManagedTensor * centroids, - double * inertia, - int * n_iter) except + + cuvsError_t cuvsKMeansFit(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor * centroids, + double * inertia, + int * n_iter) except + - cuvsError_t cuvsKMeansPredict_v2(cuvsResources_t res, - cuvsKMeansParams_v2_t params, - DLManagedTensor* X, - DLManagedTensor* sample_weight, - DLManagedTensor * centroids, - DLManagedTensor * labels, - bool normalize_weight, - double * inertia) + cuvsError_t cuvsKMeansPredict(cuvsResources_t res, + cuvsKMeansParams_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor * centroids, + DLManagedTensor * labels, + bool normalize_weight, + double * inertia) cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* X, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 2efc969698..2e9046b4b2 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -100,13 +100,13 @@ cdef class KMeansParams: For hierarchical k-means , defines the number of training iterations """ - cdef cuvsKMeansParams_v2* params + cdef cuvsKMeansParams* params def __cinit__(self): - cuvsKMeansParamsCreate_v2(&self.params) + cuvsKMeansParamsCreate(&self.params) def __dealloc__(self): - check_cuvs(cuvsKMeansParamsDestroy_v2(self.params)) + check_cuvs(cuvsKMeansParamsDestroy(self.params)) def __init__(self, *, metric=None, @@ -338,7 +338,7 @@ def fit( cydlpack.dlpack_c(wrap_array(sample_weights)) with cuda_interruptible(): - check_cuvs(cuvsKMeansFit_v2( + check_cuvs(cuvsKMeansFit( res, params.params, x_dlpack, @@ -431,7 +431,7 @@ def predict( cdef double inertia = 0 with cuda_interruptible(): - check_cuvs(cuvsKMeansPredict_v2( + check_cuvs(cuvsKMeansPredict( res, params.params, x_dlpack, diff --git a/rust/cuvs/src/cluster/kmeans/mod.rs b/rust/cuvs/src/cluster/kmeans/mod.rs index 74ac7732a7..6fb0848e3d 100644 --- a/rust/cuvs/src/cluster/kmeans/mod.rs +++ b/rust/cuvs/src/cluster/kmeans/mod.rs @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -68,7 +68,7 @@ pub fn fit( Some(tensor) => tensor.as_ptr(), None => std::ptr::null_mut(), }; - check_cuvs(ffi::cuvsKMeansFit_v2( + check_cuvs(ffi::cuvsKMeansFit( res.0, params.0, x.as_ptr(), @@ -108,7 +108,7 @@ pub fn predict( Some(tensor) => tensor.as_ptr(), None => std::ptr::null_mut(), }; - check_cuvs(ffi::cuvsKMeansPredict_v2( + check_cuvs(ffi::cuvsKMeansPredict( res.0, params.0, x.as_ptr(), diff --git a/rust/cuvs/src/cluster/kmeans/params.rs b/rust/cuvs/src/cluster/kmeans/params.rs index b241cddc23..46e4957a32 100644 --- a/rust/cuvs/src/cluster/kmeans/params.rs +++ b/rust/cuvs/src/cluster/kmeans/params.rs @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,14 +8,14 @@ use crate::error::{check_cuvs, Result}; use std::fmt; use std::io::{stderr, Write}; -pub struct Params(pub ffi::cuvsKMeansParams_v2_t); +pub struct Params(pub ffi::cuvsKMeansParams_t); impl Params { /// Returns a new Params pub fn new() -> Result { unsafe { - let mut params = std::mem::MaybeUninit::::uninit(); - check_cuvs(ffi::cuvsKMeansParamsCreate_v2(params.as_mut_ptr()))?; + let mut params = std::mem::MaybeUninit::::uninit(); + check_cuvs(ffi::cuvsKMeansParamsCreate(params.as_mut_ptr()))?; Ok(Params(params.assume_init())) } } @@ -115,13 +115,9 @@ impl fmt::Debug for Params { impl Drop for Params { fn drop(&mut self) { - if let Err(e) = check_cuvs(unsafe { ffi::cuvsKMeansParamsDestroy_v2(self.0) }) { - write!( - stderr(), - "failed to call cuvsKMeansParamsDestroy_v2 {:?}", - e - ) - .expect("failed to write to stderr"); + if let Err(e) = check_cuvs(unsafe { ffi::cuvsKMeansParamsDestroy(self.0) }) { + write!(stderr(), "failed to call cuvsKMeansParamsDestroy {:?}", e) + .expect("failed to write to stderr"); } } } From 55bbdadb25b6ef651b8513724a4fda31d27e6237 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 4 May 2026 23:04:48 -0700 Subject: [PATCH 41/50] use to_dlpack --- c/tests/CMakeLists.txt | 2 +- c/tests/cluster/kmeans_c.cu | 431 ++++++++++++++++++--------------- c/tests/cluster/run_kmeans_c.c | 200 --------------- 3 files changed, 243 insertions(+), 390 deletions(-) delete mode 100644 c/tests/cluster/run_kmeans_c.c diff --git a/c/tests/CMakeLists.txt b/c/tests/CMakeLists.txt index 351edc90b9..c1f502adf4 100644 --- a/c/tests/CMakeLists.txt +++ b/c/tests/CMakeLists.txt @@ -74,7 +74,7 @@ ConfigureTest(NAME INTEROP_TEST PATH core/interop.cu) ConfigureTest( NAME DISTANCE_C_TEST PATH distance/run_pairwise_distance_c.c distance/pairwise_distance_c.cu ) -ConfigureTest(NAME KMEANS_C_TEST PATH cluster/run_kmeans_c.c cluster/kmeans_c.cu) +ConfigureTest(NAME KMEANS_C_TEST PATH cluster/kmeans_c.cu) ConfigureTest(NAME BRUTEFORCE_C_TEST PATH neighbors/run_brute_force_c.c neighbors/brute_force_c.cu) ConfigureTest(NAME IVF_FLAT_C_TEST PATH neighbors/run_ivf_flat_c.c neighbors/ann_ivf_flat_c.cu) ConfigureTest(NAME IVF_PQ_C_TEST PATH neighbors/run_ivf_pq_c.c neighbors/ann_ivf_pq_c.cu) diff --git a/c/tests/cluster/kmeans_c.cu b/c/tests/cluster/kmeans_c.cu index fbf0054a4a..3c87d035d3 100644 --- a/c/tests/cluster/kmeans_c.cu +++ b/c/tests/cluster/kmeans_c.cu @@ -3,209 +3,262 @@ * SPDX-License-Identifier: Apache-2.0 */ - #include "test_utils.cuh" - - #include - #include - - #include - #include - #include - - #include - #include - - #include - #include - - extern "C" void run_kmeans(int64_t n_samples, - int64_t n_features, - int n_clusters, - int max_iter, - double tol, - cuvsKMeansInitMethod init, - bool dataset_on_host, - int64_t streaming_batch_size, - void* dataset_data, - float* centroids_data, - int32_t* labels_data, - double* inertia_out, - int* n_iter_out, - double* predict_inertia_out, - double* cluster_cost_out); - - // TODO(cuVS 26.08): remove run_kmeans_v2 declaration once the `_v2` ABI is -// promoted to the unsuffixed names. -extern "C" void run_kmeans_v2(int64_t n_samples, - int64_t n_features, - int n_clusters, - int max_iter, - double tol, - cuvsKMeansInitMethod init, - bool dataset_on_host, - int64_t streaming_batch_size, - void* dataset_data, - float* centroids_data, - int32_t* labels_data, - double* inertia_out, - int* n_iter_out, - double* predict_inertia_out, - double* cluster_cost_out); - - namespace { - - constexpr int64_t kNSamples = 8; - constexpr int64_t kNFeatures = 2; - constexpr int kNClusters = 2; - - float kDataset[kNSamples][kNFeatures] = { - {1.0f, 1.0f}, - {1.0f, 2.0f}, - {2.0f, 1.0f}, - {2.0f, 2.0f}, - {10.0f, 10.0f}, - {10.0f, 11.0f}, - {11.0f, 10.0f}, - {11.0f, 11.0f}, - }; +#include "test_utils.cuh" + +#include +#include + +#include +#include +#include +#include +#include + +#include "../../src/core/interop.hpp" +#include +#include + +#include + +namespace { + +constexpr int64_t kNSamples = 8; +constexpr int64_t kNFeatures = 2; +constexpr int kNClusters = 2; + +float kDataset[kNSamples][kNFeatures] = { + {1.0f, 1.0f}, + {1.0f, 2.0f}, + {2.0f, 1.0f}, + {2.0f, 2.0f}, + {10.0f, 10.0f}, + {10.0f, 11.0f}, + {11.0f, 10.0f}, + {11.0f, 11.0f}, +}; float kInitCentroids[kNClusters][kNFeatures] = { {0.0f, 0.0f}, {12.0f, 12.0f}, }; - float kExpectedCentroids[kNClusters * kNFeatures] = {1.5f, 1.5f, 10.5f, 10.5f}; - int32_t kExpectedLabels[kNSamples] = {0, 0, 0, 0, 1, 1, 1, 1}; - - // 8 points, each at squared distance 0.5 from its cluster mean -> 4.0. - constexpr double kExpectedInertia = 4.0; - - template - void test_fit_predict(RunFn run_fn) - { - raft::handle_t handle; - auto stream = raft::resource::get_cuda_stream(handle); - - rmm::device_uvector dataset_d(kNSamples * kNFeatures, stream); - rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); - rmm::device_uvector labels_d(kNSamples, stream); - - raft::copy(dataset_d.data(), - reinterpret_cast(kDataset), - kNSamples * kNFeatures, - stream); - raft::copy(centroids_d.data(), - reinterpret_cast(kInitCentroids), - kNClusters * kNFeatures, - stream); - - double inertia = -1.0; - int n_iter = -1; - double predict_inertia = -1.0; - double cluster_cost = -1.0; - - run_fn(kNSamples, - kNFeatures, - kNClusters, - 100, - 1e-6, - Array, - false, - 0, - dataset_d.data(), - centroids_d.data(), - labels_d.data(), - &inertia, - &n_iter, - &predict_inertia, - &cluster_cost); - - ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, - centroids_d.data(), - kNClusters * kNFeatures, - cuvs::CompareApprox(1e-4f))); - ASSERT_TRUE(cuvs::devArrMatchHost( - kExpectedLabels, labels_d.data(), kNSamples, cuvs::Compare())); - - EXPECT_GT(n_iter, 0); - EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); - EXPECT_NEAR(predict_inertia, kExpectedInertia, 1e-4); - EXPECT_NEAR(cluster_cost, kExpectedInertia, 1e-4); - } - - template - void test_fit_host(RunFn run_fn) - { - raft::handle_t handle; - auto stream = raft::resource::get_cuda_stream(handle); - - rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); - raft::copy(centroids_d.data(), - reinterpret_cast(kInitCentroids), - kNClusters * kNFeatures, - stream); - - double inertia = -1.0; - int n_iter = -1; - double unused_predict = 0.0; - double unused_cost = 0.0; - - run_fn(kNSamples, - kNFeatures, - kNClusters, - 100, - 1e-6, - Array, - true, - 4, // force at least 2 streamed batches - reinterpret_cast(kDataset), - centroids_d.data(), - nullptr, - &inertia, - &n_iter, - &unused_predict, - &unused_cost); - - ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, - centroids_d.data(), - kNClusters * kNFeatures, - cuvs::CompareApprox(1e-4f))); - - EXPECT_GT(n_iter, 0); - EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); - } - - } // namespace - -TEST(KMeansC, FitPredict) { test_fit_predict(run_kmeans); } +float kExpectedCentroids[kNClusters * kNFeatures] = {1.5f, 1.5f, 10.5f, 10.5f}; +int32_t kExpectedLabels[kNSamples] = {0, 0, 0, 0, 1, 1, 1, 1}; + +// 8 points, each at squared distance 0.5 from its cluster mean -> 4.0. +constexpr double kExpectedInertia = 4.0; + +// Type-erased dispatcher to exercise both the v1 and v2 entry points with +// shared test bodies. +struct kmeans_api_v1 { + using params_t = cuvsKMeansParams_t; + static cuvsError_t params_create(params_t* p) { return cuvsKMeansParamsCreate(p); } + static cuvsError_t params_destroy(params_t p) { return cuvsKMeansParamsDestroy(p); } + static cuvsError_t fit(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) + { + return cuvsKMeansFit(res, params, dataset, NULL, centroids, inertia, n_iter); + } + static cuvsError_t predict(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + DLManagedTensor* labels, + double* inertia) + { + return cuvsKMeansPredict( + res, params, dataset, NULL, centroids, labels, false, inertia); + } +}; + +struct kmeans_api_v2 { + using params_t = cuvsKMeansParams_v2_t; + static cuvsError_t params_create(params_t* p) { return cuvsKMeansParamsCreate_v2(p); } + static cuvsError_t params_destroy(params_t p) { return cuvsKMeansParamsDestroy_v2(p); } + static cuvsError_t fit(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) + { + return cuvsKMeansFit_v2(res, params, dataset, NULL, centroids, inertia, n_iter); + } + static cuvsError_t predict(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + DLManagedTensor* labels, + double* inertia) + { + return cuvsKMeansPredict_v2( + res, params, dataset, NULL, centroids, labels, false, inertia); + } +}; + +template +void test_fit_predict() +{ + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector dataset_d(kNSamples * kNFeatures, stream); + rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); + rmm::device_uvector labels_d(kNSamples, stream); + + raft::copy(dataset_d.data(), + reinterpret_cast(kDataset), + kNSamples * kNFeatures, + stream); + raft::copy(centroids_d.data(), + reinterpret_cast(kInitCentroids), + kNClusters * kNFeatures, + stream); + + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + typename Api::params_t params; + ASSERT_EQ(Api::params_create(¶ms), CUVS_SUCCESS); + params->n_clusters = kNClusters; + params->max_iter = 100; + params->tol = 1e-6; + params->init = Array; + params->streaming_batch_size = 0; + + DLManagedTensor dataset_t{}; + cuvs::core::to_dlpack( + raft::make_device_matrix_view(dataset_d.data(), kNSamples, kNFeatures), + &dataset_t); + + DLManagedTensor centroids_t{}; + cuvs::core::to_dlpack( + raft::make_device_matrix_view(centroids_d.data(), kNClusters, kNFeatures), + ¢roids_t); + + DLManagedTensor labels_t{}; + cuvs::core::to_dlpack( + raft::make_device_vector_view(labels_d.data(), kNSamples), &labels_t); + + double inertia = -1.0; + int n_iter = -1; + double predict_inertia = -1.0; + double cluster_cost = -1.0; + + ASSERT_EQ(Api::fit(res, params, &dataset_t, ¢roids_t, &inertia, &n_iter), CUVS_SUCCESS); + ASSERT_EQ(Api::predict(res, params, &dataset_t, ¢roids_t, &labels_t, &predict_inertia), + CUVS_SUCCESS); + ASSERT_EQ(cuvsKMeansClusterCost(res, &dataset_t, ¢roids_t, &cluster_cost), CUVS_SUCCESS); + + ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, + centroids_d.data(), + kNClusters * kNFeatures, + cuvs::CompareApprox(1e-4f))); + ASSERT_TRUE(cuvs::devArrMatchHost( + kExpectedLabels, labels_d.data(), kNSamples, cuvs::Compare())); + + EXPECT_GT(n_iter, 0); + EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); + EXPECT_NEAR(predict_inertia, kExpectedInertia, 1e-4); + EXPECT_NEAR(cluster_cost, kExpectedInertia, 1e-4); + + labels_t.deleter(&labels_t); + centroids_t.deleter(¢roids_t); + dataset_t.deleter(&dataset_t); + + ASSERT_EQ(Api::params_destroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +template +void test_fit_host() +{ + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); + raft::copy(centroids_d.data(), + reinterpret_cast(kInitCentroids), + kNClusters * kNFeatures, + stream); + + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + typename Api::params_t params; + ASSERT_EQ(Api::params_create(¶ms), CUVS_SUCCESS); + params->n_clusters = kNClusters; + params->max_iter = 100; + params->tol = 1e-6; + params->init = Array; + params->streaming_batch_size = 4; // force at least 2 streamed batches + + DLManagedTensor dataset_t{}; + cuvs::core::to_dlpack( + raft::make_host_matrix_view( + reinterpret_cast(kDataset), kNSamples, kNFeatures), + &dataset_t); + + DLManagedTensor centroids_t{}; + cuvs::core::to_dlpack( + raft::make_device_matrix_view(centroids_d.data(), kNClusters, kNFeatures), + ¢roids_t); + + double inertia = -1.0; + int n_iter = -1; + + ASSERT_EQ(Api::fit(res, params, &dataset_t, ¢roids_t, &inertia, &n_iter), CUVS_SUCCESS); + + ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, + centroids_d.data(), + kNClusters * kNFeatures, + cuvs::CompareApprox(1e-4f))); + + EXPECT_GT(n_iter, 0); + EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); + + centroids_t.deleter(¢roids_t); + dataset_t.deleter(&dataset_t); + + ASSERT_EQ(Api::params_destroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +} // namespace + +TEST(KMeansC, FitPredict) { test_fit_predict(); } // TODO(cuVS 26.08): remove FitPredictV2 once `_v2` is promoted to the // unsuffixed ABI -- it will be redundant with FitPredict at that point. -TEST(KMeansC, FitPredictV2) { test_fit_predict(run_kmeans_v2); } +TEST(KMeansC, FitPredictV2) { test_fit_predict(); } -TEST(KMeansC, FitHost) { test_fit_host(run_kmeans); } +TEST(KMeansC, FitHost) { test_fit_host(); } // TODO(cuVS 26.08): remove FitHostV2 once `_v2` is promoted to the // unsuffixed ABI. -TEST(KMeansC, FitHostV2) { test_fit_host(run_kmeans_v2); } - - TEST(KMeansC, ParamsCreateDestroy) - { - cuvsKMeansParams_t params = nullptr; - ASSERT_EQ(cuvsKMeansParamsCreate(¶ms), CUVS_SUCCESS); - ASSERT_NE(params, nullptr); - EXPECT_GT(params->n_clusters, 0); - EXPECT_GT(params->max_iter, 0); - ASSERT_EQ(cuvsKMeansParamsDestroy(params), CUVS_SUCCESS); - } +TEST(KMeansC, FitHostV2) { test_fit_host(); } + +TEST(KMeansC, ParamsCreateDestroy) +{ + cuvsKMeansParams_t params = nullptr; + ASSERT_EQ(cuvsKMeansParamsCreate(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_clusters, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsKMeansParamsDestroy(params), CUVS_SUCCESS); +} // TODO(cuVS 26.08): remove ParamsCreateDestroyV2 once cuvsKMeansParamsCreate_v2 // / cuvsKMeansParamsDestroy_v2 are promoted to the unsuffixed entry points and // the `_v2` symbols are deleted from the public header. TEST(KMeansC, ParamsCreateDestroyV2) { - cuvsKMeansParams_v2_t params = nullptr; - ASSERT_EQ(cuvsKMeansParamsCreate_v2(¶ms), CUVS_SUCCESS); - ASSERT_NE(params, nullptr); - EXPECT_GT(params->n_clusters, 0); - EXPECT_GT(params->max_iter, 0); - ASSERT_EQ(cuvsKMeansParamsDestroy_v2(params), CUVS_SUCCESS); - } + cuvsKMeansParams_v2_t params = nullptr; + ASSERT_EQ(cuvsKMeansParamsCreate_v2(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_clusters, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsKMeansParamsDestroy_v2(params), CUVS_SUCCESS); +} diff --git a/c/tests/cluster/run_kmeans_c.c b/c/tests/cluster/run_kmeans_c.c deleted file mode 100644 index c2b140db9d..0000000000 --- a/c/tests/cluster/run_kmeans_c.c +++ /dev/null @@ -1,200 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - - #include - #include - #include - #include - #include - #include - - static void fill_matrix_tensor(DLManagedTensor* t, - void* data, - int64_t* shape, - DLDeviceType device_type, - uint8_t code, - uint8_t bits) - { - t->dl_tensor.data = data; - t->dl_tensor.device.device_type = device_type; - t->dl_tensor.device.device_id = 0; - t->dl_tensor.ndim = 2; - t->dl_tensor.dtype.code = code; - t->dl_tensor.dtype.bits = bits; - t->dl_tensor.dtype.lanes = 1; - t->dl_tensor.shape = shape; - t->dl_tensor.strides = NULL; - t->dl_tensor.byte_offset = 0; - t->manager_ctx = NULL; - t->deleter = NULL; - } - - static void fill_vector_tensor(DLManagedTensor* t, - void* data, - int64_t* shape, - DLDeviceType device_type, - uint8_t code, - uint8_t bits) - { - t->dl_tensor.data = data; - t->dl_tensor.device.device_type = device_type; - t->dl_tensor.device.device_id = 0; - t->dl_tensor.ndim = 1; - t->dl_tensor.dtype.code = code; - t->dl_tensor.dtype.bits = bits; - t->dl_tensor.dtype.lanes = 1; - t->dl_tensor.shape = shape; - t->dl_tensor.strides = NULL; - t->dl_tensor.byte_offset = 0; - t->manager_ctx = NULL; - t->deleter = NULL; - } - - /** - * Run KMeans fit + (optional) predict + cluster_cost using the C API. - * - * If `dataset_on_host` is true, `dataset_data` is a host pointer, otherwise it is a - * device pointer. `centroids_data` and `labels_data` are always device pointers. - * - * `predict_inertia_out`/`labels_data`/`cluster_cost_out` are only used when - * `dataset_on_host` is false (predict + cluster_cost require device data). - */ - void run_kmeans(int64_t n_samples, - int64_t n_features, - int n_clusters, - int max_iter, - double tol, - cuvsKMeansInitMethod init, - bool dataset_on_host, - int64_t streaming_batch_size, - void* dataset_data, - float* centroids_data, - int32_t* labels_data, - double* inertia_out, - int* n_iter_out, - double* predict_inertia_out, - double* cluster_cost_out) - { - cuvsResources_t res; - cuvsResourcesCreate(&res); - - cuvsKMeansParams_t params; - cuvsKMeansParamsCreate(¶ms); - params->n_clusters = n_clusters; - params->max_iter = max_iter; - params->tol = tol; - params->init = init; - params->streaming_batch_size = streaming_batch_size; - - DLManagedTensor dataset_tensor; - int64_t dataset_shape[2] = {n_samples, n_features}; - fill_matrix_tensor(&dataset_tensor, - dataset_data, - dataset_shape, - dataset_on_host ? kDLCPU : kDLCUDA, - kDLFloat, - 32); - - DLManagedTensor centroids_tensor; - int64_t centroids_shape[2] = {n_clusters, n_features}; - fill_matrix_tensor( - ¢roids_tensor, centroids_data, centroids_shape, kDLCUDA, kDLFloat, 32); - - cuvsKMeansFit( - res, params, &dataset_tensor, NULL, ¢roids_tensor, inertia_out, n_iter_out); - - if (!dataset_on_host) { - DLManagedTensor labels_tensor; - int64_t labels_shape[1] = {n_samples}; - fill_vector_tensor(&labels_tensor, labels_data, labels_shape, kDLCUDA, kDLInt, 32); - - cuvsKMeansPredict(res, - params, - &dataset_tensor, - NULL, - ¢roids_tensor, - &labels_tensor, - false, - predict_inertia_out); - - cuvsKMeansClusterCost(res, &dataset_tensor, ¢roids_tensor, cluster_cost_out); - } - - cuvsKMeansParamsDestroy(params); - cuvsResourcesDestroy(res); - } - -/** - * Run KMeans fit + (optional) predict + cluster_cost. - * - * TODO(cuVS 26.08): delete run_kmeans_v2 once the `_v2` entry points - * (cuvsKMeansFit_v2 / cuvsKMeansPredict_v2 / cuvsKMeansParamsCreate_v2 / - * cuvsKMeansParamsDestroy_v2) are promoted to the unsuffixed names in the - * public header. - */ -void run_kmeans_v2(int64_t n_samples, - int64_t n_features, - int n_clusters, - int max_iter, - double tol, - cuvsKMeansInitMethod init, - bool dataset_on_host, - int64_t streaming_batch_size, - void* dataset_data, - float* centroids_data, - int32_t* labels_data, - double* inertia_out, - int* n_iter_out, - double* predict_inertia_out, - double* cluster_cost_out) - { - cuvsResources_t res; - cuvsResourcesCreate(&res); - - cuvsKMeansParams_v2_t params; - cuvsKMeansParamsCreate_v2(¶ms); - params->n_clusters = n_clusters; - params->max_iter = max_iter; - params->tol = tol; - params->init = init; - params->streaming_batch_size = streaming_batch_size; - - DLManagedTensor dataset_tensor; - int64_t dataset_shape[2] = {n_samples, n_features}; - fill_matrix_tensor(&dataset_tensor, - dataset_data, - dataset_shape, - dataset_on_host ? kDLCPU : kDLCUDA, - kDLFloat, - 32); - - DLManagedTensor centroids_tensor; - int64_t centroids_shape[2] = {n_clusters, n_features}; - fill_matrix_tensor( - ¢roids_tensor, centroids_data, centroids_shape, kDLCUDA, kDLFloat, 32); - - cuvsKMeansFit_v2( - res, params, &dataset_tensor, NULL, ¢roids_tensor, inertia_out, n_iter_out); - - if (!dataset_on_host) { - DLManagedTensor labels_tensor; - int64_t labels_shape[1] = {n_samples}; - fill_vector_tensor(&labels_tensor, labels_data, labels_shape, kDLCUDA, kDLInt, 32); - - cuvsKMeansPredict_v2(res, - params, - &dataset_tensor, - NULL, - ¢roids_tensor, - &labels_tensor, - false, - predict_inertia_out); - - cuvsKMeansClusterCost(res, &dataset_tensor, ¢roids_tensor, cluster_cost_out); - } - - cuvsKMeansParamsDestroy_v2(params); - cuvsResourcesDestroy(res); - } From 9a9b8ee636402654d282034b5edf06835ce1a834 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 4 May 2026 23:16:38 -0700 Subject: [PATCH 42/50] cache device weights --- cpp/src/cluster/detail/kmeans.cuh | 52 +++++++++++++-- .../detail/minClusterDistanceCompute.cu | 66 +++++++++++-------- 2 files changed, 84 insertions(+), 34 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index cba7742975..29d0f2a881 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -700,15 +700,35 @@ void kmeans_fit( cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( X.data_handle(), n_samples, n_features, streaming_batch_size, stream); - // Only materialize weight batches when weights are provided; otherwise we - // never touch this iterator (and never dereference a null-pointer batch). + // Host-path weight batches: only materialized when weights are provided and + // the data resides on host std::optional> weight_batches; - if (weight_ptr != nullptr) { - weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream); - } else { + if constexpr (!data_on_device) { + if (weight_ptr != nullptr) { + weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream); + } else { + raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); + } + } else if (weight_ptr == nullptr) { raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); } + std::optional> prenormalized_weights; + if constexpr (data_on_device) { + if (weight_ptr != nullptr) { + prenormalized_weights.emplace( + raft::make_device_vector(handle, n_samples)); + const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); + raft::linalg::map( + handle, + prenormalized_weights->view(), + [n_samples, d_wt_sum_ptr] __device__(DataT w) { + return w * static_cast(n_samples) / *d_wt_sum_ptr; + }, + sample_weight.value()); + } + } + // Copies and rescales `wt_data` into `batch_weights_buf` so that weights // are normalized to sum to n_samples. const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); @@ -729,6 +749,20 @@ void kmeans_fit( cur_batch_size); }; + // Returns the weights view to feed into `process_batch` / `cluster_cost` + // for the current batch. On the device path this is a sub-range of the + // pre-normalized buffer; on the host path it delegates to `prepare_batch_weights`. + auto cur_batch_weights = [&](IndexT batch_offset, const DataT* wt_data, IndexT cur_batch_size) { + if constexpr (data_on_device) { + const DataT* base = prenormalized_weights.has_value() + ? prenormalized_weights->data_handle() + batch_offset + : batch_weights_buf.data_handle(); + return raft::make_device_vector_view(base, cur_batch_size); + } else { + return prepare_batch_weights(wt_data, cur_batch_size); + } + }; + RAFT_LOG_DEBUG( "KMeans.fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", static_cast(n_samples), @@ -829,7 +863,8 @@ void kmeans_fit( auto batch_data_view = raft::make_device_matrix_view( data_batch.data(), cur_batch_size, n_features); - auto batch_weights_view = prepare_batch_weights(wt_data, cur_batch_size); + auto batch_weights_view = + cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); auto minCAD_view = raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle(), cur_batch_size); @@ -931,7 +966,10 @@ void kmeans_fit( data_batch.data(), cur_batch_size, n_features); std::optional> batch_sw = std::nullopt; - if (weight_ptr != nullptr) { batch_sw = prepare_batch_weights(wt_data, cur_batch_size); } + if (weight_ptr != nullptr) { + batch_sw = + cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); + } DataT batch_cost = DataT{0}; cuvs::cluster::kmeans::cluster_cost(handle, diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 52a5204baa..60bfb9918a 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -38,16 +38,22 @@ void minClusterAndDistanceCompute( metric == cuvs::distance::DistanceType::CosineExpanded; if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, centroids, centroidsNorm, raft::sqrt_op{}); + const DataT* centroidsNorm_ptr = nullptr; + if (precomputed_centroid_norms.has_value()) { + centroidsNorm_ptr = precomputed_centroid_norms->data_handle(); } else { - raft::linalg::norm( - handle, centroids, centroidsNorm); + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, centroids, centroidsNorm); + } + centroidsNorm_ptr = centroidsNorm.data_handle(); } raft::KeyValuePair initial_value(0, std::numeric_limits::max()); @@ -60,7 +66,7 @@ void minClusterAndDistanceCompute( X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_ptr, n_samples, n_clusters, n_features, @@ -193,23 +199,29 @@ void minClusterDistanceCompute( raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm, - raft::sqrt_op{}); + const DataT* centroidsNorm_ptr = nullptr; + if (precomputed_centroid_norms.has_value()) { + centroidsNorm_ptr = precomputed_centroid_norms->data_handle(); } else { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm); + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm, + raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm); + } + centroidsNorm_ptr = centroidsNorm.data_handle(); } workspace.resize(sizeof(int) * n_samples, stream); @@ -219,7 +231,7 @@ void minClusterDistanceCompute( X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_ptr, n_samples, n_clusters, n_features, From a800b279fe06d5cbaf661215afc37a0d7520b5ab Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 4 May 2026 23:47:43 -0700 Subject: [PATCH 43/50] rm event --- cpp/src/cluster/detail/kmeans.cuh | 42 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 29d0f2a881..c7ad1c6f11 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -716,8 +716,7 @@ void kmeans_fit( std::optional> prenormalized_weights; if constexpr (data_on_device) { if (weight_ptr != nullptr) { - prenormalized_weights.emplace( - raft::make_device_vector(handle, n_samples)); + prenormalized_weights.emplace(raft::make_device_vector(handle, n_samples)); const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); raft::linalg::map( handle, @@ -775,7 +774,6 @@ void kmeans_fit( bool use_norm_cache = need_compute_norms && !data_on_device; auto h_norm_cache = raft::make_pinned_vector(handle, use_norm_cache ? n_samples : 0); - bool norms_cached = false; auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { auto batch_view = @@ -786,6 +784,13 @@ void kmeans_fit( handle, batch_view, norm_view); }; + // Device path: compute X norms once up front + if constexpr (data_on_device) { + if (need_compute_norms) { + compute_batch_norms(X.data_handle(), static_cast(n_samples)); + } + } + std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); @@ -812,12 +817,9 @@ void kmeans_fit( auto d_done_flag = raft::make_device_scalar(handle, 0); auto h_done_flag = raft::make_pinned_scalar(handle, 0); - cudaEvent_t convergence_event; - RAFT_CUDA_TRY(cudaEventCreateWithFlags(&convergence_event, cudaEventDisableTiming)); - for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) { if (n_current_iter > 1) { - RAFT_CUDA_TRY(cudaEventSynchronize(convergence_event)); + raft::resource::sync_stream(handle); if (*h_done_flag.data_handle()) { n_current_iter--; RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", @@ -869,19 +871,21 @@ void kmeans_fit( auto minCAD_view = raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle(), cur_batch_size); - if (need_compute_norms) { - compute_batch_norms(data_batch.data(), cur_batch_size); - if (use_norm_cache) { - raft::copy(h_norm_cache.data_handle() + data_batch.offset(), - L2NormBatch.data_handle(), + if constexpr (!data_on_device) { + if (need_compute_norms) { + compute_batch_norms(data_batch.data(), cur_batch_size); + if (use_norm_cache) { + raft::copy(h_norm_cache.data_handle() + data_batch.offset(), + L2NormBatch.data_handle(), + cur_batch_size, + stream); + } + } else if (use_norm_cache) { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data_handle() + data_batch.offset(), cur_batch_size, stream); } - } else if (use_norm_cache) { - raft::copy(L2NormBatch.data_handle(), - h_norm_cache.data_handle() + data_batch.offset(), - cur_batch_size, - stream); } auto l2_const_view = raft::make_device_vector_view( @@ -904,7 +908,6 @@ void kmeans_fit( batch_workspace, centroid_norms_opt); } - if (!norms_cached && use_norm_cache) { norms_cached = true; } finalize_centroids(handle, raft::make_const_mdspan(centroid_sums.view()), @@ -937,11 +940,8 @@ void kmeans_fit( raft::copy(handle, raft::make_pinned_scalar_view(h_done_flag.data_handle()), raft::make_device_scalar_view(d_done_flag.data_handle())); - RAFT_CUDA_TRY(cudaEventRecord(convergence_event, stream)); } - RAFT_CUDA_TRY(cudaEventDestroy(convergence_event)); - { auto centroids_const = raft::make_device_matrix_view( cur_centroids_ptr, n_clusters, n_features); From 3db8582c369940d6243701b896d2a6ab2b7569a5 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 4 May 2026 23:58:59 -0700 Subject: [PATCH 44/50] update names --- cpp/src/cluster/detail/kmeans.cuh | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index c7ad1c6f11..072d19b2f0 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -639,10 +639,7 @@ void kmeans_fit( auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, raft::device_matrix_view centroidsRawData) { if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - raft::copy( - handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); + raft::copy(handle, centroidsRawData.view(), centroids.view()); return; } @@ -680,10 +677,10 @@ void kmeans_fit( } IndexT centroid_buf_size = n_clusters * n_features; - rmm::device_uvector centroid_buf_A(centroid_buf_size, stream); - rmm::device_uvector centroid_buf_B(centroid_buf_size, stream); - DataT* cur_centroids_ptr = centroid_buf_A.data(); - DataT* new_centroids_ptr = centroid_buf_B.data(); + rmm::device_uvector cur_centroids_buf(centroid_buf_size, stream); + rmm::device_uvector new_centroids_buf(centroid_buf_size, stream); + DataT* cur_centroids_ptr = cur_centroids_buf.data(); + DataT* new_centroids_ptr = new_centroids_buf.data(); auto minClusterAndDistance = raft::make_device_vector, IndexT>( handle, streaming_batch_size); @@ -803,8 +800,8 @@ void kmeans_fit( n_init, (unsigned long long)iter_params.rng_state.seed); - cur_centroids_ptr = centroid_buf_A.data(); - new_centroids_ptr = centroid_buf_B.data(); + cur_centroids_ptr = cur_centroids_buf.data(); + new_centroids_ptr = new_centroids_buf.data(); init_centroids( iter_params, raft::make_device_matrix_view(cur_centroids_ptr, n_clusters, n_features)); From c048352a9cb27a5b5a7289984c0699b3e9fe5882 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 5 May 2026 00:14:49 -0700 Subject: [PATCH 45/50] rename --- cpp/src/cluster/detail/kmeans.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 072d19b2f0..0ac96e9f03 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -639,7 +639,7 @@ void kmeans_fit( auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, raft::device_matrix_view centroidsRawData) { if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - raft::copy(handle, centroidsRawData.view(), centroids.view()); + raft::copy(handle, centroidsRawData, centroids); return; } From 2f968f8063f0cb830a9a5be6242867bafe844a8e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 5 May 2026 08:51:55 -0700 Subject: [PATCH 46/50] rm docs --- cpp/src/cluster/detail/kmeans.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 0ac96e9f03..19fa7592de 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -745,9 +745,6 @@ void kmeans_fit( cur_batch_size); }; - // Returns the weights view to feed into `process_batch` / `cluster_cost` - // for the current batch. On the device path this is a sub-range of the - // pre-normalized buffer; on the host path it delegates to `prepare_batch_weights`. auto cur_batch_weights = [&](IndexT batch_offset, const DataT* wt_data, IndexT cur_batch_size) { if constexpr (data_on_device) { const DataT* base = prenormalized_weights.has_value() From affe85abf79e89a99f965957fd08e4ac430d5283 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 5 May 2026 08:55:07 -0700 Subject: [PATCH 47/50] empty From c6dea64dcb8dfe3f62729fbf657ba50ea5c58c69 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 5 May 2026 09:39:20 -0700 Subject: [PATCH 48/50] fix norm cache --- cpp/src/cluster/detail/kmeans.cuh | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 19fa7592de..6e32522bcb 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -765,9 +765,9 @@ void kmeans_fit( bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded; - bool use_norm_cache = need_compute_norms && !data_on_device; - auto h_norm_cache = - raft::make_pinned_vector(handle, use_norm_cache ? n_samples : 0); + auto h_norm_cache = raft::make_pinned_vector( + handle, (need_compute_norms && !data_on_device) ? n_samples : 0); + bool norms_cached = false; auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { auto batch_view = @@ -867,18 +867,18 @@ void kmeans_fit( if constexpr (!data_on_device) { if (need_compute_norms) { - compute_batch_norms(data_batch.data(), cur_batch_size); - if (use_norm_cache) { + if (!norms_cached) { + compute_batch_norms(data_batch.data(), cur_batch_size); raft::copy(h_norm_cache.data_handle() + data_batch.offset(), L2NormBatch.data_handle(), cur_batch_size, stream); + } else { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data_handle() + data_batch.offset(), + cur_batch_size, + stream); } - } else if (use_norm_cache) { - raft::copy(L2NormBatch.data_handle(), - h_norm_cache.data_handle() + data_batch.offset(), - cur_batch_size, - stream); } } @@ -902,6 +902,7 @@ void kmeans_fit( batch_workspace, centroid_norms_opt); } + if (need_compute_norms) { norms_cached = true; } finalize_centroids(handle, raft::make_const_mdspan(centroid_sums.view()), From 7dfab3e0d16a2f2afac0406ca87b0e1927abb0d9 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 5 May 2026 22:20:47 -0700 Subject: [PATCH 49/50] revert changes to minClusterDistanceCompute --- cpp/src/cluster/detail/kmeans.cuh | 13 +-- cpp/src/cluster/detail/kmeans_common.cuh | 22 ++--- .../detail/minClusterDistanceCompute.cu | 99 ++++++++----------- 3 files changed, 47 insertions(+), 87 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 6e32522bcb..297b4b2fd3 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -690,7 +690,6 @@ void kmeans_fit( auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); - auto centroid_norms_buf = raft::make_device_vector(handle, n_clusters); auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); rmm::device_uvector batch_workspace(streaming_batch_size, stream); @@ -833,15 +832,6 @@ void kmeans_fit( auto new_centroids_view = raft::make_device_matrix_view(new_centroids_ptr, n_clusters, n_features); - std::optional> centroid_norms_opt = - std::nullopt; - if (need_compute_norms) { - raft::linalg::norm( - handle, centroids_const, centroid_norms_buf.view()); - centroid_norms_opt = raft::make_device_vector_view( - centroid_norms_buf.data_handle(), n_clusters); - } - data_batches.reset(); using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; std::optional wt_it; @@ -899,8 +889,7 @@ void kmeans_fit( centroid_sums.view(), weight_per_cluster.view(), clustering_cost.view(), - batch_workspace, - centroid_norms_opt); + batch_workspace); } if (need_compute_norms) { norms_cached = true; } diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 3907fd59dd..348a0ec0e0 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -358,9 +358,7 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace, - std::optional> precomputed_centroid_norms = - std::nullopt); + rmm::device_uvector& workspace); #define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ extern template void minClusterAndDistanceCompute( \ @@ -373,8 +371,7 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace, \ - std::optional>); + rmm::device_uvector& workspace); EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int) @@ -393,9 +390,7 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace, - std::optional> - precomputed_centroid_norms = std::nullopt); + rmm::device_uvector& workspace); #define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \ extern template void minClusterDistanceCompute( \ @@ -408,8 +403,7 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace, \ - std::optional>); + rmm::device_uvector& workspace); EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t) @@ -655,8 +649,6 @@ __device__ void check_convergence(raft::device_scalar_view clusteri * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) * @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into) * @param[inout] clustering_cost Running cost scalar (device) (added into) - * @param[in] centroid_norms Optional precomputed centroid norms [n_clusters]. - * When provided, skips internal centroid norm computation. */ template void process_batch( @@ -674,8 +666,7 @@ void process_batch( raft::device_matrix_view centroid_sums, raft::device_vector_view weight_per_cluster, raft::device_scalar_view clustering_cost, - rmm::device_uvector& batch_workspace, - std::optional> centroid_norms = std::nullopt) + rmm::device_uvector& batch_workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -688,8 +679,7 @@ void process_batch( metric, batch_samples_param, batch_centroids_param, - workspace, - centroid_norms); + workspace); KeyValueIndexOp conversion_op; thrust::transform_iterator, diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 60bfb9918a..e271a861e8 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -8,8 +8,6 @@ #include -#include - namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an @@ -26,8 +24,7 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace, - std::optional> precomputed_centroid_norms) + rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -38,22 +35,16 @@ void minClusterAndDistanceCompute( metric == cuvs::distance::DistanceType::CosineExpanded; if (is_fused) { - const DataT* centroidsNorm_ptr = nullptr; - if (precomputed_centroid_norms.has_value()) { - centroidsNorm_ptr = precomputed_centroid_norms->data_handle(); + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); } else { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, centroids, centroidsNorm, raft::sqrt_op{}); - } else { - raft::linalg::norm( - handle, centroids, centroidsNorm); - } - centroidsNorm_ptr = centroidsNorm.data_handle(); + raft::linalg::norm( + handle, centroids, centroidsNorm); } raft::KeyValuePair initial_value(0, std::numeric_limits::max()); @@ -66,7 +57,7 @@ void minClusterAndDistanceCompute( X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm_ptr, + centroidsNorm.data_handle(), n_samples, n_clusters, n_features, @@ -163,8 +154,7 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace, \ - std::optional>); + rmm::device_uvector& workspace); INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int64_t) @@ -174,18 +164,16 @@ INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int) #undef INSTANTIATE_MIN_CLUSTER_AND_DISTANCE template -void minClusterDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace, - std::optional> precomputed_centroid_norms) +void minClusterDistanceCompute(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view minClusterDistance, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -199,29 +187,23 @@ void minClusterDistanceCompute( raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); if (is_fused) { - const DataT* centroidsNorm_ptr = nullptr; - if (precomputed_centroid_norms.has_value()) { - centroidsNorm_ptr = precomputed_centroid_norms->data_handle(); + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm, + raft::sqrt_op{}); } else { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm, - raft::sqrt_op{}); - } else { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - centroidsNorm); - } - centroidsNorm_ptr = centroidsNorm.data_handle(); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm); } workspace.resize(sizeof(int) * n_samples, stream); @@ -231,7 +213,7 @@ void minClusterDistanceCompute( X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm_ptr, + centroidsNorm.data_handle(), n_samples, n_clusters, n_features, @@ -301,8 +283,7 @@ void minClusterDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace, \ - std::optional>); + rmm::device_uvector& workspace); INSTANTIATE_MIN_CLUSTER_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_DISTANCE(double, int64_t) From 7a383da1f0f79daa229e372bfde3942efa5338e5 Mon Sep 17 00:00:00 2001 From: tarang-jain Date: Tue, 5 May 2026 23:02:04 -0700 Subject: [PATCH 50/50] update tests to use mdspan instead of rmm --- cpp/tests/cluster/kmeans.cu | 157 +++++++++++++++--------------------- 1 file changed, 67 insertions(+), 90 deletions(-) diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 6261da7afe..08aaa0949a 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -6,6 +6,8 @@ #include "../test_utils.cuh" #include +#include +#include #include #include #include @@ -17,8 +19,6 @@ #include -#include - #include #include @@ -361,15 +361,7 @@ struct KmeansBatchedInputs { template class KmeansFitBatchedTest : public ::testing::TestWithParam> { protected: - KmeansFitBatchedTest() - : d_labels(0, raft::resource::get_cuda_stream(handle)), - d_labels_ref(0, raft::resource::get_cuda_stream(handle)), - d_centroids(0, raft::resource::get_cuda_stream(handle)), - d_centroids_ref(0, raft::resource::get_cuda_stream(handle)), - d_X(0, raft::resource::get_cuda_stream(handle)), - d_sample_weight(0, raft::resource::get_cuda_stream(handle)) - { - } + KmeansFitBatchedTest() = default; void prepareBlobInputs() { @@ -379,10 +371,10 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(n_samples) * n_features, stream); - d_labels_ref.resize(n_samples, stream); - raft::random::make_blobs(d_X.data(), - d_labels_ref.data(), + d_X.emplace(raft::make_device_matrix(handle, n_samples, n_features)); + d_labels_ref.emplace(raft::make_device_vector(handle, n_samples)); + raft::random::make_blobs(d_X->data_handle(), + d_labels_ref->data_handle(), n_samples, n_features, n_clusters, @@ -396,37 +388,24 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(n_samples) * n_features); - raft::update_host(h_X.data(), d_X.data(), h_X.size(), stream); + h_X.emplace(raft::make_host_matrix(n_samples, n_features)); + raft::update_host( + h_X->data_handle(), d_X->data_handle(), static_cast(n_samples) * n_features, stream); if (testparams.weighted) { - d_sample_weight.resize(n_samples, stream); - raft::matrix::fill( - handle, raft::make_device_vector_view(d_sample_weight.data(), n_samples), T(1)); + d_sample_weight.emplace(raft::make_device_vector(handle, n_samples)); + raft::matrix::fill(handle, d_sample_weight->view(), T(1)); } else { - d_sample_weight.resize(0, stream); + d_sample_weight.reset(); } raft::resource::sync_stream(handle, stream); } - raft::device_matrix_view X_dview() const - { - return raft::make_device_matrix_view( - d_X.data(), testparams.n_row, testparams.n_col); - } - - raft::host_matrix_view h_X_view() const - { - return raft::make_host_matrix_view( - h_X.data(), testparams.n_row, testparams.n_col); - } - std::optional> d_sw_view() const { - if (!testparams.weighted) return std::nullopt; - return std::make_optional( - raft::make_device_vector_view(d_sample_weight.data(), testparams.n_row)); + if (!d_sample_weight.has_value()) return std::nullopt; + return std::make_optional(raft::make_const_mdspan(d_sample_weight->view())); } void fitBatchedTest() @@ -440,23 +419,24 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(handle, n_samples)); + d_centroids.emplace(raft::make_device_matrix(handle, params.n_clusters, n_features)); + d_centroids_ref.emplace( + raft::make_device_matrix(handle, params.n_clusters, n_features)); raft::random::RngState rng(params.rng_state.seed); raft::random::uniform( - handle, rng, d_centroids.data(), params.n_clusters * n_features, T(-1), T(1)); - raft::copy(d_centroids_ref.data(), d_centroids.data(), params.n_clusters * n_features, stream); + handle, rng, d_centroids->data_handle(), params.n_clusters * n_features, T(-1), T(1)); + raft::copy(d_centroids_ref->data_handle(), + d_centroids->data_handle(), + params.n_clusters * n_features, + stream); - auto d_centroids_view = - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + auto d_centroids_view = raft::make_device_matrix_view( + d_centroids->data_handle(), params.n_clusters, n_features); auto d_sw = d_sw_view(); - auto d_centroids_ref_view = - raft::make_device_matrix_view(d_centroids_ref.data(), params.n_clusters, n_features); - params.init = cuvs::cluster::kmeans::params::Array; params.max_iter = 20; @@ -464,9 +444,9 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamview()), d_sw, - d_centroids_ref_view, + d_centroids_ref->view(), raft::make_host_scalar_view(&ref_inertia), raft::make_host_scalar_view(&ref_n_iter)); @@ -474,11 +454,10 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam> h_sw = std::nullopt; - std::vector h_sample_weight; + auto h_sample_weight = raft::make_host_vector(testparams.weighted ? n_samples : 0); if (testparams.weighted) { - h_sample_weight.resize(n_samples, T(1)); - h_sw = std::make_optional( - raft::make_host_vector_view(h_sample_weight.data(), n_samples)); + std::fill_n(h_sample_weight.data_handle(), n_samples, T(1)); + h_sw = std::make_optional(raft::make_const_mdspan(h_sample_weight.view())); } T inertia = 0; @@ -486,7 +465,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamview()), h_sw, d_centroids_view, raft::make_host_scalar_view(&inertia), @@ -494,48 +473,46 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamdata_handle(), + d_centroids->data_handle(), params.n_clusters, n_features, CompareApprox(T(1e-2)), stream); T ref_pred_inertia = 0; - cuvs::cluster::kmeans::predict( - handle, - params, - X_dview(), - d_sw, - raft::make_device_matrix_view( - d_centroids_ref.data(), params.n_clusters, n_features), - raft::make_device_vector_view(d_labels_ref.data(), n_samples), - true, - raft::make_host_scalar_view(&ref_pred_inertia)); + cuvs::cluster::kmeans::predict(handle, + params, + raft::make_const_mdspan(d_X->view()), + d_sw, + raft::make_const_mdspan(d_centroids_ref->view()), + d_labels_ref->view(), + true, + raft::make_host_scalar_view(&ref_pred_inertia)); T pred_inertia = 0; - cuvs::cluster::kmeans::predict( - handle, - params, - X_dview(), - d_sw, - raft::make_device_matrix_view( - d_centroids.data(), params.n_clusters, n_features), - raft::make_device_vector_view(d_labels.data(), n_samples), - true, - raft::make_host_scalar_view(&pred_inertia)); + cuvs::cluster::kmeans::predict(handle, + params, + raft::make_const_mdspan(d_X->view()), + d_sw, + raft::make_const_mdspan(d_centroids->view()), + d_labels->view(), + true, + raft::make_host_scalar_view(&pred_inertia)); raft::resource::sync_stream(handle, stream); - score = raft::stats::adjusted_rand_index( - d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); + score = raft::stats::adjusted_rand_index(d_labels_ref->data_handle(), + d_labels->data_handle(), + n_samples, + raft::resource::get_cuda_stream(handle)); if (score < 0.99) { std::stringstream ss; - ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + ss << "Expected: " << raft::arr2Str(d_labels_ref->data_handle(), 25, "d_labels_ref", stream); std::cout << (ss.str().c_str()) << '\n'; ss.str(std::string()); - ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + ss << "Actual: " << raft::arr2Str(d_labels->data_handle(), 25, "d_labels", stream); std::cout << (ss.str().c_str()) << '\n'; std::cout << "Score = " << score << '\n'; } @@ -570,7 +547,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamview()), std::optional>{std::nullopt}, d_centroids_buf.view(), raft::make_host_scalar_view(&inertia), @@ -627,7 +604,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamview()), std::optional>{std::nullopt}, d_centroids_buf.view(), raft::make_host_scalar_view(&inertia), @@ -657,7 +634,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(handle, n_clusters, n_features); raft::copy(d_centroids_buf.data_handle(), - d_X.data(), + d_X->data_handle(), static_cast(n_samples) * n_features, stream); @@ -675,7 +652,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamview()), std::optional>{std::nullopt}, d_centroids_buf.view(), raft::make_host_scalar_view(&inertia), @@ -692,13 +669,13 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam testparams; - rmm::device_uvector d_labels; - rmm::device_uvector d_labels_ref; - rmm::device_uvector d_centroids; - rmm::device_uvector d_centroids_ref; - rmm::device_uvector d_X; - rmm::device_uvector d_sample_weight; - std::vector h_X; + std::optional> d_labels; + std::optional> d_labels_ref; + std::optional> d_centroids; + std::optional> d_centroids_ref; + std::optional> d_X; + std::optional> d_sample_weight; + std::optional> h_X; double score; testing::AssertionResult centroids_match = testing::AssertionSuccess(); bool inertia_match = false;