diff --git a/lib/scholar/cluster/k_means.ex b/lib/scholar/cluster/k_means.ex index e826a709..91077495 100644 --- a/lib/scholar/cluster/k_means.ex +++ b/lib/scholar/cluster/k_means.ex @@ -181,10 +181,19 @@ defmodule Scholar.Cluster.KMeans do broadcast_weights group_sizes = Nx.sum(group_masks, axes: [2], keep_axes: true) + empty_clusters = group_sizes == 0 - centroids = + new_centroids = ((Nx.new_axis(group_masks, -1) * Nx.new_axis(broadcast_x, 1)) |> Nx.sum(axes: [2])) / - group_sizes + Nx.max(group_sizes, 1) + + # Keep previous centroid for empty clusters instead of NaN from 0/0 + centroids = + Nx.select( + Nx.broadcast(empty_clusters, Nx.shape(new_centroids)), + previous_iteration_centroids, + new_centroids + ) distance = Scholar.Metrics.Distance.squared_euclidean(centroids, previous_iteration_centroids, @@ -228,22 +237,27 @@ defmodule Scholar.Cluster.KMeans do end end - defnp calculate_inertia(x, centroids, num_clusters, num_runs) do - {num_samples, num_features} = Nx.shape(x) + defnp calculate_inertia(x, centroids, _num_clusters, _num_runs) do + # Use the identity ||x - c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ + # to compute distances via matrix multiply instead of broadcasting. + # Peak memory is O(runs*k*n) instead of O(runs*k*n*d). + x_sq = Nx.sum(x * x, axes: [1]) + c_sq = Nx.sum(centroids * centroids, axes: [2]) + dot = Nx.dot(centroids, [2], x, [1]) - modified_centroids = - centroids - |> Nx.new_axis(2) - |> Nx.broadcast({num_runs, num_clusters, num_samples, num_features}) - |> Nx.reshape({num_runs, num_clusters * num_samples, num_features}) + inertia_for_centroids = + Nx.new_axis(Nx.new_axis(x_sq, 0), 0) + + Nx.new_axis(c_sq, 2) - + 2 * dot + # k-means++ pads unused centroid slots with infinity. The expansion + # produces inf - inf = NaN there; restore inf so weighted sampling works. inertia_for_centroids = - Scholar.Metrics.Distance.squared_euclidean( - Nx.tile(x, [num_runs, num_clusters, 1]), - modified_centroids, - axes: [2] + Nx.select( + Nx.is_nan(inertia_for_centroids), + Nx.Constants.infinity(Nx.type(inertia_for_centroids)), + inertia_for_centroids ) - |> Nx.reshape({num_runs, num_clusters, num_samples}) {inertia_for_centroids, Nx.reduce_min(inertia_for_centroids, axes: [1])} end