Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/fp-arena-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
sudo apt-get update
# gcc/g++ (C++20) + cmake/ninja drive DaCe's CPU codegen build; the
# generated maps use OpenMP (libgomp).
sudo apt-get install -y cmake ninja-build gcc g++ libgomp1
sudo apt-get install -y cmake ninja-build gcc g++ libgomp1 libmpfr-dev

- name: Install DaCe (yakup/dev) and FP-Arena
run: |
Expand Down
4 changes: 4 additions & 0 deletions fp_arena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Float64sr,
float32sr,
float64sr,
mpfr,
register,
FP_ARENA_TYPECLASSES,
)
Expand All @@ -29,6 +30,7 @@
INCLUDE_DIR,
)
from fp_arena.transformations.change_fp_types import change_fptype
from fp_arena.transformations.change_and_propagate_fp_types import change_and_propagate_fp_types

# Register the types and the SDFG convenience method on import (idempotent).
register()
Expand All @@ -47,6 +49,7 @@
"Float64sr",
"float32sr",
"float64sr",
"mpfr",
"register",
"FP_ARENA_TYPECLASSES",
"enable_fp_arena_extensions",
Expand All @@ -59,4 +62,5 @@
"fp_arena_global_code",
"INCLUDE_DIR",
"change_fptype",
"change_and_propagate_fp_types",
]
61 changes: 60 additions & 1 deletion fp_arena/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
"""

import numpy
import ctypes

import dace
from dace import dtypes as _ddtypes
from dace import dtypes as _ddtypes, typeclass

#: C++ namespace-qualified type names emitted into generated code.
_FLOAT32SR_CTYPE = "fp_arena::float32sr"
Expand Down Expand Up @@ -77,6 +78,60 @@ def __repr__(self) -> str:
}


class _mpfr_t(ctypes.Structure):
_fields_ = [
("_mpfr_prec", ctypes.c_long),
("_mpfr_sign", ctypes.c_int),
("_mpfr_exp", ctypes.c_long),
("_mpfr_d", ctypes.c_void_p),
]


class mpfr(typeclass):
"""
A data type for custom Multiple Precision Floating-Point (MPFR) types.

Example use: `dace.mpfr(128)` for 128-bit precision.
"""

def __init__(self, precision: int):
self.precision = precision
self.type = numpy.object_
self.bytes = ctypes.sizeof(_mpfr_t)
self.dtype = self
self.typename = f"mpfr{precision}"

def to_string(self):
return self.typename

def to_json(self):
return {"type": "mpfr", "precision": self.precision}

@staticmethod
def from_json(json_obj, context=None):
if json_obj["type"] != "mpfr":
raise TypeError("Invalid type for mpfr")
return mpfr(json_obj["precision"])

@property
def ctype(self):
return f"dace::mpfr<{self.precision}>"

@property
def ctype_unaligned(self):
return self.ctype

def as_ctypes(self):
return ctypes.c_void_p

def as_numpy_dtype(self):
return numpy.dtype(numpy.object_)

@property
def base_type(self):
return self


def register():
"""
Register the FP-Arena types into DaCe's global registries.
Expand All @@ -100,4 +155,8 @@ def register():
_ddtypes.TYPECLASS_STRINGS.append(name)
_ddtypes.TYPECLASS_TO_STRING.setdefault(tc, tc.ctype)

# Also expose the parametric mpfr class so `dace.mpfr(128)` works.
setattr(_ddtypes, "mpfr", mpfr)
setattr(dace, "mpfr", mpfr)

return FP_ARENA_TYPECLASSES
21 changes: 16 additions & 5 deletions fp_arena/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_HEADERS = (
os.path.join(INCLUDE_DIR, "fp_arena", "float32sr.h"),
os.path.join(INCLUDE_DIR, "fp_arena", "float64sr.h"),
os.path.join(INCLUDE_DIR, "fp_arena", "mpfr.h"),
)

#: Backends whose global-code section receives the includes (CPU frame + CUDA).
Expand All @@ -51,8 +52,8 @@
#: Marker so the includes are only injected once per SDFG.
_GUARD = "// fp_arena extensions enabled"

#: Substring identifying an FP-Arena C type (used to detect SR usage in an SDFG).
_CTYPE_MARKER = "fp_arena::"
#: Substrings identifying an FP-Arena C type (fp_arena:: for SR types, dace::mpfr for mpfr).
_CTYPE_MARKERS = ("fp_arena::", "dace::mpfr")


def fp_arena_global_code() -> str:
Expand Down Expand Up @@ -145,13 +146,23 @@ def enable_fp_arena_extensions(sdfg: dace.SDFG) -> dace.SDFG:
def uses_fp_arena_types(sdfg: dace.SDFG) -> bool:
"""
:param sdfg: the SDFG to inspect.
:returns: ``True`` if any data descriptor in ``sdfg`` or its nested SDFGs has
an FP-Arena C type (e.g. ``float32sr``), ``False`` otherwise.
:returns: ``True`` if any data descriptor or tasklet body in ``sdfg`` or its
nested SDFGs references an FP-Arena C type, ``False`` otherwise.
"""
from dace.sdfg import nodes as _dnodes
for nested in sdfg.all_sdfgs_recursive():
for desc in nested.arrays.values():
if _CTYPE_MARKER in (getattr(desc.dtype, "ctype", "") or ""):
if any(m in (getattr(desc.dtype, "ctype", "") or "") for m in _CTYPE_MARKERS):
return True
for state in nested.states():
for node in state.nodes():
if isinstance(node, _dnodes.Tasklet):
try:
code_str = node.code.as_string
except AttributeError:
code_str = str(node.code)
if any(m in code_str for m in _CTYPE_MARKERS):
return True
return False


Expand Down
222 changes: 222 additions & 0 deletions fp_arena/runtime/include/fp_arena/mpfr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#pragma once

#include <mpfr.h>

#include <iostream>

namespace dace {

// Process-global exponent width (0 = unlimited).
// Set once before any computation via set_mpfr_exponent_bits(). Not
// thread-safe.
inline unsigned int mpfr_exponent_bits = 0;

// Configure the exponent range for all dace::mpfr values.
// Uses IEEE-style bias: emax = 2^(bits-1)-1, emin = 1-emax.

inline void set_mpfr_exponent_bits(unsigned int bits) {
mpfr_exponent_bits = bits;
if (bits > 0) {
mpfr_exp_t emax = (mpfr_exp_t(1) << (bits - 1)) - 1;
mpfr_set_emin(1 - emax);
mpfr_set_emax(emax);
} else {
mpfr_set_emin(MPFR_EMIN_DEFAULT);
mpfr_set_emax(MPFR_EMAX_DEFAULT);
}
}

template <unsigned int Precision> class mpfr {
private:
mpfr_t val;

void clamp_exp(int t) {
if (mpfr_exponent_bits > 0) {
int temp = mpfr_check_range(val, t, MPFR_RNDN);
mpfr_subnormalize(val, temp, MPFR_RNDN);
}
}

public:
// Constructors
mpfr() { mpfr_init2(val, Precision); }

mpfr(double d) {
mpfr_init2(val, Precision);
clamp_exp(mpfr_set_d(val, d, MPFR_RNDN));
}

mpfr(float f) {
mpfr_init2(val, Precision);
clamp_exp(mpfr_set_flt(val, f, MPFR_RNDN));
}

mpfr(int i) {
mpfr_init2(val, Precision);
clamp_exp(mpfr_set_si(val, i, MPFR_RNDN));
}

mpfr(unsigned int i) {
mpfr_init2(val, Precision);
clamp_exp(mpfr_set_ui(val, i, MPFR_RNDN));
}

mpfr(long int i) {
mpfr_init2(val, Precision);
clamp_exp(mpfr_set_si(val, i, MPFR_RNDN));
}

mpfr(unsigned long int i) {
mpfr_init2(val, Precision);
clamp_exp(mpfr_set_ui(val, i, MPFR_RNDN));
}

// Copy Constructor
mpfr(const mpfr &other) {
mpfr_init2(val, Precision);
mpfr_set(val, other.val, MPFR_RNDN);
}

// Move Constructor
mpfr(mpfr &&other) noexcept {
mpfr_init2(val, Precision);
mpfr_swap(val, other.val);
}

// Destructor
~mpfr() { mpfr_clear(val); }

// Copy Assignment
mpfr &operator=(const mpfr &other) {
if (this != &other) {
mpfr_set(val, other.val, MPFR_RNDN);
}
return *this;
}

// Move Assignment
mpfr &operator=(mpfr &&other) noexcept {
if (this != &other) {
mpfr_swap(val, other.val);
}
return *this;
}

// Assign from basic types
mpfr &operator=(double d) {
clamp_exp(mpfr_set_d(val, d, MPFR_RNDN));
return *this;
}

mpfr &operator=(float f) {
clamp_exp(mpfr_set_flt(val, f, MPFR_RNDN));
return *this;
}

mpfr &operator=(int i) {
clamp_exp(mpfr_set_si(val, i, MPFR_RNDN));
return *this;
}

// Conversions
explicit operator double() const { return mpfr_get_d(val, MPFR_RNDN); }

explicit operator float() const { return mpfr_get_flt(val, MPFR_RNDN); }

explicit operator int() const { return (int)mpfr_get_si(val, MPFR_RNDN); }

explicit operator long() const { return mpfr_get_si(val, MPFR_RNDN); }

explicit operator bool() const { return mpfr_cmp_d(val, 0.0) != 0; }

// Arithmetic Operators
friend mpfr operator+(const mpfr &lhs, const mpfr &rhs) {
mpfr result;
result.clamp_exp(mpfr_add(result.val, lhs.val, rhs.val, MPFR_RNDN));
return result;
}

friend mpfr operator-(const mpfr &lhs, const mpfr &rhs) {
mpfr result;
result.clamp_exp(mpfr_sub(result.val, lhs.val, rhs.val, MPFR_RNDN));
return result;
}

friend mpfr operator*(const mpfr &lhs, const mpfr &rhs) {
mpfr result;
result.clamp_exp(mpfr_mul(result.val, lhs.val, rhs.val, MPFR_RNDN));
return result;
}

friend mpfr operator/(const mpfr &lhs, const mpfr &rhs) {
mpfr result;
result.clamp_exp(mpfr_div(result.val, lhs.val, rhs.val, MPFR_RNDN));
return result;
}

// Unary minus
mpfr operator-() const {
mpfr result;
result.clamp_exp(mpfr_neg(result.val, this->val, MPFR_RNDN));
return result;
}

// Compound assignments
mpfr &operator+=(const mpfr &other) {
clamp_exp(mpfr_add(val, val, other.val, MPFR_RNDN));
return *this;
}

mpfr &operator-=(const mpfr &other) {
clamp_exp(mpfr_sub(val, val, other.val, MPFR_RNDN));
return *this;
}

mpfr &operator*=(const mpfr &other) {
clamp_exp(mpfr_mul(val, val, other.val, MPFR_RNDN));
return *this;
}

mpfr &operator/=(const mpfr &other) {
clamp_exp(mpfr_div(val, val, other.val, MPFR_RNDN));
return *this;
}

// Comparison Operators
friend bool operator==(const mpfr &lhs, const mpfr &rhs) {
return mpfr_cmp(lhs.val, rhs.val) == 0;
}

friend bool operator!=(const mpfr &lhs, const mpfr &rhs) {
return mpfr_cmp(lhs.val, rhs.val) != 0;
}

friend bool operator<(const mpfr &lhs, const mpfr &rhs) {
return mpfr_cmp(lhs.val, rhs.val) < 0;
}

friend bool operator>(const mpfr &lhs, const mpfr &rhs) {
return mpfr_cmp(lhs.val, rhs.val) > 0;
}

friend bool operator<=(const mpfr &lhs, const mpfr &rhs) {
return mpfr_cmp(lhs.val, rhs.val) <= 0;
}

friend bool operator>=(const mpfr &lhs, const mpfr &rhs) {
return mpfr_cmp(lhs.val, rhs.val) >= 0;
}

// IO Stream
friend std::ostream &operator<<(std::ostream &os, const mpfr &m) {
char *str = nullptr;
mpfr_asprintf(&str, "%.*RNg", (int)(Precision * 0.30103) + 2, m.val);
if (str) {
os << str;
mpfr_free_str(str);
}
return os;
}
};

} // namespace dace
3 changes: 2 additions & 1 deletion fp_arena/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"""FP-Arena SDFG transformations."""

from fp_arena.transformations.change_fp_types import change_fptype
from fp_arena.transformations.change_and_propagate_fp_types import change_and_propagate_fp_types

__all__ = ["change_fptype"]
__all__ = ["change_fptype", "change_and_propagate_fp_types"]
Loading
Loading