diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index 0369349d6b..deee4de570 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -7,7 +7,12 @@ from quadrants.lang.common_ops import QuadrantsOperations from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError from quadrants.lang.matrix import make_matrix -from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type +from quadrants.lang.util import ( + cook_dtype, + is_matrix_class, + is_quadrants_class, + to_numpy_type, +) from quadrants.types import primitive_types from quadrants.types.primitive_types import integer_types, real_types @@ -109,12 +114,17 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int: def make_constant_expr(val, dtype): + # Normalise dtype once up front so the per-branch fallbacks only need to + # cook the runtime defaults (default_fp / default_ip). + if dtype is not None: + dtype = cook_dtype(dtype) + if isinstance(val, (bool, np.bool_)): - constant_dtype = primitive_types.u1 + constant_dtype = cook_dtype(primitive_types.u1) return Expr(_qd_core.make_const_expr_bool(constant_dtype, val)) if isinstance(val, (float, np.floating)): - constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_fp) if constant_dtype not in real_types: raise QuadrantsTypeError( "Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`." @@ -122,7 +132,7 @@ def make_constant_expr(val, dtype): return Expr(_qd_core.make_const_expr_fp(constant_dtype, val)) if isinstance(val, (int, np.integer)): - constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_ip) if constant_dtype not in integer_types: raise QuadrantsTypeError( "Integer literals must be annotated with a integer type. For type casting, use `qd.cast`." diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index a8f4a15e1a..7a791bd2f4 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -84,6 +84,7 @@ def expr_init_shared_array(shape, element_type): ast_builder = get_runtime().compiling_callable.ast_builder() debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info()) + element_type = cook_dtype(element_type) return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info) @@ -365,9 +366,9 @@ def __init__(self, kernels=None): self.grad_vars = [] self.dual_vars = [] self.matrix_fields = [] - self.default_fp = f32 - self.default_ip = i32 - self.default_up = u32 + self._default_fp = cook_dtype(f32) + self._default_ip = cook_dtype(i32) + self._default_up = cook_dtype(u32) self.print_full_traceback: bool = False self.target_tape = None self.fwd_mode_manager = None @@ -381,6 +382,30 @@ def __init__(self, kernels=None): self.unrolling_limit: int = 0 self.src_ll_cache: bool = True + @property + def default_fp(self) -> DataTypeCxx: + return self._default_fp + + @default_fp.setter + def default_fp(self, value: Any) -> None: + self._default_fp = cook_dtype(value) + + @property + def default_ip(self) -> DataTypeCxx: + return self._default_ip + + @default_ip.setter + def default_ip(self, value: Any) -> None: + self._default_ip = cook_dtype(value) + + @property + def default_up(self) -> DataTypeCxx: + return self._default_up + + @default_up.setter + def default_up(self, value: Any) -> None: + self._default_up = cook_dtype(value) + @property def compiling_callable(self) -> KernelCxx | Kernel | Function: if self._compiling_callable is None: @@ -747,10 +772,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual): if prog.config().debug: # adjoint checkbit x_grad_checkbit = Expr(prog.make_id_expr("")) - dtype = u8 + checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: - dtype = i32 - x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype)) + checkbit_dtype = i32 + x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype)) x_grad_checkbit.ptr.set_name(name + ".grad_checkbit") x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT) x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr) diff --git a/python/quadrants/lang/matrix.py b/python/quadrants/lang/matrix.py index bb264321db..26c16359c2 100644 --- a/python/quadrants/lang/matrix.py +++ b/python/quadrants/lang/matrix.py @@ -176,7 +176,7 @@ def make_matrix(arr, dt=None): if len(arr) == 0: # the only usage of an empty vector is to serve as field indices shape = [0] - dt = primitive_types.i32 + dt = cook_dtype(primitive_types.i32) else: if isinstance(arr[0], Iterable): # matrix shape = [len(arr), len(arr[0])] diff --git a/python/quadrants/linalg/sparse_matrix.py b/python/quadrants/linalg/sparse_matrix.py index 7eb4f40be2..09cfd75a35 100644 --- a/python/quadrants/linalg/sparse_matrix.py +++ b/python/quadrants/linalg/sparse_matrix.py @@ -9,6 +9,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.types import f32 @@ -24,11 +25,12 @@ class SparseMatrix: """ def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"): - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if sm is None: self.n = n self.m = m if m else n - self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format) + self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format) else: self.n = sm.num_rows() self.m = sm.num_cols() @@ -247,7 +249,8 @@ def __init__( ): self.num_rows = num_rows self.num_cols = num_cols if num_cols else num_rows - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if num_rows is not None: quadrants_arch = get_runtime().prog.config().arch if quadrants_arch in [ @@ -259,7 +262,7 @@ def __init__( num_rows, num_cols, max_num_triplets, - dtype, + dtype_cxx, storage_format, ) self.ptr.create_ndarray(get_runtime().prog) diff --git a/python/quadrants/linalg/sparse_solver.py b/python/quadrants/linalg/sparse_solver.py index 3544d1a957..e66de69b06 100644 --- a/python/quadrants/linalg/sparse_solver.py +++ b/python/quadrants/linalg/sparse_solver.py @@ -8,6 +8,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.linalg.sparse_matrix import SparseMatrix from quadrants.types.primitive_types import f32 @@ -24,7 +25,8 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): or quadrants_arch == _qd_core.Arch.cuda ), "SparseSolver only supports CPU and CUDA for now." if quadrants_arch == _qd_core.Arch.cuda: - self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering) else: - self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering) else: raise QuadrantsRuntimeError( f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported." diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index fc37ef582c..989286c23d 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -1,3 +1,5 @@ +import pytest + import quadrants as qd from tests import test_utils @@ -124,3 +126,31 @@ def clear_dual_test(): with qd.ad.FwdMode(loss=loss, param=x): clear_dual_test() assert y.dual[None] == 4.0 + + +@pytest.mark.parametrize("dtype", [qd.f32, qd.f64]) +@test_utils.test(debug=True) +def test_dual_field_dtype_preserved_in_debug_mode(dtype): + """Regression: debug-mode checkbit must not shadow the outer dtype. + + Picks values whose dual is a non-integer exactly representable in + both ``f32`` and ``f64`` (``x=1.25`` -> ``dual=2.5``): under the old + bug the dual field was created as ``u8`` (or ``i32`` on Vulkan), + which would truncate ``2.5`` to ``2`` and fail the assertion. + """ + test_utils.skip_if_f64_unsupported(dtype) + + x = qd.field(dtype, shape=(), needs_dual=True) + loss = qd.field(dtype, shape=(), needs_dual=True) + + x[None] = 1.25 + + @qd.kernel + def compute(): + loss[None] = x[None] * x[None] + + with qd.ad.FwdMode(loss=loss, param=x): + compute() + + assert loss[None] == 1.5625 + assert loss.dual[None] == 2.5 diff --git a/tests/run_tests.py b/tests/run_tests.py index a454003002..36059493cc 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -52,6 +52,8 @@ def _test_python(args, default_dir="python"): pytest_args += [ "--durations=15", "-p", + "no:timeout", + "-p", "pytest_hardtle", f"--timeout={args.timeout}", ]