diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 85488702..a10e7d24 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -69,16 +69,47 @@ jobs: path = pathlib.Path("pyproject.toml") text = path.read_text() + def remove_toml_array(text, key): + lines = text.splitlines(keepends=True) + out = [] + i = 0 + while i < len(lines): + if lines[i].startswith(f"{key} = ["): + depth = lines[i].count("[") - lines[i].count("]") + i += 1 + while i < len(lines) and depth > 0: + depth += lines[i].count("[") - lines[i].count("]") + i += 1 + continue + out.append(lines[i]) + i += 1 + return "".join(out) + # Rename package text = text.replace( 'name = "rapids-singlecell"', f'name = "rapids-singlecell-cu{cuda}"', ) # Rename matching extra to "rapids", remove the other - text = text.replace(f'rapids-cu{cuda} =', 'rapids =') - # Remove the other CUDA extra line entirely - lines = text.splitlines(keepends=True) - text = "".join(l for l in lines if f'rapids-cu{other}' not in l) + text = text.replace(f'rapids-cu{cuda} = [', 'rapids = [') + text = remove_toml_array(text, f"rapids-cu{other}") + + # librmm is needed at build time because CMake links the CUDA + # extension against librmm. Add the matching wheel to the isolated + # PEP 517 build requirements after selecting the CUDA package variant. + for dep in ( + f' "librmm-cu{other}>=25.10",\n', + f' "rmm-cu{other}>=25.10",\n', + ): + text = text.replace(dep, "") + rmm_build_req = f' "librmm-cu{cuda}>=25.10",\n' + build_system_text = text.split("[project]", 1)[0] + if f'"librmm-cu{cuda}>=25.10"' not in build_system_text: + text = text.replace( + ']\nbuild-backend = "scikit_build_core.build"', + f'{rmm_build_req}]\nbuild-backend = "scikit_build_core.build"', + 1, + ) # Set CUDA architectures (replace "native" with CI target archs) text = text.replace( @@ -96,6 +127,7 @@ jobs: - name: Sanity check pyproject.toml run: | + python3 -c "import tomllib; tomllib.load(open('pyproject.toml', 'rb'))" grep -E "name|rapids|CUDA_ARCH" pyproject.toml - name: Build CUDA manylinux image @@ -116,11 +148,23 @@ jobs: LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH PATH=/usr/local/cuda/bin:$PATH CIBW_BEFORE_BUILD: > + rm -f build/.librmm_dir && + mkdir -p build && python -m pip install -U pip scikit-build-core cmake ninja nanobind + librmm-cu${{ matrix.cuda_major }} && + RMM_ROOT=$(python -c "import librmm; print(librmm.__path__[0])") && + LOG_ROOT=$(python -c "import rapids_logger; print(rapids_logger.__path__[0])") && + echo "[rsc-build] librmm=$RMM_ROOT" && + echo "[rsc-build] rapids_logger=$LOG_ROOT" && + ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so && + ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so && + ldconfig && + python -c "import librmm; print(librmm.__path__[0])" > build/.librmm_dir && + echo "[rsc-build] marker=$(cat build/.librmm_dir)" CIBW_TEST_SKIP: "*" CIBW_TEST_COMMAND: "" - CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}" CIBW_BUILD_VERBOSITY: "1" - uses: actions/upload-artifact@v4 diff --git a/.gitignore b/.gitignore index 2e107eab..c9252c90 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,8 @@ coverage.xml .claude/ .codex CLAUDE.md +.codex # tmp_scripts tmp_scripts/ +/benchmarks/ diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..4e404263 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,104 @@ if (RSC_BUILD_EXTENSIONS) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(nanobind CONFIG REQUIRED) find_package(CUDAToolkit REQUIRED) + set(RSC_RMM_HINTS) + set(RSC_RAPIDS_CMAKE_PREFIXES) + set(RSC_CCCL_HINTS) + set(RSC_RAPIDS_LOGGER_HINTS) + set(RSC_NVTX3_HINTS) + macro(_rsc_collect_rapids_python_prefix _rsc_prefix) + if (NOT "${_rsc_prefix}" STREQUAL "") + file(GLOB _rsc_rmm_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/rmm") + file(GLOB _rsc_rapids_prefixes + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64" + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids" + "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib" + ) + file(GLOB _rsc_cccl_dirs + "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/rapids/cmake/cccl" + "${_rsc_prefix}/lib/python*/site-packages/nvidia/cu*/lib/cmake/cccl" + ) + file(GLOB _rsc_rapids_logger_dirs "${_rsc_prefix}/lib/python*/site-packages/rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_nvtx3_dirs "${_rsc_prefix}/lib/python*/site-packages/librmm/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_nvtx3_dirs}) + endif() + endmacro() + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import importlib.util, pathlib; spec = importlib.util.find_spec('librmm'); print(pathlib.Path(spec.origin).parent / 'lib64' / 'cmake' / 'rmm' if spec else '')" + OUTPUT_VARIABLE RSC_PYTHON_RMM_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake") + list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}") + endif() + # Wheel builds install librmm/rapids_logger into the isolated build env and + # write build/.librmm_dir from CIBW_BEFORE_BUILD. publish.yml also symlinks + # those shared libraries into /usr/local/lib so auditwheel can see and exclude + # them instead of bundling RAPIDS runtime libraries into the wheel. + if(DEFINED ENV{RSC_LIBRMM_DIR} AND EXISTS "$ENV{RSC_LIBRMM_DIR}/lib64/cmake/rmm/rmm-config.cmake") + set(_rsc_librmm_marker "$ENV{RSC_LIBRMM_DIR}") + elseif(EXISTS "${CMAKE_SOURCE_DIR}/build/.librmm_dir") + file(READ "${CMAKE_SOURCE_DIR}/build/.librmm_dir" _rsc_librmm_marker) + string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker) + else() + set(_rsc_librmm_marker "") + endif() + if(NOT "${_rsc_librmm_marker}" STREQUAL "" AND EXISTS "${_rsc_librmm_marker}/lib64/cmake/rmm/rmm-config.cmake") + file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm") + file(GLOB _rsc_marker_rapids_prefixes + "${_rsc_librmm_marker}/lib64" + "${_rsc_librmm_marker}/lib64/rapids" + "${_rsc_librmm_marker}/../rapids_logger/lib64" + ) + file(GLOB _rsc_marker_cccl_dirs + "${_rsc_librmm_marker}/lib64/rapids/cmake/cccl" + ) + file(GLOB _rsc_marker_rapids_logger_dirs "${_rsc_librmm_marker}/../rapids_logger/lib64/cmake/rapids_logger") + file(GLOB _rsc_marker_nvtx3_dirs "${_rsc_librmm_marker}/lib64/cmake/nvtx3") + list(APPEND RSC_RMM_HINTS ${_rsc_marker_rmm_dirs}) + list(APPEND RSC_RAPIDS_CMAKE_PREFIXES ${_rsc_marker_rapids_prefixes}) + list(APPEND RSC_CCCL_HINTS ${_rsc_marker_cccl_dirs}) + list(APPEND RSC_RAPIDS_LOGGER_HINTS ${_rsc_marker_rapids_logger_dirs}) + list(APPEND RSC_NVTX3_HINTS ${_rsc_marker_nvtx3_dirs}) + endif() + foreach(_rsc_python_prefix IN ITEMS "${Python_ROOT_DIR}" "${Python3_ROOT_DIR}") + _rsc_collect_rapids_python_prefix("${_rsc_python_prefix}") + endforeach() + foreach(_rsc_env_prefix IN ITEMS "$ENV{CONDA_PREFIX}" "$ENV{VIRTUAL_ENV}") + _rsc_collect_rapids_python_prefix("${_rsc_env_prefix}") + endforeach() + string(REPLACE ":" ";" _rsc_path_entries "$ENV{PATH}") + foreach(_rsc_path_entry IN LISTS _rsc_path_entries) + get_filename_component(_rsc_path_prefix "${_rsc_path_entry}/.." ABSOLUTE) + _rsc_collect_rapids_python_prefix("${_rsc_path_prefix}") + endforeach() + if (RSC_RAPIDS_CMAKE_PREFIXES) + list(APPEND CMAKE_PREFIX_PATH ${RSC_RAPIDS_CMAKE_PREFIXES}) + if (RSC_CCCL_HINTS) + list(GET RSC_CCCL_HINTS 0 _rsc_cccl_dir) + set(CCCL_DIR "${_rsc_cccl_dir}" CACHE PATH "Path to CCCL package config" FORCE) + endif() + if (RSC_RAPIDS_LOGGER_HINTS) + list(GET RSC_RAPIDS_LOGGER_HINTS 0 _rsc_rapids_logger_dir) + set(rapids_logger_DIR "${_rsc_rapids_logger_dir}" CACHE PATH "Path to rapids_logger package config" FORCE) + endif() + if (RSC_NVTX3_HINTS) + list(GET RSC_NVTX3_HINTS 0 _rsc_nvtx3_dir) + set(nvtx3_DIR "${_rsc_nvtx3_dir}" CACHE PATH "Path to nvtx3 package config" FORCE) + endif() + endif() + if (RSC_RMM_HINTS) + find_package(rmm CONFIG REQUIRED HINTS ${RSC_RMM_HINTS}) + else() + find_package(rmm CONFIG REQUIRED) + endif() + message(STATUS "Using RMM for CUDA extension scratch allocations") message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") @@ -85,6 +183,11 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) + target_sources(_wilcoxon_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) + target_link_libraries(_wilcoxon_cuda PRIVATE rmm::rmm) + add_nb_cuda_module(_wilcoxon_sparse_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu) + target_sources(_wilcoxon_sparse_cuda PRIVATE src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu) + target_link_libraries(_wilcoxon_sparse_cuda PRIVATE rmm::rmm) # Harmony CUDA modules add_nb_cuda_module(_harmony_scatter_cuda src/rapids_singlecell/_cuda/harmony/scatter/scatter.cu) add_nb_cuda_module(_harmony_outer_cuda src/rapids_singlecell/_cuda/harmony/outer/outer.cu) diff --git a/pyproject.toml b/pyproject.toml index 6dea3a07..e0961f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,11 @@ requires = [ "scikit-build-core>=0.10", "nanobind>=2.0.0", "setuptools-scm>=8", + # librmm headers/CMake config are needed at build time for Wilcoxon. + # Generic isolated source builds default to CUDA 12. CUDA wheel builds + # rewrite this to the matching cu12/cu13 package; CUDA 13 source builds + # should build in an existing RAPIDS env with --no-build-isolation. + "librmm-cu12>=25.10", ] build-backend = "scikit_build_core.build" @@ -32,8 +37,22 @@ dependencies = [ ] [project.optional-dependencies] -rapids-cu13 = [ "cupy-cuda13x", "cudf-cu13>=25.10", "cuml-cu13>=25.10", "cugraph-cu13>=25.10", "cuvs-cu13>=25.10" ] -rapids-cu12 = [ "cupy-cuda12x", "cudf-cu12>=25.10", "cuml-cu12>=25.10", "cugraph-cu12>=25.10", "cuvs-cu12>=25.10" ] +rapids-cu13 = [ + "cupy-cuda13x", + "cudf-cu13>=25.10", + "cuml-cu13>=25.10", + "cugraph-cu13>=25.10", + "cuvs-cu13>=25.10", + "librmm-cu13>=25.10", +] +rapids-cu12 = [ + "cupy-cuda12x", + "cudf-cu12>=25.10", + "cuml-cu12>=25.10", + "cugraph-cu12>=25.10", + "cuvs-cu12>=25.10", + "librmm-cu12>=25.10", +] doc = [ "sphinx>=4.5.0", @@ -150,7 +169,7 @@ sdist.include = [ "src/rapids_singlecell/_version.py" ] # Use abi3audit to catch issues with Limited API wheels [tool.cibuildwheel.linux] repair-wheel-command = [ - "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 -w {dest_dir} {wheel}", + "auditwheel repair --exclude libcublas.so.12 --exclude libcublas.so.13 --exclude libcublasLt.so.12 --exclude libcublasLt.so.13 --exclude libcudart.so.12 --exclude libcudart.so.13 --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}", "pipx run abi3audit --strict --report {wheel}", ] [tool.cibuildwheel.macos] diff --git a/src/rapids_singlecell/_cuda/nb_types.h b/src/rapids_singlecell/_cuda/nb_types.h index 905e1e07..eb343815 100644 --- a/src/rapids_singlecell/_cuda/nb_types.h +++ b/src/rapids_singlecell/_cuda/nb_types.h @@ -42,6 +42,13 @@ using gpu_array = nb::ndarray; template using gpu_array_contig = nb::ndarray; +// Host (NumPy) array aliases +template +using host_array = nb::ndarray>; + +template +using host_array_2d = nb::ndarray>; + // Register bindings for both regular CUDA and managed-memory arrays. // Usage: // template diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh index c89d913a..5af4e964 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon.cuh @@ -2,143 +2,207 @@ #include +__device__ __forceinline__ double wilcoxon_block_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + return v; + } + return 0.0; +} + /** - * Kernel to compute tie correction factor for Wilcoxon test. - * Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) where t is the count of tied - * values. + * OVR dense rank-sum kernel for data sorted by column. * - * Each block handles one column. Uses binary search to find tie groups. - * Assumes input is sorted column-wise (F-order). + * sorted_vals and sorted_row_idx are F-order arrays from a segmented + * SortPairs. One block owns one column, walks tie runs, and accumulates the + * average ranks per group without materializing a full rank matrix. */ -__global__ void tie_correction_kernel(const double* __restrict__ sorted_vals, - double* __restrict__ correction, - const int n_rows, const int n_cols) { - // Each block handles one column +__global__ void rank_sums_from_sorted_kernel( + const float* __restrict__ sorted_vals, + const int* __restrict__ sorted_row_idx, const int* __restrict__ group_codes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, bool compute_tie_corr, bool use_gmem) { int col = blockIdx.x; if (col >= n_cols) return; - const double* sv = sorted_vals + (size_t)col * n_rows; + extern __shared__ double smem[]; + + double* grp_sums; + if (use_gmem) { + grp_sums = rank_sums + (size_t)col; + } else { + grp_sums = smem; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + } + __syncthreads(); + } + + const float* sv = sorted_vals + (size_t)col * n_rows; + const int* si = sorted_row_idx + (size_t)col * n_rows; - double local_sum = 0.0; - int tid = threadIdx.x; + int chunk = (n_rows + blockDim.x - 1) / blockDim.x; + int my_start = threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > n_rows) my_end = n_rows; - // Each thread processes positions where it detects END of a tie group - // Start from index 1, check if sv[i-1] != sv[i] (boundary detected) - // When at boundary, use binary search to find tie group size - for (int i = tid + 1; i <= n_rows; i += blockDim.x) { - // Detect boundary: either at the end, or value changed - bool at_boundary = (i == n_rows) || (sv[i] != sv[i - 1]); + double local_tie_sum = 0.0; + int acc_stride = use_gmem ? n_cols : 1; - if (at_boundary) { - // Found end of tie group at position i-1 - // Binary search for start of this tie group - double val = sv[i - 1]; - int lo = 0, hi = i - 1; + int i = my_start; + while (i < my_end) { + double val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) { + ++tie_local_end; + } + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + int lo = 0; + int hi = i; while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) lo = mid + 1; - } else { + else hi = mid; - } } - int tie_count = i - lo; + tie_global_start = lo; + } - // t^3 - t for this tie group - double t = (double)tie_count; - local_sum += t * t * t - t; + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < n_rows && + sv[tie_local_end] == val) { + int lo = tie_local_end; + int hi = n_rows - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; } - } - // Warp-level reduction using shuffle -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_down_sync(0xffffffff, local_sum, offset); - } + int total_tie = tie_global_end - tie_global_start; + double avg_rank = (double)(tie_global_start + tie_global_end + 1) / 2.0; - // Cross-warp reduction using small shared memory - __shared__ double warp_sums[32]; - int lane = tid & 31; - int warp_id = tid >> 5; + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } - if (lane == 0) { - warp_sums[warp_id] = local_sum; + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; } + __syncthreads(); - // Final reduction in first warp - // Note: blockDim.x must be a multiple of 32 for correct warp reduction - if (tid < 32) { - double val = (tid < (blockDim.x >> 5)) ? warp_sums[tid] : 0.0; -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xffffffff, val, offset); + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * n_cols + col] = grp_sums[g]; } - if (tid == 0) { + } + + if (compute_tie_corr) { + int warp_buf_off = use_gmem ? 0 : n_groups; + double* warp_buf = smem + warp_buf_off; + double tie_sum = wilcoxon_block_sum(local_tie_sum, warp_buf); + if (threadIdx.x == 0) { double n = (double)n_rows; double denom = n * n * n - n; - if (denom > 0) { - correction[col] = 1.0 - val / denom; - } else { - correction[col] = 1.0; - } + tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; } } } /** - * Kernel to compute average ranks for each column. - * Uses scipy.stats.rankdata 'average' method: ties get the average of the ranks - * they would span. + * OVR dense rank core. * - * Each block handles one column. Assumes input is sorted column-wise (F-order). + * sorted_vals and sorter are F-order outputs of sorting each column of the + * current dense block. The kernel directly accumulates rank sums per group, + * avoiding a full ranks matrix and a group one-hot matrix multiply. */ -__global__ void average_rank_kernel(const double* __restrict__ sorted_vals, - const int* __restrict__ sorter, - double* __restrict__ ranks, - const int n_rows, const int n_cols) { - // Each thread block handles one column +__global__ void ovr_rank_dense_kernel(const float* __restrict__ sorted_vals, + const int* __restrict__ sorter, + const int* __restrict__ group_codes, + double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_rows, + int n_cols, int n_groups, + bool compute_tie_corr) { int col = blockIdx.x; if (col >= n_cols) return; - // Pointers to this column's data - const double* sv = sorted_vals + (size_t)col * n_rows; - const int* si = sorter + (size_t)col * n_rows; - double* rk = ranks + (size_t)col * n_rows; + const float* sv = sorted_vals + (long long)col * n_rows; + const int* si = sorter + (long long)col * n_rows; - // Each thread processes multiple rows + double local_tie = 0.0; for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { - double val = sv[i]; + float val = sv[i]; - // Binary search for tie_start (first element equal to val) int lo = 0, hi = i; while (lo < hi) { - int mid = (lo + hi) / 2; - if (sv[mid] < val) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) lo = mid + 1; - } else { + else hi = mid; - } } int tie_start = lo; - // Binary search for tie_end (last element equal to val) lo = i; hi = n_rows - 1; while (lo < hi) { - int mid = (lo + hi + 1) / 2; - if (sv[mid] > val) { + int mid = lo + ((hi - lo + 1) >> 1); + if (sv[mid] > val) hi = mid - 1; - } else { + else lo = mid; - } } int tie_end = lo; - - // Average rank for ties: (start + end + 2) / 2 (1-based ranks) double avg_rank = (double)(tie_start + tie_end + 2) / 2.0; - // Write rank to original position - rk[si[i]] = avg_rank; + int row = si[i]; + int group = group_codes[row]; + if (group >= 0 && group < n_groups) { + atomicAdd(&rank_sums[(size_t)group * n_cols + col], avg_rank); + } + + if (compute_tie_corr && i == tie_end) { + double t = (double)(tie_end - tie_start + 1); + if (t > 1.0) local_tie += t * t * t - t; + } + } + + if (!compute_tie_corr) return; + + __shared__ double warp_buf[32]; + double tie_sum = wilcoxon_block_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; } } diff --git a/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh new file mode 100644 index 00000000..a8e9ed4f --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh @@ -0,0 +1,978 @@ +#pragma once + +#include + +#include "wilcoxon_fast_common.cuh" + +// ============================================================================ +// Warp reduction helper (sum doubles across block via warp_buf) +// ============================================================================ + +__device__ __forceinline__ double block_reduce_sum(double val, + double* warp_buf) { +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + val += __shfl_down_sync(0xffffffff, val, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = val; + __syncthreads(); + if (threadIdx.x < 32) { + double v2 = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v2 += __shfl_down_sync(0xffffffff, v2, off); + return v2; // only lane 0 of warp 0 has the final result + } + return 0.0; +} + +// ============================================================================ +// Parallel tie correction — all threads collaborate. +// +// For each unique value in the combined sorted (ref, grp) arrays, accumulate +// t^3 - t where t = count of that value. Uses two passes: +// 1. Iterate unique values in ref_col, count in both arrays. +// 2. Iterate unique values in grp_col that do NOT appear in ref_col. +// +// Incremental binary search bounds exploit monotonicity within each thread's +// stride to reduce total search work. +// +// Caller must __syncthreads() before calling. warp_buf is reused for +// reduction (32 doubles, shared memory). +// ============================================================================ + +__device__ __forceinline__ void compute_tie_correction_parallel( + const float* ref_col, int n_ref, const float* grp_col, int n_grp, + double* warp_buf, double* out) { + double local_tie = 0.0; + + // Pass 1: unique values in ref_col + int grp_lb = 0, grp_ub = 0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - i; + + // Count in grp: incremental lower/upper bound + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int lb = lo; + grp_lb = lb; + + lo = (grp_ub > lb) ? grp_ub : lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_grp = lo - lb; + grp_ub = lo; + + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp_col that are absent from ref_col + int ref_lb = 0; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + if (i == 0 || grp_col[i] != grp_col[i - 1]) { + float v = grp_col[i]; + + // Incremental lower_bound in ref + int lo = ref_lb, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + ref_lb = lo; + + if (lo >= n_ref || ref_col[lo] != v) { + // Value not in ref — count in grp only (upper_bound from i+1) + lo = i + 1; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + } + + // Block-wide reduction + double tie_sum = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + *out = (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Batched rank sums — pre-sorted (binary search, no shared memory sort) +// Used by the OVO streaming pipeline in wilcoxon_streaming.cu. +// +// Incremental binary search: each thread carries forward lower/upper bound +// positions across loop iterations, exploiting the monotonicity of the +// sorted grp_col values within each thread's stride. +// ============================================================================ + +__global__ void batched_rank_sums_presorted_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_sorted, + const int* __restrict__ grp_offsets, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch (see ovo_fused_sort_rank_kernel for the contract). + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + const float* grp_col = grp_sorted + (long long)col * n_all_grp + g_start; + + // Incremental binary search bounds (advance monotonically per thread) + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_col[i]; + int lo, hi; + + // Lower bound in ref (from ref_lb) + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + // Upper bound in ref (from max(ref_ub, n_lt_ref)) + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + // Lower bound in grp (from grp_lb) + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + // Upper bound in grp (from max(grp_ub, n_lt_grp)) + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + compute_tie_correction_parallel(ref_col, n_ref, grp_col, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 1 fused kernel: smem bitonic sort + binary search rank sums +// For small groups (< ~2K cells). No CUB, no global memory sort buffers. +// Grid: (n_cols, n_groups), Block: min(padded_grp_size, 512) +// Shared memory: padded_grp_size floats + 32 doubles (warp reduction) +// ============================================================================ + +__global__ void ovo_fused_sort_rank_kernel( + const float* __restrict__ ref_sorted, // F-order (n_ref, n_cols) sorted + const float* __restrict__ grp_dense, // F-order (n_all_grp, n_cols) + // unsorted + const int* __restrict__ grp_offsets, // (n_groups + 1,) + double* __restrict__ rank_sums, // (n_groups, n_cols) row-major + double* __restrict__ tie_corr, // (n_groups, n_cols) row-major + int n_ref, int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int padded_grp_size, int skip_n_grp_le /*= 0*/) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // Size-gated dispatch: when co-launched with the Tier 0 warp kernel we + // skip groups it's already handling. Each group owns its own + // rank_sums row, so the two kernels' writes never alias. + if (n_grp <= skip_n_grp_le) return; + + if (n_grp == 0) { + if (threadIdx.x == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // Shared memory: [padded_grp_size floats | 32 doubles for warp reduction] + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + padded_grp_size * sizeof(float)); + + // Load group data into shared memory, pad with +INF + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + for (int i = n_grp + threadIdx.x; i < padded_grp_size; i += blockDim.x) + grp_smem[i] = __int_as_float(0x7f800000); // +INF + __syncthreads(); + + // Bitonic sort in shared memory + for (int k = 2; k <= padded_grp_size; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = threadIdx.x; i < padded_grp_size; i += blockDim.x) { + int ixj = i ^ j; + if (ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + } + __syncthreads(); + } + } + + // Binary search each sorted grp element against sorted ref + // Incremental bounds: values are monotonic within each thread's stride + const float* ref_col = ref_sorted + (long long)col * n_ref; + int ref_lb = 0, ref_ub = 0; + int grp_lb = 0, grp_ub = 0; + double local_sum = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + int lo, hi; + + lo = ref_lb; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + ref_lb = n_lt_ref; + + lo = (ref_ub > n_lt_ref) ? ref_ub : n_lt_ref; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + ref_ub = lo; + + lo = grp_lb; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + grp_lb = n_lt_grp; + + lo = (grp_ub > n_lt_grp) ? grp_ub : n_lt_grp; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + grp_ub = lo; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + } + + // Block reduction → write rank_sums + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + // Parallel tie correction (grp_smem is sorted shared memory) + compute_tie_correction_parallel(ref_col, n_ref, grp_smem, n_grp, warp_buf, + &tie_corr[grp * n_cols + col]); +} + +// ============================================================================ +// Tier 2 helper: tie contribution of the sorted reference alone. +// One block per column. The medium unsorted-rank kernel uses this as a base +// and only adds group-only/overlap deltas from the unsorted group values. +// ============================================================================ + +__global__ void ref_tie_sum_kernel(const float* __restrict__ ref_sorted, + double* __restrict__ ref_tie_sums, int n_ref, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + const float* ref_col = ref_sorted + (long long)col * n_ref; + + double local_tie = 0.0; + for (int i = threadIdx.x; i < n_ref; i += blockDim.x) { + if (i == 0 || ref_col[i] != ref_col[i - 1]) { + float v = ref_col[i]; + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt = lo - i; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + __shared__ double warp_buf[32]; + double total = block_reduce_sum(local_tie, warp_buf); + if (threadIdx.x == 0) ref_tie_sums[col] = total; +} + +__global__ void ovo_small64_sort_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > TIER0_64_GROUP_THRESHOLD) return; + + __shared__ float grp_smem[TIER0_64_GROUP_THRESHOLD]; + __shared__ double warp_buf[WARP_REDUCE_BUF]; + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + const float POS_INF = __int_as_float(0x7f800000); + if (threadIdx.x < TIER0_64_GROUP_THRESHOLD) { + grp_smem[threadIdx.x] = + (threadIdx.x < n_grp) ? grp_col[threadIdx.x] : POS_INF; + } + __syncthreads(); + + for (int k = 2; k <= TIER0_64_GROUP_THRESHOLD; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + int i = threadIdx.x; + int ixj = i ^ j; + if (i < TIER0_64_GROUP_THRESHOLD && ixj > i) { + bool asc = ((i & k) == 0); + float a = grp_smem[i], b = grp_smem[ixj]; + if (asc ? (a > b) : (a < b)) { + grp_smem[i] = b; + grp_smem[ixj] = a; + } + } + __syncthreads(); + } + } + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + if (threadIdx.x < n_grp) { + float v = grp_smem[threadIdx.x]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + lo = 0; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_grp = lo; + hi = n_grp; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (grp_smem[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_grp = lo - n_lt_grp; + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && + (threadIdx.x == 0 || v != grp_smem[threadIdx.x - 1])) { + double combined = (double)(n_eq_ref + n_eq_grp); + if (combined > 1.0) { + local_tie_delta += combined * combined * combined - combined; + } + if (n_eq_ref > 1) { + double cr = (double)n_eq_ref; + local_tie_delta -= cr * cr * cr - cr; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Tier 2 fused kernel: no-sort direct rank for medium groups. +// +// Avoids the smem bitonic sort for groups in (skip_n_grp_le, +// max_n_grp_le]. Ranks are computed from ref binary searches plus an +// in-group scan over unsorted shared values. Tie correction starts from +// ref_tie_sums[col] and adds only group-only / ref-overlap deltas. +// ============================================================================ + +__global__ void ovo_medium_unsorted_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int skip_n_grp_le, int max_n_grp_le) { + int col = blockIdx.x; + int grp = blockIdx.y; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + if (n_grp <= skip_n_grp_le || n_grp > max_n_grp_le) return; + + extern __shared__ char smem_raw[]; + float* grp_smem = (float*)smem_raw; + double* warp_buf = (double*)(smem_raw + max_n_grp_le * sizeof(float)); + + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) + grp_smem[i] = grp_col[i]; + __syncthreads(); + + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + double local_tie_delta = 0.0; + + for (int i = threadIdx.x; i < n_grp; i += blockDim.x) { + float v = grp_smem[i]; + + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + int n_lt_grp = 0; + int n_eq_grp = 0; + bool first_in_grp = true; + for (int j = 0; j < n_grp; ++j) { + float w = grp_smem[j]; + if (w < v) ++n_lt_grp; + if (w == v) { + ++n_eq_grp; + if (j < i) first_in_grp = false; + } + } + + local_sum += (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp) + 1.0) / 2.0; + + if (compute_tie_corr && first_in_grp) { + double cg = (double)n_eq_grp; + double cr = (double)n_eq_ref; + double group_tie = (cg > 1.0) ? (cg * cg * cg - cg) : 0.0; + local_tie_delta += group_tie; + if (cr > 0.0) { + double combined = cr + cg; + double ref_tie = (cr > 1.0) ? (cr * cr * cr - cr) : 0.0; + local_tie_delta += combined * combined * combined - combined - + ref_tie - group_tie; + } + } + } + + double total = block_reduce_sum(local_sum, warp_buf); + if (threadIdx.x == 0) rank_sums[grp * n_cols + col] = total; + + if (!compute_tie_corr) return; + __syncthreads(); + + double tie_delta = block_reduce_sum(local_tie_delta, warp_buf); + if (threadIdx.x == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + double tie_sum = ref_tie_sums[col] + tie_delta; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} + +// ============================================================================ +// Warp-scoped tie correction for Tier 0. +// +// Sorted values live in a 32-lane register (one per lane, with unused lanes +// carrying +INF). Walks unique values via lane-step differentials and +// counts ties across the sorted ref column via binary search. All the +// sync is __syncwarp — no smem, no __syncthreads. +// ============================================================================ + +__device__ __forceinline__ double tier0_tie_sum_warp(const float* ref_col, + int n_ref, float v_lane, + int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_tie = 0.0; + + // Pass 1: for each unique value in ref_col, count occurrences in ref and + // in the sorted group (held in register v_lane across 32 lanes). + for (int base = 0; base < n_ref; base += 32) { + int i = base + lane; + bool in_ref_lane = (i < n_ref); + float v = in_ref_lane ? ref_col[i] : 0.0f; + bool is_first = in_ref_lane && ((i == 0) || (v != ref_col[i - 1])); + int cnt_ref = 0; + if (is_first) { + // Count in ref: upper_bound from i+1 + int lo = i + 1, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + cnt_ref = lo - i; + } + + // Count in grp: look up how many lanes hold v_lane == v. All lanes + // execute the shuffle loop; only lanes owning a unique ref value use + // the result. + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + float vi = __shfl_sync(0xffffffff, v_lane, lane_i); + if (is_first && lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (is_first) { + int cnt = cnt_ref + cnt_grp; + if (cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + } + + // Pass 2: unique values in grp that are absent from ref. + // Walk lanes 0..n_grp-1; for each lane whose v differs from prev lane's, + // binary-search ref for v. If not present, count consecutive matching + // lanes (tie block). + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + bool in_ref = false; + if (first_in_grp) { + // Binary search in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + in_ref = (lo < n_ref) && (ref_col[lo] == v); + } + + // Count how many lanes ≥ this lane hold the same v. Keep the shuffle + // uniform across active lanes even though only unique, ref-absent + // group values consume the count. + int cnt = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (first_in_grp && !in_ref && lane_i >= lane && lane_i < n_grp && + vi == v) { + ++cnt; + } + } + if (first_in_grp && !in_ref && cnt > 1) { + double t = (double)cnt; + local_tie += t * t * t - t; + } + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie += __shfl_down_sync(0xffffffff, local_tie, off); + return local_tie; // meaningful on lane 0. +} + +__device__ __forceinline__ double tier0_tie_delta_warp( + const float* ref_col, int n_ref, float v_lane, int n_grp, + unsigned int active_mask) { + int lane = threadIdx.x & 31; + double local_delta = 0.0; + + if (lane < n_grp) { + float v = v_lane; + float prev_lane_v = + __shfl_sync(active_mask, v_lane, (lane > 0) ? lane - 1 : 0); + float v_prev = + (lane > 0) ? prev_lane_v : __int_as_float(0xff800000); // -INF + bool first_in_grp = (lane == 0) || (v != v_prev); + + int cnt_grp = 0; +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + int src_lane = (lane_i < n_grp) ? lane_i : 0; + float vi = __shfl_sync(active_mask, v_lane, src_lane); + if (lane_i < n_grp && vi == v) ++cnt_grp; + } + + if (first_in_grp) { + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int ref_lb = lo; + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int cnt_ref = lo - ref_lb; + + double combined = (double)(cnt_ref + cnt_grp); + if (combined > 1.0) { + local_delta += combined * combined * combined - combined; + } + if (cnt_ref > 1) { + double cr = (double)cnt_ref; + local_delta -= cr * cr * cr - cr; + } + } + } + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_delta += __shfl_down_sync(0xffffffff, local_delta, off); + return local_delta; // meaningful on lane 0. +} + +// ============================================================================ +// Tier 0 fused kernel: warp-per-(col, group) pair, 8 warps packed per block. +// +// Each warp independently: +// 1. Loads ≤ 32 group values into a single register (one per lane, +// padded with +INF). +// 2. Bitonic-sorts via __shfl_xor_sync — no smem, no __syncthreads. +// 3. Binary-searches into sorted ref for each lane's value and +// accumulates the rank-sum term. +// 4. Warp-shuffle reduces to lane 0 and writes rank_sums / tie_corr. +// +// 8 (col, group) pairs per block cuts block count 8× vs the block-per-pair +// Tier 1, and the lack of __syncthreads / smem sort lets each warp run +// independently at full throughput. +// +// Grid: (n_cols, ceil(n_groups / 8)), Block: 256. +// ============================================================================ + +__global__ void ovo_warp_sort_rank_kernel( + const float* __restrict__ ref_sorted, const float* __restrict__ grp_dense, + const int* __restrict__ grp_offsets, + const double* __restrict__ ref_tie_sums, double* __restrict__ rank_sums, + double* __restrict__ tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr) { + constexpr int WARPS_PER_BLOCK = 8; + int warp_id = threadIdx.x >> 5; + int lane = threadIdx.x & 31; + + int col = blockIdx.x; + int grp = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (col >= n_cols || grp >= n_groups) return; + + int g_start = grp_offsets[grp]; + int g_end = grp_offsets[grp + 1]; + int n_grp = g_end - g_start; + + // This kernel only handles groups that fit in a single warp (one value + // per lane). Larger groups are delegated to Tier 1/3 in a co-launched + // kernel; since each group owns its own row in rank_sums/tie_corr, the + // two kernels interlace into the output without conflict. + if (n_grp > TIER0_GROUP_THRESHOLD) return; + + if (n_grp == 0) { + if (lane == 0) { + rank_sums[grp * n_cols + col] = 0.0; + if (compute_tie_corr) tie_corr[grp * n_cols + col] = 1.0; + } + return; + } + + // One value per lane, pad with +INF so sort pushes them to the end. + const float POS_INF = __int_as_float(0x7f800000); + const float* grp_col = grp_dense + (long long)col * n_all_grp + g_start; + float x = (lane < n_grp) ? grp_col[lane] : POS_INF; + unsigned int active_mask = __ballot_sync(0xffffffff, lane < n_grp); + + // Warp-shuffle bitonic sort (ascending) — 32 elements in registers. + for (int k = 1; k <= 16; k <<= 1) { + for (int j = k; j > 0; j >>= 1) { + float y = __shfl_xor_sync(0xffffffff, x, j); + bool asc = (((lane & (k << 1)) == 0)); + bool take_min = (((lane & j) == 0) == asc); + x = take_min ? fminf(x, y) : fmaxf(x, y); + } + } + + // After sort, x[lane] holds the lane-th smallest group value (lanes + // ≥ n_grp hold +INF). Binary-search each value into the sorted ref. + const float* ref_col = ref_sorted + (long long)col * n_ref; + double local_sum = 0.0; + + if (lane < n_grp) { + float v = x; + // Lower bound in ref. + int lo = 0, hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] < v) + lo = m + 1; + else + hi = m; + } + int n_lt_ref = lo; + // Upper bound in ref. + hi = n_ref; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (ref_col[m] <= v) + lo = m + 1; + else + hi = m; + } + int n_eq_ref = lo - n_lt_ref; + + // In-group counts: in the sorted warp-register x, count lanes < this + // one that hold strictly less, and lanes with equal value. + int n_lt_grp = 0; + int n_eq_grp_offset = 0; // tied lanes strictly before this one + int n_eq_grp_after = 1; // count self +#pragma unroll + for (int lane_i = 0; lane_i < TIER0_GROUP_THRESHOLD; ++lane_i) { + if (lane_i >= n_grp) continue; + float vi = __shfl_sync(active_mask, v, lane_i); + if (lane_i < lane) { + if (vi < v) + ++n_lt_grp; + else if (vi == v) + ++n_eq_grp_offset; + } else if (lane_i > lane) { + if (vi == v) ++n_eq_grp_after; + } + } + int n_eq_grp_total = n_eq_grp_offset + n_eq_grp_after; + // Contribution: rank = n_lt_ref + n_lt_grp + (n_eq_ref + + // n_eq_grp_total + 1) / 2, but we sum per lane so each tie lane + // gets the same mid-rank. This matches the Tier 1 accumulation. + local_sum = (double)(n_lt_ref + n_lt_grp) + + ((double)(n_eq_ref + n_eq_grp_total) + 1.0) / 2.0; + } + + // Warp reduce. +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_sum += __shfl_down_sync(0xffffffff, local_sum, off); + if (lane == 0) rank_sums[grp * n_cols + col] = local_sum; + + if (!compute_tie_corr) return; + + // Warp-scoped tie correction. + double tie_sum; + if (ref_tie_sums != nullptr) { + tie_sum = ref_tie_sums[col] + + tier0_tie_delta_warp(ref_col, n_ref, x, n_grp, active_mask); + } else { + tie_sum = tier0_tie_sum_warp(ref_col, n_ref, x, n_grp, active_mask); + } + if (lane == 0) { + int n = n_ref + n_grp; + double dn = (double)n; + double denom = dn * dn * dn - dn; + tie_corr[grp * n_cols + col] = + (denom > 0.0) ? (1.0 - tie_sum / denom) : 1.0; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu index d25f7d0f..d314b289 100644 --- a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu @@ -1,70 +1,500 @@ #include +#include + +#include +#include +#include + #include "../nb_types.h" #include "kernels_wilcoxon.cuh" +#include "wilcoxon_fast_common.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovo_kernels.cuh" using namespace nb::literals; -// Constants for kernel launch configuration -constexpr int WARP_SIZE = 32; -constexpr int MAX_THREADS_PER_BLOCK = 512; - -static inline int round_up_to_warp(int n) { - int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; -} - -static inline void launch_tie_correction(const double* sorted_vals, - double* correction, int n_rows, - int n_cols, cudaStream_t stream) { +static inline void launch_ovr_rank_dense( + const float* sorted_vals, const int* sorter, const int* group_codes, + double* rank_sums, double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, cudaStream_t stream) { int threads_per_block = round_up_to_warp(n_rows); dim3 block(threads_per_block); dim3 grid(n_cols); - tie_correction_kernel<<>>(sorted_vals, correction, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(tie_correction_kernel); + ovr_rank_dense_kernel<<>>( + sorted_vals, sorter, group_codes, rank_sums, tie_corr, n_rows, n_cols, + n_groups, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovr_rank_dense_kernel); } -static inline void launch_average_rank(const double* sorted_vals, - const int* sorter, double* ranks, - int n_rows, int n_cols, - cudaStream_t stream) { - int threads_per_block = round_up_to_warp(n_rows); - dim3 block(threads_per_block); - dim3 grid(n_cols); - average_rank_kernel<<>>(sorted_vals, sorter, ranks, - n_rows, n_cols); - CUDA_CHECK_LAST_ERROR(average_rank_kernel); +static void launch_ovr_rank_dense_streaming( + const float* block, const int* group_codes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_rows == 0 || n_cols == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) { + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + } + + size_t sub_items = (size_t)n_rows * sub_batch_cols; + int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch"); + + size_t cub_temp_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, sub_items_i32, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; ++i) { + cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); + } + + cudaEvent_t inputs_ready; + cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); + cudaEventRecord(inputs_ready, upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cudaStreamWaitEvent(streams[i], inputs_ready, 0); + } + + RmmScratchPool pool; + struct StreamBuf { + float* keys_out; + int* vals_in; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + bufs[s].keys_out = pool.alloc(sub_items); + bufs[s].vals_in = pool.alloc(sub_items); + bufs[s].vals_out = pool.alloc(sub_items); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + } + + int tpb_rank = round_up_to_warp(n_rows); + bool use_gmem = false; + size_t smem_rank = ovr_smem_config(n_groups, use_gmem); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols, + "Dense OVR active sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + upload_linear_offsets(buf.seg_offsets, sb_cols, n_rows, stream); + fill_row_indices_kernel<<>>( + buf.vals_in, n_rows, sb_cols); + CUDA_CHECK_LAST_ERROR(fill_row_indices_kernel); + + const float* keys_in = block + (size_t)col * n_rows; + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, keys_in, buf.keys_out, buf.vals_in, + buf.vals_out, sb_items, sb_cols, buf.seg_offsets, + buf.seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + + if (use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_from_sorted_kernel<<>>( + buf.keys_out, buf.vals_out, group_codes, buf.sub_rank_sums, + buf.sub_tie_corr, n_rows, sb_cols, n_groups, compute_tie_corr, + use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_from_sorted_kernel); + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + ++batch_idx; + } + + for (int s = 0; s < n_streams; ++s) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("CUDA error in dense OVR streaming rank: ") + + cudaGetErrorString(err)); + } + } + cudaEventDestroy(inputs_ready); + for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); +} + +static void launch_ovo_rank_dense_tiered_impl( + const float* ref_data, bool ref_is_sorted, const float* grp_data, + const int* grp_offsets, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols, cudaStream_t upstream_stream) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0 || n_groups == 0) return; + if (sub_batch_cols <= 0) sub_batch_cols = SUB_BATCH_COLS; + + std::vector h_offsets(n_groups + 1); + cudaStreamSynchronize(upstream_stream); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "Dense OVO reference sub-batch"); + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "Dense OVO group sub-batch"); + + size_t grp_cub_temp_bytes = 0; + if (needs_tier3) { + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "Dense OVO group segment count"); + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, grp_cub_temp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t ref_cub_temp_bytes = 0; + if (!ref_is_sorted) { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_temp_bytes, fk, fk, sub_ref_items_i32, + sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; ++i) { + cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking); + } + + cudaEvent_t inputs_ready; + cudaEventCreateWithFlags(&inputs_ready, cudaEventDisableTiming); + cudaEventRecord(inputs_ready, upstream_stream); + for (int i = 0; i < n_streams; ++i) { + cudaStreamWaitEvent(streams[i], inputs_ready, 0); + } + + RmmScratchPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* ref_sorted; + int* ref_seg_offsets; + uint8_t* ref_cub_temp; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* grp_cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; ++s) { + if (ref_is_sorted) { + bufs[s].ref_sorted = nullptr; + bufs[s].ref_seg_offsets = nullptr; + bufs[s].ref_cub_temp = nullptr; + } else { + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_cub_temp = pool.alloc(ref_cub_temp_bytes); + } + bufs[s].grp_cub_temp = + needs_tier3 ? pool.alloc(grp_cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = checked_int_product((size_t)n_sort_groups, + (size_t)sub_batch_cols, + "Dense OVO group segment buffer"); + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "Dense OVO active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "Dense OVO active group sub-batch"); + int s = batch_idx % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = ref_data + (size_t)col * n_ref; + const float* grp_sub = grp_data + (size_t)col * n_all_grp; + if (!ref_is_sorted) { + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + size_t temp = ref_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.ref_cub_temp, temp, ref_sub, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + ref_sub = buf.ref_sorted; + } + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(ref_sub, grp_sub, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, + n_all_grp, sb_cols, n_groups, compute_tie_corr, + skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, grp_sub, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "Dense OVO active group segment count"); + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + + size_t temp = grp_cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.grp_cub_temp, temp, grp_sub, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + ++batch_idx; + } + + for (int s = 0; s < n_streams; ++s) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) { + throw std::runtime_error( + std::string("CUDA error in dense OVO tiered rank: ") + + cudaGetErrorString(err)); + } + } + cudaEventDestroy(inputs_ready); + for (int s = 0; s < n_streams; ++s) cudaStreamDestroy(streams[s]); +} + +static void launch_ovo_rank_dense_tiered( + const float* ref_sorted, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + launch_ovo_rank_dense_tiered_impl(ref_sorted, true, grp_data, grp_offsets, + rank_sums, tie_corr, n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, upstream_stream); +} + +static void launch_ovo_rank_dense_tiered_unsorted_ref( + const float* ref_data, const float* grp_data, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, + cudaStream_t upstream_stream) { + launch_ovo_rank_dense_tiered_impl(ref_data, false, grp_data, grp_offsets, + rank_sums, tie_corr, n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, upstream_stream); } template void register_bindings(nb::module_& m) { m.doc() = "CUDA kernels for Wilcoxon rank-sum test"; - // Tie correction kernel m.def( - "tie_correction", - [](gpu_array_f sorted_vals, - gpu_array correction, int n_rows, int n_cols, + "ovo_rank_dense_tiered", + [](gpu_array_f ref_sorted, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_tie_correction(sorted_vals.data(), correction.data(), n_rows, - n_cols, (cudaStream_t)stream); + launch_ovo_rank_dense_tiered(ref_sorted.data(), grp_data.data(), + grp_offsets.data(), rank_sums.data(), + tie_corr.data(), n_ref, n_all_grp, + n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); }, - "sorted_vals"_a, "correction"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, + "ref_sorted"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); - // Average rank kernel m.def( - "average_rank", - [](gpu_array_f sorted_vals, + "ovo_rank_dense_tiered_unsorted_ref", + [](gpu_array_f ref_data, + gpu_array_f grp_data, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, int sub_batch_cols, + std::uintptr_t stream) { + launch_ovo_rank_dense_tiered_unsorted_ref( + ref_data.data(), grp_data.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols, + (cudaStream_t)stream); + }, + "ref_data"_a, "grp_data"_a, "grp_offsets"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, + "n_groups"_a, "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS, + "stream"_a = 0); + + m.def( + "ovr_rank_dense", + [](gpu_array_f sorted_vals, gpu_array_f sorter, - gpu_array_f ranks, int n_rows, int n_cols, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, std::uintptr_t stream) { + launch_ovr_rank_dense(sorted_vals.data(), sorter.data(), + group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, + compute_tie_corr, (cudaStream_t)stream); + }, + "sorted_vals"_a, "sorter"_a, "group_codes"_a, "rank_sums"_a, + "tie_corr"_a, nb::kw_only(), "n_rows"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "stream"_a = 0); + + m.def( + "ovr_rank_dense_streaming", + [](gpu_array_f block, + gpu_array_c group_codes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols, std::uintptr_t stream) { - launch_average_rank(sorted_vals.data(), sorter.data(), ranks.data(), - n_rows, n_cols, (cudaStream_t)stream); + launch_ovr_rank_dense_streaming( + block.data(), group_codes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols, (cudaStream_t)stream); }, - "sorted_vals"_a, "sorter"_a, "ranks"_a, nb::kw_only(), "n_rows"_a, - "n_cols"_a, "stream"_a = 0); + "block"_a, "group_codes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), + "n_rows"_a, "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS, "stream"_a = 0); } NB_MODULE(_wilcoxon_cuda, m) { diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh new file mode 100644 index 00000000..ec723b55 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh @@ -0,0 +1,381 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "../nb_types.h" // for CUDA_CHECK_LAST_ERROR + +void* wilcoxon_rmm_allocate(size_t bytes); +void wilcoxon_rmm_deallocate(void* ptr, size_t bytes); + +constexpr int WARP_SIZE = 32; +constexpr int MAX_THREADS_PER_BLOCK = 512; +constexpr int N_STREAMS = 4; +constexpr int SUB_BATCH_COLS = 64; +constexpr int BEGIN_BIT = 0; +constexpr int END_BIT = 32; +// Default thread-per-block for utility kernels (extract, gather, offsets, +// etc.). +constexpr int UTIL_BLOCK_SIZE = 256; +// Scratch slots for warp-level reduction (one slot per warp, 32 warps max). +constexpr int WARP_REDUCE_BUF = 32; +// Max group size for the super-fast "warp-per-(col,group)" fused kernel +// (Tier 0). Each warp sorts and ranks one (col, group) pair entirely in +// registers via warp-shuffle bitonic sort — no smem sort buffer, no +// __syncthreads(). Blocks pack 8 warps so block launch overhead is +// amortised 8× across (col, group) work items. This path is the fast +// route for per-celltype perturbation-style workloads where most test +// groups have only a few dozen cells. +constexpr int TIER0_GROUP_THRESHOLD = 32; +// Second small-group tier for perturbation workloads where most groups are +// slightly larger than one warp. Uses one compact shared-memory sort block per +// (column, group), avoiding the heavier Tier 2 in-group scan. +constexpr int TIER0_64_GROUP_THRESHOLD = 64; +// Medium-group cutoff for the unsorted direct-rank kernel. For perturbation +// workloads most groups sit below this range, where avoiding a full smem +// bitonic sort wins despite the O(n^2) in-group count. +constexpr int TIER2_GROUP_THRESHOLD = 512; +// Max group size for the fused smem-sort rank kernel (Tier 1 fast path). +// Beyond this, fall back to CUB segmented sort + binary-search rank kernel. +constexpr int TIER1_GROUP_THRESHOLD = 2500; +// Per-stream dense slab budget (float32 items). Dynamic sub-batching sizes +// each group's column batch so that (n_g × eff_sb_cols) ≤ this. Bigger = +// fewer kernel launches; smaller = less per-stream memory. 128M items × 4B = +// 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream. +constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024; + +static inline size_t wilcoxon_max_smem_per_block() { + int device = 0; + int max_smem = 0; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock, + device); + return (size_t)max_smem; +} + +static inline int checked_cub_items(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds CUB int item limit"); + } + return (int)count; +} + +static inline int checked_int_span(size_t count, const char* context) { + if (count > (size_t)std::numeric_limits::max()) { + throw std::runtime_error(std::string(context) + + " exceeds int32 offset limit"); + } + return (int)count; +} + +static inline int checked_int_product(size_t a, size_t b, const char* context) { + if (a != 0 && b > (size_t)std::numeric_limits::max() / a) { + throw std::runtime_error(std::string(context) + + " exceeds int32 item limit"); + } + return (int)(a * b); +} + +// --------------------------------------------------------------------------- +// RAII guard for cudaHostRegister. Unregisters on scope exit even when an +// exception unwinds — prevents leaked host pinning on stream-sync failures. +// --------------------------------------------------------------------------- +struct HostRegisterGuard { + void* ptr = nullptr; + + HostRegisterGuard() = default; + HostRegisterGuard(void* p, size_t bytes, unsigned int flags = 0) { + if (p && bytes > 0) { + cudaError_t err = cudaHostRegister(p, bytes, flags); + if (err != cudaSuccess) { + // Already-registered memory belongs to another owner; use it + // without unregistering here. Other failures mean mapped reads + // would be unsafe, so surface them immediately. + if (err == cudaErrorHostMemoryAlreadyRegistered) { + cudaGetLastError(); // clear sticky error flag + } else { + throw std::runtime_error( + std::string("cudaHostRegister failed (") + + std::to_string((size_t)bytes) + + " bytes, flags=" + std::to_string(flags) + + "): " + cudaGetErrorString(err)); + } + } else { + ptr = p; + } + } + } + ~HostRegisterGuard() { + if (ptr) cudaHostUnregister(ptr); + } + HostRegisterGuard(const HostRegisterGuard&) = delete; + HostRegisterGuard& operator=(const HostRegisterGuard&) = delete; + HostRegisterGuard(HostRegisterGuard&& other) noexcept : ptr(other.ptr) { + other.ptr = nullptr; + } + HostRegisterGuard& operator=(HostRegisterGuard&& other) noexcept { + if (this != &other) { + if (ptr) cudaHostUnregister(ptr); + ptr = other.ptr; + other.ptr = nullptr; + } + return *this; + } +}; + +// --------------------------------------------------------------------------- +// Small allocation pool for temporary CUDA buffers. Uses the current RMM device +// resource so scratch participates in the same pool as CuPy/RAPIDS allocations. +// --------------------------------------------------------------------------- +struct RmmScratchPool { + struct Allocation { + void* ptr = nullptr; + size_t bytes = 0; + }; + std::vector bufs; + + ~RmmScratchPool() { + for (Allocation alloc : bufs) { + if (!alloc.ptr) continue; + wilcoxon_rmm_deallocate(alloc.ptr, alloc.bytes); + } + } + + template + T* alloc(size_t count) { + if (count == 0) count = 1; + if (count > std::numeric_limits::max() / sizeof(T)) { + throw std::runtime_error( + "Wilcoxon scratch allocation size overflow"); + } + size_t bytes = count * sizeof(T); + void* ptr = wilcoxon_rmm_allocate(bytes); + bufs.push_back({ptr, bytes}); + return static_cast(ptr); + } +}; + +struct ScopedCudaBuffer { + void* ptr = nullptr; + size_t bytes = 0; + + explicit ScopedCudaBuffer(size_t requested_bytes) { + bytes = requested_bytes == 0 ? 1 : requested_bytes; + ptr = wilcoxon_rmm_allocate(bytes); + } + + ~ScopedCudaBuffer() { + if (!ptr) return; + wilcoxon_rmm_deallocate(ptr, bytes); + } + + void* data() { + return ptr; + } + + ScopedCudaBuffer(const ScopedCudaBuffer&) = delete; + ScopedCudaBuffer& operator=(const ScopedCudaBuffer&) = delete; +}; + +static inline int round_up_to_warp(int n) { + int rounded = ((n + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + return (rounded < MAX_THREADS_PER_BLOCK) ? rounded : MAX_THREADS_PER_BLOCK; +} + +/** Fill linear segment offsets [0, stride, 2*stride, ..., n_segments*stride] + * on-device. One thread per output slot. */ +__global__ void fill_linear_offsets_kernel(int* __restrict__ out, + int n_segments, int stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i <= n_segments) out[i] = i * stride; +} + +/** Fill per-row stats codes for a pack of K groups. + * Given pack_grp_offsets (size K+1, relative to pack start), write + * stats_codes[r] = base_slot + group_idx_of_row_r for r in [0, pack_n_rows). + * Binary search within the K+1 offsets. */ +__global__ void fill_pack_stats_codes_kernel( + const int* __restrict__ pack_grp_offsets, int* __restrict__ stats_codes, + int K, int base_slot) { + int r = blockIdx.x * blockDim.x + threadIdx.x; + int pack_n_rows = pack_grp_offsets[K]; + if (r >= pack_n_rows) return; + int lo = 0, hi = K; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (pack_grp_offsets[m + 1] <= r) + lo = m + 1; + else + hi = m; + } + stats_codes[r] = base_slot + lo; +} + +/** Rebase a slice of indptr: out[i] = indptr[col + i] - indptr[col]. + * Grid-strided: supports arbitrary `count` (no single-block thread limit). + * Templated so that 64-bit global indptrs can produce 32-bit pack-local + * indptrs (per-pack nnz always fits in int32 thanks to the memory budget). + */ +template +__global__ void rebase_indptr_kernel(const IdxIn* __restrict__ indptr, + IdxOut* __restrict__ out, int col, + int count) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) out[i] = (IdxOut)(indptr[col + i] - indptr[col]); +} + +/** Fused gather + cast-to-float32 + stats accumulation, reading from mapped + * pinned host memory. Block-per-row; threads in the block cooperate on the + * row's nnz. Each nnz is read from host over PCIe exactly once — no + * intermediate native-dtype GPU buffer, no second GPU pass. + * + * h_data / h_indices: device-accessible pointers into mapped pinned host + * memory (cudaHostRegisterMapped). + * d_indptr_full: full-matrix indptr on device. + * d_row_ids: rows to gather (size n_target_rows). + * d_out_indptr: pre-computed compacted indptr, size n_target_rows+1 with + * out_indptr[i+1] - out_indptr[i] equal to the source row's + * nnz. + * + * Slot dispatch: + * d_stats_codes != nullptr → slot = d_stats_codes[r]; otherwise slot = + * fixed_slot (used for the Ref phase where every row maps to the same + * slot). slot ∉ [0, n_groups_stats) skips accumulation. + */ +template +__global__ void csr_gather_cast_accumulate_mapped_kernel( + const InT* __restrict__ h_data, const IndexT* __restrict__ h_indices, + const IndptrT* __restrict__ d_indptr_full, + const int* __restrict__ d_row_ids, const int* __restrict__ d_out_indptr, + const int* __restrict__ d_stats_codes, int fixed_slot, + float* __restrict__ d_out_data_f32, int* __restrict__ d_out_indices, + double* __restrict__ group_sums, double* __restrict__ group_sq_sums, + double* __restrict__ group_nnz, int n_target_rows, int n_cols, + int n_groups_stats, bool compute_sums, bool compute_sq_sums, + bool compute_nnz) { + int r = blockIdx.x; + if (r >= n_target_rows) return; + int src_row = d_row_ids[r]; + IndptrT rs = d_indptr_full[src_row]; + IndptrT re = d_indptr_full[src_row + 1]; + int row_nnz = (int)(re - rs); + int ds = d_out_indptr[r]; + int slot = (d_stats_codes != nullptr) ? d_stats_codes[r] : fixed_slot; + bool accumulate = (slot >= 0 && slot < n_groups_stats); + for (int i = threadIdx.x; i < row_nnz; i += blockDim.x) { + InT v_in = h_data[rs + i]; + int c = (int)h_indices[rs + i]; + double v = (double)v_in; + d_out_data_f32[ds + i] = (float)v_in; + d_out_indices[ds + i] = c; + if (accumulate) { + if (compute_sums) { + atomicAdd(&group_sums[(size_t)slot * n_cols + c], v); + } + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)slot * n_cols + c], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)slot * n_cols + c], 1.0); + } + } + } +} + +/** Fill linear segment offsets [0, stride, 2*stride, ...] on device. + * Runs on the supplied stream so it doesn't serialize multi-stream pipelines. + */ +static inline void upload_linear_offsets(int* d_offsets, int n_segments, + int stride, cudaStream_t stream) { + int count = n_segments + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_linear_offsets_kernel<<>>( + d_offsets, n_segments, stride); + CUDA_CHECK_LAST_ERROR(fill_linear_offsets_kernel); +} + +// ============================================================================ +// CSR → dense F-order extraction (templated on data type) +// ============================================================================ + +template +__global__ void csr_extract_dense_kernel(const T* __restrict__ data, + const int* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_ids, + T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_target) return; + + int row = row_ids[tid]; + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo; p < re; ++p) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + tid] = data[p]; + } +} + +template +__global__ void csr_extract_dense_identity_rows_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + int lo = rs, hi = re; + while (lo < hi) { + int m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + + for (int p = lo + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_stop) break; + out[(long long)(c - col_start) * n_target + row] = data[p]; + } +} + +template +__global__ void csr_extract_dense_identity_rows_unsorted_kernel( + const T* __restrict__ data, const int* __restrict__ indices, + const int* __restrict__ indptr, T* __restrict__ out, int n_target, + int col_start, int col_stop) { + int row = blockIdx.x; + if (row >= n_target) return; + + int rs = indptr[row]; + int re = indptr[row + 1]; + + for (int p = rs + threadIdx.x; p < re; p += blockDim.x) { + int c = indices[p]; + if (c >= col_start && c < col_stop) { + out[(long long)(c - col_start) * n_target + row] = data[p]; + } + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh new file mode 100644 index 00000000..b60b87ff --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_device_sparse.cuh @@ -0,0 +1,548 @@ +#pragma once + +/** + * CSR-direct OVO streaming pipeline. + * + * One C++ call does everything. Reference rows are extracted and sorted once + * across all columns, then each group sub-batch ranks against that cached + * reference slice. This mirrors the fast host-CSR path and avoids redoing the + * reference dense extraction + segmented sort for every column sub-batch. + */ +static void ovo_streaming_csr_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* ref_row_ids, const int* grp_row_ids, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + + size_t max_ref_cols = 2147483647LL / (size_t)n_ref; + if (max_ref_cols == 0) { + throw std::runtime_error( + "OVO device CSR reference group exceeds CUB int item limit"); + } + int ref_cache_cols = std::min(n_cols, (int)max_ref_cols); + size_t free_bytes = 0; + size_t total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess) { + size_t bytes_per_col = (size_t)n_ref * sizeof(float) * 2; + size_t target_bytes = free_bytes / 3; + if (bytes_per_col > 0 && target_bytes >= bytes_per_col) { + size_t mem_cols = target_bytes / bytes_per_col; + if (mem_cols > 0 && mem_cols < (size_t)ref_cache_cols) { + ref_cache_cols = (int)mem_cols; + } + } + } + if (ref_cache_cols < 1) ref_cache_cols = 1; + + RmmScratchPool pool; + + size_t cub_temp_bytes = 0; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSR group sub-batch"); + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment count"); + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = cub_grp_bytes; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + cudaStream_t ref_stream; + cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); + + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* grp_dense; + float* grp_sorted; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].cub_temp = + needs_tier3 ? pool.alloc(cub_temp_bytes) : nullptr; + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSR group segment buffer"); + bufs[s].grp_seg_offsets = pool.alloc(max_seg); + bufs[s].grp_seg_ends = pool.alloc(max_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_extract = round_up_to_warp(std::max(n_ref, n_all_grp)); + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + for (int cache_col = 0; cache_col < n_cols; cache_col += ref_cache_cols) { + int cache_cols = std::min(ref_cache_cols, n_cols - cache_col); + size_t cache_ref_items = (size_t)n_ref * cache_cols; + int cache_ref_items_i32 = checked_cub_items( + cache_ref_items, "OVO device CSR reference cache"); + + ScopedCudaBuffer ref_dense_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_sorted_buf(cache_ref_items * sizeof(float)); + ScopedCudaBuffer ref_seg_offsets_buf((size_t)(cache_cols + 1) * + sizeof(int)); + float* d_ref_dense = (float*)ref_dense_buf.data(); + float* d_ref_sorted = (float*)ref_sorted_buf.data(); + int* d_ref_seg_offsets = (int*)ref_seg_offsets_buf.data(); + + cudaMemsetAsync(d_ref_dense, 0, cache_ref_items * sizeof(float), + ref_stream); + int tpb_ref_extract = round_up_to_warp(n_ref); + int ref_blk = (n_ref + tpb_ref_extract - 1) / tpb_ref_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, ref_row_ids, d_ref_dense, n_ref, + cache_col, cache_col + cache_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + + upload_linear_offsets(d_ref_seg_offsets, cache_cols, n_ref, ref_stream); + + size_t ref_cub_bytes = 0; + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, cache_ref_items_i32, cache_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + ScopedCudaBuffer ref_cub_temp_buf(ref_cub_bytes); + size_t ref_temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + ref_cub_temp_buf.data(), ref_temp, d_ref_dense, d_ref_sorted, + cache_ref_items_i32, cache_cols, d_ref_seg_offsets, + d_ref_seg_offsets + 1, BEGIN_BIT, END_BIT, ref_stream); + cudaStreamSynchronize(ref_stream); + + int col = cache_col; + int cache_stop = cache_col + cache_cols; + int batch_idx = 0; + while (col < cache_stop) { + int sb_cols = std::min(sub_batch_cols, cache_stop - col); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSR active group sub-batch"); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + const float* ref_sub = + d_ref_sorted + (size_t)(col - cache_col) * n_ref; + + cudaMemsetAsync(buf.grp_dense, 0, + sb_grp_items_actual * sizeof(float), stream); + { + int blk = (n_all_grp + tpb_extract - 1) / tpb_extract; + csr_extract_dense_kernel<<>>( + csr_data, csr_indices, csr_indptr, grp_row_ids, + buf.grp_dense, n_all_grp, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_extract_dense_kernel); + } + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(ref_sub, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium( + ref_sub, buf.grp_dense, grp_offsets, buf.ref_tie_sums, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sb_cols, + "OVO device CSR active group segment count"); + { + int blk = + (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.grp_sorted, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in OVO device CSR streaming: ") + + cudaGetErrorString(err)); + } + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); + cudaStreamDestroy(ref_stream); +} + +/** + * CSC-direct OVO streaming pipeline. + * + * Like the CSR variant, but extracts rows via lookup maps so it can operate on + * native CSC input without converting the whole matrix. + */ +static void ovo_streaming_csc_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* ref_row_map, const int* grp_row_map, const int* grp_offsets, + double* rank_sums, double* tie_corr, int n_ref, int n_all_grp, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + std::vector h_offsets(n_groups + 1); + cudaMemcpy(h_offsets.data(), grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + auto t1 = make_tier1_config(h_offsets.data(), n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = make_sort_group_ids(h_offsets.data(), n_groups, + TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "OVO device CSC reference sub-batch"); + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO device CSC group sub-batch"); + + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment count"); + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmScratchPool pool; + int* d_sort_group_ids = nullptr; + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* sub_rank_sums; + double* sub_tie_corr; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO device CSC group segment buffer"); + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_items_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO device CSC active reference sub-batch"); + int sb_grp_items_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO device CSC active group sub-batch"); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, ref_row_map, buf.ref_dense, + n_ref, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_items_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_items_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + csc_data, csc_indices, csc_indptr, grp_row_map, buf.grp_dense, + n_all_grp, col); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + if (run_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, buf.sub_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, + stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(buf.ref_sorted, buf.grp_dense, grp_offsets, + buf.ref_tie_sums, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, grp_offsets, buf.sub_rank_sums, + buf.sub_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sb_cols, + "OVO device CSC active group segment count"); + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_items_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, grp_offsets, + buf.sub_rank_sums, buf.sub_tie_corr, n_ref, n_all_grp, + sb_cols, n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(tie_corr + col, n_cols * sizeof(double), + buf.sub_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in OVO device CSC streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh new file mode 100644 index 00000000..53f27bbe --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_host_sparse.cuh @@ -0,0 +1,917 @@ +#pragma once + +/** + * Host-streaming CSC OVO pipeline. + * + * CSC arrays live on host. Only the sparse data for each sub-batch of + * columns is transferred to GPU. Row maps + group offsets are uploaded once. + * Results are written back to host per sub-batch. + */ +template +static void ovo_streaming_csc_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_ref_row_map, const int* h_grp_row_map, + const int* h_grp_offsets, const int* h_stats_codes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_ref, int n_all_grp, int n_rows, int n_cols, + int n_groups, int n_groups_stats, bool compute_tie_corr, + bool compute_sq_sums, bool compute_nnz, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_all_grp == 0) return; + + // ---- Tier dispatch from host offsets ---- + auto t1 = make_tier1_config(h_grp_offsets, n_groups); + int max_grp_size = t1.max_grp_size; + bool use_tier1 = t1.any_above_t2 && t1.use_tier1; + bool needs_tier3 = t1.any_above_t2 && !use_tier1; + int padded_grp_size = t1.padded_grp_size; + int tier1_tpb = t1.tier1_tpb; + size_t tier1_smem = t1.tier1_smem; + std::vector h_sort_group_ids; + int n_sort_groups = n_groups; + if (needs_tier3) { + h_sort_group_ids = + make_sort_group_ids(h_grp_offsets, n_groups, TIER2_GROUP_THRESHOLD); + n_sort_groups = (int)h_sort_group_ids.size(); + } + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + size_t sub_ref_items = (size_t)n_ref * sub_batch_cols; + size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols; + int sub_ref_items_i32 = + checked_cub_items(sub_ref_items, "OVO host CSC reference sub-batch"); + int sub_grp_items_i32 = + checked_cub_items(sub_grp_items, "OVO host CSC group sub-batch"); + + // CUB temp + size_t cub_ref_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_ref_bytes, fk, fk, sub_ref_items_i32, sub_batch_cols, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + size_t cub_temp_bytes = cub_ref_bytes; + if (needs_tier3) { + size_t cub_grp_bytes = 0; + int max_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC group segment count"); + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg, + doff, doff + 1, BEGIN_BIT, END_BIT); + cub_temp_bytes = std::max(cub_ref_bytes, cub_grp_bytes); + } + + // Max nnz across any sub-batch for sparse transfer buffer sizing + size_t max_nnz = 0; + for (int c = 0; c < n_cols; c += sub_batch_cols) { + int sb = std::min(sub_batch_cols, n_cols - c); + size_t nnz = (size_t)(h_indptr[c + sb] - h_indptr[c]); + if (nnz > max_nnz) max_nnz = nnz; + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmScratchPool pool; + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) { + off[i] = + checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), + "OVO host CSC rebased column offsets"); + } + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // GPU copies of row maps + group offsets + stats codes (uploaded once) + int* d_ref_row_map = pool.alloc(n_rows); + int* d_grp_row_map = pool.alloc(n_rows); + int* d_grp_offsets = pool.alloc(n_groups + 1); + int* d_stats_codes = pool.alloc(n_rows); + int* d_sort_group_ids = nullptr; + cudaMemcpy(d_ref_row_map, h_ref_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_map, h_grp_row_map, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets, h_grp_offsets, (n_groups + 1) * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_stats_codes, h_stats_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + if (needs_tier3) { + d_sort_group_ids = pool.alloc(h_sort_group_ids.size()); + cudaMemcpy(d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice); + } + + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* d_indptr; + float* ref_dense; + float* ref_sorted; + float* grp_dense; + float* grp_sorted; + int* ref_seg_offsets; + int* grp_seg_offsets; + int* grp_seg_ends; + uint8_t* cub_temp; + double* ref_tie_sums; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_indptr = pool.alloc(sub_batch_cols + 1); + bufs[s].ref_dense = pool.alloc(sub_ref_items); + bufs[s].ref_sorted = pool.alloc(sub_ref_items); + bufs[s].grp_dense = pool.alloc(sub_grp_items); + bufs[s].ref_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].ref_tie_sums = + (compute_tie_corr && + (t1.use_tier0 || t1.any_tier0_64 || t1.any_tier2)) + ? pool.alloc(sub_batch_cols) + : nullptr; + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups_stats * sub_batch_cols); + bufs[s].d_group_sq_sums = pool.alloc( + compute_sq_sums ? (size_t)n_groups_stats * sub_batch_cols : 1); + bufs[s].d_group_nnz = pool.alloc( + compute_nnz ? (size_t)n_groups_stats * sub_batch_cols : 1); + if (needs_tier3) { + bufs[s].grp_sorted = pool.alloc(sub_grp_items); + int max_grp_seg = checked_int_product( + (size_t)n_sort_groups, (size_t)sub_batch_cols, + "OVO host CSC stream group segment count"); + bufs[s].grp_seg_offsets = pool.alloc(max_grp_seg); + bufs[s].grp_seg_ends = pool.alloc(max_grp_seg); + } else { + bufs[s].grp_sorted = nullptr; + bufs[s].grp_seg_offsets = nullptr; + bufs[s].grp_seg_ends = nullptr; + } + } + + int tpb_rank = + round_up_to_warp(std::min(max_grp_size, MAX_THREADS_PER_BLOCK)); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config( + n_groups_stats, compute_sq_sums, compute_nnz, cast_use_gmem); + + // Pin only the sparse input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int sb_ref_actual = + checked_int_product((size_t)n_ref, (size_t)sb_cols, + "OVO host CSC active reference sub-batch"); + int sb_grp_actual = + checked_int_product((size_t)n_all_grp, (size_t)sb_cols, + "OVO host CSC active group sub-batch"); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + // ---- H2D: sparse data for this column range (native dtype) ---- + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + size_t nnz = (size_t)(ptr_end - ptr_start); + checked_int_span(nnz, "OVO host CSC active batch nnz"); + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + nnz * sizeof(InT), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + nnz * sizeof(IndexT), cudaMemcpyHostToDevice, stream); + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_indptr, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // ---- Cast to float32 for sort + accumulate stats in float64 ---- + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_indptr, d_stats_codes, buf.d_group_sums, buf.d_group_sq_sums, + buf.d_group_nnz, sb_cols, n_groups_stats, compute_sq_sums, + compute_nnz, UTIL_BLOCK_SIZE, smem_cast, cast_use_gmem, stream); + + // ---- Extract ref from CSC via row_map, sort ---- + cudaMemsetAsync(buf.ref_dense, 0, sb_ref_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_ref_row_map, buf.ref_dense, n_ref, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + upload_linear_offsets(buf.ref_seg_offsets, sb_cols, n_ref, stream); + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.ref_dense, buf.ref_sorted, + sb_ref_actual, sb_cols, buf.ref_seg_offsets, + buf.ref_seg_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // ---- Extract grp from CSC via row_map ---- + cudaMemsetAsync(buf.grp_dense, 0, sb_grp_actual * sizeof(float), + stream); + csc_extract_mapped_kernel<<>>( + buf.d_sparse_data_f32, buf.d_sparse_indices, buf.d_indptr, + d_grp_row_map, buf.grp_dense, n_all_grp, 0); + CUDA_CHECK_LAST_ERROR(csc_extract_mapped_kernel); + + // ---- Tier dispatch: sort grp + rank ---- + int skip_le = 0; + bool run_tier0 = t1.use_tier0; + bool run_tier0_64 = t1.any_tier0_64; + bool run_tier2 = t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(buf.ref_sorted, buf.ref_tie_sums, n_ref, + sb_cols, stream); + } + if (run_tier0) { + launch_tier0(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, compute_tie_corr, + stream); + if (t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, + n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, skip_le, stream); + if (t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium(buf.ref_sorted, buf.grp_dense, d_grp_offsets, + buf.ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = t1.any_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (t1.any_above_t2 && use_tier1) { + dim3 grid(sb_cols, n_groups); + ovo_fused_sort_rank_kernel<<>>( + buf.ref_sorted, buf.grp_dense, d_grp_offsets, buf.d_rank_sums, + buf.d_tie_corr, n_ref, n_all_grp, sb_cols, n_groups, + compute_tie_corr, padded_grp_size, upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (needs_tier3) { + int sb_grp_seg = + checked_int_product((size_t)n_sort_groups, (size_t)sb_cols, + "OVO host CSC active group segment count"); + { + int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<>>( + d_grp_offsets, d_sort_group_ids, buf.grp_seg_offsets, + buf.grp_seg_ends, n_all_grp, n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR(build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.grp_dense, buf.grp_sorted, + sb_grp_actual, sb_grp_seg, buf.grp_seg_offsets, + buf.grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + { + dim3 grid(sb_cols, n_groups); + batched_rank_sums_presorted_kernel<<>>( + buf.ref_sorted, buf.grp_sorted, d_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, n_all_grp, sb_cols, + n_groups, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + } + + // ---- D2D: scatter sub-batch results into caller's GPU buffers ---- + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync(d_tie_corr + col, n_cols * sizeof(double), + buf.d_tie_corr, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups_stats, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in wilcoxon streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +/** + * Host CSR OVO pipeline — zero-copy mapped full-CSR with GPU-side row gather. + * + * Setup: pin the full host CSR with cudaHostRegisterMapped, upload the full + * indptr (small) + row_ids + pre-computed compacted indptrs. Each pack + * gathers only its rows over PCIe via a UVA kernel — the full matrix is never + * transferred to GPU. + * + * Phase 1 (Ref): fused gather + cast + stats over ref rows; segmented sort + * to d_ref_sorted (cached for the whole run). + * Phase 2 (per pack, round-robin across N_STREAMS): + * 1. rebase per-pack output indptr from the pre-uploaded global compacted + * indptr. + * 2. rebase per-pack group offsets + build per-row stats codes. + * 3. csr_gather_cast_accumulate_mapped_kernel — one PCIe pass, writes + * compacted f32 data + indices and accumulates per-group stats. + * 4. Per sub-batch: extract dense → sort → rank vs ref_sorted → scatter. + * + * Memory: d_ref_sorted (n_ref × n_cols × 4B) + N_STREAMS pack buffers sized + * for max_pack_rows × sb_cols (dense) and max_pack_nnz (compacted CSR). + * Full CSR stays on host (pinned-mapped). + */ +template +static void ovo_streaming_csr_host_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + int n_full_rows, const int* h_ref_row_ids, int n_ref, + const int* h_grp_row_ids, const int* h_grp_offsets, int n_all_grp, + int n_test, double* d_rank_sums, double* d_tie_corr, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int n_cols, + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, + bool compute_nnz, bool compute_sums, int sub_batch_cols) { + if (n_cols == 0 || n_ref == 0 || n_test == 0 || n_all_grp == 0) return; + + // ---- Pre-compute compacted indptrs on host (O(n_ref + n_all_grp)) ---- + // Use IndptrT for the global compacted indptr because the grp side can + // exceed 2^31 nnz on very large / dense matrices. Ref always fits in + // int32 since n_ref × n_cols ≪ 2B; keeping int32 there matches the + // downstream CUB segmented-sort temp sizing. + std::vector h_ref_indptr_compact(n_ref + 1); + h_ref_indptr_compact[0] = 0; + for (int i = 0; i < n_ref; i++) { + int r = h_ref_row_ids[i]; + IndptrT row_nnz = h_indptr[r + 1] - h_indptr[r]; + if ((size_t)row_nnz > (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference row exceeds int32 compacted nnz limit"); + } + int nnz_i = (int)row_nnz; + if ((size_t)h_ref_indptr_compact[i] + (size_t)nnz_i > + (size_t)std::numeric_limits::max()) { + throw std::runtime_error( + "OVO host CSR reference compacted nnz exceeds int32 limit"); + } + h_ref_indptr_compact[i + 1] = h_ref_indptr_compact[i] + nnz_i; + } + int ref_nnz = h_ref_indptr_compact[n_ref]; + + // grp: compacted indptr over concatenated test-group rows (IndptrT). + std::vector h_grp_indptr_compact(n_all_grp + 1); + h_grp_indptr_compact[0] = 0; + for (int i = 0; i < n_all_grp; i++) { + int r = h_grp_row_ids[i]; + IndptrT nnz_i = h_indptr[r + 1] - h_indptr[r]; + h_grp_indptr_compact[i + 1] = h_grp_indptr_compact[i] + nnz_i; + } + + // ---- Build packs (same rule as grp_impl, but uses compacted indptr) ---- + struct Pack { + int first; + int end; + int n_rows; + size_t nnz; + int sb_cols; + }; + std::vector packs; + int max_pack_rows = 0; + size_t max_pack_nnz = 0; + int max_pack_K = 0; + int max_pack_items = 0; + int max_pack_sb_cols = sub_batch_cols; + { + int target_packs = N_STREAMS; + int target_rows = (n_all_grp + target_packs - 1) / target_packs; + if (target_rows < 1) target_rows = 1; + size_t budget_cap_rows = + GROUP_DENSE_BUDGET_ITEMS / (size_t)sub_batch_cols; + if ((size_t)target_rows > budget_cap_rows) + target_rows = (int)budget_cap_rows; + + int cur_first = 0; + int cur_rows = 0; + size_t cur_nnz = 0; + for (int g = 0; g < n_test; g++) { + int n_g = h_grp_offsets[g + 1] - h_grp_offsets[g]; + size_t nnz_g = (size_t)(h_grp_indptr_compact[h_grp_offsets[g + 1]] - + h_grp_indptr_compact[h_grp_offsets[g]]); + int new_rows = cur_rows + n_g; + bool can_add = (cur_rows == 0) || (new_rows <= target_rows); + if (!can_add) { + size_t sb_size = + std::min((size_t)n_cols, + GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, g, cur_rows, cur_nnz, (int)sb_size}); + cur_first = g; + cur_rows = n_g; + cur_nnz = nnz_g; + } else { + cur_rows = new_rows; + cur_nnz += nnz_g; + } + } + if (cur_rows > 0) { + size_t sb_size = std::min( + (size_t)n_cols, GROUP_DENSE_BUDGET_ITEMS / (size_t)cur_rows); + if (sb_size < (size_t)sub_batch_cols) sb_size = sub_batch_cols; + packs.push_back( + {cur_first, n_test, cur_rows, cur_nnz, (int)sb_size}); + } + } + for (const Pack& pk : packs) { + int K = pk.end - pk.first; + if (pk.n_rows > max_pack_rows) max_pack_rows = pk.n_rows; + if (pk.nnz > max_pack_nnz) max_pack_nnz = pk.nnz; + if (K > max_pack_K) max_pack_K = K; + int pack_items = + checked_int_product((size_t)pk.n_rows, (size_t)pk.sb_cols, + "OVO host CSR pack dense slab"); + if (pack_items > max_pack_items) max_pack_items = pack_items; + checked_int_span(pk.nnz, "OVO host CSR pack compacted nnz"); + if (pk.sb_cols > max_pack_sb_cols) max_pack_sb_cols = pk.sb_cols; + } + int max_group_rows = max_pack_rows; + size_t max_sub_items = (size_t)max_pack_items; + if (max_pack_rows == 0) return; + + RmmScratchPool pool; + + // Zero stats outputs. + if (compute_sums) { + cudaMemsetAsync(d_group_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, + (size_t)n_groups_stats * n_cols * sizeof(double)); + } + + // ---- Pin full host data + indices as MAPPED (zero-copy accessible) ---- + size_t full_nnz = (size_t)h_indptr[n_full_rows]; + HostRegisterGuard _pin_data(const_cast(h_data), + full_nnz * sizeof(InT), cudaHostRegisterMapped); + HostRegisterGuard _pin_indices(const_cast(h_indices), + full_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + + // Get device-accessible pointers (UVA makes these equal to host ptrs on + // Linux x86-64, but the API is the safe/portable way). + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (full_nnz > 0) { + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + // ---- Upload full indptr (keep native IndptrT — can exceed int32) ---- + IndptrT* d_indptr_full = pool.alloc(n_full_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_full_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + // ---- Upload row_ids + compacted indptrs + group boundaries ---- + int* d_ref_row_ids = pool.alloc(n_ref); + int* d_grp_row_ids = pool.alloc(n_all_grp); + IndptrT* d_grp_indptr_compact = pool.alloc(n_all_grp + 1); + int* d_grp_offsets_full = pool.alloc(n_test + 1); + cudaMemcpy(d_ref_row_ids, h_ref_row_ids, n_ref * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_row_ids, h_grp_row_ids, n_all_grp * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_indptr_compact, h_grp_indptr_compact.data(), + (n_all_grp + 1) * sizeof(IndptrT), cudaMemcpyHostToDevice); + cudaMemcpy(d_grp_offsets_full, h_grp_offsets, (n_test + 1) * sizeof(int), + cudaMemcpyHostToDevice); + + // ---- Phase 1: Ref setup (scoped scratch, ref_sorted persists) ---- + size_t ref_items = (size_t)n_ref * (size_t)n_cols; + if (n_ref > 0 && (size_t)n_cols > (size_t)std::numeric_limits::max() / + (size_t)n_ref) { + throw std::runtime_error( + "OVO host CSR dense reference cache exceeds CUB int item limit; " + "use native CSC/device sparse input or reduce genes/reference " + "size"); + } + if (ref_items > std::numeric_limits::max() / (2 * sizeof(float))) { + throw std::runtime_error( + "OVO host CSR dense reference cache size overflows size_t"); + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess && + total_bytes > 0 && ref_items * 2 * sizeof(float) > total_bytes) { + throw std::runtime_error( + "OVO host CSR dense reference cache requires more GPU memory than " + "the device provides; use native CSC/device sparse input or reduce " + "genes/reference size"); + } + int ref_items_i32 = + checked_cub_items(ref_items, "OVO host CSR dense reference cache"); + float* d_ref_sorted = pool.alloc(ref_items); + cudaStream_t ref_stream; + cudaStreamCreateWithFlags(&ref_stream, cudaStreamNonBlocking); + { + ScopedCudaBuffer ref_data_f32_buf(ref_nnz * sizeof(float)); + ScopedCudaBuffer ref_indices_buf(ref_nnz * sizeof(int)); + ScopedCudaBuffer ref_indptr_buf((n_ref + 1) * sizeof(int)); + ScopedCudaBuffer ref_dense_buf(ref_items * sizeof(float)); + ScopedCudaBuffer ref_seg_buf((n_cols + 1) * sizeof(int)); + + float* d_ref_data_f32 = (float*)ref_data_f32_buf.data(); + int* d_ref_indices = (int*)ref_indices_buf.data(); + int* d_ref_indptr = (int*)ref_indptr_buf.data(); + float* d_ref_dense = (float*)ref_dense_buf.data(); + int* d_ref_seg = (int*)ref_seg_buf.data(); + + // Upload ref compacted indptr + cudaMemcpy(d_ref_indptr, h_ref_indptr_compact.data(), + (n_ref + 1) * sizeof(int), cudaMemcpyHostToDevice); + + // Fused gather + cast + stats for ref (fixed slot = n_test). One + // pass over PCIe, no intermediate native-dtype GPU buffer. + if (n_ref > 0 && ref_nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, d_ref_row_ids, + d_ref_indptr, /*d_stats_codes=*/nullptr, + /*fixed_slot=*/n_test, d_ref_data_f32, d_ref_indices, + d_group_sums, d_group_sq_sums, d_group_nnz, n_ref, n_cols, + n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Extract ref dense (F-order) from compacted CSR. + cudaMemsetAsync(d_ref_dense, 0, ref_items * sizeof(float), ref_stream); + { + csr_extract_dense_identity_rows_unsorted_kernel + <<>>( + d_ref_data_f32, d_ref_indices, d_ref_indptr, d_ref_dense, + n_ref, 0, n_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + } + + // Segmented sort ref_dense by column → ref_sorted + size_t ref_cub_bytes = 0; + { + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, ref_cub_bytes, fk, fk, ref_items_i32, n_cols, doff, + doff + 1, BEGIN_BIT, END_BIT); + } + ScopedCudaBuffer cub_temp_buf(ref_cub_bytes); + upload_linear_offsets(d_ref_seg, n_cols, n_ref, ref_stream); + size_t temp = ref_cub_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + cub_temp_buf.data(), temp, d_ref_dense, d_ref_sorted, ref_items_i32, + n_cols, d_ref_seg, d_ref_seg + 1, BEGIN_BIT, END_BIT, ref_stream); + cudaStreamSynchronize(ref_stream); + } // ref scratch drops here + cudaStreamDestroy(ref_stream); + + // ---- Phase 2: Per-pack streaming ---- + auto t1 = make_tier1_config(h_grp_offsets, n_test); + bool may_need_cub = (t1.max_grp_size > TIER1_GROUP_THRESHOLD); + + constexpr int MAX_GROUP_STREAMS = 4; + int n_streams = MAX_GROUP_STREAMS; + if (n_test < n_streams) n_streams = n_test; + if (n_streams < 1) n_streams = 1; + if ((int)packs.size() < n_streams) n_streams = (int)packs.size(); + if (n_streams < 1) n_streams = 1; + + size_t cub_grp_bytes = 0; + if (may_need_cub && max_sub_items > 0) { + int max_sub_items_i32 = + checked_cub_items(max_sub_items, "OVO host CSR group pack"); + auto* fk = reinterpret_cast(1); + auto* doff = reinterpret_cast(1); + int max_segments = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR max group segment count"); + cub::DeviceSegmentedRadixSort::SortKeys( + nullptr, cub_grp_bytes, fk, fk, max_sub_items_i32, max_segments, + doff, doff + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + struct StreamBuf { + float* d_grp_data_f32; + int* d_grp_indices; + int* d_grp_indptr; + int* d_pack_grp_offsets; + int* d_pack_stats_codes; + float* d_grp_dense; + float* d_grp_sorted; + double* d_ref_tie_sums; + int* d_sort_group_ids; + int* d_grp_seg_offsets; + int* d_grp_seg_ends; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + }; + std::vector bufs(n_streams); + int max_pack_kernel_seg = + checked_int_product((size_t)max_pack_K, (size_t)max_pack_sb_cols, + "OVO host CSR pack segment buffer"); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_grp_data_f32 = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indices = pool.alloc(max_pack_nnz); + bufs[s].d_grp_indptr = pool.alloc(max_pack_rows + 1); + bufs[s].d_pack_grp_offsets = pool.alloc(max_pack_K + 1); + bufs[s].d_pack_stats_codes = pool.alloc(max_pack_rows); + bufs[s].d_grp_dense = pool.alloc(max_sub_items); + bufs[s].d_ref_tie_sums = pool.alloc(max_pack_sb_cols); + bufs[s].d_rank_sums = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + bufs[s].d_tie_corr = + pool.alloc((size_t)max_pack_K * max_pack_sb_cols); + if (may_need_cub) { + bufs[s].d_grp_sorted = pool.alloc(max_sub_items); + bufs[s].d_sort_group_ids = pool.alloc(max_pack_K); + bufs[s].d_grp_seg_offsets = pool.alloc(max_pack_kernel_seg); + bufs[s].d_grp_seg_ends = pool.alloc(max_pack_kernel_seg); + bufs[s].cub_temp = pool.alloc(cub_grp_bytes); + } else { + bufs[s].d_grp_sorted = nullptr; + bufs[s].d_sort_group_ids = nullptr; + bufs[s].d_grp_seg_offsets = nullptr; + bufs[s].d_grp_seg_ends = nullptr; + bufs[s].cub_temp = nullptr; + } + } + + for (int p = 0; p < (int)packs.size(); p++) { + const Pack& pack = packs[p]; + int K = pack.end - pack.first; + if (K == 0 || pack.n_rows == 0) continue; + Tier1Config pack_t1 = make_tier1_config(h_grp_offsets + pack.first, K); + int pack_tpb_rank = round_up_to_warp( + std::min(pack_t1.max_grp_size, MAX_THREADS_PER_BLOCK)); + bool pack_has_above_t2 = pack_t1.max_grp_size > TIER2_GROUP_THRESHOLD; + int pack_tier3_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : TIER0_GROUP_THRESHOLD; + std::vector h_sort_group_ids; + int pack_n_sort_groups = K; + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + h_sort_group_ids = make_sort_group_ids(h_grp_offsets + pack.first, + K, pack_tier3_skip_le); + pack_n_sort_groups = (int)h_sort_group_ids.size(); + } + + int s = p % n_streams; + cudaStream_t stream = streams[s]; + auto& buf = bufs[s]; + + if (pack_t1.any_above_t0 && !pack_t1.use_tier1) { + cudaMemcpyAsync(buf.d_sort_group_ids, h_sort_group_ids.data(), + h_sort_group_ids.size() * sizeof(int), + cudaMemcpyHostToDevice, stream); + } + + int row_start = h_grp_offsets[pack.first]; + int pack_rows = pack.n_rows; + int pack_sb = pack.sb_cols; + + // Rebase pack's output indptr from pre-uploaded global compacted indptr + // (IndptrT → int32: pack nnz is bounded by GROUP_DENSE_BUDGET so fits). + { + int count = pack_rows + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel + <<>>( + d_grp_indptr_compact, buf.d_grp_indptr, row_start, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Build per-pack group offsets on GPU (on this stream) — needed to + // compute stats codes before the fused gather kernel can run. + { + int count = K + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + d_grp_offsets_full, buf.d_pack_grp_offsets, pack.first, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Fill per-row stats codes for this pack + { + int blk = (pack_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + fill_pack_stats_codes_kernel<<>>( + buf.d_pack_grp_offsets, buf.d_pack_stats_codes, K, pack.first); + CUDA_CHECK_LAST_ERROR(fill_pack_stats_codes_kernel); + } + + // Fused gather + cast + stats for the pack. One pass over PCIe + // (reads mapped host via UVA), no intermediate native-dtype GPU + // buffer, writes f32 + indices + atomics. + if (pack.nnz > 0) { + csr_gather_cast_accumulate_mapped_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, + d_grp_row_ids + row_start, buf.d_grp_indptr, + buf.d_pack_stats_codes, /*fixed_slot=*/-1, + buf.d_grp_data_f32, buf.d_grp_indices, d_group_sums, + d_group_sq_sums, d_group_nnz, pack_rows, n_cols, + n_groups_stats, compute_sums, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(csr_gather_cast_accumulate_mapped_kernel); + } + + // Per col sub-batch + int col = 0; + while (col < n_cols) { + int sb_cols = std::min(pack_sb, n_cols - col); + int sb_items = + checked_int_product((size_t)pack_rows, (size_t)sb_cols, + "OVO host CSR active group sub-batch"); + + cudaMemsetAsync(buf.d_grp_dense, 0, sb_items * sizeof(float), + stream); + csr_extract_dense_identity_rows_unsorted_kernel + <<>>( + buf.d_grp_data_f32, buf.d_grp_indices, buf.d_grp_indptr, + buf.d_grp_dense, pack_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR( + csr_extract_dense_identity_rows_unsorted_kernel); + + const float* ref_sub = d_ref_sorted + (size_t)col * n_ref; + + int skip_le = 0; + bool run_tier0 = pack_t1.use_tier0; + bool run_tier0_64 = pack_t1.any_tier0_64; + bool run_tier2 = pack_t1.any_tier2; + if (compute_tie_corr && (run_tier0 || run_tier0_64 || run_tier2)) { + launch_ref_tie_sums(ref_sub, buf.d_ref_tie_sums, n_ref, sb_cols, + stream); + } + if (run_tier0) { + launch_tier0(ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, + buf.d_tie_corr, n_ref, pack_rows, sb_cols, K, + compute_tie_corr, stream); + if (pack_t1.any_above_t0) skip_le = TIER0_GROUP_THRESHOLD; + } + if (run_tier0_64) { + launch_tier0_64( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + if (pack_t1.max_grp_size > TIER0_64_GROUP_THRESHOLD) { + skip_le = TIER0_64_GROUP_THRESHOLD; + } + } + if (run_tier2) { + launch_tier2_medium( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_ref_tie_sums, buf.d_rank_sums, buf.d_tie_corr, n_ref, + pack_rows, sb_cols, K, compute_tie_corr, skip_le, stream); + } + + int upper_skip_le = + pack_has_above_t2 ? TIER2_GROUP_THRESHOLD : skip_le; + if (pack_has_above_t2 && pack_t1.use_tier1) { + dim3 grid(sb_cols, K); + ovo_fused_sort_rank_kernel<<>>( + ref_sub, buf.d_grp_dense, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, pack_t1.padded_grp_size, + upper_skip_le); + CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel); + } else if (pack_has_above_t2) { + int n_seg = checked_int_product( + (size_t)pack_n_sort_groups, (size_t)sb_cols, + "OVO host CSR active group segment count"); + { + int blk = (n_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + build_tier3_seg_begin_end_offsets_kernel<<< + blk, UTIL_BLOCK_SIZE, 0, stream>>>( + buf.d_pack_grp_offsets, buf.d_sort_group_ids, + buf.d_grp_seg_offsets, buf.d_grp_seg_ends, pack_rows, + pack_n_sort_groups, sb_cols); + CUDA_CHECK_LAST_ERROR( + build_tier3_seg_begin_end_offsets_kernel); + } + { + size_t temp = cub_grp_bytes; + cub::DeviceSegmentedRadixSort::SortKeys( + buf.cub_temp, temp, buf.d_grp_dense, buf.d_grp_sorted, + sb_items, n_seg, buf.d_grp_seg_offsets, + buf.d_grp_seg_ends, BEGIN_BIT, END_BIT, stream); + } + dim3 grid(sb_cols, K); + batched_rank_sums_presorted_kernel<<>>( + ref_sub, buf.d_grp_sorted, buf.d_pack_grp_offsets, + buf.d_rank_sums, buf.d_tie_corr, n_ref, pack_rows, sb_cols, + K, compute_tie_corr, upper_skip_le); + CUDA_CHECK_LAST_ERROR(batched_rank_sums_presorted_kernel); + } + + cudaMemcpy2DAsync(d_rank_sums + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_rank_sums, + sb_cols * sizeof(double), + sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpy2DAsync( + d_tie_corr + (size_t)pack.first * n_cols + col, + n_cols * sizeof(double), buf.d_tie_corr, + sb_cols * sizeof(double), sb_cols * sizeof(double), K, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in ovo csr host streaming: ") + + cudaGetErrorString(err)); + } + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh new file mode 100644 index 00000000..9fd626b6 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovo_kernels.cuh @@ -0,0 +1,173 @@ +#pragma once + +/** + * Build CUB segmented-sort ranges only for groups that Tier 3 will rank. + * Group ids are relative to grp_offsets, and ranges still point into the + * original dense group layout so the presorted rank kernel can read from the + * normal per-group positions. + */ +__global__ void build_tier3_seg_begin_end_offsets_kernel( + const int* __restrict__ grp_offsets, const int* __restrict__ group_ids, + int* __restrict__ begins, int* __restrict__ ends, int n_all_grp, + int n_sort_groups, int sb_cols) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = sb_cols * n_sort_groups; + if (idx >= total) return; + + int c = idx / n_sort_groups; + int local = idx % n_sort_groups; + int g = group_ids[local]; + int base = c * n_all_grp; + begins[idx] = base + grp_offsets[g]; + ends[idx] = base + grp_offsets[g + 1]; +} + +/** + * Extract specific rows from CSC into dense F-order, using a row lookup map. + * row_map[original_row] = output_row_index (or -1 to skip). + * One block per column, threads scatter matching nonzeros. + * Output must be pre-zeroed. + */ +template +__global__ void csc_extract_mapped_kernel(const float* __restrict__ data, + const IndexT* __restrict__ indices, + const int* __restrict__ indptr, + const int* __restrict__ row_map, + float* __restrict__ out, int n_target, + int col_start) { + int col_local = blockIdx.x; + int col = col_start + col_local; + + int start = indptr[col]; + int end = indptr[col + 1]; + + for (int p = start + threadIdx.x; p < end; p += blockDim.x) { + int out_row = row_map[(int)indices[p]]; + if (out_row >= 0) { + out[(long long)col_local * n_target + out_row] = data[p]; + } + } +} + +/** + * Tier 1 dispatch: when the largest group fits in shared memory, a fused + * bitonic-sort + binary-search kernel handles the whole group per block. + * Otherwise we fall back to CUB segmented sort plus the pre-sorted rank + * kernel. This struct bundles the sizing knobs derived from the host-side + * group offsets so each streaming impl can drop a 15-line prep block. + */ +struct Tier1Config { + int max_grp_size = 0; + int min_grp_size = 0; + bool use_tier0 = + false; // any group fits in one warp (≤ TIER0_GROUP_THRESHOLD) + bool use_tier1 = + false; // any group needs > tier0 but fits in tier1 smem sort + bool any_above_t0 = + false; // at least one group exceeds TIER0_GROUP_THRESHOLD + bool any_tier0_64 = false; // any group needs Tier 0.5: (T0, T0_64] + bool any_tier2 = false; // any group needs Tier 2: (T0_64, T2] + bool any_above_t2 = + false; // at least one group exceeds TIER2_GROUP_THRESHOLD + int padded_grp_size = 0; + int tier1_tpb = 0; + size_t tier1_smem = 0; +}; + +static Tier1Config make_tier1_config(const int* h_grp_offsets, int n_groups) { + Tier1Config c; + c.min_grp_size = INT_MAX; + for (int g = 0; g < n_groups; g++) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (sz > c.max_grp_size) c.max_grp_size = sz; + if (sz < c.min_grp_size) c.min_grp_size = sz; + if (sz > TIER0_GROUP_THRESHOLD && sz <= TIER0_64_GROUP_THRESHOLD) { + c.any_tier0_64 = true; + } + if (sz > TIER0_64_GROUP_THRESHOLD && sz <= TIER2_GROUP_THRESHOLD) { + c.any_tier2 = true; + } + if (sz > TIER2_GROUP_THRESHOLD) c.any_above_t2 = true; + } + if (n_groups == 0) c.min_grp_size = 0; + + // use_tier0: Tier 0 kernel is worth running (at least one group small + // enough to benefit from the warp path). + c.use_tier0 = (c.min_grp_size <= TIER0_GROUP_THRESHOLD); + // any_above_t0: at least one group needs a non-Tier-0 kernel. + c.any_above_t0 = (c.max_grp_size > TIER0_GROUP_THRESHOLD); + // use_tier1: the fused smem-sort fast path (for groups > T0 but ≤ T1). + c.use_tier1 = c.any_above_t0 && (c.max_grp_size <= TIER1_GROUP_THRESHOLD); + if (c.use_tier1) { + c.padded_grp_size = 1; + while (c.padded_grp_size < c.max_grp_size) c.padded_grp_size <<= 1; + c.tier1_tpb = std::min(c.padded_grp_size, MAX_THREADS_PER_BLOCK); + c.tier1_smem = (size_t)c.padded_grp_size * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + } + return c; +} + +static std::vector make_sort_group_ids(const int* h_grp_offsets, + int n_groups, int skip_n_grp_le) { + std::vector ids; + ids.reserve(n_groups); + for (int g = 0; g < n_groups; ++g) { + int sz = h_grp_offsets[g + 1] - h_grp_offsets[g]; + if (skip_n_grp_le > 0 && sz <= skip_n_grp_le) continue; + ids.push_back(g); + } + return ids; +} + +// Tier 0 kernel launcher: 8 warps × 32 threads per block, one (col, group) +// pair per warp. grid.y covers ceil(K/8) pair rows. +static inline void launch_tier0(const float* ref_sorted, const float* grp_dense, + const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, + double* tie_corr, int n_ref, int n_all_grp, + int sb_cols, int K, bool compute_tie_corr, + cudaStream_t stream) { + constexpr int WARPS_PER_BLOCK = 8; + dim3 grid(sb_cols, (K + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK); + ovo_warp_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr); + CUDA_CHECK_LAST_ERROR(ovo_warp_sort_rank_kernel); +} + +static inline void launch_ref_tie_sums(const float* ref_sorted, + double* ref_tie_sums, int n_ref, + int sb_cols, cudaStream_t stream) { + ref_tie_sum_kernel<<>>( + ref_sorted, ref_tie_sums, n_ref, sb_cols); + CUDA_CHECK_LAST_ERROR(ref_tie_sum_kernel); +} + +static inline void launch_tier0_64( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + dim3 grid(sb_cols, K); + ovo_small64_sort_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le); + CUDA_CHECK_LAST_ERROR(ovo_small64_sort_rank_kernel); +} + +static inline void launch_tier2_medium( + const float* ref_sorted, const float* grp_dense, const int* grp_offsets, + const double* ref_tie_sums, double* rank_sums, double* tie_corr, int n_ref, + int n_all_grp, int sb_cols, int K, bool compute_tie_corr, int skip_n_grp_le, + cudaStream_t stream) { + constexpr int tpb = 256; + size_t smem = (size_t)TIER2_GROUP_THRESHOLD * sizeof(float) + + WARP_REDUCE_BUF * sizeof(double); + dim3 grid(sb_cols, K); + ovo_medium_unsorted_rank_kernel<<>>( + ref_sorted, grp_dense, grp_offsets, ref_tie_sums, rank_sums, tie_corr, + n_ref, n_all_grp, sb_cols, K, compute_tie_corr, skip_n_grp_le, + TIER2_GROUP_THRESHOLD); + CUDA_CHECK_LAST_ERROR(ovo_medium_unsorted_rank_kernel); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh new file mode 100644 index 00000000..2323e27f --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_kernels.cuh @@ -0,0 +1,89 @@ +#pragma once + +/** Count nonzeros per column from CSR. One thread per row. */ +template +__global__ void csr_col_histogram_kernel(const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, + unsigned int* __restrict__ col_counts, + int n_rows, int n_cols) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)indices[p]; + if (c < n_cols) atomicAdd(&col_counts[c], 1u); + } +} + +/** + * Scatter CSR nonzeros into CSC layout for columns [col_start, col_stop). + * write_pos[c - col_start] must be initialized to the prefix-sum offset + * for column c. Each thread atomically claims a unique destination slot. + */ +template +__global__ void csr_scatter_to_csc_kernel( + const InT* __restrict__ data, const IndexT* __restrict__ indices, + const IndptrT* __restrict__ indptr, int* __restrict__ write_pos, + InT* __restrict__ csc_vals, int* __restrict__ csc_row_idx, int n_rows, + int col_start, int col_stop) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n_rows) return; + IndptrT rs = indptr[row]; + IndptrT re = indptr[row + 1]; + // Binary search for col_start (overflow-safe midpoint) + IndptrT lo = rs, hi = re; + while (lo < hi) { + IndptrT m = lo + ((hi - lo) >> 1); + if (indices[m] < col_start) + lo = m + 1; + else + hi = m; + } + for (IndptrT p = lo; p < re; ++p) { + int c = (int)indices[p]; + if (c >= col_stop) break; + int dest = atomicAdd(&write_pos[c - col_start], 1); + csc_vals[dest] = data[p]; + csc_row_idx[dest] = row; + } +} + +static size_t ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(n_groups + 32) * sizeof(double); + if (need <= wilcoxon_max_smem_per_block()) { + use_gmem = false; + return need; + } + // Fall back to global memory accumulators; only need warp buf in smem + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Decide smem-vs-gmem for the sparse OVR rank kernel. Two accumulator + * arrays (grp_sums + grp_nz_count) of size n_groups each plus warp buf. + */ +static size_t sparse_ovr_smem_config(int n_groups, bool& use_gmem) { + size_t need = (size_t)(2 * n_groups + 32) * sizeof(double); + if (need <= wilcoxon_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 32 * sizeof(double); +} + +/** + * Fill sort values with row indices [0,1,...,n_rows-1] per column. + * Grid: (n_cols,), block: 256 threads. + */ +__global__ void fill_row_indices_kernel(int* __restrict__ vals, int n_rows, + int n_cols) { + int col = blockIdx.x; + if (col >= n_cols) return; + int* out = vals + (long long)col * n_rows; + for (int i = threadIdx.x; i < n_rows; i += blockDim.x) { + out[i] = i; + } +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh new file mode 100644 index 00000000..257bbbb3 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_ovr_sparse.cuh @@ -0,0 +1,882 @@ +#pragma once + +/** + * Sparse-aware host-streaming CSC OVR pipeline. + * + * Like ovr_streaming_csc_host_impl but sorts only stored nonzeros per column + * instead of extracting dense blocks. GPU memory is O(max_batch_nnz) instead + * of O(sub_batch * n_rows), and sort work is proportional to nnz, not n_rows. + */ +template +static void ovr_sparse_csc_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR host CSC sparse sub-batch nnz"); + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + RmmScratchPool pool; + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + struct StreamBuf { + InT* d_sparse_data_orig; + float* d_sparse_data_f32; + IndexT* d_sparse_indices; + int* d_seg_offsets; + float* keys_out; + IndexT* vals_out; + uint8_t* cub_temp; + double* d_rank_sums; + double* d_tie_corr; + double* d_group_sums; + double* d_group_sq_sums; + double* d_group_nnz; + double* d_nz_scratch; // gmem-only; non-null when rank_use_gmem + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].d_sparse_data_orig = pool.alloc(max_nnz); + bufs[s].d_sparse_data_f32 = pool.alloc(max_nnz); + bufs[s].d_sparse_indices = pool.alloc(max_nnz); + bufs[s].d_seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].d_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].d_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + // Transfer group codes + sizes once + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + // Pre-compute rebased per-batch offsets and upload once (avoids per-batch + // H2D copy from a transient host buffer). + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb = std::min(sub_batch_cols, n_cols - col_start); + IndptrT ptr_start = h_indptr[col_start]; + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i <= sb; i++) { + off[i] = + checked_int_span((size_t)(h_indptr[col_start + i] - ptr_start), + "OVR host CSC rebased column offsets"); + } + } + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + // In gmem mode the sparse rank kernel accumulates into rank_sums directly + // and needs a per-stream nz_count scratch buffer sized (n_groups, sb_cols). + for (int s = 0; s < n_streams; s++) { + if (rank_use_gmem) { + bufs[s].d_nz_scratch = + pool.alloc((size_t)n_groups * sub_batch_cols); + } else { + bufs[s].d_nz_scratch = nullptr; + } + } + + // Pin only the host input arrays; outputs live on the device. + size_t total_nnz = (size_t)h_indptr[n_cols]; + HostRegisterGuard _pin_data(const_cast(h_data), + total_nnz * sizeof(InT)); + HostRegisterGuard _pin_indices(const_cast(h_indices), + total_nnz * sizeof(IndexT)); + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + IndptrT ptr_start = h_indptr[col]; + IndptrT ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR host CSC active batch nnz"); + + // H2D: transfer sparse data for this column range (native dtype) + if (batch_nnz > 0) { + cudaMemcpyAsync(buf.d_sparse_data_orig, h_data + ptr_start, + (size_t)batch_nnz * sizeof(InT), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(buf.d_sparse_indices, h_indices + ptr_start, + (size_t)batch_nnz * sizeof(IndexT), + cudaMemcpyHostToDevice, stream); + } + + // D2D: copy this batch's rebased offsets from the pre-uploaded buffer + int* src = d_all_offsets + (size_t)batch_idx * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.d_seg_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Cast to float32 for sort + accumulate stats in float64 + launch_ovr_cast_and_accumulate_sparse( + buf.d_sparse_data_orig, buf.d_sparse_data_f32, buf.d_sparse_indices, + buf.d_seg_offsets, d_group_codes, buf.d_group_sums, + buf.d_group_sq_sums, buf.d_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + // CUB sort only stored nonzeros (float32 keys) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.d_sparse_data_f32, buf.keys_out, + buf.d_sparse_indices, buf.vals_out, batch_nnz, sb_cols, + buf.d_seg_offsets, buf.d_seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (stats already captured above) + if (rank_use_gmem) { + cudaMemsetAsync(buf.d_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel + <<>>( + buf.keys_out, buf.vals_out, buf.d_seg_offsets, d_group_codes, + d_group_sizes, buf.d_rank_sums, buf.d_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // D2D: scatter sub-batch results into caller's GPU buffers + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.d_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.d_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.d_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.d_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.d_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSC streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware host-streaming CSR OVR pipeline. +// ============================================================================ + +/** + * Host CSR variant of the sparse OVR stream. + * + * The CSR input stays in host memory. We count columns once on the CPU, then + * use mapped pinned CSR arrays for bounded per-column-batch CSR->CSC scatter + * on the GPU. This avoids both a full host->device sparse upload and any + * whole-matrix CSR->CSC conversion. + */ +template +static void ovr_sparse_csr_host_streaming_impl( + const InT* h_data, const IndexT* h_indices, const IndptrT* h_indptr, + const int* h_group_codes, const double* h_group_sizes, double* d_rank_sums, + double* d_tie_corr, double* d_group_sums, double* d_group_sq_sums, + double* d_group_nnz, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, + int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + RmmScratchPool pool; + size_t total_nnz = (size_t)h_indptr[n_rows]; + + // ---- Phase 0: CPU planning in native CSR order ---- + std::vector h_col_counts(n_cols, 0); + for (int row = 0; row < n_rows; row++) { + IndptrT rs = h_indptr[row]; + IndptrT re = h_indptr[row + 1]; + for (IndptrT p = rs; p < re; ++p) { + int c = (int)h_indices[p]; + if (c >= 0 && c < n_cols) h_col_counts[c]++; + } + } + + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)h_col_counts[col_start + i], + "OVR host CSR rebased column offsets"); + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: allocate per-stream bounded work buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR host CSR sparse sub-batch nnz"); + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + bool cast_use_gmem = false; + size_t smem_cast = cast_accumulate_smem_config(n_groups, compute_sq_sums, + compute_nnz, cast_use_gmem); + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + size_t per_stream_bytes = + max_batch_nnz * (sizeof(InT) + sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + 2 * (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + if (compute_sq_sums) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (compute_nnz) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + if (rank_use_gmem) { + per_stream_bytes += (size_t)n_groups * sub_batch_cols * sizeof(double); + } + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + // Pin the source CSR arrays as mapped memory. The scatter kernel reads + // only the requested column window from each row. + HostRegisterGuard pin_data; + HostRegisterGuard pin_indices; + InT* d_data_zc = nullptr; + IndexT* d_indices_zc = nullptr; + if (total_nnz > 0) { + pin_data = + HostRegisterGuard(const_cast(h_data), total_nnz * sizeof(InT), + cudaHostRegisterMapped); + pin_indices = HostRegisterGuard(const_cast(h_indices), + total_nnz * sizeof(IndexT), + cudaHostRegisterMapped); + cudaError_t e1 = cudaHostGetDevicePointer((void**)&d_data_zc, + const_cast(h_data), 0); + cudaError_t e2 = cudaHostGetDevicePointer( + (void**)&d_indices_zc, const_cast(h_indices), 0); + if (e1 != cudaSuccess || e2 != cudaSuccess) { + throw std::runtime_error( + std::string("cudaHostGetDevicePointer failed: ") + + cudaGetErrorString(e1 != cudaSuccess ? e1 : e2)); + } + } + + IndptrT* d_indptr_full = pool.alloc(n_rows + 1); + cudaMemcpy(d_indptr_full, h_indptr, (n_rows + 1) * sizeof(IndptrT), + cudaMemcpyHostToDevice); + + int* d_group_codes = pool.alloc(n_rows); + double* d_group_sizes = pool.alloc(n_groups); + cudaMemcpy(d_group_codes, h_group_codes, n_rows * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_group_sizes, h_group_sizes, n_groups * sizeof(double), + cudaMemcpyHostToDevice); + + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; + int* write_pos; + InT* csc_vals_orig; + float* csc_vals_f32; + int* csc_row_idx; + float* keys_out; + int* vals_out; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* sub_group_sums; + double* sub_group_sq_sums; + double* sub_group_nnz; + double* d_nz_scratch; + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals_orig = pool.alloc(max_batch_nnz); + bufs[s].csc_vals_f32 = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].sub_group_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_group_sq_sums = + compute_sq_sums + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].sub_group_nnz = + compute_nnz ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: bounded CSR->CSC scatter + GPU rank batches ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = + checked_int_span(h_batch_nnz[b], "OVR host CSR active batch nnz"); + + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + csr_scatter_to_csc_kernel + <<>>( + d_data_zc, d_indices_zc, d_indptr_full, buf.write_pos, + buf.csc_vals_orig, buf.csc_row_idx, n_rows, col, + col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + } + + launch_ovr_cast_and_accumulate_sparse( + buf.csc_vals_orig, buf.csc_vals_f32, buf.csc_row_idx, + buf.col_offsets, d_group_codes, buf.sub_group_sums, + buf.sub_group_sq_sums, buf.sub_group_nnz, sb_cols, n_groups, + compute_sq_sums, compute_nnz, tpb, smem_cast, cast_use_gmem, + stream); + + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals_f32, buf.keys_out, + buf.csc_row_idx, buf.vals_out, batch_nnz, sb_cols, + buf.col_offsets, buf.col_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, d_group_codes, + d_group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, + buf.d_nz_scratch, n_rows, sb_cols, n_groups, compute_tie_corr, + rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + cudaMemcpy2DAsync(d_rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(d_tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + cudaMemcpy2DAsync(d_group_sums + col, n_cols * sizeof(double), + buf.sub_group_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_sq_sums) { + cudaMemcpy2DAsync(d_group_sq_sums + col, n_cols * sizeof(double), + buf.sub_group_sq_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + if (compute_nnz) { + cudaMemcpy2DAsync(d_group_nnz + col, n_cols * sizeof(double), + buf.sub_group_nnz, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse host CSR streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSC OVR streaming (sort only stored nonzeros) +// ============================================================================ + +static void ovr_sparse_csc_streaming_impl( + const float* csc_data, const int* csc_indices, const int* csc_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // Read indptr to host for batch planning + std::vector h_indptr(n_cols + 1); + cudaMemcpy(h_indptr.data(), csc_indptr, (n_cols + 1) * sizeof(int), + cudaMemcpyDeviceToHost); + + int n_streams = N_STREAMS; + if (n_cols < n_streams * sub_batch_cols) + n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + + // Find max nnz across any sub-batch for buffer sizing + size_t max_nnz = 0; + for (int col = 0; col < n_cols; col += sub_batch_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + size_t nnz = (size_t)(h_indptr[col + sb_cols] - h_indptr[col]); + if (nnz > max_nnz) max_nnz = nnz; + } + + // CUB temp size for max_nnz items + size_t cub_temp_bytes = 0; + if (max_nnz > 0) { + int max_nnz_i32 = + checked_cub_items(max_nnz, "OVR device CSC sparse sub-batch nnz"); + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_nnz_i32, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + + RmmScratchPool pool; + struct StreamBuf { + float* keys_out; + int* vals_out; + int* seg_offsets; + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].keys_out = pool.alloc(max_nnz); + bufs[s].vals_out = pool.alloc(max_nnz); + bufs[s].seg_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + int col = 0; + int batch_idx = 0; + while (col < n_cols) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = batch_idx % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + + int ptr_start = h_indptr[col]; + int ptr_end = h_indptr[col + sb_cols]; + int batch_nnz = checked_int_span((size_t)(ptr_end - ptr_start), + "OVR device CSC active batch nnz"); + + // Compute rebased segment offsets on GPU (avoids host pinned-buffer + // race) + { + int count = sb_cols + 1; + int blk = (count + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + rebase_indptr_kernel<<>>( + csc_indptr, buf.seg_offsets, col, count); + CUDA_CHECK_LAST_ERROR(rebase_indptr_kernel); + } + + // Sort only stored values (keys=data, vals=row_indices) + if (batch_nnz > 0) { + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, csc_data + ptr_start, buf.keys_out, + csc_indices + ptr_start, buf.vals_out, batch_nnz, sb_cols, + buf.seg_offsets, buf.seg_offsets + 1, BEGIN_BIT, END_BIT, + stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.seg_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + batch_idx++; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} + +// ============================================================================ +// Sparse-aware CSR OVR streaming (partial CSR→CSC transpose per sub-batch) +// ============================================================================ + +/** + * Sparse-aware OVR streaming pipeline for GPU CSR data. + * + * Phase 0: One histogram kernel counts nnz per column. D2H + host prefix sums + * give exact per-batch nnz and max_batch_nnz for buffer sizing. + * Phase 1: Allocate per-stream buffers sized to max_batch_nnz. + * Phase 2: For each sub-batch: scatter CSR→CSC (partial transpose via + * atomics) → CUB sort only nonzeros → sparse rank kernel. + * + * Compared to the dense CSR path, sort work drops by ~1/sparsity. + */ +static void ovr_sparse_csr_streaming_impl( + const float* csr_data, const int* csr_indices, const int* csr_indptr, + const int* group_codes, const double* group_sizes, double* rank_sums, + double* tie_corr, int n_rows, int n_cols, int n_groups, + bool compute_tie_corr, int sub_batch_cols) { + if (n_rows == 0 || n_cols == 0) return; + + // ---- Phase 0: Planning — count nnz per column via histogram ---- + RmmScratchPool pool; + unsigned int* d_col_counts = pool.alloc(n_cols); + cudaMemset(d_col_counts, 0, n_cols * sizeof(unsigned int)); + { + int blocks = (n_rows + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE; + csr_col_histogram_kernel<<>>( + csr_indices, csr_indptr, d_col_counts, n_rows, n_cols); + CUDA_CHECK_LAST_ERROR(csr_col_histogram_kernel); + } + std::vector h_col_counts(n_cols); + cudaMemcpy(h_col_counts.data(), d_col_counts, n_cols * sizeof(unsigned int), + cudaMemcpyDeviceToHost); + + // Per-batch prefix sums on host + int n_batches = (n_cols + sub_batch_cols - 1) / sub_batch_cols; + size_t max_batch_nnz = 0; + + // Flat array: n_batches × (sub_batch_cols + 1) offsets + std::vector h_all_offsets((size_t)n_batches * (sub_batch_cols + 1), 0); + std::vector h_batch_nnz(n_batches); + + for (int b = 0; b < n_batches; b++) { + int col_start = b * sub_batch_cols; + int sb_cols = std::min(sub_batch_cols, n_cols - col_start); + int* off = &h_all_offsets[(size_t)b * (sub_batch_cols + 1)]; + off[0] = 0; + for (int i = 0; i < sb_cols; i++) + off[i + 1] = checked_int_span( + (size_t)off[i] + (size_t)h_col_counts[col_start + i], + "OVR device CSR rebased column offsets"); + h_batch_nnz[b] = (size_t)off[sb_cols]; + if (h_batch_nnz[b] > max_batch_nnz) max_batch_nnz = h_batch_nnz[b]; + } + + // Upload all batch offsets to GPU in one shot (~20 KB) + int* d_all_offsets = + pool.alloc((size_t)n_batches * (sub_batch_cols + 1)); + cudaMemcpy(d_all_offsets, h_all_offsets.data(), + h_all_offsets.size() * sizeof(int), cudaMemcpyHostToDevice); + + // ---- Phase 1: Allocate per-stream buffers ---- + size_t cub_temp_bytes = 0; + if (max_batch_nnz > 0) { + int max_batch_nnz_i32 = checked_cub_items( + max_batch_nnz, "OVR device CSR sparse sub-batch nnz"); + auto* fk = reinterpret_cast(1); + auto* iv = reinterpret_cast(1); + cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, cub_temp_bytes, fk, fk, iv, iv, max_batch_nnz_i32, + sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT); + } + + int n_streams = N_STREAMS; + if (n_batches < n_streams) n_streams = n_batches; + + // CSR path needs 4 sort arrays per stream (scatter intermediates + + // CUB output). Fit stream count to available GPU memory. + size_t per_stream_bytes = + max_batch_nnz * (2 * sizeof(float) + 2 * sizeof(int)) + + (sub_batch_cols + 1 + sub_batch_cols) * sizeof(int) + cub_temp_bytes + + (size_t)n_groups * sub_batch_cols * sizeof(double) + + sub_batch_cols * sizeof(double); + + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + constexpr double MEM_BUDGET_FRAC = 0.8; + size_t budget = (size_t)(free_mem * MEM_BUDGET_FRAC); + while (n_streams > 1 && (size_t)n_streams * per_stream_bytes > budget) + n_streams--; + + std::vector streams(n_streams); + for (int i = 0; i < n_streams; i++) cudaStreamCreate(&streams[i]); + + int tpb = UTIL_BLOCK_SIZE; + bool rank_use_gmem = false; + size_t smem_bytes = sparse_ovr_smem_config(n_groups, rank_use_gmem); + int scatter_blocks = (n_rows + tpb - 1) / tpb; + + struct StreamBuf { + int* col_offsets; // [sub_batch_cols + 1] CSC-style offsets + int* write_pos; // [sub_batch_cols] atomic write counters + float* csc_vals; // [max_batch_nnz] transposed values + int* csc_row_idx; // [max_batch_nnz] transposed row indices + float* keys_out; // [max_batch_nnz] CUB sort output + int* vals_out; // [max_batch_nnz] CUB sort output + uint8_t* cub_temp; + double* sub_rank_sums; + double* sub_tie_corr; + double* d_nz_scratch; // gmem-only + }; + std::vector bufs(n_streams); + for (int s = 0; s < n_streams; s++) { + bufs[s].col_offsets = pool.alloc(sub_batch_cols + 1); + bufs[s].write_pos = pool.alloc(sub_batch_cols); + bufs[s].csc_vals = pool.alloc(max_batch_nnz); + bufs[s].csc_row_idx = pool.alloc(max_batch_nnz); + bufs[s].keys_out = pool.alloc(max_batch_nnz); + bufs[s].vals_out = pool.alloc(max_batch_nnz); + bufs[s].cub_temp = pool.alloc(cub_temp_bytes); + bufs[s].sub_rank_sums = + pool.alloc((size_t)n_groups * sub_batch_cols); + bufs[s].sub_tie_corr = pool.alloc(sub_batch_cols); + bufs[s].d_nz_scratch = + rank_use_gmem + ? pool.alloc((size_t)n_groups * sub_batch_cols) + : nullptr; + } + + cudaDeviceSynchronize(); + + // ---- Phase 2: Stream loop ---- + int col = 0; + for (int b = 0; b < n_batches; b++) { + int sb_cols = std::min(sub_batch_cols, n_cols - col); + int s = b % n_streams; + auto stream = streams[s]; + auto& buf = bufs[s]; + int batch_nnz = + checked_int_span(h_batch_nnz[b], "OVR device CSR active batch nnz"); + + // D2D copy pre-computed col_offsets for this batch + int* src = d_all_offsets + (size_t)b * (sub_batch_cols + 1); + cudaMemcpyAsync(buf.col_offsets, src, (sb_cols + 1) * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + // Initialize write_pos = col_offsets[0..sb_cols-1] (same D2D source) + cudaMemcpyAsync(buf.write_pos, src, sb_cols * sizeof(int), + cudaMemcpyDeviceToDevice, stream); + + if (batch_nnz > 0) { + // Scatter CSR → CSC layout for this sub-batch + csr_scatter_to_csc_kernel<<>>( + csr_data, csr_indices, csr_indptr, buf.write_pos, buf.csc_vals, + buf.csc_row_idx, n_rows, col, col + sb_cols); + CUDA_CHECK_LAST_ERROR(csr_scatter_to_csc_kernel); + + // CUB sort only the nonzeros + size_t temp = cub_temp_bytes; + cub::DeviceSegmentedRadixSort::SortPairs( + buf.cub_temp, temp, buf.csc_vals, buf.keys_out, buf.csc_row_idx, + buf.vals_out, batch_nnz, sb_cols, buf.col_offsets, + buf.col_offsets + 1, BEGIN_BIT, END_BIT, stream); + } + + // Sparse rank kernel (handles implicit zeros analytically) + if (rank_use_gmem) { + cudaMemsetAsync(buf.sub_rank_sums, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + cudaMemsetAsync(buf.d_nz_scratch, 0, + (size_t)n_groups * sb_cols * sizeof(double), + stream); + } + rank_sums_sparse_ovr_kernel<<>>( + buf.keys_out, buf.vals_out, buf.col_offsets, group_codes, + group_sizes, buf.sub_rank_sums, buf.sub_tie_corr, buf.d_nz_scratch, + n_rows, sb_cols, n_groups, compute_tie_corr, rank_use_gmem); + CUDA_CHECK_LAST_ERROR(rank_sums_sparse_ovr_kernel); + + // Scatter results to global output + cudaMemcpy2DAsync(rank_sums + col, n_cols * sizeof(double), + buf.sub_rank_sums, sb_cols * sizeof(double), + sb_cols * sizeof(double), n_groups, + cudaMemcpyDeviceToDevice, stream); + if (compute_tie_corr) { + cudaMemcpyAsync(tie_corr + col, buf.sub_tie_corr, + sb_cols * sizeof(double), cudaMemcpyDeviceToDevice, + stream); + } + + col += sb_cols; + } + + for (int s = 0; s < n_streams; s++) { + cudaError_t err = cudaStreamSynchronize(streams[s]); + if (err != cudaSuccess) + throw std::runtime_error( + std::string("CUDA error in sparse CSR ovr streaming: ") + + cudaGetErrorString(err)); + } + + for (int s = 0; s < n_streams; s++) cudaStreamDestroy(streams[s]); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu new file mode 100644 index 00000000..94a101e9 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_rmm.cu @@ -0,0 +1,20 @@ +#include +#include +#include + +#include +#include + +void* wilcoxon_rmm_allocate(size_t bytes) { + try { + return rmm::mr::get_current_device_resource()->allocate_sync(bytes); + } catch (std::exception const& e) { + throw std::runtime_error( + std::string("RMM allocation failed in Wilcoxon scratch (") + + std::to_string(bytes) + " bytes): " + e.what()); + } +} + +void wilcoxon_rmm_deallocate(void* ptr, size_t bytes) { + rmm::mr::get_current_device_resource()->deallocate_sync(ptr, bytes); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu new file mode 100644 index 00000000..4316d284 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse.cu @@ -0,0 +1,302 @@ +#include +#include + +#include + +#include "../nb_types.h" +#include "wilcoxon_fast_common.cuh" +#include "wilcoxon_sparse_kernels.cuh" +#include "wilcoxon_ovr_kernels.cuh" +#include "wilcoxon_ovr_sparse.cuh" +#include "kernels_wilcoxon_ovo.cuh" +#include "wilcoxon_ovo_kernels.cuh" +#include "wilcoxon_ovo_device_sparse.cuh" +#include "wilcoxon_ovo_host_sparse.cuh" + +using namespace nb::literals; + +template +void register_sparse_bindings(nb::module_& m) { + m.doc() = "Sparse-native host Wilcoxon CUDA kernels"; + + m.def( + "ovr_sparse_csc_device", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csc_streaming_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovr_sparse_csr_device", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c group_codes, + gpu_array_c group_sizes, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_rows, int n_cols, + int n_groups, bool compute_tie_corr, int sub_batch_cols) { + ovr_sparse_csr_streaming_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + group_codes.data(), group_sizes.data(), rank_sums.data(), + tie_corr.data(), n_rows, n_cols, n_groups, compute_tie_corr, + sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "group_codes"_a, + "group_sizes"_a, "rank_sums"_a, "tie_corr"_a, nb::kw_only(), "n_rows"_a, + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, + "sub_batch_cols"_a = SUB_BATCH_COLS); + +#define RSC_OVR_SPARSE_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csc_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host", float, int, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_i64", float, int, + int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSC_HOST_BINDING("ovr_sparse_csc_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVR_SPARSE_CSC_HOST_BINDING + +#define RSC_OVR_SPARSE_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_group_codes, \ + host_array h_group_sizes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_rows, int n_cols, \ + int n_groups, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, int sub_batch_cols) { \ + ovr_sparse_csr_host_streaming_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_group_codes.data(), h_group_sizes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_rows, n_cols, \ + n_groups, compute_tie_corr, compute_sq_sums, compute_nnz, \ + sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_group_codes"_a, \ + "h_group_sizes"_a, "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_rows"_a, \ + "n_cols"_a, "n_groups"_a, "compute_tie_corr"_a, \ + "compute_sq_sums"_a = true, "compute_nnz"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host", float, int, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_i64", float, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64", float, int64_t, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_idx64_i64", float, + int64_t, int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64", double, int, + int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_i64", double, int, + int64_t); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVR_SPARSE_CSR_HOST_BINDING("ovr_sparse_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVR_SPARSE_CSR_HOST_BINDING + + m.def( + "ovo_streaming_csc_device", + [](gpu_array_c csc_data, + gpu_array_c csc_indices, + gpu_array_c csc_indptr, + gpu_array_c ref_row_map, + gpu_array_c grp_row_map, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csc_impl( + csc_data.data(), csc_indices.data(), csc_indptr.data(), + ref_row_map.data(), grp_row_map.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csc_data"_a, "csc_indices"_a, "csc_indptr"_a, "ref_row_map"_a, + "grp_row_map"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + + m.def( + "ovo_streaming_csr_device", + [](gpu_array_c csr_data, + gpu_array_c csr_indices, + gpu_array_c csr_indptr, + gpu_array_c ref_row_ids, + gpu_array_c grp_row_ids, + gpu_array_c grp_offsets, + gpu_array_c rank_sums, + gpu_array_c tie_corr, int n_ref, int n_all_grp, + int n_cols, int n_groups, bool compute_tie_corr, + int sub_batch_cols) { + ovo_streaming_csr_impl( + csr_data.data(), csr_indices.data(), csr_indptr.data(), + ref_row_ids.data(), grp_row_ids.data(), grp_offsets.data(), + rank_sums.data(), tie_corr.data(), n_ref, n_all_grp, n_cols, + n_groups, compute_tie_corr, sub_batch_cols); + }, + "csr_data"_a, "csr_indices"_a, "csr_indptr"_a, "ref_row_ids"_a, + "grp_row_ids"_a, "grp_offsets"_a, "rank_sums"_a, "tie_corr"_a, + nb::kw_only(), "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_groups"_a, + "compute_tie_corr"_a, "sub_batch_cols"_a = SUB_BATCH_COLS); + +#define RSC_OVO_CSC_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_map, \ + host_array h_grp_row_map, \ + host_array h_grp_offsets, \ + host_array h_stats_codes, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_ref, int n_all_grp, \ + int n_rows, int n_cols, int n_groups, int n_groups_stats, \ + bool compute_tie_corr, bool compute_sq_sums, bool compute_nnz, \ + int sub_batch_cols) { \ + ovo_streaming_csc_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), \ + h_ref_row_map.data(), h_grp_row_map.data(), \ + h_grp_offsets.data(), h_stats_codes.data(), \ + d_rank_sums.data(), d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_ref, n_all_grp, \ + n_rows, n_cols, n_groups, n_groups_stats, compute_tie_corr, \ + compute_sq_sums, compute_nnz, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_map"_a, \ + "h_grp_row_map"_a, "h_grp_offsets"_a, "h_stats_codes"_a, \ + "d_rank_sums"_a, "d_tie_corr"_a, "d_group_sums"_a, \ + "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), "n_ref"_a, \ + "n_all_grp"_a, "n_rows"_a, "n_cols"_a, "n_groups"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host", float, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_i64", float, int, int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64", float, int64_t, + int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64", double, int, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSC_HOST_BINDING("ovo_streaming_csc_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSC_HOST_BINDING + +#define RSC_OVO_CSR_HOST_BINDING(NAME, InT, IndexT, IndptrT) \ + m.def( \ + NAME, \ + [](host_array h_data, host_array h_indices, \ + host_array h_indptr, \ + host_array h_ref_row_ids, \ + host_array h_grp_row_ids, \ + host_array h_grp_offsets, \ + gpu_array_c d_rank_sums, \ + gpu_array_c d_tie_corr, \ + gpu_array_c d_group_sums, \ + gpu_array_c d_group_sq_sums, \ + gpu_array_c d_group_nnz, int n_full_rows, \ + int n_ref, int n_all_grp, int n_cols, int n_test, \ + int n_groups_stats, bool compute_tie_corr, bool compute_sq_sums, \ + bool compute_nnz, bool compute_sums, int sub_batch_cols) { \ + ovo_streaming_csr_host_impl( \ + h_data.data(), h_indices.data(), h_indptr.data(), n_full_rows, \ + h_ref_row_ids.data(), n_ref, h_grp_row_ids.data(), \ + h_grp_offsets.data(), n_all_grp, n_test, d_rank_sums.data(), \ + d_tie_corr.data(), d_group_sums.data(), \ + d_group_sq_sums.data(), d_group_nnz.data(), n_cols, \ + n_groups_stats, compute_tie_corr, compute_sq_sums, \ + compute_nnz, compute_sums, sub_batch_cols); \ + }, \ + "h_data"_a, "h_indices"_a, "h_indptr"_a, "h_ref_row_ids"_a, \ + "h_grp_row_ids"_a, "h_grp_offsets"_a, "d_rank_sums"_a, "d_tie_corr"_a, \ + "d_group_sums"_a, "d_group_sq_sums"_a, "d_group_nnz"_a, nb::kw_only(), \ + "n_full_rows"_a, "n_ref"_a, "n_all_grp"_a, "n_cols"_a, "n_test"_a, \ + "n_groups_stats"_a, "compute_tie_corr"_a, "compute_sq_sums"_a = true, \ + "compute_nnz"_a = true, "compute_sums"_a = true, \ + "sub_batch_cols"_a = SUB_BATCH_COLS) + + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host", float, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_i64", float, int, int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64", float, int64_t, + int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_idx64_i64", float, int64_t, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64", double, int, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_i64", double, int, + int64_t); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64", double, + int64_t, int); + RSC_OVO_CSR_HOST_BINDING("ovo_streaming_csr_host_f64_idx64_i64", double, + int64_t, int64_t); +#undef RSC_OVO_CSR_HOST_BINDING +} + +NB_MODULE(_wilcoxon_sparse_cuda, m) { + REGISTER_GPU_BINDINGS(register_sparse_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh new file mode 100644 index 00000000..efdac894 --- /dev/null +++ b/src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_sparse_kernels.cuh @@ -0,0 +1,494 @@ +#pragma once + +#include + +/** + * Sparse-aware OVR rank-sum kernel for nonnegative sorted stored values. + * + * Sparse rank_genes_groups now rejects explicit negative sparse values before + * reaching CUDA, so after CUB sort each column segment is: + * [stored_zeros..., positives...] + * + * Implicit zeros (n_rows - nnz_stored) join stored zeros as the first tie + * block. The kernel ranks only stored positive values and adds each group's + * zero contribution analytically. + * + * Full sorted array (conceptual): + * [ALL_zeros (stored+implicit)..., positives...] + * + * Rank offsets: + * positive at stored pos i : full pos = i + n_implicit_zero + * zeros : avg rank = (total_zero + 1) / 2 + * + * Shared-memory layout (doubles): + * grp_sums[n_groups] rank-sum accumulators + * grp_nz_count[n_groups] nonzero-per-group counters + * warp_buf[32] tie-correction reduction scratch + * + * n_rows is the ranking population, including rows whose group code is the + * n_groups sentinel. Sentinel rows contribute to the "rest" distribution and + * tie-correction denominator but do not receive rank-sum accumulation. + * + * Grid: (sb_cols,) Block: (tpb,) + */ +template +__global__ void rank_sums_sparse_ovr_kernel( + const float* __restrict__ sorted_vals, + const IndexT* __restrict__ sorted_row_idx, + const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, const double* __restrict__ group_sizes, + double* __restrict__ rank_sums, double* __restrict__ tie_corr, + double* __restrict__ nz_count_scratch, int n_rows, int sb_cols, + int n_groups, bool compute_tie_corr, bool use_gmem) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + int nnz_stored = seg_end - seg_start; + + const float* sv = sorted_vals + seg_start; + const IndexT* si = sorted_row_idx + seg_start; + + extern __shared__ double smem[]; + double* grp_sums; + double* grp_nz_count; + // Accumulator stride: 1 for shared mem (dense per-block), sb_cols for + // gmem (row-major layout (n_groups, sb_cols) shared across blocks). + int acc_stride; + + if (use_gmem) { + // Output rank_sums doubles as accumulator (pre-zeroed by caller). + grp_sums = rank_sums + (size_t)col; + grp_nz_count = nz_count_scratch + (size_t)col; + acc_stride = sb_cols; + } else { + grp_sums = smem; + grp_nz_count = smem + n_groups; + acc_stride = 1; + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + grp_sums[g] = 0.0; + grp_nz_count[g] = 0.0; + } + __syncthreads(); + } + + // --- Find stored zero range: pos_start = first val > 0 --- + __shared__ int sh_pos_start; + if (threadIdx.x == 0) { + // Binary search: first index where sv[i] > 0.0 + int lo = 0, hi = nnz_stored; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] <= 0.0f) + lo = mid + 1; + else + hi = mid; + } + sh_pos_start = lo; + } + __syncthreads(); + + int pos_start = sh_pos_start; + int n_stored_zero = pos_start; + int n_implicit_zero = n_rows - nnz_stored; + int total_zero = n_implicit_zero + n_stored_zero; + double zero_avg_rank = (total_zero > 0) ? (total_zero + 1.0) / 2.0 : 0.0; + + // Rank offset for positive stored values: + // full_pos(i) = i + n_implicit_zero for i >= pos_start + // So avg_rank for tie group [a,b) of positives: + // = n_implicit_zero + (a + b + 1) / 2 + int offset_pos = n_implicit_zero; + + // --- Count stored positive values per group --- + for (int i = pos_start + threadIdx.x; i < nnz_stored; i += blockDim.x) { + int grp = group_codes[si[i]]; + if (grp < n_groups) { + atomicAdd(&grp_nz_count[grp * acc_stride], 1.0); + } + } + __syncthreads(); + + // --- Zero-rank contribution per group --- + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + double n_zero_in_g = group_sizes[g] - grp_nz_count[g * acc_stride]; + grp_sums[g * acc_stride] = n_zero_in_g * zero_avg_rank; + } + __syncthreads(); + + // --- Walk stored positives only and compute ranks --- + int n_pos = nnz_stored - pos_start; + int chunk = (n_pos + blockDim.x - 1) / blockDim.x; + int my_start = pos_start + threadIdx.x * chunk; + int my_end = my_start + chunk; + if (my_end > nnz_stored) my_end = nnz_stored; + + double local_tie_sum = 0.0; + + int i = my_start; + while (i < my_end) { + float val = sv[i]; + + int tie_local_end = i + 1; + while (tie_local_end < my_end && sv[tie_local_end] == val) + ++tie_local_end; + + int tie_global_start = i; + if (i == my_start && i > 0 && sv[i - 1] == val) { + // Binary search for first occurrence + int lo = pos_start, hi = i; + while (lo < hi) { + int mid = lo + ((hi - lo) >> 1); + if (sv[mid] < val) + lo = mid + 1; + else + hi = mid; + } + tie_global_start = lo; + } + + int tie_global_end = tie_local_end; + if (tie_local_end == my_end && tie_local_end < nnz_stored && + sv[tie_local_end] == val) { + int lo = tie_local_end, hi = nnz_stored - 1; + while (lo < hi) { + int mid = hi - ((hi - lo) >> 1); + if (sv[mid] > val) + hi = mid - 1; + else + lo = mid; + } + tie_global_end = lo + 1; + } + + int total_tie = tie_global_end - tie_global_start; + + double avg_rank = (double)offset_pos + + (double)(tie_global_start + tie_global_end + 1) / 2.0; + + for (int j = i; j < tie_local_end; ++j) { + int grp = group_codes[si[j]]; + if (grp < n_groups) { + atomicAdd(&grp_sums[grp * acc_stride], avg_rank); + } + } + + if (compute_tie_corr && tie_global_start >= my_start && total_tie > 1) { + double t = (double)total_tie; + local_tie_sum += t * t * t - t; + } + + i = tie_local_end; + } + + __syncthreads(); + + // Write rank sums to global output (smem path only — gmem path is direct) + if (!use_gmem) { + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + rank_sums[(size_t)g * sb_cols + col] = grp_sums[g]; + } + } + + // Tie correction: warp + block reduction + if (compute_tie_corr) { + // Zero tie group contribution (one thread only) + if (threadIdx.x == 0 && total_zero > 1) { + double tz = (double)total_zero; + local_tie_sum += tz * tz * tz - tz; + } + + // smem path: warp buf after both accumulator arrays (2 * n_groups). + // gmem path: accumulators are in gmem, warp buf starts at smem[0]. + int warp_buf_off = use_gmem ? 0 : 2 * n_groups; + double* warp_buf = smem + warp_buf_off; + +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + local_tie_sum += __shfl_down_sync(0xffffffff, local_tie_sum, off); + int lane = threadIdx.x & 31; + int wid = threadIdx.x >> 5; + if (lane == 0) warp_buf[wid] = local_tie_sum; + __syncthreads(); + if (threadIdx.x < 32) { + double v = (threadIdx.x < ((blockDim.x + 31) >> 5)) + ? warp_buf[threadIdx.x] + : 0.0; +#pragma unroll + for (int off = 16; off > 0; off >>= 1) + v += __shfl_down_sync(0xffffffff, v, off); + if (threadIdx.x == 0) { + double n = (double)n_rows; + double denom = n * n * n - n; + tie_corr[col] = (denom > 0.0) ? (1.0 - v / denom) : 1.0; + } + } + } +} + +static size_t cast_accumulate_smem_config(int n_groups, bool compute_sq_sums, + bool compute_nnz, bool& use_gmem) { + int n_arrays = 1 + (compute_sq_sums ? 1 : 0) + (compute_nnz ? 1 : 0); + size_t need = (size_t)n_arrays * n_groups * sizeof(double); + if (need <= wilcoxon_max_smem_per_block()) { + use_gmem = false; + return need; + } + use_gmem = true; + return 0; +} + +/** + * Pre-sort cast-and-accumulate kernel for dense OVR host streaming. + * + * Reads a sub-batch block in its native host dtype (InT = float or double), + * writes a float32 copy used as the sort input, and accumulates per-group + * sum, sum-of-squares and nonzero counts in float64. Stats are derived + * from the original-precision values so float64 host input keeps its + * precision while the sort still runs on float32 keys. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles (s_sum, s_sq, s_nnz). + */ +template +__global__ void ovr_cast_and_accumulate_dense_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_dense_global_kernel( + const InT* __restrict__ block_in, float* __restrict__ block_f32_out, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + const InT* src = block_in + (size_t)col * n_rows; + float* dst = block_f32_out + (size_t)col * n_rows; + + for (int r = threadIdx.x; r < n_rows; r += blockDim.x) { + InT v_in = src[r]; + double v = (double)v_in; + dst[r] = (float)v_in; + int g = group_codes[r]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +/** + * Pre-sort cast-and-accumulate kernel for sparse OVR host streaming. + * + * Sub-batch CSC data is laid out contiguously: values for column c live + * at positions [col_seg_offsets[c], col_seg_offsets[c+1]). For each + * stored value, read the native-dtype InT, write a float32 copy for the + * CUB sort, and accumulate per-group sum/sum-sq/nnz in float64. Implicit + * zeros contribute nothing to any of these stats. + * + * Block-per-column layout (grid: (sb_cols,), block: (tpb,)). + * Shared memory: 3 * n_groups doubles. + */ +template +__global__ void ovr_cast_and_accumulate_sparse_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + extern __shared__ double smem[]; + double* s_sum = smem; + double* s_sq = smem + n_groups; + double* s_nnz = smem + 2 * n_groups; + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + s_sum[g] = 0.0; + if (compute_sq_sums) s_sq[g] = 0.0; + if (compute_nnz) s_nnz[g] = 0.0; + } + __syncthreads(); + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&s_sum[g], v); + if (compute_sq_sums) atomicAdd(&s_sq[g], v * v); + if (compute_nnz && v != 0.0) atomicAdd(&s_nnz[g], 1.0); + } + } + __syncthreads(); + + for (int g = threadIdx.x; g < n_groups; g += blockDim.x) { + group_sums[(size_t)g * sb_cols + col] = s_sum[g]; + if (compute_sq_sums) { + group_sq_sums[(size_t)g * sb_cols + col] = s_sq[g]; + } + if (compute_nnz) { + group_nnz[(size_t)g * sb_cols + col] = s_nnz[g]; + } + } +} + +template +__global__ void ovr_cast_and_accumulate_sparse_global_kernel( + const InT* __restrict__ data_in, float* __restrict__ data_f32_out, + const IndexT* __restrict__ indices, const int* __restrict__ col_seg_offsets, + const int* __restrict__ group_codes, double* __restrict__ group_sums, + double* __restrict__ group_sq_sums, double* __restrict__ group_nnz, + int sb_cols, int n_groups, bool compute_sq_sums = true, + bool compute_nnz = true) { + int col = blockIdx.x; + if (col >= sb_cols) return; + + int seg_start = col_seg_offsets[col]; + int seg_end = col_seg_offsets[col + 1]; + + for (int i = seg_start + threadIdx.x; i < seg_end; i += blockDim.x) { + InT v_in = data_in[i]; + double v = (double)v_in; + data_f32_out[i] = (float)v_in; + int row = (int)indices[i]; + int g = group_codes[row]; + if (g < n_groups) { + atomicAdd(&group_sums[(size_t)g * sb_cols + col], v); + if (compute_sq_sums) { + atomicAdd(&group_sq_sums[(size_t)g * sb_cols + col], v * v); + } + if (compute_nnz && v != 0.0) { + atomicAdd(&group_nnz[(size_t)g * sb_cols + col], 1.0); + } + } + } +} + +template +static void launch_ovr_cast_and_accumulate_dense( + const InT* d_block_orig, float* d_block_f32, const int* d_group_codes, + double* d_group_sums, double* d_group_sq_sums, double* d_group_nnz, + int n_rows, int sb_cols, int n_groups, bool compute_sq_sums, + bool compute_nnz, int tpb, size_t smem_cast, bool use_gmem, + cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_dense_global_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_global_kernel); + } else { + ovr_cast_and_accumulate_dense_kernel + <<>>( + d_block_orig, d_block_f32, d_group_codes, d_group_sums, + d_group_sq_sums, d_group_nnz, n_rows, sb_cols, n_groups, + compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_dense_kernel); + } +} + +template +static void launch_ovr_cast_and_accumulate_sparse( + const InT* d_data_orig, float* d_data_f32, const IndexT* d_indices, + const int* d_col_offsets, const int* d_group_codes, double* d_group_sums, + double* d_group_sq_sums, double* d_group_nnz, int sb_cols, int n_groups, + bool compute_sq_sums, bool compute_nnz, int tpb, size_t smem_cast, + bool use_gmem, cudaStream_t stream) { + if (use_gmem) { + size_t stats_items = (size_t)n_groups * sb_cols; + cudaMemsetAsync(d_group_sums, 0, stats_items * sizeof(double), stream); + if (compute_sq_sums) { + cudaMemsetAsync(d_group_sq_sums, 0, stats_items * sizeof(double), + stream); + } + if (compute_nnz) { + cudaMemsetAsync(d_group_nnz, 0, stats_items * sizeof(double), + stream); + } + ovr_cast_and_accumulate_sparse_global_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_global_kernel); + } else { + ovr_cast_and_accumulate_sparse_kernel + <<>>( + d_data_orig, d_data_f32, d_indices, d_col_offsets, + d_group_codes, d_group_sums, d_group_sq_sums, d_group_nnz, + sb_cols, n_groups, compute_sq_sums, compute_nnz); + CUDA_CHECK_LAST_ERROR(ovr_cast_and_accumulate_sparse_kernel); + } +} diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py index 0b9753a3..a204d73e 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/__init__.py @@ -21,6 +21,33 @@ ] +def _array_result_to_records( + arrays: dict[str, object], field: str, dtype: str | np.dtype +) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + values = np.asarray(arrays[field]) + out = np.empty( + values.shape[1], + dtype=[(group_name, np.dtype(dtype)) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = values[row] + return out + + +def _array_result_to_names(arrays: dict[str, object]) -> np.ndarray: + group_names = tuple(str(name) for name in arrays["group_names"]) + var_names = np.asarray(arrays["var_names"]) + gene_indices = np.asarray(arrays["gene_indices"], dtype=np.intp) + out = np.empty( + gene_indices.shape[1], + dtype=[(group_name, object) for group_name in group_names], + ) + for row, group_name in enumerate(group_names): + out[group_name] = var_names[gene_indices[row]] + return out + + def rank_genes_groups( adata: AnnData, groupby: str, @@ -37,17 +64,21 @@ def rank_genes_groups( corr_method: _CorrMethod = "benjamini-hochberg", tie_correct: bool = False, use_continuity: bool = False, + return_u_values: bool = False, layer: str | None = None, chunk_size: int | None = None, pre_load: bool = False, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + skip_empty_groups: bool = False, **kwds, ) -> None: """ Rank genes for characterizing groups using GPU acceleration. - Expects logarithmized data. + Expects nonnegative expression data. Log1p/log-normalized data is expected + for biologically meaningful log fold changes; negative values are rejected + for eager in-memory inputs. .. note:: **Dask support:** `'t-test'`, `'t-test_overestim_var'`, and @@ -101,6 +132,10 @@ def rank_genes_groups( z-scores. Subtracts 0.5 from ``|R - E[R]|`` before dividing by the standard deviation, matching :func:`scipy.stats.mannwhitneyu` default behavior. + return_u_values + For `'wilcoxon'`, store Mann-Whitney U statistics in `scores` instead + of z-scores. P-values are still computed from the z-score normal + approximation using the selected tie and continuity settings. layer Key from `adata.layers` whose value will be used to perform tests on. chunk_size @@ -119,15 +154,20 @@ def rank_genes_groups( ``None`` (default) uses ``'auto'`` for in-memory arrays and ``'log1p'`` for Dask arrays (to avoid a costly data scan). ``'log1p'`` uses a fixed [0, 15] range suitable for most log1p-normalized data. - ``'auto'`` computes the actual data range. Use this for z-scored - or unnormalized data. + ``'auto'`` computes the actual data range. Use this for nonnegative + expression data outside the fixed log1p range. + skip_empty_groups + Skip selected groups with fewer than two observations after filtering. + This is useful for perturbation workflows where a per-cell-type slice + keeps categories that are empty or singleton in that slice. **kwds Additional arguments passed to the method. For `'logreg'`, these are passed to :class:`cuml.linear_model.LogisticRegression`. Returns ------- - Updates `adata` with the following fields: + Updates `adata` with the following fields. Rank result fields are + Scanpy-compatible structured arrays. `adata.uns['rank_genes_groups' | key_added]['names']` Structured array to be indexed by group id storing the gene @@ -135,7 +175,8 @@ def rank_genes_groups( `adata.uns['rank_genes_groups' | key_added]['scores']` Structured array to be indexed by group id storing the z-score underlying the computation of a p-value for each gene for each - group. Ordered according to scores. + group, or the Mann-Whitney U statistic when + `return_u_values=True`. Ordered according to scores. `adata.uns['rank_genes_groups' | key_added]['logfoldchanges']` Structured array to be indexed by group id storing the log2 fold change for each gene for each group. @@ -154,6 +195,13 @@ def rank_genes_groups( msg = "corr_method must be either 'benjamini-hochberg' or 'bonferroni'." raise ValueError(msg) + if "return_format" in kwds: + msg = ( + "return_format has been removed; rank_genes_groups always writes " + "Scanpy-compatible structured results to adata.uns." + ) + raise TypeError(msg) + if method is None: method = "t-test" @@ -170,6 +218,10 @@ def rank_genes_groups( ) raise ValueError(msg) + if return_u_values and method != "wilcoxon": + msg = "return_u_values is only supported for method='wilcoxon'." + raise ValueError(msg) + if key_added is None: key_added = "rank_genes_groups" @@ -197,6 +249,7 @@ def rank_genes_groups( layer=layer, comp_pts=pts, pre_load=pre_load, + skip_empty_groups=skip_empty_groups, ) # Determine n_genes_user @@ -211,25 +264,14 @@ def rank_genes_groups( rankby_abs=rankby_abs, tie_correct=tie_correct, use_continuity=use_continuity, + return_u_values=return_u_values, chunk_size=chunk_size, n_bins=n_bins, bin_range=bin_range, **kwds, ) - # Build output - test_obj.stats.columns = test_obj.stats.columns.swaplevel() - - dtypes = { - "names": "U50", - "scores": "float32", - "logfoldchanges": "float32", - "pvals": "float64", - "pvals_adj": "float64", - } - - adata.uns[key_added] = {} - adata.uns[key_added]["params"] = { + params = { "groupby": groupby, "reference": reference, "method": method, @@ -237,10 +279,22 @@ def rank_genes_groups( "layer": layer, "corr_method": corr_method, } + if method == "wilcoxon": + params["tie_correct"] = tie_correct + params["return_u_values"] = return_u_values + + arrays = test_obj.stats_arrays or {} + adata.uns[key_added] = {"params": params} + if arrays and len(arrays.get("group_names", ())) > 0: + adata.uns[key_added]["names"] = _array_result_to_names(arrays) + for col in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + if col in arrays: + values = arrays[col] + dtype = values.dtype + adata.uns[key_added][col] = _array_result_to_records(arrays, col, dtype) - # Store pts results if computed + groups_names = [str(name) for name in test_obj.groups_order] if test_obj.pts is not None: - groups_names = [str(name) for name in test_obj.groups_order] adata.uns[key_added]["pts"] = pd.DataFrame( test_obj.pts.T, index=test_obj.var_names, columns=groups_names ) @@ -249,14 +303,7 @@ def rank_genes_groups( test_obj.pts_rest.T, index=test_obj.var_names, columns=groups_names ) - if method == "wilcoxon": - adata.uns[key_added]["params"]["tie_correct"] = tie_correct - - for col in test_obj.stats.columns.levels[0]: - if col in dtypes: - adata.uns[key_added][col] = test_obj.stats[col].to_records( - index=False, column_dtypes=dtypes[col] - ) + return None if TYPE_CHECKING: @@ -285,7 +332,7 @@ def rank_genes_groups_logreg( layer: str | None = None, **kwds, ) -> None: - rank_genes_groups( + return rank_genes_groups( adata, groupby, groups=groups, diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py index 019e126b..7f062b21 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_core.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_core.py @@ -1,18 +1,77 @@ from __future__ import annotations +import os +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Literal, assert_never import cupy as cp import numpy as np import pandas as pd -from statsmodels.stats.multitest import multipletests from rapids_singlecell._compat import DaskArray from rapids_singlecell.get import X_to_GPU from rapids_singlecell.get._aggregated import Aggregate from rapids_singlecell.preprocessing._utils import _check_gpu_X -from ._utils import EPS, _select_groups, _select_top_n +from ._utils import EPS, _check_sparse_nonnegative, _select_groups + +_FDR_BH_REVERSE_CUMMIN_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ void fdr_bh_reverse_cummin(double* values, const int n_cols) { + const int row = blockIdx.x; + double running = 1.0; + double* row_values = values + static_cast(row) * n_cols; + for (int col = n_cols - 1; col >= 0; --col) { + double value = row_values[col]; + if (!(value == value)) { + value = 1.0; + } + if (value < running) { + running = value; + } + row_values[col] = running; + } +} +""", + "fdr_bh_reverse_cummin", +) +_GROUP_CHUNK_STATS_KERNEL = cp.RawKernel( + r""" +extern "C" __global__ void group_chunk_stats( + const double* block, + const int* group_codes, + double* group_sums, + double* group_sum_sq, + double* group_nnz, + const int n_rows, + const int n_cols, + const int n_groups, + const bool compute_nnz +) { + const long long idx = blockIdx.x * blockDim.x + threadIdx.x; + const long long total = static_cast(n_rows) * n_cols; + if (idx >= total) { + return; + } + const int row = idx % n_rows; + const int col = idx / n_rows; + const int group = group_codes[row]; + if (group < 0 || group >= n_groups) { + return; + } + const double value = block[idx]; + const long long out = static_cast(group) * n_cols + col; + atomicAdd(group_sums + out, value); + atomicAdd(group_sum_sq + out, value * value); + if (compute_nnz && value != 0.0) { + atomicAdd(group_nnz + out, 1.0); + } +} +""", + "group_chunk_stats", +) +_RANK_SORT_MIN_ELEMENTS = 1_000_000 +_RANK_SORT_MAX_WORKERS = 64 if TYPE_CHECKING: from collections.abc import Iterable @@ -38,6 +97,7 @@ def __init__( layer: str | None = None, comp_pts: bool = False, pre_load: bool = False, + skip_empty_groups: bool = False, ) -> None: # Handle groups parameter if groups == "all" or groups is None: @@ -63,7 +123,10 @@ def __init__( raise ValueError(msg) self.groups_order, self.group_codes, self.group_sizes = _select_groups( - self.labels, selected + self.labels, + selected, + reference=reference, + skip_empty_groups=skip_empty_groups, ) # Get data matrix @@ -91,6 +154,8 @@ def __init__( self.X = self.X[:, mask_var] self.var_names = self.var_names[mask_var] + _check_sparse_nonnegative(self.X) + self.pre_load = pre_load self.ireference = None @@ -100,6 +165,7 @@ def __init__( # Set up expm1 function based on log base self.is_log1p = "log1p" in adata.uns base = adata.uns.get("log1p", {}).get("base") + self._log1p_base = base if base is not None: self.expm1_func = lambda x: np.expm1(x * np.log(base)) else: @@ -115,8 +181,14 @@ def __init__( self.pts_rest: np.ndarray | None = None self.stats: pd.DataFrame | None = None + self.stats_arrays: dict[str, object] | None = None + self._store_wilcoxon_gpu_result = False + self._wilcoxon_gpu_result: ( + tuple[np.ndarray, cp.ndarray, cp.ndarray, cp.ndarray | None] | None + ) = None self._compute_stats_in_chunks: bool = False self._ref_chunk_computed: set[int] = set() + self._score_dtype = np.dtype(np.float32) def _init_stats_arrays(self, n_genes: int) -> None: """Pre-allocate stats arrays before chunk loop.""" @@ -231,7 +303,7 @@ def _accumulate_chunk_stats_vs_rest( start: int, stop: int, *, - group_matrix: cp.ndarray, + group_codes_dev: cp.ndarray, group_sizes_dev: cp.ndarray, n_cells: int, ) -> None: @@ -241,9 +313,31 @@ def _accumulate_chunk_stats_vs_rest( rest_sizes = n_cells - group_sizes_dev - # Group sums and sum of squares - group_sums = group_matrix.T @ block - group_sum_sq = group_matrix.T @ (block**2) + n_groups = len(self.groups_order) + n_cols = stop - start + group_sums = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_sum_sq = cp.zeros((n_groups, n_cols), dtype=cp.float64) + group_nnz = ( + cp.zeros((n_groups, n_cols), dtype=cp.float64) if self.comp_pts else None + ) + n_items = n_cells * n_cols + threads = 256 + blocks = (n_items + threads - 1) // threads + _GROUP_CHUNK_STATS_KERNEL( + (blocks,), + (threads,), + ( + block, + group_codes_dev, + group_sums, + group_sum_sq, + group_nnz if group_nnz is not None else group_sums, + np.int32(n_cells), + np.int32(n_cols), + np.int32(n_groups), + self.comp_pts, + ), + ) # Means chunk_means = group_sums / group_sizes_dev[:, None] @@ -256,7 +350,6 @@ def _accumulate_chunk_stats_vs_rest( # Pts (fraction expressing) if self.comp_pts: - group_nnz = group_matrix.T @ (block != 0).astype(cp.float64) self.pts[:, start:stop] = cp.asnumpy(group_nnz / group_sizes_dev[:, None]) # Rest statistics @@ -337,6 +430,7 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" from ._wilcoxon import wilcoxon @@ -346,6 +440,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) def wilcoxon_binned( @@ -387,6 +482,7 @@ def compute_statistics( chunk_size: int | None = None, n_bins: int | None = None, bin_range: Literal["log1p", "auto"] | None = None, + return_u_values: bool = False, **kwds, ) -> None: """Compute statistics for all groups.""" @@ -397,17 +493,28 @@ def compute_statistics( }: self.X = X_to_GPU(self.X) + n_genes = self.X.shape[1] + if n_genes_user is None: + n_genes_user = n_genes + if method in {"t-test", "t-test_overestim_var"}: test_results = self.t_test(method) elif method == "wilcoxon": if isinstance(self.X, DaskArray): msg = "Wilcoxon test is not supported for Dask arrays. Please convert your data to CuPy arrays." raise ValueError(msg) - test_results = self.wilcoxon( - tie_correct=tie_correct, - use_continuity=use_continuity, - chunk_size=chunk_size, - ) + self._score_dtype = np.dtype(np.float64 if return_u_values else np.float32) + self._wilcoxon_gpu_result = None + self._store_wilcoxon_gpu_result = True + try: + test_results = self.wilcoxon( + tie_correct=tie_correct, + use_continuity=use_continuity, + chunk_size=chunk_size, + return_u_values=return_u_values, + ) + finally: + self._store_wilcoxon_gpu_result = False elif method == "wilcoxon_binned": test_results = self.wilcoxon_binned( tie_correct=tie_correct, @@ -421,58 +528,225 @@ def compute_statistics( else: assert_never(method) - n_genes = self.X.shape[1] + if not test_results and self._wilcoxon_gpu_result is None: + self.stats_arrays = { + "group_indices": np.empty(0, dtype=np.intp), + "group_names": np.empty(0, dtype=object), + "var_names": np.asarray(self.var_names), + "gene_indices": np.empty((0, n_genes_user), dtype=np.intp), + } + self.stats = None + return + + if self._wilcoxon_gpu_result is not None: + group_indices, scores_gpu, pvals_gpu, logfoldchanges_gpu = ( + self._wilcoxon_gpu_result + ) + try: + self._compute_statistics_gpu_arrays( + group_indices, + scores_gpu, + pvals_gpu, + logfoldchanges_gpu, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) + finally: + self._wilcoxon_gpu_result = None + return - # Collect all stats data first to avoid DataFrame fragmentation - stats_data: dict[tuple[str, str], np.ndarray] = {} + self._compute_statistics_arrays( + test_results, + corr_method=corr_method, + n_genes_user=n_genes_user, + n_genes=n_genes, + rankby_abs=rankby_abs, + ) + + @staticmethod + def _rank_indices_matrix(scores: np.ndarray, n_top: int) -> np.ndarray: + if n_top >= scores.shape[1]: + return _RankGenes._argsort_desc_matrix(scores) + partition = np.argpartition(scores, -n_top, axis=1)[:, -n_top:] + row_ids = np.arange(scores.shape[0])[:, None] + order = np.argsort(scores[row_ids, partition], axis=1)[:, ::-1] + return partition[row_ids, order] + + @staticmethod + def _argsort_desc_matrix(scores: np.ndarray) -> np.ndarray: + n_rows, n_cols = scores.shape + n_elements = n_rows * n_cols + n_workers = min(_RANK_SORT_MAX_WORKERS, os.cpu_count() or 1, n_rows) + if n_workers <= 1 or n_elements < _RANK_SORT_MIN_ELEMENTS: + return np.argsort(scores, axis=1)[:, ::-1] + + chunks = np.linspace(0, n_rows, n_workers + 1, dtype=np.intp) + indices = np.empty((n_rows, n_cols), dtype=np.intp) + + def sort_chunk(chunk_index: int) -> None: + start = int(chunks[chunk_index]) + stop = int(chunks[chunk_index + 1]) + if start < stop: + indices[start:stop] = np.argsort(scores[start:stop], axis=1)[:, ::-1] + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + list(executor.map(sort_chunk, range(n_workers))) + return indices + + @staticmethod + def _fdr_bh_matrix(pvals: np.ndarray) -> np.ndarray: + pvals_clean = np.array(pvals, copy=True) + pvals_clean[np.isnan(pvals_clean)] = 1.0 + order = np.argsort(pvals_clean, axis=1) + sorted_p = np.take_along_axis(pvals_clean, order, axis=1) + n_tests = sorted_p.shape[1] + scale = n_tests / np.arange(1, n_tests + 1, dtype=np.float64) + corrected_sorted = sorted_p * scale + corrected_sorted = np.minimum.accumulate(corrected_sorted[:, ::-1], axis=1)[ + :, ::-1 + ] + corrected_sorted[corrected_sorted > 1.0] = 1.0 + corrected = np.empty_like(corrected_sorted) + np.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected + + @staticmethod + def _fdr_bh_matrix_gpu(pvals: cp.ndarray) -> cp.ndarray: + pvals_clean = cp.nan_to_num(pvals, nan=1.0) + order = cp.argsort(pvals_clean, axis=1) + corrected_sorted = cp.take_along_axis(pvals_clean, order, axis=1) + corrected_sorted *= corrected_sorted.shape[1] / cp.arange( + 1, corrected_sorted.shape[1] + 1, dtype=cp.float64 + ) + _FDR_BH_REVERSE_CUMMIN_KERNEL( + (corrected_sorted.shape[0],), + (1,), + (corrected_sorted, np.int32(corrected_sorted.shape[1])), + ) + corrected = cp.empty_like(corrected_sorted) + cp.put_along_axis(corrected, order, corrected_sorted, axis=1) + return corrected - for group_index, scores, pvals in test_results: - group_name = str(self.groups_order[group_index]) + def _compute_statistics_arrays( + self, + test_results: list[tuple[int, NDArray, NDArray]], + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray([r[0] for r in test_results], dtype=np.intp) + scores = np.vstack([r[1] for r in test_results]) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": np.take_along_axis(scores, top_idx, axis=1).astype( + self._score_dtype, copy=False + ), + } + + if test_results[0][2] is not None: + pvals = np.vstack([r[2] for r in test_results]) + arrays["pvals"] = np.take_along_axis(pvals, top_idx, axis=1) + if corr_method == "benjamini-hochberg": + pvals_adj = self._fdr_bh_matrix(pvals) + elif corr_method == "bonferroni": + pvals_adj = np.minimum(pvals * n_genes, 1.0) + else: + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = np.take_along_axis(pvals_adj, top_idx, axis=1) - if n_genes_user is not None: - scores_sort = np.abs(scores) if rankby_abs else scores - global_indices = _select_top_n(scores_sort, n_genes_user) + if self.means is not None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] else: - global_indices = slice(None) - - if n_genes_user is not None: - stats_data[group_name, "names"] = np.asarray(self.var_names)[ - global_indices - ] - - stats_data[group_name, "scores"] = scores[global_indices] - - if pvals is not None: - stats_data[group_name, "pvals"] = pvals[global_indices] - if corr_method == "benjamini-hochberg": - pvals_clean = np.array(pvals, copy=True) - pvals_clean[np.isnan(pvals_clean)] = 1.0 - _, pvals_adj, _, _ = multipletests( - pvals_clean, alpha=0.05, method="fdr_bh" - ) - elif corr_method == "bonferroni": - pvals_adj = np.minimum(pvals * n_genes, 1.0) - stats_data[group_name, "pvals_adj"] = pvals_adj[global_indices] - - # Compute logfoldchanges - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + EPS) / ( - self.expm1_func(mean_rest) + EPS + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) + + self.stats_arrays = arrays + self.stats = None + + def _compute_statistics_gpu_arrays( + self, + group_indices: np.ndarray, + scores_gpu: cp.ndarray, + pvals_gpu: cp.ndarray, + logfoldchanges_gpu: cp.ndarray | None, + *, + corr_method: _CorrMethod, + n_genes_user: int, + n_genes: int, + rankby_abs: bool, + ) -> None: + group_indices = np.asarray(group_indices, dtype=np.intp) + scores = cp.asnumpy(scores_gpu) + sort_scores = np.abs(scores) if rankby_abs else scores + top_idx = self._rank_indices_matrix(sort_scores, n_genes_user) + top_idx_gpu = cp.asarray(top_idx) + + arrays: dict[str, object] = { + "group_indices": group_indices, + "group_names": np.asarray( + [str(self.groups_order[i]) for i in group_indices], dtype=object + ), + "var_names": np.asarray(self.var_names), + "gene_indices": top_idx.astype(np.intp, copy=False), + "scores": cp.asnumpy( + cp.take_along_axis(scores_gpu, top_idx_gpu, axis=1).astype( + self._score_dtype, copy=False ) - stats_data[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] + ), + "pvals": cp.asnumpy(cp.take_along_axis(pvals_gpu, top_idx_gpu, axis=1)), + } + + if corr_method == "benjamini-hochberg": + pvals_adj_gpu = self._fdr_bh_matrix_gpu(pvals_gpu) + elif corr_method == "bonferroni": + pvals_adj_gpu = cp.minimum(pvals_gpu * n_genes, 1.0) + else: + msg = f"Unsupported correction method: {corr_method!r}." + raise ValueError(msg) + arrays["pvals_adj"] = cp.asnumpy( + cp.take_along_axis(pvals_adj_gpu, top_idx_gpu, axis=1) + ) + + if logfoldchanges_gpu is not None: + arrays["logfoldchanges"] = cp.asnumpy( + cp.take_along_axis(logfoldchanges_gpu, top_idx_gpu, axis=1).astype( + cp.float32, copy=False ) + ) + elif self.means is not None: + mean_group = self.means[group_indices] + if self.ireference is None: + mean_rest = self.means_rest[group_indices] + else: + mean_rest = self.means[self.ireference][None, :] + foldchanges = (self.expm1_func(mean_group) + EPS) / ( + self.expm1_func(mean_rest) + EPS + ) + logfoldchanges = np.log2(foldchanges) + arrays["logfoldchanges"] = np.take_along_axis( + logfoldchanges, top_idx, axis=1 + ).astype(np.float32, copy=False) - # Create DataFrame all at once to avoid fragmentation - if stats_data: - self.stats = pd.DataFrame(stats_data) - self.stats.columns = pd.MultiIndex.from_tuples(self.stats.columns) - if n_genes_user is None: - self.stats.index = self.var_names - else: - self.stats = None + self.stats_arrays = arrays + self.stats = None diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py index c4f2c601..e9efbc50 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_utils.py @@ -16,11 +16,49 @@ EPS = 1e-9 WARP_SIZE = 32 MAX_THREADS_PER_BLOCK = 512 +MIN_GROUP_SIZE_WARNING = 25 + + +def _nonnegative_error(prefix: str) -> ValueError: + msg = ( + f"{prefix} contains negative values. rank_genes_groups expects " + "nonnegative expression values; use raw counts or log1p/log-normalized " + "expression, not scaled or centered data." + ) + return ValueError(msg) + + +def _check_sparse_nonnegative(X) -> None: + """Reject inputs with negative values where an eager check is cheap. + + Sparse rank_genes_groups code treats missing entries as true expression + zeros. Optimized sparse Wilcoxon paths may rank explicit nonzeros and add + implicit zeros analytically, which is only valid when explicit sparse + values are nonnegative expression values. + """ + dtype = None + if sp.issparse(X) or cpsp.issparse(X): + dtype = np.dtype(X.data.dtype) + elif isinstance(X, np.ndarray | cp.ndarray): + dtype = np.dtype(X.dtype) + if dtype is not None and dtype.kind == "c": + msg = "rank_genes_groups does not support complex expression values." + raise TypeError(msg) + + if sp.issparse(X): + if X.nnz > 0 and float(X.data.min()) < 0: + raise _nonnegative_error("Sparse input") + elif cpsp.issparse(X): + if X.nnz > 0 and float(X.data.min()) < 0: + raise _nonnegative_error("Sparse input") def _select_groups( labels: pd.Series, selected: list | None, + *, + reference: str = "rest", + skip_empty_groups: bool = False, ) -> tuple[NDArray, NDArray[np.int32], NDArray[np.int64]]: """Build integer group codes from a categorical Series. @@ -51,6 +89,29 @@ def _select_groups( cat_order = {str(c): i for i, c in enumerate(all_categories)} selected.sort(key=lambda x: cat_order.get(str(x), len(all_categories))) + if skip_empty_groups: + counts = { + str(name): int(count) for name, count in labels.value_counts().items() + } + valid_selected = [group for group in selected if counts.get(str(group), 0) >= 2] + if reference != "rest": + ref_matches = [group for group in selected if str(group) == str(reference)] + if ref_matches: + ref_group = ref_matches[0] + if ref_group not in valid_selected: + msg = ( + f"reference = {reference} has fewer than two samples after " + "filtering and cannot be used for rank_genes_groups." + ) + raise ValueError(msg) + selected = valid_selected + if len(selected) == 0: + msg = ( + "No groups with at least two samples remain after applying " + "skip_empty_groups=True." + ) + raise ValueError(msg) + n_groups = len(selected) groups_order = np.array(selected) @@ -76,7 +137,7 @@ def _select_groups( if invalid_groups: msg = ( f"Could not calculate statistics for groups {', '.join(invalid_groups)} " - "since they only contain one sample." + "since they contain fewer than two samples." ) raise ValueError(msg) @@ -88,20 +149,6 @@ def _round_up_to_warp(n: int) -> int: return min(MAX_THREADS_PER_BLOCK, ((n + WARP_SIZE - 1) // WARP_SIZE) * WARP_SIZE) -def _select_top_n(scores: NDArray, n_top: int) -> NDArray: - """Select indices of top n scores. - - Uses argpartition + argsort for O(n + k log k) complexity where k = n_top. - This is faster than full sorting when k << n. - """ - n_from = scores.shape[0] - reference_indices = np.arange(n_from, dtype=int) - partition = np.argpartition(scores, -n_top)[-n_top:] - partial_indices = np.argsort(scores[partition])[::-1] - global_indices = reference_indices[partition][partial_indices] - return global_indices - - def _choose_chunk_size(requested: int | None) -> int: """Choose chunk size for gene processing.""" if requested is not None: diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py index c14c760d..880da7e0 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon.py @@ -4,86 +4,402 @@ from typing import TYPE_CHECKING import cupy as cp +import cupyx.scipy.sparse as cpsp import cupyx.scipy.special as cupyx_special import numpy as np import scipy.sparse as sp from rapids_singlecell._cuda import _wilcoxon_cuda as _wc -from rapids_singlecell._utils._csr_to_csc import _fast_csr_to_csc +from rapids_singlecell._cuda import _wilcoxon_sparse_cuda as _wcs -from ._utils import _choose_chunk_size, _get_column_block +from ._utils import EPS, MIN_GROUP_SIZE_WARNING, _choose_chunk_size, _get_column_block if TYPE_CHECKING: from numpy.typing import NDArray from ._core import _RankGenes -MIN_GROUP_SIZE_WARNING = 25 +DEFAULT_WILCOXON_CHUNK_SIZE = 512 +OVR_HOST_CSC_SUB_BATCH = 512 +OVR_HOST_CSR_SUB_BATCH = 2048 +OVR_DEVICE_CSC_SUB_BATCH = 2048 +OVR_DEVICE_CSR_SUB_BATCH = 2048 +OVO_HOST_SPARSE_SUB_BATCH = 256 +OVO_DEVICE_SPARSE_SUB_BATCH = 128 +OVR_DENSE_SUB_BATCH = 64 +OVO_DENSE_TIERED_SUB_BATCH = 256 +DENSE_HOST_PRELOAD_MAX_GPU_FRACTION = 0.55 # leave headroom for rank buffers -def _average_ranks( - matrix: cp.ndarray, *, return_sorted: bool = False -) -> cp.ndarray | tuple[cp.ndarray, cp.ndarray]: - """ - Compute average ranks for each column using GPU kernel. +def _maybe_preload_host_dense(rg: _RankGenes) -> None: + """Preload moderate host-dense matrices to avoid repeated chunk transfers.""" + X = rg.X + if not isinstance(X, np.ndarray) or X.size == 0: + return - Uses scipy.stats.rankdata 'average' method: ties get the average - of the ranks they would span. + try: + _, total = cp.cuda.runtime.memGetInfo() + except cp.cuda.runtime.CUDARuntimeError: + return - Parameters - ---------- - matrix - Input matrix (n_rows, n_cols) - return_sorted - If True, also return sorted values (useful for tie correction) + if X.nbytes > total * DENSE_HOST_PRELOAD_MAX_GPU_FRACTION: + return - Returns - ------- - ranks or (ranks, sorted_vals) - """ - n_rows, n_cols = matrix.shape + registered = False + if X.flags.c_contiguous or X.flags.f_contiguous: + try: + cp.cuda.runtime.hostRegister(X.ctypes.data, X.nbytes, 0) + registered = True + except cp.cuda.runtime.CUDARuntimeError: + registered = False - # Sort each column - sorter = cp.argsort(matrix, axis=0) - sorted_vals = cp.take_along_axis(matrix, sorter, axis=0) + try: + X_gpu = cp.asarray(X) + cp.cuda.get_current_stream().synchronize() + except cp.cuda.memory.OutOfMemoryError: + cp.get_default_memory_pool().free_all_blocks() + return + except cp.cuda.runtime.CUDARuntimeError: + return + finally: + if registered: + try: + cp.cuda.runtime.hostUnregister(X.ctypes.data) + except cp.cuda.runtime.CUDARuntimeError: + pass + rg.X = X_gpu - # Ensure F-order for kernel (columns contiguous in memory) - sorted_vals = cp.asfortranarray(sorted_vals) - sorter = cp.asfortranarray(sorter.astype(cp.int32)) - stream = cp.cuda.get_current_stream().ptr - _wc.average_rank( - sorted_vals, sorter, matrix, n_rows=n_rows, n_cols=n_cols, stream=stream - ) +def _get_dense_column_block_f32(X, start: int, stop: int) -> cp.ndarray: + """Extract a dense column block as F-order float32 CuPy memory.""" + if isinstance(X, np.ndarray | cp.ndarray): + return cp.asarray(X[:, start:stop], dtype=cp.float32, order="F") + raise TypeError(f"Expected dense matrix, got {type(X)}") + + +def _extract_dense_rows_cols( + X, row_ids: np.ndarray, start: int, stop: int +) -> cp.ndarray: + """Extract a bounded row/column block as F-order CuPy dense memory.""" + if isinstance(X, np.ndarray): + return cp.asarray(X[row_ids, start:stop], order="F") + if isinstance(X, cp.ndarray): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows, start:stop]) + if isinstance(X, sp.spmatrix | sp.sparray): + return cp.asarray(X[row_ids][:, start:stop].toarray(), order="F") + if cpsp.issparse(X): + rows = cp.asarray(row_ids, dtype=cp.int32) + return cp.asfortranarray(X[rows][:, start:stop].toarray()) + raise TypeError(f"Unsupported matrix type: {type(X)}") + + +def _choose_wilcoxon_chunk_size(requested: int | None, n_genes: int) -> int: + if requested is not None: + return _choose_chunk_size(requested) + return min(DEFAULT_WILCOXON_CHUNK_SIZE, max(1, n_genes)) + + +def _fill_ovo_chunk_stats( + rg: _RankGenes, + ref_block: cp.ndarray, + grp_block: cp.ndarray, + *, + offsets: np.ndarray, + test_group_indices: list[int], + start: int, + stop: int, + group_sizes: NDArray, +) -> None: + if not rg._compute_stats_in_chunks: + return + + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) + ref_mean = ref_block.mean(axis=0) + rg.means[ireference, start:stop] = cp.asnumpy(ref_mean) + if n_ref > 1: + rg.vars[ireference, start:stop] = cp.asnumpy(ref_block.var(axis=0, ddof=1)) + if rg.comp_pts: + ref_nnz = (ref_block != 0).sum(axis=0) + rg.pts[ireference, start:stop] = cp.asnumpy(ref_nnz / n_ref) + + for slot, group_index in enumerate(test_group_indices): + begin = int(offsets[slot]) + end = int(offsets[slot + 1]) + n_group = int(group_sizes[group_index]) + group_block = grp_block[begin:end] + group_mean = group_block.mean(axis=0) + rg.means[group_index, start:stop] = cp.asnumpy(group_mean) + if n_group > 1: + rg.vars[group_index, start:stop] = cp.asnumpy( + group_block.var(axis=0, ddof=1) + ) + if rg.comp_pts: + group_nnz = (group_block != 0).sum(axis=0) + rg.pts[group_index, start:stop] = cp.asnumpy(group_nnz / n_group) + + +def _fill_basic_stats_from_accumulators( + rg: _RankGenes, + group_sums: cp.ndarray, + group_sq_sums: cp.ndarray, + group_nnz: cp.ndarray, + group_sizes: np.ndarray, + *, + n_cells: int, + compute_vars: bool, + total_sums: cp.ndarray | None = None, + total_sq_sums: cp.ndarray | None = None, + total_nnz: cp.ndarray | None = None, +) -> None: + n = cp.asarray(group_sizes, dtype=cp.float64)[:, None] + means = group_sums / n + rg.means = cp.asnumpy(means) + if compute_vars: + group_ss = group_sq_sums - n * means**2 + rg.vars = cp.asnumpy(cp.maximum(group_ss / cp.maximum(n - 1, 1), 0)) + else: + rg.vars = np.zeros_like(rg.means) + rg.pts = cp.asnumpy(group_nnz / n) if rg.comp_pts else None - if return_sorted: - return matrix, sorted_vals - return matrix + n_rest = cp.float64(n_cells) - n + if total_sums is None: + total_sums = group_sums.sum(axis=0, keepdims=True) + rest_sums = total_sums - group_sums + rest_means = rest_sums / n_rest + rg.means_rest = cp.asnumpy(rest_means) + if compute_vars: + if total_sq_sums is None: + total_sq_sums = group_sq_sums.sum(axis=0, keepdims=True) + rest_ss = (total_sq_sums - group_sq_sums) - n_rest * rest_means**2 + rg.vars_rest = cp.asnumpy(cp.maximum(rest_ss / cp.maximum(n_rest - 1, 1), 0)) + else: + rg.vars_rest = np.zeros_like(rg.means_rest) + if rg.comp_pts: + if total_nnz is None: + total_nnz = group_nnz.sum(axis=0, keepdims=True) + rg.pts_rest = cp.asnumpy((total_nnz - group_nnz) / n_rest) + else: + rg.pts_rest = None + rg._compute_stats_in_chunks = False -def _tie_correction(sorted_vals: cp.ndarray) -> cp.ndarray: - """ - Compute tie correction factor for Wilcoxon test. +def _fill_ovo_stats_from_accumulators( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + group_sq_sums_slots: cp.ndarray, + group_nnz_slots: cp.ndarray, + *, + group_sizes: NDArray, + test_group_indices: list[int], + n_ref: int, + compute_vars: bool, +) -> None: + n_test = len(test_group_indices) + n_genes = int(group_sums_slots.shape[1]) + n_groups = len(rg.groups_order) + slot_group_indices = np.empty(n_test + 1, dtype=np.intp) + slot_group_indices[:n_test] = np.asarray(test_group_indices, dtype=np.intp) + slot_group_indices[n_test] = rg.ireference + slot_sizes = np.empty(n_test + 1, dtype=np.float64) + slot_sizes[:n_test] = group_sizes[slot_group_indices[:n_test]] + slot_sizes[n_test] = n_ref + slot_sizes_dev = cp.asarray(slot_sizes, dtype=cp.float64)[:, None] + + rg.means = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.vars = np.zeros((n_groups, n_genes), dtype=np.float64) + rg.pts = np.zeros((n_groups, n_genes), dtype=np.float64) if rg.comp_pts else None + + means_slots = group_sums_slots / slot_sizes_dev + rg.means[slot_group_indices] = cp.asnumpy(means_slots) + if compute_vars: + group_ss = group_sq_sums_slots - slot_sizes_dev * means_slots**2 + denom = cp.maximum(slot_sizes_dev - 1.0, 1.0) + rg.vars[slot_group_indices] = cp.asnumpy(cp.maximum(group_ss / denom, 0)) + if rg.comp_pts: + rg.pts[slot_group_indices] = cp.asnumpy(group_nnz_slots / slot_sizes_dev) - Takes pre-sorted values (column-wise) to avoid re-sorting. - Formula: tc = 1 - sum(t^3 - t) / (n^3 - n) - where t is the count of tied values. - """ - n_rows, n_cols = sorted_vals.shape - correction = cp.ones(n_cols, dtype=cp.float64) + rg.means_rest = None + rg.vars_rest = None + rg.pts_rest = None + rg._compute_stats_in_chunks = False - if n_rows < 2: - return correction - # Ensure F-order - sorted_vals = cp.asfortranarray(sorted_vals) +def _ovo_logfoldchanges_from_sums( + rg: _RankGenes, + group_sums_slots: cp.ndarray, + test_sizes: cp.ndarray, + n_ref: int, +) -> cp.ndarray: + n_test = int(test_sizes.shape[0]) + mean_group = group_sums_slots[:n_test] / test_sizes[:, None] + mean_ref = group_sums_slots[n_test][None, :] / cp.float64(n_ref) + if rg._log1p_base is not None: + scale = cp.float64(np.log(rg._log1p_base)) + group_expr = cp.expm1(mean_group * scale) + ref_expr = cp.expm1(mean_ref * scale) + else: + group_expr = cp.expm1(mean_group) + ref_expr = cp.expm1(mean_ref) + return cp.log2((group_expr + EPS) / (ref_expr + EPS)) + + +def _wilcoxon_scores( + rank_sums: cp.ndarray, + group_sizes: cp.ndarray, + z_scores: cp.ndarray, + *, + return_u_values: bool, +) -> cp.ndarray: + if not return_u_values: + return z_scores + n_group = group_sizes[:, None] + return rank_sums - n_group * (n_group + 1.0) / 2.0 + + +def _host_sparse_fn_and_arrays(module, base_name: str, X, *, support_idx64: bool): + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float64: + is_f64 = True + data_arr = X.data + elif data_dtype == np.float32 or data_dtype.kind in {"b", "i", "u"}: + is_f64 = False + data_arr = X.data.astype(np.float32, copy=False) + else: + msg = ( + "Wilcoxon sparse input data dtype must be float32, float64, bool, " + f"or integer; got {data_dtype}." + ) + raise TypeError(msg) + + is_idx64 = support_idx64 and X.indices.dtype == np.int64 + is_i64 = X.indptr.dtype == np.int64 + suffix = "" + if is_f64: + suffix += "_f64" + if is_idx64: + suffix += "_idx64" + if is_i64: + suffix += "_i64" + fn = getattr(module, base_name + suffix) + indices_arr = X.indices if is_idx64 else X.indices.astype(np.int32, copy=False) + return fn, data_arr, indices_arr + + +def _device_sparse_arrays_i32_f32(X): + data_dtype = np.dtype(X.data.dtype) + if data_dtype == np.float32 or data_dtype == np.float64: + pass + elif data_dtype.kind in {"b", "i", "u"}: + pass + else: + msg = ( + "Wilcoxon device sparse input data dtype must be float32, float64, " + f"bool, or integer; got {data_dtype}." + ) + raise TypeError(msg) + + if X.indptr.dtype != cp.int32: + max_indptr = int(cp.asnumpy(X.indptr[-1])) + if max_indptr > np.iinfo(np.int32).max: + warnings.warn( + "Wilcoxon device sparse path requires int32 indptr for CUDA " + "kernels; falling back to the bounded dense chunk path because " + f"nnz={max_indptr} exceeds int32.", + RuntimeWarning, + stacklevel=3, + ) + return None + data = X.data.astype(cp.float32, copy=False) + indices = X.indices.astype(cp.int32, copy=False) + indptr = X.indptr.astype(cp.int32, copy=False) + return data, indices, indptr - stream = cp.cuda.get_current_stream().ptr - _wc.tie_correction( - sorted_vals, correction, n_rows=n_rows, n_cols=n_cols, stream=stream + +def _column_totals_for_host_matrix( + X, *, compute_sq_sums: bool, compute_nnz: bool +) -> tuple[cp.ndarray, cp.ndarray | None, cp.ndarray | None]: + n_cols = X.shape[1] + if isinstance(X, sp.spmatrix | sp.sparray): + data = np.asarray(X.data) + values = data.astype(np.float64, copy=False) + if X.format == "csc": + indptr = np.asarray(X.indptr) + counts = np.diff(indptr) + nonempty = counts > 0 + starts = indptr[:-1][nonempty] + sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sums[nonempty] = np.add.reduceat(values, starts) + sq_sums = None + if compute_sq_sums: + sq_sums = np.zeros(n_cols, dtype=np.float64) + if starts.size: + sq_sums[nonempty] = np.add.reduceat(values * values, starts) + nnz = None + if compute_nnz: + nnz = np.zeros(n_cols, dtype=np.float64) + if starts.size: + nnz[nonempty] = np.add.reduceat( + (data != 0).astype(np.float64, copy=False), starts + ) + elif X.format == "csr": + indices = np.asarray(X.indices, dtype=np.intp) + sums = np.bincount(indices, weights=values, minlength=n_cols).astype( + np.float64, copy=False + ) + sq_sums = ( + np.bincount(indices, weights=values * values, minlength=n_cols).astype( + np.float64, copy=False + ) + if compute_sq_sums + else None + ) + nnz = ( + np.bincount( + indices, + weights=(data != 0).astype(np.float64, copy=False), + minlength=n_cols, + ).astype(np.float64, copy=False) + if compute_nnz + else None + ) + else: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + else: + raise TypeError(f"Unsupported host matrix type: {type(X)}") + + total_sums = cp.asarray(sums.reshape(1, n_cols), dtype=cp.float64) + total_sq_sums = ( + cp.asarray(sq_sums.reshape(1, n_cols), dtype=cp.float64) + if sq_sums is not None + else None + ) + total_nnz = ( + cp.asarray(nnz.reshape(1, n_cols), dtype=cp.float64) + if nnz is not None + else None ) + return total_sums, total_sq_sums, total_nnz - return correction + +def _host_ovr_totals_if_needed( + X, + group_codes: np.ndarray, + n_groups: int, + *, + compute_sq_sums: bool, + compute_nnz: bool, +) -> tuple[cp.ndarray | None, cp.ndarray | None, cp.ndarray | None]: + if not np.any(group_codes == n_groups): + return None, None, None + return _column_totals_for_host_matrix( + X, compute_sq_sums=compute_sq_sums, compute_nnz=compute_nnz + ) def wilcoxon( @@ -92,8 +408,10 @@ def wilcoxon( tie_correct: bool, use_continuity: bool = False, chunk_size: int | None = None, + return_u_values: bool = False, ) -> list[tuple[int, NDArray, NDArray]]: """Compute Wilcoxon rank-sum test statistics.""" + _maybe_preload_host_dense(rg) # Compute basic stats - uses Aggregate if on GPU, else defers to chunks rg._basic_stats() X = rg.X @@ -110,6 +428,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) # Compare each group against "rest" (all other cells) return _wilcoxon_vs_rest( @@ -121,6 +440,7 @@ def wilcoxon( tie_correct=tie_correct, use_continuity=use_continuity, chunk_size=chunk_size, + return_u_values=return_u_values, ) @@ -134,6 +454,7 @@ def _wilcoxon_vs_rest( tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: """Wilcoxon test: each group vs rest of cells.""" n_groups = len(rg.groups_order) @@ -149,50 +470,233 @@ def _wilcoxon_vs_rest( stacklevel=4, ) - # Build one-hot indicator matrix from group codes - codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int64) - group_matrix = cp.zeros((n_cells, n_groups), dtype=cp.float64) - valid_idx = cp.where(codes_gpu < n_groups)[0] - group_matrix[valid_idx, codes_gpu[valid_idx]] = 1.0 + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + if host_sparse: + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." + ) + + group_codes = rg.group_codes.astype(np.int32, copy=False) + group_sizes_np = group_sizes.astype(np.float64, copy=False) + group_sizes_dev = cp.asarray(group_sizes_np, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + compute_vars = False + compute_nnz = rg.comp_pts + + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + group_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + group_sq_sums = cp.empty( + (n_groups, n_total_genes) if compute_vars else (1, 1), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups, n_total_genes) if compute_nnz else (1, 1), + dtype=cp.float64, + ) + + if X.format == "csc": + csc = X + if not csc.has_sorted_indices: + csc = csc.copy() + csc.sort_indices() + csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovr_sparse_csc_host", csc, support_idx64=True + ) + csc_host_fn( + data_arr, + indices_arr, + csc.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVR_HOST_CSC_SUB_BATCH, + ) + else: + csr = X + if not csr.has_sorted_indices: + csr = csr.copy() + csr.sort_indices() + csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovr_sparse_csr_host", csr, support_idx64=True + ) + csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + group_codes, + group_sizes_np, + rank_sums, + tie_corr, + group_sums, + group_sq_sums, + group_nnz, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVR_HOST_CSR_SUB_BATCH, + ) + + if rg._compute_stats_in_chunks: + total_sums, total_sq_sums, total_nnz = _host_ovr_totals_if_needed( + X, + group_codes, + n_groups, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + ) + _fill_basic_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes_np, + n_cells=n_cells, + compute_vars=compute_vars, + total_sums=total_sums, + total_sq_sums=total_sq_sums, + total_nnz=total_nnz, + ) + + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + variance *= (n_cells + 1) / 12.0 + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores_host = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ).get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + sparse_arrays = _device_sparse_arrays_i32_f32(X) + if sparse_arrays is not None: + data, indices, indptr = sparse_arrays + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) + group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) + rest_sizes = n_cells - group_sizes_dev + rank_sums = cp.empty((n_groups, n_total_genes), dtype=cp.float64) + tie_corr = cp.ones(n_total_genes, dtype=cp.float64) + if cpsp.isspmatrix_csc(X): + _wcs.ovr_sparse_csc_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSC_SUB_BATCH, + ) + else: + sparse_X = X + if not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + data, indices, indptr = _device_sparse_arrays_i32_f32(sparse_X) + _wcs.ovr_sparse_csr_device( + data, + indices, + indptr, + group_codes_gpu, + group_sizes_dev, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_total_genes, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DEVICE_CSR_SUB_BATCH, + ) + + expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 + variance = ( + tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] + ) + variance *= (n_cells + 1) / 12.0 + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores_host = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ).get() + p_host = p_values.get() + return [(gi, scores_host[gi], p_host[gi]) for gi in range(n_groups)] + + group_codes_gpu = cp.asarray(rg.group_codes, dtype=cp.int32) group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64) rest_sizes = n_cells - group_sizes_dev - chunk_width = _choose_chunk_size(chunk_size) + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) # Accumulate results per group all_scores: dict[int, list] = {i: [] for i in range(n_groups)} all_pvals: dict[int, list] = {i: [] for i in range(n_groups)} - # One-time CSR->CSC via fast parallel Numba kernel; _get_column_block - # then uses direct indptr pointer copy for each chunk. - if isinstance(X, sp.spmatrix | sp.sparray): - X = _fast_csr_to_csc(X) if X.format == "csr" else X.tocsc() - for start in range(0, n_total_genes, chunk_width): stop = min(start + chunk_width, n_total_genes) - # Slice and convert to dense GPU array (F-order for column ops) - block = _get_column_block(X, start, stop) - - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_vs_rest( - block, - start, - stop, - group_matrix=group_matrix, - group_sizes_dev=group_sizes_dev, - n_cells=n_cells, - ) - - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) + if rg._compute_stats_in_chunks: + block = _get_column_block(X, start, stop) + rg._accumulate_chunk_stats_vs_rest( + block, + start, + stop, + group_codes_dev=group_codes_gpu, + group_sizes_dev=group_sizes_dev, + n_cells=n_cells, + ) + block_f32 = cp.asfortranarray(block.astype(cp.float32, copy=False)) else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) + block_f32 = _get_dense_column_block_f32(X, start, stop) - rank_sums = group_matrix.T @ ranks + n_cols = stop - start + rank_sums = cp.empty((n_groups, n_cols), dtype=cp.float64) + tie_corr = ( + cp.empty(n_cols, dtype=cp.float64) + if tie_correct + else cp.ones(n_cols, dtype=cp.float64) + ) + _wc.ovr_rank_dense_streaming( + block_f32, + group_codes_gpu, + rank_sums, + tie_corr, + n_rows=n_cells, + n_cols=n_cols, + n_groups=n_groups, + compute_tie_corr=tie_correct, + sub_batch_cols=OVR_DENSE_SUB_BATCH, + stream=cp.cuda.get_current_stream().ptr, + ) expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0 variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None] variance *= (n_cells + 1) / 12.0 @@ -203,12 +707,15 @@ def _wilcoxon_vs_rest( z = diff / std cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, group_sizes_dev, z, return_u_values=return_u_values + ) - z_host = z.get() + scores_host = scores.get() p_host = p_values.get() for idx in range(n_groups): - all_scores[idx].append(z_host[idx]) + all_scores[idx].append(scores_host[idx]) all_pvals[idx].append(p_host[idx]) # Collect results per group @@ -227,98 +734,359 @@ def _wilcoxon_with_reference( tie_correct: bool, use_continuity: bool, chunk_size: int | None, + return_u_values: bool, ) -> list[tuple[int, NDArray, NDArray]]: - """Wilcoxon test: each group vs a specific reference group.""" + """Wilcoxon test: all selected groups vs a specific reference group.""" codes = rg.group_codes - n_ref = int(group_sizes[rg.ireference]) - mask_ref = codes == rg.ireference - - results: list[tuple[int, NDArray, NDArray]] = [] + n_groups = len(rg.groups_order) + ireference = rg.ireference + n_ref = int(group_sizes[ireference]) + ref_row_ids = np.flatnonzero(codes == ireference).astype(np.int32, copy=False) - for group_index in range(len(rg.groups_order)): - if group_index == rg.ireference: - continue + test_group_indices = [i for i in range(n_groups) if i != ireference] + if not test_group_indices: + return [] - n_group = int(group_sizes[group_index]) - n_combined = n_group + n_ref + offsets = [0] + row_id_parts = [] + small_groups = [] + for group_index in test_group_indices: + group_rows = np.flatnonzero(codes == group_index).astype(np.int32, copy=False) + row_id_parts.append(group_rows) + offsets.append(offsets[-1] + int(group_rows.size)) + if int(group_sizes[group_index]) <= MIN_GROUP_SIZE_WARNING: + small_groups.append(str(rg.groups_order[group_index])) - # Warn for small groups - if n_group <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: - warnings.warn( - f"Group {rg.groups_order[group_index]} has size {n_group} " - f"(reference {n_ref}); normal approximation " - "of the Wilcoxon statistic may be inaccurate.", - RuntimeWarning, - stacklevel=4, + if n_ref <= MIN_GROUP_SIZE_WARNING or small_groups: + parts = [] + if small_groups: + parts.append( + f"{len(small_groups)} test group(s) have size " + f"<= {MIN_GROUP_SIZE_WARNING} (first few: " + f"{', '.join(small_groups[:5])}" + f"{'...' if len(small_groups) > 5 else ''})" ) + if n_ref <= MIN_GROUP_SIZE_WARNING: + parts.append(f"reference has size {n_ref}") + warnings.warn( + f"Small groups detected: {'; '.join(parts)}. normal approximation " + "of the Wilcoxon statistic may be inaccurate.", + RuntimeWarning, + stacklevel=4, + ) - # Combined mask: group + reference - mask_obs = codes == group_index - mask_combined = mask_obs | mask_ref - - # Subset matrix ONCE before chunking (10x faster than filtering each chunk) - X_subset = X[mask_combined, :] + all_grp_row_ids = ( + np.concatenate(row_id_parts).astype(np.int32, copy=False) + if row_id_parts + else np.empty(0, dtype=np.int32) + ) + offsets_np = np.asarray(offsets, dtype=np.int32) + offsets_gpu = cp.asarray(offsets_np) + n_all_grp = int(all_grp_row_ids.size) + n_test = len(test_group_indices) + test_sizes = cp.asarray( + group_sizes[np.asarray(test_group_indices, dtype=np.intp)].astype( + np.float64, copy=False + ) + ) - # One-time CSR->CSC via fast parallel Numba kernel - if isinstance(X_subset, sp.spmatrix | sp.sparray): - X_subset = ( - _fast_csr_to_csc(X_subset) - if X_subset.format == "csr" - else X_subset.tocsc() + host_sparse = isinstance(X, sp.spmatrix | sp.sparray) + if host_sparse: + if X.format not in {"csr", "csc"}: + raise TypeError( + "Wilcoxon sparse input must be CSR or CSC; refusing hidden " + f"full-matrix conversion from {X.format!r}." ) - # Within the combined array, True = group cell, False = reference cell - group_mask_gpu = cp.asarray(mask_obs[mask_combined]) - - chunk_width = _choose_chunk_size(chunk_size) - - # Pre-allocate output arrays - scores = np.empty(n_total_genes, dtype=np.float64) - pvals = np.empty(n_total_genes, dtype=np.float64) - - for start in range(0, n_total_genes, chunk_width): - stop = min(start + chunk_width, n_total_genes) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + n_groups_stats = n_test + 1 + compute_vars = False + compute_sums = rg._compute_stats_in_chunks + compute_nnz = rg.comp_pts + group_sums = cp.empty( + (n_groups_stats, n_total_genes) + if (compute_sums or X.format == "csc") + else (1,), + dtype=cp.float64, + ) + group_sq_sums = cp.empty( + (n_groups_stats, n_total_genes) if compute_vars else (1,), + dtype=cp.float64, + ) + group_nnz = cp.empty( + (n_groups_stats, n_total_genes) if compute_nnz else (1,), + dtype=cp.float64, + ) - # Get block for combined cells only - block = _get_column_block(X_subset, start, stop) + stats_code_lookup = np.full(n_groups + 1, n_groups_stats, dtype=np.int32) + test_group_indices_np = np.asarray(test_group_indices, dtype=np.intp) + stats_code_lookup[test_group_indices_np] = np.arange(n_test, dtype=np.int32) + stats_code_lookup[ireference] = n_test + stats_codes = stats_code_lookup[codes] - # Accumulate stats for this chunk - rg._accumulate_chunk_stats_with_ref( - block, - start, - stop, - group_index=group_index, - group_mask_gpu=group_mask_gpu, - n_group=n_group, + if X.format == "csc": + csc = X + if not csc.has_sorted_indices: + csc = csc.copy() + csc.sort_indices() + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + csc_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovo_streaming_csc_host", csc, support_idx64=True + ) + csc_host_fn( + data_arr, + indices_arr, + csc.indptr, + ref_row_map, + grp_row_map, + offsets_np, + stats_codes, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_rows=X.shape[0], + n_cols=n_total_genes, + n_groups=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, + ) + else: + csr = X + # Host CSR gather scans each row's native index list and tolerates + # unsorted row indices; avoid a full CSR copy just to sort. + csr_host_fn, data_arr, indices_arr = _host_sparse_fn_and_arrays( + _wcs, "ovo_streaming_csr_host", csr, support_idx64=True + ) + csr_host_fn( + data_arr, + indices_arr, + csr.indptr, + ref_row_ids.astype(np.int32, copy=False), + all_grp_row_ids.astype(np.int32, copy=False), + offsets_np, + rank_sums, + tie_corr_arr, + group_sums, + group_sq_sums, + group_nnz, + n_full_rows=X.shape[0], n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_test=n_test, + n_groups_stats=n_groups_stats, + compute_tie_corr=tie_correct, + compute_sq_sums=compute_vars, + compute_nnz=compute_nnz, + compute_sums=compute_sums, + sub_batch_cols=OVO_HOST_SPARSE_SUB_BATCH, ) - # Ranks for combined group+reference cells - if tie_correct: - ranks, sorted_vals = _average_ranks(block, return_sorted=True) - tie_corr = _tie_correction(sorted_vals) + logfoldchanges_gpu = None + if rg._compute_stats_in_chunks: + if rg._store_wilcoxon_gpu_result and not rg.comp_pts: + logfoldchanges_gpu = _ovo_logfoldchanges_from_sums( + rg, + group_sums, + test_sizes, + n_ref, + ) + rg._compute_stats_in_chunks = False else: - ranks = _average_ranks(block) - tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64) + _fill_ovo_stats_from_accumulators( + rg, + group_sums, + group_sq_sums, + group_nnz, + group_sizes=group_sizes, + test_group_indices=test_group_indices, + n_ref=n_ref, + compute_vars=compute_vars, + ) - # Rank sum for the group - rank_sums = (ranks * group_mask_gpu[:, None]).sum(axis=0) + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / cp.sqrt(variance) + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + logfoldchanges_gpu, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + if cpsp.isspmatrix_csc(X) or cpsp.isspmatrix_csr(X): + sparse_X = X + if cpsp.isspmatrix_csr(sparse_X) and not sparse_X.has_sorted_indices: + sparse_X = sparse_X.copy() + sparse_X.sort_indices() + sparse_arrays = _device_sparse_arrays_i32_f32(sparse_X) + if sparse_arrays is not None: + data, indices, indptr = sparse_arrays + offsets_gpu = cp.asarray(offsets_np, dtype=cp.int32) + rank_sums = cp.empty((n_test, n_total_genes), dtype=cp.float64) + tie_corr_arr = cp.ones((n_test, n_total_genes), dtype=cp.float64) + + if cpsp.isspmatrix_csc(sparse_X): + ref_row_map = np.full(X.shape[0], -1, dtype=np.int32) + ref_row_map[ref_row_ids] = np.arange(n_ref, dtype=np.int32) + grp_row_map = np.full(X.shape[0], -1, dtype=np.int32) + grp_row_map[all_grp_row_ids] = np.arange(n_all_grp, dtype=np.int32) + _wcs.ovo_streaming_csc_device( + data, + indices, + indptr, + cp.asarray(ref_row_map), + cp.asarray(grp_row_map), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) + else: + _wcs.ovo_streaming_csr_device( + data, + indices, + indptr, + cp.asarray(ref_row_ids, dtype=cp.int32), + cp.asarray(all_grp_row_ids, dtype=cp.int32), + offsets_gpu, + rank_sums, + tie_corr_arr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_total_genes, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DEVICE_SPARSE_SUB_BATCH, + ) - # Wilcoxon z-score formula for two groups - expected = n_group * (n_combined + 1) / 2.0 - variance = tie_corr * n_group * n_ref * (n_combined + 1) / 12.0 - std = cp.sqrt(variance) + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr_arr diff = rank_sums - expected if use_continuity: diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) - z = diff / std + z = diff / cp.sqrt(variance) cp.nan_to_num(z, copy=False) p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + if rg._store_wilcoxon_gpu_result: + rg._wilcoxon_gpu_result = ( + np.asarray(test_group_indices, dtype=np.intp), + scores, + p_values, + None, + ) + return [] + scores_host = scores.get() + p_host = p_values.get() + return [ + (group_index, scores_host[slot], p_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] + + chunk_width = _choose_wilcoxon_chunk_size(chunk_size, n_total_genes) - # Fill pre-allocated arrays - scores[start:stop] = z.get() - pvals[start:stop] = p_values.get() + scores_host = np.empty((n_test, n_total_genes), dtype=np.float64) + pvals_host = np.empty((n_test, n_total_genes), dtype=np.float64) - results.append((group_index, scores, pvals)) + for start in range(0, n_total_genes, chunk_width): + stop = min(start + chunk_width, n_total_genes) + n_cols = stop - start - return results + ref_block = _extract_dense_rows_cols(X, ref_row_ids, start, stop) + grp_block = _extract_dense_rows_cols(X, all_grp_row_ids, start, stop) + + _fill_ovo_chunk_stats( + rg, + ref_block, + grp_block, + offsets=offsets_np, + test_group_indices=test_group_indices, + start=start, + stop=stop, + group_sizes=group_sizes, + ) + + ref_f32 = cp.asarray(ref_block, dtype=cp.float32, order="F") + grp_f32 = cp.asarray(grp_block, dtype=cp.float32, order="F") + rank_sums = cp.empty((n_test, n_cols), dtype=cp.float64) + tie_corr = cp.empty((n_test, n_cols), dtype=cp.float64) + + _wc.ovo_rank_dense_tiered_unsorted_ref( + ref_f32, + grp_f32, + offsets_gpu, + rank_sums, + tie_corr, + n_ref=n_ref, + n_all_grp=n_all_grp, + n_cols=n_cols, + n_groups=n_test, + compute_tie_corr=tie_correct, + sub_batch_cols=OVO_DENSE_TIERED_SUB_BATCH, + stream=cp.cuda.get_current_stream().ptr, + ) + + n_combined = test_sizes + n_ref + expected = test_sizes[:, None] * (n_combined[:, None] + 1) / 2.0 + variance = test_sizes[:, None] * n_ref * (n_combined[:, None] + 1) / 12.0 + if tie_correct: + variance = variance * tie_corr + std = cp.sqrt(variance) + diff = rank_sums - expected + if use_continuity: + diff = cp.sign(diff) * cp.maximum(cp.abs(diff) - 0.5, 0.0) + z = diff / std + cp.nan_to_num(z, copy=False) + p_values = cupyx_special.erfc(cp.abs(z) * cp.float64(cp.sqrt(0.5))) + scores = _wilcoxon_scores( + rank_sums, test_sizes, z, return_u_values=return_u_values + ) + + scores_host[:, start:stop] = scores.get() + pvals_host[:, start:stop] = p_values.get() + + return [ + (group_index, scores_host[slot], pvals_host[slot]) + for slot, group_index in enumerate(test_group_indices) + ] diff --git a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py index fa4bbccf..14793834 100644 --- a/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py +++ b/src/rapids_singlecell/tools/_rank_genes_groups/_wilcoxon_binned.py @@ -11,6 +11,8 @@ from rapids_singlecell._compat import DaskArray from rapids_singlecell._cuda import _wilcoxon_binned_cuda as _wb +from ._utils import MIN_GROUP_SIZE_WARNING + if TYPE_CHECKING: from numpy.typing import NDArray @@ -102,7 +104,7 @@ def wilcoxon_binned( ``'log1p'`` uses a fixed [0, 15] range suitable for log1p-normalized data. ``'auto'`` computes the actual (min, max) of the data. Use this - for z-scored or unnormalized data. + for nonnegative expression data outside the fixed log1p range. """ if not rg.is_log1p: warnings.warn( @@ -119,20 +121,6 @@ def wilcoxon_binned( if n_bins is None: n_bins = _DASK_N_BINS if isinstance(X, DaskArray) else _DEFAULT_N_BINS - # Sparse kernels assume non-negative data (pre-fill+correct pattern). - # Dense kernel handles any range. - # NOTE: Dask sparse is not validated here because checking .data.min() - # would require materializing all blocks. The sparse histogram kernels - # will silently produce incorrect results for negative Dask sparse data. - if not isinstance(X, DaskArray) and cpsp.issparse(X) and X.nnz > 0: - if float(X.data.min()) < 0: - msg = ( - "Sparse input contains negative values. The sparse histogram " - "kernels assume non-negative data. Convert to dense or use " - "bin_range='auto' with a dense array." - ) - raise ValueError(msg) - n_groups = len(rg.groups_order) n_cells, n_genes = X.shape group_sizes = rg.group_sizes @@ -173,7 +161,7 @@ def wilcoxon_binned( ): if gi == ireference: continue - if size <= 25 or n_ref <= 25: + if size <= MIN_GROUP_SIZE_WARNING or n_ref <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (reference {n_ref}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", @@ -183,7 +171,7 @@ def wilcoxon_binned( else: for name, size in zip(rg.groups_order, group_sizes, strict=True): rest = n_cells - size - if size <= 25 or rest <= 25: + if size <= MIN_GROUP_SIZE_WARNING or rest <= MIN_GROUP_SIZE_WARNING: warnings.warn( f"Group {name} has size {size} (rest {rest}); normal " "approximation of the Wilcoxon statistic may be inaccurate.", diff --git a/tests/test_rank_genes_groups_ttest.py b/tests/test_rank_genes_groups_ttest.py index e1684536..7f109e24 100644 --- a/tests/test_rank_genes_groups_ttest.py +++ b/tests/test_rank_genes_groups_ttest.py @@ -1,5 +1,6 @@ from __future__ import annotations +import anndata as ad import numpy as np import pytest import scanpy as sc @@ -10,6 +11,10 @@ import rapids_singlecell as rsc +def _make_nonnegative(adata): + adata.X = np.abs(adata.X) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("method", ["t-test", "t-test_overestim_var"]) @pytest.mark.parametrize("sparse", [True, False]) @@ -18,11 +23,15 @@ def test_rank_genes_groups_ttest_matches_scanpy(reference, method, sparse): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) if sparse: + adata_gpu.X = adata_gpu.X.astype(np.float32) adata_gpu.X = sp.csr_matrix(adata_gpu.X) adata_cpu = adata_gpu.copy() + if sparse: + adata_cpu.X = adata_cpu.X.astype(np.float64) rsc.tl.rank_genes_groups( adata_gpu, @@ -75,6 +84,7 @@ def test_rank_genes_groups_ttest_honors_layer_and_use_raw(reference, method): np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) base.obs["blobs"] = base.obs["blobs"].astype("category") + _make_nonnegative(base) base.layers["signal"] = base.X.copy() ref_adata = base.copy() @@ -123,6 +133,7 @@ def test_rank_genes_groups_ttest_subset_and_bonferroni(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -218,6 +229,7 @@ def test_rank_genes_groups_ttest_with_renamed_categories( np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # First run with original category names rsc.tl.rank_genes_groups(adata, "blobs", method=method, reference=reference_before) @@ -246,6 +258,7 @@ def test_rank_genes_groups_ttest_with_unsorted_groups(reference, method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) bdata = adata.copy() groups = ["0", "1", "2", "3"] if reference != "rest" else ["0", "2", "3"] @@ -285,6 +298,7 @@ def test_rank_genes_groups_ttest_pts(reference, method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() # Run with pts=True @@ -346,8 +360,6 @@ def test_rank_genes_groups_ttest_direct_scipy(): Creates a simple two-group dataset and compares rapids_singlecell t-test directly against scipy.stats.ttest_ind without intermediate statistics. """ - import anndata as ad - np.random.seed(42) n_group1, n_group2, n_genes = 50, 60, 20 @@ -357,6 +369,9 @@ def test_rank_genes_groups_ttest_direct_scipy(): # Combine into AnnData X = np.vstack([X_group1, X_group2]) + X -= X.min() + X_group1 = X[:n_group1] + X_group2 = X[n_group1:] obs = {"group": ["A"] * n_group1 + ["B"] * n_group2} adata = ad.AnnData(X=X, obs=obs) adata.obs["group"] = adata.obs["group"].astype("category") @@ -399,6 +414,7 @@ def test_rank_genes_groups_ttest_matches_scipy(): adata = pbmc68k_reduced() # Convert to float64 for maximum precision in comparison adata.X = adata.X.astype(np.float64) + _make_nonnegative(adata) # Run rapids_singlecell t-test rsc.tl.rank_genes_groups(adata, "bulk_labels", method="t-test", use_raw=False) @@ -461,6 +477,7 @@ def test_rank_genes_groups_ttest_mask_var_array(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Create mask to select only first 5 genes mask = np.array([True] * 5 + [False] * 5) @@ -488,6 +505,7 @@ def test_rank_genes_groups_ttest_mask_var_string(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=150) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) # Add mask column to adata.var adata.var["highly_variable"] = [True] * 6 + [False] * 4 @@ -514,6 +532,7 @@ def test_rank_genes_groups_ttest_mask_var_matches_scanpy(method): np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=8, n_centers=3, n_observations=150) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + _make_nonnegative(adata_gpu) adata_cpu = adata_gpu.copy() mask = np.array([True, False, True, False, True, True, False, True]) @@ -546,6 +565,7 @@ def test_rank_genes_groups_ttest_rankby_abs(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) adata_abs = adata.copy() # Run without rankby_abs @@ -573,6 +593,7 @@ def test_rank_genes_groups_ttest_key_added(method): np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) adata.obs["blobs"] = adata.obs["blobs"].astype("category") + _make_nonnegative(adata) custom_key = "my_custom_key" diff --git a/tests/test_rank_genes_groups_wilcoxon.py b/tests/test_rank_genes_groups_wilcoxon.py index 7f32f0e5..af39da54 100644 --- a/tests/test_rank_genes_groups_wilcoxon.py +++ b/tests/test_rank_genes_groups_wilcoxon.py @@ -1,16 +1,284 @@ from __future__ import annotations import cupy as cp +import cupyx.scipy.sparse as cpsp import numpy as np import pandas as pd import pytest import scanpy as sc import scipy.sparse as sp -from scipy.stats import mannwhitneyu, rankdata, tiecorrect +from scipy.stats import mannwhitneyu import rapids_singlecell as rsc +def _to_format(X_dense, fmt): + if fmt == "numpy_dense": + return np.asarray(X_dense) + if fmt == "scipy_csr": + return sp.csr_matrix(X_dense) + if fmt == "scipy_csc": + return sp.csc_matrix(X_dense) + if fmt == "cupy_dense": + return cp.asarray(X_dense) + if fmt == "cupy_csr": + return cpsp.csr_matrix(cp.asarray(X_dense)) + if fmt == "cupy_csc": + return cpsp.csc_matrix(cp.asarray(X_dense)) + raise ValueError(f"Unknown format: {fmt}") + + +def _make_nonnegative(adata): + adata.X = np.abs(np.asarray(adata.X)).astype(np.float32) + return adata + + +@pytest.mark.parametrize( + "method", + ["t-test", "t-test_overestim_var", "wilcoxon", "wilcoxon_binned", "logreg"], +) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc", "cupy_csr", "cupy_csc"]) +def test_rank_genes_groups_sparse_negative_values_raise(method, fmt): + X = np.array( + [ + [-1.0, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.float32, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(ValueError, match="Sparse input contains negative values"): + rsc.tl.rank_genes_groups(adata, "group", method=method, use_raw=False) + + +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_dense", "cupy_csr"]) +def test_rank_genes_groups_complex_values_raise(fmt): + X = np.array( + [ + [1.0 + 0.0j, 0.0, 2.0], + [0.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [0.0, 3.0, 0.0], + ], + dtype=np.complex64, + ) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame( + {"group": pd.Categorical(["a", "a", "b", "b"], categories=["a", "b"])} + ), + var=pd.DataFrame(index=["g0", "g1", "g2"]), + ) + + with pytest.raises(TypeError, match="complex expression values"): + rsc.tl.rank_genes_groups(adata, "group", method="wilcoxon", use_raw=False) + + +def test_device_sparse_int64_indptr_overflow_warns(): + from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( + _device_sparse_arrays_i32_f32, + ) + + class FakeSparse: + data = cp.asarray([1.0], dtype=cp.float32) + indices = cp.asarray([0], dtype=cp.int32) + indptr = cp.asarray([0, np.iinfo(np.int32).max + 1], dtype=cp.int64) + + with pytest.warns(RuntimeWarning, match="requires int32 indptr"): + assert _device_sparse_arrays_i32_f32(FakeSparse()) is None + + +def test_rank_genes_groups_structured_results_get_df_and_h5ad_match_scanpy(tmp_path): + np.random.seed(42) + adata_rsc = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=120) + _make_nonnegative(adata_rsc) + adata_rsc.obs["blobs"] = adata_rsc.obs["blobs"].astype("category") + adata_rsc.X = sp.csr_matrix(adata_rsc.X) + adata_cpu = adata_rsc.copy() + adata_cpu.X = adata_cpu.X.toarray() + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "reference": "1", + "use_raw": False, + "tie_correct": True, + "n_genes": 4, + } + rsc.tl.rank_genes_groups(adata_rsc, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + rsc_result = adata_rsc.uns["rank_genes_groups"] + assert isinstance(rsc_result["names"], np.ndarray) + assert rsc_result["names"].dtype.names == ("0", "2") + assert tuple(rsc_result["names"][0]) == tuple( + adata_cpu.uns["rank_genes_groups"]["names"][0] + ) + np.testing.assert_array_equal( + rsc_result["names"].copy(), + np.asarray(rsc_result["names"]), + ) + + h5ad_path = tmp_path / "rank_genes_groups.h5ad" + adata_rsc.write_h5ad(h5ad_path) + adata_rsc = sc.read_h5ad(h5ad_path) + + rsc_df = sc.get.rank_genes_groups_df(adata_rsc, group=None) + scanpy_df = sc.get.rank_genes_groups_df(adata_cpu, group=None) + pd.testing.assert_frame_equal(rsc_df, scanpy_df) + + +def test_rank_genes_groups_return_format_removed(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(TypeError, match="return_format has been removed"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="wilcoxon", + use_raw=False, + return_format="arrays", + ) + + +@pytest.mark.parametrize("reference", ["rest", "b"]) +@pytest.mark.parametrize("fmt", ["numpy_dense", "scipy_csr", "cupy_csr"]) +def test_rank_genes_groups_wilcoxon_return_u_values(reference, fmt): + X = np.array( + [ + [5.0, 0.0, 1.0, 2.0], + [4.0, 0.0, 1.0, 2.0], + [1.0, 3.0, 2.0, 2.0], + [0.0, 2.0, 2.0, 2.0], + [2.0, 1.0, 0.0, 3.0], + [3.0, 1.0, 0.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "b", "b", "c", "c"]) + adata = sc.AnnData( + X=_to_format(X, fmt), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference=reference, + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + result = adata.uns["rank_genes_groups"] + assert result["params"]["return_u_values"] is True + assert result["scores"].dtype["a"] == np.dtype("float64") + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + mask_group = labels == "a" + mask_ref = labels != "a" if reference == "rest" else labels == reference + expected = np.array( + [ + mannwhitneyu( + X[mask_group, gene], + X[mask_ref, gene], + alternative="two-sided", + ).statistic + for gene in range(X.shape[1]) + ], + dtype=np.float64, + ) + + gene_to_idx = {name: idx for idx, name in enumerate(adata.var_names)} + expected_sorted = np.array([expected[gene_to_idx[name]] for name in df["names"]]) + np.testing.assert_allclose(df["scores"].to_numpy(), expected_sorted) + + +def test_rank_genes_groups_wilcoxon_dense_edge_cases_match_scipy(): + X = np.array( + [ + [1.0, 5.0, 0.0, 2.0, 1.0], + [2.0, 5.0, 0.0, 2.0, 1.0], + [3.0, 5.0, 1.0, 2.0, 1.0], + [4.0, 5.0, 1.0, 3.0, 2.0], + [5.0, 5.0, 1.0, 3.0, 2.0], + [6.0, 5.0, 2.0, 3.0, 2.0], + [7.0, 5.0, 2.0, 4.0, 3.0], + [8.0, 5.0, 2.0, 4.0, 3.0], + ], + dtype=np.float32, + ) + labels = np.array(["a", "a", "a", "a", "b", "b", "b", "b"]) + adata = sc.AnnData( + X=X, + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=["no_ties", "all_ties", "zero_ties", "mixed", "pairs"]), + ) + rsc.tl.rank_genes_groups( + adata, + "group", + groups=["a"], + reference="b", + method="wilcoxon", + use_raw=False, + tie_correct=True, + use_continuity=True, + return_u_values=True, + n_genes=adata.n_vars, + ) + + df = sc.get.rank_genes_groups_df(adata, group="a").sort_values("names") + expected_u = {} + for idx, name in enumerate(adata.var_names): + result = mannwhitneyu( + X[labels == "a", idx], + X[labels == "b", idx], + alternative="two-sided", + method="asymptotic", + use_continuity=True, + ) + expected_u[name] = result.statistic + + np.testing.assert_allclose( + df["scores"].to_numpy(), + np.array([expected_u[name] for name in df["names"]]), + rtol=1e-13, + atol=1e-15, + ) + assert np.isfinite(df["pvals"]).all() + + +def test_rank_genes_groups_return_u_values_requires_wilcoxon(): + adata = sc.datasets.blobs(n_variables=3, n_centers=2, n_observations=20) + _make_nonnegative(adata) + adata.obs["blobs"] = adata.obs["blobs"].astype("category") + + with pytest.raises(ValueError, match="only supported for method='wilcoxon'"): + rsc.tl.rank_genes_groups( + adata, + "blobs", + method="t-test", + use_raw=False, + return_u_values=True, + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) @pytest.mark.parametrize("tie_correct", [True, False]) @pytest.mark.parametrize("sparse", [True, False]) @@ -18,6 +286,7 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars """Test wilcoxon matches scanpy output across configurations.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") if sparse: @@ -55,11 +324,13 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): gpu_field = gpu_result[field] cpu_field = cpu_result[field] + rtol = 1e-13 assert gpu_field.dtype.names == cpu_field.dtype.names for group in gpu_field.dtype.names: gpu_values = np.asarray(gpu_field[group], dtype=float) cpu_values = np.asarray(cpu_field[group], dtype=float) - np.testing.assert_allclose(gpu_values, cpu_values, rtol=1e-13, atol=1e-15) + atol = 1e-15 + np.testing.assert_allclose(gpu_values, cpu_values, rtol=rtol, atol=atol) params = gpu_result["params"] assert params["use_raw"] is False @@ -69,11 +340,46 @@ def test_rank_genes_groups_wilcoxon_matches_scanpy(reference, tie_correct, spars assert params["reference"] == reference +def test_rank_genes_groups_wilcoxon_dense_ovr_ties_match_scanpy(): + rng = np.random.default_rng(16) + X = rng.integers(0, 40, size=(128, 7)).astype(np.float32) + labels = rng.integers(0, 7, size=128).astype(str) + adata_gpu = sc.AnnData( + X=X.copy(), + obs=pd.DataFrame({"group": pd.Categorical(labels)}), + var=pd.DataFrame(index=[f"g{i}" for i in range(X.shape[1])]), + ) + adata_cpu = adata_gpu.copy() + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "rest", + "use_raw": False, + "tie_correct": True, + "n_genes": adata_gpu.n_vars, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for group in gpu_result["scores"].dtype.names: + assert list(gpu_result["names"][group]) == list(cpu_result["names"][group]) + np.testing.assert_allclose( + gpu_result["scores"][group], cpu_result["scores"][group], rtol=1e-13 + ) + np.testing.assert_allclose( + gpu_result["pvals"][group], cpu_result["pvals"][group], rtol=1e-13 + ) + + @pytest.mark.parametrize("reference", ["rest", "1"]) def test_rank_genes_groups_wilcoxon_honors_layer_and_use_raw(reference): """Test that layer parameter is respected.""" np.random.seed(42) base = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=150) + _make_nonnegative(base) base.obs["blobs"] = base.obs["blobs"].astype("category") base.layers["signal"] = base.X.copy() @@ -121,6 +427,7 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): """Test group subsetting and bonferroni correction.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=5, n_centers=4, n_observations=150) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") groups = ["0", "1", "2"] if reference != "rest" else ["0", "2"] @@ -148,6 +455,233 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference): assert np.all(adjusted <= 1.0) +def test_rank_genes_groups_wilcoxon_skip_empty_groups_filters_singletons(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=21) + _make_nonnegative(adata) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["valid"] * 10 + ["singleton"], + categories=["ref", "valid", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + n_genes=3, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert result["names"].dtype.names == ("valid",) + assert result["scores"].dtype.names == ("valid",) + + +def test_rank_genes_groups_wilcoxon_skip_empty_groups_all_tests_filtered(): + np.random.seed(42) + adata = sc.datasets.blobs(n_variables=5, n_centers=2, n_observations=11) + _make_nonnegative(adata) + adata.obs["target"] = pd.Categorical( + ["ref"] * 10 + ["singleton"], + categories=["ref", "singleton", "empty"], + ) + + rsc.tl.rank_genes_groups( + adata, + "target", + method="wilcoxon", + reference="ref", + use_raw=False, + skip_empty_groups=True, + ) + + result = adata.uns["rank_genes_groups"] + assert "names" not in result + assert result["params"]["reference"] == "ref" + + +@pytest.mark.parametrize( + "fmt", + [ + pytest.param("scipy_csr", id="host_csr"), + pytest.param("scipy_csc", id="host_csc"), + pytest.param("cupy_dense", id="device_dense"), + ], +) +def test_wilcoxon_subset_rest_stats_match_scanpy(fmt): + """groups=... with reference='rest' must use all other cells for stats.""" + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=160) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "groups": ["0", "2"], + "reference": "rest", + "pts": True, + "n_genes": 6, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-13 + atol = 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + for key in ("pts", "pts_rest"): + gpu_pts = gpu_result[key] + cpu_pts = cpu_result[key] + for col in gpu_pts.columns: + np.testing.assert_allclose( + gpu_pts[col].values, cpu_pts[col].values, rtol=1e-13, atol=1e-15 + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize("fmt", ["scipy_csr", "scipy_csc"]) +def test_wilcoxon_zero_nnz_host_sparse_does_not_crash(reference, fmt): + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["0"] * 4 + ["1"] * 4 + ["2"] * 4, + categories=["0", "1", "2"], + ) + } + ) + adata = sc.AnnData( + X=_to_format(np.zeros((12, 5), dtype=np.float32), fmt), + obs=obs, + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + ) + + rsc.tl.rank_genes_groups( + adata, + "group", + method="wilcoxon", + use_raw=False, + reference=reference, + pts=True, + ) + + result = adata.uns["rank_genes_groups"] + for field in ("scores", "pvals"): + for group in result[field].dtype.names: + assert np.all(np.isfinite(np.asarray(result[field][group], dtype=float))) + + +def test_wilcoxon_ovo_host_csr_unsorted_indices_match_sorted(): + rng = np.random.default_rng(42) + dense = rng.poisson(1.0, size=(80, 12)).astype(np.float32) + dense[rng.random(dense.shape) < 0.55] = 0 + sorted_csr = sp.csr_matrix(dense) + unsorted_csr = sorted_csr.copy() + for row in range(unsorted_csr.shape[0]): + start, stop = unsorted_csr.indptr[row : row + 2] + order = np.arange(stop - start)[::-1] + unsorted_csr.indices[start:stop] = unsorted_csr.indices[start:stop][order] + unsorted_csr.data[start:stop] = unsorted_csr.data[start:stop][order] + unsorted_csr.has_sorted_indices = False + + obs = pd.DataFrame( + { + "group": pd.Categorical( + ["ref"] * 20 + ["a"] * 20 + ["b"] * 20 + ["c"] * 20, + categories=["ref", "a", "b", "c"], + ) + } + ) + var = pd.DataFrame(index=[f"g{i}" for i in range(dense.shape[1])]) + sorted_adata = sc.AnnData(X=sorted_csr, obs=obs.copy(), var=var.copy()) + unsorted_adata = sc.AnnData(X=unsorted_csr, obs=obs.copy(), var=var.copy()) + + kw = { + "groupby": "group", + "method": "wilcoxon", + "reference": "ref", + "use_raw": False, + "tie_correct": True, + "n_genes": dense.shape[1], + } + rsc.tl.rank_genes_groups(sorted_adata, **kw) + rsc.tl.rank_genes_groups(unsorted_adata, **kw) + + sorted_result = sorted_adata.uns["rank_genes_groups"] + unsorted_result = unsorted_adata.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + for group in sorted_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(unsorted_result[field][group], dtype=float), + np.asarray(sorted_result[field][group], dtype=float), + rtol=1e-13, + atol=1e-15, + equal_nan=True, + ) + + +@pytest.mark.parametrize("reference", ["rest", "1"]) +@pytest.mark.parametrize( + "fmt", + [ + "numpy_dense", + "scipy_csr", + "scipy_csc", + "cupy_dense", + "cupy_csr", + "cupy_csc", + ], +) +@pytest.mark.parametrize("pre_load", [False, True]) +def test_wilcoxon_all_public_formats_match_scanpy(reference, fmt, pre_load): + np.random.seed(42) + adata_gpu = sc.datasets.blobs(n_variables=5, n_centers=3, n_observations=120) + _make_nonnegative(adata_gpu) + adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") + adata_cpu = adata_gpu.copy() + adata_gpu.X = _to_format(adata_gpu.X, fmt) + + kw = { + "groupby": "blobs", + "method": "wilcoxon", + "use_raw": False, + "reference": reference, + "tie_correct": True, + "n_genes": 5, + } + rsc.tl.rank_genes_groups(adata_gpu, **kw, pre_load=pre_load) + sc.tl.rank_genes_groups(adata_cpu, **kw) + + gpu_result = adata_gpu.uns["rank_genes_groups"] + cpu_result = adata_cpu.uns["rank_genes_groups"] + for field in ("scores", "logfoldchanges", "pvals", "pvals_adj"): + rtol = 1e-13 + atol = 1e-15 + for group in gpu_result[field].dtype.names: + np.testing.assert_allclose( + np.asarray(gpu_result[field][group], dtype=float), + np.asarray(cpu_result[field][group], dtype=float), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + @pytest.mark.parametrize( ("groups", "reference"), [ @@ -221,6 +755,7 @@ def test_rank_genes_groups_wilcoxon_with_renamed_categories( """Test with renamed category labels.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") # First run with original category names @@ -252,6 +787,7 @@ def test_rank_genes_groups_wilcoxon_with_unsorted_groups(reference): """Test that group order doesn't affect results.""" np.random.seed(42) adata = sc.datasets.blobs(n_variables=6, n_centers=4, n_observations=180) + _make_nonnegative(adata) adata.obs["blobs"] = adata.obs["blobs"].astype("category") bdata = adata.copy() @@ -291,6 +827,7 @@ def test_rank_genes_groups_wilcoxon_pts(reference, pre_load): """Test that pts (fraction of cells expressing) is computed correctly.""" np.random.seed(42) adata_gpu = sc.datasets.blobs(n_variables=6, n_centers=3, n_observations=200) + _make_nonnegative(adata_gpu) adata_gpu.obs["blobs"] = adata_gpu.obs["blobs"].astype("category") adata_cpu = adata_gpu.copy() @@ -504,188 +1041,3 @@ def test_sparse_matches_dense(self, perturbation_adata, sparse): np.testing.assert_array_equal( dense_df["pvals"].values, sparse_df["pvals"].values ) - - -# ============================================================================ -# Tests for ranking and tie correction kernels (edge cases from scipy) -# ============================================================================ - - -class TestRankingKernel: - """Tests for _average_ranks based on scipy.stats.rankdata edge cases.""" - - @pytest.fixture - def average_ranks(self): - """Import the ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - ) - - return _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_basic_ranking(self, average_ranks): - """Test basic average ranking on simple data.""" - values = [3.0, 1.0, 2.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_all_ties(self, average_ranks): - """All identical values should get the average rank.""" - values = [5.0, 5.0, 5.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_no_ties(self, average_ranks): - """All unique values should get sequential ranks.""" - values = [1.0, 2.0, 3.0, 4.0, 5.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_mixed_ties(self, average_ranks): - """Mix of ties and unique values.""" - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_negative_values(self, average_ranks): - """Test with negative values.""" - values = [-3.0, -1.0, -2.0, 0.0, 1.0] - result = average_ranks(self._to_gpu(values)) - expected = rankdata(values, method="average") - np.testing.assert_allclose(result.get().flatten(), expected) - - def test_single_element(self, average_ranks): - """Single element should have rank 1.""" - values = [42.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.0]) - - def test_two_elements_tied(self, average_ranks): - """Two tied elements should both have rank 1.5.""" - values = [7.0, 7.0] - result = average_ranks(self._to_gpu(values)) - np.testing.assert_allclose(result.get().flatten(), [1.5, 1.5]) - - def test_multiple_columns(self, average_ranks): - """Test ranking across multiple columns independently.""" - col0 = [3.0, 1.0, 2.0] - col1 = [1.0, 1.0, 2.0] - data = np.column_stack([col0, col1]).astype(np.float64) - result = average_ranks(cp.asarray(data, order="F")) - - np.testing.assert_allclose(result.get()[:, 0], rankdata(col0, method="average")) - np.testing.assert_allclose(result.get()[:, 1], rankdata(col1, method="average")) - - -class TestTieCorrectionKernel: - """Tests for _tie_correction based on scipy.stats.tiecorrect edge cases.""" - - @pytest.fixture - def tie_correction(self): - """Import the tie correction function and ranking function.""" - from rapids_singlecell.tools._rank_genes_groups._wilcoxon import ( - _average_ranks, - _tie_correction, - ) - - return _tie_correction, _average_ranks - - @staticmethod - def _to_gpu(values): - """Convert 1D values to GPU column matrix with F-order.""" - arr = np.asarray(values, dtype=np.float64).reshape(-1, 1) - return cp.asarray(arr, order="F") - - def test_no_ties(self, tie_correction): - """No ties should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 3.0, 4.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_all_ties(self, tie_correction): - """All tied values should give correction factor 0.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [5.0, 5.0, 5.0, 5.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_mixed_ties(self, tie_correction): - """Mix of ties should give intermediate correction factor.""" - _tie_correction, _average_ranks = tie_correction - - values = [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_two_elements_tied(self, tie_correction): - """Two tied elements.""" - _tie_correction, _average_ranks = tie_correction - - values = [7.0, 7.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10) - - def test_single_element(self, tie_correction): - """Single element should give correction factor 1.0.""" - _tie_correction, _average_ranks = tie_correction - - values = [42.0] - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - # Single element: n^3 - n = 0, so formula gives 1.0 - np.testing.assert_allclose(result.get()[0], 1.0, rtol=1e-10) - - def test_multiple_columns(self, tie_correction): - """Test tie correction across multiple columns independently.""" - _tie_correction, _average_ranks = tie_correction - - col0 = [1.0, 2.0, 3.0] # No ties - col1 = [5.0, 5.0, 5.0] # All ties - data = np.column_stack([col0, col1]).astype(np.float64) - _, sorted_vals = _average_ranks(cp.asarray(data, order="F"), return_sorted=True) - result = _tie_correction(sorted_vals) - - np.testing.assert_allclose( - result.get()[0], tiecorrect(rankdata(col0)), rtol=1e-10 - ) - np.testing.assert_allclose( - result.get()[1], tiecorrect(rankdata(col1)), rtol=1e-10 - ) - - def test_large_tie_groups(self, tie_correction): - """Test with large tie groups.""" - _tie_correction, _average_ranks = tie_correction - - # 50 values of 1, 50 values of 2 (non-multiple of 32 to test warp handling) - values = [1.0] * 50 + [2.0] * 50 - _, sorted_vals = _average_ranks(self._to_gpu(values), return_sorted=True) - result = _tie_correction(sorted_vals) - - expected = tiecorrect(rankdata(values)) - np.testing.assert_allclose(result.get()[0], expected, rtol=1e-10)