Minimal Bayesian PPL for JAX inspired by DenseJax.
Write a plain log-density function, declare parameter constraints, get HMC samples.
import jax.numpy as jnp
import shadow
y = jnp.array([1.0, 1.5, 2.0, 1.8, 1.2])
def log_pdf(mu, sigma):
# prior
lp = -0.5 * (mu / 10.0) ** 2
# prior
lp += -0.5 * sigma
# likelihood
lp += -jnp.sum(jnp.log(sigma) + 0.5 * ((y - mu) / sigma) ** 2)
return lp
draws = shadow.sample(log_pdf, sigma=shadow.positive, num_draws=1000)That's it. No model class, no random variable objects, no DSL. Shadow automatically maps constrained parameters into unconstrained space and samples in that shadow domain via NUTS.
- You write
log_pdf(mu, sigma)using constrained parameter values and normal JAX math. - You call
shadow.sample(log_pdf, sigma=shadow.positive)— declaring which arguments have constraints. - Shadow inspects the function signature, wraps it with a bijective transform + log-det-Jacobian correction to create an unconstrained target, and runs a pure-JAX NUTS sampler with windowed adaptation.
- Draws come back on the constrained (natural) scale as a
dict[str, ndarray].
Parameters with no declared constraint default to shadow.real (unconstrained).
pip install -e .Requires Python ≥ 3.10, JAX ≥ 0.4.20. No other dependencies.
| Constraint | Example | Domain |
|---|---|---|
shadow.real |
(default) | |
shadow.positive |
sigma=shadow.positive |
|
shadow.softplus |
sigma=shadow.softplus |
|
shadow.unit_interval |
p=shadow.unit_interval |
|
shadow.ordered |
cuts=shadow.ordered |
sorted vector |
shadow.simplex |
theta=shadow.simplex |
sums to 1 |
shadow.corr_cholesky |
L=shadow.corr_cholesky |
Cholesky of correlation matrix |
shadow.cov_cholesky |
L=shadow.cov_cholesky |
Cholesky of covariance matrix |
shadow.lower_bounded(a) |
x=shadow.lower_bounded(2.0) |
|
shadow.upper_bounded(b) |
x=shadow.upper_bounded(10.0) |
|
shadow.bounded(a, b) |
x=shadow.bounded(0, 10) |
shadow.sample(
log_density_fn, # callable: params -> scalar log-density
*,
init=None, # dict of initial values (constrained scale)
num_draws=1000,
num_warmup=1000,
num_chains=1,
seed=0,
**constraints, # e.g. sigma=shadow.positive
)init— required when any parameter is non-scalar (so shape can be inferred). Values are on the constrained scale.**constraints— keyword arguments mapping parameter names to constraint objects. Unmentioned parameters default toshadow.real.
Returns dict[str, jnp.ndarray] with shape (num_draws, *param_shape).
For non-scalar parameters (vectors, matrices), pass init so shadow knows the shape:
# W is a (3, 2) weight matrix, b is a (2,) bias vector
def log_pdf(W, b, sigma):
pred = X @ W + b
lp = -0.5 * jnp.sum((W / 5.0) ** 2)
lp += -0.5 * jnp.sum((b / 5.0) ** 2)
lp += -sigma
lp += -jnp.sum(jnp.log(sigma) + 0.5 * ((Y - pred) / sigma) ** 2)
return lp
draws = shadow.sample(
log_pdf,
sigma=shadow.positive,
init={"W": jnp.zeros((3, 2)), "b": jnp.zeros(2), "sigma": 1.0},
)
# draws["W"].shape == (1000, 3, 2)
# draws["b"].shape == (1000, 2)Shadow also ships optional log-density helpers if you prefer readable notation over raw math:
def log_pdf(mu, sigma):
lp = shadow.normal(mu, 0, 10)
lp += shadow.exponential(sigma, 0.5)
lp += shadow.normal(y, mu, sigma)
return lpThese are thin wrappers around JAX math — no tracing or side effects. Available: normal, half_normal, log_normal, exponential, gamma, inv_gamma, beta_dist, uniform, cauchy, student_t, double_exponential, bernoulli_logit, poisson, binomial_logit, neg_binomial_2, multi_normal, dirichlet, lkj_corr_cholesky.
Any object with constrain, unconstrain, and log_det_jacobian methods works as a constraint — no base class needed:
class MyPositive:
"""R -> R+ via softplus with a custom shift."""
def constrain(self, x):
return jax.nn.softplus(x) + 1e-6
def unconstrain(self, y):
z = y - 1e-6
return z + jnp.log(-jnp.expm1(-z))
def log_det_jacobian(self, x):
return jnp.sum(jax.nn.log_sigmoid(x))
draws = shadow.sample(log_pdf, sigma=MyPositive())pip install -e ".[test]"
pytest tests/ -v # fast tests only
pytest tests/ -v -m slow # NUTS integration tests
pytest tests/ -v -m "not slow" # skip slow testsMIT