Optional GPU backends: device= across predict / predict_proba / TreeSHAP#8
Conversation
New native-NVIDIA-only, off-by-default 'cuda' Cargo feature, sibling to the wgpu 'gpu' feature. src/cuda.rs CudaForest flattens the TreeNode forest to SoA, uploads once, and traverses one row per CUDA thread via an nvrtc-compiled kernel (cudarc, dynamic-loading: no CUDA toolkit needed at build time). RandomForestRegressor.predict_cuda exposes it to Python. CudaForest::new returns None when CUDA is unavailable so callers fall back to CPU. Validated as a standalone Rust prototype on an A100: ~160x kernel-only / ~51x end-to-end over 12-core CPU, exact parity with traverse. Default and wasm builds unaffected (optional dep, native-only).
predict_cuda was rebuilding the CUDA context, recompiling the nvrtc kernel and re-uploading the forest on every call. Cache it in a RefCell<Option<CudaForest>> built lazily on first call and reset on fit. The pyclass becomes unsendable under the cuda feature (CUDA context is thread-affine); default/wasm builds keep plain #[pyclass].
The real per-call bottleneck at scale was the single-threaded f64->f32 copy of the input (~32M elems), not the CUDA forest setup (the driver already caches the context/kernel). par_iter over the contiguous slice, sequential fallback for non-contiguous arrays.
Single torch-style selector: device='cpu' (default), 'cuda' (native CUDA, cuda feature), or 'mps'/'metal'/'gpu' (wgpu->Metal, gpu feature). Both backends cached per model; unavailable devices raise a clear error listing what the wheel was built with. Replaces the standalone predict_cuda and exposes the wgpu/mps path to Python for the first time. Verified on M4: device='mps' 5.1x@10k, 9.4x@100k, 12x@1M (unified memory wins small batches); device='cuda' errors clearly on the gpu-only wheel.
Generalize GpuForest/CudaForest to out_dim>1: leaves carry an out_dim-length vector, one kernel accumulates it per row (out_dim=1 regression, =n_classes proba). Wire RandomForestClassifier.predict/predict_proba(device=...) with per-class leaf distributions (cap MAX_OUT=32 classes -> CPU fallback above). Verified on M4 mps: proba parity ~3e-7, predict exact, for binary and 3-class.
…sifier Add GpuForest/CudaForest::new_scaled (out_dim=1, per-tree leaf scaling, reuses the mean kernel via a shared from_flat) so the per-round lr/tree-count factors fold into one forest per output channel. boosting.rs raw_dispatch builds those forests locally and combines them; predict/predict_proba on RFGBoostRegressor, RFGBoostClassifier and the RFGBoost wrapper take device=. The sklearn Python wrappers in _woe.py forward device to the Rust model. Verified on M4 mps: regressor 4e-3 (f32 branch-flips), binary proba 2e-7 / predict exact, 3-class proba 2e-3 / predict exact.
TreeSHAP.explain(device=...) runs the exact 2^k coalition enumeration on the GPU, one thread per sample. The recursive evaluate_coalition (hidden feature -> weight both children) becomes an explicit per-thread weight-stack. tree_shap.rs flattens the SHAP tree to SoA (feat/thr/left/right/p_left/p_right/node-uidx/ leaf-values) + unique features + factorials; new gpu::shap_explain (wgpu, higher storage-buffer limits) and cuda::shap_explain (cudarc + nvrtc) kernels. Trees with >SHAP_MAX_K(16) unique features fall back to CPU. Verified on M4 mps: classification 7e-8, regression 1e-7 vs CPU exact Shapley.
Reviewer's GuideAdds optional, off-by-default GPU backends (CUDA and wgpu) and plumbs a unified device="..." argument through RandomForest, RFGBoost, and TreeSHAP predict/explain paths, including multi-output GPU kernels and exact-but-exponential TreeSHAP GPU implementations, while preserving CPU defaults and clear feature/availability checks. Sequence diagram for predict(device=...) dispatch in RandomForestClassifiersequenceDiagram
actor User
participant RandomForestClassifier
participant predict_proba_impl
participant proba_cuda_impl
participant proba_gpu_impl
participant traverse_proba
participant CudaForest
participant GpuForest
User->>RandomForestClassifier: predict_proba(x, device)
RandomForestClassifier->>predict_proba_impl: predict_proba_impl(x, device)
alt device == cpu
predict_proba_impl->>traverse_proba: traverse_proba(tree, sample, n_classes)
traverse_proba-->>predict_proba_impl: per-tree probs
predict_proba_impl-->>RandomForestClassifier: CPU probabilities
RandomForestClassifier-->>User: proba
else device == cuda (feature cuda)
predict_proba_impl->>proba_cuda_impl: proba_cuda_impl(x)
proba_cuda_impl->>CudaForest: CudaForest::new(trees, n_features, n_classes, leaf_proba)
CudaForest-->>proba_cuda_impl: Option<CudaForest>
proba_cuda_impl->>CudaForest: predict(xf, n)
CudaForest-->>proba_cuda_impl: GPU probs
proba_cuda_impl-->>RandomForestClassifier: CUDA probabilities
RandomForestClassifier-->>User: proba
else device in [mps, metal, gpu] (feature gpu)
predict_proba_impl->>proba_gpu_impl: proba_gpu_impl(x)
proba_gpu_impl->>GpuForest: GpuForest::new(trees, n_features, n_classes, leaf_proba)
GpuForest-->>proba_gpu_impl: Option<GpuForest>
proba_gpu_impl->>GpuForest: predict(xf, n)
GpuForest-->>proba_gpu_impl: GPU probs
proba_gpu_impl-->>RandomForestClassifier: wgpu probabilities
RandomForestClassifier-->>User: proba
else unsupported device
predict_proba_impl-->>RandomForestClassifier: PyValueError
RandomForestClassifier-->>User: error
end
Sequence diagram for TreeSHAP.explain(device=...) with GPU fallbacksequenceDiagram
actor User
participant TreeSHAP
participant explain
participant explain_cpu
participant explain_device
participant flatten_for_gpu
participant shap_explain_gpu
User->>TreeSHAP: explain(x, device)
TreeSHAP->>explain: explain(x, device)
alt device == cpu
explain->>explain_cpu: explain_cpu(x_arr)
explain_cpu-->>explain: SHAP values
explain-->>User: SHAP values
else device == cuda / mps / metal / gpu
explain->>explain_device: explain_device(x_arr, backend)
explain_device->>flatten_for_gpu: flatten_for_gpu()
alt flatten_for_gpu returns Some
flatten_for_gpu-->>explain_device: ShapFlat
explain_device->>shap_explain_gpu: shap_explain(xf, n, nf, nc, k, ...)
alt shap_explain_gpu returns Some
shap_explain_gpu-->>explain_device: GPU SHAP values
explain_device-->>explain: SHAP values
explain-->>User: SHAP values
else shap_explain_gpu returns None
shap_explain_gpu-->>explain_device: None
explain_device->>explain_cpu: explain_cpu(x_arr)
explain_cpu-->>explain_device: CPU SHAP values
explain_device-->>explain: SHAP values
explain-->>User: SHAP values
end
else flatten_for_gpu returns None
flatten_for_gpu-->>explain_device: None
explain_device->>explain_cpu: explain_cpu(x_arr)
explain_cpu-->>explain_device: CPU SHAP values
explain_device-->>explain: SHAP values
explain-->>User: SHAP values
end
else unsupported device
explain-->>User: PyValueError
end
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 3 issues, and left some high level feedback:
- The
available_devices()helpers only list"mps"for the GPU backend while the code accepts"mps" | "metal" | "gpu"; consider either listing all accepted aliases or normalizingdeviceto a canonical internal name so error messages accurately reflect what is supported. - The CUDA and wgpu forest implementations both duplicate the same SoA flattening logic (
Flat+flatten_node+new_scaled); you could reduce maintenance overhead and the risk of divergence by extracting this into a shared helper module used by both backends.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The `available_devices()` helpers only list `"mps"` for the GPU backend while the code accepts `"mps" | "metal" | "gpu"`; consider either listing all accepted aliases or normalizing `device` to a canonical internal name so error messages accurately reflect what is supported.
- The CUDA and wgpu forest implementations both duplicate the same SoA flattening logic (`Flat` + `flatten_node` + `new_scaled`); you could reduce maintenance overhead and the risk of divergence by extracting this into a shared helper module used by both backends.
## Individual Comments
### Comment 1
<location path="src/gpu.rs" line_range="295-304" />
<code_context>
+// ----------------------------------------------------------------- TreeSHAP
+// Exact Shapley by 2^k coalition enumeration, one CUDA thread per sample;
+// `evaluate_coalition`'s recursion is an explicit per-thread weight-stack. f32.
+pub const SHAP_MAX_K: usize = 16;
+
+const SHAP_KERNEL: &str = r#"
</code_context>
<issue_to_address>
**issue (bug_risk):** SHAP_MAX_DEPTH guard is currently a no-op and depth is effectively unchecked on the WGSL path
`SHAP_MAX_DEPTH` is only used in `if k > SHAP_MAX_K || (SHAP_MAX_DEPTH as usize) < 1`, so with the current constant value of 64 that clause is always false and depth is effectively unchecked in the WGSL path. Meanwhile the shader allocates a fixed-size stack (`sn/sw: array<u32, 64>`), so trees with depth > 64 will silently overflow it.
If you want a real depth bound (like `MAXD` in the CUDA kernel), either enforce it by measuring depth during flattening and bailing out when it exceeds `SHAP_MAX_DEPTH`, or remove `SHAP_MAX_DEPTH` and the misleading condition. As written, it suggests a depth guard that doesn’t exist, which makes future changes to tree depth harder to reason about safely.
</issue_to_address>
### Comment 2
<location path="src/cuda.rs" line_range="193-201" />
<code_context>
+// ----------------------------------------------------------------- TreeSHAP
+// Exact Shapley by 2^k coalition enumeration, one CUDA thread per sample;
+// `evaluate_coalition`'s recursion is an explicit per-thread weight-stack. f32.
+pub const SHAP_MAX_K: usize = 16;
+
+const SHAP_KERNEL: &str = r#"
+#define MAXD 64
+__device__ float eval_coalition(const float* x, int base, unsigned mask, int c, int nc,
+ const int* feat, const float* thr, const unsigned* nleft, const unsigned* nright,
+ const float* pL, const float* pR, const int* uidx, const float* leafval)
+{
+ unsigned sn[MAXD]; float sw[MAXD]; int sp = 0;
+ sn[0] = 0u; sw[0] = 1.0f; sp = 1;
+ float acc = 0.0f;
</code_context>
<issue_to_address>
**issue (bug_risk):** Tree depth is limited implicitly by MAXD, but that constraint is not enforced in Rust
The CUDA SHAP kernel uses fixed-size per-thread stacks (`sn[MAXD]`, `sw[MAXD]`) with `MAXD = 64`; deeper trees will overrun these arrays and cause undefined behaviour. Rust enforces `SHAP_MAX_K`, but depth is not similarly constrained.
Please either add an explicit depth check in `flatten_for_gpu` (track max depth during traversal and fail/return `None` when it exceeds `MAXD`), or make the depth guarantee explicit in TreeSHAP and tie the Rust/CUDA constants together so they cannot diverge. As written, this implicit limit is easy to break with future tree changes.
</issue_to_address>
### Comment 3
<location path="src/tree_shap.rs" line_range="234-235" />
<code_context>
+ uidx: vec![], leafval: vec![], ufeat: ufeat_us.iter().map(|&f| f as u32).collect(),
+ fact: vec![], k,
+ };
+ flatten_shap(root, &mut a, &ufeat_us, self.n_classes.max(1));
+ let mut fact = vec![1.0f32; k + 1];
+ for i in 1..=k { fact[i] = fact[i - 1] * i as f32; }
+ a.fact = fact;
</code_context>
<issue_to_address>
**nitpick (bug_risk):** Factorial precomputation assumes k is small; defending against overflow or large k would clarify behaviour
The factorial array is built as `fact[i] = fact[i - 1] * i as f32` for `i in 1..=k`. With the current `SHAP_MAX_K = 16` this is safe, but for larger `k` the values will overflow `f32` and corrupt the weights.
Since both CUDA and WGPU SHAP paths already enforce `k <= SHAP_MAX_K`, please either enforce that invariant here (e.g. an assert on `k`) or document that increasing `SHAP_MAX_K` requires revisiting the numerical stability of this computation.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| pub const SHAP_MAX_K: usize = 16; | ||
| const SHAP_MAX_DEPTH: u32 = 64; | ||
|
|
||
| const SHAP_SHADER: &str = r#" | ||
| struct P { n_samples:u32, n_features:u32, n_classes:u32, k:u32 }; | ||
| @group(0) @binding(0) var<storage, read> features: array<f32>; | ||
| @group(0) @binding(1) var<storage, read> feat: array<i32>; | ||
| @group(0) @binding(2) var<storage, read> thr: array<f32>; | ||
| @group(0) @binding(3) var<storage, read> nleft: array<u32>; | ||
| @group(0) @binding(4) var<storage, read> nright: array<u32>; |
There was a problem hiding this comment.
issue (bug_risk): SHAP_MAX_DEPTH guard is currently a no-op and depth is effectively unchecked on the WGSL path
SHAP_MAX_DEPTH is only used in if k > SHAP_MAX_K || (SHAP_MAX_DEPTH as usize) < 1, so with the current constant value of 64 that clause is always false and depth is effectively unchecked in the WGSL path. Meanwhile the shader allocates a fixed-size stack (sn/sw: array<u32, 64>), so trees with depth > 64 will silently overflow it.
If you want a real depth bound (like MAXD in the CUDA kernel), either enforce it by measuring depth during flattening and bailing out when it exceeds SHAP_MAX_DEPTH, or remove SHAP_MAX_DEPTH and the misleading condition. As written, it suggests a depth guard that doesn’t exist, which makes future changes to tree depth harder to reason about safely.
| pub const SHAP_MAX_K: usize = 16; | ||
|
|
||
| const SHAP_KERNEL: &str = r#" | ||
| #define MAXD 64 | ||
| __device__ float eval_coalition(const float* x, int base, unsigned mask, int c, int nc, | ||
| const int* feat, const float* thr, const unsigned* nleft, const unsigned* nright, | ||
| const float* pL, const float* pR, const int* uidx, const float* leafval) | ||
| { | ||
| unsigned sn[MAXD]; float sw[MAXD]; int sp = 0; |
There was a problem hiding this comment.
issue (bug_risk): Tree depth is limited implicitly by MAXD, but that constraint is not enforced in Rust
The CUDA SHAP kernel uses fixed-size per-thread stacks (sn[MAXD], sw[MAXD]) with MAXD = 64; deeper trees will overrun these arrays and cause undefined behaviour. Rust enforces SHAP_MAX_K, but depth is not similarly constrained.
Please either add an explicit depth check in flatten_for_gpu (track max depth during traversal and fail/return None when it exceeds MAXD), or make the depth guarantee explicit in TreeSHAP and tie the Rust/CUDA constants together so they cannot diverge. As written, this implicit limit is easy to break with future tree changes.
| flatten_shap(root, &mut a, &ufeat_us, self.n_classes.max(1)); | ||
| let mut fact = vec![1.0f32; k + 1]; |
There was a problem hiding this comment.
nitpick (bug_risk): Factorial precomputation assumes k is small; defending against overflow or large k would clarify behaviour
The factorial array is built as fact[i] = fact[i - 1] * i as f32 for i in 1..=k. With the current SHAP_MAX_K = 16 this is safe, but for larger k the values will overflow f32 and corrupt the weights.
Since both CUDA and WGPU SHAP paths already enforce k <= SHAP_MAX_K, please either enforce that invariant here (e.g. an assert on k) or document that increasing SHAP_MAX_K requires revisiting the numerical stability of this computation.
Adds optional, off-by-default GPU acceleration selected via a
device=argument, across the predict/explain API. Native-only and excluded from the default and wasm wheels — the default build and published wheels are unchanged.What
device="cuda"— native CUDA (cudarc + runtime nvrtc kernels),cudafeature, NVIDIA.device="mps"/"metal"/"gpu"— wgpu → Metal / Vulkan / DX12,gpufeature.device="cpu"(default) unchanged. Unavailable devices raise a clear error listing what the wheel was built with (a wheel carries one GPU backend, like torch's per-platform wheels).Covered:
RandomForestRegressor.predict,RandomForestClassifier.predict/predict_proba,RFGBoost{,Regressor,Classifier}predict/predict_proba,TreeSHAP.explain.Design
GpuForest/CudaForestflatten the forest to struct-of-arrays; one multi-output kernel (out_dim=1 value, out_dim=n_classes distribution).new_scaledfolds per-round lr/tree-count factors so boosting reuses the mean kernel.k > 16falls back to CPU.Verified (M4 mps + A100 cuda)
Caveat
TreeSHAPGPU is a correct exact reference but uses exact 2^k enumeration (exponential). Benchmarked againstshap.TreeExplainer(polynomial), shap's CPU is ~2000× faster — the right SHAP win is the polynomial algorithm, tracked separately. Predict/proba GPU paths are unambiguous wins.Summary by Sourcery
Introduce optional, off-by-default GPU and CUDA backends for tree-based inference and SHAP explanations, selectable via a new device parameter across predict and explain APIs.
New Features:
Enhancements:
Build:
Documentation: