diff --git a/.github/workflows/fp-arena-ci.yml b/.github/workflows/fp-arena-ci.yml index c91c7a8..2e8c2ef 100644 --- a/.github/workflows/fp-arena-ci.yml +++ b/.github/workflows/fp-arena-ci.yml @@ -42,7 +42,7 @@ jobs: sudo apt-get update # gcc/g++ (C++20) + cmake/ninja drive DaCe's CPU codegen build; the # generated maps use OpenMP (libgomp). - sudo apt-get install -y cmake ninja-build gcc g++ libgomp1 + sudo apt-get install -y cmake ninja-build gcc g++ libgomp1 libmpfr-dev - name: Install DaCe (yakup/dev) and FP-Arena run: | diff --git a/fp_arena/__init__.py b/fp_arena/__init__.py index 1284e1d..e881e72 100644 --- a/fp_arena/__init__.py +++ b/fp_arena/__init__.py @@ -14,6 +14,7 @@ Float64sr, float32sr, float64sr, + mpfr, register, FP_ARENA_TYPECLASSES, ) @@ -29,6 +30,7 @@ INCLUDE_DIR, ) from fp_arena.transformations.change_fp_types import change_fptype +from fp_arena.transformations.change_and_propagate_fp_types import change_and_propagate_fp_types # Register the types and the SDFG convenience method on import (idempotent). register() @@ -47,6 +49,7 @@ "Float64sr", "float32sr", "float64sr", + "mpfr", "register", "FP_ARENA_TYPECLASSES", "enable_fp_arena_extensions", @@ -59,4 +62,5 @@ "fp_arena_global_code", "INCLUDE_DIR", "change_fptype", + "change_and_propagate_fp_types", ] diff --git a/fp_arena/dtypes.py b/fp_arena/dtypes.py index f2bf2e1..3caf66f 100644 --- a/fp_arena/dtypes.py +++ b/fp_arena/dtypes.py @@ -13,9 +13,10 @@ """ import numpy +import ctypes import dace -from dace import dtypes as _ddtypes +from dace import dtypes as _ddtypes, typeclass #: C++ namespace-qualified type names emitted into generated code. _FLOAT32SR_CTYPE = "fp_arena::float32sr" @@ -77,6 +78,60 @@ def __repr__(self) -> str: } +class _mpfr_t(ctypes.Structure): + _fields_ = [ + ("_mpfr_prec", ctypes.c_long), + ("_mpfr_sign", ctypes.c_int), + ("_mpfr_exp", ctypes.c_long), + ("_mpfr_d", ctypes.c_void_p), + ] + + +class mpfr(typeclass): + """ + A data type for custom Multiple Precision Floating-Point (MPFR) types. + + Example use: `dace.mpfr(128)` for 128-bit precision. + """ + + def __init__(self, precision: int): + self.precision = precision + self.type = numpy.object_ + self.bytes = ctypes.sizeof(_mpfr_t) + self.dtype = self + self.typename = f"mpfr{precision}" + + def to_string(self): + return self.typename + + def to_json(self): + return {"type": "mpfr", "precision": self.precision} + + @staticmethod + def from_json(json_obj, context=None): + if json_obj["type"] != "mpfr": + raise TypeError("Invalid type for mpfr") + return mpfr(json_obj["precision"]) + + @property + def ctype(self): + return f"dace::mpfr<{self.precision}>" + + @property + def ctype_unaligned(self): + return self.ctype + + def as_ctypes(self): + return ctypes.c_void_p + + def as_numpy_dtype(self): + return numpy.dtype(numpy.object_) + + @property + def base_type(self): + return self + + def register(): """ Register the FP-Arena types into DaCe's global registries. @@ -100,4 +155,8 @@ def register(): _ddtypes.TYPECLASS_STRINGS.append(name) _ddtypes.TYPECLASS_TO_STRING.setdefault(tc, tc.ctype) + # Also expose the parametric mpfr class so `dace.mpfr(128)` works. + setattr(_ddtypes, "mpfr", mpfr) + setattr(dace, "mpfr", mpfr) + return FP_ARENA_TYPECLASSES diff --git a/fp_arena/extensions.py b/fp_arena/extensions.py index 8356ae2..f49867d 100644 --- a/fp_arena/extensions.py +++ b/fp_arena/extensions.py @@ -43,6 +43,7 @@ _HEADERS = ( os.path.join(INCLUDE_DIR, "fp_arena", "float32sr.h"), os.path.join(INCLUDE_DIR, "fp_arena", "float64sr.h"), + os.path.join(INCLUDE_DIR, "fp_arena", "mpfr.h"), ) #: Backends whose global-code section receives the includes (CPU frame + CUDA). @@ -51,8 +52,8 @@ #: Marker so the includes are only injected once per SDFG. _GUARD = "// fp_arena extensions enabled" -#: Substring identifying an FP-Arena C type (used to detect SR usage in an SDFG). -_CTYPE_MARKER = "fp_arena::" +#: Substrings identifying an FP-Arena C type (fp_arena:: for SR types, dace::mpfr for mpfr). +_CTYPE_MARKERS = ("fp_arena::", "dace::mpfr") def fp_arena_global_code() -> str: @@ -145,13 +146,23 @@ def enable_fp_arena_extensions(sdfg: dace.SDFG) -> dace.SDFG: def uses_fp_arena_types(sdfg: dace.SDFG) -> bool: """ :param sdfg: the SDFG to inspect. - :returns: ``True`` if any data descriptor in ``sdfg`` or its nested SDFGs has - an FP-Arena C type (e.g. ``float32sr``), ``False`` otherwise. + :returns: ``True`` if any data descriptor or tasklet body in ``sdfg`` or its + nested SDFGs references an FP-Arena C type, ``False`` otherwise. """ + from dace.sdfg import nodes as _dnodes for nested in sdfg.all_sdfgs_recursive(): for desc in nested.arrays.values(): - if _CTYPE_MARKER in (getattr(desc.dtype, "ctype", "") or ""): + if any(m in (getattr(desc.dtype, "ctype", "") or "") for m in _CTYPE_MARKERS): return True + for state in nested.states(): + for node in state.nodes(): + if isinstance(node, _dnodes.Tasklet): + try: + code_str = node.code.as_string + except AttributeError: + code_str = str(node.code) + if any(m in code_str for m in _CTYPE_MARKERS): + return True return False diff --git a/fp_arena/runtime/include/fp_arena/mpfr.h b/fp_arena/runtime/include/fp_arena/mpfr.h new file mode 100644 index 0000000..fed1e87 --- /dev/null +++ b/fp_arena/runtime/include/fp_arena/mpfr.h @@ -0,0 +1,222 @@ +#pragma once + +#include + +#include + +namespace dace { + +// Process-global exponent width (0 = unlimited). +// Set once before any computation via set_mpfr_exponent_bits(). Not +// thread-safe. +inline unsigned int mpfr_exponent_bits = 0; + +// Configure the exponent range for all dace::mpfr values. +// Uses IEEE-style bias: emax = 2^(bits-1)-1, emin = 1-emax. + +inline void set_mpfr_exponent_bits(unsigned int bits) { + mpfr_exponent_bits = bits; + if (bits > 0) { + mpfr_exp_t emax = (mpfr_exp_t(1) << (bits - 1)) - 1; + mpfr_set_emin(1 - emax); + mpfr_set_emax(emax); + } else { + mpfr_set_emin(MPFR_EMIN_DEFAULT); + mpfr_set_emax(MPFR_EMAX_DEFAULT); + } +} + +template class mpfr { +private: + mpfr_t val; + + void clamp_exp(int t) { + if (mpfr_exponent_bits > 0) { + int temp = mpfr_check_range(val, t, MPFR_RNDN); + mpfr_subnormalize(val, temp, MPFR_RNDN); + } + } + +public: + // Constructors + mpfr() { mpfr_init2(val, Precision); } + + mpfr(double d) { + mpfr_init2(val, Precision); + clamp_exp(mpfr_set_d(val, d, MPFR_RNDN)); + } + + mpfr(float f) { + mpfr_init2(val, Precision); + clamp_exp(mpfr_set_flt(val, f, MPFR_RNDN)); + } + + mpfr(int i) { + mpfr_init2(val, Precision); + clamp_exp(mpfr_set_si(val, i, MPFR_RNDN)); + } + + mpfr(unsigned int i) { + mpfr_init2(val, Precision); + clamp_exp(mpfr_set_ui(val, i, MPFR_RNDN)); + } + + mpfr(long int i) { + mpfr_init2(val, Precision); + clamp_exp(mpfr_set_si(val, i, MPFR_RNDN)); + } + + mpfr(unsigned long int i) { + mpfr_init2(val, Precision); + clamp_exp(mpfr_set_ui(val, i, MPFR_RNDN)); + } + + // Copy Constructor + mpfr(const mpfr &other) { + mpfr_init2(val, Precision); + mpfr_set(val, other.val, MPFR_RNDN); + } + + // Move Constructor + mpfr(mpfr &&other) noexcept { + mpfr_init2(val, Precision); + mpfr_swap(val, other.val); + } + + // Destructor + ~mpfr() { mpfr_clear(val); } + + // Copy Assignment + mpfr &operator=(const mpfr &other) { + if (this != &other) { + mpfr_set(val, other.val, MPFR_RNDN); + } + return *this; + } + + // Move Assignment + mpfr &operator=(mpfr &&other) noexcept { + if (this != &other) { + mpfr_swap(val, other.val); + } + return *this; + } + + // Assign from basic types + mpfr &operator=(double d) { + clamp_exp(mpfr_set_d(val, d, MPFR_RNDN)); + return *this; + } + + mpfr &operator=(float f) { + clamp_exp(mpfr_set_flt(val, f, MPFR_RNDN)); + return *this; + } + + mpfr &operator=(int i) { + clamp_exp(mpfr_set_si(val, i, MPFR_RNDN)); + return *this; + } + + // Conversions + explicit operator double() const { return mpfr_get_d(val, MPFR_RNDN); } + + explicit operator float() const { return mpfr_get_flt(val, MPFR_RNDN); } + + explicit operator int() const { return (int)mpfr_get_si(val, MPFR_RNDN); } + + explicit operator long() const { return mpfr_get_si(val, MPFR_RNDN); } + + explicit operator bool() const { return mpfr_cmp_d(val, 0.0) != 0; } + + // Arithmetic Operators + friend mpfr operator+(const mpfr &lhs, const mpfr &rhs) { + mpfr result; + result.clamp_exp(mpfr_add(result.val, lhs.val, rhs.val, MPFR_RNDN)); + return result; + } + + friend mpfr operator-(const mpfr &lhs, const mpfr &rhs) { + mpfr result; + result.clamp_exp(mpfr_sub(result.val, lhs.val, rhs.val, MPFR_RNDN)); + return result; + } + + friend mpfr operator*(const mpfr &lhs, const mpfr &rhs) { + mpfr result; + result.clamp_exp(mpfr_mul(result.val, lhs.val, rhs.val, MPFR_RNDN)); + return result; + } + + friend mpfr operator/(const mpfr &lhs, const mpfr &rhs) { + mpfr result; + result.clamp_exp(mpfr_div(result.val, lhs.val, rhs.val, MPFR_RNDN)); + return result; + } + + // Unary minus + mpfr operator-() const { + mpfr result; + result.clamp_exp(mpfr_neg(result.val, this->val, MPFR_RNDN)); + return result; + } + + // Compound assignments + mpfr &operator+=(const mpfr &other) { + clamp_exp(mpfr_add(val, val, other.val, MPFR_RNDN)); + return *this; + } + + mpfr &operator-=(const mpfr &other) { + clamp_exp(mpfr_sub(val, val, other.val, MPFR_RNDN)); + return *this; + } + + mpfr &operator*=(const mpfr &other) { + clamp_exp(mpfr_mul(val, val, other.val, MPFR_RNDN)); + return *this; + } + + mpfr &operator/=(const mpfr &other) { + clamp_exp(mpfr_div(val, val, other.val, MPFR_RNDN)); + return *this; + } + + // Comparison Operators + friend bool operator==(const mpfr &lhs, const mpfr &rhs) { + return mpfr_cmp(lhs.val, rhs.val) == 0; + } + + friend bool operator!=(const mpfr &lhs, const mpfr &rhs) { + return mpfr_cmp(lhs.val, rhs.val) != 0; + } + + friend bool operator<(const mpfr &lhs, const mpfr &rhs) { + return mpfr_cmp(lhs.val, rhs.val) < 0; + } + + friend bool operator>(const mpfr &lhs, const mpfr &rhs) { + return mpfr_cmp(lhs.val, rhs.val) > 0; + } + + friend bool operator<=(const mpfr &lhs, const mpfr &rhs) { + return mpfr_cmp(lhs.val, rhs.val) <= 0; + } + + friend bool operator>=(const mpfr &lhs, const mpfr &rhs) { + return mpfr_cmp(lhs.val, rhs.val) >= 0; + } + + // IO Stream + friend std::ostream &operator<<(std::ostream &os, const mpfr &m) { + char *str = nullptr; + mpfr_asprintf(&str, "%.*RNg", (int)(Precision * 0.30103) + 2, m.val); + if (str) { + os << str; + mpfr_free_str(str); + } + return os; + } +}; + +} // namespace dace diff --git a/fp_arena/transformations/__init__.py b/fp_arena/transformations/__init__.py index 1781a2e..664df36 100644 --- a/fp_arena/transformations/__init__.py +++ b/fp_arena/transformations/__init__.py @@ -2,5 +2,6 @@ """FP-Arena SDFG transformations.""" from fp_arena.transformations.change_fp_types import change_fptype +from fp_arena.transformations.change_and_propagate_fp_types import change_and_propagate_fp_types -__all__ = ["change_fptype"] +__all__ = ["change_fptype", "change_and_propagate_fp_types"] diff --git a/fp_arena/transformations/change_and_propagate_fp_types.py b/fp_arena/transformations/change_and_propagate_fp_types.py new file mode 100644 index 0000000..69bb29a --- /dev/null +++ b/fp_arena/transformations/change_and_propagate_fp_types.py @@ -0,0 +1,413 @@ +from collections import defaultdict, deque +from functools import reduce +from typing import Dict, FrozenSet, Optional, Set, Tuple + +import dace +from dace.sdfg import nodes, utils as sdfg_utils +from dace.sdfg.state import AbstractControlFlowRegion, SDFGState + + +# Returns the promoted type of *t1* and *t2* according to *rules*. +def _promote( + t1: Optional[dace.dtypes.typeclass], + t2: Optional[dace.dtypes.typeclass], + rules: Dict[FrozenSet, dace.dtypes.typeclass], +) -> Optional[dace.dtypes.typeclass]: + if t1 is None: + return t2 + if t2 is None: + return t1 + if t1 == t2: + return t1 + key = frozenset({t1, t2}) + if key not in rules: + raise ValueError(f"No promotion rule defined for types {t1} and {t2}") + return rules[key] + + +# None-safe equality for optional typeclasses (``dace.typeclass != None`` is unreliable). +def _types_equal( + a: Optional[dace.dtypes.typeclass], + b: Optional[dace.dtypes.typeclass], +) -> bool: + if a is None or b is None: + return a is b + return bool(a == b) + + +# Returns states using topological_sort, visiting nested regions recursively +def _states_in_order(cfg: AbstractControlFlowRegion): + for block in sdfg_utils.dfs_topological_sort(cfg): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, AbstractControlFlowRegion): + yield from _states_in_order(block) + + +# Builds the array-level dataflow graph of the SDFG. +# +# Returns: +# producers: array -> set of arrays that feed any computation writing it +# consumers: the reverse map (array -> arrays it feeds into) +def _build_dataflow( + sdfg: dace.SDFG, +) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: + producers: Dict[str, Set[str]] = defaultdict(set) + consumers: Dict[str, Set[str]] = defaultdict(set) + + for state in _states_in_order(sdfg): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + raise NotImplementedError( + "Type propagation does not currently support NestedSDFG nodes. " + f"Found '{node.label}' in state '{state.label}'." + ) + + if isinstance(node, (nodes.Tasklet, nodes.LibraryNode)): + ins = { + e.data.data + for e in state.in_edges(node) + if e.data is not None and e.data.data is not None + } + outs = { + e.data.data + for e in state.out_edges(node) + if e.data is not None and e.data.data is not None + } + for out in outs: + producers[out].update(ins) + for inp in ins: + consumers[inp].add(out) + + # Direct AccessNode -> AccessNode copies. + for e in state.edges(): + if e.data is None or e.data.data is None: + continue + if isinstance(e.src, nodes.AccessNode) and isinstance( + e.dst, nodes.AccessNode + ): + producers[e.dst.data].add(e.src.data) + consumers[e.src.data].add(e.dst.data) + + return producers, consumers + + +# Computes a fixed-point precision for every array. +# +# Each array's type is the join (widest, via *rules*) of its producers' types: +# - pinned arrays (in *initial_types*) are constants and never recomputed; +# - source arrays (no producers) are constants at their original dtype; +# - derived arrays start at None (bottom) and only ever widen, so demotion +# propagates correctly and the monotone worklist always terminates. +# Returns name -> dtype, where None means "leave the array's dtype unchanged". +def _infer_types( + sdfg: dace.SDFG, + producers: Dict[str, Set[str]], + consumers: Dict[str, Set[str]], + initial_types: Dict[str, dace.dtypes.typeclass], + original_types: Dict[str, dace.dtypes.typeclass], + supported: Set[dace.dtypes.typeclass], + rules: Dict[FrozenSet, dace.dtypes.typeclass], +) -> Dict[str, Optional[dace.dtypes.typeclass]]: + + inferred: Dict[str, Optional[dace.dtypes.typeclass]] = {} + for name in sdfg.arrays: + if name in initial_types: + inferred[name] = initial_types[name] # pin: constant + elif not producers.get(name): + inferred[name] = original_types[name] # source: constant at original + else: + inferred[name] = None # derived: bottom, widened below + + worklist = deque( + name + for name in sdfg.arrays + if name not in initial_types and producers.get(name) + ) + queued = set(worklist) + while worklist: + name = worklist.popleft() + queued.discard(name) + + new: Optional[dace.dtypes.typeclass] = None + for prod in producers[name]: + t = inferred.get(prod) + if t is None or t not in supported: + continue + new = _promote(new, t, rules) + + if not _types_equal(new, inferred[name]): + inferred[name] = new + for cons in consumers.get(name, ()): + if cons in initial_types or cons in queued or not producers.get(cons): + continue + worklist.append(cons) + queued.add(cons) + + return inferred + + +# Prints a report of each array's original and final (inferred) precision +def _print_type_report( + original_types: Dict[str, dace.dtypes.typeclass], + inferred: Dict[str, Optional[dace.dtypes.typeclass]], +) -> None: + name_w = max((len(n) for n in original_types), default=0) + name_w = max(name_w, len("array")) + lines = [ + "change_and_propagate_fp_types: precision report", + f" {'array':<{name_w}} {'original':<9} {'final':<9}", + ] + for name in sorted(original_types): + orig = original_types[name] + final = inferred.get(name) + if final is None: + final = orig + marker = "" if _types_equal(final, orig) else " (changed)" + lines.append( + f" {name:<{name_w}} {orig.to_string():<9} {final.to_string():<9}{marker}" + ) + print("\n".join(lines)) + + +# Writes inferred types onto tasklet/map/library-node connectors. +def _apply_connector_types( + sdfg: dace.SDFG, + inferred: Dict[str, Optional[dace.dtypes.typeclass]], + supported: Set[dace.dtypes.typeclass], + rules: Dict[FrozenSet, dace.dtypes.typeclass], +) -> None: + for state in _states_in_order(sdfg): + for node in state.nodes(): + if isinstance(node, (nodes.Tasklet, nodes.LibraryNode)): + in_types = [] + for e in state.in_edges(node): + if e.dst_conn is None or e.data is None or e.data.data is None: + continue + t = inferred.get(e.data.data) + if t in supported: + node.in_connectors[e.dst_conn] = t + in_types.append(t) + + if not in_types: + continue + compute = reduce(lambda a, b: _promote(a, b, rules), in_types) + if compute not in supported: + continue + for e in state.out_edges(node): + if e.src_conn is None or e.data is None or e.data.data is None: + continue + if inferred.get(e.data.data) in supported: + node.out_connectors[e.src_conn] = compute + + elif isinstance(node, (nodes.EntryNode, nodes.ExitNode)): + # MapEntry/Exit connectors follow the IN_/OUT_ convention; keep both in sync. + for e in state.in_edges(node): + if not (e.dst_conn and e.dst_conn.startswith("IN_")): + continue + if e.data is None or e.data.data is None: + continue + t = inferred.get(e.data.data) + if t not in supported: + continue + out_conn = "OUT_" + e.dst_conn[3:] + if e.dst_conn in node.in_connectors: + node.in_connectors[e.dst_conn] = t + if out_conn in node.out_connectors: + node.out_connectors[out_conn] = t + + +# Adds a map that copies and casts *src_name* to *dst_name* +def _add_copy_map( + state: dace.SDFGState, + src_name: str, + src_arr: dace.data.Data, + dst_name: str, + dst_arr: dace.data.Data, +) -> None: + + if src_arr.shape != dst_arr.shape: + raise ValueError( + f"Shape mismatch in _add_copy_map: " + f"{src_name}{src_arr.shape} vs {dst_name}{dst_arr.shape}" + ) + + tasklet = state.add_tasklet( + name=f"cast_{src_name}_to_{dst_name}", + inputs={"_in"}, + outputs={"_out"}, + code=f"_out = static_cast<{dst_arr.dtype.ctype}>(_in);", + language=dace.Language.CPP, + ) + + if isinstance(src_arr, dace.data.Array): + if not isinstance(dst_arr, dace.data.Array): + raise TypeError( + f"_add_copy_map: src '{src_name}' is an Array but dst '{dst_name}' is " + f"{type(dst_arr).__name__}" + ) + map_ranges = {f"_i{d}": f"0:{s}" for d, s in enumerate(src_arr.shape)} + idx = ", ".join(map_ranges) + + map_entry, map_exit = state.add_map( + name=f"cast_map_{src_name}_to_{dst_name}", + ndrange=map_ranges, + ) + map_entry.add_in_connector(f"IN_{src_name}") + map_entry.add_out_connector(f"OUT_{src_name}") + map_exit.add_in_connector(f"IN_{dst_name}") + map_exit.add_out_connector(f"OUT_{dst_name}") + + src_an = state.add_access(src_name) + dst_an = state.add_access(dst_name) + state.add_edge( + src_an, + None, + map_entry, + f"IN_{src_name}", + dace.memlet.Memlet.from_array(src_name, src_arr), + ) + state.add_edge( + map_entry, + f"OUT_{src_name}", + tasklet, + "_in", + dace.Memlet(expr=f"{src_name}[{idx}]"), + ) + state.add_edge( + tasklet, + "_out", + map_exit, + f"IN_{dst_name}", + dace.Memlet(expr=f"{dst_name}[{idx}]"), + ) + state.add_edge( + map_exit, + f"OUT_{dst_name}", + dst_an, + None, + dace.memlet.Memlet.from_array(dst_name, dst_arr), + ) + + else: + if not isinstance(src_arr, dace.data.Scalar): + raise TypeError( + f"_add_copy_map: unsupported src descriptor type " + f"{type(src_arr).__name__} for '{src_name}'" + ) + if not isinstance(dst_arr, dace.data.Scalar): + raise TypeError( + f"_add_copy_map: unsupported dst descriptor type " + f"{type(dst_arr).__name__} for '{dst_name}'" + ) + src_an = state.add_access(src_name) + dst_an = state.add_access(dst_name) + state.add_edge(src_an, None, tasklet, "_in", dace.Memlet(expr=src_name)) + state.add_edge(tasklet, "_out", dst_an, None, dace.Memlet(expr=dst_name)) + + +# Main entry point to change and propagate fp types through an SDFG. +def change_and_propagate_fp_types( + sdfg: dace.SDFG, + initial_types: Dict[str, dace.dtypes.typeclass], + promotion_rules: Dict[FrozenSet[dace.dtypes.typeclass], dace.dtypes.typeclass], +) -> None: + + supported: Set[dace.dtypes.typeclass] = set() + for pair in promotion_rules: + supported.update(pair) + + original_types: Dict[str, dace.dtypes.typeclass] = { + name: desc.dtype for name, desc in sdfg.arrays.items() + } + original_nontransients: Dict[str, dace.dtypes.typeclass] = { + name: dtype + for name, dtype in original_types.items() + if not sdfg.arrays[name].transient + } + + # Build the array-level dataflow graph and solve the precision fixed point. + producers, consumers = _build_dataflow(sdfg) + inferred = _infer_types( + sdfg, + producers, + consumers, + initial_types, + original_types, + supported, + promotion_rules, + ) + + _print_type_report(original_types, inferred) + + # Apply inferred dtypes to SDFG arrays. + for name, dtype in inferred.items(): + if dtype is not None and sdfg.arrays[name].dtype != dtype: + sdfg.arrays[name].dtype = dtype + + _apply_connector_types(sdfg, inferred, supported, promotion_rules) + + # Preserve the external interface: non-transient arrays that changed type are renamed to an internal transient. + changed_interface = { + name + for name, orig_dtype in original_nontransients.items() + if sdfg.arrays[name].dtype != orig_dtype + } + if not changed_interface: + return + + repl_dict = { + name: f"fp_casted_{name}_{sdfg.arrays[name].dtype.to_string()}" + for name in changed_interface + } + + # Snapshot read/write sets before renaming so orig_name is still meaningful. + sdfg_inputs, sdfg_outputs = sdfg.read_and_write_sets() + sdfg.replace_dict(repl_dict) + + # Reconstruct original-typed descriptors and add them back to the SDFG. + orig_descs: Dict[str, dace.data.Data] = {} + for orig_name, casted_name in repl_dict.items(): + casted_desc = sdfg.arrays[casted_name] + casted_desc.transient = True + + orig_desc = casted_desc.clone() + orig_desc.transient = False + orig_desc.dtype = original_nontransients[orig_name] + orig_descs[orig_name] = orig_desc + + sdfg.add_datadesc(name=orig_name, datadesc=orig_desc) + + # TODO: This is currently inefficient, copies all changed inputs for each sink state. + for sink in sdfg.sink_nodes(): + copy_out_state = None + for orig_name, casted_name in repl_dict.items(): + if orig_name not in sdfg_outputs: + continue + if copy_out_state is None: + copy_out_state = sdfg.add_state_after( + state=sink, label=f"copy_out_{sink.label}" + ) + _add_copy_map( + copy_out_state, + casted_name, + sdfg.arrays[casted_name], + orig_name, + orig_descs[orig_name], + ) + + copy_in_state = None + for orig_name, casted_name in repl_dict.items(): + if orig_name not in sdfg_inputs: + continue + if copy_in_state is None: + copy_in_state = sdfg.add_state_before( + state=sdfg.start_block, label="copy_in" + ) + _add_copy_map( + copy_in_state, + orig_name, + orig_descs[orig_name], + casted_name, + sdfg.arrays[casted_name], + ) diff --git a/tests/test_change_and_propagate_fp_types.py b/tests/test_change_and_propagate_fp_types.py new file mode 100644 index 0000000..b989571 --- /dev/null +++ b/tests/test_change_and_propagate_fp_types.py @@ -0,0 +1,460 @@ +import dace +from dace.libraries.standard.nodes.reduce import Reduce +from fp_arena.transformations.change_and_propagate_fp_types import ( + change_and_propagate_fp_types, +) + +RULES = { + frozenset({dace.float16, dace.float32}): dace.float32, + frozenset({dace.float32, dace.float64}): dace.float64, + frozenset({dace.float16, dace.float64}): dace.float64, +} + + +def _tasklet_chain(n_states: int, transient_intermediates: bool = True): + """ + Build an SDFG with a linear chain of states: + A -> [T0 -> B0] -> [T1 -> B1] -> ... -> [T_{n-1} -> B_{n-1}] + Returns (sdfg, 'A', 'B0', ..., 'B_{n-1}'). + """ + sdfg = dace.SDFG("chain") + sdfg.add_array("A", [1], dace.float32, transient=False) + arr_names = ["A"] + states = [] + prev_name = "A" + for i in range(n_states): + out_name = f"B{i}" + sdfg.add_array( + out_name, + [1], + dace.float32, + transient=(transient_intermediates and i < n_states - 1), + ) + arr_names.append(out_name) + s = sdfg.add_state(f"s{i}") + states.append(s) + if i > 0: + sdfg.add_edge(states[-2], s, dace.InterstateEdge()) + t = s.add_tasklet(f"t{i}", {"x"}, {"y"}, "y = x") + s.add_edge(s.add_read(prev_name), None, t, "x", dace.Memlet(f"{prev_name}[0]")) + s.add_edge(t, "y", s.add_write(out_name), None, dace.Memlet(f"{out_name}[0]")) + prev_name = out_name + return sdfg, arr_names + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_transient_intermediate_propagates(): + """Type propagates from a non-transient input through a transient intermediate to the output.""" + sdfg = dace.SDFG("test") + sdfg.add_array("A", [1], dace.float32, transient=False) + sdfg.add_array("B", [1], dace.float32, transient=True) + sdfg.add_array("C", [1], dace.float32, transient=False) + + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s1, s2, dace.InterstateEdge()) + + t1 = s1.add_tasklet("t1", {"a"}, {"b"}, "b = a") + s1.add_edge(s1.add_read("A"), None, t1, "a", dace.Memlet("A[0]")) + s1.add_edge(t1, "b", s1.add_write("B"), None, dace.Memlet("B[0]")) + + t2 = s2.add_tasklet("t2", {"b"}, {"c"}, "c = b") + s2.add_edge(s2.add_read("B"), None, t2, "b", dace.Memlet("B[0]")) + s2.add_edge(t2, "c", s2.add_write("C"), None, dace.Memlet("C[0]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + # Transient intermediate should be promoted to f16. + assert sdfg.arrays["B"].dtype == dace.float16, sdfg.arrays["B"].dtype + + # Non-transient interface: external type preserved (f32), internal casted array exists. + assert sdfg.arrays["A"].dtype == dace.float32 + assert sdfg.arrays["C"].dtype == dace.float32 + assert "fp_casted_A_float16" in sdfg.arrays + assert "fp_casted_C_float16" in sdfg.arrays + assert sdfg.arrays["fp_casted_A_float16"].dtype == dace.float16 + assert sdfg.arrays["fp_casted_C_float16"].dtype == dace.float16 + + sdfg.validate() + # sdfg.compile() TO + + +def test_all_nontransient_interface_preserved(): + """All non-transient: all arrays get cast wrappers, none change dtype externally.""" + sdfg = dace.SDFG("test_state_order") + sdfg.add_array("A", [1], dace.float32) + sdfg.add_array("B", [1], dace.float32) + sdfg.add_array("C", [1], dace.float32) + + # States added in reverse alphabetical order to exercise topological sort. + state_C = sdfg.add_state("state_C") + state_B = sdfg.add_state("state_B") + state_A = sdfg.add_state("state_A") + sdfg.add_edge(state_A, state_B, dace.InterstateEdge()) + sdfg.add_edge(state_B, state_C, dace.InterstateEdge()) + + ta = state_A.add_tasklet("ta", {"a"}, {"b"}, "b = a") + state_A.add_edge(state_A.add_read("A"), None, ta, "a", dace.Memlet("A[0]")) + state_A.add_edge(ta, "b", state_A.add_write("B"), None, dace.Memlet("B[0]")) + + tc = state_C.add_tasklet("tc", {"b"}, {"c"}, "c = b") + state_C.add_edge(state_C.add_read("B"), None, tc, "b", dace.Memlet("B[0]")) + state_C.add_edge(tc, "c", state_C.add_write("C"), None, dace.Memlet("C[0]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + # All are non-transient: external types preserved. + assert sdfg.arrays["A"].dtype == dace.float32 + assert sdfg.arrays["B"].dtype == dace.float32 + assert sdfg.arrays["C"].dtype == dace.float32 + # Casted versions exist for the changed arrays. + assert "fp_casted_A_float16" in sdfg.arrays + assert "fp_casted_C_float16" in sdfg.arrays + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_mixed_precision_promotes(): + """When a f16 and f32 array feed the same tasklet, the output is f32""" + sdfg = dace.SDFG("mixed") + sdfg.add_array("A", [1], dace.float32, transient=False) + sdfg.add_array("D", [1], dace.float32, transient=False) + sdfg.add_array("E", [1], dace.float32, transient=True) + + s = sdfg.add_state("s") + t = s.add_tasklet("t", {"a", "d"}, {"e"}, "e = a + d") + s.add_edge(s.add_read("A"), None, t, "a", dace.Memlet("A[0]")) + s.add_edge(s.add_read("D"), None, t, "d", dace.Memlet("D[0]")) + s.add_edge(t, "e", s.add_write("E"), None, dace.Memlet("E[0]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + # D stays f32, A is demoted to f16; their mix promotes E to f32. + assert sdfg.arrays["E"].dtype == dace.float32, sdfg.arrays["E"].dtype + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_map_passthrough(): + """Type flows through MapEntry and MapExit connectors.""" + sdfg = dace.SDFG("map_test") + sdfg.add_array("A", [4], dace.float32, transient=False) + sdfg.add_array("B", [4], dace.float32, transient=True) + + s = sdfg.add_state("s") + me, mx = s.add_map("m", {"i": "0:4"}) + t = s.add_tasklet("t", {"a"}, {"b"}, "b = a * 2.0;", language=dace.Language.CPP) + + a_an = s.add_read("A") + b_an = s.add_write("B") + me.add_in_connector("IN_A") + me.add_out_connector("OUT_A") + mx.add_in_connector("IN_B") + mx.add_out_connector("OUT_B") + + s.add_edge(a_an, None, me, "IN_A", dace.Memlet("A[0:4]")) + s.add_edge(me, "OUT_A", t, "a", dace.Memlet("A[i]")) + s.add_edge(t, "b", mx, "IN_B", dace.Memlet("B[i]")) + s.add_edge(mx, "OUT_B", b_an, None, dace.Memlet("B[0:4]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + assert sdfg.arrays["B"].dtype == dace.float16, sdfg.arrays["B"].dtype + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_reduce_node(): + """Type propagates through a Reduce library node.""" + sdfg = dace.SDFG("reduce_test") + sdfg.add_array("A", [4], dace.float32, transient=False) + sdfg.add_scalar("S", dace.float32, transient=True) + + s = sdfg.add_state("s") + reduce_node = Reduce("sum", "lambda a, b: a + b", axes=[0], identity=0) + s.add_node(reduce_node) + s.add_edge(s.add_read("A"), None, reduce_node, None, dace.Memlet("A[0:4]")) + s.add_edge(reduce_node, None, s.add_write("S"), None, dace.Memlet("S")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + assert sdfg.arrays["S"].dtype == dace.float16, sdfg.arrays["S"].dtype + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_initial_type_pinned(): + """An array listed in initial_types keeps that type even if higher-precision data flows into it.""" + sdfg = dace.SDFG("pinned") + sdfg.add_array("A", [1], dace.float64, transient=False) + sdfg.add_array("B", [1], dace.float32, transient=True) + + s = sdfg.add_state("s") + t = s.add_tasklet("t", {"a"}, {"b"}, "b = a") + s.add_edge(s.add_read("A"), None, t, "a", dace.Memlet("A[0]")) + s.add_edge(t, "b", s.add_write("B"), None, dace.Memlet("B[0]")) + + # Even though A is f64, B must stay f16. + change_and_propagate_fp_types(sdfg, {"A": dace.float64, "B": dace.float16}, RULES) + + assert sdfg.arrays["B"].dtype == dace.float16, sdfg.arrays["B"].dtype + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_long_chain_convergence(): + """Fixpoint converges for a longer state chain without hitting the iteration cap.""" + sdfg, arr_names = _tasklet_chain(n_states=5, transient_intermediates=True) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + for name in arr_names[1:-1]: # B0..B3 are transient intermediates + assert sdfg.arrays[name].dtype == dace.float16, ( + f"{name}: {sdfg.arrays[name].dtype}" + ) + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_unconnected_array_unchanged(): + """Arrays with no data-flow path from initial_types are not modified.""" + sdfg = dace.SDFG("unconnected") + sdfg.add_array("A", [1], dace.float32, transient=False) + sdfg.add_array("B", [1], dace.float32, transient=True) + sdfg.add_array("X", [1], dace.float64, transient=True) # not connected to A + + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s1, s2, dace.InterstateEdge()) + + t1 = s1.add_tasklet("t1", {"a"}, {"b"}, "b = a") + s1.add_edge(s1.add_read("A"), None, t1, "a", dace.Memlet("A[0]")) + s1.add_edge(t1, "b", s1.add_write("B"), None, dace.Memlet("B[0]")) + + # s2 works on X independently. + t2 = s2.add_tasklet("t2", {"x"}, {"y"}, "y = x") + s2.add_edge(s2.add_read("X"), None, t2, "x", dace.Memlet("X[0]")) + s2.add_edge(t2, "y", s2.add_write("X"), None, dace.Memlet("X[0]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + assert sdfg.arrays["B"].dtype == dace.float16 + assert sdfg.arrays["X"].dtype == dace.float64 # unchanged + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_interface_copy_in_only_for_inputs(): + """Arrays that are only written (output-only) do not get a copy-in state.""" + sdfg = dace.SDFG("output_only") + sdfg.add_array("A", [1], dace.float32, transient=False) + sdfg.add_array("B", [1], dace.float32, transient=False) + + s = sdfg.add_state("s") + t = s.add_tasklet("t", {}, {"b"}, "b = 1.0;", language=dace.Language.CPP) + # B is purely written (no read from outside), A is not used. + s.add_edge(t, "b", s.add_write("B"), None, dace.Memlet("B[0]")) + + change_and_propagate_fp_types(sdfg, {"B": dace.float16}, RULES) + + # B is non-transient and changed: external B stays f32, casted version is f16. + assert sdfg.arrays["B"].dtype == dace.float32 + casted_name = "fp_casted_B_float16" + assert casted_name in sdfg.arrays + + # B is output-only and A is unused, so nothing needs casting on the way in: + # the copy_in state is created lazily and should not exist at all here. + copy_in = next((st for st in sdfg.states() if st.label == "copy_in"), None) + assert copy_in is None, ( + "Output-only arrays should not produce a copy_in state" + ) + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_requires_two_fixpoint_passes(): + """ + Verify that the fixpoint loop runs at least two passes. In the first pass, B is promoted to f32 due to the write from E; in the second pass, C is promoted to f32 due to the read from B. + """ + sdfg = dace.SDFG("two_pass_required") + sdfg.add_array("A", [1], dace.float32, transient=False) + sdfg.add_array("E", [1], dace.float32, transient=False) + sdfg.add_array("B", [1], dace.float32, transient=True) + sdfg.add_array("C", [1], dace.float32, transient=True) + + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + s3 = sdfg.add_state("s3") + sdfg.add_edge(s1, s2, dace.InterstateEdge()) + sdfg.add_edge(s2, s3, dace.InterstateEdge()) + + # S1: A -> B + t1 = s1.add_tasklet("t1", {"a"}, {"b"}, "b = a") + s1.add_edge(s1.add_read("A"), None, t1, "a", dace.Memlet("A[0]")) + s1.add_edge(t1, "b", s1.add_write("B"), None, dace.Memlet("B[0]")) + + # S2: B -> C (visited before S3 promotes B to f32) + t2 = s2.add_tasklet("t2", {"b"}, {"c"}, "c = b") + s2.add_edge(s2.add_read("B"), None, t2, "b", dace.Memlet("B[0]")) + s2.add_edge(t2, "c", s2.add_write("C"), None, dace.Memlet("C[0]")) + + # S3: E(f32) -> B (second write; promotes B from f16 to f32) + t3 = s3.add_tasklet("t3", {"e"}, {"b"}, "b = e") + s3.add_edge(s3.add_read("E"), None, t3, "e", dace.Memlet("E[0]")) + s3.add_edge(t3, "b", s3.add_write("B"), None, dace.Memlet("B[0]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + # B is written by both S1(f16) and S3(f32) -> promoted to f32. + assert sdfg.arrays["B"].dtype == dace.float32, ( + f"B should be f32 (promoted), got {sdfg.arrays['B'].dtype}" + ) + # C reads from B; only correct after the second pass propagates B=f32 into S2. + assert sdfg.arrays["C"].dtype == dace.float32, ( + f"C should be f32 (requires two passes), got {sdfg.arrays['C'].dtype} — " + "this failure means the fixpoint loop ran only once" + ) + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_three_level_lattice(): + """With three precision levels, each array takes the widest (join) of its producers.""" + sdfg = dace.SDFG("three_level") + sdfg.add_array("A", [1], dace.float32, transient=False) # pinned f16 + sdfg.add_array("B", [1], dace.float32, transient=False) # source f32 + sdfg.add_array("C", [1], dace.float64, transient=False) # source f64 + sdfg.add_array("AB", [1], dace.float32, transient=True) # join(f16, f32) = f32 + sdfg.add_array("AC", [1], dace.float32, transient=True) # join(f16, f64) = f64 + sdfg.add_array("ABC", [1], dace.float32, transient=True) # join(f32, f64) = f64 + + s = sdfg.add_state("s") + ab = s.add_access("AB") + + t1 = s.add_tasklet("t1", {"a", "b"}, {"o"}, "o = a + b") + s.add_edge(s.add_read("A"), None, t1, "a", dace.Memlet("A[0]")) + s.add_edge(s.add_read("B"), None, t1, "b", dace.Memlet("B[0]")) + s.add_edge(t1, "o", ab, None, dace.Memlet("AB[0]")) + + t2 = s.add_tasklet("t2", {"a", "c"}, {"o"}, "o = a + c") + s.add_edge(s.add_read("A"), None, t2, "a", dace.Memlet("A[0]")) + s.add_edge(s.add_read("C"), None, t2, "c", dace.Memlet("C[0]")) + s.add_edge(t2, "o", s.add_write("AC"), None, dace.Memlet("AC[0]")) + + t3 = s.add_tasklet("t3", {"ab", "c"}, {"o"}, "o = ab + c") + s.add_edge(ab, None, t3, "ab", dace.Memlet("AB[0]")) + s.add_edge(s.add_read("C"), None, t3, "c", dace.Memlet("C[0]")) + s.add_edge(t3, "o", s.add_write("ABC"), None, dace.Memlet("ABC[0]")) + + change_and_propagate_fp_types(sdfg, {"A": dace.float16}, RULES) + + assert sdfg.arrays["AB"].dtype == dace.float32, sdfg.arrays["AB"].dtype + assert sdfg.arrays["AC"].dtype == dace.float64, sdfg.arrays["AC"].dtype + assert sdfg.arrays["ABC"].dtype == dace.float64, sdfg.arrays["ABC"].dtype + + sdfg.validate() + # sdfg.compile() TODO: Compilation of half-precision reduction currently fails + + +def test_cyclic_dependency_terminates(): + """A dependency cycle that no seed reaches keeps original precision and terminates.""" + sdfg = dace.SDFG("cycle") + sdfg.add_array("P", [1], dace.float64, transient=True) + sdfg.add_array("Q", [1], dace.float64, transient=True) + + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s1, s2, dace.InterstateEdge()) + + t1 = s1.add_tasklet("t1", {"q"}, {"p"}, "p = q") + s1.add_edge(s1.add_read("Q"), None, t1, "q", dace.Memlet("Q[0]")) + s1.add_edge(t1, "p", s1.add_write("P"), None, dace.Memlet("P[0]")) + + t2 = s2.add_tasklet("t2", {"p"}, {"q"}, "q = p") + s2.add_edge(s2.add_read("P"), None, t2, "p", dace.Memlet("P[0]")) + s2.add_edge(t2, "q", s2.add_write("Q"), None, dace.Memlet("Q[0]")) + + change_and_propagate_fp_types(sdfg, {}, RULES) + + assert sdfg.arrays["P"].dtype == dace.float64, sdfg.arrays["P"].dtype + assert sdfg.arrays["Q"].dtype == dace.float64, sdfg.arrays["Q"].dtype + + sdfg.save("test_sdfg") + + sdfg.validate() + sdfg.compile() + + +def test_end_to_end_float32_runs(): + """Full pipeline: transform (f64->f32), compile, run, and check numerics.""" + import numpy as np + + n = 8 + sdfg = dace.SDFG("e2e_f32") + sdfg.add_array("A", [n], dace.float64, transient=False) + sdfg.add_array("B", [n], dace.float64, transient=True) + sdfg.add_array("C", [n], dace.float64, transient=False) + + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s1, s2, dace.InterstateEdge()) + + for st, src, dst in [(s1, "A", "B"), (s2, "B", "C")]: + me, mx = st.add_map("m", {"i": f"0:{n}"}) + t = st.add_tasklet( + "t", {"x"}, {"y"}, "y = x * 2.0;", language=dace.Language.CPP + ) + me.add_in_connector(f"IN_{src}") + me.add_out_connector(f"OUT_{src}") + mx.add_in_connector(f"IN_{dst}") + mx.add_out_connector(f"OUT_{dst}") + st.add_edge( + st.add_read(src), None, me, f"IN_{src}", dace.Memlet(f"{src}[0:{n}]") + ) + st.add_edge(me, f"OUT_{src}", t, "x", dace.Memlet(f"{src}[i]")) + st.add_edge(t, "y", mx, f"IN_{dst}", dace.Memlet(f"{dst}[i]")) + st.add_edge( + mx, f"OUT_{dst}", st.add_write(dst), None, dace.Memlet(f"{dst}[0:{n}]") + ) + + change_and_propagate_fp_types(sdfg, {"A": dace.float32}, RULES) + sdfg.validate() + + assert sdfg.arrays["B"].dtype == dace.float32, sdfg.arrays["B"].dtype + assert sdfg.arrays["A"].dtype == dace.float64 + assert sdfg.arrays["C"].dtype == dace.float64 + + A = np.arange(n, dtype=np.float64) + C = np.zeros(n, dtype=np.float64) + sdfg(A=A, C=C) + np.testing.assert_allclose(C, A * 4.0) + + +if __name__ == "__main__": + test_transient_intermediate_propagates() + test_all_nontransient_interface_preserved() + test_mixed_precision_promotes() + test_map_passthrough() + test_reduce_node() + test_initial_type_pinned() + test_long_chain_convergence() + test_unconnected_array_unchanged() + test_interface_copy_in_only_for_inputs() + test_requires_two_fixpoint_passes() + test_three_level_lattice() + test_cyclic_dependency_terminates() + test_end_to_end_float32_runs() + print("All tests passed.") diff --git a/tests/test_mpfr.py b/tests/test_mpfr.py new file mode 100644 index 0000000..e9cb21d --- /dev/null +++ b/tests/test_mpfr.py @@ -0,0 +1,325 @@ +import numpy as np +import dace +import fp_arena +from fp_arena.transformations.change_and_propagate_fp_types import ( + change_and_propagate_fp_types, +) + +dace.Config.append("compiler", "cpu", "libs", value="mpfr") + + +def test_sdfg_scalar_compute(): + """Compute 1/3 at 128-bit precision and return as double.""" + sdfg = dace.SDFG("mpfr_scalar") + state = sdfg.add_state() + + sdfg.add_scalar("tmp", dace.mpfr(128), transient=True) + sdfg.add_array("out", [1], dace.float64, transient=False) + + tmp_node = state.add_access("tmp") + out_node = state.add_write("out") + + init = state.add_tasklet( + "init_mpfr", + {}, + {"t"}, + "t = 1.0 / 3.0;", + language=dace.Language.CPP, + ) + conv = state.add_tasklet( + "to_double", + {"t"}, + {"o"}, + "o = (double)t;", + language=dace.Language.CPP, + ) + + state.add_edge(init, "t", tmp_node, None, dace.Memlet("tmp")) + state.add_edge(tmp_node, None, conv, "t", dace.Memlet("tmp")) + state.add_edge(conv, "o", out_node, None, dace.Memlet("out[0]")) + + csdfg = sdfg.compile() + + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert abs(out[0] - 1.0 / 3.0) < 1e-15, f"Expected ~0.333…, got {out[0]}" + + +def test_sdfg_array_sum(): + """Fill an mpfr array with values 1..N, sum into a double output.""" + N = 4 + sdfg = dace.SDFG("mpfr_array_sum") + state = sdfg.add_state() + + sdfg.add_array("arr", [N], dace.mpfr(128), transient=True) + sdfg.add_array("out", [1], dace.float64, transient=False) + + arr_node = state.add_access("arr") + out_node = state.add_write("out") + + fill = state.add_tasklet( + "fill", + {}, + {"a"}, + "\n".join(f"a[{i}] = {i + 1};" for i in range(N)), + language=dace.Language.CPP, + ) + sumup = state.add_tasklet( + "sum", + {"a"}, + {"o"}, + f"""dace::mpfr<128> acc(0); for (int i = 0; i < {N}; ++i) acc += a[i]; o = (double)acc;""", + language=dace.Language.CPP, + ) + + state.add_edge(fill, "a", arr_node, None, dace.Memlet(f"arr[0:{N}]")) + state.add_edge(arr_node, None, sumup, "a", dace.Memlet(f"arr[0:{N}]")) + state.add_edge(sumup, "o", out_node, None, dace.Memlet("out[0]")) + + csdfg = sdfg.compile() + + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + expected = N * (N + 1) / 2 # 1+2+3+4 = 10 + assert abs(out[0] - expected) < 1e-15, f"Expected {expected}, got {out[0]}" + + +def test_sdfg_binary_ops(): + """Test each binary operator (+, -, *, /) independently.""" + cases = [ + ("add", "lhs + rhs", 7.0, 3.0, 10.0), + ("sub", "lhs - rhs", 7.0, 3.0, 4.0), + ("mul", "lhs * rhs", 7.0, 3.0, 21.0), + ("div", "lhs / rhs", 9.0, 3.0, 3.0), + ] + for op_name, expr, lv, rv, expected in cases: + sdfg = dace.SDFG(f"mpfr_{op_name}") + state = sdfg.add_state() + + sdfg.add_array("out", [1], dace.float64, transient=False) + out_node = state.add_write("out") + + tasklet = state.add_tasklet( + op_name, + {}, + {"o"}, + f"dace::mpfr<128> lhs({lv}), rhs({rv}), res = {expr}; o = (double)res;", + language=dace.Language.CPP, + ) + state.add_edge(tasklet, "o", out_node, None, dace.Memlet("out[0]")) + + csdfg = sdfg.compile() + + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert abs(out[0] - expected) < 1e-14, ( + f"operator{op_name}: expected {expected}, got {out[0]}" + ) + + +def test_sdfg_copy(): + """Copy one mpfr scalar into another and verify the value is preserved.""" + sdfg = dace.SDFG("mpfr_copy") + state = sdfg.add_state() + + sdfg.add_scalar("src", dace.mpfr(128), transient=True) + sdfg.add_scalar("dst", dace.mpfr(128), transient=True) + sdfg.add_array("out", [1], dace.float64, transient=False) + + src_node = state.add_access("src") + dst_node = state.add_access("dst") + out_node = state.add_write("out") + + init = state.add_tasklet( + "init", {}, {"s"}, "s = 1.0 / 3.0;", language=dace.Language.CPP + ) + copy = state.add_tasklet("copy", {"s"}, {"d"}, "d = s;", language=dace.Language.CPP) + conv = state.add_tasklet( + "conv", {"d"}, {"o"}, "o = (double)d;", language=dace.Language.CPP + ) + + state.add_edge(init, "s", src_node, None, dace.Memlet("src")) + state.add_edge(src_node, None, copy, "s", dace.Memlet("src")) + state.add_edge(copy, "d", dst_node, None, dace.Memlet("dst")) + state.add_edge(dst_node, None, conv, "d", dace.Memlet("dst")) + state.add_edge(conv, "o", out_node, None, dace.Memlet("out[0]")) + + csdfg = sdfg.compile() + + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert abs(out[0] - 1.0 / 3.0) < 1e-15, f"Copy changed value: got {out[0]}" + + +_MPFR128 = dace.mpfr(128) +_MPFR_RULES = {frozenset({dace.float64, _MPFR128}): _MPFR128} + + +@dace.program +def _prog_double(x: dace.float64[1], y: dace.float64[1]): + y[0] = x[0] * 2.0 + + +@dace.program +def _prog_add_one(A: dace.float64[4], B: dace.float64[4]): + for i in dace.map[0:4]: + B[i] = A[i] + 1.0 + + +@dace.program +def _prog_chain(A: dace.float64[1], tmp: dace.float64[1], B: dace.float64[1]): + tmp[0] = A[0] + 0.5 + B[0] = tmp[0] * 2.0 + + +def test_cap_elementwise(): + """change_and_propagate: float64 scalar doubled via mpfr(128) internally.""" + sdfg = _prog_double.to_sdfg() + change_and_propagate_fp_types(sdfg, {"x": _MPFR128}, _MPFR_RULES) + + # External interface stays float64; internal casted arrays are mpfr(128). + assert sdfg.arrays["x"].dtype == dace.float64 + assert "fp_casted_x_mpfr128" in sdfg.arrays + assert sdfg.arrays["fp_casted_x_mpfr128"].dtype == _MPFR128 + + csdfg = sdfg.compile() + x = np.array([1.5], dtype=np.float64) + y = np.zeros(1, dtype=np.float64) + csdfg(x=x, y=y) + assert abs(y[0] - 3.0) < 1e-15, f"Expected 3.0, got {y[0]}" + + +def test_cap_array_map(): + """change_and_propagate: float64 array map promoted to mpfr(128).""" + sdfg = _prog_add_one.to_sdfg() + change_and_propagate_fp_types(sdfg, {"A": _MPFR128}, _MPFR_RULES) + + assert sdfg.arrays["A"].dtype == dace.float64 + assert sdfg.arrays["fp_casted_A_mpfr128"].dtype == _MPFR128 + + csdfg = sdfg.compile() + A = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64) + B = np.zeros(4, dtype=np.float64) + csdfg(A=A, B=B) + np.testing.assert_allclose(B, A + 1.0, atol=1e-15) + + +def test_cap_chain_transient(): + """change_and_propagate: all arrays promoted to mpfr(128); float64 interface preserved.""" + sdfg = _prog_chain.to_sdfg() + change_and_propagate_fp_types(sdfg, {"A": _MPFR128}, _MPFR_RULES) + + # All three non-transients are changed, so each gets a cast wrapper; external type stays float64. + for name in ("A", "tmp", "B"): + assert sdfg.arrays[name].dtype == dace.float64 + assert sdfg.arrays[f"fp_casted_{name}_mpfr128"].dtype == _MPFR128 + + csdfg = sdfg.compile() + A = np.array([2.0], dtype=np.float64) + tmp = np.zeros(1, dtype=np.float64) + B = np.zeros(1, dtype=np.float64) + csdfg(A=A, tmp=tmp, B=B) + # (2.0 + 0.5) * 2.0 = 5.0 + assert abs(B[0] - 5.0) < 1e-15, f"Expected 5.0, got {B[0]}" + + +# ── exponent-bits tests ────────────────────────────────────────────────────── +# +# All tests use 5 exponent bits, giving: +# emax = 2^(5-1) - 1 = 15 +# emin = 1 - 15 = -14 +# +# MPFR exponent convention: x = m * 2^e with 1/2 <= |m| < 1. +# 16384 = 2^14 = 0.5 * 2^15 → MPFR exponent 15 (= emax, largest normal) +# 32768 = 2^15 = 0.5 * 2^16 → MPFR exponent 16 (> emax → overflow → +inf) +# +# Minimum subnormal exponent: emin - (precision - 1) = -14 - 127 = -141. +# Starting from 1.0 and dividing by 2 n times gives MPFR exponent -(n-1); +# at n=143 the exponent reaches -142 < -141 and the value flushes to 0. + + +def _make_exp_bits_sdfg(name: str, body: str) -> dace.SDFG: + """Build a single-tasklet SDFG that writes one double output.""" + sdfg = dace.SDFG(name) + state = sdfg.add_state() + sdfg.add_array("out", [1], dace.float64) + out_node = state.add_write("out") + tasklet = state.add_tasklet("op", {}, {"o"}, body, language=dace.Language.CPP) + state.add_edge(tasklet, "o", out_node, None, dace.Memlet("out[0]")) + return sdfg + + +def test_exponent_bits_value_in_range(): + """With emax=15, 1.0 + 1.0 = 2.0 is within range and unchanged.""" + sdfg = _make_exp_bits_sdfg( + "mpfr_exp_in_range", + "dace::set_mpfr_exponent_bits(5);\n" + "dace::mpfr<128> a(1.0), b(1.0);\n" + "o = (double)(a + b);", + ) + csdfg = sdfg.compile() + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert abs(out[0] - 2.0) < 1e-15, f"Expected 2.0, got {out[0]}" + + +def test_exponent_bits_overflow_construction(): + """With emax=15, constructing 32768.0 (MPFR exponent 16) gives +inf.""" + sdfg = _make_exp_bits_sdfg( + "mpfr_exp_overflow_ctor", + "dace::set_mpfr_exponent_bits(5);\n" + "dace::mpfr<128> a(32768.0);\n" # 2^15, MPFR exp 16 > emax=15 + "o = (double)a;", + ) + csdfg = sdfg.compile() + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert np.isposinf(out[0]), f"Expected +inf, got {out[0]}" + + +def test_exponent_bits_overflow_arithmetic(): + """With emax=15, 16384.0 * 2.0 = 32768.0 (MPFR exponent 16) gives +inf.""" + sdfg = _make_exp_bits_sdfg( + "mpfr_exp_overflow_arith", + "dace::set_mpfr_exponent_bits(5);\n" + "dace::mpfr<128> a(16384.0), b(2.0);\n" # result exp 16 > emax=15 + "o = (double)(a * b);", + ) + csdfg = sdfg.compile() + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert np.isposinf(out[0]), f"Expected +inf, got {out[0]}" + + +def test_exponent_bits_underflow_to_zero(): + """With emin=-14 and precision=128, values below the min subnormal flush to 0. + + Dividing 1.0 by 2 repeatedly: after n steps the MPFR exponent is -(n-1). + At n=143 the exponent is -142 < emin-(precision-1)=-141, so it rounds to 0. + """ + sdfg = _make_exp_bits_sdfg( + "mpfr_exp_underflow", + "dace::set_mpfr_exponent_bits(5);\n" + "dace::mpfr<128> a(1.0), two(2.0);\n" + "for (int i = 0; i < 200; ++i) a /= two;\n" + "o = (double)a;", + ) + csdfg = sdfg.compile() + out = np.zeros(1, dtype=np.float64) + csdfg(out=out) + assert out[0] == 0.0, f"Expected 0.0, got {out[0]}" + + +if __name__ == "__main__": + test_sdfg_scalar_compute() + test_sdfg_array_sum() + test_sdfg_binary_ops() + test_sdfg_copy() + test_cap_elementwise() + test_cap_array_map() + test_cap_chain_transient() + test_exponent_bits_value_in_range() + test_exponent_bits_overflow_construction() + test_exponent_bits_overflow_arithmetic() + test_exponent_bits_underflow_to_zero() + print("All SDFG tests passed.")