diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..e176455d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_qc_dask_cuda src/rapids_singlecell/_cuda/qc_dask/qc_kernels_dask.cu) add_nb_cuda_module(_bbknn_cuda src/rapids_singlecell/_cuda/bbknn/bbknn.cu) add_nb_cuda_module(_norm_cuda src/rapids_singlecell/_cuda/norm/norm.cu) + add_nb_cuda_module(_gmm_cuda src/rapids_singlecell/_cuda/gmm/gmm.cu) + target_link_libraries(_gmm_cuda PRIVATE CUDA::cublas) add_nb_cuda_module(_pr_cuda src/rapids_singlecell/_cuda/pr/pr.cu) add_nb_cuda_module(_nn_descent_cuda src/rapids_singlecell/_cuda/nn_descent/nn_descent.cu) add_nb_cuda_module(_aucell_cuda src/rapids_singlecell/_cuda/aucell/aucell.cu) diff --git a/docs/api/squidpy_gpu.md b/docs/api/squidpy_gpu.md index 99da5825..bb271054 100644 --- a/docs/api/squidpy_gpu.md +++ b/docs/api/squidpy_gpu.md @@ -13,4 +13,5 @@ gr.spatial_autocorr gr.co_occurrence gr.ligrec + gr.calculate_niche ``` diff --git a/docs/release-notes/0.15.1.md b/docs/release-notes/0.15.1.md index d5b484a2..9b9b9e97 100644 --- a/docs/release-notes/0.15.1.md +++ b/docs/release-notes/0.15.1.md @@ -1,5 +1,9 @@ ### 0.15.1 {small}`the-future` +```{rubric} Features +``` +* Add `rsc.gr.calculate_niche` with flavors `neighborhood`, `utag` , and `cellcharter`. Mirrors `squidpy.gr.calculate_niche` {pr}`644` {smaller}`S Dicks` +* Add a minimal full-covariance GMM (`squidpy_gpu._gmm.gmm_fit_predict`) used by the `cellcharter` {pr}`644` {smaller}`S Dicks` ```{rubric} Bug fixes ``` * Fixes `tl.rank_genes_groups` returning NaN/zero `logfoldchanges`/`pvals` with `groups=[subset]` and `reference='rest'` {pr}`651` {smaller}`S Dicks` diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 35e82a0d..1ca2d22a 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -20,6 +20,7 @@ "_bbknn_cuda", "_cooc_cuda", "_edistance_cuda", + "_gmm_cuda", "_harmony_clustering_cuda", "_harmony_colsum_cuda", "_harmony_correction_batched_cuda", diff --git a/src/rapids_singlecell/_cuda/gmm/gmm.cu b/src/rapids_singlecell/_cuda/gmm/gmm.cu new file mode 100644 index 00000000..b6254036 --- /dev/null +++ b/src/rapids_singlecell/_cuda/gmm/gmm.cu @@ -0,0 +1,366 @@ +#include +#include + +#include + +#include +#include +#include + +#include "../cublas_helpers.cuh" +#include "../nb_types.h" + +#include "kernels_gmm.cuh" + +using namespace nb::literals; + +constexpr int E_STEP_BLOCK = 64; +constexpr int E_STEP_LARGE64_TILE = 64; +constexpr int E_STEP_THREAD64_BLOCK = 512; +constexpr int NORMALIZE_BLOCK = 32; +constexpr size_t DEFAULT_DYNAMIC_SMEM_LIMIT = 48 * 1024; + +static inline size_t upper_tri_size(size_t d) { + return (d * (d + 1)) / 2; +} + +static inline void cuda_check_runtime(cudaError_t err, const char* what) { + if (err != cudaSuccess) { + throw std::runtime_error(std::string(what) + + " failed: " + cudaGetErrorString(err)); + } +} + +static inline void cublas_check_status(cublasStatus_t status, + const char* what) { + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error(std::string(what) + + " failed with cuBLAS status " + + std::to_string(static_cast(status))); + } +} + +template +static inline void launch_e_step_log_prob_fixed_d_impl( + const T* X, const T* weights, const T* means, const T* prec_chol, + const T* log_det_half, int n, int K, T* log_prob, dim3 grid, dim3 block, + cudaStream_t stream) { + size_t shmem = (D + upper_tri_size(D)) * sizeof(T); + if (shmem > DEFAULT_DYNAMIC_SMEM_LIMIT) { + cuda_check_runtime( + cudaFuncSetAttribute(e_step_log_prob_small_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + (int)shmem), + "cudaFuncSetAttribute(e_step_log_prob_small_kernel)"); + } + e_step_log_prob_small_kernel<<>>( + X, weights, means, prec_chol, log_det_half, n, D, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_small_kernel); +} + +template +static inline void launch_e_step(const T* X, const T* weights, const T* means, + const T* prec_chol, const T* log_det_half, + int n, int d, int K, T* log_prob, T* resp, + T* ll_per_cell, cudaStream_t stream) { + if (n == 0 || d == 0 || K == 0) return; + if (d <= 64) { + dim3 block(E_STEP_BLOCK); + dim3 grid((n + E_STEP_BLOCK - 1) / E_STEP_BLOCK, K); + if (d == 16) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else if (d == 32) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else if (d == 50) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else if (d == 64) { + launch_e_step_log_prob_fixed_d_impl( + X, weights, means, prec_chol, log_det_half, n, K, log_prob, + grid, block, stream); + } else { + size_t shmem = ((size_t)d + upper_tri_size(d)) * sizeof(T); + e_step_log_prob_small_kernel<<>>( + X, weights, means, prec_chol, log_det_half, n, d, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_small_kernel); + } + } else { + dim3 block(E_STEP_THREAD64_BLOCK); + dim3 grid((n + E_STEP_THREAD64_BLOCK - 1) / E_STEP_THREAD64_BLOCK, K); + size_t shmem = ((size_t)E_STEP_LARGE64_TILE + + (size_t)E_STEP_LARGE64_TILE * E_STEP_LARGE64_TILE) * + sizeof(T); + e_step_log_prob_large_d_thread64_kernel + <<>>(X, weights, means, prec_chol, + log_det_half, n, d, K, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_large_d_thread64_kernel); + } + { + dim3 block(NORMALIZE_BLOCK); + dim3 grid(n); + e_step_normalize_kernel + <<>>(log_prob, n, K, resp, ll_per_cell); + CUDA_CHECK_LAST_ERROR(e_step_normalize_kernel); + } +} + +template +static inline void launch_e_step_cublas(const T* X, const T* weights, + const T* means, const T* prec_chol, + const T* log_det_half, int n, int d, + int K, T* centered_workspace, + T* y_workspace, T* log_prob, T* resp, + T* ll_per_cell, cudaStream_t stream, + cublasHandle_t handle) { + if (n == 0 || d == 0 || K == 0) return; + + bool own_handle = handle == nullptr; + if (own_handle) cublas_check_status(cublasCreate(&handle), "cublasCreate"); + cublas_check_status(cublasSetStream(handle, stream), "cublasSetStream"); + + T one = T(1); + T zero = T(0); + int threads = 256; + int center_blocks = (int)(((size_t)n * d + threads - 1) / threads); + int row_blocks = (n + threads - 1) / threads; + + for (int k = 0; k < K; ++k) { + e_step_center_kernel<<>>( + X, means, n, d, k, centered_workspace); + CUDA_CHECK_LAST_ERROR(e_step_center_kernel); + + const T* pc_k = prec_chol + (size_t)k * d * d; + cublas_check_status( + cublas_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, d, n, d, &one, + pc_k, d, centered_workspace, d, &zero, y_workspace, + d), + "cublas_gemm(e_step)"); + + e_step_log_prob_from_y_kernel<<>>( + y_workspace, weights, log_det_half, n, d, K, k, log_prob); + CUDA_CHECK_LAST_ERROR(e_step_log_prob_from_y_kernel); + } + + { + dim3 block(NORMALIZE_BLOCK); + dim3 grid(n); + e_step_normalize_kernel + <<>>(log_prob, n, K, resp, ll_per_cell); + CUDA_CHECK_LAST_ERROR(e_step_normalize_kernel); + } + + if (own_handle) cublas_check_status(cublasDestroy(handle), "cublasDestroy"); +} + +template +static inline void launch_m_step(const T* resp, const T* X, const T* ones, + int n, int d, int K, T reg_covar, T* weights, + T* means, T* covariances, T* workspace_N_k, + T* workspace_num, T* workspace_centered, + cudaStream_t stream, cublasHandle_t handle) { + if (n == 0 || d == 0 || K == 0) return; + + bool own_handle = handle == nullptr; + if (own_handle) cublas_check_status(cublasCreate(&handle), "cublasCreate"); + cublas_check_status(cublasSetStream(handle, stream), "cublasSetStream"); + + T one = T(1); + T zero = T(0); + T eps = std::numeric_limits::epsilon(); + + // Row-major resp(n,K) is cuBLAS column-major (K,n). N_k = resp.T @ 1. + cublas_check_status(cublas_gemv(handle, CUBLAS_OP_N, K, n, &one, resp, K, + ones, 1, &zero, workspace_N_k, 1), + "cublas_gemv(N_k)"); + + // Row-major X(n,d) is cuBLAS column-major (d,n). Fill row-major + // workspace_num(K,d) through its column-major (d,K) view with X.T @ resp. + cublas_check_status( + cublas_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, d, K, n, &one, X, d, + resp, K, &zero, workspace_num, d), + "cublas_gemm(num)"); + + { + int threads = 256; + dim3 block(threads); + dim3 grid(K); + m_step_finalize_means_kernel<<>>( + workspace_N_k, workspace_num, weights, means, eps, n, d, K); + CUDA_CHECK_LAST_ERROR(m_step_finalize_means_kernel); + } + + { + int threads = 256; + int blocks = (int)(((size_t)n * d + threads - 1) / threads); + for (int k = 0; k < K; ++k) { + weighted_center_kernel<<>>( + X, resp, means, n, d, K, k, workspace_centered); + CUDA_CHECK_LAST_ERROR(weighted_center_kernel); + + T* cov_k = covariances + (size_t)k * d * d; + cublas_check_status( + cublas_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, d, d, n, &one, + workspace_centered, d, workspace_centered, d, + &zero, cov_k, d), + "cublas_gemm(covariance)"); + } + } + + { + int threads = 256; + dim3 block(threads); + dim3 grid(K); + m_step_finalize_cov_cublas_kernel<<>>( + workspace_N_k, covariances, reg_covar, eps, d, K); + CUDA_CHECK_LAST_ERROR(m_step_finalize_cov_cublas_kernel); + } + + if (own_handle) cublas_check_status(cublasDestroy(handle), "cublasDestroy"); +} + +template +void register_bindings(nb::module_& m) { + m.def( + "e_step", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c log_prob, gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream) { + launch_e_step(X.data(), weights.data(), means.data(), + prec_chol.data(), log_det_half.data(), n, d, K, + log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "log_prob"_a, "resp"_a, "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, + "K"_a, "stream"_a = 0); + + m.def( + "e_step", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c log_prob, + gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream) { + launch_e_step(X.data(), weights.data(), means.data(), + prec_chol.data(), log_det_half.data(), n, d, + K, log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "log_prob"_a, "resp"_a, "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, + "K"_a, "stream"_a = 0); + + m.def( + "e_step_cublas", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c centered_workspace, + gpu_array_c y_workspace, + gpu_array_c log_prob, gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream, std::uintptr_t handle) { + launch_e_step_cublas( + X.data(), weights.data(), means.data(), prec_chol.data(), + log_det_half.data(), n, d, K, centered_workspace.data(), + y_workspace.data(), log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream, + (cublasHandle_t)handle); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "centered_workspace"_a, "y_workspace"_a, "log_prob"_a, "resp"_a, + "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, "K"_a, "stream"_a = 0, + "handle"_a = 0); + + m.def( + "e_step_cublas", + [](gpu_array_c X, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c prec_chol, + gpu_array_c log_det_half, + gpu_array_c centered_workspace, + gpu_array_c y_workspace, + gpu_array_c log_prob, + gpu_array_c resp, + gpu_array_c ll_per_cell, int n, int d, int K, + std::uintptr_t stream, std::uintptr_t handle) { + launch_e_step_cublas( + X.data(), weights.data(), means.data(), prec_chol.data(), + log_det_half.data(), n, d, K, centered_workspace.data(), + y_workspace.data(), log_prob.data(), resp.data(), + ll_per_cell.data(), (cudaStream_t)stream, + (cublasHandle_t)handle); + }, + "X"_a, "weights"_a, "means"_a, "prec_chol"_a, "log_det_half"_a, + "centered_workspace"_a, "y_workspace"_a, "log_prob"_a, "resp"_a, + "ll_per_cell"_a, nb::kw_only(), "n"_a, "d"_a, "K"_a, "stream"_a = 0, + "handle"_a = 0); + + m.def( + "m_step", + [](gpu_array_c resp, + gpu_array_c X, + gpu_array_c ones, + gpu_array_c weights, gpu_array_c means, + gpu_array_c covariances, + gpu_array_c N_k_workspace, + gpu_array_c num_workspace, + gpu_array_c centered_workspace, int n, int d, int K, + float reg_covar, std::uintptr_t stream, std::uintptr_t handle) { + launch_m_step(resp.data(), X.data(), ones.data(), n, d, K, + reg_covar, weights.data(), means.data(), + covariances.data(), N_k_workspace.data(), + num_workspace.data(), + centered_workspace.data(), + (cudaStream_t)stream, (cublasHandle_t)handle); + }, + "resp"_a, "X"_a, "ones"_a, "weights"_a, "means"_a, "covariances"_a, + "N_k_workspace"_a, "num_workspace"_a, "centered_workspace"_a, + nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0, + "handle"_a = 0); + + m.def( + "m_step", + [](gpu_array_c resp, + gpu_array_c X, + gpu_array_c ones, + gpu_array_c weights, + gpu_array_c means, + gpu_array_c covariances, + gpu_array_c N_k_workspace, + gpu_array_c num_workspace, + gpu_array_c centered_workspace, int n, int d, int K, + double reg_covar, std::uintptr_t stream, std::uintptr_t handle) { + launch_m_step(resp.data(), X.data(), ones.data(), n, d, K, + reg_covar, weights.data(), means.data(), + covariances.data(), N_k_workspace.data(), + num_workspace.data(), + centered_workspace.data(), + (cudaStream_t)stream, (cublasHandle_t)handle); + }, + "resp"_a, "X"_a, "ones"_a, "weights"_a, "means"_a, "covariances"_a, + "N_k_workspace"_a, "num_workspace"_a, "centered_workspace"_a, + nb::kw_only(), "n"_a, "d"_a, "K"_a, "reg_covar"_a, "stream"_a = 0, + "handle"_a = 0); +} + +NB_MODULE(_gmm_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh new file mode 100644 index 00000000..c3b061cf --- /dev/null +++ b/src/rapids_singlecell/_cuda/gmm/kernels_gmm.cuh @@ -0,0 +1,368 @@ +#pragma once + +#include + +// ---------------------------------------------------------------------------- +// Per-(n, k) E-step log-probability. +// +// Each block (k, n_chunk) caches means[k] and prec_chol[k] in shared memory, +// then each thread computes mahalanobis for one cell against the cached +// component. Output is row-major log_prob[n, k] with the log-weight already +// folded in: +// +// y[j] = Σ_d (X[n, d] − means[k, d]) · prec_chol[k, d, j] +// mahal[n, k] = Σ_j y[j]² +// log_prob[n, k] = −0.5·d·log(2π) + log_det_half[k] − 0.5·mahal + +// log(weights[k]) +// +// A separate normalize kernel does the per-row logsumexp. +// ---------------------------------------------------------------------------- + +constexpr float LOG_2PI_F = 1.8378770664093453f; +constexpr double LOG_2PI_D = 1.8378770664093453; + +template +__device__ __forceinline__ T log_2pi_const(); +template <> +__device__ __forceinline__ float log_2pi_const() { + return LOG_2PI_F; +} +template <> +__device__ __forceinline__ double log_2pi_const() { + return LOG_2PI_D; +} + +__device__ __forceinline__ int upper_tri_col_offset(int col) { + return (col * (col + 1)) / 2; +} + +template +__global__ void e_step_log_prob_small_kernel( + const T* __restrict__ X, // (n, d) row-major + const T* __restrict__ weights, // (K,) + const T* __restrict__ means, // (K, d) + const T* __restrict__ prec_chol, // (K, d, d) row-major; upper factor + // with cov_inv = chol·cholᵀ + const T* __restrict__ log_det_half, // (K,) + int n, int d, int K, + T* __restrict__ log_prob // (n, K) +) { + static_assert(D >= 0 && D <= 64, + "GMM small E-step supports runtime d or fixed D <= 64"); + constexpr bool fixed_d = D != 0; + int dim = fixed_d ? D : d; + int k = blockIdx.y; + int n_idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x; + + extern __shared__ unsigned char smem_raw[]; + T* sh_mean = reinterpret_cast(smem_raw); + T* sh_pc = sh_mean + dim; + + // Cooperatively load means[k] and the used upper triangle of prec_chol[k] + // into shared memory. + for (int i = tid; i < dim; i += blockDim.x) + sh_mean[i] = means[(size_t)k * dim + i]; + int pc_size_dense = dim * dim; + for (int i = tid; i < pc_size_dense; i += blockDim.x) { + int row = i / dim; + int col = i - row * dim; + if (row <= col) { + sh_pc[upper_tri_col_offset(col) + row] = + prec_chol[(size_t)k * pc_size_dense + i]; + } + } + + __shared__ T sh_const; + if (tid == 0) { + sh_const = T(-0.5) * T(dim) * log_2pi_const() + log_det_half[k] + + log(weights[k]); + } + + __syncthreads(); + + if (n_idx >= n) return; + + // Compute mahal = || (X[n] - μ_k) · prec_chol[k] ||² + T centered_vals[fixed_d ? D : 64]; + if constexpr (fixed_d) { +#pragma unroll + for (int dd = 0; dd < D; ++dd) + centered_vals[dd] = X[(size_t)n_idx * D + dd] - sh_mean[dd]; + } else { + for (int dd = 0; dd < dim; ++dd) + centered_vals[dd] = X[(size_t)n_idx * dim + dd] - sh_mean[dd]; + } + + T mahal = T(0); + if constexpr (fixed_d) { +#pragma unroll + for (int j = 0; j < D; ++j) { + T y = T(0); + int pc_col = upper_tri_col_offset(j); +#pragma unroll + for (int dd = 0; dd <= j; ++dd) { + y += centered_vals[dd] * sh_pc[pc_col + dd]; + } + mahal += y * y; + } + } else { + for (int j = 0; j < dim; ++j) { + T y = T(0); + int pc_col = upper_tri_col_offset(j); + // prec_chol is the upper triangular precision factor, so entries + // below the diagonal are zero. Skip that half of the multiply. + for (int dd = 0; dd <= j; ++dd) { + y += centered_vals[dd] * sh_pc[pc_col + dd]; + } + mahal += y * y; + } + } + log_prob[(size_t)n_idx * K + k] = sh_const - T(0.5) * mahal; +} + +template +__global__ void e_step_log_prob_large_d_thread64_kernel( + const T* __restrict__ X, // (n, d) row-major + const T* __restrict__ weights, // (K,) + const T* __restrict__ means, // (K, d) + const T* __restrict__ prec_chol, // (K, d, d) row-major; upper factor + const T* __restrict__ log_det_half, // (K,) + int n, int d, int K, + T* __restrict__ log_prob // (n, K) +) { + static_assert(TILE_D == 64, + "GMM thread64 E-step expects a 64-column precision tile"); + + int k = blockIdx.y; + int row = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x; + + extern __shared__ unsigned char smem_raw[]; + T* sh_mean = reinterpret_cast(smem_raw); // (64,) + T* sh_pc = sh_mean + TILE_D; // (64, 64) + + __shared__ T sh_const; + if (tid == 0) { + sh_const = T(-0.5) * T(d) * log_2pi_const() + log_det_half[k] + + log(weights[k]); + } + + T local_mahal = T(0); + const T* pc = prec_chol + (size_t)k * d * d; + + for (int j_base = 0; j_base < d; j_base += TILE_D) { + int cols_in_tile = min(TILE_D, d - j_base); + int dd_limit = min(d, j_base + TILE_D); + T y[TILE_D]; +#pragma unroll + for (int col = 0; col < TILE_D; ++col) y[col] = T(0); + + for (int dd_base = 0; dd_base < dd_limit; dd_base += TILE_D) { + int feats_in_tile = min(TILE_D, dd_limit - dd_base); + + for (int idx = tid; idx < TILE_D; idx += blockDim.x) { + sh_mean[idx] = (idx < feats_in_tile) + ? means[(size_t)k * d + dd_base + idx] + : T(0); + } + + constexpr int pc_tile_elems = TILE_D * TILE_D; + for (int idx = tid; idx < pc_tile_elems; idx += blockDim.x) { + int feat = idx / TILE_D; + int col_local = idx - feat * TILE_D; + int dd = dd_base + feat; + int col = j_base + col_local; + T val = T(0); + if (feat < feats_in_tile && col_local < cols_in_tile && + dd <= col) { + val = pc[(size_t)dd * d + col]; + } + sh_pc[feat * TILE_D + col_local] = val; + } + + __syncthreads(); + + if (row < n) { +#pragma unroll + for (int feat = 0; feat < TILE_D; ++feat) { + if (feat >= feats_in_tile) break; + T diff = + X[(size_t)row * d + dd_base + feat] - sh_mean[feat]; +#pragma unroll + for (int col = 0; col < TILE_D; ++col) { + if (col >= cols_in_tile) break; + y[col] += diff * sh_pc[feat * TILE_D + col]; + } + } + } + + __syncthreads(); + } + + if (row < n) { +#pragma unroll + for (int col = 0; col < TILE_D; ++col) { + if (col >= cols_in_tile) break; + local_mahal += y[col] * y[col]; + } + } + } + + if (row < n) + log_prob[(size_t)row * K + k] = sh_const - T(0.5) * local_mahal; +} + +template +__global__ void e_step_center_kernel(const T* __restrict__ X, + const T* __restrict__ means, int n, int d, + int k, T* __restrict__ centered) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)n * d; + if (idx >= total) return; + + int col = idx % d; + centered[idx] = X[idx] - means[(size_t)k * d + col]; +} + +template +__global__ void e_step_log_prob_from_y_kernel( + const T* __restrict__ y, const T* __restrict__ weights, + const T* __restrict__ log_det_half, int n, int d, int K, int k, + T* __restrict__ log_prob) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + T mahal = T(0); + T compensation = T(0); + for (int col = 0; col < d; ++col) { + T v = y[(size_t)row * d + col]; + T term = v * v - compensation; + T next = mahal + term; + compensation = (next - mahal) - term; + mahal = next; + } + + T constant = + T(-0.5) * T(d) * log_2pi_const() + log_det_half[k] + log(weights[k]); + log_prob[(size_t)row * K + k] = constant - T(0.5) * mahal; +} + +// ---------------------------------------------------------------------------- +// Per-cell logsumexp normalize: resp[n, k] = exp(log_prob[n, k] − logΣ_k). +// Also writes per-cell log-likelihood (= logΣ_k) into ll_per_cell for later +// reduction. One block per cell; threads stride across K. +// ---------------------------------------------------------------------------- + +template +__global__ void e_step_normalize_kernel( + const T* __restrict__ log_prob, // (n, K) + int n, int K, + T* __restrict__ resp, // (n, K) + T* __restrict__ ll_per_cell // (n,) +) { + int n_idx = blockIdx.x; + if (n_idx >= n) return; + int tid = threadIdx.x; + + __shared__ T sh_max; + __shared__ T sh_sum; + + // pass 1: max over K + T local_max = -CUDART_INF_F; + for (int k = tid; k < K; k += blockDim.x) { + T v = log_prob[n_idx * K + k]; + if (v > local_max) local_max = v; + } + // warp + block reduce max + for (int off = 16; off > 0; off >>= 1) { + T other = __shfl_down_sync(0xffffffff, local_max, off); + if (other > local_max) local_max = other; + } + if (tid == 0) sh_max = local_max; + __syncthreads(); + T mx = sh_max; + + // pass 2: sum exp(log_prob - max) + T local_sum = T(0); + for (int k = tid; k < K; k += blockDim.x) { + local_sum += exp(log_prob[n_idx * K + k] - mx); + } + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (tid == 0) { + sh_sum = local_sum; + T log_total = log(local_sum) + mx; + ll_per_cell[n_idx] = log_total; + } + __syncthreads(); + T log_total = log(sh_sum) + mx; + + // pass 3: write normalized responsibilities + for (int k = tid; k < K; k += blockDim.x) { + resp[n_idx * K + k] = exp(log_prob[n_idx * K + k] - log_total); + } +} + +template +__global__ void m_step_finalize_means_kernel(const T* __restrict__ N_k, + const T* __restrict__ num, + T* __restrict__ weights, + T* __restrict__ means, T eps, + int n, int d, int K) { + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + if (tid == 0) weights[k] = Nk / T(n); + + for (int i = tid; i < d; i += blockDim.x) + means[k * d + i] = num[k * d + i] * inv_Nk; +} + +template +__global__ void weighted_center_kernel(const T* __restrict__ X, + const T* __restrict__ resp, + const T* __restrict__ means, int n, + int d, int K, int k, + T* __restrict__ centered) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)n * d; + if (idx >= total) return; + + int row = idx / d; + int col = idx - (size_t)row * d; + T r = resp[row * K + k]; + centered[idx] = sqrt(r) * (X[idx] - means[k * d + col]); +} + +template +__global__ void m_step_finalize_cov_cublas_kernel(const T* __restrict__ N_k, + T* __restrict__ covariances, + T reg_covar, T eps, int d, + int K) { + int k = blockIdx.x; + int tid = threadIdx.x; + if (k >= K) return; + + T Nk = N_k[k] + T(10) * eps; + T inv_Nk = T(1) / Nk; + int total = d * d; + T* cov = covariances + (size_t)k * d * d; + + for (int idx = tid; idx < total; idx += blockDim.x) { + int i = idx / d; + int j = idx % d; + if (i > j) continue; + + // cuBLAS wrote the row-major symmetric result through a column-major + // view. Read the transposed element and write a symmetric row-major + // covariance. + T v = cov[j * d + i] * inv_Nk; + if (i == j) v += reg_covar; + cov[i * d + j] = v; + if (i != j) cov[j * d + i] = v; + } +} diff --git a/src/rapids_singlecell/squidpy_gpu/__init__.py b/src/rapids_singlecell/squidpy_gpu/__init__.py index 168e44d5..fa3cc3fd 100644 --- a/src/rapids_singlecell/squidpy_gpu/__init__.py +++ b/src/rapids_singlecell/squidpy_gpu/__init__.py @@ -3,3 +3,4 @@ from ._autocorr import spatial_autocorr from ._co_oc import co_occurrence from ._ligrec import ligrec +from ._niche import calculate_niche diff --git a/src/rapids_singlecell/squidpy_gpu/_gmm.py b/src/rapids_singlecell/squidpy_gpu/_gmm.py new file mode 100644 index 00000000..f47a066e --- /dev/null +++ b/src/rapids_singlecell/squidpy_gpu/_gmm.py @@ -0,0 +1,520 @@ +"""Full-covariance GMM for the CellCharter niche flavor. + +The public behavior mirrors :class:`sklearn.mixture.GaussianMixture` with +``covariance_type="full"``. EM is CUDA-only: CuPy is used for array handling and +the precision-Cholesky factorization, while E-step and M-step work is delegated +to the nanobind/CUDA extension. +""" + +from __future__ import annotations + +from typing import Literal + +import cupy as cp +import numpy as np +from cupyx.scipy.linalg import solve_triangular + +from rapids_singlecell._cuda import _gmm_cuda as _gc + +_GMMInit = Literal["kmeans", "random_from_data", "sklearn_kmeans"] +_EStepRoute = Literal["fused", "cublas"] + +_KMEANS_MAX_ITER = 100 +_SKLEARN_SEEDED_KMEANS_MAX_ITER = 300 + +# Fused kernels cover the CellCharter regime. Wider float32 embeddings and +# float64 embeddings above 64 dimensions use the cuBLAS E-step. +_CUDA_FUSED_E_STEP_MAX_D = 512 +_CUDA_FUSED_FLOAT64_MAX_D = 64 +_CUDA_CUBLAS_E_STEP_MIN_D = 257 + + +def _allocate_m_step_workspace(X: cp.ndarray, K: int) -> dict[str, cp.ndarray]: + n, d = X.shape + return { + "ones": cp.ones(n, dtype=X.dtype), + "effective_counts": cp.empty(K, dtype=X.dtype), + "weighted_sums": cp.empty((K, d), dtype=X.dtype), + "centered": cp.empty_like(X), + } + + +def _allocate_em_workspace( + X: cp.ndarray, K: int, e_step_route: _EStepRoute +) -> dict[str, cp.ndarray]: + n = X.shape[0] + workspace = { + "log_prob": cp.empty((n, K), dtype=X.dtype), + "responsibilities": cp.empty((n, K), dtype=X.dtype), + "ll_per_cell": cp.empty(n, dtype=X.dtype), + **_allocate_m_step_workspace(X, K), + } + if e_step_route == "cublas": + workspace["e_step_y"] = cp.empty_like(X) + return workspace + + +def _e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + *, + log_prob: cp.ndarray, + responsibilities: cp.ndarray, + ll_per_cell: cp.ndarray, + centered: cp.ndarray, + e_step_y: cp.ndarray | None, + e_step_route: _EStepRoute, + stream: int, + handle: int, +) -> tuple[cp.ndarray, cp.ndarray]: + if e_step_route == "cublas": + return _e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered=centered, + e_step_y=e_step_y, + log_prob=log_prob, + responsibilities=responsibilities, + ll_per_cell=ll_per_cell, + stream=stream, + handle=handle, + ) + return _e_step_fused( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob=log_prob, + responsibilities=responsibilities, + ll_per_cell=ll_per_cell, + stream=stream, + ) + + +def _e_step_fused( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + *, + log_prob: cp.ndarray, + responsibilities: cp.ndarray, + ll_per_cell: cp.ndarray, + stream: int, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(weights.shape[0]) + _gc.e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + responsibilities, + ll_per_cell, + n=int(n), + d=int(d), + K=K, + stream=stream, + ) + return responsibilities, ll_per_cell.mean() + + +def _e_step_cublas( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, + *, + centered: cp.ndarray, + e_step_y: cp.ndarray | None, + log_prob: cp.ndarray, + responsibilities: cp.ndarray, + ll_per_cell: cp.ndarray, + stream: int, + handle: int, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(weights.shape[0]) + _gc.e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered, + e_step_y, + log_prob, + responsibilities, + ll_per_cell, + n=int(n), + d=int(d), + K=K, + stream=stream, + handle=handle, + ) + return responsibilities, ll_per_cell.mean() + + +def gmm_fit_predict( + X: cp.ndarray, + n_components: int, + *, + random_state: int = 0, + max_iter: int = 100, + tol: float = 1e-3, + reg_covar: float = 1e-6, + init: _GMMInit = "kmeans", + kmeans_n_init: int = 1, +) -> cp.ndarray: + """Fit a full-covariance GMM and return cluster labels. + + Parameters + ---------- + X + GPU matrix with observations in rows and features in columns. + n_components + Number of mixture components. + random_state + Seed used by the selected initialization strategy. + max_iter + Maximum number of EM iterations. + tol + Convergence threshold on the mean log-likelihood change. + reg_covar + Non-negative regularization added to each covariance diagonal. + init + Initialization strategy. ``"kmeans"`` uses native cuML KMeans, + ``"random_from_data"`` matches sklearn/Squidpy random-from-data, and + ``"sklearn_kmeans"`` uses sklearn k-means++ seeding followed by cuML + KMeans. + kmeans_n_init + Number of cuML KMeans restarts for ``init="kmeans"``. + """ + K = int(n_components) + if K < 1: + raise ValueError("n_components must be >= 1.") + if int(kmeans_n_init) < 1: + raise ValueError("kmeans_n_init must be >= 1.") + + X = cp.ascontiguousarray(X) + weights, means, covariances = _initialize_parameters( + X, + K, + init=init, + random_state=random_state, + reg_covar=reg_covar, + kmeans_n_init=int(kmeans_n_init), + ) + responsibilities = _run_em( + X, + weights, + means, + covariances, + max_iter=int(max_iter), + tol=float(tol), + reg_covar=float(reg_covar), + ) + return responsibilities.argmax(axis=1).astype(cp.int32) + + +def _run_em( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, + *, + max_iter: int, + tol: float, + reg_covar: float, +) -> cp.ndarray: + n, d = X.shape + K = int(weights.shape[0]) + stream = cp.cuda.get_current_stream().ptr + handle = cp.cuda.device.get_cublas_handle() + e_step_route = _choose_e_step(int(d), X.dtype) + workspace = _allocate_em_workspace(X, K, e_step_route) + + prec_chol, log_det_half = _precision_cholesky(covariances) + previous_ll = -np.inf + + for _ in range(max_iter): + responsibilities, mean_ll = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob=workspace["log_prob"], + responsibilities=workspace["responsibilities"], + ll_per_cell=workspace["ll_per_cell"], + centered=workspace["centered"], + e_step_y=workspace.get("e_step_y"), + e_step_route=e_step_route, + stream=stream, + handle=handle, + ) + mean_ll = float(mean_ll) + if abs(mean_ll - previous_ll) < tol: + return responsibilities + + previous_ll = mean_ll + _m_step( + X, + responsibilities, + weights, + means, + covariances, + reg_covar=reg_covar, + ones=workspace["ones"], + effective_counts=workspace["effective_counts"], + weighted_sums=workspace["weighted_sums"], + centered=workspace["centered"], + stream=stream, + handle=handle, + ) + prec_chol, log_det_half = _precision_cholesky(covariances) + + responsibilities, _ = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob=workspace["log_prob"], + responsibilities=workspace["responsibilities"], + ll_per_cell=workspace["ll_per_cell"], + centered=workspace["centered"], + e_step_y=workspace.get("e_step_y"), + e_step_route=e_step_route, + stream=stream, + handle=handle, + ) + return responsibilities + + +def _initialize_parameters( + X: cp.ndarray, + K: int, + *, + init: str, + random_state: int, + reg_covar: float, + kmeans_n_init: int, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + if init == "random_from_data": + return _random_from_data_init(X, K, random_state, reg_covar) + + if init not in ("kmeans", "sklearn_kmeans"): + raise ValueError( + "init must be 'kmeans', 'random_from_data', or " + f"'sklearn_kmeans', got {init!r}" + ) + + labels, centers = _fit_kmeans( + X, + K, + init=init, + random_state=random_state, + kmeans_n_init=kmeans_n_init, + ) + return _parameters_from_labels(X, labels, centers, reg_covar) + + +def _random_from_data_init( + X: cp.ndarray, + K: int, + random_state: int, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n, d = X.shape + rng = np.random.RandomState(random_state) + idx = cp.asarray(rng.choice(n, size=K, replace=False)) + eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) + return ( + cp.full(K, 1.0 / K, dtype=X.dtype), + X[idx].copy(), + cp.broadcast_to(eye_reg, (K, d, d)).copy(), + ) + + +def _fit_kmeans( + X: cp.ndarray, + K: int, + *, + init: str, + random_state: int, + kmeans_n_init: int, +) -> tuple[cp.ndarray, cp.ndarray]: + from cuml.cluster import KMeans + + kwargs = {} + if init == "sklearn_kmeans": + from sklearn.cluster import kmeans_plusplus + + centers, _ = kmeans_plusplus(cp.asnumpy(X), K, random_state=random_state) + kwargs["init"] = cp.asarray(centers, dtype=X.dtype) + kmeans_n_init = 1 + max_iter = _SKLEARN_SEEDED_KMEANS_MAX_ITER + else: + max_iter = _KMEANS_MAX_ITER + + km = KMeans( + n_clusters=K, + random_state=random_state, + n_init=int(kmeans_n_init), + max_iter=max_iter, + **kwargs, + ) + km.fit(X) + return ( + cp.asarray(km.labels_).astype(cp.int64, copy=False), + cp.asarray(km.cluster_centers_, dtype=X.dtype), + ) + + +def _parameters_from_labels( + X: cp.ndarray, + labels: cp.ndarray, + means_init: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(means_init.shape[0]) + weights = cp.empty(K, dtype=X.dtype) + means = cp.empty((K, d), dtype=X.dtype) + covariances = cp.empty((K, d, d), dtype=X.dtype) + responsibilities = cp.zeros((n, K), dtype=X.dtype) + workspace = _allocate_m_step_workspace(X, K) + responsibilities[cp.arange(n), labels] = X.dtype.type(1.0) + _m_step( + X, + responsibilities, + weights, + means, + covariances, + reg_covar=reg_covar, + ones=workspace["ones"], + effective_counts=workspace["effective_counts"], + weighted_sums=workspace["weighted_sums"], + centered=workspace["centered"], + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + _restore_empty_components( + weights, + means, + covariances, + labels, + means_init, + n=n, + reg_covar=reg_covar, + ) + return weights, means, covariances + + +def _restore_empty_components( + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, + labels: cp.ndarray, + means_init: cp.ndarray, + *, + n: int, + reg_covar: float, +) -> None: + """Repair empty cuML KMeans components before EM starts. + + sklearn's GMM init estimates parameters from hard KMeans responsibilities + and adds ``10 * eps`` to component counts, relying on sklearn KMeans to + avoid empty final labels in normal cases. cuML can still hand back an empty + component, so keep its center, give it a tiny finite weight, and use the + regularized identity covariance instead of letting the M-step create a + zero-mean component from an empty responsibility column. + """ + K, d = means_init.shape + counts = cp.bincount(labels, minlength=int(K)).astype(means_init.dtype, copy=False) + empty = counts == 0 + eye_reg = reg_covar * cp.eye(d, dtype=means_init.dtype) + + weights[...] = cp.where(empty, means_init.dtype.type(1.0 / n), counts / n) + means[...] = cp.where(empty[:, None], means_init, means) + covariances[...] = cp.where( + empty[:, None, None], + cp.broadcast_to(eye_reg, covariances.shape), + covariances, + ) + + +def _precision_cholesky(covariances: cp.ndarray) -> tuple[cp.ndarray, cp.ndarray]: + """Return sklearn-oriented precision Cholesky without forming an inverse.""" + cov_chol = cp.linalg.cholesky(covariances) + eye = cp.broadcast_to( + cp.eye(covariances.shape[-1], dtype=covariances.dtype), + covariances.shape, + ) + cov_chol_inv = solve_triangular(cov_chol, eye, lower=True) + return ( + cp.ascontiguousarray(cov_chol_inv.transpose(0, 2, 1)), + -cp.sum( + cp.log(cp.diagonal(cov_chol, axis1=1, axis2=2)), + axis=1, + ), + ) + + +def _m_step( + X: cp.ndarray, + responsibilities: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + covariances: cp.ndarray, + *, + reg_covar: float, + ones: cp.ndarray, + effective_counts: cp.ndarray, + weighted_sums: cp.ndarray, + centered: cp.ndarray, + stream: int, + handle: int, +) -> None: + n, d = X.shape + K = int(weights.shape[0]) + + _gc.m_step( + responsibilities, + X, + ones, + weights, + means, + covariances, + effective_counts, + weighted_sums, + centered, + n=int(n), + d=int(d), + K=K, + reg_covar=float(reg_covar), + stream=stream, + handle=handle, + ) + + +def _choose_e_step(d: int, dtype) -> _EStepRoute: + """Select the CUDA E-step implementation for a feature width and dtype.""" + dtype = np.dtype(dtype) + if dtype == np.dtype("float32"): + return "cublas" if d >= _CUDA_CUBLAS_E_STEP_MIN_D else "fused" + if dtype == np.dtype("float64"): + return "cublas" if d > _CUDA_FUSED_FLOAT64_MAX_D else "fused" + return "cublas" if d >= _CUDA_CUBLAS_E_STEP_MIN_D else "fused" diff --git a/src/rapids_singlecell/squidpy_gpu/_niche.py b/src/rapids_singlecell/squidpy_gpu/_niche.py new file mode 100644 index 00000000..87cceb94 --- /dev/null +++ b/src/rapids_singlecell/squidpy_gpu/_niche.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import cupy as cp +import numpy as np +import pandas as pd +from anndata import AnnData +from cupyx.scipy import sparse as sparse_gpu + +import rapids_singlecell as rsc + +if TYPE_CHECKING: + from collections.abc import Sequence + + +__all__ = ["calculate_niche"] + + +def calculate_niche( + adata: AnnData, + *, + flavor: Literal["neighborhood", "utag", "cellcharter"], + groups: str | None = None, + n_neighbors: int = 15, + resolutions: float | Sequence[float] = (0.5,), + distance: int | None = None, + n_hop_weights: Sequence[float] | None = None, + abs_nhood: bool = False, + scale: bool = True, + min_niche_size: int | None = None, + aggregation: Literal["mean", "variance"] = "mean", + n_components: int = 10, + use_rep: str | None = None, + gmm_init: Literal[ + "random_from_data", "kmeans", "sklearn_kmeans" + ] = "random_from_data", + spatial_connectivities_key: str = "spatial_connectivities", + random_state: int = 42, + copy: bool = False, +) -> AnnData | None: + """\ + Compute spatial niches on the GPU. + + Mirrors :func:`squidpy.gr.calculate_niche` for the ``"neighborhood"``, + ``"utag"`` and ``"cellcharter"`` flavors. The spatial graph in + ``adata.obsp[spatial_connectivities_key]`` must be precomputed + (e.g. via :func:`squidpy.gr.spatial_neighbors`). + + Parameters + ---------- + adata + Annotated data matrix. + flavor + - ``"neighborhood"`` cluster cell-type frequency profiles among spatial neighbors + :cite:p:`monkeybread`. + - ``"utag"`` cluster gene expression smoothed across spatial neighbors + :cite:p:`UTAG2022`. + - ``"cellcharter"`` shell-aggregate gene expression over n-hop neighborhoods, + PCA-reduce, then cluster with a Gaussian mixture :cite:p:`CellCharter2024`. + groups + Column in ``adata.obs`` with cell-type labels. Required for ``flavor="neighborhood"``. + n_neighbors + Neighbors for the post-aggregation kNN graph passed to leiden. + resolutions + Resolution(s) for leiden. A label column is written for each value. + Ignored for ``flavor="cellcharter"``. + distance + Number of n-hop neighborhoods to include. Defaults to 3 for ``cellcharter``, + 1 for ``neighborhood``. + n_hop_weights + Per-hop weights when ``distance > 1`` (``flavor="neighborhood"`` only). + abs_nhood + Use absolute neighbor counts instead of per-cell relative frequencies + (``flavor="neighborhood"`` only). + scale + Z-score the neighborhood profile before clustering (``flavor="neighborhood"`` only). + min_niche_size + Discard niches with fewer cells than this; relabel as ``"not_a_niche"``. + aggregation + Per-shell aggregation for ``flavor="cellcharter"``. ``"mean"`` (default) or ``"variance"``. + n_components + Number of mixture components for ``flavor="cellcharter"``. + use_rep + Key in ``adata.obsm`` to use as the embedding for ``flavor="cellcharter"``; + if provided, the first ``n_components`` columns are used and the shell-aggregation + + PCA step is skipped. + gmm_init + GMM initialization for ``flavor="cellcharter"``. ``"random_from_data"`` + (default) matches Squidpy's CellCharter path. ``"kmeans"`` uses native + cuML KMeans. ``"sklearn_kmeans"`` uses sklearn-compatible k-means++ seeding + followed by cuML KMeans. + spatial_connectivities_key + Key in ``adata.obsp`` with the spatial connectivity matrix. + random_state + Random seed for leiden / GMM. + copy + Return a copy with the niche columns instead of writing in place. + """ + if spatial_connectivities_key not in adata.obsp: + raise KeyError( + f"'{spatial_connectivities_key}' not found in `adata.obsp`. " + "Compute it first with `squidpy.gr.spatial_neighbors`." + ) + if flavor not in ("neighborhood", "utag", "cellcharter"): + raise ValueError( + f"Unknown flavor '{flavor}'. Use 'neighborhood', 'utag', or 'cellcharter'." + ) + if distance is None: + distance = 3 if flavor == "cellcharter" else 1 + if flavor in ("neighborhood",) and distance < 1: + raise ValueError(f"`distance` must be >= 1, got {distance}.") + if flavor == "cellcharter" and distance < 0: + raise ValueError(f"`distance` must be >= 0, got {distance}.") + + adata = adata.copy() if copy else adata + + if flavor == "cellcharter": + if gmm_init not in ("random_from_data", "kmeans", "sklearn_kmeans"): + raise ValueError( + "`gmm_init` must be one of 'random_from_data', 'kmeans', or " + f"'sklearn_kmeans', got {gmm_init!r}." + ) + _run_cellcharter( + adata, + distance=distance, + aggregation=aggregation, + n_components=n_components, + use_rep=use_rep, + gmm_init=gmm_init, + random_state=random_state, + key=spatial_connectivities_key, + ) + return adata if copy else None + + if flavor == "neighborhood": + if groups is None: + raise ValueError("`groups` is required for flavor='neighborhood'.") + if groups not in adata.obs.columns: + raise KeyError(f"'{groups}' not found in `adata.obs`.") + profile = _neighborhood_profile( + adata, + groups=groups, + distance=distance, + weights=n_hop_weights, + abs_nhood=abs_nhood, + key=spatial_connectivities_key, + ) + prefix = "nhood_niche" + else: + profile = _utag_features(adata, spatial_connectivities_key) + prefix = "utag_niche" + + inner = AnnData(X=profile, obs=pd.DataFrame(index=adata.obs_names.copy())) + + if flavor == "neighborhood": + if scale: + rsc.pp.scale(inner, zero_center=True) + rsc.pp.neighbors( + inner, n_neighbors=n_neighbors, use_rep="X", random_state=random_state + ) + else: + rsc.pp.pca(inner) + rsc.pp.neighbors( + inner, n_neighbors=n_neighbors, use_rep="X_pca", random_state=random_state + ) + + res_list = ( + [float(resolutions)] + if isinstance(resolutions, (int, float)) + else [float(r) for r in resolutions] + ) + base = "_niche_leiden" + rsc.tl.leiden( + inner, + resolution=res_list, + key_added=base, + random_state=random_state, + dtype=np.float64, + ) + for res in res_list: + src = f"{base}_{res}" if len(res_list) > 1 else base + out_key = f"{prefix}_res={res}" + labels = inner.obs[src].astype(str) + if min_niche_size is not None and flavor == "neighborhood": + counts = labels.value_counts() + small = counts[counts < min_niche_size].index + labels = labels.where(~labels.isin(small), other="not_a_niche") + adata.obs[out_key] = pd.Categorical(labels.values) + + return adata if copy else None + + +def _neighborhood_profile( + adata: AnnData, + *, + groups: str, + distance: int, + weights: Sequence[float] | None, + abs_nhood: bool, + key: str, +) -> np.ndarray: + """Cells x categories matrix of cell-type counts (or relative frequencies) over n-hop neighbors.""" + cats = pd.Categorical(adata.obs[groups]) + n_cats = len(cats.categories) + n_obs = adata.n_obs + + one_hot = cp.zeros((n_obs, n_cats), dtype=cp.float32) + one_hot[cp.arange(n_obs), cp.asarray(cats.codes, dtype=cp.int64)] = 1.0 + + adj = rsc.get.X_to_GPU(adata.obsp[key]).astype(cp.float32) + adj.eliminate_zeros() + # Binarize so adj.data == 1: each existing edge contributes one neighbor count. + adj_bin = adj.copy() + adj_bin.data[:] = 1.0 + + if weights is None: + weights = [1.0] * distance + elif len(weights) < distance: + weights = list(weights) + [weights[-1]] * (distance - len(weights)) + + profile = cp.zeros((n_obs, n_cats), dtype=cp.float32) + adj_k = adj_bin + for hop in range(distance): + if hop == 0: + adj_hop = adj_bin + else: + adj_k = adj_k @ adj_bin + adj_hop = adj_k.copy() + adj_hop.data[:] = 1.0 + counts = adj_hop @ one_hot # (n_obs, n_cats) dense + if not abs_nhood: + row_sum = adj_hop.sum(axis=1).reshape(-1, 1) + row_sum = cp.where(row_sum == 0, cp.float32(1.0), row_sum) + counts = counts / row_sum + profile += cp.float32(weights[hop]) * counts + + if not abs_nhood: + profile /= cp.float32(sum(weights)) + + return profile + + +def _utag_features(adata: AnnData, key: str) -> cp.ndarray | sparse_gpu.csr_matrix: + """L1-row-normalize the spatial adjacency and propagate expression: D^-1 A @ X.""" + from rapids_singlecell._cuda import _norm_cuda as _nc + + adj = rsc.get.X_to_GPU(adata.obsp[key]) + if adj.dtype != cp.float32: + adj = adj.astype(cp.float32) + _nc.mul_csr( + adj.indptr, + adj.data, + nrows=adj.shape[0], + target_sum=1.0, + stream=cp.cuda.get_current_stream().ptr, + ) + + X = rsc.get.X_to_GPU(adata.X).astype(cp.float32) + if sparse_gpu.issparse(X): + out = adj @ X + return out.tocsr() + out = adj @ X + return out + + +def _run_cellcharter( + adata: AnnData, + *, + distance: int, + aggregation: str, + n_components: int, + use_rep: str | None, + gmm_init: str, + random_state: int, + key: str, +) -> None: + """Cellcharter pipeline: shell-aggregate → PCA → GMM.""" + if aggregation not in ("mean", "variance"): + raise ValueError( + f"aggregation={aggregation!r} not supported. Use 'mean' or 'variance'." + ) + if not isinstance(n_components, int) or n_components < 1: + raise ValueError(f"`n_components` must be an int >= 1, got {n_components}.") + + if use_rep is not None: + if use_rep not in adata.obsm: + raise KeyError(f"'{use_rep}' not found in `adata.obsm`.") + emb = adata.obsm[use_rep] + if emb.shape[1] < n_components: + raise ValueError( + f"`adata.obsm['{use_rep}']` has {emb.shape[1]} columns, " + f"need at least n_components={n_components}." + ) + embedding = cp.asarray(emb[:, :n_components], dtype=cp.float32) + else: + feat = _cellcharter_features(adata, distance, aggregation, key) + inner = AnnData(X=feat, obs=pd.DataFrame(index=adata.obs_names.copy())) + rsc.get.anndata_to_GPU(inner) + rsc.pp.pca(inner) + embedding = cp.asarray(inner.obsm["X_pca"], dtype=cp.float32) + + from ._gmm import gmm_fit_predict + + labels = gmm_fit_predict( + embedding, + n_components=n_components, + random_state=random_state, + init=gmm_init, + ) + adata.obs["cellcharter_niche"] = pd.Categorical(cp.asnumpy(labels).astype(str)) + + +def _cellcharter_features( + adata: AnnData, + distance: int, + aggregation: str, + key: str, +) -> cp.ndarray | sparse_gpu.csr_matrix: + """Build the shell-aggregated feature matrix: ``[X | Â₁X | Â₂X | …]``. + + For each k in ``1..distance`` the kth-shell adjacency is computed by + multiplying the previous adjacency by the base graph and subtracting the + already-visited neighbors. Each shell is row-L1-normalized via the same + fused ``mul_csr`` kernel used for utag, then aggregated as either: + + - ``"mean"``: ``Âₖ @ X`` + - ``"variance"``: ``Âₖ @ (X·X) − (Âₖ @ X)²`` (matches squidpy's path; densifies X) + + All layers are concatenated horizontally. + """ + from rapids_singlecell._cuda import _norm_cuda as _nc + + adj = rsc.get.X_to_GPU(adata.obsp[key]) + if adj.dtype != cp.float32: + adj = adj.astype(cp.float32) + + # 1-hop adjacency, no self-loops; visited tracks {self ∪ 1-hop}. + adj_hop = adj.copy() + adj_hop.setdiag(cp.float32(0.0)) + adj_hop.eliminate_zeros() + adj_visited = adj.copy() + adj_visited.setdiag(cp.float32(1.0)) + + X = rsc.get.X_to_GPU(adata.X) + if aggregation == "variance": + # Variance needs element-wise square of X; densify once up front. + X_dense = X.toarray() if sparse_gpu.issparse(X) else X + X_sq = X_dense * X_dense + aggregated: list = [X_dense] + else: + aggregated = [X] + + for k in range(1, distance + 1): + if k > 1: + # Walk one more hop, keep only newly reachable neighbors. + adj_hop = adj_hop @ adj + new_shell = (adj_hop > adj_visited).astype(cp.float32) + adj_hop = new_shell + adj_visited = adj_visited + new_shell + + # L1 row-normalize the shell adjacency in place. + adj_norm = adj_hop.copy() + if adj_norm.nnz > 0: + _nc.mul_csr( + adj_norm.indptr, + adj_norm.data, + nrows=adj_norm.shape[0], + target_sum=1.0, + stream=cp.cuda.get_current_stream().ptr, + ) + + if aggregation == "variance": + mean = adj_norm @ X_dense + mean_sq = adj_norm @ X_sq + aggregated.append(mean_sq - mean * mean) + else: + aggregated.append(adj_norm @ X) + + if all(not sparse_gpu.issparse(m) for m in aggregated): + return cp.concatenate(aggregated, axis=1) + aggregated = [ + m if sparse_gpu.issparse(m) else sparse_gpu.csr_matrix(m) for m in aggregated + ] + return sparse_gpu.hstack(aggregated, format="csr") diff --git a/tests/test_gmm.py b/tests/test_gmm.py new file mode 100644 index 00000000..d3e8c5aa --- /dev/null +++ b/tests/test_gmm.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import inspect + +import cupy as cp +import numpy as np +import pytest +from cupyx.scipy.special import logsumexp +from sklearn.metrics import adjusted_rand_score as ARI +from sklearn.mixture import GaussianMixture + +from rapids_singlecell.squidpy_gpu._gmm import ( + _choose_e_step, + _e_step, + _e_step_cublas, + _e_step_fused, + _m_step, + _precision_cholesky, + gmm_fit_predict, +) + + +def _well_separated(n_per: int, K: int, d: int, sep: float, seed: int): + rng = np.random.default_rng(seed) + centers = rng.normal(scale=sep, size=(K, d)) + X = np.vstack( + [rng.normal(loc=c, scale=1.0, size=(n_per, d)) for c in centers] + ).astype(np.float32) + y = np.repeat(np.arange(K), n_per) + perm = rng.permutation(len(X)) + return X[perm], y[perm] + + +def _pca_like_mixture(n_per: int, K: int, d: int, seed: int) -> np.ndarray: + rng = np.random.default_rng(seed) + centers = rng.normal(scale=5.0, size=(K, d)) + rows = [] + for k in range(K): + # A compact low-rank perturbation of the identity gives each synthetic + # cell state its own correlated PCA-space geometry. + factors = rng.normal(scale=0.3 + 0.05 * k, size=(d, 3)) + cov = np.eye(d) * (0.35 + 0.03 * k) + factors @ factors.T + rows.append(rng.multivariate_normal(centers[k], cov, size=n_per)) + X = np.vstack(rows).astype(np.float32) + return np.ascontiguousarray(X[rng.permutation(len(X))]) + + +_LOG_2PI = float(np.log(2.0 * np.pi)) + + +def _reference_e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = means.shape[0] + log_prob = cp.empty((n, K), dtype=X.dtype) + half_d_log2pi = X.dtype.type(0.5 * d * _LOG_2PI) + for k in range(K): + y = (X - means[k]) @ prec_chol[k] + mahal = cp.einsum("ij,ij->i", y, y) + log_prob[:, k] = ( + -X.dtype.type(0.5) * mahal + + log_det_half[k] + - half_d_log2pi + + cp.log(weights[k]) + ) + + log_total = logsumexp(log_prob, axis=1, keepdims=True) + return cp.exp(log_prob - log_total), log_total.mean() + + +def _reference_m_step( + X: cp.ndarray, + resp: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + n, d = X.shape + K = resp.shape[1] + N_k = resp.sum(axis=0) + 10.0 * cp.finfo(X.dtype).eps + weights = N_k / n + means = (resp.T @ X) / N_k[:, None] + covariances = cp.empty((K, d, d), dtype=X.dtype) + eye_reg = reg_covar * cp.eye(d, dtype=X.dtype) + for k in range(K): + diff = X - means[k] + covariances[k] = ((resp[:, k : k + 1] * diff).T @ diff) / N_k[k] + eye_reg + return weights, means, covariances + + +def _e_step_buffers(X: cp.ndarray, K: int, route: str): + n = X.shape[0] + return ( + cp.empty((n, K), dtype=X.dtype), + cp.empty((n, K), dtype=X.dtype), + cp.empty(n, dtype=X.dtype), + cp.empty_like(X), + cp.empty_like(X) if route == "cublas" else None, + ) + + +def _cuda_e_step( + X: cp.ndarray, + weights: cp.ndarray, + means: cp.ndarray, + prec_chol: cp.ndarray, + log_det_half: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + n, d = X.shape + K = int(means.shape[0]) + e_step_route = _choose_e_step(d, X.dtype) + log_prob, responsibilities, ll_per_cell, centered, e_step_y = _e_step_buffers( + X, K, e_step_route + ) + return _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + responsibilities, + ll_per_cell, + centered, + e_step_y, + e_step_route=e_step_route, + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + + +def _cuda_m_step( + X: cp.ndarray, + resp: cp.ndarray, + reg_covar: float, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray]: + K = resp.shape[1] + weights = cp.empty(K, dtype=X.dtype) + means = cp.empty((K, X.shape[1]), dtype=X.dtype) + covariances = cp.empty((K, X.shape[1], X.shape[1]), dtype=X.dtype) + _m_step( + X, + resp, + weights, + means, + covariances, + reg_covar, + cp.ones(X.shape[0], dtype=X.dtype), + cp.empty(K, dtype=X.dtype), + cp.empty((K, X.shape[1]), dtype=X.dtype), + cp.empty_like(X), + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + return weights, means, covariances + + +def test_kmeans_init_recovers_well_separated_clusters(): + """kmeans init should land at near-truth on well-separated data.""" + X_np, y = _well_separated(n_per=300, K=5, d=20, sep=6.0, seed=0) + labels = cp.asnumpy( + gmm_fit_predict(cp.asarray(X_np), n_components=5, random_state=0, init="kmeans") + ) + assert ARI(y, labels) >= 0.95 + + +def test_full_cov_gmm_matches_sklearn_on_singlecell_embedding(): + X = _pca_like_mixture(n_per=250, K=5, d=16, seed=7) + sk_labels = GaussianMixture( + n_components=5, + covariance_type="full", + tol=1e-3, + reg_covar=1e-6, + max_iter=100, + n_init=1, + init_params="kmeans", + random_state=0, + ).fit_predict(X) + rsc_labels = cp.asnumpy( + gmm_fit_predict( + cp.asarray(X), + n_components=5, + random_state=0, + init="kmeans", + kmeans_n_init=1, + ) + ) + + assert ARI(sk_labels, rsc_labels) >= 0.99 + + +def test_random_from_data_init_runs(): + """random_from_data may land at a worse local optimum than kmeans, but should + still produce a non-trivial partition on well-separated data.""" + X_np, y = _well_separated(n_per=300, K=5, d=20, sep=6.0, seed=0) + labels = cp.asnumpy( + gmm_fit_predict( + cp.asarray(X_np), n_components=5, random_state=0, init="random_from_data" + ) + ) + assert ARI(y, labels) >= 0.35 + assert len(set(labels.tolist())) >= 2 + + +def test_output_shape_and_dtype(): + rng = np.random.default_rng(0) + X = rng.standard_normal((500, 8)).astype(np.float32) + labels = gmm_fit_predict(cp.asarray(X), n_components=4, random_state=0) + assert labels.shape == (500,) + assert labels.dtype == cp.int32 + assert int(labels.max()) < 4 + assert int(labels.min()) >= 0 + + +def test_non_contiguous_input_is_normalized_at_public_boundary(): + rng = np.random.default_rng(2) + X = cp.asarray(rng.standard_normal((6, 240)).astype(np.float32)).T + + assert not X.flags.c_contiguous + + labels = gmm_fit_predict( + X, + n_components=3, + random_state=0, + max_iter=2, + init="random_from_data", + ) + + assert labels.shape == (240,) + assert labels.dtype == cp.int32 + + +@pytest.mark.parametrize("init", ["kmeans", "random_from_data", "sklearn_kmeans"]) +def test_determinism_same_seed(init): + rng = np.random.default_rng(1) + X = cp.asarray(rng.standard_normal((800, 10)).astype(np.float32)) + a = cp.asnumpy(gmm_fit_predict(X, n_components=5, random_state=42, init=init)) + b = cp.asnumpy(gmm_fit_predict(X, n_components=5, random_state=42, init=init)) + np.testing.assert_array_equal(a, b) + + +def test_invalid_init_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="init"): + gmm_fit_predict(X, n_components=3, init="bogus") + + +def test_backend_parameter_is_not_exposed(): + assert "backend" not in inspect.signature(gmm_fit_predict).parameters + + +def test_invalid_kmeans_n_init_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="kmeans_n_init"): + gmm_fit_predict(X, n_components=3, kmeans_n_init=0) + + +def test_invalid_n_components_raises(): + X = cp.asarray(np.zeros((100, 5), dtype=np.float32)) + with pytest.raises(ValueError, match="n_components"): + gmm_fit_predict(X, n_components=0) + + +@pytest.mark.parametrize( + ("d", "dtype", "route"), + [ + (16, cp.float32, "fused"), + (32, cp.float32, "fused"), + (50, cp.float32, "fused"), + (64, cp.float32, "fused"), + (80, cp.float32, "fused"), + (96, cp.float32, "fused"), + (128, cp.float32, "fused"), + (256, cp.float32, "fused"), + (384, cp.float32, "cublas"), + (512, cp.float32, "cublas"), + (768, cp.float32, "cublas"), + (2000, cp.float32, "cublas"), + (64, cp.float64, "fused"), + (128, cp.float64, "cublas"), + (512, cp.float64, "cublas"), + (2000, cp.float64, "cublas"), + ], +) +def test_cuda_e_step_routing_uses_cublas_for_high_d_and_wide_float64(d, dtype, route): + assert _choose_e_step(d, dtype) == route + + +def test_n_components_one_returns_single_label(): + rng = np.random.default_rng(0) + X = cp.asarray(rng.standard_normal((200, 4)).astype(np.float32)) + labels = cp.asnumpy(gmm_fit_predict(X, n_components=1, random_state=0)) + assert set(labels.tolist()) == {0} + + +def test_float64_input_accepted(): + rng = np.random.default_rng(0) + X = cp.asarray(rng.standard_normal((300, 6)).astype(np.float64)) + labels = gmm_fit_predict(X, n_components=3, random_state=0) + assert labels.shape == (300,) + + +def test_cuda_matches_reference_steps(): + rng = cp.random.RandomState(0) + n, d, K = 40_000, 6, 3 # large enough to exercise the cuBLAS M-step path + X = rng.standard_normal((n, d), dtype=cp.float32) + logits = rng.standard_normal((n, K), dtype=cp.float32) + resp = cp.exp(logits - cp.log(cp.exp(logits).sum(axis=1, keepdims=True))) + + w_c, m_c, cov_c = _reference_m_step(X, resp, 1e-6) + w_g, m_g, cov_g = _cuda_m_step(X, resp, 1e-6) + + assert cp.max(cp.abs(w_c - w_g)).item() < 1e-5 + assert cp.max(cp.abs(m_c - m_g)).item() < 1e-5 + assert cp.max(cp.abs(cov_c - cov_g)).item() < 1e-4 + + weights = cp.full(K, 1 / K, dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = A @ A.transpose(0, 2, 1) + cp.eye(d, dtype=cp.float32)[None] * 0.1 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) + log_prob, resp, ll_per_cell, _, _ = _e_step_buffers(X, K, "fused") + r_f, ll_f = _e_step_fused( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + stream=cp.cuda.get_current_stream().ptr, + ) + + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-4 + assert cp.abs(ll_c - ll_g).item() < 1e-4 + assert cp.max(cp.abs(r_c - r_f)).item() < 1e-4 + assert cp.abs(ll_c - ll_f).item() < 1e-4 + + +def test_cuda_large_e_step_matches_reference_for_large_feature_count(): + rng = cp.random.RandomState(2) + n, d, K = 2048, 96, 4 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.asarray([0.15, 0.2, 0.3, 0.35], dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) + + assert cp.max(cp.abs(r_c - r_g)).item() < 5e-4 + assert cp.abs(ll_c - ll_g).item() < 5e-4 + + +def test_cuda_512_e_step_matches_reference_for_cublas_route(): + rng = cp.random.RandomState(5) + n, d, K = 384, 512, 3 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.asarray([0.2, 0.3, 0.5], dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, "cublas") + r_b, ll_b = _e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered, + e_step_y, + log_prob, + resp, + ll_per_cell, + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-3 + assert cp.abs(ll_c - ll_g).item() < 1e-3 + assert cp.max(cp.abs(r_c - r_b)).item() < 1e-3 + assert cp.abs(ll_c - ll_b).item() < 1e-3 + + +def test_cuda_768_e_step_uses_cublas_route(): + rng = cp.random.RandomState(8) + n, d, K = 64, 768, 2 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.asarray([0.45, 0.55], dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + eye = cp.eye(d, dtype=cp.float32) + cov = cp.stack((eye * 1.5, eye * 2.0)) + prec_chol, log_det_half = _precision_cholesky(cov) + + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, "cublas") + stream = cp.cuda.get_current_stream().ptr + handle = cp.cuda.device.get_cublas_handle() + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + centered, + e_step_y, + e_step_route="cublas", + stream=stream, + handle=handle, + ) + log_prob_b, resp_b, ll_per_cell_b, centered_b, e_step_y_b = _e_step_buffers( + X, K, "cublas" + ) + r_b, ll_b = _e_step_cublas( + X, + weights, + means, + prec_chol, + log_det_half, + centered_b, + e_step_y_b, + log_prob_b, + resp_b, + ll_per_cell_b, + stream=stream, + handle=handle, + ) + + assert _choose_e_step(d, X.dtype) == "cublas" + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-3 + assert cp.abs(ll_c - ll_g).item() < 1e-3 + assert cp.max(cp.abs(r_c - r_b)).item() < 1e-3 + assert cp.abs(ll_c - ll_b).item() < 1e-3 + + +def test_cuda_float64_wide_e_step_uses_cublas_route(): + rng = cp.random.RandomState(7) + n, d, K = 256, 128, 3 + X = rng.standard_normal((n, d), dtype=cp.float64) + weights = cp.asarray([0.2, 0.3, 0.5], dtype=cp.float64) + means = rng.standard_normal((K, d), dtype=cp.float64) + A = rng.standard_normal((K, d, d), dtype=cp.float64) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float64)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + route = _choose_e_step(d, X.dtype) + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, route) + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + centered, + e_step_y, + e_step_route=route, + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + + assert cp.max(cp.abs(r_c - r_g)).item() < 1e-12 + assert cp.abs(ll_c - ll_g).item() < 1e-12 + + +def test_cuda_fixed_e_step_matches_reference_for_medium_regime(): + rng = cp.random.RandomState(4) + n, d, K = 1024, 16, 8 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.full(K, 1 / K, dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_g, ll_g = _cuda_e_step(X, weights, means, prec_chol, log_det_half) + + assert cp.max(cp.abs(r_c - r_g)).item() < 5e-4 + assert cp.abs(ll_c - ll_g).item() < 5e-4 + + +def test_cuda_fused_e_step_matches_reference_for_50_pc_regime(): + rng = cp.random.RandomState(6) + n, d, K = 1024, 50, 12 + X = rng.standard_normal((n, d), dtype=cp.float32) + weights = cp.full(K, 1 / K, dtype=cp.float32) + means = rng.standard_normal((K, d), dtype=cp.float32) + A = rng.standard_normal((K, d, d), dtype=cp.float32) + cov = (A @ A.transpose(0, 2, 1)) / d + cp.eye(d, dtype=cp.float32)[None] * 0.5 + prec_chol, log_det_half = _precision_cholesky(cov) + + log_prob, resp, ll_per_cell, centered, e_step_y = _e_step_buffers(X, K, "fused") + r_c, ll_c = _reference_e_step(X, weights, means, prec_chol, log_det_half) + r_d, ll_d = _e_step( + X, + weights, + means, + prec_chol, + log_det_half, + log_prob, + resp, + ll_per_cell, + centered, + e_step_y, + e_step_route="fused", + stream=cp.cuda.get_current_stream().ptr, + handle=cp.cuda.device.get_cublas_handle(), + ) + + assert cp.max(cp.abs(r_c - r_d)).item() < 5e-4 + assert cp.abs(ll_c - ll_d).item() < 5e-4 + + +def test_cuda_runs_large_feature_count(): + rng = np.random.default_rng(3) + X = cp.asarray(rng.standard_normal((360, 80)).astype(np.float32)) + labels = gmm_fit_predict( + X, + n_components=3, + random_state=0, + max_iter=2, + reg_covar=1e-2, + init="random_from_data", + ) + + assert labels.shape == (360,) + assert labels.dtype == cp.int32 diff --git a/tests/test_niche.py b/tests/test_niche.py new file mode 100644 index 00000000..fbee6226 --- /dev/null +++ b/tests/test_niche.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import inspect +from pathlib import Path + +import cupy as cp +import numpy as np +import pandas as pd +import pytest +from anndata import read_h5ad +from cupyx.scipy import sparse as sparse_gpu +from scipy import sparse + +from rapids_singlecell.gr import calculate_niche +from rapids_singlecell.squidpy_gpu._niche import ( + _neighborhood_profile, + _utag_features, +) + +DATA = Path(__file__).parent / "_data" / "dummy.h5ad" +SPATIAL_CONNECTIVITIES_KEY = "spatial_connectivities" +GROUPS = "cluster" + + +@pytest.fixture +def adata(): + a = read_h5ad(DATA) + # _neighborhood_profile uses pd.Categorical on this column + a.obs[GROUPS] = pd.Categorical(a.obs[GROUPS]) + return a + + +# -- semantic tests adapted from squidpy/tests/graph/test_niche.py (BSD-3) -- + + +def test_niche_calc_nhood(adata): + """Adapted from squidpy: profile shape, normalization, min_niche_size labels.""" + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=[0.1], + min_niche_size=20, + ) + niches = adata.obs["nhood_niche_res=0.1"] + + # no NaNs, more cells in real niches than in 'not_a_niche' + assert niches.isna().sum() == 0 + assert len(niches[niches != "not_a_niche"]) > len(niches[niches == "not_a_niche"]) + for label in niches.unique(): + if label != "not_a_niche": + assert (niches == label).sum() >= 20 + + # profile shape + n_cats = len(adata.obs[GROUPS].cat.categories) + rel = cp.asnumpy( + _neighborhood_profile( + adata, + groups=GROUPS, + distance=1, + weights=None, + abs_nhood=False, + key=SPATIAL_CONNECTIVITIES_KEY, + ) + ) + abs_ = cp.asnumpy( + _neighborhood_profile( + adata, + groups=GROUPS, + distance=1, + weights=None, + abs_nhood=True, + key=SPATIAL_CONNECTIVITIES_KEY, + ) + ) + assert rel.shape == (adata.n_obs, n_cats) + assert abs_.shape == rel.shape + + # relative profile: each row sums to 1 when the cell has neighbors (all do here), + # so total sum == n_obs and the per-row max sum is 1. + np.testing.assert_allclose(rel.sum(axis=1).sum(), adata.n_obs, atol=1e-4) + assert rel.sum(axis=1).max() == pytest.approx(1.0, abs=1e-5) + + # absolute profile: per-row sum equals that cell's degree in the spatial graph + deg = np.asarray((adata.obsp[SPATIAL_CONNECTIVITIES_KEY] != 0).sum(axis=1)).ravel() + np.testing.assert_array_equal(abs_.sum(axis=1).astype(int), deg) + + +def test_niche_calc_utag(adata): + """Adapted from squidpy: utag output shape, sparsity, sensitivity to graph.""" + calculate_niche(adata, flavor="utag", n_neighbors=10, resolutions=[0.1, 1.0]) + + niches_high = adata.obs["utag_niche_res=1.0"] + niches_low = adata.obs["utag_niche_res=0.1"] + assert niches_high.isna().sum() == 0 + # higher resolution → strictly more (or at least as many) clusters + assert niches_high.nunique() >= niches_low.nunique() + + # output shape matches X (returns cupy.ndarray for dense X) + feat = _utag_features(adata, SPATIAL_CONNECTIVITIES_KEY) + assert feat.shape == adata.X.shape + + # sparsity preserved when input X is sparse (returns cupyx sparse) + a_sparse = adata.copy() + a_sparse.X = sparse.csr_matrix(adata.X) + feat_sparse = _utag_features(a_sparse, SPATIAL_CONNECTIVITIES_KEY) + assert sparse_gpu.issparse(feat_sparse) + assert feat_sparse.shape == adata.X.shape + + # different spatial graph structure → different feature matrix + # (uniform value scaling is invisible after row-normalization, so we drop edges) + a2 = adata.copy() + G = a2.obsp[SPATIAL_CONNECTIVITIES_KEY].tolil() + G[0, :] = 0 + G[1, :] = 0 + G = G.tocsr() + G.eliminate_zeros() + a2.obsp[SPATIAL_CONNECTIVITIES_KEY] = G + feat2 = _utag_features(a2, SPATIAL_CONNECTIVITIES_KEY) + assert not cp.allclose(feat, feat2) + + +# -- additional rsc-specific tests -- + + +@pytest.mark.parametrize("flavor", ["neighborhood", "utag"]) +def test_basic_runs_inplace(adata, flavor): + kw = {"groups": GROUPS} if flavor == "neighborhood" else {} + out = calculate_niche(adata, flavor=flavor, n_neighbors=10, resolutions=0.5, **kw) + assert out is None + prefix = "nhood_niche" if flavor == "neighborhood" else "utag_niche" + col = f"{prefix}_res=0.5" + assert col in adata.obs.columns + assert isinstance(adata.obs[col].dtype, pd.CategoricalDtype) + + +def test_copy_returns_new_object(adata): + before = list(adata.obs.columns) + out = calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + copy=True, + ) + assert out is not None + assert "nhood_niche_res=0.5" in out.obs.columns + assert list(adata.obs.columns) == before + + +def test_multiple_resolutions(adata): + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=[0.3, 0.7], + ) + assert "nhood_niche_res=0.3" in adata.obs.columns + assert "nhood_niche_res=0.7" in adata.obs.columns + + +def test_n_hop_neighbors(adata): + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + distance=3, + n_hop_weights=[1.0, 0.5, 0.25], + ) + assert "nhood_niche_res=0.5" in adata.obs.columns + + +def test_min_niche_size_relabels_all(adata): + """min_niche_size > n_obs should send every cell to 'not_a_niche'.""" + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=2.0, + min_niche_size=adata.n_obs + 1, + ) + labels = adata.obs["nhood_niche_res=2.0"].astype(str) + assert (labels == "not_a_niche").all() + + +def test_determinism_same_seed(adata): + a1, a2 = adata.copy(), adata.copy() + calculate_niche( + a1, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + random_state=42, + ) + calculate_niche( + a2, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + random_state=42, + ) + np.testing.assert_array_equal( + a1.obs["nhood_niche_res=0.5"].astype(str).values, + a2.obs["nhood_niche_res=0.5"].astype(str).values, + ) + + +def test_determinism_utag_same_seed(adata): + a1, a2 = adata.copy(), adata.copy() + calculate_niche(a1, flavor="utag", n_neighbors=10, resolutions=0.5, random_state=7) + calculate_niche(a2, flavor="utag", n_neighbors=10, resolutions=0.5, random_state=7) + np.testing.assert_array_equal( + a1.obs["utag_niche_res=0.5"].astype(str).values, + a2.obs["utag_niche_res=0.5"].astype(str).values, + ) + + +def test_unknown_flavor_raises(adata): + with pytest.raises(ValueError, match="Unknown flavor"): + calculate_niche(adata, flavor="bogus", n_neighbors=10, resolutions=0.5) + + +def test_neighborhood_requires_groups(adata): + with pytest.raises(ValueError, match="`groups` is required"): + calculate_niche(adata, flavor="neighborhood", n_neighbors=10, resolutions=0.5) + + +def test_groups_not_in_obs_raises(adata): + with pytest.raises(KeyError): + calculate_niche( + adata, + flavor="neighborhood", + groups="missing_col", + n_neighbors=10, + resolutions=0.5, + ) + + +def test_missing_connectivities_raises(adata): + del adata.obsp["spatial_connectivities"] + with pytest.raises(KeyError, match="spatial_connectivities"): + calculate_niche( + adata, flavor="neighborhood", groups=GROUPS, n_neighbors=10, resolutions=0.5 + ) + + +def test_invalid_distance_raises(adata): + with pytest.raises(ValueError, match="distance"): + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + distance=0, + ) + + +def test_custom_connectivity_key(adata): + adata.obsp["my_graph"] = adata.obsp["spatial_connectivities"] + del adata.obsp["spatial_connectivities"] + calculate_niche( + adata, + flavor="neighborhood", + groups=GROUPS, + n_neighbors=10, + resolutions=0.5, + spatial_connectivities_key="my_graph", + ) + assert "nhood_niche_res=0.5" in adata.obs.columns + + +# -- cellcharter flavor tests -- + + +def test_cellcharter_basic(adata): + calculate_niche(adata, flavor="cellcharter", n_components=4) + assert "cellcharter_niche" in adata.obs.columns + col = adata.obs["cellcharter_niche"] + assert isinstance(col.dtype, pd.CategoricalDtype) + assert col.isna().sum() == 0 + assert col.nunique() <= 4 + + +def test_calculate_niche_exposes_only_gmm_init_extension(): + params = inspect.signature(calculate_niche).parameters + assert "init" not in params + assert "kmeans_n_init" not in params + assert "gmm_init" in params + assert params["gmm_init"].default == "random_from_data" + assert params["random_state"].default == 42 + + +def test_cellcharter_distance_zero(adata): + """distance=0 falls back to PCA + GMM on raw X (no shell aggregation).""" + calculate_niche(adata, flavor="cellcharter", n_components=3, distance=0) + assert adata.obs["cellcharter_niche"].nunique() <= 3 + + +def test_cellcharter_use_rep(adata): + """use_rep skips shell-aggregation and PCA; uses adata.obsm[key] directly.""" + rng = np.random.default_rng(0) + adata.obsm["X_test"] = rng.standard_normal((adata.n_obs, 10)).astype(np.float32) + calculate_niche(adata, flavor="cellcharter", n_components=4, use_rep="X_test") + assert "cellcharter_niche" in adata.obs.columns + + +@pytest.mark.parametrize("gmm_init", ["random_from_data", "kmeans", "sklearn_kmeans"]) +def test_cellcharter_gmm_init_options(adata, gmm_init): + rng = np.random.default_rng(0) + adata.obsm["X_test"] = rng.standard_normal((adata.n_obs, 10)).astype(np.float32) + calculate_niche( + adata, + flavor="cellcharter", + n_components=4, + use_rep="X_test", + gmm_init=gmm_init, + random_state=0, + ) + assert "cellcharter_niche" in adata.obs.columns + + +def test_cellcharter_determinism(adata): + a1 = adata.copy() + a2 = adata.copy() + calculate_niche(a1, flavor="cellcharter", n_components=4, random_state=42) + calculate_niche(a2, flavor="cellcharter", n_components=4, random_state=42) + np.testing.assert_array_equal( + a1.obs["cellcharter_niche"].astype(str).values, + a2.obs["cellcharter_niche"].astype(str).values, + ) + + +def test_cellcharter_variance(adata): + """`aggregation="variance"` runs and produces a categorical column.""" + calculate_niche(adata, flavor="cellcharter", n_components=4, aggregation="variance") + assert "cellcharter_niche" in adata.obs.columns + assert isinstance(adata.obs["cellcharter_niche"].dtype, pd.CategoricalDtype) + + +def test_cellcharter_invalid_aggregation(adata): + with pytest.raises(ValueError, match="aggregation"): + calculate_niche( + adata, flavor="cellcharter", n_components=4, aggregation="bogus" + ) + + +def test_cellcharter_invalid_gmm_init(adata): + with pytest.raises(ValueError, match="gmm_init"): + calculate_niche(adata, flavor="cellcharter", n_components=4, gmm_init="bogus") + + +def test_cellcharter_bad_n_components(adata): + with pytest.raises(ValueError, match="n_components"): + calculate_niche(adata, flavor="cellcharter", n_components=0) + + +def test_cellcharter_missing_use_rep(adata): + with pytest.raises(KeyError): + calculate_niche( + adata, flavor="cellcharter", n_components=4, use_rep="not_there" + ) + + +def test_cellcharter_use_rep_too_few_dims(adata): + adata.obsm["X_small"] = np.zeros((adata.n_obs, 3), dtype=np.float32) + with pytest.raises(ValueError, match="at least"): + calculate_niche(adata, flavor="cellcharter", n_components=10, use_rep="X_small") + + +def test_cellcharter_invalid_distance_negative(adata): + with pytest.raises(ValueError, match="distance"): + calculate_niche(adata, flavor="cellcharter", n_components=4, distance=-1)