Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dbda34b
Add array-api-compat and array-api-extra as core dependencies
frazane Jun 6, 2026
1b7996d
Remove vendored array-api copies in favour of pip dependencies
frazane Jun 6, 2026
7fdd6f3
Narrow coverage omit so the array-api extension layer is measured
frazane Jun 6, 2026
1f89e8e
Add array-api augmented namespace with input inference
frazane Jun 6, 2026
bbc248d
Export get_namespace from scoringrules.backend
frazane Jun 6, 2026
e0d30a1
Guard namespace __getattr__ recursion and cover linalg/gather delegation
frazane Jun 6, 2026
3612d82
Add per-framework special-function support audit
frazane Jun 6, 2026
a6c5da6
Add native-first special-function extension layer
frazane Jun 6, 2026
64e4ad6
Fix jax apply_along_axis axis handling; drop dead SCIPY_ARRAY_API; ex…
frazane Jun 6, 2026
bdc4692
Default active backend to numpy so numba selection is explicit
frazane Jun 6, 2026
7304231
Add use_numba dispatch helper and deprecate legacy backend selection
frazane Jun 6, 2026
7417a2f
Tidy warnings import and test global-numba non-numpy path
frazane Jun 6, 2026
a3f4259
Add per-backend native-array fixtures and inference assertion
frazane Jun 6, 2026
a6429c7
Add xp-parameterised univariate ensemble helpers for the CRPS pilot
frazane Jun 6, 2026
d79808a
Migrate CRPS ensemble-estimator core to the array-api xp namespace
frazane Jun 7, 2026
a1b3c1d
Migrate CRPS closed-form scores to the array-api xp namespace
frazane Jun 7, 2026
db9d5f5
Coerce python scalars in elementwise namespace ops so torch sqrt/log …
frazane Jun 7, 2026
76e14a0
Migrate CRPS ensemble-family public functions to array-api dispatch
frazane Jun 7, 2026
def4232
Migrate CRPS closed-form public wrappers to array-api dispatch
frazane Jun 7, 2026
2bfb283
Fix comb to return 0 for k outside [0, n], matching scipy.special.comb
frazane Jun 7, 2026
1617a91
Compute censored-distribution tail masses via xp so clogistic/cnormal…
frazane Jun 7, 2026
e574a9d
Exercise CRPS tests with native arrays and assert backend inference
frazane Jun 7, 2026
aaf8e4e
Preserve torch grad in namespace asarray, add namespace size, stop ge…
frazane Jun 7, 2026
ffb6d83
Rewrite user-guide backend section for inferred array frameworks
frazane Jun 7, 2026
9bc3db1
Exercise weighted-CRPS tests with native arrays and assert backend in…
frazane Jun 7, 2026
632ffcc
Drop unused array-api-extra dep, fix stale backend docstrings, silenc…
frazane Jun 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,3 @@ _devlog/
tests/output
.devcontainer/
docs/generated
scoringrules/vendored
89 changes: 89 additions & 0 deletions docs/special_functions_audit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Special-function support audit

`scoringrules` needs a handful of special functions that the array-API standard
does not provide. The extension layer (`scoringrules/backend/extensions.py`)
dispatches **native-first** — using each framework's own implementation where it
exists (so autograd and device placement are preserved) — and falls back to
`scipy.special` only for plain numpy.

This document records, per function and per framework, whether a usable path
exists for the **forward** evaluation and for the **gradient**. The values below
were produced by probing the actually-installed frameworks (jax, torch, scipy)
on this branch, not from memory.

## Legend

- **PASS** — a usable native path exists. For jax this is `jax.scipy.special`;
for torch it is either `torch.special` or a differentiable composition built
from native primitives (`exp`/`lgamma`); for numpy it is `scipy.special`. The
forward value is correct and, where the *grad* column says PASS, the gradient
is finite.
- **PARTIAL (numpy round-trip)** — only reachable by detouring through
numpy/scipy, which breaks the framework's native array type and autograd. No
special function currently needs this; the only numpy round-trip in the layer
is the `apply_along_axis` helper on torch (see *Non-standard helpers* below).
- **BLOCKED** — no native and no differentiable path. The scipy fallback cannot
be used because, on that framework, calling scipy raises (torch refuses
`numpy()` on grad-requiring tensors; see probe output). These functions must
receive numpy or jax inputs. The extension layer raises an explicit
`NotImplementedError` for them on torch.

## Probe results (this branch)

```
JAX native: erf T gamma T gammainc T gammaincc T beta T betainc T i0 T i1 T hyp2f1 T expi T factorial T comb F
TORCH native: erf T gamma F gammainc T gammaincc T beta F betainc F i0 T i1 T hyp2f1 F expi F factorial F comb F
torch scipy-fallback under grad: FAIL (RuntimeError: can't call numpy() on a tensor that requires grad)
jax scipy-fallback under grad: PASS (jax traces through scipy.special for these ufuncs)
```

Notes from the probe that shaped the implementation:

- `jax.scipy.special.hyp2f1` **exists and is correct** (the legacy backend had it
commented out): `hyp2f1(1,1,2,0.5) = 1.386294`, matching scipy.
- `jax.scipy.special.expi` accepts **0-d input** fine (`expi(1.0) = 1.895118`),
so no scalar-reshape workaround is needed.
- jax has **no** `comb`, and composing it as `factorial(n) // (factorial(k)·
factorial(n-k))` is wrong: jax `factorial` is float32, and floor-division of
`120.0001 // 12.00001` rounds to **9.0** instead of 10. The layer therefore
composes `comb` with floating division and `round`, which gives 10.0 on numpy,
jax, and torch.
- torch has no `gamma`/`beta`/`factorial`, but all three are built from the
native, differentiable `torch.lgamma`, so they are PASS (forward + grad).

## Support matrix

| fn | numpy | jax fwd | jax grad | torch fwd | torch grad | notes |
|----|-------|---------|----------|-----------|------------|-------|
| `erf` | PASS | PASS | PASS | PASS | PASS | `jax.scipy.special.erf` / `torch.special.erf` |
| `gamma` | PASS | PASS | PASS | PASS | PASS | torch via `exp(lgamma)` |
| `gammainc` (reg. lower) | PASS | PASS | PASS | PASS | PASS | `torch.special.gammainc` |
| `gammalinc` (unreg. lower) | PASS | PASS | PASS | PASS | PASS | composed `gammainc·gamma` |
| `gammauinc` (unreg. upper) | PASS | PASS | PASS | PASS | PASS | composed `gammaincc·gamma`; `torch.special.gammaincc` present |
| `beta` | PASS | PASS | PASS | PASS | PASS | torch via `lgamma` |
| `betainc` (reg. incomplete) | PASS | PASS | PASS | BLOCKED | BLOCKED | no torch native; scipy raises under grad |
| `mbessel0` (`i0`) | PASS | PASS | PASS | PASS | PASS | `torch.special.i0` |
| `mbessel1` (`i1`) | PASS | PASS | PASS | PASS | PASS | `torch.special.i1` |
| `hypergeometric` (`hyp2f1`) | PASS | PASS | PASS | BLOCKED | BLOCKED | jax: `jax.scipy.special.hyp2f1`; torch: none |
| `expi` | PASS | PASS | PASS | BLOCKED | BLOCKED | no torch native; jax 0-d works |
| `factorial` | PASS | PASS | PASS | PASS | PASS | torch via `exp(lgamma(n+1))` |
| `comb` | PASS | PASS | n/a | PASS | n/a | composed from `factorial` with `/ + round`; discrete, so gradient is not meaningful |

`n/a` for `comb` grad: `comb` is integer-valued and routed through `round`, so it
has no meaningful gradient (this matched the old `//` form, which was also
non-differentiable). `factorial` itself, built from the smooth `lgamma`, *is*
differentiable.

## Non-standard helpers

These are not special functions but are also provided by the extension layer:

- `apply_along_axis` — numpy/jax have native paths (jax via `vmap` with a
`jnp.apply_along_axis` fallback); **torch has no native equivalent**, so it is
the one **numpy round-trip** in the layer (PARTIAL): the per-slice callable is
evaluated through `numpy.apply_along_axis` and the result re-wrapped as a
tensor. This breaks autograd on torch and should be avoided in grad-sensitive
code paths; it also requires CPU tensors, since `numpy()` raises on CUDA
tensors.
- `cov` — `jnp.cov` / `torch.cov` (with `correction` mapped from `bias`) / `np.cov`.
- `indices` — `jnp.indices` / `np.indices` (torch wraps `np.indices`).
54 changes: 34 additions & 20 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,53 @@ sr.variogram_score(obs, fct)
For the univariate ensemble metrics, the ensemble dimension is on the last axis unless you specify otherwise with the `axis` argument. For the multivariate ensemble metrics, the ensemble dimension and the variable dimension are on the second last and last axis respectively, unless specified otherwise with `m_axis` and `v_axis`.

## Backends
Scoringrules supports multiple backends. By default, the `numpy` and `numba` backends will be registered when importing the library. You can see the list of registered backends with

Scoringrules runs every score across multiple array frameworks — numpy, jax, and
torch — from a single implementation. The framework is **inferred from the input
arrays**: pass numpy, jax, or torch arrays and the result is returned in the same
framework, with no configuration required.

```python
print(sr.backends)
# {'numpy': <scoringrules.backend.numpy.NumpyBackend at 0x2ba2d6f391b0>,
# 'numba': <scoringrules.backend.numpy.NumbaBackend at 0x2ba2d6f38ac0>}
```
import numpy as np
import scoringrules as sr

and the currently active backend, used by default in all metrics, can be seen with
sr.crps_normal(np.array([0.1]), np.array([0.0]), np.array([1.0])) # numpy array in, numpy array out
```

```python
print(sr.backends.active)
# <scoringrules.backend.numpy.NumpyBackend at 0x2ba2d6f38ac0>
```
import jax.numpy as jnp

The default backend can also be changed with
sr.crps_normal(jnp.array([0.1]), jnp.array([0.0]), jnp.array([1.0])) # jax array out
```

```python
sr.backends.set_active("numba")
print(sr.backends.active)
# <scoringrules.backend.numpy.NumbaBackend at 0x2ba2d6f38ac0>
import torch

mu = torch.tensor([0.0], requires_grad=True)
sr.crps_normal(torch.tensor([0.1]), mu, torch.tensor([1.0])) # torch tensor out; autograd preserved
```
When computing a metric, the `backend` argument can be used to override the default choice.

Inputs must come from a single framework; mixing (e.g. a numpy observation with a
torch forecast) raises an error.

### The numba fast path

To register a new backend, for example `torch`, simply use
For numpy inputs, opt into the compiled [numba](https://numba.pydata.org/) gufuncs
with `backend="numba"`:

```python
sr.register_backend("torch")
sr.crps_ensemble(np.random.randn(5), np.random.randn(5, 11), backend="numba")
```

You can now use `torch` to compute metrics, either by setting it as the default backend or by specifying it on a specific metric:
`backend="numba"` requires numpy-compatible inputs.

```python
sr.crps_normal(0.1, 1.0, 0.0, backend="torch")
```
### Deprecations

Selecting an array-API backend explicitly is deprecated and will be removed in 1.0
— the framework is inferred from the input instead. This applies to:

- the `backend="numpy"`, `backend="jax"`, and `backend="torch"` arguments, and
- `sr.register_backend(...)` / `sr.backends.set_active(...)` for those backends.

`backend="numba"` is **not** deprecated; it remains the supported way to reach the
numba fast path.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ requires-python = ">=3.12"
dependencies = [
"numpy>=2.0.0",
"scipy>=1.14.0",
"array-api-compat>=1.9",
]
authors = [
{name = "Francesco Zanetta", email = "zanetta.francesco@gmail.com"},
Expand Down Expand Up @@ -58,8 +59,10 @@ ignore = ["E741"]
[tool.coverage.run]
omit = [
"**/_gufuncs.py", # numba gufuncs are not python code
"**/_gufuncs_w.py",
"**/_gufunc.py",
"scoringrules/backend/*.py", # superfluous
"scoringrules/backend/base.py", # ABC, being removed
"scoringrules/backend/registry.py", # backend glue
"scoringrules/core/typing.py" # only type hints
]

Expand Down
Loading
Loading