diff --git a/chelombus/clustering/PyQKmeans.py b/chelombus/clustering/PyQKmeans.py index e08091d..b8de301 100644 --- a/chelombus/clustering/PyQKmeans.py +++ b/chelombus/clustering/PyQKmeans.py @@ -314,12 +314,16 @@ def _fit_gpu(self, X_train: np.ndarray, return_labels: bool = False) -> np.ndarr verbose=self.verbose, ) - def predict(self, X: np.ndarray, device: str = 'auto') -> np.ndarray: + def predict(self, X: np.ndarray, device: str = 'auto', + batch_size: int = 0) -> np.ndarray: """Predict cluster labels for PQ codes. Args: X: PQ codes of shape (n_samples, n_subvectors), dtype uint8 device: 'cpu' for Numba, 'gpu' for Triton/CUDA, 'auto' to pick GPU if available. + batch_size: GPU-only. Max points per GPU batch. 0 (default) = + auto-detect from free VRAM. Set a manual cap to bound peak + VRAM on large N (e.g. N > 1B on 16 GB cards). Returns: Cluster labels of shape (n_samples,) @@ -335,7 +339,7 @@ def predict(self, X: np.ndarray, device: str = 'auto') -> np.ndarray: if use_gpu: codes = np.asarray(X, dtype=np.uint8) centers = np.asarray(self._centers_u8, dtype=np.uint8) - return predict_gpu(codes, centers, self._dtables) + return predict_gpu(codes, centers, self._dtables, batch_size=batch_size) codes = np.asarray(X, dtype=self.encoder.codebook_dtype) return _predict_numba(codes, self._centers_u8, self._dtables) diff --git a/chelombus/encoder/encoder.py b/chelombus/encoder/encoder.py index 8a9e077..f331121 100644 --- a/chelombus/encoder/encoder.py +++ b/chelombus/encoder/encoder.py @@ -164,7 +164,7 @@ def _fit_gpu(self, X_train: NDArray, verbose: int = 1) -> None: for subvector_idx in iterable: sub_slice = X_f32[:, subvector_dim * subvector_idx : subvector_dim * (subvector_idx + 1)] X_gpu = torch.from_numpy(sub_slice).cuda() - # Precompute ||x||² (stays constant across iterations) + # Precompute ||x||^2 (stays constant across iterations) x_sq = (X_gpu * X_gpu).sum(dim=1) # (N,) B = self._gpu_encoder_batch_size(N, self.k)