diff --git a/numbast/src/numbast/callconv.py b/numbast/src/numbast/callconv.py index 5e10eadd..94df7b40 100644 --- a/numbast/src/numbast/callconv.py +++ b/numbast/src/numbast/callconv.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# NUMBAST_RETVAL_ALIGN_FIX_APPLIED from numbast.args import prepare_ir_types from numbast.intent import IntentPlan @@ -107,6 +108,28 @@ def _lower_impl(self, builder, context, sig, args): else: retval_ty = context.get_value_type(cxx_return_type) retval_ptr = builder.alloca(retval_ty, name="retval") + # Use the Numba type's alignof_ to set the alloca alignment. + # + # LLVM computes struct ABI alignment as the max alignment of its + # members. For CUDA vector types (float2, float4, uchar4, …) + # declared with __align__(N) in the CUDA headers, N can exceed + # the member alignment: float2 is {float,float} with member + # alignment 4 B but __align__(8). LLVM therefore assigns a 4 B + # alloca, while the NVRTC shim uses a vector instruction + # (ld/st.v2.f32) that requires 8 B alignment, causing + # cudaErrorMisalignedAddress at runtime. + # + # Numbast's struct binder already sets alignof_ on user-defined + # bound structs (propagated from ast_canopy). For built-in CUDA + # vector types, callers must set alignof_ on the Numba type when + # registering it in CTYPE_MAPS. When alignof_ is present it is + # used here, matching the convention already applied to loads and + # stores (getattr(argty, "alignof_", None)). When absent, LLVM's + # default ABI alignment is used, which is correct for scalars and + # structs without an explicit __align__ attribute. + _nb_align = getattr(cxx_return_type, "alignof_", None) + if _nb_align is not None: + retval_ptr.align = _nb_align # 2. Prepare arguments if self._intent_plan is None: @@ -154,8 +177,11 @@ def _lower_impl(self, builder, context, sig, args): ptrs.append(arg) else: ptr = cgutils.alloca_once(builder, vty) + _nb_align = getattr(argty, "alignof_", None) + if _nb_align is not None: + ptr.align = _nb_align # see retval_ptr comment above builder.store( - arg, ptr, align=getattr(argty, "alignof_", None) + arg, ptr, align=_nb_align ) ptrs.append(ptr) else: @@ -175,6 +201,9 @@ def _lower_impl(self, builder, context, sig, args): out_nbty = self._out_return_types[out_pos] vty = context.get_value_type(out_nbty) ptr = cgutils.alloca_once(builder, vty) + _nb_align = getattr(out_nbty, "alignof_", None) + if _nb_align is not None: + ptr.align = _nb_align # see retval_ptr comment above ptrs.append(ptr) arg_pointer_types.append(ir.PointerType(vty)) out_return_ptrs.append((out_nbty, ptr)) @@ -194,8 +223,11 @@ def _lower_impl(self, builder, context, sig, args): arg_pointer_types.append(vty) else: ptr = cgutils.alloca_once(builder, vty) + _nb_align = getattr(argty, "alignof_", None) + if _nb_align is not None: + ptr.align = _nb_align # see retval_ptr comment above builder.store( - arg, ptr, align=getattr(argty, "alignof_", None) + arg, ptr, align=_nb_align ) ptrs.append(ptr) arg_pointer_types.append(ir.PointerType(vty))