Support for ONNX export #74
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request introduces a new ONNX export example notebook and refactors key flow and spline utility functions to support ONNX export. The changes are grouped below:
New Example and ONNX Export Support
examples/onnx_export.ipynb, a notebook demonstrating how to train a Neural Spline Flow model and export it to ONNX format. The notebook includes custom symbolic registration for ONNX export and validation usingonnxruntime.Refactoring and Numerical Stability in Mixing Flows
Mixingflow (inverse_no_cacheinnormflows/flows/mixing.py) by replacing triangular solve logic with a direct matrix inversion viaweight_inverse(), making it compatible with ONNX and keeping the same outputs as before.weight_inverse()innormflows/flows/mixing.pyto usetorch.inverse()directly, removing the previous two-step triangular solve.Spline Utility Improvements
unconstrained_rational_quadratic_splineinnormflows/utils/splines.pyto clarify handling of derivatives and output initialization, and to use clamping for input bounds. The current implementation ofunconstrained_rational_quadratic_splineuses boolean indexing (e.g.,inputs[mask]) to separate values inside the tail bounds from those outside. While this works in eager mode, it breaks ONNX export and TorchScript. When exporting, PyTorch traces the graph with a specific batch of data. If boolean indexing is used, the resulting ONNX graph hardcodes the specific number of elements that fell inside the mask during the trace (e.g., if 900/1000 elements are inside, the graph creates a Reshape node for size 900). During inference, if a different number of elements fall inside the bounds, ONNX Runtime crashes with a shape mismatch error. This PR pre-clamp the inputs to[-tail_bound, tail_bound]before passing them to the spline function to prevent numerical instability orNaNsfor out-of-bound values. I usetorch.where(mask, spline_output, identity_output)to merge the results.unconstrained_rational_quadratic_splineto usetorch.where, simplifying the masking logic and ensuring correct handling of outputs and log determinants for inputs inside and outside the interval. These both changes are needed for edge cases to work with ONNX.Let me know if something is unclear or need to be fixed.