Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/api_reference/public/inference/filter_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ The single `Filter()` handler is directed to the appropriate filtering algorithm
| Config class | Time domain | When it fits best |
|----------------------------|---------------------|-------------------|
| `KFConfig` | Discrete | Linear-Gaussian dynamics and linear-Gaussian observations (exact & optimal). |
| `EKFConfig` | Discrete | Nonlinear, differentiable Gaussian dynamics, nonlinear but differentiable Gaussian observations (approximate). *(default)*. |
| `EnKFConfig` | Discrete | Nonlinear or expensive models with Gaussian observations; cuthbert-backed and a good general-purpose default. *(default)* |
| `EKFConfig` | Discrete | Nonlinear, differentiable Gaussian dynamics, nonlinear (and with `cuthbert`, non-Gaussian) but differentiable observations (approximate). |
| `UKFConfig` | Discrete | Nonlinear, differentiable Gaussian dynamics, nonlinear but differentiable Gaussian observations (approximate). Generally more accurate, but slower than `EKFConfig`. |
| `EnKFConfig` | Discrete | High-dimensional or expensive models with lower-dimensional structure and Gaussian observations (approximate). |
| `PFConfig` | Discrete | Applicable for arbitrary state-space models, but quite expensive and noisy estimates (asymptotically exact in the limit of infinite particles, approximate in practice). |
| `HMMConfig` | Discrete (HMM) | Finite discrete latent state space (exact & optimal). |
| `ContinuousTimeKFConfig` | Continuous-discrete | Linear-Gaussian SDE + linear-Gaussian observations (exact and optimal). |
Expand Down Expand Up @@ -48,4 +48,4 @@ The single `Filter()` handler is directed to the appropriate filtering algorithm
::: dynestyx.inference.filter_configs
options:
members:
- HMMConfig
- HMMConfig
76 changes: 43 additions & 33 deletions docs/deep_dives/discrete_time_lti_profile_likelihood.ipynb
Comment thread
mattlevine22 marked this conversation as resolved.

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ with Filter(filter_config=HMMConfig()):
return model(obs_times=obs_times, obs_values=obs_values)
```

- **Discrete-time**: Either a **Simulator** (NUTS samples both parameters and latent states) or a **Filter** (pseudo-marginal MCMC—parameters only). Note: the usage of discrete-time filters is currently under active development (likely incorrect implementations).
- **Discrete-time**: Either a **Simulator** (NUTS samples both parameters and latent states) or a **Filter** (parameters only, with latent states marginalized by a filtering algorithm). `Filter()` defaults to the cuthbert-backed EnKF for Gaussian observation models. Use `PFConfig` when you need non-Gaussian observations or a fully particle-based approximation.
For explicit representation of latent states (NUTS / SVI do all the work of parameter and latent state inference), use the simulator approach (currently working reliably), do:
```python
with DiscreteTimeSimulator():
return model(obs_times=obs_times, obs_values=obs_values)
```
For filter-based marginalization (currently not working reliably), do:
For filter-based marginalization with the default EnKF, do:
```python
with Filter():
return model(obs_times=obs_times, obs_values=obs_values)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Other JAX-based libraries for dynamical systems:
- **[dynamax](https://github.com/probml/dynamax)** — Discrete-time state space models with linear/non-linear Kalman filters and Bayesian parameter estimation
- **[cd-dynamax](https://github.com/hd-UQ/cd_dynamax)** — Continuous-discrete state space models with EnKF, EKF, UKF, PF and Bayesian parameter estimation
- **[PFJax](https://pfjax.readthedocs.io/en/latest/)** — Nonlinear and non-Gaussian discrete-time models with particle filters and particle MCMC
- **[Cuthbert](https://state-space-models.github.io/cuthbert/)** — Discrete-time state space models with linear/non-linear Kalman (and Particle Filters) filters, options for associative scans.
- **[Cuthbert](https://state-space-models.github.io/cuthbert/)** — Discrete-time state space models with linear/non-linear Kalman, ensemble Kalman, and particle filters, plus options for associative scans.
- **[diffrax](https://docs.kidger.site/diffrax/)** - Numerical differential equation solvers.

Other probabilistic programming languages with support for dynamical systems:
Expand Down
19 changes: 11 additions & 8 deletions dynestyx/inference/filter_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ class BaseFilterConfig:
class EnKFConfig(BaseFilterConfig):
r"""Ensemble Kalman Filter (EnKF) for discrete-time models.

A good general-purpose filter for nonlinear models. Works with any
The **default filter** for discrete-time models. A good general-purpose
filter for nonlinear models with Gaussian observations. Works with any
differentiable or non-differentiable dynamics and scales well to moderate
state dimensions. Cheaper per-step than the particle filter, but assumes
observations are approximately Gaussian given the ensemble.

The observation noise covariance must be **state-independent** (it may
still depend on time or controls). Using a state-dependent scale with the
cuthbert backend raises a `ValueError`; if you need heteroscedastic noise,
use `PFConfig` instead.

The primary tuning knob is `n_particles`, with more particles providing
more accurate results at the cost of higher compute.
If the ensemble collapses over long trajectories, increase
Expand All @@ -108,7 +114,7 @@ class EnKFConfig(BaseFilterConfig):
inflation_delta (float | None): Scale ensemble anomalies by
\(\sqrt{1 + \delta}\) before the update to prevent collapse.
`None` disables inflation.
filter_source (FilterSource): Backend. Defaults to `"cd_dynamax"`.
filter_source (FilterSource): Backend. Defaults to `"cuthbert"`.

??? note "Algorithm Reference"
The ensemble Kalman filter comprises ensemble members $x_t^{(i)}, i = 1, \ldots, N_{\text{particles}}$.
Expand Down Expand Up @@ -159,7 +165,7 @@ class EnKFConfig(BaseFilterConfig):
)
perturb_measurements: bool | None = None
inflation_delta: float | None = None
filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax"
filter_source: CuthbertOnlyFilterSource = "cuthbert"


@dataclasses.dataclass
Expand Down Expand Up @@ -282,9 +288,6 @@ class EKFConfig(BaseFilterConfig):

This is exact (but wasteful) for linear-Gaussian models.

This is the **default discrete-time filter** when no `filter_config` is
passed to `Filter`.

Attributes:
filter_emission_order (FilterEmissionOrder): Linearisation order for
the observation function. `"first"` *(default)* is the standard
Expand Down Expand Up @@ -375,7 +378,7 @@ class KFConfig(BaseFilterConfig):
- For more details on the `cuthbert` implementation, see the [cuthbert documentation](https://state-space-models.github.io/cuthbert/cuthbert_api/gaussian/kalman/).
"""

filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax"
filter_source: CuthbertOrCDDynamaxFilterSource = "cd_dynamax"


@dataclasses.dataclass
Expand Down Expand Up @@ -520,7 +523,7 @@ class ContinuousTimeEnKFConfig(EnKFConfig, ContinuousTimeConfig):
[Available Online](https://epubs.siam.org/doi/abs/10.1137/21M1434477).
"""

filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax"
filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax" # type: ignore[assignment]


@dataclasses.dataclass
Expand Down
26 changes: 8 additions & 18 deletions dynestyx/inference/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from dynestyx.inference.integrations.utils import (
WeightedParticles,
covariance_from_cholesky,
particles_to_delta_mixtures,
)
from dynestyx.models import DynamicalModel
Expand Down Expand Up @@ -112,25 +113,15 @@ def _cuthbert_states_to_dists(
if isinstance(config, PFConfig):
particles = states.particles
log_weights = states.log_weights
# cuthbert includes an init step at index 0; align with dynestyx T convention.
particles = particles[
(slice(None),) * len(plate_shapes) + (slice(1, None), ...)
]
log_weights = log_weights[
(slice(None),) * len(plate_shapes) + (slice(1, None), ...)
]
return _particle_to_batched_dists(
particles,
log_weights,
plate_shapes=plate_shapes,
)

# Kalman / Taylor-KF variants expose mean/chol_cov and include init at index 0.
mean = states.mean[(slice(None),) * len(plate_shapes) + (slice(1, None), ...)]
chol_cov = states.chol_cov[
(slice(None),) * len(plate_shapes) + (slice(1, None), ...)
]
cov = jnp.matmul(chol_cov, jnp.swapaxes(chol_cov, -1, -2))
mean = states.mean
chol_cov = states.chol_cov
cov = covariance_from_cholesky(chol_cov)
t_len = _time_len_from_array(mean, plate_shapes)
return [
numpyro.distributions.MultivariateNormal(
Expand Down Expand Up @@ -297,8 +288,7 @@ def _default_filter_config(dynamics: DynamicalModel):
if dynamics.continuous_time:
return ContinuousTimeEnKFConfig()

# default to particle filter in discrete time
return EKFConfig(filter_source="cuthbert")
return EnKFConfig()


@dataclasses.dataclass
Expand Down Expand Up @@ -342,7 +332,7 @@ class Filter(BaseLogFactorAdder):
If `filter_config=None`, defaults are:

- `ContinuousTimeEnKFConfig()` for continuous-time models, and
- `EKFConfig(filter_source="cuthbert")` for discrete-time models.
- `EnKFConfig()` for discrete-time models.

Notes:
- If your latent state is *discrete* (an HMM), you must use `HMMConfig`.
Expand Down Expand Up @@ -643,8 +633,8 @@ def _filter_discrete_time(
) -> list[numpyro.distributions.Distribution]:
"""Discrete-time marginal likelihood via cuthbert or cd-dynamax.

Filter type inferred from config class: KFConfig, EKFConfig, UKFConfig (cd-dynamax)
or EKFConfig (cuthbert), PFConfig (cuthbert).
Filter type inferred from config class: KFConfig, EKFConfig, UKFConfig
(cd-dynamax) or KFConfig, EKFConfig, EnKFConfig, PFConfig (cuthbert).

Args:
name: Name of the factor.
Expand Down
Loading
Loading