diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 8f55edb925..0f5f4554c1 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -39,6 +39,8 @@ typedef enum { /** * @brief Hyper-parameters for the kmeans algorithm + * NB: The inertia_check field is kept for ABI compatibility. Removed in cuvsKMeansParams_v2. + * TODO: CalVer for the replacement: 26.08 */ struct cuvsKMeansParams { cuvsDistanceType metric; @@ -91,7 +93,7 @@ struct cuvsKMeansParams { */ int batch_centroids; - /** Check inertia during iterations for early convergence. */ + /** Deprecated, ignored. Kept for ABI compatibility. */ bool inertia_check; /** @@ -108,14 +110,104 @@ struct cuvsKMeansParams { * Number of samples to process per GPU batch for the batched (host-data) API. * When set to 0, defaults to n_samples (process all at once). */ - int64_t streaming_batch_size; + int64_t streaming_batch_size; + + /** + * Number of samples to draw for KMeansPlusPlus initialization. + * When set to 0, uses heuristic min(3 * n_clusters, n_samples) for host data, + * or n_samples for device data. + */ + int64_t init_size; +}; + +/** + * @brief Hyper-parameters for the kmeans algorithm + * TODO: Remove this after cuvsKMeansParams is replaced in ABI 2.0 + */ + struct cuvsKMeansParams_v2 { + cuvsDistanceType metric; + + /** + * The number of clusters to form as well as the number of centroids to generate (default:8). + */ + int n_clusters; + + /** + * Method for initialization, defaults to k-means++: + * - cuvsKMeansInitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm + * to select the initial cluster centers. + * - cuvsKMeansInitMethod::Random (random): Choose 'n_clusters' observations (rows) at + * random from the input data for the initial centroids. + * - cuvsKMeansInitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + */ + cuvsKMeansInitMethod init; + + /** + * Maximum number of iterations of the k-means algorithm for a single run. + */ + int max_iter; + + /** + * Relative tolerance with regards to inertia to declare convergence. + */ + double tol; + + /** + * Number of instance k-means algorithm will be run with different seeds. + */ + int n_init; + + /** + * Oversampling factor for use in the k-means|| algorithm + */ + double oversampling_factor; + + /** + * batch_samples and batch_centroids are used to tile 1NN computation which is + * useful to optimize/control the memory footprint + * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + * then don't tile the centroids + */ + int batch_samples; + + /** + * if 0 then batch_centroids = n_clusters + */ + int batch_centroids; + + /** + * Whether to use hierarchical (balanced) kmeans or not + */ + bool hierarchical; + + /** + * For hierarchical k-means , defines the number of training iterations + */ + int hierarchical_n_iters; + + /** + * Number of samples to process per GPU batch for the batched (host-data) API. + * When set to 0, defaults to n_samples (process all at once). + */ + int64_t streaming_batch_size; + + /** + * Number of samples to draw for KMeansPlusPlus initialization. + * When set to 0, uses heuristic min(3 * n_clusters, n_samples) for host data, + * or n_samples for device data. + */ + int64_t init_size; }; typedef struct cuvsKMeansParams* cuvsKMeansParams_t; +typedef struct cuvsKMeansParams_v2* cuvsKMeansParams_v2_t; /** * @brief Allocate KMeans params, and populate with default values * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansParamsCreate_v2. + * * @param[in] params cuvsKMeansParams_t to allocate * @return cuvsError_t */ @@ -124,11 +216,33 @@ cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params); /** * @brief De-allocate KMeans params * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansParamsDestroy_v2. + * * @param[in] params * @return cuvsError_t */ cuvsError_t cuvsKMeansParamsDestroy(cuvsKMeansParams_t params); +/** + * @brief Allocate KMeans params + * + * Mirrors cuvsKMeansParamsCreate but operates on cuvsKMeansParams_v2. + * Will become the unsuffixed cuvsKMeansParamsCreate in cuVS 26.08. + * + * @param[in] params cuvsKMeansParams_v2_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsCreate_v2(cuvsKMeansParams_v2_t* params); + +/** + * @brief De-allocate KMeans params allocated by cuvsKMeansParamsCreate_v2. + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsKMeansParamsDestroy_v2(cuvsKMeansParams_v2_t params); + /** * @brief Type of k-means algorithm. */ @@ -154,6 +268,9 @@ typedef enum { CUVS_KMEANS_TYPE_KMEANS = 0, CUVS_KMEANS_TYPE_KMEANS_BALANCED = 1 * When X is on the host the data is streamed to the GPU in * batches controlled by params->streaming_batch_size. * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansFit_v2. + * * @param[in] res opaque C handle * @param[in] params Parameters for KMeans model. * @param[in] X Training instances to cluster. The data must @@ -181,9 +298,45 @@ cuvsError_t cuvsKMeansFit(cuvsResources_t res, double* inertia, int* n_iter); +/** + * @brief Find clusters with k-means algorithm (v2 params layout). + * + * Mirrors cuvsKMeansFit but takes cuvsKMeansParams_v2_t. Will become the + * unsuffixed cuvsKMeansFit in cuVS 26.08. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model (v2 layout). + * @param[in] X Training instances to cluster. The data must + * be in row-major format. May be on host or + * device memory. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * Must be on the same memory space as X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. Must be on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +cuvsError_t cuvsKMeansFit_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter); + /** * @brief Predict the closest cluster each sample in X belongs to. * + * @note In cuVS 26.08 (next ABI major version) this signature will be + * replaced by cuvsKMeansPredict_v2. + * * @param[in] res opaque C handle * @param[in] params Parameters for KMeans model. * @param[in] X New data to predict. @@ -209,6 +362,37 @@ cuvsError_t cuvsKMeansPredict(cuvsResources_t res, bool normalize_weight, double* inertia); +/** + * @brief Predict the closest cluster each sample in X belongs to (v2 params layout). + * + * Mirrors cuvsKMeansPredict but takes cuvsKMeansParams_v2_t. Will become the + * unsuffixed cuvsKMeansPredict in cuVS 26.08. + * + * @param[in] res opaque C handle + * @param[in] params Parameters for KMeans model (v2 layout). + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +cuvsError_t cuvsKMeansPredict_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia); + /** * @brief Compute cluster cost * diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index a84cd50259..8e46764ce4 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -16,7 +16,9 @@ namespace { -cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) +// The conversions are templated on the C struct type and reused by both API surfaces. +template +cuvs::cluster::kmeans::params convert_params(const ParamsT& params) { auto kmeans_params = cuvs::cluster::kmeans::params(); kmeans_params.metric = static_cast(params.metric); @@ -28,12 +30,13 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.oversampling_factor = params.oversampling_factor; kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; - kmeans_params.inertia_check = params.inertia_check; + kmeans_params.init_size = params.init_size; kmeans_params.streaming_batch_size = params.streaming_batch_size; return kmeans_params; } -cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansParams& params) +template +cuvs::cluster::kmeans::balanced_params convert_balanced_params(const ParamsT& params) { auto kmeans_params = cuvs::cluster::kmeans::balanced_params(); kmeans_params.metric = static_cast(params.metric); @@ -41,9 +44,9 @@ cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansP return kmeans_params; } -template +template void _fit(cuvsResources_t res, - const cuvsKMeansParams& params, + const ParamsT& params, DLManagedTensor* X_tensor, DLManagedTensor* sample_weight_tensor, DLManagedTensor* centroids_tensor, @@ -140,9 +143,9 @@ void _fit(cuvsResources_t res, } } -template +template void _predict(cuvsResources_t res, - const cuvsKMeansParams& params, + const ParamsT& params, DLManagedTensor* X_tensor, DLManagedTensor* sample_weight_tensor, DLManagedTensor* centroids_tensor, @@ -237,10 +240,11 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .oversampling_factor = cpp_params.oversampling_factor, .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, - .inertia_check = cpp_params.inertia_check, + .inertia_check = false, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), - .streaming_batch_size = cpp_params.streaming_batch_size}; + .streaming_batch_size = cpp_params.streaming_batch_size, + .init_size = cpp_params.init_size}; }); } @@ -294,6 +298,79 @@ extern "C" cuvsError_t cuvsKMeansPredict(cuvsResources_t res, }); } +extern "C" cuvsError_t cuvsKMeansParamsCreate_v2(cuvsKMeansParams_v2_t* params) +{ + return cuvs::core::translate_exceptions([=] { + cuvs::cluster::kmeans::params cpp_params; + cuvs::cluster::kmeans::balanced_params cpp_balanced_params; + *params = new cuvsKMeansParams_v2{ + .metric = static_cast(cpp_params.metric), + .n_clusters = cpp_params.n_clusters, + .init = static_cast(cpp_params.init), + .max_iter = cpp_params.max_iter, + .tol = cpp_params.tol, + .n_init = cpp_params.n_init, + .oversampling_factor = cpp_params.oversampling_factor, + .batch_samples = cpp_params.batch_samples, + .batch_centroids = cpp_params.batch_centroids, + .hierarchical = false, + .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), + .streaming_batch_size = cpp_params.streaming_batch_size, + .init_size = cpp_params.init_size}; + }); +} + +extern "C" cuvsError_t cuvsKMeansParamsDestroy_v2(cuvsKMeansParams_v2_t params) +{ + return cuvs::core::translate_exceptions([=] { delete params; }); +} + +extern "C" cuvsError_t cuvsKMeansFit_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _fit(res, *params, X, sample_weight, centroids, inertia, n_iter); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsKMeansPredict_v2(cuvsResources_t res, + cuvsKMeansParams_v2_t params, + DLManagedTensor* X, + DLManagedTensor* sample_weight, + DLManagedTensor* centroids, + DLManagedTensor* labels, + bool normalize_weight, + double* inertia) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = X->dl_tensor; + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + _predict(res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) { + _predict( + res, *params, X, sample_weight, centroids, labels, normalize_weight, inertia); + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + }); +} + extern "C" cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* X, DLManagedTensor* centroids, diff --git a/c/tests/CMakeLists.txt b/c/tests/CMakeLists.txt index a80f0b518a..c1f502adf4 100644 --- a/c/tests/CMakeLists.txt +++ b/c/tests/CMakeLists.txt @@ -74,6 +74,7 @@ ConfigureTest(NAME INTEROP_TEST PATH core/interop.cu) ConfigureTest( NAME DISTANCE_C_TEST PATH distance/run_pairwise_distance_c.c distance/pairwise_distance_c.cu ) +ConfigureTest(NAME KMEANS_C_TEST PATH cluster/kmeans_c.cu) ConfigureTest(NAME BRUTEFORCE_C_TEST PATH neighbors/run_brute_force_c.c neighbors/brute_force_c.cu) ConfigureTest(NAME IVF_FLAT_C_TEST PATH neighbors/run_ivf_flat_c.c neighbors/ann_ivf_flat_c.cu) ConfigureTest(NAME IVF_PQ_C_TEST PATH neighbors/run_ivf_pq_c.c neighbors/ann_ivf_pq_c.cu) diff --git a/c/tests/cluster/kmeans_c.cu b/c/tests/cluster/kmeans_c.cu new file mode 100644 index 0000000000..3c87d035d3 --- /dev/null +++ b/c/tests/cluster/kmeans_c.cu @@ -0,0 +1,264 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "test_utils.cuh" + +#include +#include + +#include +#include +#include +#include +#include + +#include "../../src/core/interop.hpp" +#include +#include + +#include + +namespace { + +constexpr int64_t kNSamples = 8; +constexpr int64_t kNFeatures = 2; +constexpr int kNClusters = 2; + +float kDataset[kNSamples][kNFeatures] = { + {1.0f, 1.0f}, + {1.0f, 2.0f}, + {2.0f, 1.0f}, + {2.0f, 2.0f}, + {10.0f, 10.0f}, + {10.0f, 11.0f}, + {11.0f, 10.0f}, + {11.0f, 11.0f}, +}; + +float kInitCentroids[kNClusters][kNFeatures] = { + {0.0f, 0.0f}, + {12.0f, 12.0f}, +}; + +float kExpectedCentroids[kNClusters * kNFeatures] = {1.5f, 1.5f, 10.5f, 10.5f}; +int32_t kExpectedLabels[kNSamples] = {0, 0, 0, 0, 1, 1, 1, 1}; + +// 8 points, each at squared distance 0.5 from its cluster mean -> 4.0. +constexpr double kExpectedInertia = 4.0; + +// Type-erased dispatcher to exercise both the v1 and v2 entry points with +// shared test bodies. +struct kmeans_api_v1 { + using params_t = cuvsKMeansParams_t; + static cuvsError_t params_create(params_t* p) { return cuvsKMeansParamsCreate(p); } + static cuvsError_t params_destroy(params_t p) { return cuvsKMeansParamsDestroy(p); } + static cuvsError_t fit(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) + { + return cuvsKMeansFit(res, params, dataset, NULL, centroids, inertia, n_iter); + } + static cuvsError_t predict(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + DLManagedTensor* labels, + double* inertia) + { + return cuvsKMeansPredict( + res, params, dataset, NULL, centroids, labels, false, inertia); + } +}; + +struct kmeans_api_v2 { + using params_t = cuvsKMeansParams_v2_t; + static cuvsError_t params_create(params_t* p) { return cuvsKMeansParamsCreate_v2(p); } + static cuvsError_t params_destroy(params_t p) { return cuvsKMeansParamsDestroy_v2(p); } + static cuvsError_t fit(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + double* inertia, + int* n_iter) + { + return cuvsKMeansFit_v2(res, params, dataset, NULL, centroids, inertia, n_iter); + } + static cuvsError_t predict(cuvsResources_t res, + params_t params, + DLManagedTensor* dataset, + DLManagedTensor* centroids, + DLManagedTensor* labels, + double* inertia) + { + return cuvsKMeansPredict_v2( + res, params, dataset, NULL, centroids, labels, false, inertia); + } +}; + +template +void test_fit_predict() +{ + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector dataset_d(kNSamples * kNFeatures, stream); + rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); + rmm::device_uvector labels_d(kNSamples, stream); + + raft::copy(dataset_d.data(), + reinterpret_cast(kDataset), + kNSamples * kNFeatures, + stream); + raft::copy(centroids_d.data(), + reinterpret_cast(kInitCentroids), + kNClusters * kNFeatures, + stream); + + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + typename Api::params_t params; + ASSERT_EQ(Api::params_create(¶ms), CUVS_SUCCESS); + params->n_clusters = kNClusters; + params->max_iter = 100; + params->tol = 1e-6; + params->init = Array; + params->streaming_batch_size = 0; + + DLManagedTensor dataset_t{}; + cuvs::core::to_dlpack( + raft::make_device_matrix_view(dataset_d.data(), kNSamples, kNFeatures), + &dataset_t); + + DLManagedTensor centroids_t{}; + cuvs::core::to_dlpack( + raft::make_device_matrix_view(centroids_d.data(), kNClusters, kNFeatures), + ¢roids_t); + + DLManagedTensor labels_t{}; + cuvs::core::to_dlpack( + raft::make_device_vector_view(labels_d.data(), kNSamples), &labels_t); + + double inertia = -1.0; + int n_iter = -1; + double predict_inertia = -1.0; + double cluster_cost = -1.0; + + ASSERT_EQ(Api::fit(res, params, &dataset_t, ¢roids_t, &inertia, &n_iter), CUVS_SUCCESS); + ASSERT_EQ(Api::predict(res, params, &dataset_t, ¢roids_t, &labels_t, &predict_inertia), + CUVS_SUCCESS); + ASSERT_EQ(cuvsKMeansClusterCost(res, &dataset_t, ¢roids_t, &cluster_cost), CUVS_SUCCESS); + + ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, + centroids_d.data(), + kNClusters * kNFeatures, + cuvs::CompareApprox(1e-4f))); + ASSERT_TRUE(cuvs::devArrMatchHost( + kExpectedLabels, labels_d.data(), kNSamples, cuvs::Compare())); + + EXPECT_GT(n_iter, 0); + EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); + EXPECT_NEAR(predict_inertia, kExpectedInertia, 1e-4); + EXPECT_NEAR(cluster_cost, kExpectedInertia, 1e-4); + + labels_t.deleter(&labels_t); + centroids_t.deleter(¢roids_t); + dataset_t.deleter(&dataset_t); + + ASSERT_EQ(Api::params_destroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +template +void test_fit_host() +{ + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector centroids_d(kNClusters * kNFeatures, stream); + raft::copy(centroids_d.data(), + reinterpret_cast(kInitCentroids), + kNClusters * kNFeatures, + stream); + + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + typename Api::params_t params; + ASSERT_EQ(Api::params_create(¶ms), CUVS_SUCCESS); + params->n_clusters = kNClusters; + params->max_iter = 100; + params->tol = 1e-6; + params->init = Array; + params->streaming_batch_size = 4; // force at least 2 streamed batches + + DLManagedTensor dataset_t{}; + cuvs::core::to_dlpack( + raft::make_host_matrix_view( + reinterpret_cast(kDataset), kNSamples, kNFeatures), + &dataset_t); + + DLManagedTensor centroids_t{}; + cuvs::core::to_dlpack( + raft::make_device_matrix_view(centroids_d.data(), kNClusters, kNFeatures), + ¢roids_t); + + double inertia = -1.0; + int n_iter = -1; + + ASSERT_EQ(Api::fit(res, params, &dataset_t, ¢roids_t, &inertia, &n_iter), CUVS_SUCCESS); + + ASSERT_TRUE(cuvs::devArrMatchHost(kExpectedCentroids, + centroids_d.data(), + kNClusters * kNFeatures, + cuvs::CompareApprox(1e-4f))); + + EXPECT_GT(n_iter, 0); + EXPECT_NEAR(inertia, kExpectedInertia, 1e-4); + + centroids_t.deleter(¢roids_t); + dataset_t.deleter(&dataset_t); + + ASSERT_EQ(Api::params_destroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +} // namespace + +TEST(KMeansC, FitPredict) { test_fit_predict(); } +// TODO(cuVS 26.08): remove FitPredictV2 once `_v2` is promoted to the +// unsuffixed ABI -- it will be redundant with FitPredict at that point. +TEST(KMeansC, FitPredictV2) { test_fit_predict(); } + +TEST(KMeansC, FitHost) { test_fit_host(); } +// TODO(cuVS 26.08): remove FitHostV2 once `_v2` is promoted to the +// unsuffixed ABI. +TEST(KMeansC, FitHostV2) { test_fit_host(); } + +TEST(KMeansC, ParamsCreateDestroy) +{ + cuvsKMeansParams_t params = nullptr; + ASSERT_EQ(cuvsKMeansParamsCreate(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_clusters, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsKMeansParamsDestroy(params), CUVS_SUCCESS); +} + +// TODO(cuVS 26.08): remove ParamsCreateDestroyV2 once cuvsKMeansParamsCreate_v2 +// / cuvsKMeansParamsDestroy_v2 are promoted to the unsuffixed entry points and +// the `_v2` symbols are deleted from the public header. +TEST(KMeansC, ParamsCreateDestroyV2) +{ + cuvsKMeansParams_v2_t params = nullptr; + ASSERT_EQ(cuvsKMeansParamsCreate_v2(¶ms), CUVS_SUCCESS); + ASSERT_NE(params, nullptr); + EXPECT_GT(params->n_clusters, 0); + EXPECT_GT(params->max_iter, 0); + ASSERT_EQ(cuvsKMeansParamsDestroy_v2(params), CUVS_SUCCESS); +} diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index d299d9f483..36e6f46c51 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -113,9 +113,18 @@ struct params : base_params { int batch_centroids = 0; /** - * If true, check inertia during iterations for early convergence. + * Number of samples to randomly draw for the KMeansPlusPlus initialization + * step. A random subset of this size is used for centroid seeding. + * + * Only applies when dataset is on host; for device data the full dataset + * is always used for seeding and this parameter is ignored. + * + * When set to 0 (default) with host data uses `min(3 * n_clusters, n_samples)` + * as a default. + * + * Default: 0. */ - bool inertia_check = false; + int64_t init_size = 0; /** * Number of samples to process per GPU batch when fitting with host data. diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 97332d7165..297b4b2fd3 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -5,6 +5,7 @@ #pragma once #include "../../core/nvtx.hpp" +#include "../../neighbors/detail/ann_utils.cuh" #include "kmeans_common.cuh" #include @@ -20,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -31,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +48,7 @@ #include #include #include +#include #include #include @@ -259,197 +264,18 @@ void kmeansPlusPlus(raft::resources const& handle, } /// <<<< Step-5 >>> } -/** - * - * @tparam DataT - * @tparam IndexT - * @param handle - * @param[in] X input matrix (size n_samples, n_features) - * @param[in] weight number of samples currently assigned to each centroid - * @param[in] cur_centroids matrix of current centroids (size n_clusters, n_features) - * @param[in] l2norm_x - * @param[out] min_cluster_and_dist - * @param[out] new_centroids - * @param[out] new_weight - * @param[inout] workspace - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - - // TODO: Figure out how to best wrap iterator types in mdspan - LabelsIterator cluster_labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids, - rmm::device_uvector& workspace) -{ - auto n_clusters = centroids.extent(0); - - cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, - X, - sample_weights, - cluster_labels, - static_cast(n_clusters), - new_centroids, - weight_per_cluster, - workspace); - - cuvs::cluster::kmeans::detail::finalize_centroids(handle, - raft::make_const_mdspan(new_centroids), - raft::make_const_mdspan(weight_per_cluster), - centroids, - new_centroids); -} - -// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector -template -void kmeans_fit_main(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded, - "kmeans only supports L2Expanded or L2SqrtExpanded distance metrics."); - raft::common::nvtx::range fun_scope("kmeans_fit_main"); - raft::default_logger().set_level(params.verbosity); - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - rmm::device_scalar clusterCostD(stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm(handle, X, L2NormX.view()); - } - - RAFT_LOG_DEBUG( - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG( - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter[0]); - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - update_centroids( - handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - cuda::transform_iterator(minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}), - wtInCluster.view(), - newCentroids.view(), - workspace); - - // Compute how much centroids shifted - DataT sqrdNormError = compute_centroid_shift( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(newCentroids.view())); - - raft::copy(handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()), - raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - DataT curClusteringCost = clusterCostD.value(stream); - - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - cuvs::cluster::kmeans::cluster_cost(handle, - X, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - inertia, - std::make_optional(weight)); - - RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", - n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], - inertia[0]); -} +template +void kmeans_fit( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& pams, + raft::mdspan, raft::row_major, Accessor> X, + std::optional< + raft::mdspan, raft::layout_right, Accessor>> + sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, + std::optional>> workspace = std::nullopt); /* * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. @@ -654,17 +480,21 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, auto inertia = raft::make_host_scalar(0); auto n_iter = raft::make_host_scalar(0); - cuvs::cluster::kmeans::params default_params; - default_params.n_clusters = params.n_clusters; - - cuvs::cluster::kmeans::detail::kmeans_fit_main(handle, - default_params, - potentialCentroids, - weight.view(), - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); + cuvs::cluster::kmeans::params recluster_params; + recluster_params.n_clusters = params.n_clusters; + recluster_params.init = cuvs::cluster::kmeans::params::InitMethod::Array; + recluster_params.n_init = 1; + + auto weight_opt = std::make_optional(raft::make_const_mdspan(weight.view())); + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, + recluster_params, + raft::make_const_mdspan(potentialCentroids), + weight_opt, + centroidsRawData, + inertia.view(), + n_iter.view(), + std::ref(workspace)); } else if ((int)potentialCentroids.extent(0) < n_clusters) { // supplement with random @@ -700,37 +530,40 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, } /** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. + * @brief Unified k-means fit (works with host or device data). + * + * @tparam DataT Data / weight type + * @tparam IndexT Index type + * @tparam Accessor Accessor policy (host or device); deduced from X + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. It must be noted - * that the data must be in row-major format and stored in device accessible - * location. - * @param[in] n_samples Number of samples in the input X. - * @param[in] n_features Number of features or the dimensions of each - * sample. + * @param[in] pams Parameters for the KMeans model. + * @param[in] X Training instances to cluster (host or device). + * Row-major, [n_samples x n_features]. * @param[in] sample_weight Optional weights for each observation in X. - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] Otherwise, generated centroids from the - * kmeans algorithm is stored at the address pointed by 'centroids'. + * [n_samples]. When std::nullopt, uniform weights + * are used. + * @param[inout] centroids [in] When init is InitMethod::Array, used as + * the initial cluster centers. + * [out] The final centroids produced by the + * algorithm. [n_clusters x n_features]. * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. + * closest cluster center. + * @param[out] n_iter Number of iterations run for the best + * initialization. */ -template -void kmeans_fit(raft::resources const& handle, - const cuvs::cluster::kmeans::params& pams, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +template +void kmeans_fit( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& pams, + raft::mdspan, raft::row_major, Accessor> X, + std::optional< + raft::mdspan, raft::layout_right, Accessor>> + sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, + std::optional>> workspace) { RAFT_EXPECTS(pams.metric == cuvs::distance::DistanceType::L2Expanded || pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded, @@ -739,54 +572,100 @@ void kmeans_fit(raft::resources const& handle, auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = pams.n_clusters; + auto metric = pams.metric; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - // Check that parameters are valid + if (sample_weight.has_value()) RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, "invalid parameter (sample_weight!=n_samples)"); RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); RAFT_EXPECTS(pams.tol > 0, "invalid parameter (tol<=0)"); RAFT_EXPECTS(pams.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == pams.n_clusters, + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "invalid parameter (centroids.extent(0) != n_clusters)"); RAFT_EXPECTS(centroids.extent(1) == n_features, "invalid parameter (centroids.extent(1) != n_features)"); - // Display a message if the batch size is smaller than n_samples but will be ignored - if (pams.batch_samples < (int)n_samples && - (pams.metric == cuvs::distance::DistanceType::L2Expanded || - pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_samples=%d was passed, but batch_samples=%d will be used (reason: " - "batch_samples has no impact on the memory footprint when FusedL2NN can be used)", - pams.batch_samples, - (int)n_samples); + raft::default_logger().set_level(pams.verbosity); + + IndexT streaming_batch_size = static_cast(pams.streaming_batch_size); + if (streaming_batch_size <= 0 || streaming_batch_size > static_cast(n_samples)) { + streaming_batch_size = static_cast(n_samples); } - // Display a message if batch_centroids is set and a fusedL2NN-compatible metric is used - if (pams.batch_centroids != 0 && pams.batch_centroids != pams.n_clusters && - (pams.metric == cuvs::distance::DistanceType::L2Expanded || - pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " - "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", - pams.batch_centroids, - pams.n_clusters); + + constexpr bool data_on_device = raft::is_device_mdspan_v; + + const DataT* weight_ptr = + sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; + + auto d_wt_sum = raft::make_device_scalar(handle, static_cast(n_samples)); + if (sample_weight.has_value()) { weightSum(handle, sample_weight.value(), d_wt_sum.view()); } + + rmm::device_uvector local_workspace(0, stream); + rmm::device_uvector& ws = workspace.has_value() ? workspace->get() : local_workspace; + + if (data_on_device && streaming_batch_size != static_cast(n_samples)) { + RAFT_LOG_WARN( + "KMeans: streaming_batch_size (%zu) ignored when data resides on device; using n_samples " + "(%zu)", + static_cast(streaming_batch_size), + static_cast(n_samples)); + streaming_batch_size = static_cast(n_samples); } - raft::default_logger().set_level(pams.verbosity); + // Preallocate the host-side KMeans++ init sample buffer. + std::optional> init_sample; + if constexpr (!data_on_device) { + if (pams.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IndexT default_init_size = + std::min(static_cast(std::int64_t{3} * n_clusters), n_samples); + IndexT init_sample_size = pams.init_size > 0 + ? std::min(static_cast(pams.init_size), n_samples) + : default_init_size; + + if (pams.init_size <= 0 && init_sample_size < n_samples) { + RAFT_LOG_WARN( + "KMeans.fit: KMeans++ initialization is using a random subsample of %zu/%zu host rows " + "(params.init_size=0 defaults to min(3 * n_clusters, n_samples) for host data). " + "Set params.init_size to n_samples to use the full dataset for seeding.", + static_cast(init_sample_size), + static_cast(n_samples)); + } - // Allocate memory - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(handle, weight.view(), sample_weight.value()); - else - raft::matrix::fill(handle, weight.view(), DataT(1)); + init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + } + } - // check if weights sum up to n_samples - checkWeight(handle, weight.view(), workspace); + auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, + raft::device_matrix_view centroidsRawData) { + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + raft::copy(handle, centroidsRawData, centroids); + return; + } - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); + raft::random::RngState random_state(iter_params.rng_state.seed); + + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + raft::matrix::sample_rows(handle, random_state, X, centroidsRawData); + } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + auto run_kmeanspp = [&](raft::device_matrix_view init_data) { + if (iter_params.oversampling_factor == 0) + kmeansPlusPlus(handle, iter_params, init_data, centroidsRawData, ws); + else + initScalableKMeansPlusPlus( + handle, iter_params, init_data, centroidsRawData, ws); + }; + + if constexpr (data_on_device) { + run_kmeanspp(X); + } else { + raft::matrix::sample_rows(handle, random_state, X, init_sample->view()); + run_kmeanspp(raft::make_const_mdspan(init_sample->view())); + } + } else { + THROW("unknown initialization method to select initial centers"); + } + }; auto n_init = pams.n_init; if (pams.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { @@ -797,70 +676,302 @@ void kmeans_fit(raft::resources const& handle, n_init = 1; } + IndexT centroid_buf_size = n_clusters * n_features; + rmm::device_uvector cur_centroids_buf(centroid_buf_size, stream); + rmm::device_uvector new_centroids_buf(centroid_buf_size, stream); + DataT* cur_centroids_ptr = cur_centroids_buf.data(); + DataT* new_centroids_ptr = new_centroids_buf.data(); + + auto minClusterAndDistance = raft::make_device_vector, IndexT>( + handle, streaming_batch_size); + auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); + auto batch_weights_buf = raft::make_device_vector(handle, streaming_batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); + auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); + + rmm::device_uvector batch_workspace(streaming_batch_size, stream); + + cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, streaming_batch_size, stream); + // Host-path weight batches: only materialized when weights are provided and + // the data resides on host + std::optional> weight_batches; + if constexpr (!data_on_device) { + if (weight_ptr != nullptr) { + weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream); + } else { + raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); + } + } else if (weight_ptr == nullptr) { + raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); + } + + std::optional> prenormalized_weights; + if constexpr (data_on_device) { + if (weight_ptr != nullptr) { + prenormalized_weights.emplace(raft::make_device_vector(handle, n_samples)); + const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); + raft::linalg::map( + handle, + prenormalized_weights->view(), + [n_samples, d_wt_sum_ptr] __device__(DataT w) { + return w * static_cast(n_samples) / *d_wt_sum_ptr; + }, + sample_weight.value()); + } + } + + // Copies and rescales `wt_data` into `batch_weights_buf` so that weights + // are normalized to sum to n_samples. + const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); + auto prepare_batch_weights = [&](const DataT* wt_data, IndexT cur_batch_size) { + if (wt_data != nullptr) { + raft::copy(batch_weights_buf.data_handle(), wt_data, cur_batch_size, stream); + auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + raft::linalg::map( + handle, + bw, + [n_samples, d_wt_sum_ptr] __device__(DataT w) { + return w * static_cast(n_samples) / *d_wt_sum_ptr; + }, + raft::make_const_mdspan(bw)); + } + return raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + }; + + auto cur_batch_weights = [&](IndexT batch_offset, const DataT* wt_data, IndexT cur_batch_size) { + if constexpr (data_on_device) { + const DataT* base = prenormalized_weights.has_value() + ? prenormalized_weights->data_handle() + batch_offset + : batch_weights_buf.data_handle(); + return raft::make_device_vector_view(base, cur_batch_size); + } else { + return prepare_batch_weights(wt_data, cur_batch_size); + } + }; + + RAFT_LOG_DEBUG( + "KMeans.fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(streaming_batch_size)); + + bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded; + auto h_norm_cache = raft::make_pinned_vector( + handle, (need_compute_norms && !data_on_device) ? n_samples : 0); + bool norms_cached = false; + + auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { + auto batch_view = + raft::make_device_matrix_view(batch_ptr, batch_size, n_features); + auto norm_view = + raft::make_device_vector_view(L2NormBatch.data_handle(), batch_size); + raft::linalg::norm( + handle, batch_view, norm_view); + }; + + // Device path: compute X norms once up front + if constexpr (data_on_device) { + if (need_compute_norms) { + compute_batch_norms(X.data_handle(), static_cast(n_samples)); + } + } + std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = pams; iter_params.rng_state.seed = gen(); + RAFT_LOG_DEBUG("KMeans.fit: n_init iteration %d/%d (seed=%llu)", + seed_iter + 1, + n_init, + (unsigned long long)iter_params.rng_state.seed); + + cur_centroids_ptr = cur_centroids_buf.data(); + new_centroids_ptr = new_centroids_buf.data(); + init_centroids( + iter_params, + raft::make_device_matrix_view(cur_centroids_ptr, n_clusters, n_features)); + DataT iter_inertia = std::numeric_limits::max(); IndexT n_current_iter = 0; - if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - // initializing with random samples from input dataset - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, iter_params, X, centroidsRawData.view()); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (iter_params.oversampling_factor == 0) - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - else - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init argument.", - seed_iter + 1, - n_init); - raft::copy( + auto sqrdNormError = raft::make_device_scalar(handle, DataT{0}); + + auto d_prior_cost = raft::make_device_scalar(handle, DataT{0}); + auto d_done_flag = raft::make_device_scalar(handle, 0); + auto h_done_flag = raft::make_pinned_scalar(handle, 0); + + for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) { + if (n_current_iter > 1) { + raft::resource::sync_stream(handle); + if (*h_done_flag.data_handle()) { + n_current_iter--; + RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", + n_current_iter); + break; + } + } + + RAFT_LOG_DEBUG("KMeans.fit: Iteration-%d", n_current_iter); + + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, weight_per_cluster.view(), DataT{0}); + raft::matrix::fill(handle, clustering_cost.view(), DataT{0}); + + auto centroids_const = raft::make_device_matrix_view( + cur_centroids_ptr, n_clusters, n_features); + auto new_centroids_view = + raft::make_device_matrix_view(new_centroids_ptr, n_clusters, n_features); + + data_batches.reset(); + using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; + std::optional wt_it; + if (weight_batches.has_value()) { + weight_batches->reset(); + wt_it = weight_batches->begin(); + } + for (const auto& data_batch : data_batches) { + IndexT cur_batch_size = static_cast(data_batch.size()); + const DataT* wt_data = nullptr; + if (wt_it.has_value()) { + wt_data = (**wt_it).data(); + ++(*wt_it); + } + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), cur_batch_size, n_features); + auto batch_weights_view = + cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); + + auto minCAD_view = raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), cur_batch_size); + + if constexpr (!data_on_device) { + if (need_compute_norms) { + if (!norms_cached) { + compute_batch_norms(data_batch.data(), cur_batch_size); + raft::copy(h_norm_cache.data_handle() + data_batch.offset(), + L2NormBatch.data_handle(), + cur_batch_size, + stream); + } else { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data_handle() + data_batch.offset(), + cur_batch_size, + stream); + } + } + } + + auto l2_const_view = raft::make_device_vector_view( + L2NormBatch.data_handle(), cur_batch_size); + + process_batch(handle, + batch_data_view, + batch_weights_view, + centroids_const, + metric, + iter_params.batch_samples, + iter_params.batch_centroids, + minCAD_view, + l2_const_view, + L2NormBuf_OR_DistBuf, + ws, + centroid_sums.view(), + weight_per_cluster.view(), + clustering_cost.view(), + batch_workspace); + } + if (need_compute_norms) { norms_cached = true; } + + finalize_centroids(handle, + raft::make_const_mdspan(centroid_sums.view()), + raft::make_const_mdspan(weight_per_cluster.view()), + centroids_const, + new_centroids_view); + + compute_centroid_shift(handle, + raft::make_const_mdspan(centroids_const), + raft::make_const_mdspan(new_centroids_view), + sqrdNormError.view()); + + std::swap(cur_centroids_ptr, new_centroids_ptr); + + auto d_cost_view = raft::make_device_scalar_view(clustering_cost.data_handle()); + auto d_prior_view = d_prior_cost.view(); + auto d_norm_view = raft::make_device_scalar_view(sqrdNormError.data_handle()); + auto d_done_view = d_done_flag.view(); + DataT tol = iter_params.tol; + int iter = n_current_iter; + + raft::linalg::map_offset( handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); - } else { - THROW("unknown initialization method to select initial centers"); + raft::make_device_vector_view(d_done_flag.data_handle(), 1), + [=] __device__(int) { + check_convergence(d_cost_view, d_prior_view, d_norm_view, tol, iter, d_done_view); + return *d_done_view.data_handle(); + }); + + raft::copy(handle, + raft::make_pinned_scalar_view(h_done_flag.data_handle()), + raft::make_device_scalar_view(d_done_flag.data_handle())); + } + + { + auto centroids_const = raft::make_device_matrix_view( + cur_centroids_ptr, n_clusters, n_features); + + iter_inertia = DataT{0}; + data_batches.reset(); + using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator; + std::optional wt_it; + if (weight_batches.has_value()) { + weight_batches->reset(); + wt_it = weight_batches->begin(); + } + for (const auto& data_batch : data_batches) { + IndexT cur_batch_size = static_cast(data_batch.size()); + const DataT* wt_data = nullptr; + if (wt_it.has_value()) { + wt_data = (**wt_it).data(); + ++(*wt_it); + } + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), cur_batch_size, n_features); + + std::optional> batch_sw = std::nullopt; + if (weight_ptr != nullptr) { + batch_sw = + cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); + } + + DataT batch_cost = DataT{0}; + cuvs::cluster::kmeans::cluster_cost(handle, + batch_data_view, + centroids_const, + raft::make_host_scalar_view(&batch_cost), + batch_sw); + + iter_inertia += batch_cost; + } } - cuvs::cluster::kmeans::detail::kmeans_fit_main( - handle, - iter_params, - X, - weight.view(), - centroidsRawData.view(), - raft::make_host_scalar_view(&iter_inertia), - raft::make_host_scalar_view(&n_current_iter), - workspace); if (iter_inertia < inertia[0]) { inertia[0] = iter_inertia; n_iter[0] = n_current_iter; - raft::copy( - handle, - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); + raft::copy(centroids.data_handle(), cur_centroids_ptr, centroid_buf_size, stream); } - RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", + RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter - %d", seed_iter + 1, n_init, inertia[0], @@ -883,15 +994,26 @@ void kmeans_fit(raft::resources const& handle, auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = raft::make_device_matrix_view(centroids, pams.n_clusters, n_features); - std::optional> sample_weightView = std::nullopt; + std::optional> sample_weightView = std::nullopt; if (sample_weight) sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); auto inertiaView = raft::make_host_scalar_view(&inertia); auto n_iterView = raft::make_host_scalar_view(&n_iter); - cuvs::cluster::kmeans::detail::kmeans_fit( - handle, pams, XView, sample_weightView, centroidsView, inertiaView, n_iterView); + kmeans_fit(handle, pams, XView, sample_weightView, centroidsView, inertiaView, n_iterView); +} + +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } template @@ -935,8 +1057,18 @@ void kmeans_predict(raft::resources const& handle, else raft::matrix::fill(handle, weight.view(), DataT(1)); - // check if weights sum up to n_samples - if (normalize_weight) checkWeight(handle, weight.view(), workspace); + if (normalize_weight && sample_weight.has_value()) { + auto d_wt_sum = raft::make_device_scalar(handle, DataT{0}); + weightSum(handle, raft::make_const_mdspan(weight.view()), d_wt_sum.view()); + const DataT* d_wt_sum_ptr = d_wt_sum.data_handle(); + raft::linalg::map( + handle, + weight.view(), + [n_samples, d_wt_sum_ptr] __device__(DataT w) { + return w * static_cast(n_samples) / *d_wt_sum_ptr; + }, + raft::make_const_mdspan(weight.view())); + } auto minClusterAndDistance = raft::make_device_vector, IndexT>(handle, n_samples); diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh deleted file mode 100644 index 6c0b9a3253..0000000000 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ /dev/null @@ -1,513 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "kmeans.cuh" -#include "kmeans_common.cuh" - -#include "../../neighbors/detail/ann_utils.cuh" -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include -#include - -namespace cuvs::cluster::kmeans::detail { - -/** - * @brief Initialize centroids from host data - * - * @tparam T Input data type - * @tparam IdxT Index type - */ -template -void init_centroids_from_host_sample(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - IdxT streaming_batch_size, - raft::host_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - - if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - // this is a heuristic to speed up the initialization - IdxT init_sample_size = 3 * streaming_batch_size; - if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } - init_sample_size = std::min(init_sample_size, n_samples); - - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); - raft::random::RngState random_state(params.rng_state.seed); - raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - - if (params.oversampling_factor == 0) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, params, raft::make_const_mdspan(init_sample.view()), centroids, workspace); - } else { - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( - handle, params, raft::make_const_mdspan(init_sample.view()), centroids, workspace); - } - } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - raft::random::RngState random_state(params.rng_state.seed); - raft::matrix::sample_rows(handle, random_state, X, centroids); - } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - // already provided - } else { - RAFT_FAIL("Unknown initialization method"); - } -} - -/** - * @brief Compute the weight normalization scale factor for host sample weights. Weights are - * normalized to sum to n_samples. Returns the scale factor to apply to each weight. - * - * @param[in] sample_weight Optional host vector of sample weights - * @param[in] n_samples Number of samples - * @return Scale factor (1.0 if no weights or already normalized) - */ -template -T compute_host_weight_scale( - const std::optional>& sample_weight, IdxT n_samples) -{ - if (!sample_weight.has_value()) { return T{1}; } - - T wt_sum = T{0}; - const T* sw_ptr = sample_weight->data_handle(); - for (IdxT i = 0; i < n_samples; ++i) { - wt_sum += sw_ptr[i]; - } - if (wt_sum == static_cast(n_samples)) { return T{1}; } - - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %zu samples (scale=%f)", - static_cast(n_samples), - static_cast(static_cast(n_samples) / wt_sum)); - return static_cast(n_samples) / wt_sum; -} - -/** - * @brief Copy host sample weights to device and apply normalization scale. - * - * When sample_weight is provided, copies the batch slice to the device buffer - * and applies the normalization scale factor. When not provided, the device - * buffer is assumed to already be filled with 1.0. - * - * @param[in] handle RAFT resources handle - * @param[in] sample_weight Optional host weights - * @param[in] batch_offset Offset into the host weights for this batch - * @param[in] batch_size Number of elements in this batch - * @param[in] weight_scale Scale factor from compute_host_weight_scale - * @param[inout] batch_weights Device buffer to write normalized weights into - */ -template -void copy_and_scale_batch_weights( - raft::resources const& handle, - const std::optional>& sample_weight, - size_t batch_offset, - IdxT batch_size, - T weight_scale, - raft::device_vector& batch_weights) -{ - if (!sample_weight.has_value()) { return; } - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - raft::copy( - batch_weights.data_handle(), sample_weight->data_handle() + batch_offset, batch_size, stream); - - if (weight_scale != T{1}) { - auto batch_weights_view = - raft::make_device_vector_view(batch_weights.data_handle(), batch_size); - raft::linalg::map(handle, - batch_weights_view, - raft::mul_const_op{weight_scale}, - raft::make_const_mdspan(batch_weights_view)); - } -} - -/** - * @brief Accumulate partial centroid sums and counts from a batch - * - * This function adds the partial sums from a batch to the running accumulators. - * It does NOT divide - that happens once at the end of all batches. - */ -template -void accumulate_batch_centroids( - raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view, IdxT> minClusterAndDistance, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view batch_sums, - raft::device_vector_view batch_counts) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto workspace = rmm::device_uvector( - batch_data.extent(0), stream, raft::resource::get_workspace_resource_ref(handle)); - - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> - labels_itr(minClusterAndDistance.data_handle(), conversion_op); - - cuvs::cluster::kmeans::detail::compute_centroid_adjustments( - handle, - batch_data, - sample_weights, - labels_itr, - static_cast(centroid_sums.extent(0)), - batch_sums, - batch_counts, - workspace); - - raft::linalg::add(centroid_sums.data_handle(), - centroid_sums.data_handle(), - batch_sums.data_handle(), - centroid_sums.size(), - stream); - - raft::linalg::add(cluster_counts.data_handle(), - cluster_counts.data_handle(), - batch_counts.data_handle(), - cluster_counts.size(), - stream); -} - -/** - * @brief Main fit function for batched k-means with host data (full-batch / Lloyd's algorithm). - * - * Processes host data in GPU-sized batches per iteration, accumulating partial centroid - * sums and counts, then finalizes centroids at the end of each iteration. - * - * @tparam T Input data type (float, double) - * @tparam IdxT Index type (int, int64_t) - * - * @param[in] handle RAFT resources handle - * @param[in] params K-means parameters - * @param[in] X Input data on HOST [n_samples x n_features] - * @param[in] sample_weight Optional weights per sample (on host) - * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid - * @param[out] n_iter Number of iterations run - */ -template -void fit(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded, - "kmeans only supports L2Expanded or L2SqrtExpanded distance metrics."); - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - IdxT streaming_batch_size = static_cast(params.streaming_batch_size); - - if (params.streaming_batch_size == 0) { - streaming_batch_size = static_cast(n_samples); - } else if (params.streaming_batch_size < 0 || params.streaming_batch_size > n_samples) { - RAFT_LOG_WARN("streaming_batch_size must be >= 0 and <= n_samples, using n_samples=%zu", - static_cast(n_samples)); - streaming_batch_size = static_cast(n_samples); - } - - RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, - "centroids.extent(0) must equal n_clusters"); - RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); - - RAFT_LOG_DEBUG( - "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", - static_cast(n_samples), - static_cast(n_features), - n_clusters, - static_cast(streaming_batch_size)); - - rmm::device_uvector workspace(0, stream); - - auto n_init = params.n_init; - if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { - RAFT_LOG_DEBUG( - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - auto best_centroids = n_init > 1 - ? raft::make_device_matrix(handle, n_clusters, n_features) - : raft::make_device_matrix(handle, 0, 0); - T best_inertia = std::numeric_limits::max(); - IdxT best_n_iter = 0; - - std::mt19937 gen(params.rng_state.seed); - - // ----- Allocate reusable work buffers (shared across n_init iterations) ----- - auto batch_weights = raft::make_device_vector(handle, streaming_batch_size); - auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, streaming_batch_size); - auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - auto clustering_cost = raft::make_device_vector(handle, 1); - auto batch_clustering_cost = raft::make_device_vector(handle, 1); - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); - - // Compute weight normalization (matches checkWeight in regular kmeans) - T weight_scale = compute_host_weight_scale(sample_weight, n_samples); - - // ---- Main n_init loop ---- - for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { - cuvs::cluster::kmeans::params iter_params = params; - iter_params.rng_state.seed = gen(); - - RAFT_LOG_DEBUG("KMeans batched fit: n_init iteration %d/%d (seed=%llu)", - seed_iter + 1, - n_init, - (unsigned long long)iter_params.rng_state.seed); - - if (iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample( - handle, iter_params, streaming_batch_size, X, centroids, workspace); - } - - if (!sample_weight.has_value()) { raft::matrix::fill(handle, batch_weights.view(), T{1}); } - - // Reset per-iteration state - T prior_cluster_cost = 0; - - cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, streaming_batch_size, stream); - - for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, weight_per_cluster.view(), T{0}); - - raft::matrix::fill(handle, clustering_cost.view(), T{0}); - - auto centroids_const = raft::make_const_mdspan(centroids); - - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - raft::matrix::fill(handle, batch_clustering_cost.view(), T{0}); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - copy_and_scale_batch_weights(handle, - sample_weight, - data_batch.offset(), - current_batch_size, - weight_scale, - batch_weights); - - auto batch_weights_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); - - auto L2NormBatch_view = - raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features), - L2NormBatch_view); - } - - auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch_view); - - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - batch_data_view, - centroids_const, - minClusterAndDistance_view, - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); - - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - weight_per_cluster.view(), - batch_sums.view(), - batch_counts.view()); - - if (params.inertia_check) { - raft::linalg::map( - handle, - minClusterAndDistance_view, - [=] __device__(const raft::KeyValuePair kvp, T wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance_view), - batch_weights_view); - - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(batch_clustering_cost.data_handle()), - raft::value_op{}, - raft::add_op{}); - raft::linalg::add(handle, - raft::make_const_mdspan(clustering_cost.view()), - raft::make_const_mdspan(batch_clustering_cost.view()), - clustering_cost.view()); - } - } - - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto weight_per_cluster_const = - raft::make_device_vector_view(weight_per_cluster.data_handle(), n_clusters); - - finalize_centroids(handle, - centroid_sums_const, - weight_per_cluster_const, - centroids_const, - new_centroids.view()); - - T sqrdNormError = compute_centroid_shift( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(new_centroids.view())); - - raft::copy(handle, centroids, new_centroids.view()); - - bool done = false; - if (params.inertia_check) { - raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); - raft::resource::sync_stream(handle); - ASSERT(inertia[0] != (T)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - if (n_iter[0] > 1) { - T delta = inertia[0] / prior_cluster_cost; - if (delta > 1 - params.tol) done = true; - } - prior_cluster_cost = inertia[0]; - } - - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - // Recompute final weighted inertia with the converged centroids. - { - auto centroids_const_view = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - - inertia[0] = T{0}; - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - std::optional> batch_sw = std::nullopt; - if (sample_weight.has_value()) { - copy_and_scale_batch_weights(handle, - sample_weight, - data_batch.offset(), - current_batch_size, - weight_scale, - batch_weights); - batch_sw = raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size); - } - - T batch_cost = T{0}; - cuvs::cluster::kmeans::cluster_cost(handle, - batch_data_view, - centroids_const_view, - raft::make_host_scalar_view(&batch_cost), - batch_sw); - - inertia[0] += batch_cost; - } - } - - RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", - seed_iter + 1, - n_init, - static_cast(inertia[0])); - - if (n_init > 1 && inertia[0] < best_inertia) { - best_inertia = inertia[0]; - best_n_iter = n_iter[0]; - raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - } - } - if (n_init > 1) { - inertia[0] = best_inertia; - n_iter[0] = best_n_iter; - raft::copy(handle, centroids, best_centroids.view()); - RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", - n_init, - static_cast(best_inertia), - best_n_iter); - } -} - -} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index e42d868dd9..348a0ec0e0 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,9 @@ #include #include #include +#include + +#include #include #include @@ -128,42 +132,30 @@ void countLabels(raft::resources const& handle, stream)); } -template -void checkWeight(raft::resources const& handle, - raft::device_vector_view weight, - rmm::device_uvector& workspace) +/** + * @brief Compute the sum of sample weights into a device scalar. + * + * Device-accessible mdspans are reduced on device. Host mdspans are summed on the host. + */ +template +void weightSum( + raft::resources const& handle, + raft::mdspan, raft::layout_right, Accessor> weight, + raft::device_scalar_view d_wt_sum) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto wt_aggr = raft::make_device_scalar(handle, 0); - auto n_samples = weight.extent(0); + auto n_samples = weight.extent(0); + auto stream = raft::resource::get_cuda_stream(handle); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(), - temp_storage_bytes, - weight.data_handle(), - wt_aggr.data_handle(), - n_samples, - stream)); - DataT wt_sum = 0; - raft::copy(handle, - raft::make_host_scalar_view(&wt_sum), - raft::make_device_scalar_view(wt_aggr.data_handle())); - raft::resource::sync_stream(handle, stream); - - if (wt_sum != n_samples) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %d samples", - n_samples); - - auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::map( - handle, weight, raft::mul_const_op{scale}, raft::make_const_mdspan(weight)); + if constexpr (raft::is_device_mdspan_v) { + raft::linalg::mapThenSumReduce( + d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle()); + } else { + DataT wt_sum = DataT{0}; + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += weight(i); + } + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); + raft::copy(d_wt_sum.data_handle(), &wt_sum, 1, stream); } } @@ -261,7 +253,7 @@ void sampleCentroids(raft::resources const& handle, raft::copy(handle, raft::make_host_scalar_view(&nPtsSampledInRank), raft::make_device_scalar_view(nSelected.data_handle())); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); thrust::for_each_n(raft::resource::get_thrust_policy(handle), @@ -483,14 +475,18 @@ void countSamplesInCluster(raft::resources const& handle, * @tparam IndexT Index type * @tparam LabelsIterator Iterator type for cluster labels * - * @param[in] handle RAFT resources handle - * @param[in] X Input samples [n_samples x n_features] - * @param[in] sample_weights Weights for each sample [n_samples] - * @param[in] cluster_labels Cluster assignment for each sample (iterator) - * @param[in] n_clusters Number of clusters - * @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features] - * @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters] - * @param[inout] workspace Workspace buffer for intermediate operations + * @param[in] handle RAFT resources handle + * @param[in] X Input samples [n_samples x n_features] + * @param[in] sample_weights Weights for each sample [n_samples] + * @param[in] cluster_labels Cluster assignment for each sample (iterator) + * @param[in] n_clusters Number of clusters + * @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features] + * @param[inout] weight_per_cluster Sum of weights per cluster [n_clusters]. Follows the same + * overwrite-vs-accumulate semantics as `centroid_sums` + * @param[inout] workspace Workspace buffer for intermediate operations + * @param[in] reset_sums If true (default), outputs are reset to zero before reducing; + * if false, this call's contribution is accumulated into the + * existing `centroid_sums` */ template void compute_centroid_adjustments( @@ -501,7 +497,8 @@ void compute_centroid_adjustments( IndexT n_clusters, raft::device_matrix_view centroid_sums, raft::device_vector_view weight_per_cluster, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + bool reset_sums = true) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -517,7 +514,8 @@ void compute_centroid_adjustments( X.extent(1), n_clusters, centroid_sums.data_handle(), - stream); + stream, + reset_sums); raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), cluster_labels, @@ -525,7 +523,8 @@ void compute_centroid_adjustments( static_cast(1), static_cast(n_samples), n_clusters, - stream); + stream, + reset_sums); } /** * @brief Finalize centroids by dividing accumulated sums by counts. @@ -575,26 +574,148 @@ void finalize_centroids(raft::resources const& handle, /** * @brief Compute the squared norm difference between two centroid sets. * - * Returns sum((old_centroids - new_centroids)^2). - * Used for convergence checking. + * Writes sum((old_centroids - new_centroids)^2) into @p sqrd_norm_out. + * Used for convergence checking. Fully asynchronous — no stream sync. */ template -DataT compute_centroid_shift(raft::resources const& handle, - raft::device_matrix_view old_centroids, - raft::device_matrix_view new_centroids) +void compute_centroid_shift(raft::resources const& handle, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids, + raft::device_scalar_view sqrd_norm_out) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + raft::linalg::mapThenSumReduce(sqrd_norm_out.data_handle(), old_centroids.size(), raft::sqdiff_op{}, stream, old_centroids.data_handle(), new_centroids.data_handle()); - DataT result = 0; - raft::copy(&result, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); - return result; +} + +/** + * @brief Evaluate convergence criteria entirely on device. + * + * Checks the cost-ratio and centroid-shift stopping conditions and writes + * a boolean result (0 or 1) into @p done_flag. Also advances + * @p prior_clustering_cost to the current cost for the next iteration. + */ +template +__device__ void check_convergence(raft::device_scalar_view clustering_cost, + raft::device_scalar_view prior_clustering_cost, + raft::device_scalar_view sqrd_norm_error, + DataT tol, + int n_iter, + raft::device_scalar_view done_flag) +{ + DataT cur_cost = *clustering_cost.data_handle(); + DataT norm_err = *sqrd_norm_error.data_handle(); + int done = 0; + + if (cur_cost != DataT{0} && n_iter > 1) { + DataT delta = cur_cost / *prior_clustering_cost.data_handle(); + if (delta > DataT{1} - tol) done = 1; + } + if (norm_err < tol) done = 1; + + *prior_clustering_cost.data_handle() = cur_cost; + *done_flag.data_handle() = done; +} + +/** + * @brief Process a single batch of data in the Lloyd iteration. + * + * Given one batch of data + precomputed norms + weights + current centroids it + * 1. finds the nearest centroid for every sample, + * 2. accumulates weighted centroid sums and counts into the running accumulators, + * 3. accumulates the weighted clustering cost (inertia). + * + * Data norms must be precomputed by the caller and passed in via L2NormBatch. + * + * @tparam DataT Data / weight type (float, double) + * @tparam IndexT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] batch_data Device batch data [batch_size x n_features] + * @param[in] batch_weights Device batch weights [batch_size] + * @param[in] centroids Current centroids [n_clusters x n_features] + * @param[in] metric Distance metric + * @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute + * @param[in] batch_centroids_param Batch-centroids param forwarded to + * minClusterAndDistanceCompute + * @param[inout] minClusterAndDistance Work buffer [batch_size] + * @param[in] L2NormBatch Precomputed data norms [batch_size] + * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch + * @param[inout] workspace Resizable scratch + * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) + * @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into) + * @param[inout] clustering_cost Running cost scalar (device) (added into) + */ +template +void process_batch( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view batch_weights, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric, + int batch_samples_param, + int batch_centroids_param, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view L2NormBatch, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + raft::device_scalar_view clustering_cost, + rmm::device_uvector& batch_workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + minClusterAndDistanceCompute(handle, + batch_data, + centroids, + minClusterAndDistance, + L2NormBatch, + L2NormBuf_OR_DistBuf, + metric, + batch_samples_param, + batch_centroids_param, + workspace); + + KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + compute_centroid_adjustments(handle, + batch_data, + batch_weights, + labels_itr, + static_cast(centroid_sums.extent(0)), + centroid_sums, + weight_per_cluster, + batch_workspace, + /*reset_sums=*/false); + + raft::linalg::map( + handle, + minClusterAndDistance, + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance), + batch_weights); + + auto batch_cost = raft::make_device_scalar(handle, DataT{0}); + computeClusterCost( + handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{}); + raft::linalg::add(clustering_cost.data_handle(), + clustering_cost.data_handle(), + batch_cost.data_handle(), + 1, + stream); } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 09347f3153..f4830dd56c 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -467,15 +467,9 @@ void checkWeights(const raft::resources& handle, const auto& comm = raft::resource::get_comms(handle); - auto n_samples = weight.extent(0); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - workspace.data(), temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); + auto n_samples = weight.extent(0); + raft::linalg::mapThenSumReduce( + wt_aggr.data(), n_samples, raft::identity_op{}, stream, weight.data_handle()); comm.allreduce(wt_aggr.data(), // sendbuff wt_aggr.data(), // recvbuff @@ -484,6 +478,7 @@ void checkWeights(const raft::resources& handle, stream); DataT wt_sum = wt_aggr.value(stream); raft::resource::sync_stream(handle, stream); + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); if (wt_sum != n_samples) { CUVS_LOG_KMEANS(handle, @@ -491,9 +486,11 @@ void checkWeights(const raft::resources& handle, "sum up to %d samples", n_samples); - DataT scale = n_samples / wt_sum; - raft::linalg::map( - handle, weight, raft::mul_const_op(scale), raft::make_const_mdspan(weight)); + raft::linalg::map(handle, + weight, + raft::compose_op(raft::mul_const_op{static_cast(n_samples)}, + raft::div_const_op{wt_sum}), + raft::make_const_mdspan(weight)); } } @@ -704,49 +701,45 @@ void fit(const raft::resources& handle, raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); bool done = false; - if (params.inertia_check) { - rmm::device_scalar> clusterCostD(stream); - - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::cluster_cost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); - - // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); - - DataT curClusteringCost = 0; - raft::copy(handle, - raft::make_host_scalar_view(&curClusteringCost), - raft::make_device_scalar_view(&(clusterCostD.data()->value))); - - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers\n"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; + rmm::device_scalar> clusterCostD(stream); + + // calculate cluster cost phi_x(C) + cuvs::cluster::kmeans::cluster_cost( + handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + cuda::proclaim_return_type>( + [] __device__(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) { + raft::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + })); + + // Cluster cost phi_x(C) from all ranks + comm.allreduce(&(clusterCostD.data()->value), + &(clusterCostD.data()->value), + 1, + raft::comms::op_t::SUM, + stream); + + DataT curClusteringCost = 0; + raft::copy(handle, + raft::make_host_scalar_view(&curClusteringCost), + raft::make_device_scalar_view(&(clusterCostD.data()->value))); + + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + if (curClusteringCost == (DataT)0.0) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_iter[0] > 1) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; } + priorClusteringCost = curClusteringCost; raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index e4f9821990..12d3cc2c4b 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -324,39 +324,6 @@ void cluster_cost(raft::resources const& handle, handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); } -/** - * @brief Update centroids given current centroids and number of points assigned to each centroid. - * This function also produces a vector of RAFT key/value pairs containing the cluster assignment - * for each point and its distance. - * - * @tparam DataT - * @tparam IndexT - * @param[in] handle: Raft handle to use for managing library resources - * @param[in] X: input matrix (size n_samples, n_features) - * @param[in] sample_weights: number of samples currently assigned to each centroid (size n_samples) - * @param[in] centroids: matrix of current centroids (size n_clusters, n_features) - * @param[in] labels: Iterator of labels (can also be a raw pointer) - * @param[out] weight_per_cluster: sum of sample weights per cluster (size n_clusters) - * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - LabelsIterator labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids) -{ - // TODO: Passing these into the algorithm doesn't really present much of a benefit - // because they are being resized anyways. - // ref https://github.com/rapidsai/raft/issues/930 - rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); - - cuvs::cluster::kmeans::detail::update_centroids( - handle, X, sample_weights, centroids, labels, weight_per_cluster, new_centroids, workspace); -} - /** * @brief Compute distance for every sample to it's nearest centroid * diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index d7e4748e33..51cd21cb51 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -72,7 +71,7 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( + cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index f86fabcfbd..000774b9c6 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -72,7 +71,7 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( + cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index 437aa16c76..9be6aac929 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -18,8 +18,12 @@ void fit_main(raft::resources const& handle, raft::host_scalar_view n_iter, rmm::device_uvector& workspace) { - cuvs::cluster::kmeans::detail::kmeans_fit_main( - handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); + cuvs::cluster::kmeans::params p = params; + p.init = kmeans::params::InitMethod::Array; + auto sw = std::make_optional( + raft::make_device_vector_view(sample_weights.data_handle(), X.extent(0))); + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, p, X, sw, centroids, inertia, n_iter, std::ref(workspace)); } template @@ -31,7 +35,6 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise if (raft::resource::comms_initialized(handle)) { cuvs::cluster::kmeans::mg::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } else { @@ -54,4 +57,17 @@ void predict(raft::resources const& handle, handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } +template +void fit(raft::resources const& handle, + const kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 2ff8195b8b..d7eb46c5a3 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -1329,6 +1329,9 @@ auto build(raft::resources const& handle, rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, impl->n_lists(), impl->dim()); + if (impl->metric() == distance::DistanceType::CosineExpanded) { + raft::linalg::row_normalize(handle, centers_const_view, centers_view); + } auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); cuvs::cluster::kmeans::predict( diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 1ef8d07623..08aaa0949a 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -6,6 +6,8 @@ #include "../test_utils.cuh" #include +#include +#include #include #include #include @@ -17,8 +19,6 @@ #include -#include - #include #include @@ -361,34 +361,23 @@ struct KmeansBatchedInputs { template class KmeansFitBatchedTest : public ::testing::TestWithParam> { protected: - KmeansFitBatchedTest() - : d_labels(0, raft::resource::get_cuda_stream(handle)), - d_labels_ref(0, raft::resource::get_cuda_stream(handle)), - d_centroids(0, raft::resource::get_cuda_stream(handle)), - d_centroids_ref(0, raft::resource::get_cuda_stream(handle)) - { - } + KmeansFitBatchedTest() = default; - void fitBatchedTest() + void prepareBlobInputs() { - testparams = ::testing::TestWithParam>::GetParam(); - - int n_samples = testparams.n_row; - int n_features = testparams.n_col; - params.n_clusters = testparams.n_clusters; - params.tol = testparams.tol; - params.rng_state.seed = 1; - params.oversampling_factor = 0; - - auto stream = raft::resource::get_cuda_stream(handle); - auto X = raft::make_device_matrix(handle, n_samples, n_features); - auto labels = raft::make_device_vector(handle, n_samples); - - raft::random::make_blobs(X.data_handle(), - labels.data_handle(), + testparams = ::testing::TestWithParam>::GetParam(); + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + int n_clusters = testparams.n_clusters; + auto stream = raft::resource::get_cuda_stream(handle); + + d_X.emplace(raft::make_device_matrix(handle, n_samples, n_features)); + d_labels_ref.emplace(raft::make_device_vector(handle, n_samples)); + raft::random::make_blobs(d_X->data_handle(), + d_labels_ref->data_handle(), n_samples, n_features, - params.n_clusters, + n_clusters, stream, true, nullptr, @@ -399,64 +388,76 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam h_X(n_samples * n_features); - raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + h_X.emplace(raft::make_host_matrix(n_samples, n_features)); + raft::update_host( + h_X->data_handle(), d_X->data_handle(), static_cast(n_samples) * n_features, stream); + + if (testparams.weighted) { + d_sample_weight.emplace(raft::make_device_vector(handle, n_samples)); + raft::matrix::fill(handle, d_sample_weight->view(), T(1)); + } else { + d_sample_weight.reset(); + } + raft::resource::sync_stream(handle, stream); + } - d_labels.resize(n_samples, stream); - d_labels_ref.resize(n_samples, stream); - d_centroids.resize(params.n_clusters * n_features, stream); - d_centroids_ref.resize(params.n_clusters * n_features, stream); - raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); + std::optional> d_sw_view() const + { + if (!d_sample_weight.has_value()) return std::nullopt; + return std::make_optional(raft::make_const_mdspan(d_sample_weight->view())); + } + + void fitBatchedTest() + { + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + auto stream = raft::resource::get_cuda_stream(handle); + + d_labels.emplace(raft::make_device_vector(handle, n_samples)); + d_centroids.emplace(raft::make_device_matrix(handle, params.n_clusters, n_features)); + d_centroids_ref.emplace( + raft::make_device_matrix(handle, params.n_clusters, n_features)); raft::random::RngState rng(params.rng_state.seed); raft::random::uniform( - handle, rng, d_centroids.data(), params.n_clusters * n_features, T(-1), T(1)); - raft::copy(d_centroids_ref.data(), d_centroids.data(), params.n_clusters * n_features, stream); - - auto h_X_view = - raft::make_host_matrix_view(h_X.data(), n_samples, n_features); - auto d_centroids_view = - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + handle, rng, d_centroids->data_handle(), params.n_clusters * n_features, T(-1), T(1)); + raft::copy(d_centroids_ref->data_handle(), + d_centroids->data_handle(), + params.n_clusters * n_features, + stream); - std::optional> d_sw = std::nullopt; - rmm::device_uvector d_sample_weight(0, stream); - if (testparams.weighted) { - d_sample_weight.resize(n_samples, stream); - d_sw = std::make_optional( - raft::make_device_vector_view(d_sample_weight.data(), n_samples)); - raft::matrix::fill( - handle, raft::make_device_vector_view(d_sample_weight.data(), n_samples), T(1)); - } + auto d_centroids_view = raft::make_device_matrix_view( + d_centroids->data_handle(), params.n_clusters, n_features); - auto d_centroids_ref_view = - raft::make_device_matrix_view(d_centroids_ref.data(), params.n_clusters, n_features); + auto d_sw = d_sw_view(); - params.init = cuvs::cluster::kmeans::params::Array; - params.inertia_check = true; - params.max_iter = 20; + params.init = cuvs::cluster::kmeans::params::Array; + params.max_iter = 20; T ref_inertia = 0; int ref_n_iter = 0; cuvs::cluster::kmeans::fit(handle, params, - raft::make_const_mdspan(X.view()), + raft::make_const_mdspan(d_X->view()), d_sw, - d_centroids_ref_view, + d_centroids_ref->view(), raft::make_host_scalar_view(&ref_inertia), raft::make_host_scalar_view(&ref_n_iter)); cuvs::cluster::kmeans::params batched_params = params; - batched_params.inertia_check = true; batched_params.streaming_batch_size = testparams.streaming_batch_size; std::optional> h_sw = std::nullopt; - std::vector h_sample_weight; + auto h_sample_weight = raft::make_host_vector(testparams.weighted ? n_samples : 0); if (testparams.weighted) { - h_sample_weight.resize(n_samples, T(1)); - h_sw = std::make_optional( - raft::make_host_vector_view(h_sample_weight.data(), n_samples)); + std::fill_n(h_sample_weight.data_handle(), n_samples, T(1)); + h_sw = std::make_optional(raft::make_const_mdspan(h_sample_weight.view())); } T inertia = 0; @@ -464,7 +465,7 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamview()), h_sw, d_centroids_view, raft::make_host_scalar_view(&inertia), @@ -472,48 +473,46 @@ class KmeansFitBatchedTest : public ::testing::TestWithParamdata_handle(), + d_centroids->data_handle(), params.n_clusters, n_features, CompareApprox(T(1e-2)), stream); T ref_pred_inertia = 0; - cuvs::cluster::kmeans::predict( - handle, - params, - raft::make_const_mdspan(X.view()), - d_sw, - raft::make_device_matrix_view( - d_centroids_ref.data(), params.n_clusters, n_features), - raft::make_device_vector_view(d_labels_ref.data(), n_samples), - true, - raft::make_host_scalar_view(&ref_pred_inertia)); + cuvs::cluster::kmeans::predict(handle, + params, + raft::make_const_mdspan(d_X->view()), + d_sw, + raft::make_const_mdspan(d_centroids_ref->view()), + d_labels_ref->view(), + true, + raft::make_host_scalar_view(&ref_pred_inertia)); T pred_inertia = 0; - cuvs::cluster::kmeans::predict( - handle, - params, - raft::make_const_mdspan(X.view()), - d_sw, - raft::make_device_matrix_view( - d_centroids.data(), params.n_clusters, n_features), - raft::make_device_vector_view(d_labels.data(), n_samples), - true, - raft::make_host_scalar_view(&pred_inertia)); + cuvs::cluster::kmeans::predict(handle, + params, + raft::make_const_mdspan(d_X->view()), + d_sw, + raft::make_const_mdspan(d_centroids->view()), + d_labels->view(), + true, + raft::make_host_scalar_view(&pred_inertia)); raft::resource::sync_stream(handle, stream); - score = raft::stats::adjusted_rand_index( - d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); + score = raft::stats::adjusted_rand_index(d_labels_ref->data_handle(), + d_labels->data_handle(), + n_samples, + raft::resource::get_cuda_stream(handle)); if (score < 0.99) { std::stringstream ss; - ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + ss << "Expected: " << raft::arr2Str(d_labels_ref->data_handle(), 25, "d_labels_ref", stream); std::cout << (ss.str().c_str()) << '\n'; ss.str(std::string()); - ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + ss << "Actual: " << raft::arr2Str(d_labels->data_handle(), 25, "d_labels", stream); std::cout << (ss.str().c_str()) << '\n'; std::cout << "Score = " << score << '\n'; } @@ -525,15 +524,158 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(handle, n_clusters, n_features); + T inertia = 0; + int64_t n_iter = 0; + cuvs::cluster::kmeans::fit( + handle, + p, + raft::make_const_mdspan(h_X->view()), + std::optional>{std::nullopt}, + d_centroids_buf.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + raft::resource::sync_stream(handle, stream); + return inertia; + } + + void runInitSizeCompare() + { + int n_samples = testparams.n_row; + int n_clusters = testparams.n_clusters; + int default_init_size = std::min(3 * n_clusters, n_samples); + + T inertia_default = fitHostWithInitSize(0); + T inertia_explicit = fitHostWithInitSize(default_init_size); + T inertia_full = fitHostWithInitSize(n_samples); + + ASSERT_TRUE(std::isfinite(inertia_default)); + ASSERT_TRUE(std::isfinite(inertia_explicit)); + ASSERT_TRUE(std::isfinite(inertia_full)); + ASSERT_GT(inertia_default, T(0)); + ASSERT_GT(inertia_explicit, T(0)); + ASSERT_GT(inertia_full, T(0)); + + const T rel = T(1e-5); + + // init_size = 0 must resolve to the documented default (min(3*k, n)); + // feeding that value explicitly should reproduce the same inertia. + ASSERT_NEAR(inertia_default, inertia_explicit, std::abs(inertia_default) * rel); + + // Full-dataset seeding has at least as much information as the subsample + // default, so the converged inertia should not be worse. + ASSERT_LE(inertia_full, inertia_default * (T(1) + rel)); + } + + T fitKMeansPlusPlus(int n_init_value) + { + int n_features = testparams.n_col; + int n_clusters = testparams.n_clusters; + auto stream = raft::resource::get_cuda_stream(handle); + + cuvs::cluster::kmeans::params p; + p.n_clusters = n_clusters; + p.tol = testparams.tol; + p.n_init = n_init_value; + p.init = cuvs::cluster::kmeans::params::KMeansPlusPlus; + p.max_iter = 10; + p.rng_state.seed = 7; + p.oversampling_factor = 0; + + auto d_centroids_buf = raft::make_device_matrix(handle, n_clusters, n_features); + T inertia = 0; + int n_iter = 0; + cuvs::cluster::kmeans::fit(handle, + p, + raft::make_const_mdspan(d_X->view()), + std::optional>{std::nullopt}, + d_centroids_buf.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + raft::resource::sync_stream(handle, stream); + return inertia; + } + + void runMultiSeedCheck() + { + T inertia1 = fitKMeansPlusPlus(1); + T inertia3 = fitKMeansPlusPlus(3); + ASSERT_GT(inertia1, T(0)); + ASSERT_GT(inertia3, T(0)); + // n_init > 1 keeps the best trial, so it should not be worse than + // n_init == 1. + const T rel = T(1e-5); + ASSERT_LE(inertia3, inertia1 * (T(1) + rel)); + } + + void runZeroCost() + { + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + int n_clusters = n_samples; + auto stream = raft::resource::get_cuda_stream(handle); + + auto d_centroids_buf = raft::make_device_matrix(handle, n_clusters, n_features); + raft::copy(d_centroids_buf.data_handle(), + d_X->data_handle(), + static_cast(n_samples) * n_features, + stream); + + cuvs::cluster::kmeans::params p; + p.n_clusters = n_clusters; + p.tol = testparams.tol; + p.n_init = 1; + p.init = cuvs::cluster::kmeans::params::Array; + p.max_iter = 5; + p.rng_state.seed = 1; + p.oversampling_factor = 0; + + T inertia = 0; + int n_iter = 0; + ASSERT_NO_THROW(cuvs::cluster::kmeans::fit( + handle, + p, + raft::make_const_mdspan(d_X->view()), + std::optional>{std::nullopt}, + d_centroids_buf.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter))); + raft::resource::sync_stream(handle, stream); + + // The expanded L2 formula ||x||^2 - 2 x.c + ||c||^2 does not cancel to + // exactly 0 even when x == c due to float roundoff. The largest residual + // inertia observed across parameterized blob shapes is ~27.95; use an + // absolute upper bound of 100 for headroom. + ASSERT_LE(inertia, T(100)); + } protected: raft::resources handle; KmeansBatchedInputs testparams; - rmm::device_uvector d_labels; - rmm::device_uvector d_labels_ref; - rmm::device_uvector d_centroids; - rmm::device_uvector d_centroids_ref; + std::optional> d_labels; + std::optional> d_labels_ref; + std::optional> d_centroids; + std::optional> d_centroids_ref; + std::optional> d_X; + std::optional> d_sample_weight; + std::optional> h_X; double score; testing::AssertionResult centroids_match = testing::AssertionSuccess(); bool inertia_match = false; @@ -570,16 +712,26 @@ typedef KmeansFitBatchedTest KmeansFitBatchedTestD; TEST_P(KmeansFitBatchedTestF, Result) { + prepareBlobInputs(); + fitBatchedTest(); ASSERT_TRUE(centroids_match); ASSERT_TRUE(score >= 0.99); ASSERT_TRUE(inertia_match); + runInitSizeCompare(); + runMultiSeedCheck(); + runZeroCost(); } TEST_P(KmeansFitBatchedTestD, Result) { + prepareBlobInputs(); + fitBatchedTest(); ASSERT_TRUE(centroids_match); ASSERT_TRUE(score >= 0.99); ASSERT_TRUE(inertia_match); + runInitSizeCompare(); + runMultiSeedCheck(); + runZeroCost(); } INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, @@ -588,4 +740,5 @@ INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, KmeansFitBatchedTestD, ::testing::ValuesIn(batched_inputsd2)); + } // namespace cuvs diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 6d0c878660..975ef386df 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -22,6 +22,11 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED + # NOTE: The Python binding currently targets the unsuffixed cuvsKMeansParams + # ABI (which still carries the deprecated `inertia_check` field). In cuVS + # 26.08 this struct/entry-point set will be replaced by the contents of + # cuvsKMeansParams_v2 -- once that lands, the `inertia_check` field below + # should be deleted. ctypedef struct cuvsKMeansParams: cuvsDistanceType metric, int n_clusters, @@ -33,9 +38,10 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: int batch_samples, int batch_centroids, bool inertia_check, - int64_t streaming_batch_size, bool hierarchical, - int hierarchical_n_iters + int hierarchical_n_iters, + int64_t streaming_batch_size, + int64_t init_size ctypedef cuvsKMeansParams* cuvsKMeansParams_t diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b267c908c9..2e9046b4b2 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -4,6 +4,7 @@ # # cython: language_level=3 +import warnings from collections import namedtuple import numpy as np @@ -77,7 +78,13 @@ cdef class KMeansParams: batch_centroids : int Number of centroids to process in each batch. If 0, uses n_clusters. inertia_check : bool - If True, check inertia during iterations for early convergence. + Deprecated and ignored. Will be + removed in a future release. Inertia-based convergence checking + always runs. + init_size : int + Number of samples to draw for KMeansPlusPlus initialization with + host (out-of-core) data. When set to 0, uses the heuristic + min(3 * n_clusters, n_samples). Default: 0. streaming_batch_size : int Number of samples to process per GPU batch when fitting with host (numpy) data. When set to 0, defaults to n_samples (process all @@ -112,6 +119,7 @@ cdef class KMeansParams: batch_samples=None, batch_centroids=None, inertia_check=None, + init_size=None, streaming_batch_size=None, hierarchical=None, hierarchical_n_iters=None): @@ -135,7 +143,13 @@ cdef class KMeansParams: if batch_centroids is not None: self.params.batch_centroids = batch_centroids if inertia_check is not None: - self.params.inertia_check = inertia_check + warnings.warn( + "KMeansParams `inertia_check` is deprecated and ignored; " + "inertia-based convergence checking always runs.", + FutureWarning + ) + if init_size is not None: + self.params.init_size = init_size if streaming_batch_size is not None: self.params.streaming_batch_size = streaming_batch_size if hierarchical is not None: @@ -183,8 +197,8 @@ cdef class KMeansParams: return self.params.batch_centroids @property - def inertia_check(self): - return self.params.inertia_check + def init_size(self): + return self.params.init_size @property def streaming_batch_size(self):