-
Notifications
You must be signed in to change notification settings - Fork 18
[Lang] Add qd.math.fma(...) single-rounding fused multiply-add #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: duburcqa/qd_precise
Are you sure you want to change the base?
Changes from all commits
a9942ed
8f02070
5438801
2669dc5
48ce67b
44c99f8
97d7fb6
c5fbab6
a44208d
b229ca5
d3eb88f
c59c542
68b7c17
8f71366
3d641ac
9f121bb
94020dc
31cab48
40aba10
3455440
8e98e16
4425393
a03801f
bfadc9b
e3203cd
fbd6e40
4d4539d
ab0b576
e391298
bfdf37f
4b85bf8
e3a4795
7998277
1837821
e412219
e9a55c1
e98c703
bf232b6
b66c9d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ scalar_tensors | |
| matrix_vector | ||
| compound_types | ||
| static | ||
| precise | ||
| sub_functions | ||
| parallelization | ||
| ``` | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| # qd.precise | ||
|
|
||
| `qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no non-IEEE-exact algebraic simplification, regardless of the module-level `fast_math` setting. Folds that are IEEE-exact for every input (e.g. `a - 0 -> a`, `a > a -> false`) are still applied. It is equivalent to the `precise` keyword in MSL / HLSL. | ||
|
|
||
| ## Why | ||
|
|
||
| Quadrants compiles kernels with `fast_math=True` by default. Under that mode the compiler is free to: | ||
|
|
||
| - **reassociate** FP ops (e.g. `(a + b) + c -> a + (b + c)`) | ||
| - **contract** mul-then-add into FMA | ||
| - **substitute approximations** for `sqrt`, `sin`, `cos`, `log`, `1/x` | ||
| - **algebraically simplify** (e.g. `a - a -> 0`, `a / a -> 1`) | ||
|
|
||
| This silently destroys compensated-arithmetic primitives (Dekker / Kahan 2Sum, Veltkamp split, double-single accumulators) whose entire correctness rests on the fact that `(a - aa) + (b - bb)` is non-zero under IEEE arithmetic. The traditional workaround is to flip the global `fast_math=False` switch, but that pays the perf cost everywhere, even when only a handful of lines need IEEE semantics. | ||
|
|
||
| `qd.precise(expr)` is the per-expression opt-in: keep `fast_math=True` globally for speed, and wrap the expressions that must be IEEE-exact. | ||
|
|
||
| ## Basic usage | ||
|
|
||
| ```python | ||
| @qd.func | ||
| def fast_two_sum(a, b): | ||
| s = qd.precise(a + b) | ||
| e = qd.precise(b - (s - a)) # would fold to 0 under fast-math without precise | ||
| return s, e | ||
| ``` | ||
|
|
||
| Any expression value can be wrapped. The wrapper returns the same expression with every reachable FP op tagged as precise; at codegen time the tagged ops opt out of the optimizations above. | ||
|
|
||
| ## What gets protected | ||
|
|
||
| `qd.precise` walks the wrapped expression tree and tags: | ||
|
|
||
| - Every `BinaryOp` (`+`, `-`, `*`, `/`, `%`, FP comparisons) | ||
| - Every `UnaryOp` (`neg`, `sqrt`, `sin`, `cos`, `log`, `exp`, `rsqrt`, casts, bit_cast, ...) | ||
|
|
||
| Bitwise operations (`bit_and`, `bit_or`, `bit_xor`, `bit_shl`, `bit_sar`) are integer-domain; the walker tags them for completeness but the flag has no effect on integer IR. | ||
|
|
||
| The walker descends through `BinaryOp`, `UnaryOp`, and `TernaryOp` (e.g. `qd.select`) nodes, so wrapping a composite expression protects the inner ops too: | ||
|
|
||
| ```python | ||
| # All four FP ops below are tagged: the outer sqrt, the inner add, and the two inner muls. | ||
| r = qd.precise(qd.sqrt(a * a + b * b)) | ||
|
|
||
| # Ternary is traversed through; the two branches and the condition's inner ops are tagged. | ||
| r = qd.precise(qd.select(cond, a + b, a - b)) | ||
| ``` | ||
|
|
||
| ## Where the walker stops | ||
|
|
||
| `qd.precise` does not descend into: | ||
|
|
||
| - Loads (ndarray indexing, field access) | ||
| - Constants | ||
| - `qd.func` call sites | ||
| - Atomic ops | ||
| - Intermediate Python variable assignments (`tmp = a + b` wraps the RHS in an internal alloca, so `qd.precise(tmp)` sees the alloca, not the inner `BinaryOp`, and is a silent no-op) | ||
|
|
||
| Semantics inside a `qd.func` body are governed by that body's own ops. If you want IEEE-strict behavior inside a called function, wrap the relevant ops inside the function's body, not at the call site. Similarly, wrap `qd.precise` directly around the expression rather than around a variable that was assigned earlier: | ||
|
|
||
| ```python | ||
| @qd.func | ||
| def dot_precise(a, b, c, d): | ||
| # Wrap inside the body, not at the caller. | ||
| return qd.precise(a * b + c * d) | ||
|
|
||
| @qd.kernel | ||
| def k(...): | ||
| r = dot_precise(x, y, z, w) # inner ops are already precise | ||
| ``` | ||
|
|
||
| ## Interaction with fast_math | ||
|
|
||
| `qd.precise` is a per-op override. It takes effect whether `fast_math` is on or off: | ||
|
|
||
| | Setting | Non-precise op | `qd.precise` op | | ||
| |---|---|---| | ||
| | `fast_math=True` | reassoc / contract / simplify | IEEE-strict | | ||
| | `fast_math=False` | IEEE-strict | IEEE-strict (redundant but harmless) | | ||
|
|
||
| The recommended workflow is to leave `fast_math=True` globally for throughput and reach for `qd.precise` only in the handful of spots that need IEEE behavior. | ||
|
|
||
| ## Backend coverage | ||
|
|
||
| | Backend | Reassoc / contraction / algebraic folds | Approximate transcendentals (`sin` / `cos` / `log`) | | ||
| |---|---|---| | ||
| | CPU | LLVM FMF cleared | libc `sinf` is already correctly rounded | | ||
| | CUDA | LLVM FMF cleared | libdevice `__nv_<fn>f` (non-fast) selected | | ||
| | AMDGPU | LLVM FMF cleared | `__ocml_<fn>` already correctly rounded | | ||
| | Vulkan / MoltenVK | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) | | ||
| | Metal | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) | | ||
|
|
||
| On SPIR-V backends, `NoContraction` is defined by the spec to apply to arithmetic instructions only; most consumers ignore it on the `OpExtInst` calls used for transcendentals. The decoration is still emitted (it is harmless and future-proofs against downstream toolchains that start honoring it), but correctness of `qd.precise(qd.sin(x))` / `qd.precise(qd.cos(x))` on Metal / Vulkan cannot be guaranteed through the tag: the Vulkan precision requirements for GLSL.std.450 `Sin`/`Cos` are stated as 2^-11 absolute error, which on inputs whose reference magnitude is smaller than 1 is thousands of ULPs, and drivers are within their rights to saturate that latitude. If you need correctly-rounded sin/cos, use the CPU / CUDA / AMDGPU backends. | ||
|
|
||
| ## Example: Dekker 2Sum | ||
|
|
||
| A textbook compensated addition that computes `s + e = a + b` exactly in f32: | ||
|
|
||
| ```python | ||
| @qd.func | ||
| def two_sum(a, b): | ||
| s = qd.precise(a + b) | ||
| bb = qd.precise(s - a) | ||
| aa = qd.precise(s - bb) | ||
| e = qd.precise((a - aa) + (b - bb)) | ||
| return s, e | ||
| ``` | ||
|
|
||
| Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognizes `(a - (s - (s - a))) + (b - (s - a))` as algebraically zero and folds `e` to `0`. The wrappers prevent that fold, and `s + e` reproduces `a + b` to full precision. | ||
|
|
||
| ## Caveats | ||
|
|
||
| - `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside. | ||
| - `qd.precise` does not mutate its input. It returns a fresh expression subtree with every reachable FP op tagged; the original expression is unchanged. Reusing the original elsewhere is safe and never inherits the tag. | ||
|
|
||
| ## Companion: `qd.math.fma` | ||
|
|
||
| Compensated-arithmetic blocks typically need two things: (1) IEEE-strict ordering on ordinary ops (provided by `qd.precise`) and (2) a guaranteed single-rounding fused multiply-add for error-free transforms. The second is exposed separately as `qd.math.fma(a, b, c)`: | ||
|
|
||
| ```python | ||
| # Two-product error-free transform (TwoProd). | ||
| # Returns p = round(a*b) and e such that a*b = p + e exactly. | ||
| p = a * b | ||
| e = qd.math.fma(a, b, -p) # single rounding: exact residual of p | ||
| ``` | ||
|
|
||
| `qd.math.fma` is lowered to the native FMA on every backend: `llvm.fma` on CPU, `__nv_fma` / `__nv_fmaf` (libdevice) on CUDA, `GLSL.std.450 Fma` on Vulkan / Metal. Unlike relying on the compiler to contract `mul; add` into FMA (which requires both fast-math flags to permit contraction *and* the inputs to survive algebraic simplification), this is an explicit instruction - so the TwoProd residual, Fast2Sum, and double-single multiply patterns port over directly without needing per-backend contraction hints. | ||
|
|
||
| Backends without hardware FMA fall back to a regular mul-then-add and lose the single-rounding guarantee; on those targets compensated algorithms should be rewritten to the Dekker / Veltkamp-split form. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,59 @@ def cast(obj, dtype): | |
| return expr.Expr(_qd_core.value_cast(expr.Expr(obj).ptr, dtype)) | ||
|
|
||
|
|
||
| def precise(obj): | ||
| """Mark a floating-point expression as IEEE-strict. | ||
|
|
||
| Every binary and unary FP op inside ``obj`` is evaluated in source | ||
| order with no reassociation, no FMA contraction, no approximate | ||
| transcendental substitution, and no non-IEEE-exact algebraic | ||
| simplification, regardless of the module-level :attr:`fast_math` | ||
| setting. Folds that are IEEE-exact for every input (e.g. | ||
| ``a - 0 -> a``, ``a > a -> false``) are still applied. This is | ||
| equivalent to MSL's / HLSL's ``precise`` keyword and lets you keep | ||
| ``fast_math=True`` globally while protecting compensated-arithmetic | ||
| blocks (Dekker / Kahan 2Sum, Veltkamp split, etc.) from being folded | ||
| away. | ||
|
|
||
| Recursion descends through ``BinaryOp``, ``UnaryOp`` (cast, bit_cast, | ||
| neg, sqrt, ...), and ``TernaryOp`` (select) wrappers so that inner | ||
| binary ops are reached even when wrapped, e.g. | ||
| ``qd.precise(qd.bit_cast(a + b, qd.f32))``. It stops at loads, | ||
| constants, ``qd.func`` calls, ndarray accesses, etc.; semantics inside | ||
| a ``qd.func`` body are governed by that body's own ops - wrap calls | ||
| separately if needed. | ||
|
|
||
| Notes: | ||
| * ``qd.precise`` does NOT mutate the input expression. It returns | ||
| a fresh subtree that mirrors the input's structure, with every | ||
| reachable Binary / Unary / Ternary node cloned and the new | ||
| Binary / Unary nodes tagged as ``precise``. Non-walked nodes | ||
| (loads, constants, ``qd.func`` calls, ndarray accesses, ...) | ||
| are shared with the input by reference. The practical upshot: | ||
| reusing the original (pre-``precise``) expression value | ||
| elsewhere is safe - it will NOT pick up the tag. | ||
|
|
||
| Args: | ||
| obj: A scalar Quadrants expression (typically a chain of FP ops). | ||
|
|
||
| Returns: | ||
| A fresh expression subtree with every reachable binary and unary | ||
| FP op tagged as ``precise``. The original ``obj`` is unchanged. | ||
|
|
||
| Example:: | ||
|
|
||
| >>> @qd.func | ||
| >>> def fast_two_sum(a, b): | ||
| >>> # Local IEEE region, survives even with fast_math=True. | ||
| >>> s = qd.precise(a + b) | ||
| >>> e = qd.precise(b - (s - a)) | ||
| >>> return s, e | ||
| """ | ||
| if is_quadrants_class(obj): | ||
| raise ValueError("Cannot apply precise on Quadrants classes") | ||
| return expr.Expr(_qd_core.precise(expr.Expr(obj).ptr)) | ||
|
|
||
|
|
||
| def bit_cast(obj, dtype): | ||
| """Copy and cast a scalar to a specified data type with its underlying | ||
| bits preserved. Must be called in quadrants scope. | ||
|
|
@@ -1117,6 +1170,42 @@ def py_select(cond, x1, x2): | |
| return _ternary_operation(_qd_core.expr_select, py_select, cond, x1, x2) | ||
|
|
||
|
|
||
| def fma(a, b, c): | ||
| """Fused multiply-add: return ``a * b + c`` computed as a single rounded | ||
| operation. | ||
|
|
||
| Unlike a plain ``a * b + c``, the intermediate product is not rounded: | ||
| the result is ``round(a * b + c, 1 ULP)``. This is the hardware FMA | ||
| available on every modern FP pipeline (x86 FMA3, ARM, Apple Silicon, | ||
| NVIDIA ``fma``, AMD, RISC-V Zfa). Exposed here primarily to let | ||
| compensated-arithmetic primitives (TwoProd, Fast2Sum + FMA, | ||
| double-single accumulators) get the single-rounding guarantee without | ||
| relying on backend-specific FMF contraction. | ||
|
|
||
| Classic two-product error-free transform: | ||
|
|
||
| p = a * b | ||
| e = qd.fma(a, b, -p) # exact residual of p | ||
|
|
||
| Each backend maps this to its native FMA (LLVM ``llvm.fma`` intrinsic | ||
| on CPU, ``__nv_fma/__nv_fmaf`` on CUDA via libdevice, GLSL.std.450 | ||
| ``Fma`` on Vulkan/Metal). Backends without hardware FMA fall back to | ||
| a regular mul-then-add and lose the single-rounding guarantee. | ||
|
|
||
| Args: | ||
| a, b, c: Homogeneous FP scalars (``f16``/``f32``/``f64``). Integer | ||
| inputs are rejected. | ||
|
|
||
| Returns: | ||
| ``round(a * b + c, 1 ULP)`` as a single rounded operation. | ||
| """ | ||
|
|
||
| def py_fma(a, b, c): | ||
| return a * b + c | ||
|
|
||
| return _ternary_operation(_qd_core.expr_fma, py_fma, a, b, c) | ||
|
|
||
|
|
||
| def ifte(cond, x1, x2): | ||
| """Evaluate and return `x1` if `cond` is true; otherwise evaluate and return `x2`. This operator guarantees | ||
| short-circuit semantics: exactly one of `x1` or `x2` will be evaluated. | ||
|
|
@@ -1535,4 +1624,5 @@ def min(*args): # pylint: disable=W0622 | |
| "select", | ||
| "abs", | ||
| "pow", | ||
| "precise", | ||
| ] | ||
|
Comment on lines
1624
to
1628
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 The fma docstring in ops.py shows 'e = qd.fma(a, b, -p)' in its TwoProd example, but fma is not in ops.all and is therefore not re-exported to the top-level qd namespace. Copying this example verbatim raises AttributeError: module 'quadrants' has no attribute 'fma'. The correct public API is qd.math.fma, which is what the companion documentation in precise.md correctly uses.
Extended reasoning...
What the bug is and how it manifests
The fma function added in python/quadrants/lang/ops.py has a docstring that demonstrates the classic TwoProd error-free transform:
A user reading the docstring and copying this example will immediately encounter AttributeError: module 'quadrants' has no attribute 'fma'.
The specific code path that triggers it
The top-level qd namespace is populated via quadrants/init.py which does 'from quadrants.lang import *', which chains through 'from quadrants.lang.ops import *'. Only names listed in ops.all (lines 1592-1628) reach the top level. The fma function defined in ops.py is deliberately NOT added to ops.all - it is an internal function wrapped by qd.math.fma.
Why existing code does not prevent it
The docstring was written using the wrong namespace prefix. The omission of fma from ops.all is correct and intentional (preventing qd.fma from polluting the top-level namespace), but the docstring example was never updated to reflect that the public entry point is qd.math.fma rather than qd.fma.
What the impact would be
Any user reading the ops.fma docstring (via help(), an IDE, or generated API docs) and copying the example verbatim will get a runtime AttributeError. The bug is documentation-only: the runtime implementation is correct, and qd.math.fma works as advertised. The companion documentation in docs/source/user_guide/precise.md (added in the same PR) correctly uses qd.math.fma(a, b, -p), creating an inconsistency between the two sources.
How to fix it
Change the docstring example in python/quadrants/lang/ops.py (around line 1188) from:
to:
Step-by-step proof