Skip to content

Optional GPU backends: device= across predict / predict_proba / TreeSHAP#8

Merged
xRiskLab merged 8 commits into
mainfrom
feature/cuda
Jun 21, 2026
Merged

Optional GPU backends: device= across predict / predict_proba / TreeSHAP#8
xRiskLab merged 8 commits into
mainfrom
feature/cuda

Conversation

@xRiskLab

@xRiskLab xRiskLab commented Jun 21, 2026

Copy link
Copy Markdown
Owner

Adds optional, off-by-default GPU acceleration selected via a device= argument, across the predict/explain API. Native-only and excluded from the default and wasm wheels — the default build and published wheels are unchanged.

What

  • device="cuda" — native CUDA (cudarc + runtime nvrtc kernels), cuda feature, NVIDIA.
  • device="mps"/"metal"/"gpu" — wgpu → Metal / Vulkan / DX12, gpu feature.
  • device="cpu" (default) unchanged. Unavailable devices raise a clear error listing what the wheel was built with (a wheel carries one GPU backend, like torch's per-platform wheels).

Covered: RandomForestRegressor.predict, RandomForestClassifier.predict/predict_proba, RFGBoost{,Regressor,Classifier} predict/predict_proba, TreeSHAP.explain.

Design

  • GpuForest/CudaForest flatten the forest to struct-of-arrays; one multi-output kernel (out_dim=1 value, out_dim=n_classes distribution). new_scaled folds per-round lr/tree-count factors so boosting reuses the mean kernel.
  • TreeSHAP: exact 2^k coalition Shapley, one thread per sample, recursion → explicit weight-stack. k > 16 falls back to CPU.

Verified (M4 mps + A100 cuda)

  • predict/proba parity exact (argmax) / ~1e-7 (proba); boosting regressor ~4e-3 (f32 over rounds). TreeSHAP parity ~1e-7.
  • predict speedups: ~9.5×@10k, ~33×@100k, ~28×@1m (A100); ~5×/9×/12× (M4). TreeSHAP ~3× (exact-2^k algorithm; see caveat below).

Caveat

TreeSHAP GPU is a correct exact reference but uses exact 2^k enumeration (exponential). Benchmarked against shap.TreeExplainer (polynomial), shap's CPU is ~2000× faster — the right SHAP win is the polynomial algorithm, tracked separately. Predict/proba GPU paths are unambiguous wins.

Summary by Sourcery

Introduce optional, off-by-default GPU and CUDA backends for tree-based inference and SHAP explanations, selectable via a new device parameter across predict and explain APIs.

New Features:

  • Add GPU-backed multi-output forest inference pipeline supporting regression and classification outputs with configurable device selection via device parameter.
  • Add optional native CUDA backend mirroring the GPU forest capabilities for batched prediction and TreeSHAP explanations.
  • Extend RandomForestRegressor, RandomForestClassifier, RFGBoost, RFGBoostRegressor, and RFGBoost Python APIs to accept a device argument for predict and predict_proba, including device capability reporting and fallbacks.
  • Add GPU-accelerated exact TreeSHAP implementation for both CUDA and wgpu backends with automatic CPU fallback when unsupported or unavailable.

Enhancements:

  • Refactor forest flattening and prediction paths to support multi-output leaf vectors and shared kernels across regression and classification.
  • Introduce shared raw prediction dispatch logic for boosting models that routes to CPU, CUDA, or GPU paths while reusing kernels via scaled forests.
  • Ensure CUDA-enabled pyclasses are marked unsendable and cache device-specific forest representations for reuse across predictions.

Build:

  • Add optional cudarc-based CUDA dependency and feature flag, keeping default and wasm builds unchanged and GPU backends native-only.

Documentation:

  • Document new optional GPU and CUDA acceleration options, device arguments, supported backends, and observed speedups in the changelog.

deburky added 8 commits June 21, 2026 15:23
New native-NVIDIA-only, off-by-default 'cuda' Cargo feature, sibling to the
wgpu 'gpu' feature. src/cuda.rs CudaForest flattens the TreeNode forest to SoA,
uploads once, and traverses one row per CUDA thread via an nvrtc-compiled
kernel (cudarc, dynamic-loading: no CUDA toolkit needed at build time).
RandomForestRegressor.predict_cuda exposes it to Python. CudaForest::new
returns None when CUDA is unavailable so callers fall back to CPU.

Validated as a standalone Rust prototype on an A100: ~160x kernel-only / ~51x
end-to-end over 12-core CPU, exact parity with traverse. Default and wasm
builds unaffected (optional dep, native-only).
predict_cuda was rebuilding the CUDA context, recompiling the nvrtc kernel and
re-uploading the forest on every call. Cache it in a RefCell<Option<CudaForest>>
built lazily on first call and reset on fit. The pyclass becomes unsendable
under the cuda feature (CUDA context is thread-affine); default/wasm builds
keep plain #[pyclass].
The real per-call bottleneck at scale was the single-threaded f64->f32 copy of
the input (~32M elems), not the CUDA forest setup (the driver already caches the
context/kernel). par_iter over the contiguous slice, sequential fallback for
non-contiguous arrays.
Single torch-style selector: device='cpu' (default), 'cuda' (native CUDA,
cuda feature), or 'mps'/'metal'/'gpu' (wgpu->Metal, gpu feature). Both backends
cached per model; unavailable devices raise a clear error listing what the wheel
was built with. Replaces the standalone predict_cuda and exposes the wgpu/mps
path to Python for the first time.

Verified on M4: device='mps' 5.1x@10k, 9.4x@100k, 12x@1M (unified memory wins
small batches); device='cuda' errors clearly on the gpu-only wheel.
Generalize GpuForest/CudaForest to out_dim>1: leaves carry an out_dim-length
vector, one kernel accumulates it per row (out_dim=1 regression, =n_classes
proba). Wire RandomForestClassifier.predict/predict_proba(device=...) with
per-class leaf distributions (cap MAX_OUT=32 classes -> CPU fallback above).
Verified on M4 mps: proba parity ~3e-7, predict exact, for binary and 3-class.
…sifier

Add GpuForest/CudaForest::new_scaled (out_dim=1, per-tree leaf scaling, reuses
the mean kernel via a shared from_flat) so the per-round lr/tree-count factors
fold into one forest per output channel. boosting.rs raw_dispatch builds those
forests locally and combines them; predict/predict_proba on RFGBoostRegressor,
RFGBoostClassifier and the RFGBoost wrapper take device=. The sklearn Python
wrappers in _woe.py forward device to the Rust model.

Verified on M4 mps: regressor 4e-3 (f32 branch-flips), binary proba 2e-7 /
predict exact, 3-class proba 2e-3 / predict exact.
TreeSHAP.explain(device=...) runs the exact 2^k coalition enumeration on the
GPU, one thread per sample. The recursive evaluate_coalition (hidden feature ->
weight both children) becomes an explicit per-thread weight-stack. tree_shap.rs
flattens the SHAP tree to SoA (feat/thr/left/right/p_left/p_right/node-uidx/
leaf-values) + unique features + factorials; new gpu::shap_explain (wgpu, higher
storage-buffer limits) and cuda::shap_explain (cudarc + nvrtc) kernels. Trees
with >SHAP_MAX_K(16) unique features fall back to CPU.

Verified on M4 mps: classification 7e-8, regression 1e-7 vs CPU exact Shapley.
@sourcery-ai

sourcery-ai Bot commented Jun 21, 2026

Copy link
Copy Markdown

Reviewer's Guide

Adds optional, off-by-default GPU backends (CUDA and wgpu) and plumbs a unified device="..." argument through RandomForest, RFGBoost, and TreeSHAP predict/explain paths, including multi-output GPU kernels and exact-but-exponential TreeSHAP GPU implementations, while preserving CPU defaults and clear feature/availability checks.

Sequence diagram for predict(device=...) dispatch in RandomForestClassifier

sequenceDiagram
    actor User
    participant RandomForestClassifier
    participant predict_proba_impl
    participant proba_cuda_impl
    participant proba_gpu_impl
    participant traverse_proba
    participant CudaForest
    participant GpuForest

    User->>RandomForestClassifier: predict_proba(x, device)
    RandomForestClassifier->>predict_proba_impl: predict_proba_impl(x, device)
    alt device == cpu
        predict_proba_impl->>traverse_proba: traverse_proba(tree, sample, n_classes)
        traverse_proba-->>predict_proba_impl: per-tree probs
        predict_proba_impl-->>RandomForestClassifier: CPU probabilities
        RandomForestClassifier-->>User: proba
    else device == cuda (feature cuda)
        predict_proba_impl->>proba_cuda_impl: proba_cuda_impl(x)
        proba_cuda_impl->>CudaForest: CudaForest::new(trees, n_features, n_classes, leaf_proba)
        CudaForest-->>proba_cuda_impl: Option<CudaForest>
        proba_cuda_impl->>CudaForest: predict(xf, n)
        CudaForest-->>proba_cuda_impl: GPU probs
        proba_cuda_impl-->>RandomForestClassifier: CUDA probabilities
        RandomForestClassifier-->>User: proba
    else device in [mps, metal, gpu] (feature gpu)
        predict_proba_impl->>proba_gpu_impl: proba_gpu_impl(x)
        proba_gpu_impl->>GpuForest: GpuForest::new(trees, n_features, n_classes, leaf_proba)
        GpuForest-->>proba_gpu_impl: Option<GpuForest>
        proba_gpu_impl->>GpuForest: predict(xf, n)
        GpuForest-->>proba_gpu_impl: GPU probs
        proba_gpu_impl-->>RandomForestClassifier: wgpu probabilities
        RandomForestClassifier-->>User: proba
    else unsupported device
        predict_proba_impl-->>RandomForestClassifier: PyValueError
        RandomForestClassifier-->>User: error
    end
Loading

Sequence diagram for TreeSHAP.explain(device=...) with GPU fallback

sequenceDiagram
    actor User
    participant TreeSHAP
    participant explain
    participant explain_cpu
    participant explain_device
    participant flatten_for_gpu
    participant shap_explain_gpu

    User->>TreeSHAP: explain(x, device)
    TreeSHAP->>explain: explain(x, device)
    alt device == cpu
        explain->>explain_cpu: explain_cpu(x_arr)
        explain_cpu-->>explain: SHAP values
        explain-->>User: SHAP values
    else device == cuda / mps / metal / gpu
        explain->>explain_device: explain_device(x_arr, backend)
        explain_device->>flatten_for_gpu: flatten_for_gpu()
        alt flatten_for_gpu returns Some
            flatten_for_gpu-->>explain_device: ShapFlat
            explain_device->>shap_explain_gpu: shap_explain(xf, n, nf, nc, k, ...)
            alt shap_explain_gpu returns Some
                shap_explain_gpu-->>explain_device: GPU SHAP values
                explain_device-->>explain: SHAP values
                explain-->>User: SHAP values
            else shap_explain_gpu returns None
                shap_explain_gpu-->>explain_device: None
                explain_device->>explain_cpu: explain_cpu(x_arr)
                explain_cpu-->>explain_device: CPU SHAP values
                explain_device-->>explain: SHAP values
                explain-->>User: SHAP values
            end
        else flatten_for_gpu returns None
            flatten_for_gpu-->>explain_device: None
            explain_device->>explain_cpu: explain_cpu(x_arr)
            explain_cpu-->>explain_device: CPU SHAP values
            explain_device-->>explain: SHAP values
            explain-->>User: SHAP values
        end
    else unsupported device
        explain-->>User: PyValueError
    end
Loading

File-Level Changes

Change Details Files
Generalize GPU forest inference to multi-output and add a TreeSHAP GPU backend using wgpu.
  • Refactor GpuForest flattening into a struct-of-arrays with per-leaf out_dim-length value vectors and a unified kernel handling regression and class-probability outputs.
  • Introduce MAX_OUT and SHAP_MAX_K limits, TreeSHAP-specific WGSL shader and host-side shap_explain function that enumerates 2^k coalitions with an explicit stack per thread.
  • Add helpers to flatten SHAP trees to SoA, compute factorial weights, and fall back to CPU when k exceeds supported bounds or when no GPU adapter is available.
src/gpu.rs
src/tree_shap.rs
Introduce a native CUDA backend mirroring the wgpu GpuForest and TreeSHAP GPU implementations.
  • Add cuda.rs implementing CudaForest with nvrtc-compiled kernels for multi-output forest prediction and optional scaled forests for boosting.
  • Implement CUDA-based shap_explain using a 2^k coalition enumeration kernel and enforce MAX_OUT and SHAP_MAX_K limits analogous to the wgpu path.
  • Wire CUDA feature flag and cudarc dependency into Cargo features, keeping it off by default and excluded from wasm/default wheels.
src/cuda.rs
Cargo.toml
Cargo.lock
Plumb device-aware predict/predict_proba APIs for RandomForest and RFGBoost, with lazy GPU/CUDA caches and clear device-availability errors.
  • Extend RandomForestRegressor and RandomForestClassifier pyclasses with device="cpu" arguments, backend dispatchers, and per-instance CUDA/GPU forest caches invalidated on re-fit.
  • Implement leaf_proba helper to map TreeNode class_counts into probability vectors for multi-class GPU/CUDA inference while enforcing MAX_OUT limits.
  • Refactor RFGBoost(R/G) predict/predict_proba to use a shared raw_dispatch that routes to CPU, CUDA, or GPU implementations, including boosting-specific scaled forests and device availability reporting.
src/random_forest.rs
src/boosting.rs
rfgboost/_woe.py
Expose device options and GPU behavior in the public API and documentation.
  • Update Python wrapper classes to accept a device parameter and forward it to the underlying Rust predictors.
  • Document the new device="..." options, backend coverage, and performance characteristics in the changelog.
  • Ensure default behavior remains CPU-only and that requesting an unavailable device yields informative errors listing the compiled-in backends.
rfgboost/_woe.py
CHANGELOG.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@xRiskLab xRiskLab merged commit a9fd240 into main Jun 21, 2026
18 checks passed

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 3 issues, and left some high level feedback:

  • The available_devices() helpers only list "mps" for the GPU backend while the code accepts "mps" | "metal" | "gpu"; consider either listing all accepted aliases or normalizing device to a canonical internal name so error messages accurately reflect what is supported.
  • The CUDA and wgpu forest implementations both duplicate the same SoA flattening logic (Flat + flatten_node + new_scaled); you could reduce maintenance overhead and the risk of divergence by extracting this into a shared helper module used by both backends.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The `available_devices()` helpers only list `"mps"` for the GPU backend while the code accepts `"mps" | "metal" | "gpu"`; consider either listing all accepted aliases or normalizing `device` to a canonical internal name so error messages accurately reflect what is supported.
- The CUDA and wgpu forest implementations both duplicate the same SoA flattening logic (`Flat` + `flatten_node` + `new_scaled`); you could reduce maintenance overhead and the risk of divergence by extracting this into a shared helper module used by both backends.

## Individual Comments

### Comment 1
<location path="src/gpu.rs" line_range="295-304" />
<code_context>
+// ----------------------------------------------------------------- TreeSHAP
+// Exact Shapley by 2^k coalition enumeration, one CUDA thread per sample;
+// `evaluate_coalition`'s recursion is an explicit per-thread weight-stack. f32.
+pub const SHAP_MAX_K: usize = 16;
+
+const SHAP_KERNEL: &str = r#"
</code_context>
<issue_to_address>
**issue (bug_risk):** SHAP_MAX_DEPTH guard is currently a no-op and depth is effectively unchecked on the WGSL path

`SHAP_MAX_DEPTH` is only used in `if k > SHAP_MAX_K || (SHAP_MAX_DEPTH as usize) < 1`, so with the current constant value of 64 that clause is always false and depth is effectively unchecked in the WGSL path. Meanwhile the shader allocates a fixed-size stack (`sn/sw: array<u32, 64>`), so trees with depth > 64 will silently overflow it.

If you want a real depth bound (like `MAXD` in the CUDA kernel), either enforce it by measuring depth during flattening and bailing out when it exceeds `SHAP_MAX_DEPTH`, or remove `SHAP_MAX_DEPTH` and the misleading condition. As written, it suggests a depth guard that doesn’t exist, which makes future changes to tree depth harder to reason about safely.
</issue_to_address>

### Comment 2
<location path="src/cuda.rs" line_range="193-201" />
<code_context>
+// ----------------------------------------------------------------- TreeSHAP
+// Exact Shapley by 2^k coalition enumeration, one CUDA thread per sample;
+// `evaluate_coalition`'s recursion is an explicit per-thread weight-stack. f32.
+pub const SHAP_MAX_K: usize = 16;
+
+const SHAP_KERNEL: &str = r#"
+#define MAXD 64
+__device__ float eval_coalition(const float* x, int base, unsigned mask, int c, int nc,
+    const int* feat, const float* thr, const unsigned* nleft, const unsigned* nright,
+    const float* pL, const float* pR, const int* uidx, const float* leafval)
+{
+    unsigned sn[MAXD]; float sw[MAXD]; int sp = 0;
+    sn[0] = 0u; sw[0] = 1.0f; sp = 1;
+    float acc = 0.0f;
</code_context>
<issue_to_address>
**issue (bug_risk):** Tree depth is limited implicitly by MAXD, but that constraint is not enforced in Rust

The CUDA SHAP kernel uses fixed-size per-thread stacks (`sn[MAXD]`, `sw[MAXD]`) with `MAXD = 64`; deeper trees will overrun these arrays and cause undefined behaviour. Rust enforces `SHAP_MAX_K`, but depth is not similarly constrained.

Please either add an explicit depth check in `flatten_for_gpu` (track max depth during traversal and fail/return `None` when it exceeds `MAXD`), or make the depth guarantee explicit in TreeSHAP and tie the Rust/CUDA constants together so they cannot diverge. As written, this implicit limit is easy to break with future tree changes.
</issue_to_address>

### Comment 3
<location path="src/tree_shap.rs" line_range="234-235" />
<code_context>
+            uidx: vec![], leafval: vec![], ufeat: ufeat_us.iter().map(|&f| f as u32).collect(),
+            fact: vec![], k,
+        };
+        flatten_shap(root, &mut a, &ufeat_us, self.n_classes.max(1));
+        let mut fact = vec![1.0f32; k + 1];
+        for i in 1..=k { fact[i] = fact[i - 1] * i as f32; }
+        a.fact = fact;
</code_context>
<issue_to_address>
**nitpick (bug_risk):** Factorial precomputation assumes k is small; defending against overflow or large k would clarify behaviour

The factorial array is built as `fact[i] = fact[i - 1] * i as f32` for `i in 1..=k`. With the current `SHAP_MAX_K = 16` this is safe, but for larger `k` the values will overflow `f32` and corrupt the weights.

Since both CUDA and WGPU SHAP paths already enforce `k <= SHAP_MAX_K`, please either enforce that invariant here (e.g. an assert on `k`) or document that increasing `SHAP_MAX_K` requires revisiting the numerical stability of this computation.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread src/gpu.rs
Comment on lines +295 to +304
pub const SHAP_MAX_K: usize = 16;
const SHAP_MAX_DEPTH: u32 = 64;

const SHAP_SHADER: &str = r#"
struct P { n_samples:u32, n_features:u32, n_classes:u32, k:u32 };
@group(0) @binding(0) var<storage, read> features: array<f32>;
@group(0) @binding(1) var<storage, read> feat: array<i32>;
@group(0) @binding(2) var<storage, read> thr: array<f32>;
@group(0) @binding(3) var<storage, read> nleft: array<u32>;
@group(0) @binding(4) var<storage, read> nright: array<u32>;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): SHAP_MAX_DEPTH guard is currently a no-op and depth is effectively unchecked on the WGSL path

SHAP_MAX_DEPTH is only used in if k > SHAP_MAX_K || (SHAP_MAX_DEPTH as usize) < 1, so with the current constant value of 64 that clause is always false and depth is effectively unchecked in the WGSL path. Meanwhile the shader allocates a fixed-size stack (sn/sw: array<u32, 64>), so trees with depth > 64 will silently overflow it.

If you want a real depth bound (like MAXD in the CUDA kernel), either enforce it by measuring depth during flattening and bailing out when it exceeds SHAP_MAX_DEPTH, or remove SHAP_MAX_DEPTH and the misleading condition. As written, it suggests a depth guard that doesn’t exist, which makes future changes to tree depth harder to reason about safely.

Comment thread src/cuda.rs
Comment on lines +193 to +201
pub const SHAP_MAX_K: usize = 16;

const SHAP_KERNEL: &str = r#"
#define MAXD 64
__device__ float eval_coalition(const float* x, int base, unsigned mask, int c, int nc,
const int* feat, const float* thr, const unsigned* nleft, const unsigned* nright,
const float* pL, const float* pR, const int* uidx, const float* leafval)
{
unsigned sn[MAXD]; float sw[MAXD]; int sp = 0;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Tree depth is limited implicitly by MAXD, but that constraint is not enforced in Rust

The CUDA SHAP kernel uses fixed-size per-thread stacks (sn[MAXD], sw[MAXD]) with MAXD = 64; deeper trees will overrun these arrays and cause undefined behaviour. Rust enforces SHAP_MAX_K, but depth is not similarly constrained.

Please either add an explicit depth check in flatten_for_gpu (track max depth during traversal and fail/return None when it exceeds MAXD), or make the depth guarantee explicit in TreeSHAP and tie the Rust/CUDA constants together so they cannot diverge. As written, this implicit limit is easy to break with future tree changes.

Comment thread src/tree_shap.rs
Comment on lines +234 to +235
flatten_shap(root, &mut a, &ufeat_us, self.n_classes.max(1));
let mut fact = vec![1.0f32; k + 1];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (bug_risk): Factorial precomputation assumes k is small; defending against overflow or large k would clarify behaviour

The factorial array is built as fact[i] = fact[i - 1] * i as f32 for i in 1..=k. With the current SHAP_MAX_K = 16 this is safe, but for larger k the values will overflow f32 and corrupt the weights.

Since both CUDA and WGPU SHAP paths already enforce k <= SHAP_MAX_K, please either enforce that invariant here (e.g. an assert on k) or document that increasing SHAP_MAX_K requires revisiting the numerical stability of this computation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants