Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214
Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214mattlevine22 wants to merge 6 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends continuous-time SDE model support by introducing a structured diffusion specification API (constant scalar/diag/full or callable) and plumbing those semantics through validation, solvers, and the cd-dynamax integration path.
Changes:
- Add
dynestyx/models/diffusions.pyto standardize diffusion evaluation/inference (diffusion_type,bm_dim) and conversions (matrix/covariance/application). - Expand
ContinuousTimeStateEvolutionto acceptDiffusionSpec+ optionaldiffusion_type, update core validation/inference accordingly. - Update SDE solver internals and cd-dynamax integration utilities to use the shared diffusion semantics; expand test coverage across models/discretizers/plates.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
dynestyx/models/diffusions.py |
New helper module implementing evaluation, type inference, covariance/matrix conversion, and applying diffusion to Brownian increments. |
dynestyx/models/core.py |
Extends ContinuousTimeStateEvolution API to accept structured diffusion specs and optional diffusion_type. |
dynestyx/models/checkers.py |
Updates bm_dim inference/validation to use the new diffusion evaluation/metadata resolution helpers. |
dynestyx/solvers/sde.py |
Replaces ad-hoc diffusion handling with shared helpers for EM moments/sampling and diffrax diffusion callbacks. |
dynestyx/inference/integrations/cd_dynamax/utils.py |
Normalizes/validates diffusion for cd-dynamax using the new diffusion semantics and adds early validation. |
dynestyx/models/lti_dynamics.py |
Updates LTI_continuous to pass constant diffusion matrix directly using the new API. |
dynestyx/models/__init__.py |
Re-exports DiffusionType. |
tests/test_models_core.py |
Adds unit tests for new diffusion semantics and cd-dynamax diffusion constraints/normalization. |
tests/test_discretizers.py |
Adds parametrized tests ensuring structured diffusion forms match expected dense covariance behavior. |
tests/test_bm_dim_plate.py |
Expands plate-context tests to cover structured diffusion specs and expected bm_dim resolution behavior. |
tests/test_hierarchical_simulator_discretizer_smokes.py |
Updates/clarifies a comment related to rectangular diffusion behavior in cd-dynamax integration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| bm_dim = diffusion_as_matrix(diffusion, state_dim=state_dim).shape[-1] | ||
| if bm_dim > state_dim: | ||
| raise ValueError( | ||
| "Continuous cd-dynamax filters require bm_dim <= state_dim. " | ||
| f"Got state_dim={state_dim}, bm_dim={bm_dim}." | ||
| ) |
There was a problem hiding this comment.
The PR description says continuous cd-dynamax filtering requires square full diffusion (bm_dim == state_dim / no rectangular diffusion), but the implemented validation only rejects bm_dim > state_dim and _normalize_cd_dynamax_diffusion explicitly pads when bm_dim < state_dim. Please reconcile the intended constraint: either tighten validation to require bm_dim == state_dim (and possibly diffusion_type == 'full') or update the PR description/docs to match the current bm_dim <= state_dim + zero-padding behavior.
There was a problem hiding this comment.
@copilot update the docs/descriptions to match the current behavior.
There was a problem hiding this comment.
Done in commit aa50c38. Updated the PR description Notes section and renamed the two misleading test functions to accurately reflect that only bm_dim > state_dim is rejected — rectangular diffusion with bm_dim < state_dim is zero-padded and supported.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…traint Agent-Logs-Url: https://github.com/BasisResearch/dynestyx/sessions/9646e756-8a5e-47b4-9c36-33cb58c2110e Co-authored-by: mattlevine22 <11492591+mattlevine22@users.noreply.github.com>
|
Thanks, this looks interesting!!
Out of curiosity, did you try for state estimation only or for parameter inference? |
Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different? |
@DanWaxman I looked at value_and_grad, but saw inconsistent improvements. The best I saw was like a 5% gain. I think this is because currently, we only do this multiply once-per-observation (and we have other operations that occur per-observation that dominate). When we have fancier discretization that take many steps between observations, it will be more impactful I think. I'd recommend for now introducing the API and leveraging in the simple SDE solver where we know it clearly helps (esp. in large dims), then leveraging diagonal structure as it becomes usable on the methods end. |
That makes sense! |
Addresses #212
Summary
dynestyx/models/diffusions.py) that standardizes evaluation and conversion of diffusion inputs.ContinuousTimeStateEvolutionAPI sodiffusion_coefficientcan be:L(x, u, t)and can be annotated with
diffusion_type="full" | "diag" | "scalar".bm_dimbehavior:(..., state_dim, bm_dim)if not provided1orstate_dimbm_dim != state_dim).LTI_continuousto use the new constant-matrix API directly (diffusion_coefficient=L,diffusion_type="full").Expanded Diffusion API (new behavior)
For
ContinuousTimeStateEvolution(...):diffusion_coefficientnow accepts DiffusionSpec:lambda x, u, t: ...diffusion_typeis optional:"full": treat value as matrix(..., state_dim, bm_dim)"diag": treat value as diagonal entries(..., state_dim)"scalar": treat value as scalar()or(..., 1)diffusion_typeis omitted, behavior is inferred from shape for backward compatibility:"full"state_dim->"diag"1->"scalar"bm_dimrules:"full""diag"and"scalar"; must be1orstate_dimTests
tests/test_models_core.pytests/test_discretizers.pytests/test_bm_dim_plate.pytests/test_hierarchical_simulator_discretizer_smokes.pyWhy
diffusion_coefficient = Lif you want).It also makes invalid setups fail early with clear errors, especially in continuous cd-dynamax filtering where only full square diffusion is currently supported.
Notes
bm_dim == state_dim.Below image showing speed-ups of solver using scalar/diag/full diffusion coefficient in SDE.
