Skip to content

Conversation

@Stanpol
Copy link

@Stanpol Stanpol commented Nov 20, 2025

This pull request introduces a new ONNX export example notebook and refactors key flow and spline utility functions to support ONNX export. The changes are grouped below:

New Example and ONNX Export Support

  • Added examples/onnx_export.ipynb, a notebook demonstrating how to train a Neural Spline Flow model and export it to ONNX format. The notebook includes custom symbolic registration for ONNX export and validation using onnxruntime.

Refactoring and Numerical Stability in Mixing Flows

  • Simplified the inverse operation in Mixing flow (inverse_no_cache in normflows/flows/mixing.py) by replacing triangular solve logic with a direct matrix inversion via weight_inverse(), making it compatible with ONNX and keeping the same outputs as before.
  • Refactored weight_inverse() in normflows/flows/mixing.py to use torch.inverse() directly, removing the previous two-step triangular solve.

Spline Utility Improvements

  • Refactored unconstrained_rational_quadratic_spline in normflows/utils/splines.py to clarify handling of derivatives and output initialization, and to use clamping for input bounds. The current implementation of unconstrained_rational_quadratic_spline uses boolean indexing (e.g., inputs[mask]) to separate values inside the tail bounds from those outside. While this works in eager mode, it breaks ONNX export and TorchScript. When exporting, PyTorch traces the graph with a specific batch of data. If boolean indexing is used, the resulting ONNX graph hardcodes the specific number of elements that fell inside the mask during the trace (e.g., if 900/1000 elements are inside, the graph creates a Reshape node for size 900). During inference, if a different number of elements fall inside the bounds, ONNX Runtime crashes with a shape mismatch error. This PR pre-clamp the inputs to [-tail_bound, tail_bound] before passing them to the spline function to prevent numerical instability or NaNs for out-of-bound values. I use torch.where(mask, spline_output, identity_output) to merge the results.
  • Updated the output assignment in unconstrained_rational_quadratic_spline to use torch.where, simplifying the masking logic and ensuring correct handling of outputs and log determinants for inputs inside and outside the interval. These both changes are needed for edge cases to work with ONNX.

Let me know if something is unclear or need to be fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant