From bc5127333d7881c47dde6c36872b554d1dd8fee0 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 15:38:28 +0200 Subject: [PATCH 1/7] first iteration of refactor --- .gitignore | 2 + CMakeLists.txt | 1 + src/rapids_singlecell/_cuda/nb_types.h | 7 + .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 400 +++++++ .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 976 ++++++++++++++++++ .../_cuda/wilcoxon/wilcoxon.cu | 91 ++ .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 345 +++++++ .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 518 ++++++++++ .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 853 +++++++++++++++ .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 182 ++++ .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 104 ++ .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 861 +++++++++++++++ .../_cuda/wilcoxon/wilcoxon_sparse.cu | 292 ++++++ .../wilcoxon/wilcoxon_sparse_kernels.cuh | 651 ++++++++++++ .../tools/_rank_genes_groups/__init__.py | 182 +++- .../tools/_rank_genes_groups/_core.py | 340 ++++-- .../tools/_rank_genes_groups/_utils.py | 54 +- .../tools/_rank_genes_groups/_wilcoxon.py | 966 +++++++++++++++-- .../_rank_genes_groups/_wilcoxon_binned.py | 16 +- tests/test_rank_genes_groups_ttest.py | 10 +- tests/test_rank_genes_groups_wilcoxon.py | 401 ++++++- 21 files changed, 7054 insertions(+), 198 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh diff --git a/.gitignore b/.gitignore index c0e83438..6994e147 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,8 @@ coverage.xml .cursor/ .claude/ CLAUDE.md +.codex # tmp_scripts tmp_scripts/ +benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..85d33e91 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 905e1e07..eb343815 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -42,6 +42,13 @@ using gpu_array = nb::ndarray; template using gpu_array_contig = nb::ndarray; +// Host (NumPy) array aliases +template +using host_array = nb::ndarray>; + +template +using host_array_2d = nb::ndarray>; + // Register bindings for both regular CUDA and managed-memory arrays. // Usage: // template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..8b6af5f6 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -2,6 +2,27 @@ #include +__device__ __forceinline__ double wilcoxon_block_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + return v; + } + return 0.0; +} + /** * Kernel to compute tie correction factor for Wilcoxon test. * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied @@ -142,3 +163,382 @@ __global__ void average_rank_kernel(const double* __restrict__ sorted_vals, rk[si[i]] = avg_rank; } } + +/** + * OVO dense rank core. + * + * ref_sorted is F-order and sorted independently for every column. + * grp_data is F-order and contains test-group rows concatenated by + * grp_offsets. One block computes one (column, test-group) result. + * + * This intentionally centralizes the OVO math; host/device and CSR/CSC/dense + * paths only need to materialize bounded dense column batches that feed this + * kernel. + */ +__global__ void ovo_rank_dense_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_data, + const int* __restrict__ grp_offsets, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, + int n_all_grp, int n_cols, int n_groups, + bool compute_tie_corr) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_data + (long long)col * n_all_grp + g_start; + + __shared__ double warp_buf[32]; + double local_rank = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + for (int j = 0; j < n_grp; ++j) { + float u = grp_col[j]; + n_lt_grp += (u < v); + n_eq_grp += (u == v); + } + + local_rank += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + double total_rank = wilcoxon_block_sum(local_rank, warp_buf); + if (threadIdx.x == 0) { + rank_sums[(size_t)grp * n_cols + col] = total_rank; + } + + if (!compute_tie_corr) return; + __syncthreads(); + + double local_tie = 0.0; + + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int count = lo - i; + for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); + if (count > 1) { + double t = (double)count; + local_tie += t * t * t - t; + } + } + } + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + bool seen_in_group = false; + for (int j = 0; j < i; ++j) { + if (grp_col[j] == v) { + seen_in_group = true; + break; + } + } + if (seen_in_group) continue; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + if (lo < n_ref && ref_col[lo] == v) continue; + + int count = 0; + for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); + if (count > 1) { + double t = (double)count; + local_tie += t * t * t - t; + } + } + + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[(size_t)grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +__global__ void ovo_rank_presorted_kernel(const float* __restrict__ ref_sorted, + const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, + int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + __shared__ double warp_buf[32]; + double local_rank = 0.0; + + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_rank += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + double total_rank = wilcoxon_block_sum(local_rank, warp_buf); + if (threadIdx.x == 0) { + rank_sums[(size_t)grp * n_cols + col] = total_rank; + } + + if (!compute_tie_corr) return; + __syncthreads(); + + double local_tie = 0.0; + int grp_lb_tie = 0, grp_ub_tie = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + lo = grp_lb_tie; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb_tie = lb; + + lo = (grp_ub_tie > lb) ? grp_ub_tie : lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub_tie = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + int ref_lb_tie = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + int lo = ref_lb_tie, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb_tie = lo; + if (lo < n_ref && ref_col[lo] == v) continue; + + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[(size_t)grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +/** + * OVR dense rank core. + * + * sorted_vals and sorter are F-order outputs of sorting each column of the + * current dense block. The kernel directly accumulates rank sums per group, + * avoiding a full ranks matrix and a group one-hot matrix multiply. + */ +__global__ void ovr_rank_dense_kernel(const float* __restrict__ sorted_vals, + const int* __restrict__ sorter, + const int* __restrict__ group_codes, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, + bool compute_tie_corr) { + int col = blockIdx.x; + if (col >= n_cols) return; + + const float* sv = sorted_vals + (long long)col * n_rows; + const int* si = sorter + (long long)col * n_rows; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + float val = sv[i]; + + int lo = 0, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + int tie_start = lo; + + lo = i; + hi = n_rows - 1; + while (lo < hi) { + int mid = lo + ((hi - lo + 1) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + int tie_end = lo; + double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; + + int row = si[i]; + int group = group_codes[row]; + if (group >= 0 && group < n_groups) { + atomicAdd(&rank_sums[(size_t)group * n_cols + col], avg_rank); + } + + if (compute_tie_corr && i == tie_end) { + double t = (double)(tie_end - tie_start + 1); + if (t > 1.0) local_tie += t * t * t - t; + } + } + + if (!compute_tie_corr) return; + + __shared__ double warp_buf[32]; + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh new file mode 100644 index 00000000..5b4c0b8c --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -0,0 +1,976 @@ +#pragma once + +#include + +// ============================================================================ +// Warp reduction helper (sum doubles across block via warp_buf) +// ============================================================================ + +__device__ __forceinline__ double block_reduce_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v2 = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v2 += __shfl_down_sync(0xffffffff, v2, off); + return v2; // only lane 0 of warp 0 has the final result + } + return 0.0; +} + +// ============================================================================ +// Parallel tie correction — all threads collaborate. +// +// For each unique value in the combined sorted (ref, grp) arrays, accumulate +// t^3 - t where t = count of that value. Uses two passes: +// 1. Iterate unique values in ref_col, count in both arrays. +// 2. Iterate unique values in grp_col that do NOT appear in ref_col. +// +// Incremental binary search bounds exploit monotonicity within each thread's +// stride to reduce total search work. +// +// Caller must __syncthreads() before calling. warp_buf is reused for +// reduction (32 doubles, shared memory). +// ============================================================================ + +__device__ __forceinline__ void compute_tie_correction_parallel( + const float* ref_col, int n_ref, const float* grp_col, int n_grp, + double* warp_buf, double* out) { + double local_tie = 0.0; + + // Pass 1: unique values in ref_col + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + // Count in grp: incremental lower/upper bound + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb = lb; + + lo = (grp_ub > lb) ? grp_ub : lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp_col that are absent from ref_col + int ref_lb = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + + // Incremental lower_bound in ref + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb = lo; + + if (lo >= n_ref || ref_col[lo] != v) { + // Value not in ref — count in grp only (upper_bound from i+1) + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + } + + // Block-wide reduction + double tie_sum = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + *out = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Batched rank sums — pre-sorted (binary search, no shared memory sort) +// Used by the OVO streaming pipeline in wilcoxon_streaming.cu. +// +// Incremental binary search: each thread carries forward lower/upper bound +// positions across loop iterations, exploiting the monotonicity of the +// sorted grp_col values within each thread's stride. +// ============================================================================ + +__global__ void batched_rank_sums_presorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch (see ovo_fused_sort_rank_kernel for the contract). + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + // Incremental binary search bounds (advance monotonically per thread) + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + int lo, hi; + + // Lower bound in ref (from ref_lb) + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + // Upper bound in ref (from max(ref_ub, n_lt_ref)) + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + // Lower bound in grp (from grp_lb) + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + // Upper bound in grp (from max(grp_ub, n_lt_grp)) + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + compute_tie_correction_parallel(ref_col, n_ref, grp_col, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 1 fused kernel: smem bitonic sort + binary search rank sums +// For small groups (< ~2K cells). No CUB, no global memory sort buffers. +// Grid: (n_cols, n_groups), Block: min(padded_grp_size, 512) +// Shared memory: padded_grp_size floats + 32 doubles (warp reduction) +// ============================================================================ + +__global__ void ovo_fused_sort_rank_kernel( + const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted + const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) + // unsorted + const int* __restrict__ grp_offsets, // (n_groups + 1,) + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_groups, n_cols) row-major + int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int padded_grp_size, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch: when co-launched with the Tier 0 warp kernel we + // skip groups it's already handling. Each group owns its own + // rank_sums row, so the two kernels' writes never alias. + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // Shared memory: [padded_grp_size floats | 32 doubles for warp reduction] + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + padded_grp_size * sizeof(float)); + + // Load group data into shared memory, pad with +INF + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + for (int i = n_grp + threadIdx.x; i < padded_grp_size; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF + __syncthreads(); + + // Bitonic sort in shared memory + for (int k = 2; k <= padded_grp_size; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < padded_grp_size; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + } + __syncthreads(); + } + } + + // Binary search each sorted grp element against sorted ref + // Incremental bounds: values are monotonic within each thread's stride + const float* ref_col = ref_sorted + (long long)col * n_ref; + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + int lo, hi; + + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + // Block reduction → write rank_sums + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + // Parallel tie correction (grp_smem is sorted shared memory) + compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 2 helper: tie contribution of the sorted reference alone. +// One block per column. The medium unsorted-rank kernel uses this as a base +// and only adds group-only/overlap deltas from the unsorted group values. +// ============================================================================ + +__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, + double* __restrict__ ref_tie_sums, int n_ref, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* ref_col = ref_sorted + (long long)col * n_ref; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) ref_tie_sums[col] = total; +} + +__global__ void ovo_small64_sort_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > TIER0_64_GROUP_THRESHOLD) return; + + __shared__ float grp_smem[TIER0_64_GROUP_THRESHOLD]; + __shared__ double warp_buf[WARP_REDUCE_BUF]; + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + const float POS_INF = __int_as_float(0x7f800000); + if (threadIdx.x < TIER0_64_GROUP_THRESHOLD) { + grp_smem[threadIdx.x] = + (threadIdx.x < n_grp) ? grp_col[threadIdx.x] : POS_INF; + } + __syncthreads(); + + for (int k = 2; k <= TIER0_64_GROUP_THRESHOLD; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + int i = threadIdx.x; + int ixj = i ^ j; + if (i < TIER0_64_GROUP_THRESHOLD && ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + __syncthreads(); + } + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + if (threadIdx.x < n_grp) { + float v = grp_smem[threadIdx.x]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + lo = 0; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && + (threadIdx.x == 0 || v != grp_smem[threadIdx.x - 1])) { + double combined = (double)(n_eq_ref + n_eq_grp); + if (combined > 1.0) { + local_tie_delta += combined * combined * combined - combined; + } + if (n_eq_ref > 1) { + double cr = (double)n_eq_ref; + local_tie_delta -= cr * cr * cr - cr; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Tier 2 fused kernel: no-sort direct rank for medium groups. +// +// Avoids the smem bitonic sort for groups in (skip_n_grp_le, +// max_n_grp_le]. Ranks are computed from ref binary searches plus an +// in-group scan over unsorted shared values. Tie correction starts from +// ref_tie_sums[col] and adds only group-only / ref-overlap deltas. +// ============================================================================ + +__global__ void ovo_medium_unsorted_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return; + + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float)); + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + __syncthreads(); + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + bool first_in_grp = true; + for (int j = 0; j < n_grp; ++j) { + float w = grp_smem[j]; + if (w < v) ++n_lt_grp; + if (w == v) { + ++n_eq_grp; + if (j < i) first_in_grp = false; + } + } + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && first_in_grp) { + double cg = (double)n_eq_grp; + double cr = (double)n_eq_ref; + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local_tie_delta += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local_tie_delta += combined * combined * combined - combined - + ref_tie - group_tie; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Warp-scoped tie correction for Tier 0. +// +// Sorted values live in a 32-lane register (one per lane, with unused lanes +// carrying +INF). Walks unique values via lane-step differentials and +// counts ties across the sorted ref column via binary search. All the +// sync is __syncwarp — no smem, no __syncthreads. +// ============================================================================ + +__device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, + int n_ref, float v_lane, + int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_tie = 0.0; + + // Pass 1: for each unique value in ref_col, count occurrences in ref and + // in the sorted group (held in register v_lane across 32 lanes). + for (int base = 0; base < n_ref; base += 32) { + int i = base + lane; + bool in_ref_lane = (i < n_ref); + float v = in_ref_lane ? ref_col[i] : 0.0f; + bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); + int cnt_ref = 0; + if (is_first) { + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + cnt_ref = lo - i; + } + + // Count in grp: look up how many lanes hold v_lane == v. All lanes + // execute the shuffle loop; only lanes owning a unique ref value use + // the result. + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + float vi = __shfl_sync(0xffffffff, v_lane, lane_i); + if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (is_first) { + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp that are absent from ref. + // Walk lanes 0..n_grp-1; for each lane whose v differs from prev lane's, + // binary-search ref for v. If not present, count consecutive matching + // lanes (tie block). + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + bool in_ref = false; + if (first_in_grp) { + // Binary search in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + in_ref = (lo < n_ref) && (ref_col[lo] == v); + } + + // Count how many lanes ≥ this lane hold the same v. Keep the shuffle + // uniform across active lanes even though only unique, ref-absent + // group values consume the count. + int cnt = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && + vi == v) { + ++cnt; + } + } + if (first_in_grp && !in_ref && cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie += __shfl_down_sync(0xffffffff, local_tie, off); + return local_tie; // meaningful on lane 0. +} + +__device__ __forceinline__ double tier0_tie_delta_warp( + const float* ref_col, int n_ref, float v_lane, int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_delta = 0.0; + + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (first_in_grp) { + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int ref_lb = lo; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - ref_lb; + + double combined = (double)(cnt_ref + cnt_grp); + if (combined > 1.0) { + local_delta += combined * combined * combined - combined; + } + if (cnt_ref > 1) { + double cr = (double)cnt_ref; + local_delta -= cr * cr * cr - cr; + } + } + } + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_delta += __shfl_down_sync(0xffffffff, local_delta, off); + return local_delta; // meaningful on lane 0. +} + +// ============================================================================ +// Tier 0 fused kernel: warp-per-(col, group) pair, 8 warps packed per block. +// +// Each warp independently: +// 1. Loads ≤ 32 group values into a single register (one per lane, +// padded with +INF). +// 2. Bitonic-sorts via __shfl_xor_sync — no smem, no __syncthreads. +// 3. Binary-searches into sorted ref for each lane's value and +// accumulates the rank-sum term. +// 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. +// +// 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair +// Tier 1, and the lack of __syncthreads / smem sort lets each warp run +// independently at full throughput. +// +// Grid: (n_cols, ceil(n_groups / 8)), Block: 256. +// ============================================================================ + +__global__ void ovo_warp_sort_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + constexpr int WARPS_PER_BLOCK = 8; + int warp_id = threadIdx.x >> 5; + int lane = threadIdx.x & 31; + + int col = blockIdx.x; + int grp = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // This kernel only handles groups that fit in a single warp (one value + // per lane). Larger groups are delegated to Tier 1/3 in a co-launched + // kernel; since each group owns its own row in rank_sums/tie_corr, the + // two kernels interlace into the output without conflict. + if (n_grp > TIER0_GROUP_THRESHOLD) return; + + if (n_grp == 0) { + if (lane == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // One value per lane, pad with +INF so sort pushes them to the end. + const float POS_INF = __int_as_float(0x7f800000); + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + float x = (lane < n_grp) ? grp_col[lane] : POS_INF; + unsigned int active_mask = __ballot_sync(0xffffffff, lane < n_grp); + + // Warp-shuffle bitonic sort (ascending) — 32 elements in registers. + for (int k = 1; k <= 16; k <<= 1) { + for (int j = k; j > 0; j >>= 1) { + float y = __shfl_xor_sync(0xffffffff, x, j); + bool asc = (((lane & (k << 1)) == 0)); + bool take_min = (((lane & j) == 0) == asc); + x = take_min ? fminf(x, y) : fmaxf(x, y); + } + } + + // After sort, x[lane] holds the lane-th smallest group value (lanes + // ≥ n_grp hold +INF). Binary-search each value into the sorted ref. + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + + if (lane < n_grp) { + float v = x; + // Lower bound in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + // Upper bound in ref. + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + // In-group counts: in the sorted warp-register x, count lanes < this + // one that hold strictly less, and lanes with equal value. + int n_lt_grp = 0; + int n_eq_grp_offset = 0; // tied lanes strictly before this one + int n_eq_grp_after = 1; // count self +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + if (lane_i >= n_grp) continue; + float vi = __shfl_sync(active_mask, v, lane_i); + if (lane_i < lane) { + if (vi < v) + ++n_lt_grp; + else if (vi == v) + ++n_eq_grp_offset; + } else if (lane_i > lane) { + if (vi == v) ++n_eq_grp_after; + } + } + int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; + // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + + // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane + // gets the same mid-rank. This matches the Tier 1 accumulation. + local_sum = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; + + if (!compute_tie_corr) return; + + // Warp-scoped tie correction. + double tie_sum; + if (ref_tie_sums != nullptr) { + tie_sum = ref_tie_sums[col] + + tier0_tie_delta_warp(ref_col, n_ref, x, n_grp, active_mask); + } else { + tie_sum = tier0_tie_sum_warp(ref_col, n_ref, x, n_grp, active_mask); + } + if (lane == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index d25f7d0f..0ab5b26c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -8,6 +8,7 @@ using namespace nb::literals; // Constants for kernel launch configuration constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int OVO_THREADS_PER_BLOCK = 256; static inline int round_up_to_warp(int n) { int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; @@ -37,6 +38,43 @@ static inline void launch_average_rank(const double* sorted_vals, CUDA_CHECK_LAST_ERROR(average_rank_kernel); } +static inline void launch_ovo_rank_dense( + const float* ref_sorted, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, cudaStream_t stream) { + dim3 block(OVO_THREADS_PER_BLOCK); + dim3 grid(n_cols, n_groups); + ovo_rank_dense_kernel<<>>( + ref_sorted, grp_data, grp_offsets, rank_sums, tie_corr, n_ref, + n_all_grp, n_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_rank_dense_kernel); +} + +static inline void launch_ovo_rank_presorted( + const float* ref_sorted, const float* grp_sorted, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, cudaStream_t stream) { + dim3 block(OVO_THREADS_PER_BLOCK); + dim3 grid(n_cols, n_groups); + ovo_rank_presorted_kernel<<>>( + ref_sorted, grp_sorted, grp_offsets, rank_sums, tie_corr, n_ref, + n_all_grp, n_cols, n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_rank_presorted_kernel); +} + +static inline void launch_ovr_rank_dense( + const float* sorted_vals, const int* sorter, const int* group_codes, + double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, cudaStream_t stream) { + int threads_per_block = round_up_to_warp(n_rows); + dim3 block(threads_per_block); + dim3 grid(n_cols); + ovr_rank_dense_kernel<<>>( + sorted_vals, sorter, group_codes, rank_sums, tie_corr, n_rows, n_cols, + n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovr_rank_dense_kernel); +} + template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; @@ -65,6 +103,59 @@ void register_bindings(nb::module_& m) { }, "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "stream"_a = 0); + + m.def( + "ovo_rank_dense", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + std::uintptr_t stream) { + launch_ovo_rank_dense( + ref_sorted.data(), grp_data.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, (cudaStream_t)stream); + }, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovo_rank_presorted", + [](gpu_array_f ref_sorted, + gpu_array_f grp_sorted, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + std::uintptr_t stream) { + launch_ovo_rank_presorted( + ref_sorted.data(), grp_sorted.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, (cudaStream_t)stream); + }, + "ref_sorted"_a, "grp_sorted"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovr_rank_dense", + [](gpu_array_f sorted_vals, + gpu_array_f sorter, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, std::uintptr_t stream) { + launch_ovr_rank_dense(sorted_vals.data(), sorter.data(), + group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, + compute_tie_corr, (cudaStream_t)stream); + }, + "sorted_vals"_a, "sorter"_a, "group_codes"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh new file mode 100644 index 00000000..dd50d2cb --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -0,0 +1,345 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 64; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; +// Default thread-per-block for utility kernels (extract, gather, offsets, +// etc.). +constexpr int UTIL_BLOCK_SIZE = 256; +// Scratch slots for warp-level reduction (one slot per warp, 32 warps max). +constexpr int WARP_REDUCE_BUF = 32; +// Max group size for the super-fast "warp-per-(col,group)" fused kernel +// (Tier 0). Each warp sorts and ranks one (col, group) pair entirely in +// registers via warp-shuffle bitonic sort — no smem sort buffer, no +// __syncthreads(). Blocks pack 8 warps so block launch overhead is +// amortised 8× across (col, group) work items. This path is the fast +// route for per-celltype perturbation-style workloads where most test +// groups have only a few dozen cells. +constexpr int TIER0_GROUP_THRESHOLD = 32; +// Second small-group tier for perturbation workloads where most groups are +// slightly larger than one warp. Uses one compact shared-memory sort block per +// (column, group), avoiding the heavier Tier 2 in-group scan. +constexpr int TIER0_64_GROUP_THRESHOLD = 64; +// Medium-group cutoff for the unsorted direct-rank kernel. For perturbation +// workloads most groups sit below this range, where avoiding a full smem +// bitonic sort wins despite the O(n^2) in-group count. +constexpr int TIER2_GROUP_THRESHOLD = 512; +// Max group size for the fused smem-sort rank kernel (Tier 1 fast path). +// Beyond this, fall back to CUB segmented sort + binary-search rank kernel. +constexpr int TIER1_GROUP_THRESHOLD = 2500; +// Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes +// each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = +// fewer kernel launches; smaller = less per-stream memory. 128M items × 4B = +// 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. +constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; + +// --------------------------------------------------------------------------- +// RAII guard for cudaHostRegister. Unregisters on scope exit even when an +// exception unwinds — prevents leaked host pinning on stream-sync failures. +// --------------------------------------------------------------------------- +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); + if (err != cudaSuccess) { + // Already-registered memory is fine; anything else means the + // subsequent kernels would read garbage from an unmapped + // pointer, so surface the error immediately. + if (err == cudaErrorHostMemoryAlreadyRegistered) { + cudaGetLastError(); // clear sticky error flag + } else { + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } else { + ptr = p; + } + } + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; + +// --------------------------------------------------------------------------- +// Small allocation pool for temporary CUDA buffers. The previous PR used RMM +// here, but these sparse Wilcoxon kernels only need scoped scratch memory; +// using cudaMalloc keeps this module independent of an extra build-time +// dependency. +// --------------------------------------------------------------------------- +struct RmmPool { + std::vector bufs; + + ~RmmPool() { + for (void* ptr : bufs) { + if (ptr) cudaFree(ptr); + } + } + + template + T* alloc(size_t count) { + if (count == 0) count = 1; + void* ptr = nullptr; + cudaError_t err = cudaMalloc(&ptr, count * sizeof(T)); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("cudaMalloc failed in Wilcoxon scratch pool: ") + + cudaGetErrorString(err)); + } + bufs.push_back(ptr); + return static_cast(ptr); + } +}; + +struct ScopedCudaBuffer { + void* ptr = nullptr; + + explicit ScopedCudaBuffer(size_t bytes) { + if (bytes == 0) bytes = 1; + cudaError_t err = cudaMalloc(&ptr, bytes); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("cudaMalloc failed in Wilcoxon scoped buffer: ") + + cudaGetErrorString(err)); + } + } + + ~ScopedCudaBuffer() { + if (ptr) cudaFree(ptr); + } + + void* data() { + return ptr; + } + + ScopedCudaBuffer(const ScopedCudaBuffer&) = delete; + ScopedCudaBuffer& operator=(const ScopedCudaBuffer&) = delete; +}; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** Fill linear segment offsets [0, stride, 2*stride, ..., n_segments*stride] + * on-device. One thread per output slot. */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Fill per-row stats codes for a pack of K groups. + * Given pack_grp_offsets (size K+1, relative to pack start), write + * stats_codes[r] = base_slot + group_idx_of_row_r for r in [0, pack_n_rows). + * Binary search within the K+1 offsets. */ +__global__ void fill_pack_stats_codes_kernel( + const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, + int K, int base_slot) { + int r = blockIdx.x * blockDim.x + threadIdx.x; + int pack_n_rows = pack_grp_offsets[K]; + if (r >= pack_n_rows) return; + int lo = 0, hi = K; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (pack_grp_offsets[m + 1] <= r) + lo = m + 1; + else + hi = m; + } + stats_codes[r] = base_slot + lo; +} + +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. + * Grid-strided: supports arbitrary `count` (no single-block thread limit). + * Templated so that 64-bit global indptrs can produce 32-bit pack-local + * indptrs (per-pack nnz always fits in int32 thanks to the memory budget). + */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +/** Fused gather + cast-to-float32 + stats accumulation, reading from mapped + * pinned host memory. Block-per-row; threads in the block cooperate on the + * row's nnz. Each nnz is read from host over PCIe exactly once — no + * intermediate native-dtype GPU buffer, no second GPU pass. + * + * h_data / h_indices: device-accessible pointers into mapped pinned host + * memory (cudaHostRegisterMapped). + * d_indptr_full: full-matrix indptr on device. + * d_row_ids: rows to gather (size n_target_rows). + * d_out_indptr: pre-computed compacted indptr, size n_target_rows+1 with + * out_indptr[i+1] - out_indptr[i] equal to the source row's + * nnz. + * + * Slot dispatch: + * d_stats_codes != nullptr → slot = d_stats_codes[r]; otherwise slot = + * fixed_slot (used for the Ref phase where every row maps to the same + * slot). slot ∉ [0, n_groups_stats) skips accumulation. + */ +template +__global__ void csr_gather_cast_accumulate_mapped_kernel( + const InT* __restrict__ h_data, const IndexT* __restrict__ h_indices, + const IndptrT* __restrict__ d_indptr_full, + const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, + const int* __restrict__ d_stats_codes, int fixed_slot, + float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, + double* __restrict__ group_sums, double* __restrict__ group_sq_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sums, bool compute_sq_sums, + bool compute_nnz) { + int r = blockIdx.x; + if (r >= n_target_rows) return; + int src_row = d_row_ids[r]; + IndptrT rs = d_indptr_full[src_row]; + IndptrT re = d_indptr_full[src_row + 1]; + int row_nnz = (int)(re - rs); + int ds = d_out_indptr[r]; + int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; + bool accumulate = (slot >= 0 && slot < n_groups_stats); + for (int i = threadIdx.x; i < row_nnz; i += blockDim.x) { + InT v_in = h_data[rs + i]; + int c = (int)h_indices[rs + i]; + double v = (double)v_in; + d_out_data_f32[ds + i] = (float)v_in; + d_out_indices[ds + i] = c; + if (accumulate) { + if (compute_sums) { + atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + } + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } + } + } +} + +/** Fill linear segment offsets [0, stride, 2*stride, ...] on device. + * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. + */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); +} + +// ============================================================================ +// CSR → dense F-order extraction (templated on data type) +// ============================================================================ + +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} + +template +__global__ void csr_extract_dense_identity_rows_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + row] = data[p]; + } +} + +template +__global__ void csr_extract_dense_identity_rows_unsorted_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + for (int p = rs + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_start && c < col_stop) { + out[(long long)(c - col_start) * n_target + row] = data[p]; + } + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh new file mode 100644 index 00000000..7ad20b01 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -0,0 +1,518 @@ +#pragma once + +/** + * CSR-direct OVO streaming pipeline. + * + * One C++ call does everything. Reference rows are extracted and sorted once + * across all columns, then each group sub-batch ranks against that cached + * reference slice. This mirrors the fast host-CSR path and avoids redoing the + * reference dense extraction + segmented sort for every column sub-batch. + */ +static void ovo_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + size_t max_ref_cols = 2147483647LL / (size_t)n_ref; + if (max_ref_cols == 0) { + throw std::runtime_error( + "OVO device CSR reference group exceeds CUB int item limit"); + } + int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); + size_t free_bytes = 0; + size_t total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess) { + size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; + size_t target_bytes = free_bytes / 3; + if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { + size_t mem_cols = target_bytes / bytes_per_col; + if (mem_cols > 0 && mem_cols < (size_t)ref_cache_cols) { + ref_cache_cols = (int)mem_cols; + } + } + } + if (ref_cache_cols < 1) ref_cache_cols = 1; + + RmmPool pool; + + size_t cub_temp_bytes = 0; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = cub_grp_bytes; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* grp_dense; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].cub_temp = + needs_tier3 ? pool.alloc(cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + for (int cache_col = 0; cache_col < n_cols; cache_col += ref_cache_cols) { + int cache_cols = std::min(ref_cache_cols, n_cols - cache_col); + size_t cache_ref_items = (size_t)n_ref * cache_cols; + + ScopedCudaBuffer ref_dense_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_sorted_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_seg_offsets_buf((size_t)(cache_cols + 1) * + sizeof(int)); + float* d_ref_dense = (float*)ref_dense_buf.data(); + float* d_ref_sorted = (float*)ref_sorted_buf.data(); + int* d_ref_seg_offsets = (int*)ref_seg_offsets_buf.data(); + + cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float)); + int tpb_ref_extract = round_up_to_warp(n_ref); + int ref_blk = (n_ref + tpb_ref_extract - 1) / tpb_ref_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, d_ref_dense, n_ref, + cache_col, cache_col + cache_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + + upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, 0); + + size_t ref_cub_bytes = 0; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, (int)cache_ref_items, cache_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); + size_t ref_temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, + (int)cache_ref_items, cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT); + cudaDeviceSynchronize(); + + int col = cache_col; + int cache_stop = cache_col + cache_cols; + int batch_idx = 0; + while (col < cache_stop) { + int sb_cols = std::min(sub_batch_cols, cache_stop - col); + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = + d_ref_sorted + (size_t)(col - cache_col) * n_ref; + + cudaMemsetAsync(buf.grp_dense, 0, + sb_grp_items_actual * sizeof(float), stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, + buf.grp_dense, n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(ref_sub, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium( + ref_sub, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = + (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in OVO device CSR streaming: ") + + cudaGetErrorString(err)); + } + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * CSC-direct OVO streaming pipeline. + * + * Like the CSR variant, but extracts rows via lookup maps so it can operate on + * native CSC input without converting the whole matrix. + */ +static void ovo_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + if (run_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, buf.sub_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, + stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in OVO device CSC streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh new file mode 100644 index 00000000..feb86e57 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -0,0 +1,853 @@ +#pragma once + +/** + * Host-streaming CSC OVO pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU. Row maps + group offsets are uploaded once. + * Results are written back to host per sub-batch. + */ +template +static void ovo_streaming_csc_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Max nnz across any sub-batch for sparse transfer buffer sizing + size_t max_nnz = 0; + for (int c = 0; c < n_cols; c += sub_batch_cols) { + int sb = std::min(sub_batch_cols, n_cols - c); + size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // GPU copies of row maps + group offsets + stats codes (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); + + // Pin only the sparse input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = n_ref * sb_cols; + int sb_grp_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: sparse data for this column range (native dtype) ---- + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, + buf.d_group_nnz, sb_cols, n_groups_stats, compute_sq_sums, + compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + // ---- Extract ref from CSC via row_map, sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + if (run_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, + stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host CSR OVO pipeline — zero-copy mapped full-CSR with GPU-side row gather. + * + * Setup: pin the full host CSR with cudaHostRegisterMapped, upload the full + * indptr (small) + row_ids + pre-computed compacted indptrs. Each pack + * gathers only its rows over PCIe via a UVA kernel — the full matrix is never + * transferred to GPU. + * + * Phase 1 (Ref): fused gather + cast + stats over ref rows; segmented sort + * to d_ref_sorted (cached for the whole run). + * Phase 2 (per pack, round-robin across N_STREAMS): + * 1. rebase per-pack output indptr from the pre-uploaded global compacted + * indptr. + * 2. rebase per-pack group offsets + build per-row stats codes. + * 3. csr_gather_cast_accumulate_mapped_kernel — one PCIe pass, writes + * compacted f32 data + indices and accumulates per-group stats. + * 4. Per sub-batch: extract dense → sort → rank vs ref_sorted → scatter. + * + * Memory: d_ref_sorted (n_ref × n_cols × 4B) + N_STREAMS pack buffers sized + * for max_pack_rows × sb_cols (dense) and max_pack_nnz (compacted CSR). + * Full CSR stays on host (pinned-mapped). + */ +template +static void ovo_streaming_csr_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_full_rows, const int* h_ref_row_ids, int n_ref, + const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, + int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int n_cols, + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, + bool compute_nnz, bool compute_sums, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; + + // ---- Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)) ---- + // Use IndptrT for the global compacted indptr because the grp side can + // exceed 2^31 nnz on very large / dense matrices. Ref always fits in + // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the + // downstream CUB segmented-sort temp sizing. + std::vector h_ref_indptr_compact(n_ref + 1); + h_ref_indptr_compact[0] = 0; + for (int i = 0; i < n_ref; i++) { + int r = h_ref_row_ids[i]; + int nnz_i = (int)(h_indptr[r + 1] - h_indptr[r]); + h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; + } + int ref_nnz = h_ref_indptr_compact[n_ref]; + + // grp: compacted indptr over concatenated test-group rows (IndptrT). + std::vector h_grp_indptr_compact(n_all_grp + 1); + h_grp_indptr_compact[0] = 0; + for (int i = 0; i < n_all_grp; i++) { + int r = h_grp_row_ids[i]; + IndptrT nnz_i = h_indptr[r + 1] - h_indptr[r]; + h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; + } + + // ---- Build packs (same rule as grp_impl, but uses compacted indptr) ---- + struct Pack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; + }; + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = sub_batch_cols; + { + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = + GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows); + if (!can_add) { + size_t sb_size = + std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + } + for (const Pack& pk : packs) { + int K = pk.end - pk.first; + if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; + if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; + if (K > max_pack_K) max_pack_K = K; + int pack_items = pk.n_rows * pk.sb_cols; + if (pack_items > max_pack_items) max_pack_items = pack_items; + if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; + } + int max_group_rows = max_pack_rows; + size_t max_sub_items = (size_t)max_pack_items; + if (max_pack_rows == 0) return; + + RmmPool pool; + + // Zero stats outputs. + if (compute_sums) { + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + + // ---- Pin full host data + indices as MAPPED (zero-copy accessible) ---- + size_t full_nnz = (size_t)h_indptr[n_full_rows]; + HostRegisterGuard _pin_data(const_cast(h_data), + full_nnz * sizeof(InT), cudaHostRegisterMapped); + HostRegisterGuard _pin_indices(const_cast(h_indices), + full_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + + // Get device-accessible pointers (UVA makes these equal to host ptrs on + // Linux x86-64, but the API is the safe/portable way). + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (full_nnz > 0) { + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + // ---- Upload full indptr (keep native IndptrT — can exceed int32) ---- + IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // ---- Upload row_ids + compacted indptrs + group boundaries ---- + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); + int* d_grp_offsets_full = pool.alloc(n_test + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_indptr_compact, h_grp_indptr_compact.data(), + (n_all_grp + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + float* d_ref_sorted = pool.alloc((size_t)n_ref * n_cols); + { + ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); + ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); + ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); + ScopedCudaBuffer ref_dense_buf((size_t)n_ref * n_cols * sizeof(float)); + ScopedCudaBuffer ref_seg_buf((n_cols + 1) * sizeof(int)); + + float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); + int* d_ref_indices = (int*)ref_indices_buf.data(); + int* d_ref_indptr = (int*)ref_indptr_buf.data(); + float* d_ref_dense = (float*)ref_dense_buf.data(); + int* d_ref_seg = (int*)ref_seg_buf.data(); + + // Upload ref compacted indptr + cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), + (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Fused gather + cast + stats for ref (fixed slot = n_test). One + // pass over PCIe, no intermediate native-dtype GPU buffer. + if (n_ref > 0 && ref_nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, + d_ref_indptr, /*d_stats_codes=*/nullptr, + /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, + d_group_sums, d_group_sq_sums, d_group_nnz, n_ref, n_cols, + n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Extract ref dense (F-order) from compacted CSR. + cudaMemsetAsync(d_ref_dense, 0, (size_t)n_ref * n_cols * sizeof(float)); + { + csr_extract_dense_identity_rows_unsorted_kernel + <<>>(d_ref_data_f32, d_ref_indices, + d_ref_indptr, d_ref_dense, n_ref, + 0, n_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + } + + // Segmented sort ref_dense by column → ref_sorted + size_t ref_cub_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, (int)((size_t)n_ref * n_cols), + n_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, 0); + size_t temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, + (int)((size_t)n_ref * n_cols), n_cols, d_ref_seg, d_ref_seg + 1, + BEGIN_BIT, END_BIT); + cudaDeviceSynchronize(); + } // ref scratch drops here + + // ---- Phase 2: Per-pack streaming ---- + auto t1 = make_tier1_config(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > TIER1_GROUP_THRESHOLD); + + constexpr int MAX_GROUP_STREAMS = 4; + int n_streams = MAX_GROUP_STREAMS; + if (n_test < n_streams) n_streams = n_test; + if (n_streams < 1) n_streams = 1; + if ((int)packs.size() < n_streams) n_streams = (int)packs.size(); + if (n_streams < 1) n_streams = 1; + + size_t cub_grp_bytes = 0; + if (may_need_cub && max_sub_items > 0) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + int max_segments = max_pack_K * max_pack_sb_cols; + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, (int)max_sub_items, max_segments, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* d_grp_data_f32; + int* d_grp_indices; + int* d_grp_indptr; + int* d_pack_grp_offsets; + int* d_pack_stats_codes; + float* d_grp_dense; + float* d_grp_sorted; + double* d_ref_tie_sums; + int* d_sort_group_ids; + int* d_grp_seg_offsets; + int* d_grp_seg_ends; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + int max_pack_kernel_seg = max_pack_K * max_pack_sb_cols; + for (int s = 0; s < n_streams; s++) { + bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indptr = pool.alloc(max_pack_rows + 1); + bufs[s].d_pack_grp_offsets = pool.alloc(max_pack_K + 1); + bufs[s].d_pack_stats_codes = pool.alloc(max_pack_rows); + bufs[s].d_grp_dense = pool.alloc(max_sub_items); + bufs[s].d_ref_tie_sums = pool.alloc(max_pack_sb_cols); + bufs[s].d_rank_sums = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + if (may_need_cub) { + bufs[s].d_grp_sorted = pool.alloc(max_sub_items); + bufs[s].d_sort_group_ids = pool.alloc(max_pack_K); + bufs[s].d_grp_seg_offsets = pool.alloc(max_pack_kernel_seg); + bufs[s].d_grp_seg_ends = pool.alloc(max_pack_kernel_seg); + bufs[s].cub_temp = pool.alloc(cub_grp_bytes); + } else { + bufs[s].d_grp_sorted = nullptr; + bufs[s].d_sort_group_ids = nullptr; + bufs[s].d_grp_seg_offsets = nullptr; + bufs[s].d_grp_seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + } + + cudaDeviceSynchronize(); // ensure Phase 1 done before Phase 2 streams + + for (int p = 0; p < (int)packs.size(); p++) { + const Pack& pack = packs[p]; + int K = pack.end - pack.first; + if (K == 0 || pack.n_rows == 0) continue; + Tier1Config pack_t1 = make_tier1_config(h_grp_offsets + pack.first, K); + int pack_tpb_rank = round_up_to_warp( + std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); + bool pack_has_above_t2 = pack_t1.max_grp_size > TIER2_GROUP_THRESHOLD; + int pack_tier3_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : TIER0_GROUP_THRESHOLD; + std::vector h_sort_group_ids; + int pack_n_sort_groups = K; + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, + K, pack_tier3_skip_le); + pack_n_sort_groups = (int)h_sort_group_ids.size(); + } + + int s = p % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + int row_start = h_grp_offsets[pack.first]; + int pack_rows = pack.n_rows; + int pack_sb = pack.sb_cols; + + // Rebase pack's output indptr from pre-uploaded global compacted indptr + // (IndptrT → int32: pack nnz is bounded by GROUP_DENSE_BUDGET so fits). + { + int count = pack_rows + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel + <<>>( + d_grp_indptr_compact, buf.d_grp_indptr, row_start, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Build per-pack group offsets on GPU (on this stream) — needed to + // compute stats codes before the fused gather kernel can run. + { + int count = K + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + d_grp_offsets_full, buf.d_pack_grp_offsets, pack.first, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Fill per-row stats codes for this pack + { + int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_pack_stats_codes_kernel<<>>( + buf.d_pack_grp_offsets, buf.d_pack_stats_codes, K, pack.first); + CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); + } + + // Fused gather + cast + stats for the pack. One pass over PCIe + // (reads mapped host via UVA), no intermediate native-dtype GPU + // buffer, writes f32 + indices + atomics. + if (pack.nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, + d_grp_row_ids + row_start, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, + buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, + d_group_sq_sums, d_group_nnz, pack_rows, n_cols, + n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Per col sub-batch + int col = 0; + while (col < n_cols) { + int sb_cols = std::min(pack_sb, n_cols - col); + int sb_items = pack_rows * sb_cols; + + cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), + stream); + csr_extract_dense_identity_rows_unsorted_kernel + <<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + buf.d_grp_dense, pack_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + + const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; + + int skip_le = 0; + bool run_tier0 = pack_t1.use_tier0; + bool run_tier0_64 = pack_t1.any_tier0_64; + bool run_tier2 = pack_t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.d_ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, n_ref, pack_rows, sb_cols, K, + compute_tie_corr, stream); + if (pack_t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + if (pack_t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (pack_has_above_t2 && pack_t1.use_tier1) { + dim3 grid(sb_cols, K); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, pack_t1.padded_grp_size, + upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (pack_has_above_t2) { + int n_seg = pack_n_sort_groups * sb_cols; + { + int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + buf.d_pack_grp_offsets, buf.d_sort_group_ids, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, pack_rows, + pack_n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_grp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.d_grp_dense, buf.d_grp_sorted, + sb_items, n_seg, buf.d_grp_seg_offsets, + buf.d_grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + dim3 grid(sb_cols, K); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.d_grp_sorted, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_rank_sums, + sb_cols * sizeof(double), + sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync( + d_tie_corr + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_tie_corr, + sb_cols * sizeof(double), sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in ovo csr host streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh new file mode 100644 index 00000000..afac20f2 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -0,0 +1,182 @@ +#pragma once + +/** + * Build CUB segmented-sort ranges only for groups that Tier 3 will rank. + * Group ids are relative to grp_offsets, and ranges still point into the + * original dense group layout so the presorted rank kernel can read from the + * normal per-group positions. + */ +__global__ void build_tier3_seg_begin_end_offsets_kernel( + const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, + int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, + int n_sort_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_sort_groups; + if (idx >= total) return; + + int c = idx / n_sort_groups; + int local = idx % n_sort_groups; + int g = group_ids[local]; + int base = c * n_all_grp; + begins[idx] = base + grp_offsets[g]; + ends[idx] = base + grp_offsets[g + 1]; +} + +/** + * Extract specific rows from CSC into dense F-order, using a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column, threads scatter matching nonzeros. + * Output must be pre-zeroed. + */ +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { + size_t bytes = 0; + auto* dk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, + n_segments, doff, doff + 1, 0, 32); + return bytes; +} + +/** + * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * bitonic-sort + binary-search kernel handles the whole group per block. + * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank + * kernel. This struct bundles the sizing knobs derived from the host-side + * group offsets so each streaming impl can drop a 15-line prep block. + */ +struct Tier1Config { + int max_grp_size = 0; + int min_grp_size = 0; + bool use_tier0 = + false; // any group fits in one warp (≤ TIER0_GROUP_THRESHOLD) + bool use_tier1 = + false; // any group needs > tier0 but fits in tier1 smem sort + bool any_above_t0 = + false; // at least one group exceeds TIER0_GROUP_THRESHOLD + bool any_tier0_64 = false; // any group needs Tier 0.5: (T0, T0_64] + bool any_tier2 = false; // any group needs Tier 2: (T0_64, T2] + bool any_above_t2 = + false; // at least one group exceeds TIER2_GROUP_THRESHOLD + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; +}; + +static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { + Tier1Config c; + c.min_grp_size = INT_MAX; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + if (sz < c.min_grp_size) c.min_grp_size = sz; + if (sz > TIER0_GROUP_THRESHOLD && sz <= TIER0_64_GROUP_THRESHOLD) { + c.any_tier0_64 = true; + } + if (sz > TIER0_64_GROUP_THRESHOLD && sz <= TIER2_GROUP_THRESHOLD) { + c.any_tier2 = true; + } + if (sz > TIER2_GROUP_THRESHOLD) c.any_above_t2 = true; + } + if (n_groups == 0) c.min_grp_size = 0; + + // use_tier0: Tier 0 kernel is worth running (at least one group small + // enough to benefit from the warp path). + c.use_tier0 = (c.min_grp_size <= TIER0_GROUP_THRESHOLD); + // any_above_t0: at least one group needs a non-Tier-0 kernel. + c.any_above_t0 = (c.max_grp_size > TIER0_GROUP_THRESHOLD); + // use_tier1: the fused smem-sort fast path (for groups > T0 but ≤ T1). + c.use_tier1 = c.any_above_t0 && (c.max_grp_size <= TIER1_GROUP_THRESHOLD); + if (c.use_tier1) { + c.padded_grp_size = 1; + while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; + c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); + c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + } + return c; +} + +static std::vector make_sort_group_ids(const int* h_grp_offsets, + int n_groups, int skip_n_grp_le) { + std::vector ids; + ids.reserve(n_groups); + for (int g = 0; g < n_groups; ++g) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (skip_n_grp_le > 0 && sz <= skip_n_grp_le) continue; + ids.push_back(g); + } + return ids; +} + +// Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) +// pair per warp. grid.y covers ceil(K/8) pair rows. +static inline void launch_tier0(const float* ref_sorted, const float* grp_dense, + const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int sb_cols, int K, bool compute_tie_corr, + cudaStream_t stream) { + constexpr int WARPS_PER_BLOCK = 8; + dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); + ovo_warp_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_warp_sort_rank_kernel); +} + +static inline void launch_ref_tie_sums(const float* ref_sorted, + double* ref_tie_sums, int n_ref, + int sb_cols, cudaStream_t stream) { + ref_tie_sum_kernel<<>>( + ref_sorted, ref_tie_sums, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); +} + +static inline void launch_tier0_64( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + dim3 grid(sb_cols, K); + ovo_small64_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le); + CUDA_CHECK_LAST_ERROR(ovo_small64_sort_rank_kernel); +} + +static inline void launch_tier2_medium( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + constexpr int tpb = 256; + size_t smem = (size_t)TIER2_GROUP_THRESHOLD * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + dim3 grid(sb_cols, K); + ovo_medium_unsorted_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, + TIER2_GROUP_THRESHOLD); + CUDA_CHECK_LAST_ERROR(ovo_medium_unsorted_rank_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh new file mode 100644 index 00000000..006002b9 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -0,0 +1,104 @@ +#pragma once + +/** Count nonzeros per column from CSR. One thread per row. */ +template +__global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)indices[p]; + if (c < n_cols) atomicAdd(&col_counts[c], 1); + } +} + +/** + * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). + * write_pos[c - col_start] must be initialized to the prefix-sum offset + * for column c. Each thread atomically claims a unique destination slot. + */ +template +__global__ void csr_scatter_to_csc_kernel( + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row; + } +} + +/** + * Decide whether to use shared or global memory for OVR rank accumulators. + * Returns the smem size to request and sets use_gmem accordingly. + */ +static int query_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator + * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if ((int)need <= query_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Fill sort values with row indices [0,1,...,n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. + */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh new file mode 100644 index 00000000..0f74a2c8 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -0,0 +1,861 @@ +#pragma once + +/** + * Sparse-aware host-streaming CSC OVR pipeline. + * + * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column + * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead + * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. + */ +template +static void ovr_sparse_csc_host_streaming_impl( + const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmPool pool; + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + int* d_sparse_indices; + int* d_seg_offsets; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Transfer group codes + sizes once + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + // Pre-compute rebased per-batch offsets and upload once (avoids per-batch + // H2D copy from a transient host buffer). + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) + off[i] = (int)(h_indptr[col_start + i] - ptr_start); + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + // In gmem mode the sparse rank kernel accumulates into rank_sums directly + // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + + // Pin only the host input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(int)); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = (int)(ptr_end - ptr_start); + + // H2D: transfer sparse data for this column range (native dtype) + if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + // D2D: copy this batch's rebased offsets from the pre-uploaded buffer + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + // CUB sort only stored nonzeros (float32 keys) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, + buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (stats already captured above) + if (rank_use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // D2D: scatter sub-batch results into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSC streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware host-streaming CSR OVR pipeline. +// ============================================================================ + +/** + * Host CSR variant of the sparse OVR stream. + * + * The CSR input stays in host memory. We count columns once on the CPU, then + * use mapped pinned CSR arrays for bounded per-column-batch CSR->CSC scatter + * on the GPU. This avoids both a full host->device sparse upload and any + * whole-matrix CSR->CSC conversion. + */ +template +static void ovr_sparse_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + RmmPool pool; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + // ---- Phase 0: CPU planning in native CSR order ---- + std::vector h_col_counts(n_cols, 0); + for (int row = 0; row < n_rows; row++) { + IndptrT rs = h_indptr[row]; + IndptrT re = h_indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)h_indices[p]; + if (c >= 0 && c < n_cols) h_col_counts[c]++; + } + } + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: allocate per-stream bounded work buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + size_t per_stream_bytes = + max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (compute_sq_sums) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (compute_nnz) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (rank_use_gmem) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Pin the source CSR arrays as mapped memory. The scatter kernel reads + // only the requested column window from each row. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (total_nnz > 0) { + pin_data = + HostRegisterGuard(const_cast(h_data), total_nnz * sizeof(InT), + cudaHostRegisterMapped); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; + int* write_pos; + InT* csc_vals_orig; + float* csc_vals_f32; + int* csc_row_idx; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sq_sums; + double* sub_group_nnz; + double* d_nz_scratch; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals_orig = pool.alloc(max_batch_nnz); + bufs[s].csc_vals_f32 = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: bounded CSR->CSC scatter + GPU rank batches ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, buf.write_pos, + buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, + buf.col_offsets, d_group_codes, buf.sub_group_sums, + buf.sub_group_sq_sums, buf.sub_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, + d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.sub_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSR streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSC OVR streaming (sort only stored nonzeros) +// ============================================================================ + +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Read indptr to host for batch planning + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch for buffer sizing + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + + RmmPool pool; + struct StreamBuf { + float* keys_out; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = ptr_end - ptr_start; + + // Compute rebased segment offsets on GPU (avoids host pinned-buffer + // race) + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Sort only stored values (keys=data, vals=row_indices) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, + csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) +// ============================================================================ + +/** + * Sparse-aware OVR streaming pipeline for GPU CSR data. + * + * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums + * give exact per-batch nnz and max_batch_nnz for buffer sizing. + * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. + * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via + * atomics) → CUB sort only nonzeros → sparse rank kernel. + * + * Compared to the dense CSR path, sort work drops by ~1/sparsity. + */ +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // ---- Phase 0: Planning — count nnz per column via histogram ---- + RmmPool pool; + int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + { + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + cudaMemcpyDeviceToHost); + + // Per-batch prefix sums on host + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + + // Flat array: n_batches × (sub_batch_cols + 1) offsets + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + off[0] = 0; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = off[i] + h_col_counts[col_start + i]; + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + // Upload all batch offsets to GPU in one shot (~20 KB) + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: Allocate per-stream buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + // CSR path needs 4 sort arrays per stream (scatter intermediates + + // CUB output). Fit stream count to available GPU memory. + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets + int* write_pos; // [sub_batch_cols] atomic write counters + float* csc_vals; // [max_batch_nnz] transposed values + int* csc_row_idx; // [max_batch_nnz] transposed row indices + float* keys_out; // [max_batch_nnz] CUB sort output + int* vals_out; // [max_batch_nnz] CUB sort output + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: Stream loop ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = (int)h_batch_nnz[b]; + + // D2D copy pre-computed col_offsets for this batch + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR → CSC layout for this sub-batch + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + // CUB sort only the nonzeros + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, + buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse CSR ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu new file mode 100644 index 00000000..19f1ef57 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -0,0 +1,292 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_fast_common.cuh" +#include "wilcoxon_sparse_kernels.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovr_sparse.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovo_kernels.cuh" +#include "wilcoxon_ovo_device_sparse.cuh" +#include "wilcoxon_ovo_host_sparse.cuh" + +using namespace nb::literals; + +template +void register_sparse_bindings(nb::module_& m) { + m.doc() = "Sparse-native host Wilcoxon CUDA kernels"; + + m.def( + "ovr_sparse_csc_device", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_streaming_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csr_device", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csr_streaming_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csr_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVR_SPARSE_CSR_HOST_BINDING + + m.def( + "ovo_streaming_csc_device", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c ref_row_map, + gpu_array_c grp_row_map, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, + "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csr_device", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c ref_row_ids, + gpu_array_c grp_row_ids, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, + "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64", float, int64_t, + int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_full_rows, \ + int n_ref, int n_all_grp, int n_cols, int n_test, \ + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, bool compute_sums, int sub_batch_cols) { \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ + h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ + h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_cols, \ + n_groups_stats, compute_tie_corr, compute_sq_sums, \ + compute_nnz, compute_sums, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_full_rows"_a, "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "compute_sums"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64", float, int64_t, + int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSR_HOST_BINDING +} + +NB_MODULE(_wilcoxon_sparse_cuda, m) { + REGISTER_GPU_BINDINGS(register_sparse_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh new file mode 100644 index 00000000..b0e40fdc --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -0,0 +1,651 @@ +#pragma once + +#include + +/** + * Fused rank-sum kernel: walk sorted data, compute per-group rank sums + * and tie correction without materializing a rank matrix. + * + * Each thread processes a CONTIGUOUS chunk of sorted elements, detecting + * tie groups by adjacent comparison (sequential access, no binary search). + * Cross-boundary ties are resolved via binary search at chunk boundaries. + * + * When use_gmem is false, per-group accumulators live in shared memory + * (fast atomics, limited to ~1500 groups on 48 KB devices). When use_gmem + * is true, accumulators write directly to ``rank_sums`` in global memory, + * supporting an arbitrary number of groups. The caller must pre-zero + * ``rank_sums`` before launching in the gmem path. + * + * Shared memory layout: + * use_gmem=false: (n_groups + 32) doubles (accumulators + warp buf) + * use_gmem=true: 32 doubles (warp buf only) + */ +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { + int col = blockIdx.x; + if (col >= n_cols) return; + + extern __shared__ double smem[]; + + double* grp_sums; + if (use_gmem) { + // Global memory path: write directly to output (must be pre-zeroed) + grp_sums = rank_sums + (size_t)col; // stride: n_cols + } else { + // Shared memory path: per-block accumulators + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + } + __syncthreads(); + } + + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; + + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; + + double local_tie_sum = 0.0; + + // Stride for accumulator indexing: 1 for shared mem, n_cols for global mem + int acc_stride = use_gmem ? n_cols : 1; + + int i = my_start; + while (i < my_end) { + double val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0, hi = i; + while (lo < hi) { + int mid = lo + (hi - lo) / 2; + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = n_rows - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Copy shared memory accumulators to global output (smem path only) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; + } + } + + if (compute_tie_corr) { + // Warp buf sits after accumulator array in shared memory. + // gmem path: warp buf starts at smem[0]. + // smem path: n_groups doubles, then warp buf. + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - val / denom) : 1.0; + } + } + } +} + +/** + * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. + * + * Sparse rank_genes_groups now rejects explicit negative sparse values before + * reaching CUDA, so after CUB sort each column segment is: + * [stored_zeros..., positives...] + * + * Implicit zeros (n_rows - nnz_stored) join stored zeros as the first tie + * block. The kernel ranks only stored positive values and adds each group's + * zero contribution analytically. + * + * Full sorted array (conceptual): + * [ALL_zeros (stored+implicit)..., positives...] + * + * Rank offsets: + * positive at stored pos i : full pos = i + n_implicit_zero + * zeros : avg rank = (total_zero + 1) / 2 + * + * Shared-memory layout (doubles): + * grp_sums[n_groups] rank-sum accumulators + * grp_nz_count[n_groups] nonzero-per-group counters + * warp_buf[32] tie-correction reduction scratch + * + * Grid: (sb_cols,) Block: (tpb,) + */ +__global__ void rank_sums_sparse_ovr_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, + const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, const double* __restrict__ group_sizes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ nz_count_scratch, int n_rows, int sb_cols, + int n_groups, bool compute_tie_corr, bool use_gmem) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + int nnz_stored = seg_end - seg_start; + + const float* sv = sorted_vals + seg_start; + const int* si = sorted_row_idx + seg_start; + + extern __shared__ double smem[]; + double* grp_sums; + double* grp_nz_count; + // Accumulator stride: 1 for shared mem (dense per-block), sb_cols for + // gmem (row-major layout (n_groups, sb_cols) shared across blocks). + int acc_stride; + + if (use_gmem) { + // Output rank_sums doubles as accumulator (pre-zeroed by caller). + grp_sums = rank_sums + (size_t)col; + grp_nz_count = nz_count_scratch + (size_t)col; + acc_stride = sb_cols; + } else { + grp_sums = smem; + grp_nz_count = smem + n_groups; + acc_stride = 1; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); + } + + // --- Find stored zero range: pos_start = first val > 0 --- + __shared__ int sh_pos_start; + if (threadIdx.x == 0) { + // Binary search: first index where sv[i] > 0.0 + int lo = 0, hi = nnz_stored; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] <= 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_pos_start = lo; + } + __syncthreads(); + + int pos_start = sh_pos_start; + int n_stored_zero = pos_start; + int n_implicit_zero = n_rows - nnz_stored; + int total_zero = n_implicit_zero + n_stored_zero; + double zero_avg_rank = (total_zero > 0) ? (total_zero + 1.0) / 2.0 : 0.0; + + // Rank offset for positive stored values: + // full_pos(i) = i + n_implicit_zero for i >= pos_start + // So avg_rank for tie group [a,b) of positives: + // = n_implicit_zero + (a + b + 1) / 2 + int offset_pos = n_implicit_zero; + + // --- Count stored positive values per group --- + for (int i = pos_start + threadIdx.x; i < nnz_stored; i += blockDim.x) { + int grp = group_codes[si[i]]; + if (grp < n_groups) { + atomicAdd(&grp_nz_count[grp * acc_stride], 1.0); + } + } + __syncthreads(); + + // --- Zero-rank contribution per group --- + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + double n_zero_in_g = group_sizes[g] - grp_nz_count[g * acc_stride]; + grp_sums[g * acc_stride] = n_zero_in_g * zero_avg_rank; + } + __syncthreads(); + + // --- Walk stored positives only and compute ranks --- + int n_pos = nnz_stored - pos_start; + int chunk = (n_pos + blockDim.x - 1) / blockDim.x; + int my_start = pos_start + threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > nnz_stored) my_end = nnz_stored; + + double local_tie_sum = 0.0; + + int i = my_start; + while (i < my_end) { + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + // Binary search for first occurrence + int lo = pos_start, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < nnz_stored && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = nnz_stored - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + + double avg_rank = (double)offset_pos + + (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Write rank sums to global output (smem path only — gmem path is direct) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } + } + + // Tie correction: warp + block reduction + if (compute_tie_corr) { + // Zero tie group contribution (one thread only) + if (threadIdx.x == 0 && total_zero > 1) { + double tz = (double)total_zero; + local_tie_sum += tz * tz * tz - tz; + } + + // smem path: warp buf after both accumulator arrays (2 * n_groups). + // gmem path: accumulators are in gmem, warp buf starts at smem[0]. + int warp_buf_off = use_gmem ? 0 : 2 * n_groups; + double* warp_buf = smem + warp_buf_off; + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; + } + } + } +} + +/** + * Decide whether the host cast+stats kernels can use per-block shared memory + * accumulators. Large group counts exceed the dynamic smem launch limit, so + * those cases fall back to direct global-memory atomics after zeroing the + * per-stream output buffers. + */ +static int wilcoxon_cast_max_smem_per_block() { + static int cached = -1; + if (cached < 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, + device); + } + return cached; +} + +static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, + bool compute_nnz, bool& use_gmem) { + int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); + size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (need <= (size_t)wilcoxon_cast_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 0; +} + +/** + * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. + * + * Reads a sub-batch block in its native host dtype (InT = float or double), + * writes a float32 copy used as the sort input, and accumulates per-group + * sum, sum-of-squares and nonzero counts in float64. Stats are derived + * from the original-precision values so float64 host input keeps its + * precision while the sort still runs on float32 keys. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles (s_sum, s_sq, s_nnz). + */ +template +__global__ void ovr_cast_and_accumulate_dense_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_dense_global_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +/** + * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. + * + * Sub-batch CSC data is laid out contiguously: values for column c live + * at positions [col_seg_offsets[c], col_seg_offsets[c+1]). For each + * stored value, read the native-dtype InT, write a float32 copy for the + * CUB sort, and accumulate per-group sum/sum-sq/nnz in float64. Implicit + * zeros contribute nothing to any of these stats. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles. + */ +template +__global__ void ovr_cast_and_accumulate_sparse_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_sparse_global_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +template +static void launch_ovr_cast_and_accumulate_dense( + const InT* d_block_orig, float* d_block_f32, const int* d_group_codes, + double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums, + bool compute_nnz, int tpb, size_t smem_cast, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_dense_global_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_global_kernel); + } else { + ovr_cast_and_accumulate_dense_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + } +} + +template +static void launch_ovr_cast_and_accumulate_sparse( + const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, + const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int sb_cols, int n_groups, + bool compute_sq_sums, bool compute_nnz, int tpb, size_t smem_cast, + bool use_gmem, cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_sparse_global_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); + } else { + ovr_cast_and_accumulate_sparse_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + } +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..d399a301 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -21,6 +21,102 @@ ] +class _LazyRankGenesColumn: + def __init__( + self, + values: np.ndarray | None = None, + *, + var_names: np.ndarray | None = None, + gene_indices: np.ndarray | None = None, + dtype: str | np.dtype, + ) -> None: + self._values = values + self._var_names = var_names + self._gene_indices = gene_indices + self._dtype = np.dtype(dtype) + + def __len__(self) -> int: + if self._values is not None: + return int(self._values.shape[0]) + return int(self._gene_indices.shape[0]) + + def __getitem__(self, key): + if self._values is not None: + return self._values[key] + return self._var_names[self._gene_indices[key]] + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + def __array__(self, dtype=None, copy=None) -> np.ndarray: + if self._values is not None: + arr = np.asarray(self._values, dtype=self._dtype) + else: + arr = np.asarray(self._var_names[self._gene_indices], dtype=self._dtype) + if dtype is not None: + arr = np.asarray(arr, dtype=dtype) + if copy: + arr = arr.copy() + return arr + + +class _LazyRankGenesRecords(dict): + def __init__( + self, group_names: np.ndarray, columns: dict[str, object], dtype: str | np.dtype + ) -> None: + super().__init__(columns) + self._group_names = tuple(str(name) for name in group_names) + self._dtype = np.dtype([(name, np.dtype(dtype)) for name in self._group_names]) + + @property + def dtype(self) -> np.dtype: + return self._dtype + + def __getitem__(self, key): + if isinstance(key, str): + return super().__getitem__(key) + return np.asarray(self)[key] + + def __array__(self, dtype=None, copy=None) -> np.ndarray: + out = np.empty(len(next(iter(self.values()))) if self else 0, dtype=self._dtype) + for name in self._group_names: + out[name] = np.asarray(super().__getitem__(name)) + if dtype is not None: + out = np.asarray(out, dtype=dtype) + if copy: + out = out.copy() + return out + + def copy(self) -> np.ndarray: + return np.asarray(self).copy() + + +def _array_result_to_lazy_records( + arrays: dict[str, object], field: str, dtype: str | np.dtype +) -> _LazyRankGenesRecords: + group_names = arrays["group_names"] + values = arrays[field] + columns = { + str(group_name): _LazyRankGenesColumn(values[row], dtype=dtype) + for row, group_name in enumerate(group_names) + } + return _LazyRankGenesRecords(group_names, columns, dtype) + + +def _array_result_to_lazy_names(arrays: dict[str, object]) -> _LazyRankGenesRecords: + group_names = arrays["group_names"] + var_names = arrays["var_names"] + gene_indices = arrays["gene_indices"] + columns = { + str(group_name): _LazyRankGenesColumn( + var_names=var_names, gene_indices=gene_indices[row], dtype=object + ) + for row, group_name in enumerate(group_names) + } + return _LazyRankGenesRecords(group_names, columns, object) + + def rank_genes_groups( adata: AnnData, groupby: str, @@ -37,17 +133,21 @@ def rank_genes_groups( corr_method: _CorrMethod = "benjamini-hochberg", tie_correct: bool = False, use_continuity: bool = False, + return_u_values: bool = False, layer: str | None = None, chunk_size: int | None = None, pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + skip_empty_groups: bool = False, **kwds, ) -> None: """ Rank genes for characterizing groups using GPU acceleration. - Expects logarithmized data. + Expects nonnegative expression data. Log1p/log-normalized data is expected + for biologically meaningful log fold changes; sparse inputs with explicit + negative values are rejected. .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and @@ -101,6 +201,10 @@ def rank_genes_groups( z-scores. Subtracts 0.5 from ``|R - E[R]|`` before dividing by the standard deviation, matching :func:`scipy.stats.mannwhitneyu` default behavior. + return_u_values + For `'wilcoxon'`, store Mann-Whitney U statistics in `scores` instead + of z-scores. P-values are still computed from the z-score normal + approximation using the selected tie and continuity settings. layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size @@ -119,15 +223,22 @@ def rank_genes_groups( ``None`` (default) uses ``'auto'`` for in-memory arrays and ``'log1p'`` for Dask arrays (to avoid a costly data scan). ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. - ``'auto'`` computes the actual data range. Use this for z-scored - or unnormalized data. + ``'auto'`` computes the actual data range. Use this for nonnegative + expression data outside the fixed log1p range. + skip_empty_groups + Skip selected groups with fewer than two observations after filtering. + This is useful for perturbation workflows where a per-cell-type slice + keeps categories that are empty or singleton in that slice. **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. Returns ------- - Updates `adata` with the following fields: + Updates `adata` with the following fields. Rank result fields are lazy + Scanpy-compatible record objects: group fields can be indexed like + structured arrays, while full structured arrays are materialized only when + requested through NumPy conversion or `.copy()`. `adata.uns['rank_genes_groups' | key_added]['names']` Structured array to be indexed by group id storing the gene @@ -135,7 +246,8 @@ def rank_genes_groups( `adata.uns['rank_genes_groups' | key_added]['scores']` Structured array to be indexed by group id storing the z-score underlying the computation of a p-value for each gene for each - group. Ordered according to scores. + group, or the Mann-Whitney U statistic when + `return_u_values=True`. Ordered according to scores. `adata.uns['rank_genes_groups' | key_added]['logfoldchanges']` Structured array to be indexed by group id storing the log2 fold change for each gene for each group. @@ -154,6 +266,13 @@ def rank_genes_groups( msg = "corr_method must be either 'benjamini-hochberg' or 'bonferroni'." raise ValueError(msg) + if "return_format" in kwds: + msg = ( + "return_format has been removed; rank_genes_groups always writes " + "lazy Scanpy-compatible results to adata.uns." + ) + raise TypeError(msg) + if method is None: method = "t-test" @@ -170,6 +289,10 @@ def rank_genes_groups( ) raise ValueError(msg) + if return_u_values and method != "wilcoxon": + msg = "return_u_values is only supported for method='wilcoxon'." + raise ValueError(msg) + if key_added is None: key_added = "rank_genes_groups" @@ -197,6 +320,7 @@ def rank_genes_groups( layer=layer, comp_pts=pts, pre_load=pre_load, + skip_empty_groups=skip_empty_groups, ) # Determine n_genes_user @@ -211,25 +335,14 @@ def rank_genes_groups( rankby_abs=rankby_abs, tie_correct=tie_correct, use_continuity=use_continuity, + return_u_values=return_u_values, chunk_size=chunk_size, n_bins=n_bins, bin_range=bin_range, **kwds, ) - # Build output - test_obj.stats.columns = test_obj.stats.columns.swaplevel() - - dtypes = { - "names": "U50", - "scores": "float32", - "logfoldchanges": "float32", - "pvals": "float64", - "pvals_adj": "float64", - } - - adata.uns[key_added] = {} - adata.uns[key_added]["params"] = { + params = { "groupby": groupby, "reference": reference, "method": method, @@ -237,8 +350,28 @@ def rank_genes_groups( "layer": layer, "corr_method": corr_method, } + if method == "wilcoxon": + params["tie_correct"] = tie_correct + params["return_u_values"] = return_u_values + + arrays = test_obj.stats_arrays or {} + adata.uns[key_added] = {"params": params} + if arrays and len(arrays.get("group_names", ())) > 0: + adata.uns[key_added]["names"] = _array_result_to_lazy_names(arrays) + for col, dtype in { + "scores": "float32", + "logfoldchanges": "float32", + "pvals": "float64", + "pvals_adj": "float64", + }.items(): + if col in arrays: + values = arrays[col] + if hasattr(values, "dtype"): + dtype = values.dtype + adata.uns[key_added][col] = _array_result_to_lazy_records( + arrays, col, dtype + ) - # Store pts results if computed if test_obj.pts is not None: groups_names = [str(name) for name in test_obj.groups_order] adata.uns[key_added]["pts"] = pd.DataFrame( @@ -249,14 +382,7 @@ def rank_genes_groups( test_obj.pts_rest.T, index=test_obj.var_names, columns=groups_names ) - if method == "wilcoxon": - adata.uns[key_added]["params"]["tie_correct"] = tie_correct - - for col in test_obj.stats.columns.levels[0]: - if col in dtypes: - adata.uns[key_added][col] = test_obj.stats[col].to_records( - index=False, column_dtypes=dtypes[col] - ) + return None if TYPE_CHECKING: @@ -285,7 +411,7 @@ def rank_genes_groups_logreg( layer: str | None = None, **kwds, ) -> None: - rank_genes_groups( + return rank_genes_groups( adata, groupby, groups=groups, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index c65bbf7c..acfbe2e2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -1,18 +1,42 @@ from __future__ import annotations +import os +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Literal, assert_never import cupy as cp import numpy as np import pandas as pd -from statsmodels.stats.multitest import multipletests from rapids_singlecell._compat import DaskArray from rapids_singlecell.get import X_to_GPU from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _select_groups, _select_top_n +from ._utils import EPS, _check_sparse_nonnegative, _select_groups + +_FDR_BH_REVERSE_CUMMIN_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ void fdr_bh_reverse_cummin(double* values, const int n_cols) { + const int row = blockIdx.x; + double running = 1.0; + double* row_values = values + static_cast(row) * n_cols; + for (int col = n_cols - 1; col >= 0; --col) { + double value = row_values[col]; + if (!(value == value)) { + value = 1.0; + } + if (value < running) { + running = value; + } + row_values[col] = running; + } +} +""", + "fdr_bh_reverse_cummin", +) +_RANK_SORT_MIN_ELEMENTS = 1_000_000 +_RANK_SORT_MAX_WORKERS = 64 if TYPE_CHECKING: from collections.abc import Iterable @@ -38,6 +62,7 @@ def __init__( layer: str | None = None, comp_pts: bool = False, pre_load: bool = False, + skip_empty_groups: bool = False, ) -> None: # Handle groups parameter if groups == "all" or groups is None: @@ -63,7 +88,10 @@ def __init__( raise ValueError(msg) self.groups_order, self.group_codes, self.group_sizes = _select_groups( - self.labels, selected + self.labels, + selected, + reference=reference, + skip_empty_groups=skip_empty_groups, ) # Get data matrix @@ -91,6 +119,8 @@ def __init__( self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] + _check_sparse_nonnegative(self.X) + self.pre_load = pre_load self.ireference = None @@ -100,6 +130,7 @@ def __init__( # Set up expm1 function based on log base self.is_log1p = "log1p" in adata.uns base = adata.uns.get("log1p", {}).get("base") + self._log1p_base = base if base is not None: self.expm1_func = lambda x: np.expm1(x * np.log(base)) else: @@ -115,8 +146,14 @@ def __init__( self.pts_rest: np.ndarray | None = None self.stats: pd.DataFrame | None = None + self.stats_arrays: dict[str, object] | None = None + self._store_wilcoxon_gpu_result = False + self._wilcoxon_gpu_result: ( + tuple[np.ndarray, cp.ndarray, cp.ndarray, cp.ndarray | None] | None + ) = None self._compute_stats_in_chunks: bool = False self._ref_chunk_computed: set[int] = set() + self._score_dtype = np.dtype(np.float32) def _init_stats_arrays(self, n_genes: int) -> None: """Pre-allocate stats arrays before chunk loop.""" @@ -190,16 +227,18 @@ def _basic_stats(self) -> None: # Compute rest statistics if reference='rest' if self.ireference is None: - n_rest = n.sum() - n - means_rest = (sums.sum(axis=0) - sums) / n_rest - rest_ss = (sq_sums.sum(axis=0) - sq_sums) - n_rest * means_rest**2 + n_rest = cp.float64(self.X.shape[0]) - n + total_sums = result["sum"].sum(axis=0, keepdims=True) + total_sq_sums = result["sq_sum"].sum(axis=0, keepdims=True) + means_rest = (total_sums - sums) / n_rest + rest_ss = (total_sq_sums - sq_sums) - n_rest * means_rest**2 vars_rest = cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0) self.means_rest = cp.asnumpy(means_rest) self.vars_rest = cp.asnumpy(vars_rest) if self.comp_pts: - total_count = (pts * n).sum(axis=0) + total_count = result["count_nonzero"].sum(axis=0, keepdims=True) self.pts_rest = cp.asnumpy((total_count - pts * n) / n_rest) else: self.pts_rest = None @@ -325,6 +364,7 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" from ._wilcoxon import wilcoxon @@ -334,6 +374,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) def wilcoxon_binned( @@ -375,6 +416,7 @@ def compute_statistics( chunk_size: int | None = None, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + return_u_values: bool = False, **kwds, ) -> None: """Compute statistics for all groups.""" @@ -385,17 +427,28 @@ def compute_statistics( }: self.X = X_to_GPU(self.X) + n_genes = self.X.shape[1] + if n_genes_user is None: + n_genes_user = n_genes + if method in {"t-test", "t-test_overestim_var"}: test_results = self.t_test(method) elif method == "wilcoxon": if isinstance(self.X, DaskArray): msg = "Wilcoxon test is not supported for Dask arrays. Please convert your data to CuPy arrays." raise ValueError(msg) - test_results = self.wilcoxon( - tie_correct=tie_correct, - use_continuity=use_continuity, - chunk_size=chunk_size, - ) + self._score_dtype = np.dtype(np.float64 if return_u_values else np.float32) + self._wilcoxon_gpu_result = None + self._store_wilcoxon_gpu_result = n_genes_user is not None + try: + test_results = self.wilcoxon( + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + finally: + self._store_wilcoxon_gpu_result = False elif method == "wilcoxon_binned": test_results = self.wilcoxon_binned( tie_correct=tie_correct, @@ -409,58 +462,225 @@ def compute_statistics( else: assert_never(method) - n_genes = self.X.shape[1] + if not test_results and self._wilcoxon_gpu_result is None: + self.stats_arrays = { + "group_indices": np.empty(0, dtype=np.intp), + "group_names": np.empty(0, dtype=object), + "var_names": np.asarray(self.var_names), + "gene_indices": np.empty((0, n_genes_user), dtype=np.intp), + } + self.stats = None + return + + if self._wilcoxon_gpu_result is not None: + group_indices, scores_gpu, pvals_gpu, logfoldchanges_gpu = ( + self._wilcoxon_gpu_result + ) + try: + self._compute_statistics_gpu_arrays( + group_indices, + scores_gpu, + pvals_gpu, + logfoldchanges_gpu, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) + finally: + self._wilcoxon_gpu_result = None + return - # Collect all stats data first to avoid DataFrame fragmentation - stats_data: dict[tuple[str, str], np.ndarray] = {} + self._compute_statistics_arrays( + test_results, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) - for group_index, scores, pvals in test_results: - group_name = str(self.groups_order[group_index]) + @staticmethod + def _rank_indices_matrix(scores: np.ndarray, n_top: int) -> np.ndarray: + if n_top >= scores.shape[1]: + return _RankGenes._argsort_desc_matrix(scores) + partition = np.argpartition(scores, -n_top, axis=1)[:, -n_top:] + row_ids = np.arange(scores.shape[0])[:, None] + order = np.argsort(scores[row_ids, partition], axis=1)[:, ::-1] + return partition[row_ids, order] + + @staticmethod + def _argsort_desc_matrix(scores: np.ndarray) -> np.ndarray: + n_rows, n_cols = scores.shape + n_elements = n_rows * n_cols + n_workers = min(_RANK_SORT_MAX_WORKERS, os.cpu_count() or 1, n_rows) + if n_workers <= 1 or n_elements < _RANK_SORT_MIN_ELEMENTS: + return np.argsort(scores, axis=1)[:, ::-1] + + chunks = np.linspace(0, n_rows, n_workers + 1, dtype=np.intp) + indices = np.empty((n_rows, n_cols), dtype=np.intp) + + def sort_chunk(chunk_index: int) -> None: + start = int(chunks[chunk_index]) + stop = int(chunks[chunk_index + 1]) + if start < stop: + indices[start:stop] = np.argsort(scores[start:stop], axis=1)[:, ::-1] + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + list(executor.map(sort_chunk, range(n_workers))) + return indices + + @staticmethod + def _fdr_bh_matrix(pvals: np.ndarray) -> np.ndarray: + pvals_clean = np.array(pvals, copy=True) + pvals_clean[np.isnan(pvals_clean)] = 1.0 + order = np.argsort(pvals_clean, axis=1) + sorted_p = np.take_along_axis(pvals_clean, order, axis=1) + n_tests = sorted_p.shape[1] + scale = n_tests / np.arange(1, n_tests + 1, dtype=np.float64) + corrected_sorted = sorted_p * scale + corrected_sorted = np.minimum.accumulate(corrected_sorted[:, ::-1], axis=1)[ + :, ::-1 + ] + corrected_sorted[corrected_sorted > 1.0] = 1.0 + corrected = np.empty_like(corrected_sorted) + np.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected + + @staticmethod + def _fdr_bh_matrix_gpu(pvals: cp.ndarray) -> cp.ndarray: + pvals_clean = cp.nan_to_num(pvals, nan=1.0) + order = cp.argsort(pvals_clean, axis=1) + corrected_sorted = cp.take_along_axis(pvals_clean, order, axis=1) + corrected_sorted *= corrected_sorted.shape[1] / cp.arange( + 1, corrected_sorted.shape[1] + 1, dtype=cp.float64 + ) + _FDR_BH_REVERSE_CUMMIN_KERNEL( + (corrected_sorted.shape[0],), + (1,), + (corrected_sorted, np.int32(corrected_sorted.shape[1])), + ) + corrected = cp.empty_like(corrected_sorted) + cp.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected - if n_genes_user is not None: - scores_sort = np.abs(scores) if rankby_abs else scores - global_indices = _select_top_n(scores_sort, n_genes_user) + def _compute_statistics_arrays( + self, + test_results: list[tuple[int, NDArray, NDArray]], + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray([r[0] for r in test_results], dtype=np.intp) + scores = np.vstack([r[1] for r in test_results]) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": np.take_along_axis(scores, top_idx, axis=1).astype( + self._score_dtype, copy=False + ), + } + + if test_results[0][2] is not None: + pvals = np.vstack([r[2] for r in test_results]) + arrays["pvals"] = np.take_along_axis(pvals, top_idx, axis=1) + if corr_method == "benjamini-hochberg": + pvals_adj = self._fdr_bh_matrix(pvals) + elif corr_method == "bonferroni": + pvals_adj = np.minimum(pvals * n_genes, 1.0) else: - global_indices = slice(None) - - if n_genes_user is not None: - stats_data[group_name, "names"] = np.asarray(self.var_names)[ - global_indices - ] - - stats_data[group_name, "scores"] = scores[global_indices] - - if pvals is not None: - stats_data[group_name, "pvals"] = pvals[global_indices] - if corr_method == "benjamini-hochberg": - pvals_clean = np.array(pvals, copy=True) - pvals_clean[np.isnan(pvals_clean)] = 1.0 - _, pvals_adj, _, _ = multipletests( - pvals_clean, alpha=0.05, method="fdr_bh" - ) - elif corr_method == "bonferroni": - pvals_adj = np.minimum(pvals * n_genes, 1.0) - stats_data[group_name, "pvals_adj"] = pvals_adj[global_indices] - - # Compute logfoldchanges - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = np.take_along_axis(pvals_adj, top_idx, axis=1) + + if self.means is not None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) + + self.stats_arrays = arrays + self.stats = None + + def _compute_statistics_gpu_arrays( + self, + group_indices: np.ndarray, + scores_gpu: cp.ndarray, + pvals_gpu: cp.ndarray, + logfoldchanges_gpu: cp.ndarray | None, + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray(group_indices, dtype=np.intp) + scores = cp.asnumpy(scores_gpu) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + top_idx_gpu = cp.asarray(top_idx) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": cp.asnumpy( + cp.take_along_axis(scores_gpu, top_idx_gpu, axis=1).astype( + self._score_dtype, copy=False ) - stats_data[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] + ), + "pvals": cp.asnumpy(cp.take_along_axis(pvals_gpu, top_idx_gpu, axis=1)), + } + + if corr_method == "benjamini-hochberg": + pvals_adj_gpu = self._fdr_bh_matrix_gpu(pvals_gpu) + elif corr_method == "bonferroni": + pvals_adj_gpu = cp.minimum(pvals_gpu * n_genes, 1.0) + else: + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = cp.asnumpy( + cp.take_along_axis(pvals_adj_gpu, top_idx_gpu, axis=1) + ) + + if logfoldchanges_gpu is not None: + arrays["logfoldchanges"] = cp.asnumpy( + cp.take_along_axis(logfoldchanges_gpu, top_idx_gpu, axis=1).astype( + cp.float32, copy=False ) + ) + elif self.means is not None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) - # Create DataFrame all at once to avoid fragmentation - if stats_data: - self.stats = pd.DataFrame(stats_data) - self.stats.columns = pd.MultiIndex.from_tuples(self.stats.columns) - if n_genes_user is None: - self.stats.index = self.var_names - else: - self.stats = None + self.stats_arrays = arrays + self.stats = None diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index c4f2c601..4ec37e40 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -18,9 +18,38 @@ MAX_THREADS_PER_BLOCK = 512 +def _check_sparse_nonnegative(X) -> None: + """Reject sparse matrices with explicit negative values. + + Sparse rank_genes_groups code treats missing entries as true expression + zeros. Optimized sparse Wilcoxon paths may rank explicit nonzeros and add + implicit zeros analytically, which is only valid when explicit sparse + values are nonnegative expression values. + """ + if sp.issparse(X): + if X.nnz > 0 and float(X.data.min()) < 0: + msg = ( + "Sparse input contains negative values. rank_genes_groups " + "expects nonnegative expression values; use raw counts or " + "log1p/log-normalized expression, not scaled or centered data." + ) + raise ValueError(msg) + elif cpsp.issparse(X): + if X.nnz > 0 and float(X.data.min()) < 0: + msg = ( + "Sparse input contains negative values. rank_genes_groups " + "expects nonnegative expression values; use raw counts or " + "log1p/log-normalized expression, not scaled or centered data." + ) + raise ValueError(msg) + + def _select_groups( labels: pd.Series, selected: list | None, + *, + reference: str = "rest", + skip_empty_groups: bool = False, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: """Build integer group codes from a categorical Series. @@ -51,6 +80,29 @@ def _select_groups( cat_order = {str(c): i for i, c in enumerate(all_categories)} selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + if skip_empty_groups: + counts = { + str(name): int(count) for name, count in labels.value_counts().items() + } + valid_selected = [group for group in selected if counts.get(str(group), 0) >= 2] + if reference != "rest": + ref_matches = [group for group in selected if str(group) == str(reference)] + if ref_matches: + ref_group = ref_matches[0] + if ref_group not in valid_selected: + msg = ( + f"reference = {reference} has fewer than two samples after " + "filtering and cannot be used for rank_genes_groups." + ) + raise ValueError(msg) + selected = valid_selected + if len(selected) == 0: + msg = ( + "No groups with at least two samples remain after applying " + "skip_empty_groups=True." + ) + raise ValueError(msg) + n_groups = len(selected) groups_order = np.array(selected) @@ -76,7 +128,7 @@ def _select_groups( if invalid_groups: msg = ( f"Could not calculate statistics for groups {', '.join(invalid_groups)} " - "since they only contain one sample." + "since they contain fewer than two samples." ) raise ValueError(msg) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index c14c760d..e20af614 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -4,14 +4,15 @@ from typing import TYPE_CHECKING import cupy as cp +import cupyx.scipy.sparse as cpsp import cupyx.scipy.special as cupyx_special import numpy as np import scipy.sparse as sp from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc +from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import _choose_chunk_size, _get_column_block +from ._utils import EPS, _choose_chunk_size, _get_column_block if TYPE_CHECKING: from numpy.typing import NDArray @@ -19,6 +20,14 @@ from ._core import _RankGenes MIN_GROUP_SIZE_WARNING = 25 +DEFAULT_WILCOXON_CHUNK_SIZE = 512 +OVO_SORT_GROUP_THRESHOLD = 512 +OVR_HOST_CSC_SUB_BATCH = 512 +OVR_HOST_CSR_SUB_BATCH = 2048 +OVR_DEVICE_CSC_SUB_BATCH = 2048 +OVR_DEVICE_CSR_SUB_BATCH = 2048 +OVO_HOST_SPARSE_SUB_BATCH = 256 +OVO_DEVICE_SPARSE_SUB_BATCH = 128 def _average_ranks( @@ -86,12 +95,307 @@ def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: return correction +def _extract_dense_rows_cols( + X, row_ids: np.ndarray, start: int, stop: int +) -> cp.ndarray: + """Extract a bounded row/column block as F-order CuPy dense memory.""" + if isinstance(X, np.ndarray): + return cp.asarray(X[row_ids, start:stop], order="F") + if isinstance(X, cp.ndarray): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows, start:stop]) + if isinstance(X, sp.spmatrix | sp.sparray): + return cp.asarray(X[row_ids][:, start:stop].toarray(), order="F") + if cpsp.issparse(X): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows][:, start:stop].toarray()) + raise TypeError(f"Unsupported matrix type: {type(X)}") + + +def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: + if requested is not None: + return _choose_chunk_size(requested) + return min(DEFAULT_WILCOXON_CHUNK_SIZE, max(1, n_genes)) + + +def _fill_ovo_chunk_stats( + rg: _RankGenes, + ref_block: cp.ndarray, + grp_block: cp.ndarray, + *, + offsets: np.ndarray, + test_group_indices: list[int], + start: int, + stop: int, + group_sizes: NDArray, +) -> None: + if not rg._compute_stats_in_chunks: + return + + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) + ref_mean = ref_block.mean(axis=0) + rg.means[ireference, start:stop] = cp.asnumpy(ref_mean) + if n_ref > 1: + rg.vars[ireference, start:stop] = cp.asnumpy(ref_block.var(axis=0, ddof=1)) + if rg.comp_pts: + ref_nnz = (ref_block != 0).sum(axis=0) + rg.pts[ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) + + for slot, group_index in enumerate(test_group_indices): + begin = int(offsets[slot]) + end = int(offsets[slot + 1]) + n_group = int(group_sizes[group_index]) + group_block = grp_block[begin:end] + group_mean = group_block.mean(axis=0) + rg.means[group_index, start:stop] = cp.asnumpy(group_mean) + if n_group > 1: + rg.vars[group_index, start:stop] = cp.asnumpy( + group_block.var(axis=0, ddof=1) + ) + if rg.comp_pts: + group_nnz = (group_block != 0).sum(axis=0) + rg.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) + + +def _fill_basic_stats_from_accumulators( + rg: _RankGenes, + group_sums: cp.ndarray, + group_sq_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: np.ndarray, + *, + n_cells: int, + compute_vars: bool, + total_sums: cp.ndarray | None = None, + total_sq_sums: cp.ndarray | None = None, + total_nnz: cp.ndarray | None = None, +) -> None: + n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] + means = group_sums / n + rg.means = cp.asnumpy(means) + if compute_vars: + group_ss = group_sq_sums - n * means**2 + rg.vars = cp.asnumpy(cp.maximum(group_ss / cp.maximum(n - 1, 1), 0)) + else: + rg.vars = np.zeros_like(rg.means) + rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None + + n_rest = cp.float64(n_cells) - n + if total_sums is None: + total_sums = group_sums.sum(axis=0, keepdims=True) + rest_sums = total_sums - group_sums + rest_means = rest_sums / n_rest + rg.means_rest = cp.asnumpy(rest_means) + if compute_vars: + if total_sq_sums is None: + total_sq_sums = group_sq_sums.sum(axis=0, keepdims=True) + rest_ss = (total_sq_sums - group_sq_sums) - n_rest * rest_means**2 + rg.vars_rest = cp.asnumpy(cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)) + else: + rg.vars_rest = np.zeros_like(rg.means_rest) + if rg.comp_pts: + if total_nnz is None: + total_nnz = group_nnz.sum(axis=0, keepdims=True) + rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) + else: + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _fill_ovo_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sq_sums_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, + compute_vars: bool, +) -> None: + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + slot_group_indices = np.empty(n_test + 1, dtype=np.intp) + slot_group_indices[:n_test] = np.asarray(test_group_indices, dtype=np.intp) + slot_group_indices[n_test] = rg.ireference + slot_sizes = np.empty(n_test + 1, dtype=np.float64) + slot_sizes[:n_test] = group_sizes[slot_group_indices[:n_test]] + slot_sizes[n_test] = n_ref + slot_sizes_dev = cp.asarray(slot_sizes, dtype=cp.float64)[:, None] + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + means_slots = group_sums_slots / slot_sizes_dev + rg.means[slot_group_indices] = cp.asnumpy(means_slots) + if compute_vars: + group_ss = group_sq_sums_slots - slot_sizes_dev * means_slots**2 + denom = cp.maximum(slot_sizes_dev - 1.0, 1.0) + rg.vars[slot_group_indices] = cp.asnumpy(cp.maximum(group_ss / denom, 0)) + if rg.comp_pts: + rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) + + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False + + +def _ovo_logfoldchanges_from_sums( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + test_sizes: cp.ndarray, + n_ref: int, +) -> cp.ndarray: + n_test = int(test_sizes.shape[0]) + mean_group = group_sums_slots[:n_test] / test_sizes[:, None] + mean_ref = group_sums_slots[n_test][None, :] / cp.float64(n_ref) + if rg._log1p_base is not None: + scale = cp.float64(np.log(rg._log1p_base)) + group_expr = cp.expm1(mean_group * scale) + ref_expr = cp.expm1(mean_ref * scale) + else: + group_expr = cp.expm1(mean_group) + ref_expr = cp.expm1(mean_ref) + return cp.log2((group_expr + EPS) / (ref_expr + EPS)) + + +def _wilcoxon_scores( + rank_sums: cp.ndarray, + group_sizes: cp.ndarray, + z_scores: cp.ndarray, + *, + return_u_values: bool, +) -> cp.ndarray: + if not return_u_values: + return z_scores + n_group = group_sizes[:, None] + return rank_sums - n_group * (n_group + 1.0) / 2.0 + + +def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): + is_f64 = X.data.dtype == np.float64 + is_idx64 = support_idx64 and X.indices.dtype == np.int64 + is_i64 = X.indptr.dtype == np.int64 + suffix = "" + if is_f64: + suffix += "_f64" + if is_idx64: + suffix += "_idx64" + if is_i64: + suffix += "_i64" + fn = getattr(module, base_name + suffix) + data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) + indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + return fn, data_arr, indices_arr + + +def _device_sparse_arrays_i32_f32(X): + if X.indptr.dtype != cp.int32: + max_indptr = int(cp.asnumpy(X.indptr[-1])) + if max_indptr > np.iinfo(np.int32).max: + return None + data = X.data.astype(cp.float32, copy=False) + indices = X.indices.astype(cp.int32, copy=False) + indptr = X.indptr.astype(cp.int32, copy=False) + return data, indices, indptr + + +def _column_totals_for_host_matrix( + X, *, compute_sq_sums: bool, compute_nnz: bool +) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: + n_cols = X.shape[1] + if isinstance(X, sp.spmatrix | sp.sparray): + data = np.asarray(X.data) + values = data.astype(np.float64, copy=False) + if X.format == "csc": + indptr = np.asarray(X.indptr) + counts = np.diff(indptr) + nonempty = counts > 0 + starts = indptr[:-1][nonempty] + sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sums[nonempty] = np.add.reduceat(values, starts) + sq_sums = None + if compute_sq_sums: + sq_sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sq_sums[nonempty] = np.add.reduceat(values * values, starts) + nnz = None + if compute_nnz: + nnz = np.zeros(n_cols, dtype=np.float64) + if starts.size: + nnz[nonempty] = np.add.reduceat( + (data != 0).astype(np.float64, copy=False), starts + ) + elif X.format == "csr": + indices = np.asarray(X.indices, dtype=np.intp) + sums = np.bincount(indices, weights=values, minlength=n_cols).astype( + np.float64, copy=False + ) + sq_sums = ( + np.bincount(indices, weights=values * values, minlength=n_cols).astype( + np.float64, copy=False + ) + if compute_sq_sums + else None + ) + nnz = ( + np.bincount( + indices, + weights=(data != 0).astype(np.float64, copy=False), + minlength=n_cols, + ).astype(np.float64, copy=False) + if compute_nnz + else None + ) + else: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + else: + raise TypeError(f"Unsupported host matrix type: {type(X)}") + + total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) + total_sq_sums = ( + cp.asarray(sq_sums.reshape(1, n_cols), dtype=cp.float64) + if sq_sums is not None + else None + ) + total_nnz = ( + cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) + if nnz is not None + else None + ) + return total_sums, total_sq_sums, total_nnz + + +def _host_ovr_totals_if_needed( + X, + group_codes: np.ndarray, + n_groups: int, + *, + compute_sq_sums: bool, + compute_nnz: bool, +) -> tuple[cp.ndarray | None, cp.ndarray | None, cp.ndarray | None]: + if not np.any(group_codes == n_groups): + return None, None, None + return _column_totals_for_host_matrix( + X, compute_sq_sums=compute_sq_sums, compute_nnz=compute_nnz + ) + + def wilcoxon( rg: _RankGenes, *, tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" # Compute basic stats - uses Aggregate if on GPU, else defers to chunks @@ -110,6 +414,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( @@ -121,6 +426,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) @@ -134,6 +440,7 @@ def _wilcoxon_vs_rest( tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs rest of cells.""" n_groups = len(rg.groups_order) @@ -149,26 +456,203 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - # Build one-hot indicator matrix from group codes - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + if host_sparse: + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + + group_codes = rg.group_codes.astype(np.int32, copy=False) + group_sizes_np = group_sizes.astype(np.float64, copy=False) + group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_vars = False + compute_nnz = rg.comp_pts + + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups, n_total_genes) if compute_vars else (1, 1), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, + ) + + if X.format == "csc": + csc = X + if not csc.has_sorted_indices: + csc = csc.copy() + csc.sort_indices() + csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovr_sparse_csc_host", csc, support_idx64=False + ) + csc_host_fn( + data_arr, + indices_arr, + csc.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, + ) + else: + csr = X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() + csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovr_sparse_csr_host", csr, support_idx64=True + ) + csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, + ) + + if rg._compute_stats_in_chunks: + total_sums, total_sq_sums, total_nnz = _host_ovr_totals_if_needed( + X, + group_codes, + n_groups, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + ) + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes_np, + n_cells=n_cells, + compute_vars=compute_vars, + total_sums=total_sums, + total_sq_sums=total_sq_sums, + total_nnz=total_nnz, + ) + + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores_host = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ).get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + sparse_arrays = _device_sparse_arrays_i32_f32(X) + if sparse_arrays is not None: + data, indices, indptr = sparse_arrays + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + if cpsp.isspmatrix_csc(X): + _wcs.ovr_sparse_csc_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, + ) + else: + sparse_X = X + if not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + data, indices, indptr = _device_sparse_arrays_i32_f32(sparse_X) + _wcs.ovr_sparse_csr_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, + ) + + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = ( + tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + ) + variance *= (n_cells + 1) / 12.0 + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores_host = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ).get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_matrix = None + if rg._compute_stats_in_chunks: + codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) + group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) + valid_idx = cp.where(codes_gpu < n_groups)[0] + group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - chunk_width = _choose_chunk_size(chunk_size) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) # Accumulate results per group all_scores: dict[int, list] = {i: [] for i in range(n_groups)} all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() - for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) @@ -185,14 +669,28 @@ def _wilcoxon_vs_rest( n_cells=n_cells, ) - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) - else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) - - rank_sums = group_matrix.T @ ranks + block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) + sorter = cp.asfortranarray(cp.argsort(block_f32, axis=0).astype(cp.int32)) + sorted_vals = cp.asfortranarray(cp.take_along_axis(block_f32, sorter, axis=0)) + n_cols = stop - start + rank_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + tie_corr = ( + cp.empty(n_cols, dtype=cp.float64) + if tie_correct + else cp.ones(n_cols, dtype=cp.float64) + ) + _wc.ovr_rank_dense( + sorted_vals, + sorter, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_cols, + n_groups=n_groups, + compute_tie_corr=tie_correct, + stream=cp.cuda.get_current_stream().ptr, + ) expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] variance *= (n_cells + 1) / 12.0 @@ -203,12 +701,15 @@ def _wilcoxon_vs_rest( z = diff / std cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ) - z_host = z.get() + scores_host = scores.get() p_host = p_values.get() for idx in range(n_groups): - all_scores[idx].append(z_host[idx]) + all_scores[idx].append(scores_host[idx]) all_pvals[idx].append(p_host[idx]) # Collect results per group @@ -227,98 +728,379 @@ def _wilcoxon_with_reference( tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs a specific reference group.""" + """Wilcoxon test: all selected groups vs a specific reference group.""" codes = rg.group_codes - n_ref = int(group_sizes[rg.ireference]) - mask_ref = codes == rg.ireference - - results: list[tuple[int, NDArray, NDArray]] = [] + n_groups = len(rg.groups_order) + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) + ref_row_ids = np.flatnonzero(codes == ireference).astype(np.int32, copy=False) - for group_index in range(len(rg.groups_order)): - if group_index == rg.ireference: - continue + test_group_indices = [i for i in range(n_groups) if i != ireference] + if not test_group_indices: + return [] - n_group = int(group_sizes[group_index]) - n_combined = n_group + n_ref + offsets = [0] + row_id_parts = [] + small_groups = [] + for group_index in test_group_indices: + group_rows = np.flatnonzero(codes == group_index).astype(np.int32, copy=False) + row_id_parts.append(group_rows) + offsets.append(offsets[-1] + int(group_rows.size)) + if int(group_sizes[group_index]) <= MIN_GROUP_SIZE_WARNING: + small_groups.append(str(rg.groups_order[group_index])) - # Warn for small groups - if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: - warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, + if n_ref <= MIN_GROUP_SIZE_WARNING or small_groups: + parts = [] + if small_groups: + parts.append( + f"{len(small_groups)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_groups[:5])}" + f"{'...' if len(small_groups) > 5 else ''})" ) + if n_ref <= MIN_GROUP_SIZE_WARNING: + parts.append(f"reference has size {n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) - # Combined mask: group + reference - mask_obs = codes == group_index - mask_combined = mask_obs | mask_ref - - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] + all_grp_row_ids = ( + np.concatenate(row_id_parts).astype(np.int32, copy=False) + if row_id_parts + else np.empty(0, dtype=np.int32) + ) + offsets_np = np.asarray(offsets, dtype=np.int32) + offsets_gpu = cp.asarray(offsets_np) + n_all_grp = int(all_grp_row_ids.size) + n_test = len(test_group_indices) + max_test_size = int(np.diff(offsets_np).max(initial=0)) + use_presorted_groups = max_test_size > OVO_SORT_GROUP_THRESHOLD + test_sizes = cp.asarray( + group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( + np.float64, copy=False + ) + ) - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + if host_sparse: + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." ) - # Within the combined array, True = group cell, False = reference cell - group_mask_gpu = cp.asarray(mask_obs[mask_combined]) - - chunk_width = _choose_chunk_size(chunk_size) - - # Pre-allocate output arrays - scores = np.empty(n_total_genes, dtype=np.float64) - pvals = np.empty(n_total_genes, dtype=np.float64) - - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + n_groups_stats = n_test + 1 + compute_vars = False + compute_sums = rg._compute_stats_in_chunks + compute_nnz = rg.comp_pts + group_sums = cp.empty( + (n_groups_stats, n_total_genes) + if (compute_sums or X.format == "csc") + else (1,), + dtype=cp.float64, + ) + group_sq_sums = cp.empty( + (n_groups_stats, n_total_genes) if compute_vars else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) + stats_code_lookup = np.full(n_groups + 1, n_groups_stats, dtype=np.int32) + test_group_indices_np = np.asarray(test_group_indices, dtype=np.intp) + stats_code_lookup[test_group_indices_np] = np.arange(n_test, dtype=np.int32) + stats_code_lookup[ireference] = n_test + stats_codes = stats_code_lookup[codes] - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, + if X.format == "csc": + csc = X + if not csc.has_sorted_indices: + csc = csc.copy() + csc.sort_indices() + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovo_streaming_csc_host", csc, support_idx64=True + ) + csc_host_fn( + data_arr, + indices_arr, + csc.indptr, + ref_row_map, + grp_row_map, + offsets_np, + stats_codes, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=X.shape[0], + n_cols=n_total_genes, + n_groups=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, + ) + else: + csr = X + csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovo_streaming_csr_host", csr, support_idx64=True + ) + csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + ref_row_ids.astype(np.int32, copy=False), + all_grp_row_ids.astype(np.int32, copy=False), + offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_full_rows=X.shape[0], n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_test=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + compute_sums=compute_sums, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, ) - # Ranks for combined group+reference cells - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) + logfoldchanges_gpu = None + if rg._compute_stats_in_chunks: + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + logfoldchanges_gpu = _ovo_logfoldchanges_from_sums( + rg, + group_sums, + test_sizes, + n_ref, + ) + rg._compute_stats_in_chunks = False else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=test_group_indices, + n_ref=n_ref, + compute_vars=compute_vars, + ) - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + logfoldchanges_gpu, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + sparse_X = X + if cpsp.isspmatrix_csr(sparse_X) and not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + sparse_arrays = _device_sparse_arrays_i32_f32(sparse_X) + if sparse_arrays is not None: + data, indices, indptr = sparse_arrays + offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + + if cpsp.isspmatrix_csc(sparse_X): + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_device( + data, + indices, + indptr, + cp.asarray(ref_row_map), + cp.asarray(grp_row_map), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + else: + _wcs.ovo_streaming_csr_device( + data, + indices, + indptr, + cp.asarray(ref_row_ids, dtype=cp.int32), + cp.asarray(all_grp_row_ids, dtype=cp.int32), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) - # Wilcoxon z-score formula for two groups - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr diff = rank_sums - expected if use_continuity: diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std + z = diff / cp.sqrt(variance) cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + None, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) + + scores_host = np.empty((n_test, n_total_genes), dtype=np.float64) + pvals_host = np.empty((n_test, n_total_genes), dtype=np.float64) + + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + n_cols = stop - start + + ref_block = _extract_dense_rows_cols(X, ref_row_ids, start, stop) + grp_block = _extract_dense_rows_cols(X, all_grp_row_ids, start, stop) + + _fill_ovo_chunk_stats( + rg, + ref_block, + grp_block, + offsets=offsets_np, + test_group_indices=test_group_indices, + start=start, + stop=stop, + group_sizes=group_sizes, + ) + + ref_sorted = cp.asfortranarray(cp.sort(ref_block.astype(cp.float32), axis=0)) + grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + rank_sums = cp.empty((n_test, n_cols), dtype=cp.float64) + tie_corr = cp.empty((n_test, n_cols), dtype=cp.float64) + + if use_presorted_groups: + grp_rank_input = cp.empty_like(grp_f32) + for slot in range(n_test): + begin = int(offsets_np[slot]) + end = int(offsets_np[slot + 1]) + grp_rank_input[begin:end] = cp.sort(grp_f32[begin:end], axis=0) + grp_rank_input = cp.asfortranarray(grp_rank_input) + _wc.ovo_rank_presorted( + ref_sorted, + grp_rank_input, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + stream=cp.cuda.get_current_stream().ptr, + ) + else: + _wc.ovo_rank_dense( + ref_sorted, + grp_f32, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + stream=cp.cuda.get_current_stream().ptr, + ) - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr + std = cp.sqrt(variance) + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / std + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) - results.append((group_index, scores, pvals)) + scores_host[:, start:stop] = scores.get() + pvals_host[:, start:stop] = p_values.get() - return results + return [ + (group_index, scores_host[slot], pvals_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index fa4bbccf..70d049af 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -102,7 +102,7 @@ def wilcoxon_binned( ``'log1p'`` uses a fixed [0, 15] range suitable for log1p-normalized data. ``'auto'`` computes the actual (min, max) of the data. Use this - for z-scored or unnormalized data. + for nonnegative expression data outside the fixed log1p range. """ if not rg.is_log1p: warnings.warn( @@ -119,20 +119,6 @@ def wilcoxon_binned( if n_bins is None: n_bins = _DASK_N_BINS if isinstance(X, DaskArray) else _DEFAULT_N_BINS - # Sparse kernels assume non-negative data (pre-fill+correct pattern). - # Dense kernel handles any range. - # NOTE: Dask sparse is not validated here because checking .data.min() - # would require materializing all blocks. The sparse histogram kernels - # will silently produce incorrect results for negative Dask sparse data. - if not isinstance(X, DaskArray) and cpsp.issparse(X) and X.nnz > 0: - if float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. The sparse histogram " - "kernels assume non-negative data. Convert to dense or use " - "bin_range='auto' with a dense array." - ) - raise ValueError(msg) - n_groups = len(rg.groups_order) n_cells, n_genes = X.shape group_sizes = rg.group_sizes diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index 8fe93ae7..24a40721 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -20,6 +20,7 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: + adata_gpu.X = np.abs(adata_gpu.X).astype(np.float32) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() @@ -52,12 +53,19 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] + rtol = 1e-6 if sparse else 1e-13 + if sparse and field in {"scores", "logfoldchanges"}: + atol = 1e-6 + elif sparse: + atol = 1e-12 + else: + atol = 1e-15 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) np.testing.assert_allclose( - gpu_values, cpu_values, rtol=1e-13, atol=1e-15, equal_nan=True + gpu_values, cpu_values, rtol=rtol, atol=atol, equal_nan=True ) params = gpu_result["params"] diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 0c6844da..87030dfb 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,6 +1,7 @@ from __future__ import annotations import cupy as cp +import cupyx.scipy.sparse as cpsp import numpy as np import pandas as pd import pytest @@ -11,6 +12,177 @@ import rapids_singlecell as rsc +def _to_format(X_dense, fmt): + if fmt == "numpy_dense": + return np.asarray(X_dense) + if fmt == "scipy_csr": + return sp.csr_matrix(X_dense) + if fmt == "scipy_csc": + return sp.csc_matrix(X_dense) + if fmt == "cupy_dense": + return cp.asarray(X_dense) + if fmt == "cupy_csr": + return cpsp.csr_matrix(cp.asarray(X_dense)) + if fmt == "cupy_csc": + return cpsp.csc_matrix(cp.asarray(X_dense)) + raise ValueError(f"Unknown format: {fmt}") + + +def _make_nonnegative(adata): + adata.X = np.abs(np.asarray(adata.X)).astype(np.float32) + return adata + + +@pytest.mark.parametrize( + "method", + ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], +) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.float32, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(ValueError, match="Sparse input contains negative values"): + rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) + + +def test_rank_genes_groups_default_lazy_get_df_matches_scanpy(): + np.random.seed(42) + adata_lazy = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) + _make_nonnegative(adata_lazy) + adata_lazy.obs["blobs"] = adata_lazy.obs["blobs"].astype("category") + adata_lazy.X = sp.csr_matrix(adata_lazy.X) + adata_cpu = adata_lazy.copy() + adata_cpu.X = adata_cpu.X.toarray() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "reference": "1", + "use_raw": False, + "tie_correct": True, + "n_genes": 4, + } + rsc.tl.rank_genes_groups(adata_lazy, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + lazy_result = adata_lazy.uns["rank_genes_groups"] + assert lazy_result["names"].dtype.names == ("0", "2") + assert tuple(lazy_result["names"][0]) == tuple( + adata_cpu.uns["rank_genes_groups"]["names"][0] + ) + np.testing.assert_array_equal( + lazy_result["names"].copy(), + np.asarray(lazy_result["names"]), + ) + + lazy_df = sc.get.rank_genes_groups_df(adata_lazy, group=None) + scanpy_df = sc.get.rank_genes_groups_df(adata_cpu, group=None) + pd.testing.assert_frame_equal(lazy_df, scanpy_df) + + +def test_rank_genes_groups_return_format_removed(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(TypeError, match="return_format has been removed"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + return_format="arrays", + ) + + +@pytest.mark.parametrize("reference", ["rest", "b"]) +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_csr"]) +def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): + X = np.array( + [ + [5.0, 0.0, 1.0, 2.0], + [4.0, 0.0, 1.0, 2.0], + [1.0, 3.0, 2.0, 2.0], + [0.0, 2.0, 2.0, 2.0], + [2.0, 1.0, 0.0, 3.0], + [3.0, 1.0, 0.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "b", "b", "c", "c"]) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference=reference, + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + result = adata.uns["rank_genes_groups"] + assert result["params"]["return_u_values"] is True + assert result["scores"].dtype["a"] == np.dtype("float64") + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + mask_group = labels == "a" + mask_ref = labels != "a" if reference == "rest" else labels == reference + expected = np.array( + [ + mannwhitneyu( + X[mask_group, gene], + X[mask_ref, gene], + alternative="two-sided", + ).statistic + for gene in range(X.shape[1]) + ], + dtype=np.float64, + ) + + gene_to_idx = {name: idx for idx, name in enumerate(adata.var_names)} + expected_sorted = np.array([expected[gene_to_idx[name]] for name in df["names"]]) + np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) + + +def test_rank_genes_groups_return_u_values_requires_wilcoxon(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="only supported for method='wilcoxon'"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="t-test", + use_raw=False, + return_u_values=True, + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("tie_correct", [True, False]) @pytest.mark.parametrize("sparse", [True, False]) @@ -21,6 +193,7 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: + _make_nonnegative(adata_gpu) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() @@ -55,11 +228,13 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] + rtol = 1e-6 if field == "logfoldchanges" else 1e-13 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) - np.testing.assert_allclose(gpu_values, cpu_values, rtol=1e-13, atol=1e-15) + atol = 1e-6 if field == "logfoldchanges" else 1e-15 + np.testing.assert_allclose(gpu_values, cpu_values, rtol=rtol, atol=atol) params = gpu_result["params"] assert params["use_raw"] is False @@ -148,6 +323,230 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): assert np.all(adjusted <= 1.0) +def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=21) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["valid"] * 10 + ["singleton"], + categories=["ref", "valid", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + n_genes=3, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert result["names"].dtype.names == ("valid",) + assert result["scores"].dtype.names == ("valid",) + + +def test_rank_genes_groups_wilcoxon_skip_empty_groups_all_tests_filtered(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=11) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["singleton"], + categories=["ref", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert "names" not in result + assert result["params"]["reference"] == "ref" + + +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csr", id="host_csr"), + pytest.param("scipy_csc", id="host_csc"), + pytest.param("cupy_dense", id="device_dense"), + ], +) +def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): + """groups=... with reference='rest' must use all other cells for stats.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=160) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "pts": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-6 if field == "logfoldchanges" else 1e-13 + atol = 1e-6 if field == "logfoldchanges" else 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + for key in ("pts", "pts_rest"): + gpu_pts = gpu_result[key] + cpu_pts = cpu_result[key] + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_zero_nnz_host_sparse_does_not_crash(reference, fmt): + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * 4 + ["1"] * 4 + ["2"] * 4, + categories=["0", "1", "2"], + ) + } + ) + adata = sc.AnnData( + X=_to_format(np.zeros((12, 5), dtype=np.float32), fmt), + obs=obs, + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + pts=True, + ) + + result = adata.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + assert np.all(np.isfinite(np.asarray(result[field][group], dtype=float))) + + +def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): + rng = np.random.default_rng(42) + dense = rng.poisson(1.0, size=(80, 12)).astype(np.float32) + dense[rng.random(dense.shape) < 0.55] = 0 + sorted_csr = sp.csr_matrix(dense) + unsorted_csr = sorted_csr.copy() + for row in range(unsorted_csr.shape[0]): + start, stop = unsorted_csr.indptr[row : row + 2] + order = np.arange(stop - start)[::-1] + unsorted_csr.indices[start:stop] = unsorted_csr.indices[start:stop][order] + unsorted_csr.data[start:stop] = unsorted_csr.data[start:stop][order] + unsorted_csr.has_sorted_indices = False + + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["ref"] * 20 + ["a"] * 20 + ["b"] * 20 + ["c"] * 20, + categories=["ref", "a", "b", "c"], + ) + } + ) + var = pd.DataFrame(index=[f"g{i}" for i in range(dense.shape[1])]) + sorted_adata = sc.AnnData(X=sorted_csr, obs=obs.copy(), var=var.copy()) + unsorted_adata = sc.AnnData(X=unsorted_csr, obs=obs.copy(), var=var.copy()) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "ref", + "use_raw": False, + "tie_correct": True, + "n_genes": dense.shape[1], + } + rsc.tl.rank_genes_groups(sorted_adata, **kw) + rsc.tl.rank_genes_groups(unsorted_adata, **kw) + + sorted_result = sorted_adata.uns["rank_genes_groups"] + unsorted_result = unsorted_adata.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in sorted_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(unsorted_result[field][group], dtype=float), + np.asarray(sorted_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + "numpy_dense", + "scipy_csr", + "scipy_csc", + "cupy_dense", + "cupy_csr", + "cupy_csc", + ], +) +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": 5, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-6 if field == "logfoldchanges" else 1e-13 + atol = 1e-6 if field == "logfoldchanges" else 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + @pytest.mark.parametrize( "reference_before,reference_after", [("rest", "rest"), ("1", "One")], From 4094e6be4a54dc72b52ca1adb1e12267ac3e5c1b Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 17:07:29 +0200 Subject: [PATCH 2/7] add rmm --- CMakeLists.txt | 70 +++++++++++++++++++ notebooks | 2 +- pyproject.toml | 18 ++++- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 52 +++++++------- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 4 +- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 4 +- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 46 ++++++------ .../_cuda/wilcoxon/wilcoxon_rmm.cu | 20 ++++++ .../_cuda/wilcoxon/wilcoxon_sparse.cu | 24 +++++-- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 5 +- .../tools/_rank_genes_groups/_wilcoxon.py | 2 +- 11 files changed, 181 insertions(+), 66 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 85d33e91..67d8090c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,74 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + set(RSC_RMM_HINTS) + set(RSC_RAPIDS_CMAKE_PREFIXES) + set(RSC_CCCL_HINTS) + set(RSC_RAPIDS_LOGGER_HINTS) + set(RSC_NVTX3_HINTS) + macro(_rsc_collect_rapids_python_prefix _rsc_prefix) + if (NOT "${_rsc_prefix}" STREQUAL "") + file(GLOB _rsc_rmm_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/rmm") + file(GLOB _rsc_rapids_prefixes + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64" + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids" + "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib" + ) + file(GLOB _rsc_cccl_dirs + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids/cmake/cccl" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib/cmake/cccl" + ) + file(GLOB _rsc_rapids_logger_dirs "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_nvtx3_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_nvtx3_dirs}) + endif() + endmacro() + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import importlib.util, pathlib; spec = importlib.util.find_spec('librmm'); print(pathlib.Path(spec.origin).parent / 'lib64' / 'cmake' / 'rmm' if spec else '')" + OUTPUT_VARIABLE RSC_PYTHON_RMM_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") + list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") + endif() + foreach(_rsc_python_prefix IN ITEMS "${Python_ROOT_DIR}" "${Python3_ROOT_DIR}") + _rsc_collect_rapids_python_prefix("${_rsc_python_prefix}") + endforeach() + foreach(_rsc_env_prefix IN ITEMS "$ENV{CONDA_PREFIX}" "$ENV{VIRTUAL_ENV}") + _rsc_collect_rapids_python_prefix("${_rsc_env_prefix}") + endforeach() + string(REPLACE ":" ";" _rsc_path_entries "$ENV{PATH}") + foreach(_rsc_path_entry IN LISTS _rsc_path_entries) + get_filename_component(_rsc_path_prefix "${_rsc_path_entry}/.." ABSOLUTE) + _rsc_collect_rapids_python_prefix("${_rsc_path_prefix}") + endforeach() + if (RSC_RAPIDS_CMAKE_PREFIXES) + list(APPEND CMAKE_PREFIX_PATH ${RSC_RAPIDS_CMAKE_PREFIXES}) + if (RSC_CCCL_HINTS) + list(GET RSC_CCCL_HINTS 0 _rsc_cccl_dir) + set(CCCL_DIR "${_rsc_cccl_dir}" CACHE PATH "Path to CCCL package config" FORCE) + endif() + if (RSC_RAPIDS_LOGGER_HINTS) + list(GET RSC_RAPIDS_LOGGER_HINTS 0 _rsc_rapids_logger_dir) + set(rapids_logger_DIR "${_rsc_rapids_logger_dir}" CACHE PATH "Path to rapids_logger package config" FORCE) + endif() + if (RSC_NVTX3_HINTS) + list(GET RSC_NVTX3_HINTS 0 _rsc_nvtx3_dir) + set(nvtx3_DIR "${_rsc_nvtx3_dir}" CACHE PATH "Path to nvtx3 package config" FORCE) + endif() + endif() + if (RSC_RMM_HINTS) + find_package(rmm CONFIG REQUIRED HINTS ${RSC_RMM_HINTS}) + else() + find_package(rmm CONFIG REQUIRED) + endif() + message(STATUS "Using RMM for CUDA extension scratch allocations") message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -86,6 +154,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) + target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) + target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/notebooks b/notebooks index 4cdaa44f..e5c97b34 160000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -Subproject commit 4cdaa44fbd93b6f812fc8d2c72b89180ef92047d +Subproject commit e5c97b34f4acbf919fb3118c987cc5893e5b5fdf diff --git a/pyproject.toml b/pyproject.toml index c38e1d00..dc69471a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,22 @@ dependencies = [ ] [project.optional-dependencies] -rapids-cu13 = [ "cupy-cuda13x", "cudf-cu13>=25.10", "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10" ] -rapids-cu12 = [ "cupy-cuda12x", "cudf-cu12>=25.10", "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10" ] +rapids-cu13 = [ + "cupy-cuda13x", + "cudf-cu13>=25.10", + "cuml-cu13>=25.10", + "cugraph-cu13>=25.10", + "cuvs-cu13>=25.10", + "rmm-cu13>=25.10", +] +rapids-cu12 = [ + "cupy-cuda12x", + "cudf-cu12>=25.10", + "cuml-cu12>=25.10", + "cugraph-cu12>=25.10", + "cuvs-cu12>=25.10", + "rmm-cu12>=25.10", +] doc = [ "sphinx>=4.5.0", diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index dd50d2cb..2d5b3f2c 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -9,6 +9,9 @@ #include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR +void* wilcoxon_rmm_allocate(size_t bytes); +void wilcoxon_rmm_deallocate(void* ptr, size_t bytes); + constexpr int WARP_SIZE = 32; constexpr int MAX_THREADS_PER_BLOCK = 512; constexpr int N_STREAMS = 4; @@ -93,50 +96,45 @@ struct HostRegisterGuard { }; // --------------------------------------------------------------------------- -// Small allocation pool for temporary CUDA buffers. The previous PR used RMM -// here, but these sparse Wilcoxon kernels only need scoped scratch memory; -// using cudaMalloc keeps this module independent of an extra build-time -// dependency. +// Small allocation pool for temporary CUDA buffers. Uses the current RMM device +// resource so scratch participates in the same pool as CuPy/RAPIDS allocations. // --------------------------------------------------------------------------- -struct RmmPool { - std::vector bufs; - - ~RmmPool() { - for (void* ptr : bufs) { - if (ptr) cudaFree(ptr); +struct RmmScratchPool { + struct Allocation { + void* ptr = nullptr; + size_t bytes = 0; + }; + std::vector bufs; + + ~RmmScratchPool() { + for (Allocation alloc : bufs) { + if (!alloc.ptr) continue; + wilcoxon_rmm_deallocate(alloc.ptr, alloc.bytes); } } template T* alloc(size_t count) { if (count == 0) count = 1; - void* ptr = nullptr; - cudaError_t err = cudaMalloc(&ptr, count * sizeof(T)); - if (err != cudaSuccess) { - throw std::runtime_error( - std::string("cudaMalloc failed in Wilcoxon scratch pool: ") + - cudaGetErrorString(err)); - } - bufs.push_back(ptr); + size_t bytes = count * sizeof(T); + void* ptr = wilcoxon_rmm_allocate(bytes); + bufs.push_back({ptr, bytes}); return static_cast(ptr); } }; struct ScopedCudaBuffer { void* ptr = nullptr; + size_t bytes = 0; - explicit ScopedCudaBuffer(size_t bytes) { - if (bytes == 0) bytes = 1; - cudaError_t err = cudaMalloc(&ptr, bytes); - if (err != cudaSuccess) { - throw std::runtime_error( - std::string("cudaMalloc failed in Wilcoxon scoped buffer: ") + - cudaGetErrorString(err)); - } + explicit ScopedCudaBuffer(size_t requested_bytes) { + bytes = requested_bytes == 0 ? 1 : requested_bytes; + ptr = wilcoxon_rmm_allocate(bytes); } ~ScopedCudaBuffer() { - if (ptr) cudaFree(ptr); + if (!ptr) return; + wilcoxon_rmm_deallocate(ptr, bytes); } void* data() { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index 7ad20b01..b195bee0 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -59,7 +59,7 @@ static void ovo_streaming_csr_impl( } if (ref_cache_cols < 1) ref_cache_cols = 1; - RmmPool pool; + RmmScratchPool pool; size_t cub_temp_bytes = 0; if (needs_tier3) { @@ -340,7 +340,7 @@ static void ovo_streaming_csc_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - RmmPool pool; + RmmScratchPool pool; int* d_sort_group_ids = nullptr; if (needs_tier3) { d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh index feb86e57..11827b0a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -73,7 +73,7 @@ static void ovo_streaming_csc_host_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - RmmPool pool; + RmmScratchPool pool; int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); @@ -470,7 +470,7 @@ static void ovo_streaming_csr_host_impl( size_t max_sub_items = (size_t)max_pack_items; if (max_pack_rows == 0) return; - RmmPool pool; + RmmScratchPool pool; // Zero stats outputs. if (compute_sums) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 0f74a2c8..6eae2a28 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -7,9 +7,9 @@ * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. */ -template +template static void ovr_sparse_csc_host_streaming_impl( - const InT* h_data, const int* h_indices, const IndptrT* h_indptr, + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, int n_rows, int n_cols, int n_groups, @@ -33,7 +33,7 @@ static void ovr_sparse_csc_host_streaming_impl( size_t cub_temp_bytes = 0; if (max_nnz > 0) { auto* fk = reinterpret_cast(1); - auto* iv = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); @@ -42,16 +42,16 @@ static void ovr_sparse_csc_host_streaming_impl( std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); - RmmPool pool; + RmmScratchPool pool; int* d_group_codes = pool.alloc(n_rows); double* d_group_sizes = pool.alloc(n_groups); struct StreamBuf { InT* d_sparse_data_orig; float* d_sparse_data_f32; - int* d_sparse_indices; + IndexT* d_sparse_indices; int* d_seg_offsets; float* keys_out; - int* vals_out; + IndexT* vals_out; uint8_t* cub_temp; double* d_rank_sums; double* d_tie_corr; @@ -64,10 +64,10 @@ static void ovr_sparse_csc_host_streaming_impl( for (int s = 0; s < n_streams; s++) { bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); - bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); bufs[s].keys_out = pool.alloc(max_nnz); - bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); bufs[s].cub_temp = pool.alloc(cub_temp_bytes); bufs[s].d_rank_sums = pool.alloc((size_t)n_groups * sub_batch_cols); @@ -128,8 +128,8 @@ static void ovr_sparse_csc_host_streaming_impl( size_t total_nnz = (size_t)h_indptr[n_cols]; HostRegisterGuard _pin_data(const_cast(h_data), total_nnz * sizeof(InT)); - HostRegisterGuard _pin_indices(const_cast(h_indices), - total_nnz * sizeof(int)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); cudaDeviceSynchronize(); @@ -151,7 +151,7 @@ static void ovr_sparse_csc_host_streaming_impl( (size_t)batch_nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, - (size_t)batch_nnz * sizeof(int), + (size_t)batch_nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); } @@ -161,7 +161,7 @@ static void ovr_sparse_csc_host_streaming_impl( cudaMemcpyDeviceToDevice, stream); // Cast to float32 for sort + accumulate stats in float64 - launch_ovr_cast_and_accumulate_sparse( + launch_ovr_cast_and_accumulate_sparse( buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_seg_offsets, d_group_codes, buf.d_group_sums, buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, @@ -187,10 +187,12 @@ static void ovr_sparse_csc_host_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( - buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, - d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, buf.d_nz_scratch, - n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + rank_sums_sparse_ovr_kernel + <<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); // D2D: scatter sub-batch results into caller's GPU buffers @@ -257,7 +259,7 @@ static void ovr_sparse_csr_host_streaming_impl( int sub_batch_cols) { if (n_rows == 0 || n_cols == 0) return; - RmmPool pool; + RmmScratchPool pool; size_t total_nnz = (size_t)h_indptr[n_rows]; // ---- Phase 0: CPU planning in native CSR order ---- @@ -466,7 +468,7 @@ static void ovr_sparse_csr_host_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( + rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, @@ -558,7 +560,7 @@ static void ovr_sparse_csc_streaming_impl( bool rank_use_gmem = false; size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); - RmmPool pool; + RmmScratchPool pool; struct StreamBuf { float* keys_out; int* vals_out; @@ -626,7 +628,7 @@ static void ovr_sparse_csc_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( + rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); @@ -681,7 +683,7 @@ static void ovr_sparse_csr_streaming_impl( if (n_rows == 0 || n_cols == 0) return; // ---- Phase 0: Planning — count nnz per column via histogram ---- - RmmPool pool; + RmmScratchPool pool; int* d_col_counts = pool.alloc(n_cols); cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); { @@ -829,7 +831,7 @@ static void ovr_sparse_csr_streaming_impl( (size_t)n_groups * sb_cols * sizeof(double), stream); } - rank_sums_sparse_ovr_kernel<<>>( + rank_sums_sparse_ovr_kernel<<>>( buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu new file mode 100644 index 00000000..26e37f42 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu @@ -0,0 +1,20 @@ +#include +#include +#include + +#include +#include + +void* wilcoxon_rmm_allocate(size_t bytes) { + try { + return rmm::mr::get_current_device_resource()->allocate_sync(bytes); + } catch (std::exception const& e) { + throw std::runtime_error( + std::string("RMM allocation failed in Wilcoxon scratch: ") + + e.what()); + } +} + +void wilcoxon_rmm_deallocate(void* ptr, size_t bytes) { + rmm::mr::get_current_device_resource()->deallocate_sync(ptr, bytes); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu index 19f1ef57..4316d284 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -61,10 +61,10 @@ void register_sparse_bindings(nb::module_& m) { "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); -#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndptrT) \ +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ m.def( \ NAME, \ - [](host_array h_data, host_array h_indices, \ + [](host_array h_data, host_array h_indices, \ host_array h_indptr, \ host_array h_group_codes, \ host_array h_group_sizes, \ @@ -75,7 +75,7 @@ void register_sparse_bindings(nb::module_& m) { gpu_array_c d_group_nnz, int n_rows, int n_cols, \ int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ bool compute_nnz, int sub_batch_cols) { \ - ovr_sparse_csc_host_streaming_impl( \ + ovr_sparse_csc_host_streaming_impl( \ h_data.data(), h_indices.data(), h_indptr.data(), \ h_group_codes.data(), h_group_sizes.data(), \ d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ @@ -90,11 +90,21 @@ void register_sparse_bindings(nb::module_& m) { "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ "sub_batch_cols"_a = SUB_BATCH_COLS) - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int64_t); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int); - RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64_i64", double, + int64_t, int64_t); #undef RSC_OVR_SPARSE_CSC_HOST_BINDING #define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index b0e40fdc..603c1c96 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -172,9 +172,10 @@ __global__ void rank_sums_from_sorted_kernel( * * Grid: (sb_cols,) Block: (tpb,) */ +template __global__ void rank_sums_sparse_ovr_kernel( const float* __restrict__ sorted_vals, - const int* __restrict__ sorted_row_idx, + const IndexT* __restrict__ sorted_row_idx, const int* __restrict__ col_seg_offsets, const int* __restrict__ group_codes, const double* __restrict__ group_sizes, double* __restrict__ rank_sums, double* __restrict__ tie_corr, @@ -188,7 +189,7 @@ __global__ void rank_sums_sparse_ovr_kernel( int nnz_stored = seg_end - seg_start; const float* sv = sorted_vals + seg_start; - const int* si = sorted_row_idx + seg_start; + const IndexT* si = sorted_row_idx + seg_start; extern __shared__ double smem[]; double* grp_sums; diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index e20af614..90c54eb2 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -489,7 +489,7 @@ def _wilcoxon_vs_rest( csc = csc.copy() csc.sort_indices() csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( - _wcs, "ovr_sparse_csc_host", csc, support_idx64=False + _wcs, "ovr_sparse_csc_host", csc, support_idx64=True ) csc_host_fn( data_arr, From 9c391ed369c347ad3ae4ea0fb4c8a2c169b33829 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 17:42:23 +0200 Subject: [PATCH 3/7] update publish and cmake --- .github/workflows/publish.yml | 52 ++++- CMakeLists.txt | 20 ++ pyproject.toml | 9 +- .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 141 ------------- .../_cuda/wilcoxon/wilcoxon.cu | 48 ----- .../_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh | 9 - .../wilcoxon/wilcoxon_sparse_kernels.cuh | 145 -------------- .../tools/_rank_genes_groups/_wilcoxon.py | 65 ------ tests/test_rank_genes_groups_wilcoxon.py | 187 +----------------- 9 files changed, 74 insertions(+), 602 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3f2e4447..4ca5d522 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -69,16 +69,47 @@ jobs: path = pathlib.Path("pyproject.toml") text = path.read_text() + def remove_toml_array(text, key): + lines = text.splitlines(keepends=True) + out = [] + i = 0 + while i < len(lines): + if lines[i].startswith(f"{key} = ["): + depth = lines[i].count("[") - lines[i].count("]") + i += 1 + while i < len(lines) and depth > 0: + depth += lines[i].count("[") - lines[i].count("]") + i += 1 + continue + out.append(lines[i]) + i += 1 + return "".join(out) + # Rename package text = text.replace( 'name = "rapids-singlecell"', f'name = "rapids-singlecell-cu{cuda}"', ) # Rename matching extra to "rapids", remove the other - text = text.replace(f'rapids-cu{cuda} =', 'rapids =') - # Remove the other CUDA extra line entirely - lines = text.splitlines(keepends=True) - text = "".join(l for l in lines if f'rapids-cu{other}' not in l) + text = text.replace(f'rapids-cu{cuda} = [', 'rapids = [') + text = remove_toml_array(text, f"rapids-cu{other}") + + # librmm is needed at build time because CMake links the CUDA + # extension against librmm. Add the matching wheel to the isolated + # PEP 517 build requirements after selecting the CUDA package variant. + for dep in ( + f' "librmm-cu{other}>=25.10",\n', + f' "rmm-cu{other}>=25.10",\n', + ): + text = text.replace(dep, "") + rmm_build_req = f' "librmm-cu{cuda}>=25.10",\n' + build_system_text = text.split("[project]", 1)[0] + if f'"librmm-cu{cuda}>=25.10"' not in build_system_text: + text = text.replace( + ']\nbuild-backend = "scikit_build_core.build"', + f'{rmm_build_req}]\nbuild-backend = "scikit_build_core.build"', + 1, + ) # Set CUDA architectures (replace "native" with CI target archs) text = text.replace( @@ -96,6 +127,7 @@ jobs: - name: Sanity check pyproject.toml run: | + python3 -c "import tomllib; tomllib.load(open('pyproject.toml', 'rb'))" grep -E "name|rapids|CUDA_ARCH" pyproject.toml - name: Build CUDA manylinux image @@ -117,9 +149,19 @@ jobs: CIBW_BEFORE_BUILD: > python -m pip install -U pip scikit-build-core cmake ninja nanobind + librmm-cu${{ matrix.cuda_major }} && + RMM_ROOT=$(python -c "import librmm; print(librmm.__path__[0])") && + LOG_ROOT=$(python -c "import rapids_logger; print(rapids_logger.__path__[0])") && + echo "[rsc-build] librmm=$RMM_ROOT" && + echo "[rsc-build] rapids_logger=$LOG_ROOT" && + ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + ldconfig && + python -c "import librmm; print(librmm.__path__[0])" > /tmp/.librmm_dir && + echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 67d8090c..85fcfc2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,26 @@ if (RSC_BUILD_EXTENSIONS) if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") endif() + if(EXISTS "/tmp/.librmm_dir") + file(READ "/tmp/.librmm_dir" _rsc_librmm_marker) + string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker) + file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm") + file(GLOB _rsc_marker_rapids_prefixes + "${_rsc_librmm_marker}/lib64" + "${_rsc_librmm_marker}/lib64/rapids" + "${_rsc_librmm_marker}/../rapids_logger/lib64" + ) + file(GLOB _rsc_marker_cccl_dirs + "${_rsc_librmm_marker}/lib64/rapids/cmake/cccl" + ) + file(GLOB _rsc_marker_rapids_logger_dirs "${_rsc_librmm_marker}/../rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_marker_nvtx3_dirs "${_rsc_librmm_marker}/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_marker_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_marker_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_marker_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_marker_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_marker_nvtx3_dirs}) + endif() foreach(_rsc_python_prefix IN ITEMS "${Python_ROOT_DIR}" "${Python3_ROOT_DIR}") _rsc_collect_rapids_python_prefix("${_rsc_python_prefix}") endforeach() diff --git a/pyproject.toml b/pyproject.toml index dc69471a..a3b07ede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,9 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", + # librmm headers/CMake config are needed at build time for Wilcoxon. + # CUDA wheel builds rewrite this to the matching cu12/cu13 package. + "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" @@ -38,7 +41,7 @@ rapids-cu13 = [ "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10", - "rmm-cu13>=25.10", + "librmm-cu13>=25.10", ] rapids-cu12 = [ "cupy-cuda12x", @@ -46,7 +49,7 @@ rapids-cu12 = [ "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10", - "rmm-cu12>=25.10", + "librmm-cu12>=25.10", ] doc = [ @@ -164,7 +167,7 @@ sdist.include = [ "src/rapids_singlecell/_version.py" ] # Use abi3audit to catch issues with Limited API wheels [tool.cibuildwheel.linux] repair-wheel-command = [ - "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 -w {dest_dir} {wheel}", + "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", "pipx run abi3audit --strict --report {wheel}", ] [tool.cibuildwheel.macos] diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 8b6af5f6..3c42f60a 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -23,147 +23,6 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, return 0.0; } -/** - * Kernel to compute tie correction factor for Wilcoxon test. - * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied - * values. - * - * Each block handles one column. Uses binary search to find tie groups. - * Assumes input is sorted column-wise (F-order). - */ -__global__ void tie_correction_kernel(const double* __restrict__ sorted_vals, - double* __restrict__ correction, - const int n_rows, const int n_cols) { - // Each block handles one column - int col = blockIdx.x; - if (col >= n_cols) return; - - const double* sv = sorted_vals + (size_t)col * n_rows; - - double local_sum = 0.0; - int tid = threadIdx.x; - - // Each thread processes positions where it detects END of a tie group - // Start from index 1, check if sv[i-1] != sv[i] (boundary detected) - // When at boundary, use binary search to find tie group size - for (int i = tid + 1; i <= n_rows; i += blockDim.x) { - // Detect boundary: either at the end, or value changed - bool at_boundary = (i == n_rows) || (sv[i] != sv[i - 1]); - - if (at_boundary) { - // Found end of tie group at position i-1 - // Binary search for start of this tie group - double val = sv[i - 1]; - int lo = 0, hi = i - 1; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_count = i - lo; - - // t^3 - t for this tie group - double t = (double)tie_count; - local_sum += t * t * t - t; - } - } - - // Warp-level reduction using shuffle -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } - - // Cross-warp reduction using small shared memory - __shared__ double warp_sums[32]; - int lane = tid & 31; - int warp_id = tid >> 5; - - if (lane == 0) { - warp_sums[warp_id] = local_sum; - } - __syncthreads(); - - // Final reduction in first warp - // Note: blockDim.x must be a multiple of 32 for correct warp reduction - if (tid < 32) { - double val = (tid < (blockDim.x >> 5)) ? warp_sums[tid] : 0.0; -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xffffffff, val, offset); - } - if (tid == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - if (denom > 0) { - correction[col] = 1.0 - val / denom; - } else { - correction[col] = 1.0; - } - } - } -} - -/** - * Kernel to compute average ranks for each column. - * Uses scipy.stats.rankdata 'average' method: ties get the average of the ranks - * they would span. - * - * Each block handles one column. Assumes input is sorted column-wise (F-order). - */ -__global__ void average_rank_kernel(const double* __restrict__ sorted_vals, - const int* __restrict__ sorter, - double* __restrict__ ranks, - const int n_rows, const int n_cols) { - // Each thread block handles one column - int col = blockIdx.x; - if (col >= n_cols) return; - - // Pointers to this column's data - const double* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorter + (size_t)col * n_rows; - double* rk = ranks + (size_t)col * n_rows; - - // Each thread processes multiple rows - for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - double val = sv[i]; - - // Binary search for tie_start (first element equal to val) - int lo = 0, hi = i; - while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { - lo = mid + 1; - } else { - hi = mid; - } - } - int tie_start = lo; - - // Binary search for tie_end (last element equal to val) - lo = i; - hi = n_rows - 1; - while (lo < hi) { - int mid = (lo + hi + 1) / 2; - if (sv[mid] > val) { - hi = mid - 1; - } else { - lo = mid; - } - } - int tie_end = lo; - - // Average rank for ties: (start + end + 2) / 2 (1-based ranks) - double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; - - // Write rank to original position - rk[si[i]] = avg_rank; - } -} - /** * OVO dense rank core. * diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 0ab5b26c..38fc25ec 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -15,29 +15,6 @@ static inline int round_up_to_warp(int n) { return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; } -static inline void launch_tie_correction(const double* sorted_vals, - double* correction, int n_rows, - int n_cols, cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(tie_correction_kernel); -} - -static inline void launch_average_rank(const double* sorted_vals, - const int* sorter, double* ranks, - int n_rows, int n_cols, - cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(average_rank_kernel); -} - static inline void launch_ovo_rank_dense( const float* ref_sorted, const float* grp_data, const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, @@ -79,31 +56,6 @@ template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - // Tie correction kernel - m.def( - "tie_correction", - [](gpu_array_f sorted_vals, - gpu_array correction, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_tie_correction(sorted_vals.data(), correction.data(), n_rows, - n_cols, (cudaStream_t)stream); - }, - "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, - "stream"_a = 0); - - // Average rank kernel - m.def( - "average_rank", - [](gpu_array_f sorted_vals, - gpu_array_f sorter, - gpu_array_f ranks, int n_rows, int n_cols, - std::uintptr_t stream) { - launch_average_rank(sorted_vals.data(), sorter.data(), ranks.data(), - n_rows, n_cols, (cudaStream_t)stream); - }, - "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "stream"_a = 0); - m.def( "ovo_rank_dense", [](gpu_array_f ref_sorted, diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh index afac20f2..9fd626b6 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -49,15 +49,6 @@ __global__ void csc_extract_mapped_kernel(const float* __restrict__ data, } } -static size_t get_seg_sort_temp_bytes(int n_items, int n_segments) { - size_t bytes = 0; - auto* dk = reinterpret_cast(1); - auto* doff = reinterpret_cast(1); - cub::DeviceSegmentedRadixSort::SortKeys(nullptr, bytes, dk, dk, n_items, - n_segments, doff, doff + 1, 0, 32); - return bytes; -} - /** * Tier 1 dispatch: when the largest group fits in shared memory, a fused * bitonic-sort + binary-search kernel handles the whole group per block. diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index 603c1c96..d30f92cc 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -2,151 +2,6 @@ #include -/** - * Fused rank-sum kernel: walk sorted data, compute per-group rank sums - * and tie correction without materializing a rank matrix. - * - * Each thread processes a CONTIGUOUS chunk of sorted elements, detecting - * tie groups by adjacent comparison (sequential access, no binary search). - * Cross-boundary ties are resolved via binary search at chunk boundaries. - * - * When use_gmem is false, per-group accumulators live in shared memory - * (fast atomics, limited to ~1500 groups on 48 KB devices). When use_gmem - * is true, accumulators write directly to ``rank_sums`` in global memory, - * supporting an arbitrary number of groups. The caller must pre-zero - * ``rank_sums`` before launching in the gmem path. - * - * Shared memory layout: - * use_gmem=false: (n_groups + 32) doubles (accumulators + warp buf) - * use_gmem=true: 32 doubles (warp buf only) - */ -__global__ void rank_sums_from_sorted_kernel( - const float* __restrict__ sorted_vals, - const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, - double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, - int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { - int col = blockIdx.x; - if (col >= n_cols) return; - - extern __shared__ double smem[]; - - double* grp_sums; - if (use_gmem) { - // Global memory path: write directly to output (must be pre-zeroed) - grp_sums = rank_sums + (size_t)col; // stride: n_cols - } else { - // Shared memory path: per-block accumulators - grp_sums = smem; - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - grp_sums[g] = 0.0; - } - __syncthreads(); - } - - const float* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorted_row_idx + (size_t)col * n_rows; - - int chunk = (n_rows + blockDim.x - 1) / blockDim.x; - int my_start = threadIdx.x * chunk; - int my_end = my_start + chunk; - if (my_end > n_rows) my_end = n_rows; - - double local_tie_sum = 0.0; - - // Stride for accumulator indexing: 1 for shared mem, n_cols for global mem - int acc_stride = use_gmem ? n_cols : 1; - - int i = my_start; - while (i < my_end) { - double val = sv[i]; - - int tie_local_end = i + 1; - while (tie_local_end < my_end && sv[tie_local_end] == val) - ++tie_local_end; - - int tie_global_start = i; - if (i == my_start && i > 0 && sv[i - 1] == val) { - int lo = 0, hi = i; - while (lo < hi) { - int mid = lo + (hi - lo) / 2; - if (sv[mid] < val) - lo = mid + 1; - else - hi = mid; - } - tie_global_start = lo; - } - - int tie_global_end = tie_local_end; - if (tie_local_end == my_end && tie_local_end < n_rows && - sv[tie_local_end] == val) { - int lo = tie_local_end, hi = n_rows - 1; - while (lo < hi) { - int mid = hi - ((hi - lo) >> 1); - if (sv[mid] > val) - hi = mid - 1; - else - lo = mid; - } - tie_global_end = lo + 1; - } - - int total_tie = tie_global_end - tie_global_start; - double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - - for (int j = i; j < tie_local_end; ++j) { - int grp = group_codes[si[j]]; - if (grp < n_groups) { - atomicAdd(&grp_sums[grp * acc_stride], avg_rank); - } - } - - if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { - double t = (double)total_tie; - local_tie_sum += t * t * t - t; - } - - i = tie_local_end; - } - - __syncthreads(); - - // Copy shared memory accumulators to global output (smem path only) - if (!use_gmem) { - for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { - rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; - } - } - - if (compute_tie_corr) { - // Warp buf sits after accumulator array in shared memory. - // gmem path: warp buf starts at smem[0]. - // smem path: n_groups doubles, then warp buf. - int warp_buf_off = use_gmem ? 0 : n_groups; - double* warp_buf = smem + warp_buf_off; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); - int lane = threadIdx.x & 31; - int wid = threadIdx.x >> 5; - if (lane == 0) warp_buf[wid] = local_tie_sum; - __syncthreads(); - if (threadIdx.x < 32) { - double val = (threadIdx.x < ((blockDim.x + 31) >> 5)) - ? warp_buf[threadIdx.x] - : 0.0; -#pragma unroll - for (int off = 16; off > 0; off >>= 1) - val += __shfl_down_sync(0xffffffff, val, off); - if (threadIdx.x == 0) { - double n = (double)n_rows; - double denom = n * n * n - n; - tie_corr[col] = (denom > 0.0) ? (1.0 - val / denom) : 1.0; - } - } - } -} - /** * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. * diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 90c54eb2..4fec5948 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -30,71 +30,6 @@ OVO_DEVICE_SPARSE_SUB_BATCH = 128 -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: - """ - Compute average ranks for each column using GPU kernel. - - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. - - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) - - Returns - ------- - ranks or (ranks, sorted_vals) - """ - n_rows, n_cols = matrix.shape - - # Sort each column - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) - - # Ensure F-order for kernel (columns contiguous in memory) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) - - stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream - ) - - if return_sorted: - return matrix, sorted_vals - return matrix - - -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: - """ - Compute tie correction factor for Wilcoxon test. - - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. - """ - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) - - if n_rows < 2: - return correction - - # Ensure F-order - sorted_vals = cp.asfortranarray(sorted_vals) - - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream - ) - - return correction - - def _extract_dense_rows_cols( X, row_ids: np.ndarray, start: int, stop: int ) -> cp.ndarray: diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 87030dfb..413cfe3b 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -7,7 +7,7 @@ import pytest import scanpy as sc import scipy.sparse as sp -from scipy.stats import mannwhitneyu, rankdata, tiecorrect +from scipy.stats import mannwhitneyu import rapids_singlecell as rsc @@ -840,188 +840,3 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): np.testing.assert_array_equal( dense_df["pvals"].values, sparse_df["pvals"].values ) - - -# ============================================================================ -# Tests for ranking and tie correction kernels (edge cases from scipy) -# ============================================================================ - - -class TestRankingKernel: - """Tests for _average_ranks based on scipy.stats.rankdata edge cases.""" - - @pytest.fixture - def average_ranks(self): - """Import the ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - ) - - return _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_basic_ranking(self, average_ranks): - """Test basic average ranking on simple data.""" - values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_all_ties(self, average_ranks): - """All identical values should get the average rank.""" - values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_no_ties(self, average_ranks): - """All unique values should get sequential ranks.""" - values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_mixed_ties(self, average_ranks): - """Mix of ties and unique values.""" - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_negative_values(self, average_ranks): - """Test with negative values.""" - values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_single_element(self, average_ranks): - """Single element should have rank 1.""" - values = [42.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.0]) - - def test_two_elements_tied(self, average_ranks): - """Two tied elements should both have rank 1.5.""" - values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) - - def test_multiple_columns(self, average_ranks): - """Test ranking across multiple columns independently.""" - col0 = [3.0, 1.0, 2.0] - col1 = [1.0, 1.0, 2.0] - data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) - - np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) - np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) - - -class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" - - @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - _tie_correction, - ) - - return _tie_correction, _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_no_ties(self, tie_correction): - """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_all_ties(self, tie_correction): - """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_mixed_ties(self, tie_correction): - """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_two_elements_tied(self, tie_correction): - """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction - - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_single_element(self, tie_correction): - """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - # Single element: n^3 - n = 0, so formula gives 1.0 - np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - - def test_multiple_columns(self, tie_correction): - """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction - - col0 = [1.0, 2.0, 3.0] # No ties - col1 = [5.0, 5.0, 5.0] # All ties - data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) - - np.testing.assert_allclose( - result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 - ) - np.testing.assert_allclose( - result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 - ) - - def test_large_tie_groups(self, tie_correction): - """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) - values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) From 76389bbeeef5b030af0ac7763d15c42bc85a2ce5 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 18:10:14 +0200 Subject: [PATCH 4/7] update notebooks --- notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks b/notebooks index e5c97b34..4cdaa44f 160000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -Subproject commit e5c97b34f4acbf919fb3118c987cc5893e5b5fdf +Subproject commit 4cdaa44fbd93b6f812fc8d2c72b89180ef92047d From f69f1d85e6f527ce71948c004f9ece281248ab1e Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 20:03:26 +0200 Subject: [PATCH 5/7] make dense faster --- CMakeLists.txt | 2 + .../_cuda/wilcoxon/kernels_wilcoxon.cuh | 351 +++---------- .../_cuda/wilcoxon/wilcoxon.cu | 483 ++++++++++++++++-- .../tools/_rank_genes_groups/_wilcoxon.py | 140 ++--- tests/test_rank_genes_groups_wilcoxon.py | 34 ++ 5 files changed, 629 insertions(+), 381 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 85fcfc2d..e880613d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,8 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + target_sources(_wilcoxon_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) + target_link_libraries(_wilcoxon_cuda PRIVATE rmm::rmm) add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index 3c42f60a..5af4e964 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -24,313 +24,118 @@ __device__ __forceinline__ double wilcoxon_block_sum(double val, } /** - * OVO dense rank core. + * OVR dense rank-sum kernel for data sorted by column. * - * ref_sorted is F-order and sorted independently for every column. - * grp_data is F-order and contains test-group rows concatenated by - * grp_offsets. One block computes one (column, test-group) result. - * - * This intentionally centralizes the OVO math; host/device and CSR/CSC/dense - * paths only need to materialize bounded dense column batches that feed this - * kernel. + * sorted_vals and sorted_row_idx are F-order arrays from a segmented + * SortPairs. One block owns one column, walks tie runs, and accumulates the + * average ranks per group without materializing a full rank matrix. */ -__global__ void ovo_rank_dense_kernel(const float* __restrict__ ref_sorted, - const float* __restrict__ grp_data, - const int* __restrict__ grp_offsets, - double* __restrict__ rank_sums, - double* __restrict__ tie_corr, int n_ref, - int n_all_grp, int n_cols, int n_groups, - bool compute_tie_corr) { +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; - int grp = blockIdx.y; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - const float* ref_col = ref_sorted + (long long)col * n_ref; - const float* grp_col = grp_data + (long long)col * n_all_grp + g_start; - - __shared__ double warp_buf[32]; - double local_rank = 0.0; - - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; + if (col >= n_cols) return; - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; + extern __shared__ double smem[]; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; + double* grp_sums; + if (use_gmem) { + grp_sums = rank_sums + (size_t)col; + } else { + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; } - int n_eq_ref = lo - n_lt_ref; + __syncthreads(); + } - int n_lt_grp = 0; - int n_eq_grp = 0; - for (int j = 0; j < n_grp; ++j) { - float u = grp_col[j]; - n_lt_grp += (u < v); - n_eq_grp += (u == v); - } + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; - local_rank += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; - } + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; - double total_rank = wilcoxon_block_sum(local_rank, warp_buf); - if (threadIdx.x == 0) { - rank_sums[(size_t)grp * n_cols + col] = total_rank; - } + double local_tie_sum = 0.0; + int acc_stride = use_gmem ? n_cols : 1; - if (!compute_tie_corr) return; - __syncthreads(); + int i = my_start; + while (i < my_end) { + double val = sv[i]; - double local_tie = 0.0; + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) { + ++tie_local_end; + } - for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { - if (i == 0 || ref_col[i] != ref_col[i - 1]) { - float v = ref_col[i]; - int lo = i + 1, hi = n_ref; + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0; + int hi = i; while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; else - hi = m; - } - int count = lo - i; - for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); - if (count > 1) { - double t = (double)count; - local_tie += t * t * t - t; + hi = mid; } + tie_global_start = lo; } - } - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; - bool seen_in_group = false; - for (int j = 0; j < i; ++j) { - if (grp_col[j] == v) { - seen_in_group = true; - break; + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end; + int hi = n_rows - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; } + tie_global_end = lo + 1; } - if (seen_in_group) continue; - - int lo = 0, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - if (lo < n_ref && ref_col[lo] == v) continue; - - int count = 0; - for (int j = 0; j < n_grp; ++j) count += (grp_col[j] == v); - if (count > 1) { - double t = (double)count; - local_tie += t * t * t - t; - } - } - - double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - tie_corr[(size_t)grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } -} - -__global__ void ovo_rank_presorted_kernel(const float* __restrict__ ref_sorted, - const float* __restrict__ grp_sorted, - const int* __restrict__ grp_offsets, - double* __restrict__ rank_sums, - double* __restrict__ tie_corr, - int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr) { - int col = blockIdx.x; - int grp = blockIdx.y; - if (col >= n_cols || grp >= n_groups) return; - - int g_start = grp_offsets[grp]; - int g_end = grp_offsets[grp + 1]; - int n_grp = g_end - g_start; - - const float* ref_col = ref_sorted + (long long)col * n_ref; - const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; - - __shared__ double warp_buf[32]; - double local_rank = 0.0; - - int ref_lb = 0, ref_ub = 0; - int grp_lb = 0, grp_ub = 0; - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - float v = grp_col[i]; - - int lo = ref_lb, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - int n_lt_ref = lo; - ref_lb = n_lt_ref; - lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; - hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int n_eq_ref = lo - n_lt_ref; - ref_ub = lo; + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - lo = grp_lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] < v) - lo = m + 1; - else - hi = m; + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } } - int n_lt_grp = lo; - grp_lb = n_lt_grp; - lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; } - int n_eq_grp = lo - n_lt_grp; - grp_ub = lo; - local_rank += (double)(n_lt_ref + n_lt_grp) + - ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; - } - - double total_rank = wilcoxon_block_sum(local_rank, warp_buf); - if (threadIdx.x == 0) { - rank_sums[(size_t)grp * n_cols + col] = total_rank; + i = tie_local_end; } - if (!compute_tie_corr) return; __syncthreads(); - double local_tie = 0.0; - int grp_lb_tie = 0, grp_ub_tie = 0; - for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { - if (i == 0 || ref_col[i] != ref_col[i - 1]) { - float v = ref_col[i]; - int lo = i + 1, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_ref = lo - i; - - lo = grp_lb_tie; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] < v) - lo = m + 1; - else - hi = m; - } - int lb = lo; - grp_lb_tie = lb; - - lo = (grp_ub_tie > lb) ? grp_ub_tie : lb; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt_grp = lo - lb; - grp_ub_tie = lo; - - int cnt = cnt_ref + cnt_grp; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; } } - int ref_lb_tie = 0; - for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { - if (i == 0 || grp_col[i] != grp_col[i - 1]) { - float v = grp_col[i]; - int lo = ref_lb_tie, hi = n_ref; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (ref_col[m] < v) - lo = m + 1; - else - hi = m; - } - ref_lb_tie = lo; - if (lo < n_ref && ref_col[lo] == v) continue; - - lo = i + 1; - hi = n_grp; - while (lo < hi) { - int m = lo + ((hi - lo) >> 1); - if (grp_col[m] <= v) - lo = m + 1; - else - hi = m; - } - int cnt = lo - i; - if (cnt > 1) { - double t = (double)cnt; - local_tie += t * t * t - t; - } + if (compute_tie_corr) { + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; + double tie_sum = wilcoxon_block_sum(local_tie_sum, warp_buf); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; } } - - double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); - if (threadIdx.x == 0) { - int n = n_ref + n_grp; - double dn = (double)n; - double denom = dn * dn * dn - dn; - tie_corr[(size_t)grp * n_cols + col] = - (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; - } } /** diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 38fc25ec..9212960b 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -1,44 +1,20 @@ #include +#include + +#include +#include +#include + #include "../nb_types.h" #include "kernels_wilcoxon.cuh" +#include "wilcoxon_fast_common.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovo_kernels.cuh" using namespace nb::literals; -// Constants for kernel launch configuration -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; -constexpr int OVO_THREADS_PER_BLOCK = 256; - -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; -} - -static inline void launch_ovo_rank_dense( - const float* ref_sorted, const float* grp_data, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, cudaStream_t stream) { - dim3 block(OVO_THREADS_PER_BLOCK); - dim3 grid(n_cols, n_groups); - ovo_rank_dense_kernel<<>>( - ref_sorted, grp_data, grp_offsets, rank_sums, tie_corr, n_ref, - n_all_grp, n_cols, n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovo_rank_dense_kernel); -} - -static inline void launch_ovo_rank_presorted( - const float* ref_sorted, const float* grp_sorted, const int* grp_offsets, - double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, - int n_groups, bool compute_tie_corr, cudaStream_t stream) { - dim3 block(OVO_THREADS_PER_BLOCK); - dim3 grid(n_cols, n_groups); - ovo_rank_presorted_kernel<<>>( - ref_sorted, grp_sorted, grp_offsets, rank_sums, tie_corr, n_ref, - n_all_grp, n_cols, n_groups, compute_tie_corr); - CUDA_CHECK_LAST_ERROR(ovo_rank_presorted_kernel); -} - static inline void launch_ovr_rank_dense( const float* sorted_vals, const int* sorter, const int* group_codes, double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, @@ -52,45 +28,435 @@ static inline void launch_ovr_rank_dense( CUDA_CHECK_LAST_ERROR(ovr_rank_dense_kernel); } +static void launch_ovr_rank_dense_streaming( + const float* block, const int* group_codes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) { + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + } + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + if (sub_items > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "Dense OVR sub-batch exceeds CUB int item limit"); + } + + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; ++i) { + cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); + } + + cudaEvent_t inputs_ready; + cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); + cudaEventRecord(inputs_ready, upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cudaStreamWaitEvent(streams[i], inputs_ready, 0); + } + + RmmScratchPool pool; + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = n_rows * sb_cols; + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + const float* keys_in = block + (size_t)col * n_rows; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + ++batch_idx; + } + + for (int s = 0; s < n_streams; ++s) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("CUDA error in dense OVR streaming rank: ") + + cudaGetErrorString(err)); + } + } + cudaEventDestroy(inputs_ready); + for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); +} + +static void launch_ovo_rank_dense_tiered_impl( + const float* ref_data, bool ref_is_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + std::vector h_offsets(n_groups + 1); + cudaStreamSynchronize(upstream_stream); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + if (sub_ref_items > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "Dense OVO reference sub-batch exceeds CUB int item limit"); + } + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + if (sub_grp_items > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "Dense OVO sub-batch exceeds CUB int item limit"); + } + + size_t grp_cub_temp_bytes = 0; + if (needs_tier3) { + int max_grp_seg = n_sort_groups * sub_batch_cols; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, grp_cub_temp_bytes, fk, fk, (int)sub_grp_items, + max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t ref_cub_temp_bytes = 0; + if (!ref_is_sorted) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_temp_bytes, fk, fk, (int)sub_ref_items, + sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; ++i) { + cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); + } + + cudaEvent_t inputs_ready; + cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); + cudaEventRecord(inputs_ready, upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cudaStreamWaitEvent(streams[i], inputs_ready, 0); + } + + RmmScratchPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* ref_sorted; + int* ref_seg_offsets; + uint8_t* ref_cub_temp; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* grp_cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + if (ref_is_sorted) { + bufs[s].ref_sorted = nullptr; + bufs[s].ref_seg_offsets = nullptr; + bufs[s].ref_cub_temp = nullptr; + } else { + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); + } + bufs[s].grp_cub_temp = + needs_tier3 ? pool.alloc(grp_cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = n_sort_groups * sub_batch_cols; + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = n_ref * sb_cols; + int sb_grp_items_actual = n_all_grp * sb_cols; + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = ref_data + (size_t)col * n_ref; + const float* grp_sub = grp_data + (size_t)col * n_all_grp; + if (!ref_is_sorted) { + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + size_t temp = ref_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.ref_cub_temp, temp, ref_sub, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + ref_sub = buf.ref_sorted; + } + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, grp_sub, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = n_sort_groups * sb_cols; + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + + size_t temp = grp_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.grp_cub_temp, temp, grp_sub, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + ++batch_idx; + } + + for (int s = 0; s < n_streams; ++s) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("CUDA error in dense OVO tiered rank: ") + + cudaGetErrorString(err)); + } + } + cudaEventDestroy(inputs_ready); + for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); +} + +static void launch_ovo_rank_dense_tiered( + const float* ref_sorted, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + launch_ovo_rank_dense_tiered_impl(ref_sorted, true, grp_data, grp_offsets, + rank_sums, tie_corr, n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, upstream_stream); +} + +static void launch_ovo_rank_dense_tiered_unsorted_ref( + const float* ref_data, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + launch_ovo_rank_dense_tiered_impl(ref_data, false, grp_data, grp_offsets, + rank_sums, tie_corr, n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, upstream_stream); +} + template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; m.def( - "ovo_rank_dense", + "ovo_rank_dense_tiered", [](gpu_array_f ref_sorted, gpu_array_f grp_data, gpu_array_c grp_offsets, gpu_array_c rank_sums, gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_ovo_rank_dense( - ref_sorted.data(), grp_data.data(), grp_offsets.data(), - rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, (cudaStream_t)stream); + launch_ovo_rank_dense_tiered(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); }, "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, + "stream"_a = 0); m.def( - "ovo_rank_presorted", - [](gpu_array_f ref_sorted, - gpu_array_f grp_sorted, + "ovo_rank_dense_tiered_unsorted_ref", + [](gpu_array_f ref_data, + gpu_array_f grp_data, gpu_array_c grp_offsets, gpu_array_c rank_sums, gpu_array_c tie_corr, int n_ref, int n_all_grp, - int n_cols, int n_groups, bool compute_tie_corr, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_ovo_rank_presorted( - ref_sorted.data(), grp_sorted.data(), grp_offsets.data(), + launch_ovo_rank_dense_tiered_unsorted_ref( + ref_data.data(), grp_data.data(), grp_offsets.data(), rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, - n_groups, compute_tie_corr, (cudaStream_t)stream); + n_groups, compute_tie_corr, sub_batch_cols, + (cudaStream_t)stream); }, - "ref_sorted"_a, "grp_sorted"_a, "grp_offsets"_a, "rank_sums"_a, + "ref_data"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, - "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, + "stream"_a = 0); m.def( "ovr_rank_dense", @@ -108,6 +474,23 @@ void register_bindings(nb::module_& m) { "sorted_vals"_a, "sorter"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovr_rank_dense_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + std::uintptr_t stream) { + launch_ovr_rank_dense_streaming( + block.data(), group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); + }, + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index 4fec5948..b96cfee6 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -21,13 +21,60 @@ MIN_GROUP_SIZE_WARNING = 25 DEFAULT_WILCOXON_CHUNK_SIZE = 512 -OVO_SORT_GROUP_THRESHOLD = 512 OVR_HOST_CSC_SUB_BATCH = 512 OVR_HOST_CSR_SUB_BATCH = 2048 OVR_DEVICE_CSC_SUB_BATCH = 2048 OVR_DEVICE_CSR_SUB_BATCH = 2048 OVO_HOST_SPARSE_SUB_BATCH = 256 OVO_DEVICE_SPARSE_SUB_BATCH = 128 +OVR_DENSE_SUB_BATCH = 64 +OVO_DENSE_TIERED_SUB_BATCH = 256 +DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 + + +def _maybe_preload_host_dense(rg: _RankGenes) -> None: + X = rg.X + if not isinstance(X, np.ndarray) or X.size == 0: + return + + try: + _, total = cp.cuda.runtime.memGetInfo() + except cp.cuda.runtime.CUDARuntimeError: + return + + if X.nbytes > total * DENSE_HOST_PRELOAD_MAX_GPU_FRACTION: + return + + registered = False + if X.flags.c_contiguous or X.flags.f_contiguous: + try: + cp.cuda.runtime.hostRegister(X.ctypes.data, X.nbytes, 0) + registered = True + except cp.cuda.runtime.CUDARuntimeError: + registered = False + + try: + X_gpu = cp.asarray(X) + cp.cuda.get_current_stream().synchronize() + except cp.cuda.memory.OutOfMemoryError: + cp.get_default_memory_pool().free_all_blocks() + return + except cp.cuda.runtime.CUDARuntimeError: + return + finally: + if registered: + try: + cp.cuda.runtime.hostUnregister(X.ctypes.data) + except cp.cuda.runtime.CUDARuntimeError: + pass + rg.X = X_gpu + + +def _get_dense_column_block_f32(X, start: int, stop: int) -> cp.ndarray: + """Extract a dense column block as F-order float32 CuPy memory.""" + if isinstance(X, np.ndarray | cp.ndarray): + return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") + raise TypeError(f"Expected dense matrix, got {type(X)}") def _extract_dense_rows_cols( @@ -333,6 +380,7 @@ def wilcoxon( return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" + _maybe_preload_host_dense(rg) # Compute basic stats - uses Aggregate if on GPU, else defers to chunks rg._basic_stats() X = rg.X @@ -591,32 +639,29 @@ def _wilcoxon_vs_rest( for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) - # Slice and convert to dense GPU array (F-order for column ops) - block = _get_column_block(X, start, stop) - - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, - ) + if rg._compute_stats_in_chunks: + block = _get_column_block(X, start, stop) + rg._accumulate_chunk_stats_vs_rest( + block, + start, + stop, + group_matrix=group_matrix, + group_sizes_dev=group_sizes_dev, + n_cells=n_cells, + ) + block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) + else: + block_f32 = _get_dense_column_block_f32(X, start, stop) - block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) - sorter = cp.asfortranarray(cp.argsort(block_f32, axis=0).astype(cp.int32)) - sorted_vals = cp.asfortranarray(cp.take_along_axis(block_f32, sorter, axis=0)) n_cols = stop - start - rank_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) tie_corr = ( cp.empty(n_cols, dtype=cp.float64) if tie_correct else cp.ones(n_cols, dtype=cp.float64) ) - _wc.ovr_rank_dense( - sorted_vals, - sorter, + _wc.ovr_rank_dense_streaming( + block_f32, group_codes_gpu, rank_sums, tie_corr, @@ -624,6 +669,7 @@ def _wilcoxon_vs_rest( n_cols=n_cols, n_groups=n_groups, compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DENSE_SUB_BATCH, stream=cp.cuda.get_current_stream().ptr, ) expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 @@ -713,8 +759,6 @@ def _wilcoxon_with_reference( offsets_gpu = cp.asarray(offsets_np) n_all_grp = int(all_grp_row_ids.size) n_test = len(test_group_indices) - max_test_size = int(np.diff(offsets_np).max(initial=0)) - use_presorted_groups = max_test_size > OVO_SORT_GROUP_THRESHOLD test_sizes = cp.asarray( group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( np.float64, copy=False @@ -976,45 +1020,25 @@ def _wilcoxon_with_reference( group_sizes=group_sizes, ) - ref_sorted = cp.asfortranarray(cp.sort(ref_block.astype(cp.float32), axis=0)) - grp_f32 = cp.asfortranarray(grp_block.astype(cp.float32, copy=False)) + ref_f32 = cp.asarray(ref_block, dtype=cp.float32, order="F") + grp_f32 = cp.asarray(grp_block, dtype=cp.float32, order="F") rank_sums = cp.empty((n_test, n_cols), dtype=cp.float64) tie_corr = cp.empty((n_test, n_cols), dtype=cp.float64) - if use_presorted_groups: - grp_rank_input = cp.empty_like(grp_f32) - for slot in range(n_test): - begin = int(offsets_np[slot]) - end = int(offsets_np[slot + 1]) - grp_rank_input[begin:end] = cp.sort(grp_f32[begin:end], axis=0) - grp_rank_input = cp.asfortranarray(grp_rank_input) - _wc.ovo_rank_presorted( - ref_sorted, - grp_rank_input, - offsets_gpu, - rank_sums, - tie_corr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_cols, - n_groups=n_test, - compute_tie_corr=tie_correct, - stream=cp.cuda.get_current_stream().ptr, - ) - else: - _wc.ovo_rank_dense( - ref_sorted, - grp_f32, - offsets_gpu, - rank_sums, - tie_corr, - n_ref=n_ref, - n_all_grp=n_all_grp, - n_cols=n_cols, - n_groups=n_test, - compute_tie_corr=tie_correct, - stream=cp.cuda.get_current_stream().ptr, - ) + _wc.ovo_rank_dense_tiered_unsorted_ref( + ref_f32, + grp_f32, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DENSE_TIERED_SUB_BATCH, + stream=cp.cuda.get_current_stream().ptr, + ) n_combined = test_sizes + n_ref expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 413cfe3b..6e3dbf89 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -244,6 +244,40 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars assert params["reference"] == reference +def test_rank_genes_groups_wilcoxon_dense_ovr_ties_match_scanpy(): + rng = np.random.default_rng(16) + X = rng.integers(0, 40, size=(128, 7)).astype(np.float32) + labels = rng.integers(0, 7, size=128).astype(str) + adata_gpu = sc.AnnData( + X=X.copy(), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "rest", + "use_raw": False, + "tie_correct": True, + "n_genes": adata_gpu.n_vars, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for group in gpu_result["scores"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + np.testing.assert_allclose( + gpu_result["scores"][group], cpu_result["scores"][group], rtol=1e-13 + ) + np.testing.assert_allclose( + gpu_result["pvals"][group], cpu_result["pvals"][group], rtol=1e-13 + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) def test_rank_genes_groups_wilcoxon_honors_layer_and_use_raw(reference): """Test that layer parameter is respected.""" From a0e9b0c3b1fe701c22baad8e63f43de2dfda3bf9 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 24 Apr 2026 23:08:15 +0200 Subject: [PATCH 6/7] update tests and fix issues --- .github/workflows/publish.yml | 6 +- .gitignore | 2 +- CMakeLists.txt | 14 +- pyproject.toml | 4 +- .../_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh | 2 + .../_cuda/wilcoxon/wilcoxon.cu | 48 ++--- .../_cuda/wilcoxon/wilcoxon_fast_common.cuh | 44 ++++- .../wilcoxon/wilcoxon_ovo_device_sparse.cuh | 68 +++++-- .../wilcoxon/wilcoxon_ovo_host_sparse.cuh | 128 +++++++++---- .../_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh | 23 +-- .../_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh | 53 ++++-- .../_cuda/wilcoxon/wilcoxon_rmm.cu | 4 +- .../wilcoxon/wilcoxon_sparse_kernels.cuh | 23 +-- .../tools/_rank_genes_groups/__init__.py | 147 ++++----------- .../tools/_rank_genes_groups/_core.py | 68 ++++++- .../tools/_rank_genes_groups/_utils.py | 55 +++--- .../tools/_rank_genes_groups/_wilcoxon.py | 51 +++-- .../_rank_genes_groups/_wilcoxon_binned.py | 6 +- tests/test_rank_genes_groups_ttest.py | 35 ++-- tests/test_rank_genes_groups_wilcoxon.py | 177 +++++++++++++++--- 20 files changed, 626 insertions(+), 332 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4ca5d522..0ea4ee5e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -147,6 +147,8 @@ jobs: LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > + rm -f build/.librmm_dir && + mkdir -p build && python -m pip install -U pip scikit-build-core cmake ninja nanobind librmm-cu${{ matrix.cuda_major }} && @@ -157,8 +159,8 @@ jobs: ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && ldconfig && - python -c "import librmm; print(librmm.__path__[0])" > /tmp/.librmm_dir && - echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)" + python -c "import librmm; print(librmm.__path__[0])" > build/.librmm_dir && + echo "[rsc-build] marker=$(cat build/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" diff --git a/.gitignore b/.gitignore index 6994e147..4a7497ba 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ CLAUDE.md # tmp_scripts tmp_scripts/ -benchmarks/ +/benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index e880613d..4e404263 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,9 +50,19 @@ if (RSC_BUILD_EXTENSIONS) if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") endif() - if(EXISTS "/tmp/.librmm_dir") - file(READ "/tmp/.librmm_dir" _rsc_librmm_marker) + # Wheel builds install librmm/rapids_logger into the isolated build env and + # write build/.librmm_dir from CIBW_BEFORE_BUILD. publish.yml also symlinks + # those shared libraries into /usr/local/lib so auditwheel can see and exclude + # them instead of bundling RAPIDS runtime libraries into the wheel. + if(DEFINED ENV{RSC_LIBRMM_DIR} AND EXISTS "$ENV{RSC_LIBRMM_DIR}/lib64/cmake/rmm/rmm-config.cmake") + set(_rsc_librmm_marker "$ENV{RSC_LIBRMM_DIR}") + elseif(EXISTS "${CMAKE_SOURCE_DIR}/build/.librmm_dir") + file(READ "${CMAKE_SOURCE_DIR}/build/.librmm_dir" _rsc_librmm_marker) string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker) + else() + set(_rsc_librmm_marker "") + endif() + if(NOT "${_rsc_librmm_marker}" STREQUAL "" AND EXISTS "${_rsc_librmm_marker}/lib64/cmake/rmm/rmm-config.cmake") file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm") file(GLOB _rsc_marker_rapids_prefixes "${_rsc_librmm_marker}/lib64" diff --git a/pyproject.toml b/pyproject.toml index a3b07ede..b4940b18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,9 @@ requires = [ "nanobind>=2.0.0", "setuptools-scm>=8", # librmm headers/CMake config are needed at build time for Wilcoxon. - # CUDA wheel builds rewrite this to the matching cu12/cu13 package. + # Generic isolated source builds default to CUDA 12. CUDA wheel builds + # rewrite this to the matching cu12/cu13 package; CUDA 13 source builds + # should build in an existing RAPIDS env with --no-build-isolation. "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh index 5b4c0b8c..a8e9ed4f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -2,6 +2,8 @@ #include +#include "wilcoxon_fast_common.cuh" + // ============================================================================ // Warp reduction helper (sum doubles across block via warp_buf) // ============================================================================ diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index 9212960b..d314b289 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -41,17 +41,14 @@ static void launch_ovr_rank_dense_streaming( } size_t sub_items = (size_t)n_rows * sub_batch_cols; - if (sub_items > (size_t)std::numeric_limits::max()) { - throw std::runtime_error( - "Dense OVR sub-batch exceeds CUB int item limit"); - } + int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch"); size_t cub_temp_bytes = 0; { auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items, + nullptr, cub_temp_bytes, fk, fk, iv, iv, sub_items_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -97,7 +94,8 @@ static void launch_ovr_rank_dense_streaming( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_items = n_rows * sb_cols; + int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Dense OVR active sub-batch"); int s = batch_idx % n_streams; cudaStream_t stream = streams[s]; auto& buf = bufs[s]; @@ -184,32 +182,30 @@ static void launch_ovo_rank_dense_tiered_impl( n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; - if (sub_ref_items > (size_t)std::numeric_limits::max()) { - throw std::runtime_error( - "Dense OVO reference sub-batch exceeds CUB int item limit"); - } + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "Dense OVO reference sub-batch"); size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; - if (sub_grp_items > (size_t)std::numeric_limits::max()) { - throw std::runtime_error( - "Dense OVO sub-batch exceeds CUB int item limit"); - } + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "Dense OVO group sub-batch"); size_t grp_cub_temp_bytes = 0; if (needs_tier3) { - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "Dense OVO group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, grp_cub_temp_bytes, fk, fk, (int)sub_grp_items, - max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); + nullptr, grp_cub_temp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); } size_t ref_cub_temp_bytes = 0; if (!ref_is_sorted) { auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_temp_bytes, fk, fk, (int)sub_ref_items, + nullptr, ref_cub_temp_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); } @@ -270,7 +266,9 @@ static void launch_ovo_rank_dense_tiered_impl( pool.alloc((size_t)n_groups * sub_batch_cols); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_seg = n_sort_groups * sub_batch_cols; + int max_seg = checked_int_product((size_t)n_sort_groups, + (size_t)sub_batch_cols, + "Dense OVO group segment buffer"); bufs[s].grp_seg_offsets = pool.alloc(max_seg); bufs[s].grp_seg_ends = pool.alloc(max_seg); } else { @@ -287,8 +285,12 @@ static void launch_ovo_rank_dense_tiered_impl( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_items_actual = n_ref * sb_cols; - int sb_grp_items_actual = n_all_grp * sb_cols; + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "Dense OVO active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "Dense OVO active group sub-batch"); int s = batch_idx % n_streams; cudaStream_t stream = streams[s]; auto& buf = bufs[s]; @@ -343,7 +345,9 @@ static void launch_ovo_rank_dense_tiered_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "Dense OVO active group segment count"); int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<<>>( diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh index 2d5b3f2c..ec723b55 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -48,6 +49,39 @@ constexpr int TIER1_GROUP_THRESHOLD = 2500; // 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; +static inline size_t wilcoxon_max_smem_per_block() { + int device = 0; + int max_smem = 0; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + device); + return (size_t)max_smem; +} + +static inline int checked_cub_items(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds CUB int item limit"); + } + return (int)count; +} + +static inline int checked_int_span(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds int32 offset limit"); + } + return (int)count; +} + +static inline int checked_int_product(size_t a, size_t b, const char* context) { + if (a != 0 && b > (size_t)std::numeric_limits::max() / a) { + throw std::runtime_error(std::string(context) + + " exceeds int32 item limit"); + } + return (int)(a * b); +} + // --------------------------------------------------------------------------- // RAII guard for cudaHostRegister. Unregisters on scope exit even when an // exception unwinds — prevents leaked host pinning on stream-sync failures. @@ -60,9 +94,9 @@ struct HostRegisterGuard { if (p && bytes > 0) { cudaError_t err = cudaHostRegister(p, bytes, flags); if (err != cudaSuccess) { - // Already-registered memory is fine; anything else means the - // subsequent kernels would read garbage from an unmapped - // pointer, so surface the error immediately. + // Already-registered memory belongs to another owner; use it + // without unregistering here. Other failures mean mapped reads + // would be unsafe, so surface them immediately. if (err == cudaErrorHostMemoryAlreadyRegistered) { cudaGetLastError(); // clear sticky error flag } else { @@ -116,6 +150,10 @@ struct RmmScratchPool { template T* alloc(size_t count) { if (count == 0) count = 1; + if (count > std::numeric_limits::max() / sizeof(T)) { + throw std::runtime_error( + "Wilcoxon scratch allocation size overflow"); + } size_t bytes = count * sizeof(T); void* ptr = wilcoxon_rmm_allocate(bytes); bufs.push_back({ptr, bytes}); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh index b195bee0..b60b87ff 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -64,17 +64,23 @@ static void ovo_streaming_csr_impl( size_t cub_temp_bytes = 0; if (needs_tier3) { size_t cub_grp_bytes = 0; - int max_grp_seg = n_sort_groups * sub_batch_cols; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSR group sub-batch"); + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); cub_temp_bytes = cub_grp_bytes; } std::vector streams(n_streams); for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + cudaStream_t ref_stream; + cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); int* d_sort_group_ids = nullptr; if (needs_tier3) { @@ -110,7 +116,9 @@ static void ovo_streaming_csr_impl( pool.alloc((size_t)n_groups * sub_batch_cols); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_seg = n_sort_groups * sub_batch_cols; + int max_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment buffer"); bufs[s].grp_seg_offsets = pool.alloc(max_seg); bufs[s].grp_seg_ends = pool.alloc(max_seg); } else { @@ -127,6 +135,8 @@ static void ovo_streaming_csr_impl( for (int cache_col = 0; cache_col < n_cols; cache_col += ref_cache_cols) { int cache_cols = std::min(ref_cache_cols, n_cols - cache_col); size_t cache_ref_items = (size_t)n_ref * cache_cols; + int cache_ref_items_i32 = checked_cub_items( + cache_ref_items, "OVO device CSR reference cache"); ScopedCudaBuffer ref_dense_buf(cache_ref_items * sizeof(float)); ScopedCudaBuffer ref_sorted_buf(cache_ref_items * sizeof(float)); @@ -136,36 +146,39 @@ static void ovo_streaming_csr_impl( float* d_ref_sorted = (float*)ref_sorted_buf.data(); int* d_ref_seg_offsets = (int*)ref_seg_offsets_buf.data(); - cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float)); + cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float), + ref_stream); int tpb_ref_extract = round_up_to_warp(n_ref); int ref_blk = (n_ref + tpb_ref_extract - 1) / tpb_ref_extract; - csr_extract_dense_kernel<<>>( + csr_extract_dense_kernel<<>>( csr_data, csr_indices, csr_indptr, ref_row_ids, d_ref_dense, n_ref, cache_col, cache_col + cache_cols); CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); - upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, 0); + upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, ref_stream); size_t ref_cub_bytes = 0; auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_bytes, fk, fk, (int)cache_ref_items, cache_cols, + nullptr, ref_cub_bytes, fk, fk, cache_ref_items_i32, cache_cols, doff, doff + 1, BEGIN_BIT, END_BIT); ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); size_t ref_temp = ref_cub_bytes; cub::DeviceSegmentedRadixSort::SortKeys( ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, - (int)cache_ref_items, cache_cols, d_ref_seg_offsets, - d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT); - cudaDeviceSynchronize(); + cache_ref_items_i32, cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT, ref_stream); + cudaStreamSynchronize(ref_stream); int col = cache_col; int cache_stop = cache_col + cache_cols; int batch_idx = 0; while (col < cache_stop) { int sb_cols = std::min(sub_batch_cols, cache_stop - col); - int sb_grp_items_actual = n_all_grp * sb_cols; + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSR active group sub-batch"); int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; @@ -224,7 +237,9 @@ static void ovo_streaming_csr_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sb_cols, + "OVO device CSR active group segment count"); { int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; @@ -277,6 +292,7 @@ static void ovo_streaming_csr_impl( } } for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); + cudaStreamDestroy(ref_stream); } /** @@ -316,23 +332,29 @@ static void ovo_streaming_csc_impl( size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "OVO device CSC reference sub-batch"); + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSC group sub-batch"); size_t cub_ref_bytes = 0; { auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); } size_t cub_temp_bytes = cub_ref_bytes; if (needs_tier3) { size_t cub_grp_bytes = 0; - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } @@ -380,7 +402,9 @@ static void ovo_streaming_csc_impl( pool.alloc((size_t)n_groups * sub_batch_cols); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment buffer"); bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); } else { @@ -397,8 +421,12 @@ static void ovo_streaming_csc_impl( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_items_actual = n_ref * sb_cols; - int sb_grp_items_actual = n_all_grp * sb_cols; + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO device CSC active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSC active group sub-batch"); int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; @@ -465,7 +493,9 @@ static void ovo_streaming_csc_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sb_cols, + "OVO device CSC active group segment count"); { int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<<(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_ref_bytes, fk, fk, (int)sub_ref_items, sub_batch_cols, + nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); } size_t cub_temp_bytes = cub_ref_bytes; if (needs_tier3) { size_t cub_grp_bytes = 0; - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC group segment count"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)sub_grp_items, max_grp_seg, + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT); cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); } @@ -82,8 +88,11 @@ static void ovo_streaming_csc_host_impl( int sb = std::min(sub_batch_cols, n_cols - col_start); IndptrT ptr_start = h_indptr[col_start]; int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) - off[i] = (int)(h_indptr[col_start + i] - ptr_start); + for (int i = 0; i <= sb; i++) { + off[i] = + checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), + "OVO host CSC rebased column offsets"); + } } int* d_all_offsets = pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); @@ -159,7 +168,9 @@ static void ovo_streaming_csc_host_impl( compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); if (needs_tier3) { bufs[s].grp_sorted = pool.alloc(sub_grp_items); - int max_grp_seg = n_sort_groups * sub_batch_cols; + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC stream group segment count"); bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); } else { @@ -186,8 +197,12 @@ static void ovo_streaming_csc_host_impl( int batch_idx = 0; while (col < n_cols) { int sb_cols = std::min(sub_batch_cols, n_cols - col); - int sb_ref_actual = n_ref * sb_cols; - int sb_grp_actual = n_all_grp * sb_cols; + int sb_ref_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO host CSC active reference sub-batch"); + int sb_grp_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO host CSC active group sub-batch"); int s = batch_idx % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; @@ -196,6 +211,7 @@ static void ovo_streaming_csc_host_impl( IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; size_t nnz = (size_t)(ptr_end - ptr_start); + checked_int_span(nnz, "OVO host CSC active batch nnz"); cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, @@ -276,7 +292,9 @@ static void ovo_streaming_csc_host_impl( compute_tie_corr, padded_grp_size, upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (needs_tier3) { - int sb_grp_seg = n_sort_groups * sb_cols; + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "OVO host CSC active group segment count"); { int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<< (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference row exceeds int32 compacted nnz limit"); + } + int nnz_i = (int)row_nnz; + if ((size_t)h_ref_indptr_compact[i] + (size_t)nnz_i > + (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference compacted nnz exceeds int32 limit"); + } h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; } int ref_nnz = h_ref_indptr_compact[n_ref]; @@ -462,8 +490,11 @@ static void ovo_streaming_csr_host_impl( if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; if (K > max_pack_K) max_pack_K = K; - int pack_items = pk.n_rows * pk.sb_cols; + int pack_items = + checked_int_product((size_t)pk.n_rows, (size_t)pk.sb_cols, + "OVO host CSR pack dense slab"); if (pack_items > max_pack_items) max_pack_items = pack_items; + checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; } int max_group_rows = max_pack_rows; @@ -530,12 +561,37 @@ static void ovo_streaming_csr_host_impl( cudaMemcpyHostToDevice); // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- - float* d_ref_sorted = pool.alloc((size_t)n_ref * n_cols); + size_t ref_items = (size_t)n_ref * (size_t)n_cols; + if (n_ref > 0 && (size_t)n_cols > (size_t)std::numeric_limits::max() / + (size_t)n_ref) { + throw std::runtime_error( + "OVO host CSR dense reference cache exceeds CUB int item limit; " + "use native CSC/device sparse input or reduce genes/reference " + "size"); + } + if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { + throw std::runtime_error( + "OVO host CSR dense reference cache size overflows size_t"); + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess && + total_bytes > 0 && ref_items * 2 * sizeof(float) > total_bytes) { + throw std::runtime_error( + "OVO host CSR dense reference cache requires more GPU memory than " + "the device provides; use native CSC/device sparse input or reduce " + "genes/reference size"); + } + int ref_items_i32 = + checked_cub_items(ref_items, "OVO host CSR dense reference cache"); + float* d_ref_sorted = pool.alloc(ref_items); + cudaStream_t ref_stream; + cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); { ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); - ScopedCudaBuffer ref_dense_buf((size_t)n_ref * n_cols * sizeof(float)); + ScopedCudaBuffer ref_dense_buf(ref_items * sizeof(float)); ScopedCudaBuffer ref_seg_buf((n_cols + 1) * sizeof(int)); float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); @@ -552,7 +608,7 @@ static void ovo_streaming_csr_host_impl( // pass over PCIe, no intermediate native-dtype GPU buffer. if (n_ref > 0 && ref_nnz > 0) { csr_gather_cast_accumulate_mapped_kernel - <<>>( + <<>>( d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, d_ref_indptr, /*d_stats_codes=*/nullptr, /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, @@ -562,12 +618,12 @@ static void ovo_streaming_csr_host_impl( } // Extract ref dense (F-order) from compacted CSR. - cudaMemsetAsync(d_ref_dense, 0, (size_t)n_ref * n_cols * sizeof(float)); + cudaMemsetAsync(d_ref_dense, 0, ref_items * sizeof(float), ref_stream); { csr_extract_dense_identity_rows_unsorted_kernel - <<>>(d_ref_data_f32, d_ref_indices, - d_ref_indptr, d_ref_dense, n_ref, - 0, n_cols); + <<>>( + d_ref_data_f32, d_ref_indices, d_ref_indptr, d_ref_dense, + n_ref, 0, n_cols); CUDA_CHECK_LAST_ERROR( csr_extract_dense_identity_rows_unsorted_kernel); } @@ -578,18 +634,18 @@ static void ovo_streaming_csr_host_impl( auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, ref_cub_bytes, fk, fk, (int)((size_t)n_ref * n_cols), - n_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + nullptr, ref_cub_bytes, fk, fk, ref_items_i32, n_cols, doff, + doff + 1, BEGIN_BIT, END_BIT); } ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); - upload_linear_offsets(d_ref_seg, n_cols, n_ref, 0); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, ref_stream); size_t temp = ref_cub_bytes; cub::DeviceSegmentedRadixSort::SortKeys( - cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, - (int)((size_t)n_ref * n_cols), n_cols, d_ref_seg, d_ref_seg + 1, - BEGIN_BIT, END_BIT); - cudaDeviceSynchronize(); + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, ref_items_i32, + n_cols, d_ref_seg, d_ref_seg + 1, BEGIN_BIT, END_BIT, ref_stream); + cudaStreamSynchronize(ref_stream); } // ref scratch drops here + cudaStreamDestroy(ref_stream); // ---- Phase 2: Per-pack streaming ---- auto t1 = make_tier1_config(h_grp_offsets, n_test); @@ -604,11 +660,15 @@ static void ovo_streaming_csr_host_impl( size_t cub_grp_bytes = 0; if (may_need_cub && max_sub_items > 0) { + int max_sub_items_i32 = + checked_cub_items(max_sub_items, "OVO host CSR group pack"); auto* fk = reinterpret_cast(1); auto* doff = reinterpret_cast(1); - int max_segments = max_pack_K * max_pack_sb_cols; + int max_segments = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR max group segment count"); cub::DeviceSegmentedRadixSort::SortKeys( - nullptr, cub_grp_bytes, fk, fk, (int)max_sub_items, max_segments, + nullptr, cub_grp_bytes, fk, fk, max_sub_items_i32, max_segments, doff, doff + 1, BEGIN_BIT, END_BIT); } @@ -632,7 +692,9 @@ static void ovo_streaming_csr_host_impl( double* d_tie_corr; }; std::vector bufs(n_streams); - int max_pack_kernel_seg = max_pack_K * max_pack_sb_cols; + int max_pack_kernel_seg = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR pack segment buffer"); for (int s = 0; s < n_streams; s++) { bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); @@ -660,8 +722,6 @@ static void ovo_streaming_csr_host_impl( } } - cudaDeviceSynchronize(); // ensure Phase 1 done before Phase 2 streams - for (int p = 0; p < (int)packs.size(); p++) { const Pack& pack = packs[p]; int K = pack.end - pack.first; @@ -742,7 +802,9 @@ static void ovo_streaming_csr_host_impl( int col = 0; while (col < n_cols) { int sb_cols = std::min(pack_sb, n_cols - col); - int sb_items = pack_rows * sb_cols; + int sb_items = + checked_int_product((size_t)pack_rows, (size_t)sb_cols, + "OVO host CSR active group sub-batch"); cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), stream); @@ -798,7 +860,9 @@ static void ovo_streaming_csr_host_impl( upper_skip_le); CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); } else if (pack_has_above_t2) { - int n_seg = pack_n_sort_groups * sb_cols; + int n_seg = checked_int_product( + (size_t)pack_n_sort_groups, (size_t)sb_cols, + "OVO host CSR active group segment count"); { int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; build_tier3_seg_begin_end_offsets_kernel<<< diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh index 006002b9..2323e27f 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -4,7 +4,7 @@ template __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, const IndptrT* __restrict__ indptr, - int* __restrict__ col_counts, + unsigned int* __restrict__ col_counts, int n_rows, int n_cols) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= n_rows) return; @@ -12,7 +12,7 @@ __global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, IndptrT re = indptr[row + 1]; for (IndptrT p = rs; p < re; ++p) { int c = (int)indices[p]; - if (c < n_cols) atomicAdd(&col_counts[c], 1); + if (c < n_cols) atomicAdd(&col_counts[c], 1u); } } @@ -49,24 +49,9 @@ __global__ void csr_scatter_to_csc_kernel( } } -/** - * Decide whether to use shared or global memory for OVR rank accumulators. - * Returns the smem size to request and sets use_gmem accordingly. - */ -static int query_max_smem_per_block() { - static int cached = -1; - if (cached < 0) { - int device; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, - device); - } - return cached; -} - static size_t ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(n_groups + 32) * sizeof(double); - if ((int)need <= query_max_smem_per_block()) { + if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } @@ -81,7 +66,7 @@ static size_t ovr_smem_config(int n_groups, bool& use_gmem) { */ static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); - if ((int)need <= query_max_smem_per_block()) { + if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh index 6eae2a28..257bbbb3 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -32,10 +32,12 @@ static void ovr_sparse_csc_host_streaming_impl( // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR host CSC sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -98,8 +100,11 @@ static void ovr_sparse_csc_host_streaming_impl( int sb = std::min(sub_batch_cols, n_cols - col_start); IndptrT ptr_start = h_indptr[col_start]; int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; - for (int i = 0; i <= sb; i++) - off[i] = (int)(h_indptr[col_start + i] - ptr_start); + for (int i = 0; i <= sb; i++) { + off[i] = + checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), + "OVR host CSC rebased column offsets"); + } } int* d_all_offsets = pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); @@ -143,7 +148,8 @@ static void ovr_sparse_csc_host_streaming_impl( IndptrT ptr_start = h_indptr[col]; IndptrT ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = (int)(ptr_end - ptr_start); + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR host CSC active batch nnz"); // H2D: transfer sparse data for this column range (native dtype) if (batch_nnz > 0) { @@ -263,7 +269,7 @@ static void ovr_sparse_csr_host_streaming_impl( size_t total_nnz = (size_t)h_indptr[n_rows]; // ---- Phase 0: CPU planning in native CSR order ---- - std::vector h_col_counts(n_cols, 0); + std::vector h_col_counts(n_cols, 0); for (int row = 0; row < n_rows; row++) { IndptrT rs = h_indptr[row]; IndptrT re = h_indptr[row + 1]; @@ -282,7 +288,9 @@ static void ovr_sparse_csr_host_streaming_impl( int sb_cols = std::min(sub_batch_cols, n_cols - col_start); int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; for (int i = 0; i < sb_cols; i++) - off[i + 1] = off[i] + h_col_counts[col_start + i]; + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)h_col_counts[col_start + i], + "OVR host CSR rebased column offsets"); h_batch_nnz[b] = (size_t)off[sb_cols]; if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } @@ -295,10 +303,12 @@ static void ovr_sparse_csr_host_streaming_impl( // ---- Phase 1: allocate per-stream bounded work buffers ---- size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR host CSR sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -427,7 +437,8 @@ static void ovr_sparse_csr_host_streaming_impl( int s = b % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - int batch_nnz = (int)h_batch_nnz[b]; + int batch_nnz = + checked_int_span(h_batch_nnz[b], "OVR host CSR active batch nnz"); int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), @@ -546,10 +557,12 @@ static void ovr_sparse_csc_streaming_impl( // CUB temp size for max_nnz items size_t cub_temp_bytes = 0; if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR device CSC sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -597,7 +610,8 @@ static void ovr_sparse_csc_streaming_impl( int ptr_start = h_indptr[col]; int ptr_end = h_indptr[col + sb_cols]; - int batch_nnz = ptr_end - ptr_start; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR device CSC active batch nnz"); // Compute rebased segment offsets on GPU (avoids host pinned-buffer // race) @@ -684,16 +698,16 @@ static void ovr_sparse_csr_streaming_impl( // ---- Phase 0: Planning — count nnz per column via histogram ---- RmmScratchPool pool; - int* d_col_counts = pool.alloc(n_cols); - cudaMemset(d_col_counts, 0, n_cols * sizeof(int)); + unsigned int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); { int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; csr_col_histogram_kernel<<>>( csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); } - std::vector h_col_counts(n_cols); - cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(int), + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(unsigned int), cudaMemcpyDeviceToHost); // Per-batch prefix sums on host @@ -710,7 +724,9 @@ static void ovr_sparse_csr_streaming_impl( int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; off[0] = 0; for (int i = 0; i < sb_cols; i++) - off[i + 1] = off[i] + h_col_counts[col_start + i]; + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)h_col_counts[col_start + i], + "OVR device CSR rebased column offsets"); h_batch_nnz[b] = (size_t)off[sb_cols]; if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; } @@ -724,10 +740,12 @@ static void ovr_sparse_csr_streaming_impl( // ---- Phase 1: Allocate per-stream buffers ---- size_t cub_temp_bytes = 0; if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR device CSR sparse sub-batch nnz"); auto* fk = reinterpret_cast(1); auto* iv = reinterpret_cast(1); cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)max_batch_nnz, + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); } @@ -796,7 +814,8 @@ static void ovr_sparse_csr_streaming_impl( int s = b % n_streams; auto stream = streams[s]; auto& buf = bufs[s]; - int batch_nnz = (int)h_batch_nnz[b]; + int batch_nnz = + checked_int_span(h_batch_nnz[b], "OVR device CSR active batch nnz"); // D2D copy pre-computed col_offsets for this batch int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu index 26e37f42..94a101e9 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu @@ -10,8 +10,8 @@ void* wilcoxon_rmm_allocate(size_t bytes) { return rmm::mr::get_current_device_resource()->allocate_sync(bytes); } catch (std::exception const& e) { throw std::runtime_error( - std::string("RMM allocation failed in Wilcoxon scratch: ") + - e.what()); + std::string("RMM allocation failed in Wilcoxon scratch (") + + std::to_string(bytes) + " bytes): " + e.what()); } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh index d30f92cc..efdac894 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -25,6 +25,10 @@ * grp_nz_count[n_groups] nonzero-per-group counters * warp_buf[32] tie-correction reduction scratch * + * n_rows is the ranking population, including rows whose group code is the + * n_groups sentinel. Sentinel rows contribute to the "rest" distribution and + * tie-correction denominator but do not receive rank-sum accumulation. + * * Grid: (sb_cols,) Block: (tpb,) */ template @@ -223,28 +227,11 @@ __global__ void rank_sums_sparse_ovr_kernel( } } -/** - * Decide whether the host cast+stats kernels can use per-block shared memory - * accumulators. Large group counts exceed the dynamic smem launch limit, so - * those cases fall back to direct global-memory atomics after zeroing the - * per-stream output buffers. - */ -static int wilcoxon_cast_max_smem_per_block() { - static int cached = -1; - if (cached < 0) { - int device; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&cached, cudaDevAttrMaxSharedMemoryPerBlock, - device); - } - return cached; -} - static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, bool compute_nnz, bool& use_gmem) { int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); size_t need = (size_t)n_arrays * n_groups * sizeof(double); - if (need <= (size_t)wilcoxon_cast_max_smem_per_block()) { + if (need <= wilcoxon_max_smem_per_block()) { use_gmem = false; return need; } diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index d399a301..a204d73e 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -21,100 +21,31 @@ ] -class _LazyRankGenesColumn: - def __init__( - self, - values: np.ndarray | None = None, - *, - var_names: np.ndarray | None = None, - gene_indices: np.ndarray | None = None, - dtype: str | np.dtype, - ) -> None: - self._values = values - self._var_names = var_names - self._gene_indices = gene_indices - self._dtype = np.dtype(dtype) - - def __len__(self) -> int: - if self._values is not None: - return int(self._values.shape[0]) - return int(self._gene_indices.shape[0]) - - def __getitem__(self, key): - if self._values is not None: - return self._values[key] - return self._var_names[self._gene_indices[key]] - - def __iter__(self): - for idx in range(len(self)): - yield self[idx] - - def __array__(self, dtype=None, copy=None) -> np.ndarray: - if self._values is not None: - arr = np.asarray(self._values, dtype=self._dtype) - else: - arr = np.asarray(self._var_names[self._gene_indices], dtype=self._dtype) - if dtype is not None: - arr = np.asarray(arr, dtype=dtype) - if copy: - arr = arr.copy() - return arr - - -class _LazyRankGenesRecords(dict): - def __init__( - self, group_names: np.ndarray, columns: dict[str, object], dtype: str | np.dtype - ) -> None: - super().__init__(columns) - self._group_names = tuple(str(name) for name in group_names) - self._dtype = np.dtype([(name, np.dtype(dtype)) for name in self._group_names]) - - @property - def dtype(self) -> np.dtype: - return self._dtype - - def __getitem__(self, key): - if isinstance(key, str): - return super().__getitem__(key) - return np.asarray(self)[key] - - def __array__(self, dtype=None, copy=None) -> np.ndarray: - out = np.empty(len(next(iter(self.values()))) if self else 0, dtype=self._dtype) - for name in self._group_names: - out[name] = np.asarray(super().__getitem__(name)) - if dtype is not None: - out = np.asarray(out, dtype=dtype) - if copy: - out = out.copy() - return out - - def copy(self) -> np.ndarray: - return np.asarray(self).copy() - - -def _array_result_to_lazy_records( +def _array_result_to_records( arrays: dict[str, object], field: str, dtype: str | np.dtype -) -> _LazyRankGenesRecords: - group_names = arrays["group_names"] - values = arrays[field] - columns = { - str(group_name): _LazyRankGenesColumn(values[row], dtype=dtype) - for row, group_name in enumerate(group_names) - } - return _LazyRankGenesRecords(group_names, columns, dtype) - - -def _array_result_to_lazy_names(arrays: dict[str, object]) -> _LazyRankGenesRecords: - group_names = arrays["group_names"] - var_names = arrays["var_names"] - gene_indices = arrays["gene_indices"] - columns = { - str(group_name): _LazyRankGenesColumn( - var_names=var_names, gene_indices=gene_indices[row], dtype=object - ) - for row, group_name in enumerate(group_names) - } - return _LazyRankGenesRecords(group_names, columns, object) +) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + values = np.asarray(arrays[field]) + out = np.empty( + values.shape[1], + dtype=[(group_name, np.dtype(dtype)) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = values[row] + return out + + +def _array_result_to_names(arrays: dict[str, object]) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + var_names = np.asarray(arrays["var_names"]) + gene_indices = np.asarray(arrays["gene_indices"], dtype=np.intp) + out = np.empty( + gene_indices.shape[1], + dtype=[(group_name, object) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = var_names[gene_indices[row]] + return out def rank_genes_groups( @@ -146,8 +77,8 @@ def rank_genes_groups( Rank genes for characterizing groups using GPU acceleration. Expects nonnegative expression data. Log1p/log-normalized data is expected - for biologically meaningful log fold changes; sparse inputs with explicit - negative values are rejected. + for biologically meaningful log fold changes; negative values are rejected + for eager in-memory inputs. .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and @@ -235,10 +166,8 @@ def rank_genes_groups( Returns ------- - Updates `adata` with the following fields. Rank result fields are lazy - Scanpy-compatible record objects: group fields can be indexed like - structured arrays, while full structured arrays are materialized only when - requested through NumPy conversion or `.copy()`. + Updates `adata` with the following fields. Rank result fields are + Scanpy-compatible structured arrays. `adata.uns['rank_genes_groups' | key_added]['names']` Structured array to be indexed by group id storing the gene @@ -269,7 +198,7 @@ def rank_genes_groups( if "return_format" in kwds: msg = ( "return_format has been removed; rank_genes_groups always writes " - "lazy Scanpy-compatible results to adata.uns." + "Scanpy-compatible structured results to adata.uns." ) raise TypeError(msg) @@ -357,23 +286,15 @@ def rank_genes_groups( arrays = test_obj.stats_arrays or {} adata.uns[key_added] = {"params": params} if arrays and len(arrays.get("group_names", ())) > 0: - adata.uns[key_added]["names"] = _array_result_to_lazy_names(arrays) - for col, dtype in { - "scores": "float32", - "logfoldchanges": "float32", - "pvals": "float64", - "pvals_adj": "float64", - }.items(): + adata.uns[key_added]["names"] = _array_result_to_names(arrays) + for col in ("scores", "logfoldchanges", "pvals", "pvals_adj"): if col in arrays: values = arrays[col] - if hasattr(values, "dtype"): - dtype = values.dtype - adata.uns[key_added][col] = _array_result_to_lazy_records( - arrays, col, dtype - ) + dtype = values.dtype + adata.uns[key_added][col] = _array_result_to_records(arrays, col, dtype) + groups_names = [str(name) for name in test_obj.groups_order] if test_obj.pts is not None: - groups_names = [str(name) for name in test_obj.groups_order] adata.uns[key_added]["pts"] = pd.DataFrame( test_obj.pts.T, index=test_obj.var_names, columns=groups_names ) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index acfbe2e2..af91e4d5 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -35,6 +35,41 @@ """, "fdr_bh_reverse_cummin", ) +_GROUP_CHUNK_STATS_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ void group_chunk_stats( + const double* block, + const int* group_codes, + double* group_sums, + double* group_sum_sq, + double* group_nnz, + const int n_rows, + const int n_cols, + const int n_groups, + const bool compute_nnz +) { + const long long idx = blockIdx.x * blockDim.x + threadIdx.x; + const long long total = static_cast(n_rows) * n_cols; + if (idx >= total) { + return; + } + const int row = idx % n_rows; + const int col = idx / n_rows; + const int group = group_codes[row]; + if (group < 0 || group >= n_groups) { + return; + } + const double value = block[idx]; + const long long out = static_cast(group) * n_cols + col; + atomicAdd(group_sums + out, value); + atomicAdd(group_sum_sq + out, value * value); + if (compute_nnz && value != 0.0) { + atomicAdd(group_nnz + out, 1.0); + } +} +""", + "group_chunk_stats", +) _RANK_SORT_MIN_ELEMENTS = 1_000_000 _RANK_SORT_MAX_WORKERS = 64 @@ -258,7 +293,7 @@ def _accumulate_chunk_stats_vs_rest( start: int, stop: int, *, - group_matrix: cp.ndarray, + group_codes_dev: cp.ndarray, group_sizes_dev: cp.ndarray, n_cells: int, ) -> None: @@ -268,9 +303,31 @@ def _accumulate_chunk_stats_vs_rest( rest_sizes = n_cells - group_sizes_dev - # Group sums and sum of squares - group_sums = group_matrix.T @ block - group_sum_sq = group_matrix.T @ (block**2) + n_groups = len(self.groups_order) + n_cols = stop - start + group_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_sum_sq = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_nnz = ( + cp.zeros((n_groups, n_cols), dtype=cp.float64) if self.comp_pts else None + ) + n_items = n_cells * n_cols + threads = 256 + blocks = (n_items + threads - 1) // threads + _GROUP_CHUNK_STATS_KERNEL( + (blocks,), + (threads,), + ( + block, + group_codes_dev, + group_sums, + group_sum_sq, + group_nnz if group_nnz is not None else group_sums, + np.int32(n_cells), + np.int32(n_cols), + np.int32(n_groups), + self.comp_pts, + ), + ) # Means chunk_means = group_sums / group_sizes_dev[:, None] @@ -283,7 +340,6 @@ def _accumulate_chunk_stats_vs_rest( # Pts (fraction expressing) if self.comp_pts: - group_nnz = group_matrix.T @ (block != 0).astype(cp.float64) self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) # Rest statistics @@ -439,7 +495,7 @@ def compute_statistics( raise ValueError(msg) self._score_dtype = np.dtype(np.float64 if return_u_values else np.float32) self._wilcoxon_gpu_result = None - self._store_wilcoxon_gpu_result = n_genes_user is not None + self._store_wilcoxon_gpu_result = True try: test_results = self.wilcoxon( tie_correct=tie_correct, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index 4ec37e40..de91e25d 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -16,32 +16,47 @@ EPS = 1e-9 WARP_SIZE = 32 MAX_THREADS_PER_BLOCK = 512 +MIN_GROUP_SIZE_WARNING = 25 + + +def _nonnegative_error(prefix: str) -> ValueError: + msg = ( + f"{prefix} contains negative values. rank_genes_groups expects " + "nonnegative expression values; use raw counts or log1p/log-normalized " + "expression, not scaled or centered data." + ) + return ValueError(msg) def _check_sparse_nonnegative(X) -> None: - """Reject sparse matrices with explicit negative values. + """Reject inputs with negative values where an eager check is cheap. Sparse rank_genes_groups code treats missing entries as true expression zeros. Optimized sparse Wilcoxon paths may rank explicit nonzeros and add implicit zeros analytically, which is only valid when explicit sparse values are nonnegative expression values. """ + dtype = None + if sp.issparse(X) or cpsp.issparse(X): + dtype = np.dtype(X.data.dtype) + elif isinstance(X, np.ndarray | cp.ndarray): + dtype = np.dtype(X.dtype) + if dtype is not None and dtype.kind == "c": + msg = "rank_genes_groups does not support complex expression values." + raise TypeError(msg) + if sp.issparse(X): if X.nnz > 0 and float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. rank_genes_groups " - "expects nonnegative expression values; use raw counts or " - "log1p/log-normalized expression, not scaled or centered data." - ) - raise ValueError(msg) + raise _nonnegative_error("Sparse input") elif cpsp.issparse(X): if X.nnz > 0 and float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. rank_genes_groups " - "expects nonnegative expression values; use raw counts or " - "log1p/log-normalized expression, not scaled or centered data." - ) - raise ValueError(msg) + raise _nonnegative_error("Sparse input") + elif isinstance(X, np.ndarray): + if X.size > 0 and float(np.nanmin(X)) < 0: + raise _nonnegative_error("Dense input") + elif isinstance(X, cp.ndarray): + if X.size > 0 and float(cp.nanmin(X)) < 0: + raise _nonnegative_error("Dense input") def _select_groups( @@ -140,20 +155,6 @@ def _round_up_to_warp(n: int) -> int: return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) -def _select_top_n(scores: NDArray, n_top: int) -> NDArray: - """Select indices of top n scores. - - Uses argpartition + argsort for O(n + k log k) complexity where k = n_top. - This is faster than full sorting when k << n. - """ - n_from = scores.shape[0] - reference_indices = np.arange(n_from, dtype=int) - partition = np.argpartition(scores, -n_top)[-n_top:] - partial_indices = np.argsort(scores[partition])[::-1] - global_indices = reference_indices[partition][partial_indices] - return global_indices - - def _choose_chunk_size(requested: int | None) -> int: """Choose chunk size for gene processing.""" if requested is not None: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index b96cfee6..880da7e0 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -12,14 +12,13 @@ from rapids_singlecell._cuda import _wilcoxon_cuda as _wc from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import EPS, _choose_chunk_size, _get_column_block +from ._utils import EPS, MIN_GROUP_SIZE_WARNING, _choose_chunk_size, _get_column_block if TYPE_CHECKING: from numpy.typing import NDArray from ._core import _RankGenes -MIN_GROUP_SIZE_WARNING = 25 DEFAULT_WILCOXON_CHUNK_SIZE = 512 OVR_HOST_CSC_SUB_BATCH = 512 OVR_HOST_CSR_SUB_BATCH = 2048 @@ -29,10 +28,11 @@ OVO_DEVICE_SPARSE_SUB_BATCH = 128 OVR_DENSE_SUB_BATCH = 64 OVO_DENSE_TIERED_SUB_BATCH = 256 -DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 +DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 # leave headroom for rank buffers def _maybe_preload_host_dense(rg: _RankGenes) -> None: + """Preload moderate host-dense matrices to avoid repeated chunk transfers.""" X = rg.X if not isinstance(X, np.ndarray) or X.size == 0: return @@ -259,7 +259,20 @@ def _wilcoxon_scores( def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): - is_f64 = X.data.dtype == np.float64 + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float64: + is_f64 = True + data_arr = X.data + elif data_dtype == np.float32 or data_dtype.kind in {"b", "i", "u"}: + is_f64 = False + data_arr = X.data.astype(np.float32, copy=False) + else: + msg = ( + "Wilcoxon sparse input data dtype must be float32, float64, bool, " + f"or integer; got {data_dtype}." + ) + raise TypeError(msg) + is_idx64 = support_idx64 and X.indices.dtype == np.int64 is_i64 = X.indptr.dtype == np.int64 suffix = "" @@ -270,15 +283,33 @@ def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool if is_i64: suffix += "_i64" fn = getattr(module, base_name + suffix) - data_arr = X.data if is_f64 else X.data.astype(np.float32, copy=False) indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) return fn, data_arr, indices_arr def _device_sparse_arrays_i32_f32(X): + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float32 or data_dtype == np.float64: + pass + elif data_dtype.kind in {"b", "i", "u"}: + pass + else: + msg = ( + "Wilcoxon device sparse input data dtype must be float32, float64, " + f"bool, or integer; got {data_dtype}." + ) + raise TypeError(msg) + if X.indptr.dtype != cp.int32: max_indptr = int(cp.asnumpy(X.indptr[-1])) if max_indptr > np.iinfo(np.int32).max: + warnings.warn( + "Wilcoxon device sparse path requires int32 indptr for CUDA " + "kernels; falling back to the bounded dense chunk path because " + f"nnz={max_indptr} exceeds int32.", + RuntimeWarning, + stacklevel=3, + ) return None data = X.data.astype(cp.float32, copy=False) indices = X.indices.astype(cp.int32, copy=False) @@ -620,12 +651,6 @@ def _wilcoxon_vs_rest( return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) - group_matrix = None - if rg._compute_stats_in_chunks: - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev @@ -645,7 +670,7 @@ def _wilcoxon_vs_rest( block, start, stop, - group_matrix=group_matrix, + group_codes_dev=group_codes_gpu, group_sizes_dev=group_sizes_dev, n_cells=n_cells, ) @@ -838,6 +863,8 @@ def _wilcoxon_with_reference( ) else: csr = X + # Host CSR gather scans each row's native index list and tolerates + # unsorted row indices; avoid a full CSR copy just to sort. csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( _wcs, "ovo_streaming_csr_host", csr, support_idx64=True ) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index 70d049af..14793834 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -11,6 +11,8 @@ from rapids_singlecell._compat import DaskArray from rapids_singlecell._cuda import _wilcoxon_binned_cuda as _wb +from ._utils import MIN_GROUP_SIZE_WARNING + if TYPE_CHECKING: from numpy.typing import NDArray @@ -159,7 +161,7 @@ def wilcoxon_binned( ): if gi == ireference: continue - if size <= 25 or n_ref <= 25: + if size <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (reference {n_ref}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", @@ -169,7 +171,7 @@ def wilcoxon_binned( else: for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size - if size <= 25 or rest <= 25: + if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (rest {rest}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index 24a40721..719fb939 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -1,5 +1,6 @@ from __future__ import annotations +import anndata as ad import numpy as np import pytest import scanpy as sc @@ -10,6 +11,10 @@ import rapids_singlecell as rsc +def _make_nonnegative(adata): + adata.X = np.abs(adata.X) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"]) @pytest.mark.parametrize("sparse", [True, False]) @@ -18,12 +23,15 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) if sparse: - adata_gpu.X = np.abs(adata_gpu.X).astype(np.float32) + adata_gpu.X = adata_gpu.X.astype(np.float32) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() + if sparse: + adata_cpu.X = adata_cpu.X.astype(np.float64) rsc.tl.rank_genes_groups( adata_gpu, @@ -53,19 +61,12 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] - rtol = 1e-6 if sparse else 1e-13 - if sparse and field in {"scores", "logfoldchanges"}: - atol = 1e-6 - elif sparse: - atol = 1e-12 - else: - atol = 1e-15 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) np.testing.assert_allclose( - gpu_values, cpu_values, rtol=rtol, atol=atol, equal_nan=True + gpu_values, cpu_values, rtol=1e-13, atol=1e-15, equal_nan=True ) params = gpu_result["params"] @@ -83,6 +84,7 @@ def test_rank_genes_groups_ttest_honors_layer_and_use_raw(reference, method): np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) base.obs["blobs"] = base.obs["blobs"].astype("category") + _make_nonnegative(base) base.layers["signal"] = base.X.copy() ref_adata = base.copy() @@ -131,6 +133,7 @@ def test_rank_genes_groups_ttest_subset_and_bonferroni(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -169,6 +172,7 @@ def test_rank_genes_groups_ttest_with_renamed_categories( np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # First run with original category names rsc.tl.rank_genes_groups(adata, "blobs", method=method, reference=reference_before) @@ -197,6 +201,7 @@ def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) bdata = adata.copy() groups = ["0", "1", "2", "3"] if reference != "rest" else ["0", "2", "3"] @@ -236,6 +241,7 @@ def test_rank_genes_groups_ttest_pts(reference, method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() # Run with pts=True @@ -297,8 +303,6 @@ def test_rank_genes_groups_ttest_direct_scipy(): Creates a simple two-group dataset and compares rapids_singlecell t-test directly against scipy.stats.ttest_ind without intermediate statistics. """ - import anndata as ad - np.random.seed(42) n_group1, n_group2, n_genes = 50, 60, 20 @@ -308,6 +312,9 @@ def test_rank_genes_groups_ttest_direct_scipy(): # Combine into AnnData X = np.vstack([X_group1, X_group2]) + X -= X.min() + X_group1 = X[:n_group1] + X_group2 = X[n_group1:] obs = {"group": ["A"] * n_group1 + ["B"] * n_group2} adata = ad.AnnData(X=X, obs=obs) adata.obs["group"] = adata.obs["group"].astype("category") @@ -350,6 +357,7 @@ def test_rank_genes_groups_ttest_matches_scipy(): adata = pbmc68k_reduced() # Convert to float64 for maximum precision in comparison adata.X = adata.X.astype(np.float64) + _make_nonnegative(adata) # Run rapids_singlecell t-test rsc.tl.rank_genes_groups(adata, "bulk_labels", method="t-test", use_raw=False) @@ -412,6 +420,7 @@ def test_rank_genes_groups_ttest_mask_var_array(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Create mask to select only first 5 genes mask = np.array([True] * 5 + [False] * 5) @@ -439,6 +448,7 @@ def test_rank_genes_groups_ttest_mask_var_string(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Add mask column to adata.var adata.var["highly_variable"] = [True] * 6 + [False] * 4 @@ -465,6 +475,7 @@ def test_rank_genes_groups_ttest_mask_var_matches_scanpy(method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=3, n_observations=150) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() mask = np.array([True, False, True, False, True, True, False, True]) @@ -497,6 +508,7 @@ def test_rank_genes_groups_ttest_rankby_abs(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) adata_abs = adata.copy() # Run without rankby_abs @@ -524,6 +536,7 @@ def test_rank_genes_groups_ttest_key_added(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) custom_key = "my_custom_key" diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 6e3dbf89..29871ba0 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -60,13 +60,77 @@ def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) -def test_rank_genes_groups_default_lazy_get_df_matches_scanpy(): +@pytest.mark.parametrize( + "method", + ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], +) +@pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_dense"]) +def test_rank_genes_groups_dense_negative_values_raise(method, fmt): + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.float32, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(ValueError, match="Dense input contains negative values"): + rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) + + +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) +def test_rank_genes_groups_complex_values_raise(fmt): + X = np.array( + [ + [1.0 + 0.0j, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.complex64, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(TypeError, match="complex expression values"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +def test_device_sparse_int64_indptr_overflow_warns(): + from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( + _device_sparse_arrays_i32_f32, + ) + + class FakeSparse: + data = cp.asarray([1.0], dtype=cp.float32) + indices = cp.asarray([0], dtype=cp.int32) + indptr = cp.asarray([0, np.iinfo(np.int32).max + 1], dtype=cp.int64) + + with pytest.warns(RuntimeWarning, match="requires int32 indptr"): + assert _device_sparse_arrays_i32_f32(FakeSparse()) is None + + +def test_rank_genes_groups_structured_results_get_df_and_h5ad_match_scanpy(tmp_path): np.random.seed(42) - adata_lazy = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) - _make_nonnegative(adata_lazy) - adata_lazy.obs["blobs"] = adata_lazy.obs["blobs"].astype("category") - adata_lazy.X = sp.csr_matrix(adata_lazy.X) - adata_cpu = adata_lazy.copy() + adata_rsc = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) + _make_nonnegative(adata_rsc) + adata_rsc.obs["blobs"] = adata_rsc.obs["blobs"].astype("category") + adata_rsc.X = sp.csr_matrix(adata_rsc.X) + adata_cpu = adata_rsc.copy() adata_cpu.X = adata_cpu.X.toarray() kw = { @@ -77,22 +141,27 @@ def test_rank_genes_groups_default_lazy_get_df_matches_scanpy(): "tie_correct": True, "n_genes": 4, } - rsc.tl.rank_genes_groups(adata_lazy, **kw) + rsc.tl.rank_genes_groups(adata_rsc, **kw) sc.tl.rank_genes_groups(adata_cpu, **kw) - lazy_result = adata_lazy.uns["rank_genes_groups"] - assert lazy_result["names"].dtype.names == ("0", "2") - assert tuple(lazy_result["names"][0]) == tuple( + rsc_result = adata_rsc.uns["rank_genes_groups"] + assert isinstance(rsc_result["names"], np.ndarray) + assert rsc_result["names"].dtype.names == ("0", "2") + assert tuple(rsc_result["names"][0]) == tuple( adata_cpu.uns["rank_genes_groups"]["names"][0] ) np.testing.assert_array_equal( - lazy_result["names"].copy(), - np.asarray(lazy_result["names"]), + rsc_result["names"].copy(), + np.asarray(rsc_result["names"]), ) - lazy_df = sc.get.rank_genes_groups_df(adata_lazy, group=None) + h5ad_path = tmp_path / "rank_genes_groups.h5ad" + adata_rsc.write_h5ad(h5ad_path) + adata_rsc = sc.read_h5ad(h5ad_path) + + rsc_df = sc.get.rank_genes_groups_df(adata_rsc, group=None) scanpy_df = sc.get.rank_genes_groups_df(adata_cpu, group=None) - pd.testing.assert_frame_equal(lazy_df, scanpy_df) + pd.testing.assert_frame_equal(rsc_df, scanpy_df) def test_rank_genes_groups_return_format_removed(): @@ -168,6 +237,60 @@ def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) +def test_rank_genes_groups_wilcoxon_dense_edge_cases_match_scipy(): + X = np.array( + [ + [1.0, 5.0, 0.0, 2.0, 1.0], + [2.0, 5.0, 0.0, 2.0, 1.0], + [3.0, 5.0, 1.0, 2.0, 1.0], + [4.0, 5.0, 1.0, 3.0, 2.0], + [5.0, 5.0, 1.0, 3.0, 2.0], + [6.0, 5.0, 2.0, 3.0, 2.0], + [7.0, 5.0, 2.0, 4.0, 3.0], + [8.0, 5.0, 2.0, 4.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "a", "a", "b", "b", "b", "b"]) + adata = sc.AnnData( + X=X, + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=["no_ties", "all_ties", "zero_ties", "mixed", "pairs"]), + ) + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference="b", + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + expected_u = {} + for idx, name in enumerate(adata.var_names): + result = mannwhitneyu( + X[labels == "a", idx], + X[labels == "b", idx], + alternative="two-sided", + method="asymptotic", + use_continuity=True, + ) + expected_u[name] = result.statistic + + np.testing.assert_allclose( + df["scores"].to_numpy(), + np.array([expected_u[name] for name in df["names"]]), + rtol=1e-13, + atol=1e-15, + ) + assert np.isfinite(df["pvals"]).all() + + def test_rank_genes_groups_return_u_values_requires_wilcoxon(): adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) _make_nonnegative(adata) @@ -190,10 +313,10 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars """Test wilcoxon matches scanpy output across configurations.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: - _make_nonnegative(adata_gpu) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() @@ -228,12 +351,12 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] - rtol = 1e-6 if field == "logfoldchanges" else 1e-13 + rtol = 1e-13 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) - atol = 1e-6 if field == "logfoldchanges" else 1e-15 + atol = 1e-15 np.testing.assert_allclose(gpu_values, cpu_values, rtol=rtol, atol=atol) params = gpu_result["params"] @@ -283,6 +406,7 @@ def test_rank_genes_groups_wilcoxon_honors_layer_and_use_raw(reference): """Test that layer parameter is respected.""" np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) + _make_nonnegative(base) base.obs["blobs"] = base.obs["blobs"].astype("category") base.layers["signal"] = base.X.copy() @@ -330,6 +454,7 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): """Test group subsetting and bonferroni correction.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -360,6 +485,7 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=21) + _make_nonnegative(adata) adata.obs["target"] = pd.Categorical( ["ref"] * 10 + ["valid"] * 10 + ["singleton"], categories=["ref", "valid", "singleton", "empty"], @@ -383,6 +509,7 @@ def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): def test_rank_genes_groups_wilcoxon_skip_empty_groups_all_tests_filtered(): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=11) + _make_nonnegative(adata) adata.obs["target"] = pd.Categorical( ["ref"] * 10 + ["singleton"], categories=["ref", "singleton", "empty"], @@ -434,8 +561,8 @@ def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): gpu_result = adata_gpu.uns["rank_genes_groups"] cpu_result = adata_cpu.uns["rank_genes_groups"] for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): - rtol = 1e-6 if field == "logfoldchanges" else 1e-13 - atol = 1e-6 if field == "logfoldchanges" else 1e-15 + rtol = 1e-13 + atol = 1e-15 for group in gpu_result[field].dtype.names: np.testing.assert_allclose( np.asarray(gpu_result[field][group], dtype=float), @@ -547,7 +674,8 @@ def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): "cupy_csc", ], ) -def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): +@pytest.mark.parametrize("pre_load", [False, True]) +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt, pre_load): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) _make_nonnegative(adata_gpu) @@ -563,14 +691,14 @@ def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt): "tie_correct": True, "n_genes": 5, } - rsc.tl.rank_genes_groups(adata_gpu, **kw) + rsc.tl.rank_genes_groups(adata_gpu, **kw, pre_load=pre_load) sc.tl.rank_genes_groups(adata_cpu, **kw) gpu_result = adata_gpu.uns["rank_genes_groups"] cpu_result = adata_cpu.uns["rank_genes_groups"] for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): - rtol = 1e-6 if field == "logfoldchanges" else 1e-13 - atol = 1e-6 if field == "logfoldchanges" else 1e-15 + rtol = 1e-13 + atol = 1e-15 for group in gpu_result[field].dtype.names: np.testing.assert_allclose( np.asarray(gpu_result[field][group], dtype=float), @@ -591,6 +719,7 @@ def test_rank_genes_groups_wilcoxon_with_renamed_categories( """Test with renamed category labels.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") # First run with original category names @@ -622,6 +751,7 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): """Test that group order doesn't affect results.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") bdata = adata.copy() @@ -661,6 +791,7 @@ def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): """Test that pts (fraction of cells expressing) is computed correctly.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") adata_cpu = adata_gpu.copy() From 73cda5880f6f850486b1455d81af1a224f6ac715 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Wed, 13 May 2026 18:03:55 +0200 Subject: [PATCH 7/7] fix tests --- .../tools/_rank_genes_groups/_utils.py | 6 ----- tests/test_rank_genes_groups_wilcoxon.py | 27 ------------------- 2 files changed, 33 deletions(-) diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index de91e25d..e9efbc50 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -51,12 +51,6 @@ def _check_sparse_nonnegative(X) -> None: elif cpsp.issparse(X): if X.nnz > 0 and float(X.data.min()) < 0: raise _nonnegative_error("Sparse input") - elif isinstance(X, np.ndarray): - if X.size > 0 and float(np.nanmin(X)) < 0: - raise _nonnegative_error("Dense input") - elif isinstance(X, cp.ndarray): - if X.size > 0 and float(cp.nanmin(X)) < 0: - raise _nonnegative_error("Dense input") def _select_groups( diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 5bef924b..af39da54 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -60,33 +60,6 @@ def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) -@pytest.mark.parametrize( - "method", - ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], -) -@pytest.mark.parametrize("fmt", ["numpy_dense", "cupy_dense"]) -def test_rank_genes_groups_dense_negative_values_raise(method, fmt): - X = np.array( - [ - [-1.0, 0.0, 2.0], - [0.0, 1.0, 0.0], - [2.0, 0.0, 1.0], - [0.0, 3.0, 0.0], - ], - dtype=np.float32, - ) - adata = sc.AnnData( - X=_to_format(X, fmt), - obs=pd.DataFrame( - {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} - ), - var=pd.DataFrame(index=["g0", "g1", "g2"]), - ) - - with pytest.raises(ValueError, match="Dense input contains negative values"): - rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) - - @pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) def test_rank_genes_groups_complex_values_raise(fmt): X = np.array(