Skip to content

markusheinonen/shadow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Shadow

Minimal Bayesian PPL for JAX inspired by DenseJax.

Write a plain log-density function, declare parameter constraints, get HMC samples.

import jax.numpy as jnp
import shadow

y = jnp.array([1.0, 1.5, 2.0, 1.8, 1.2])

def log_pdf(mu, sigma):
    # prior
    lp  = -0.5 * (mu / 10.0) ** 2
    # prior
    lp += -0.5 * sigma
    # likelihood 
    lp += -jnp.sum(jnp.log(sigma) + 0.5 * ((y - mu) / sigma) ** 2)
    return lp

draws = shadow.sample(log_pdf, sigma=shadow.positive, num_draws=1000)

That's it. No model class, no random variable objects, no DSL. Shadow automatically maps constrained parameters into unconstrained space and samples in that shadow domain via NUTS.

How it works

  1. You write log_pdf(mu, sigma) using constrained parameter values and normal JAX math.
  2. You call shadow.sample(log_pdf, sigma=shadow.positive) — declaring which arguments have constraints.
  3. Shadow inspects the function signature, wraps it with a bijective transform + log-det-Jacobian correction to create an unconstrained target, and runs a pure-JAX NUTS sampler with windowed adaptation.
  4. Draws come back on the constrained (natural) scale as a dict[str, ndarray].

Parameters with no declared constraint default to shadow.real (unconstrained).

Install

pip install -e .

Requires Python ≥ 3.10, JAX ≥ 0.4.20. No other dependencies.

Constraints

Constraint Example Domain
shadow.real (default) $\mathbb{R}$
shadow.positive sigma=shadow.positive $(0, \infty)$ via exp
shadow.softplus sigma=shadow.softplus $(0, \infty)$ via softplus
shadow.unit_interval p=shadow.unit_interval $(0, 1)$
shadow.ordered cuts=shadow.ordered sorted vector
shadow.simplex theta=shadow.simplex sums to 1
shadow.corr_cholesky L=shadow.corr_cholesky Cholesky of correlation matrix
shadow.cov_cholesky L=shadow.cov_cholesky Cholesky of covariance matrix
shadow.lower_bounded(a) x=shadow.lower_bounded(2.0) $(a, \infty)$
shadow.upper_bounded(b) x=shadow.upper_bounded(10.0) $(-\infty, b)$
shadow.bounded(a, b) x=shadow.bounded(0, 10) $(a, b)$

sample() API

shadow.sample(
    log_density_fn,       # callable: params -> scalar log-density
    *,
    init=None,            # dict of initial values (constrained scale)
    num_draws=1000,
    num_warmup=1000,
    num_chains=1,
    seed=0,
    **constraints,        # e.g. sigma=shadow.positive
)
  • init — required when any parameter is non-scalar (so shape can be inferred). Values are on the constrained scale.
  • **constraints — keyword arguments mapping parameter names to constraint objects. Unmentioned parameters default to shadow.real.

Returns dict[str, jnp.ndarray] with shape (num_draws, *param_shape).

Array parameters

For non-scalar parameters (vectors, matrices), pass init so shadow knows the shape:

# W is a (3, 2) weight matrix, b is a (2,) bias vector
def log_pdf(W, b, sigma):
    pred = X @ W + b
    lp  = -0.5 * jnp.sum((W / 5.0) ** 2)
    lp += -0.5 * jnp.sum((b / 5.0) ** 2)
    lp += -sigma
    lp += -jnp.sum(jnp.log(sigma) + 0.5 * ((Y - pred) / sigma) ** 2)
    return lp

draws = shadow.sample(
    log_pdf,
    sigma=shadow.positive,
    init={"W": jnp.zeros((3, 2)), "b": jnp.zeros(2), "sigma": 1.0},
)
# draws["W"].shape == (1000, 3, 2)
# draws["b"].shape == (1000, 2)

Convenience distributions

Shadow also ships optional log-density helpers if you prefer readable notation over raw math:

def log_pdf(mu, sigma):
    lp  = shadow.normal(mu, 0, 10)
    lp += shadow.exponential(sigma, 0.5)
    lp += shadow.normal(y, mu, sigma)
    return lp

These are thin wrappers around JAX math — no tracing or side effects. Available: normal, half_normal, log_normal, exponential, gamma, inv_gamma, beta_dist, uniform, cauchy, student_t, double_exponential, bernoulli_logit, poisson, binomial_logit, neg_binomial_2, multi_normal, dirichlet, lkj_corr_cholesky.

Custom constraints

Any object with constrain, unconstrain, and log_det_jacobian methods works as a constraint — no base class needed:

class MyPositive:
    """R -> R+ via softplus with a custom shift."""
    def constrain(self, x):
        return jax.nn.softplus(x) + 1e-6
    def unconstrain(self, y):
        z = y - 1e-6
        return z + jnp.log(-jnp.expm1(-z))
    def log_det_jacobian(self, x):
        return jnp.sum(jax.nn.log_sigmoid(x))

draws = shadow.sample(log_pdf, sigma=MyPositive())

Tests

pip install -e ".[test]"
pytest tests/ -v                     # fast tests only
pytest tests/ -v -m slow             # NUTS integration tests
pytest tests/ -v -m "not slow"       # skip slow tests

License

MIT

About

Minimal Bayesian PPL with pure-JAX NUTS sampling

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages