Skip to content

Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214

Open
mattlevine22 wants to merge 6 commits intomainfrom
ml-feature-212
Open

Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214
mattlevine22 wants to merge 6 commits intomainfrom
ml-feature-212

Conversation

@mattlevine22
Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 commented Apr 27, 2026

Addresses #212

Summary

  • Adds a new diffusion helper module (dynestyx/models/diffusions.py) that standardizes evaluation and conversion of diffusion inputs.
  • Expands the ContinuousTimeStateEvolution API so diffusion_coefficient can be:
    • a callable L(x, u, t)
    • a constant scalar
    • a constant vector
    • a constant matrix
      and can be annotated with diffusion_type = "full" | "diag" | "scalar".
  • Clarifies bm_dim behavior:
    • full diffusion: inferred from shape (..., state_dim, bm_dim) if not provided
    • diag/scalar diffusion: must be provided explicitly and must be 1 or state_dim
  • Updates SDE solver internals to use the same diffusion semantics everywhere (EM moments, sampling, and diffrax solve path).
  • Updates cd-dynamax continuous integration to validate diffusion shape up front and reject unsupported rectangular diffusion for continuous filters (bm_dim != state_dim).
  • Updates LTI_continuous to use the new constant-matrix API directly (diffusion_coefficient=L, diffusion_type="full").
  • Expands tests across core model validation, discretizers, plate/bm_dim behavior, and hierarchical smokes.

Expanded Diffusion API (new behavior)

For ContinuousTimeStateEvolution(...):

  • diffusion_coefficient now accepts DiffusionSpec:
    • callable: lambda x, u, t: ...
    • constant value: scalar / vector / matrix
  • diffusion_type is optional:
    • "full": treat value as matrix (..., state_dim, bm_dim)
    • "diag": treat value as diagonal entries (..., state_dim)
    • "scalar": treat value as scalar () or (..., 1)
  • If diffusion_type is omitted, behavior is inferred from shape for backward compatibility:
    • matrix-like -> "full"
    • trailing dim state_dim -> "diag"
    • scalar or trailing dim 1 -> "scalar"
  • bm_dim rules:
    • inferred automatically only for "full"
    • required for "diag" and "scalar"; must be 1 or state_dim

Tests

  • Added/updated coverage in:
    • tests/test_models_core.py
    • tests/test_discretizers.py
    • tests/test_bm_dim_plate.py
    • tests/test_hierarchical_simulator_discretizer_smokes.py

Why

  • This makes specifying an SDE model easier (you can just write diffusion_coefficient = L if you want).
  • This allows for faster SDE solves in high-dimensions when diffusion coefficient happens to be scalar/diagonal. (documented below)
  • Sets up a future where structured inference methods exploit scalar/diag cases. Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

It also makes invalid setups fail early with clear errors, especially in continuous cd-dynamax filtering where only full square diffusion is currently supported.

Notes

  • Continuous cd-dynamax filters now explicitly require full diffusion with bm_dim == state_dim.
  • Rectangular diffusion remains supported in simulation-oriented EM paths, but not in this cd-dynamax continuous filter backend.

Below image showing speed-ups of solver using scalar/diag/full diffusion coefficient in SDE.
image

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends continuous-time SDE model support by introducing a structured diffusion specification API (constant scalar/diag/full or callable) and plumbing those semantics through validation, solvers, and the cd-dynamax integration path.

Changes:

  • Add dynestyx/models/diffusions.py to standardize diffusion evaluation/inference (diffusion_type, bm_dim) and conversions (matrix/covariance/application).
  • Expand ContinuousTimeStateEvolution to accept DiffusionSpec + optional diffusion_type, update core validation/inference accordingly.
  • Update SDE solver internals and cd-dynamax integration utilities to use the shared diffusion semantics; expand test coverage across models/discretizers/plates.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
dynestyx/models/diffusions.py New helper module implementing evaluation, type inference, covariance/matrix conversion, and applying diffusion to Brownian increments.
dynestyx/models/core.py Extends ContinuousTimeStateEvolution API to accept structured diffusion specs and optional diffusion_type.
dynestyx/models/checkers.py Updates bm_dim inference/validation to use the new diffusion evaluation/metadata resolution helpers.
dynestyx/solvers/sde.py Replaces ad-hoc diffusion handling with shared helpers for EM moments/sampling and diffrax diffusion callbacks.
dynestyx/inference/integrations/cd_dynamax/utils.py Normalizes/validates diffusion for cd-dynamax using the new diffusion semantics and adds early validation.
dynestyx/models/lti_dynamics.py Updates LTI_continuous to pass constant diffusion matrix directly using the new API.
dynestyx/models/__init__.py Re-exports DiffusionType.
tests/test_models_core.py Adds unit tests for new diffusion semantics and cd-dynamax diffusion constraints/normalization.
tests/test_discretizers.py Adds parametrized tests ensuring structured diffusion forms match expected dense covariance behavior.
tests/test_bm_dim_plate.py Expands plate-context tests to cover structured diffusion specs and expected bm_dim resolution behavior.
tests/test_hierarchical_simulator_discretizer_smokes.py Updates/clarifies a comment related to rectangular diffusion behavior in cd-dynamax integration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread dynestyx/models/diffusions.py Outdated
Comment thread dynestyx/inference/integrations/cd_dynamax/utils.py Outdated
Comment on lines +84 to +89
bm_dim = diffusion_as_matrix(diffusion, state_dim=state_dim).shape[-1]
if bm_dim > state_dim:
raise ValueError(
"Continuous cd-dynamax filters require bm_dim <= state_dim. "
f"Got state_dim={state_dim}, bm_dim={bm_dim}."
)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description says continuous cd-dynamax filtering requires square full diffusion (bm_dim == state_dim / no rectangular diffusion), but the implemented validation only rejects bm_dim > state_dim and _normalize_cd_dynamax_diffusion explicitly pads when bm_dim < state_dim. Please reconcile the intended constraint: either tighten validation to require bm_dim == state_dim (and possibly diffusion_type == 'full') or update the PR description/docs to match the current bm_dim <= state_dim + zero-padding behavior.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot update the docs/descriptions to match the current behavior.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit aa50c38. Updated the PR description Notes section and renamed the two misleading test functions to accurately reflect that only bm_dim > state_dim is rejected — rectangular diffusion with bm_dim < state_dim is zero-padded and supported.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

mattlevine22 and others added 2 commits April 26, 2026 23:15
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@DanWaxman
Copy link
Copy Markdown
Collaborator

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different?

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different?

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different?

@DanWaxman I looked at value_and_grad, but saw inconsistent improvements. The best I saw was like a 5% gain. I think this is because currently, we only do this multiply once-per-observation (and we have other operations that occur per-observation that dominate). When we have fancier discretization that take many steps between observations, it will be more impactful I think.

I'd recommend for now introducing the API and leveraging in the simple SDE solver where we know it clearly helps (esp. in large dims), then leveraging diagonal structure as it becomes usable on the methods end.

@mattlevine22 mattlevine22 marked this pull request as ready for review April 27, 2026 20:00
@DanWaxman
Copy link
Copy Markdown
Collaborator

@DanWaxman I looked at value_and_grad, but saw inconsistent improvements. The best I saw was like a 5% gain. I think this is because currently, we only do this multiply once-per-observation (and we have other operations that occur per-observation that dominate). When we have fancier discretization that take many steps between observations, it will be more impactful I think.

I'd recommend for now introducing the API and leveraging in the simple SDE solver where we know it clearly helps (esp. in large dims), then leveraging diagonal structure as it becomes usable on the methods end.

That makes sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants