From b9e49a89649953cd460d1675a2519cddb151a681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Tue, 14 Apr 2026 18:55:02 -0300 Subject: [PATCH 1/3] fix: Reduce K-means memory from O(n*k*d) to O(n*k) and fix empty cluster NaN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes: 1. calculate_inertia: use ||x-c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ to compute distances via matrix multiply instead of broadcasting a {runs, k*n, d} tensor. Peak memory drops from O(n*k*d) to O(n*k). Measured RSS with EXLA (1536-dim embeddings, k=20): n=100: 1604MB → 146MB (11x) n=500: 7768MB → 755MB (10x) n=1000: OOM → 707MB 2. Centroid update: when a cluster has zero members, keep the previous centroid instead of computing 0/0 = NaN. This is the standard fix used by scikit-learn. Without it, NaN propagates through all subsequent iterations, causing K-means to exit after 1 iteration. --- lib/scholar/cluster/k_means.ex | 41 +++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/lib/scholar/cluster/k_means.ex b/lib/scholar/cluster/k_means.ex index e826a709..a47c9a62 100644 --- a/lib/scholar/cluster/k_means.ex +++ b/lib/scholar/cluster/k_means.ex @@ -181,10 +181,19 @@ defmodule Scholar.Cluster.KMeans do broadcast_weights group_sizes = Nx.sum(group_masks, axes: [2], keep_axes: true) + empty_clusters = group_sizes == 0 - centroids = + new_centroids = ((Nx.new_axis(group_masks, -1) * Nx.new_axis(broadcast_x, 1)) |> Nx.sum(axes: [2])) / - group_sizes + Nx.max(group_sizes, 1) + + # Keep previous centroid for empty clusters instead of NaN from 0/0 + centroids = + Nx.select( + Nx.broadcast(empty_clusters, Nx.shape(new_centroids)), + previous_iteration_centroids, + new_centroids + ) distance = Scholar.Metrics.Distance.squared_euclidean(centroids, previous_iteration_centroids, @@ -228,22 +237,22 @@ defmodule Scholar.Cluster.KMeans do end end - defnp calculate_inertia(x, centroids, num_clusters, num_runs) do - {num_samples, num_features} = Nx.shape(x) - - modified_centroids = - centroids - |> Nx.new_axis(2) - |> Nx.broadcast({num_runs, num_clusters, num_samples, num_features}) - |> Nx.reshape({num_runs, num_clusters * num_samples, num_features}) + defnp calculate_inertia(x, centroids, _num_clusters, _num_runs) do + # Use the identity ||x - c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ + # to compute distances via matrix multiply instead of broadcasting. + # Peak memory is O(runs*k*n) instead of O(runs*k*n*d). + # + # This expansion has slightly more floating-point cancellation than + # direct subtraction for nearby points, but the difference is negligible + # for argmin-based cluster assignment. + x_sq = Nx.sum(x * x, axes: [1]) + c_sq = Nx.sum(centroids * centroids, axes: [2]) + dot = Nx.dot(centroids, [2], x, [1]) inertia_for_centroids = - Scholar.Metrics.Distance.squared_euclidean( - Nx.tile(x, [num_runs, num_clusters, 1]), - modified_centroids, - axes: [2] - ) - |> Nx.reshape({num_runs, num_clusters, num_samples}) + Nx.new_axis(Nx.new_axis(x_sq, 0), 0) + + Nx.new_axis(c_sq, 2) - + 2 * dot {inertia_for_centroids, Nx.reduce_min(inertia_for_centroids, axes: [1])} end From 42c0a94a29e790a7586b5d580878f3c364cc0bba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 27 Apr 2026 11:15:36 -0300 Subject: [PATCH 2/3] test: update k-means expected cluster ordering after empty-cluster fix The empty-cluster fix changes which run wins the lowest-inertia tie-break across num_runs initializations, swapping the label permutation. Pin tests and doctests to the new ordering. --- lib/scholar/cluster/k_means.ex | 8 ++++---- test/scholar/cluster/k_means_test.exs | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/scholar/cluster/k_means.ex b/lib/scholar/cluster/k_means.ex index a47c9a62..0441b4b9 100644 --- a/lib/scholar/cluster/k_means.ex +++ b/lib/scholar/cluster/k_means.ex @@ -116,15 +116,15 @@ defmodule Scholar.Cluster.KMeans do ), clusters: Nx.tensor( [ - [1.0, 2.5], - [2.0, 4.5] + [2.0, 4.5], + [1.0, 2.5] ] ), inertia: Nx.tensor( 1.0 ), labels: Nx.tensor( - [0, 1, 0, 1] + [1, 0, 1, 0] ) } """ @@ -307,7 +307,7 @@ defmodule Scholar.Cluster.KMeans do iex> model = Scholar.Cluster.KMeans.fit(x, num_clusters: 2, key: key) iex> Scholar.Cluster.KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) Nx.tensor( - [1, 0] + [0, 1] ) """ defn predict(%__MODULE__{clusters: clusters} = _model, x) do diff --git a/test/scholar/cluster/k_means_test.exs b/test/scholar/cluster/k_means_test.exs index bff2b1af..556aa8a8 100644 --- a/test/scholar/cluster/k_means_test.exs +++ b/test/scholar/cluster/k_means_test.exs @@ -15,13 +15,13 @@ defmodule Scholar.Cluster.KMeansTest do key: key() ) - assert model.clusters == Nx.tensor([[1.0, 2.5], [2.0, 4.5]]) + assert model.clusters == Nx.tensor([[2.0, 4.5], [1.0, 2.5]]) assert model.inertia == Nx.tensor(1.0, type: {:f, 32}) - assert model.labels == Nx.tensor([0, 1, 0, 1]) + assert model.labels == Nx.tensor([1, 0, 1, 0]) assert model.num_iterations == Nx.tensor(2) predictions = KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) - assert predictions == Nx.tensor([1, 0]) + assert predictions == Nx.tensor([0, 1]) end test "fit and predict without weights and :init set as :random" do @@ -49,13 +49,13 @@ defmodule Scholar.Cluster.KMeansTest do weights: [1, 2, 3, 4] ) - assert model.clusters == Nx.tensor([[1.0, 2.75], [2.0, 4.75]]) + assert model.clusters == Nx.tensor([[2.0, 4.75], [1.0, 2.75]]) assert model.inertia == Nx.tensor(1.5, type: {:f, 32}) - assert model.labels == Nx.tensor([0, 1, 0, 1]) + assert model.labels == Nx.tensor([1, 0, 1, 0]) assert model.num_iterations == Nx.tensor(2) predictions = KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) - assert predictions == Nx.tensor([1, 0]) + assert predictions == Nx.tensor([0, 1]) end test "fit and predict with weights as a tensor" do @@ -66,13 +66,13 @@ defmodule Scholar.Cluster.KMeansTest do weights: Nx.tensor([1, 2, 3, 4], type: {:f, 32}) ) - assert model.clusters == Nx.tensor([[1.0, 2.75], [2.0, 4.75]]) + assert model.clusters == Nx.tensor([[2.0, 4.75], [1.0, 2.75]]) assert model.inertia == Nx.tensor(1.5, type: {:f, 32}) - assert model.labels == Nx.tensor([0, 1, 0, 1]) + assert model.labels == Nx.tensor([1, 0, 1, 0]) assert model.num_iterations == Nx.tensor(2) predictions = KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) - assert predictions == Nx.tensor([1, 0]) + assert predictions == Nx.tensor([0, 1]) end test "transform" do From 44408ab81d00f8bdd1f0d4db29f35ffb0064b4fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 27 Apr 2026 11:23:17 -0300 Subject: [PATCH 3/3] fix: restore inf for NaN distances in k-means++ initialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The new ||x||² + ||c||² - 2·x·cᵀ expansion produces inf - inf = NaN when k-means++ pads unused centroid slots with infinity, breaking the weighted sampling that picks the next initial centroid. The original direct ||x - c||² formula returned inf naturally. Mapping NaN back to inf restores k-means++ behavior. With this fix, clustering matches main exactly, so revert the test pinning from the previous commit. --- lib/scholar/cluster/k_means.ex | 21 +++++++++++++-------- test/scholar/cluster/k_means_test.exs | 18 +++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/lib/scholar/cluster/k_means.ex b/lib/scholar/cluster/k_means.ex index 0441b4b9..91077495 100644 --- a/lib/scholar/cluster/k_means.ex +++ b/lib/scholar/cluster/k_means.ex @@ -116,15 +116,15 @@ defmodule Scholar.Cluster.KMeans do ), clusters: Nx.tensor( [ - [2.0, 4.5], - [1.0, 2.5] + [1.0, 2.5], + [2.0, 4.5] ] ), inertia: Nx.tensor( 1.0 ), labels: Nx.tensor( - [1, 0, 1, 0] + [0, 1, 0, 1] ) } """ @@ -241,10 +241,6 @@ defmodule Scholar.Cluster.KMeans do # Use the identity ||x - c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ # to compute distances via matrix multiply instead of broadcasting. # Peak memory is O(runs*k*n) instead of O(runs*k*n*d). - # - # This expansion has slightly more floating-point cancellation than - # direct subtraction for nearby points, but the difference is negligible - # for argmin-based cluster assignment. x_sq = Nx.sum(x * x, axes: [1]) c_sq = Nx.sum(centroids * centroids, axes: [2]) dot = Nx.dot(centroids, [2], x, [1]) @@ -254,6 +250,15 @@ defmodule Scholar.Cluster.KMeans do Nx.new_axis(c_sq, 2) - 2 * dot + # k-means++ pads unused centroid slots with infinity. The expansion + # produces inf - inf = NaN there; restore inf so weighted sampling works. + inertia_for_centroids = + Nx.select( + Nx.is_nan(inertia_for_centroids), + Nx.Constants.infinity(Nx.type(inertia_for_centroids)), + inertia_for_centroids + ) + {inertia_for_centroids, Nx.reduce_min(inertia_for_centroids, axes: [1])} end @@ -307,7 +312,7 @@ defmodule Scholar.Cluster.KMeans do iex> model = Scholar.Cluster.KMeans.fit(x, num_clusters: 2, key: key) iex> Scholar.Cluster.KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) Nx.tensor( - [0, 1] + [1, 0] ) """ defn predict(%__MODULE__{clusters: clusters} = _model, x) do diff --git a/test/scholar/cluster/k_means_test.exs b/test/scholar/cluster/k_means_test.exs index 556aa8a8..bff2b1af 100644 --- a/test/scholar/cluster/k_means_test.exs +++ b/test/scholar/cluster/k_means_test.exs @@ -15,13 +15,13 @@ defmodule Scholar.Cluster.KMeansTest do key: key() ) - assert model.clusters == Nx.tensor([[2.0, 4.5], [1.0, 2.5]]) + assert model.clusters == Nx.tensor([[1.0, 2.5], [2.0, 4.5]]) assert model.inertia == Nx.tensor(1.0, type: {:f, 32}) - assert model.labels == Nx.tensor([1, 0, 1, 0]) + assert model.labels == Nx.tensor([0, 1, 0, 1]) assert model.num_iterations == Nx.tensor(2) predictions = KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) - assert predictions == Nx.tensor([0, 1]) + assert predictions == Nx.tensor([1, 0]) end test "fit and predict without weights and :init set as :random" do @@ -49,13 +49,13 @@ defmodule Scholar.Cluster.KMeansTest do weights: [1, 2, 3, 4] ) - assert model.clusters == Nx.tensor([[2.0, 4.75], [1.0, 2.75]]) + assert model.clusters == Nx.tensor([[1.0, 2.75], [2.0, 4.75]]) assert model.inertia == Nx.tensor(1.5, type: {:f, 32}) - assert model.labels == Nx.tensor([1, 0, 1, 0]) + assert model.labels == Nx.tensor([0, 1, 0, 1]) assert model.num_iterations == Nx.tensor(2) predictions = KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) - assert predictions == Nx.tensor([0, 1]) + assert predictions == Nx.tensor([1, 0]) end test "fit and predict with weights as a tensor" do @@ -66,13 +66,13 @@ defmodule Scholar.Cluster.KMeansTest do weights: Nx.tensor([1, 2, 3, 4], type: {:f, 32}) ) - assert model.clusters == Nx.tensor([[2.0, 4.75], [1.0, 2.75]]) + assert model.clusters == Nx.tensor([[1.0, 2.75], [2.0, 4.75]]) assert model.inertia == Nx.tensor(1.5, type: {:f, 32}) - assert model.labels == Nx.tensor([1, 0, 1, 0]) + assert model.labels == Nx.tensor([0, 1, 0, 1]) assert model.num_iterations == Nx.tensor(2) predictions = KMeans.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]])) - assert predictions == Nx.tensor([0, 1]) + assert predictions == Nx.tensor([1, 0]) end test "transform" do