Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refactors several ASM-backed ops to use a torch-free C ABI (extern "C") plus a Python ctypes call path (via a new AiterTensor struct), reducing reliance on PyTorch C++/pybind for these kernels.
Changes:
- Introduces
AiterTensor/AiterDtypeplumbing and acompile_ops(..., ffi_type="ctypes")dispatch path that loads and calls.sosymbols viactypes. - Refactors ASM implementations (topk-softmax, layernorm, GEMM a16w16) to accept
AiterTensor*and an explicithipStream_t. - Adjusts build config and Python wrappers to use the new ctypes path; adds contiguity gating for the ASM topk-softmax callsites/tests.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_moeTopkSoftmax.py | Makes ASM path contiguous; minor cleanup in allclose helper |
| csrc/py_itfs_cu/asm_topksoftmax.cu | Switches to C ABI + AiterTensor* args and explicit stream/device handling |
| csrc/py_itfs_cu/asm_layernorm.cu | Switches to C ABI + AiterTensor* args and explicit stream/device handling |
| csrc/py_itfs_cu/asm_gemm_a16w16.cu | Switches to C ABI + AiterTensor* args; removes torch types; explicit device handling |
| csrc/include/rocm_ops.hpp | Removes some pybind bindings; updates macros (notably quant bindings) |
| csrc/include/norm.h | Formatting + removes ASM layernorm declarations from this torch header |
| csrc/include/moe_op.h | Removes torch-signature declaration for topk_softmax_asm |
| csrc/include/asm_gemm_a16w16.h | Deletes old torch-signature header for GEMM ASM |
| csrc/include/aiter_tensor.h | Adds new C struct used for ctypes FFI tensor metadata |
| csrc/include/aiter_hip_common.h | Adds AITER_CHECK and includes for new enum/tensor headers |
| csrc/include/aiter_enum.h | Adds AiterDtype enum and helpers |
| aiter/utility/aiter_types.py | Adds ctypes definitions + header parsing for dtype IDs |
| aiter/utility/dtypes.py | Adds torch_to_aiter() conversion and dtype maps for ctypes path |
| aiter/ops/norm.py | Points ASM layernorm wrappers to a new module + ctypes FFI |
| aiter/ops/moe_op.py | Switches topk_softmax_asm wrapper to ctypes FFI |
| aiter/ops/gemm_op_a16w16.py | Switches GEMM ASM wrapper to ctypes FFI and returns out explicitly |
| aiter/jit/optCompilerConfig.json | Updates build modules (drops some pybind sources; adds module_asm_layernorm) |
| aiter/jit/core.py | Adds _ctypes_call() and ffi_type option to compile_ops |
| aiter/fused_moe.py | Adds gating_output.is_contiguous() requirement for ASM topk-softmax fast path |
Comments suppressed due to low confidence (2)
aiter/ops/moe_op.py:32
topk_softmax_asmis switched toffi_type="ctypes", but it still targetsmodule_moe_asm, whose build config includes torch/pybind sources (e.g.,pybind/moe_op_pybind.cu)._ctypes_call()forcestorch_exclude=Truewhen it needs to build the module, so a first-time call totopk_softmax_asm(without the pybind module already built) is likely to fail to link/compile. Suggested fix: create a dedicated torch-free module (e.g.,module_asm_topksoftmax) that only compilespy_itfs_cu/asm_topksoftmax.cu(+ any needed deps) and point this decorator at it; alternatively, avoid settingtorch_exclude=Truefor modules that still depend on torch/pybind.
@compile_ops("module_moe_asm", fc_name="topk_softmax_asm", ffi_type="ctypes")
def topk_softmax_asm(
topk_weights: Tensor,
topk_indices: Tensor,
token_expert_indices: Tensor,
gating_output: Tensor,
need_renorm: bool,
) -> None: ...
csrc/include/rocm_ops.hpp:1459
QUANT_PYBINDno longer bindsmoe_smooth_per_token_scaled_quant_v1/v2, but Python still calls these via@compile_ops("module_quant")(seeaiter/ops/quant.py). This will cause runtimeAttributeErrorwhen the compiledmodule_quantis imported andgetattr(module, ...)is attempted. Either restore thesem.def(...)bindings inQUANT_PYBINDor update the Python API to stop referencing these functions.
#define QUANT_PYBIND \
m.def("static_per_tensor_quant", &aiter::static_per_tensor_quant); \
m.def("dynamic_per_tensor_quant", &aiter::dynamic_per_tensor_quant); \
m.def("dynamic_per_token_scaled_quant", \
&aiter::dynamic_per_token_scaled_quant, \
py::arg("out"), \
py::arg("input"), \
py::arg("scales"), \
py::arg("scale_ub") = std::nullopt, \
py::arg("shuffle_scale") = false, \
py::arg("num_rows") = std::nullopt, \
py::arg("num_rows_factor") = 1); \
m.def("dynamic_per_group_scaled_quant_fp4", \
&aiter::dynamic_per_group_scaled_quant_fp4, \
py::arg("out"), \
py::arg("input"), \
py::arg("scales"), \
py::arg("group_size") = 32, \
py::arg("shuffle_scale") = true, \
py::arg("num_rows") = std::nullopt, \
py::arg("num_rows_factor") = 1); \
m.def("smooth_per_token_scaled_quant", \
&aiter::smooth_per_token_scaled_quant, \
py::arg("out"), \
py::arg("input"), \
py::arg("scales"), \
py::arg("smooth_scale"), \
py::arg("smooth_scale_map") = std::nullopt, \
py::arg("shuffle_scale") = false, \
py::arg("num_rows") = std::nullopt, \
py::arg("num_rows_factor") = 1, \
py::arg("smooth_scale_map_hash") = std::nullopt, \
py::arg("enable_ps") = true); \
m.def("partial_transpose", \
&aiter::partial_transpose, \
py::arg("out"), \
py::arg("input"), \
py::arg("num_rows"));
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+220
to
222
| torch.isclose(sr, st, rtol=rtol, atol=atol) | ||
| logger.info( | ||
| f"{msg} [{i_row}x{i_col}], r:{r_idx[i_row, i_col]}->{sr}, t:{t_idx[i_row, i_col]}->{st}" |
Comment on lines
+4
to
+12
| #include "aiter_enum.h" | ||
| #include <cstdint> | ||
|
|
||
| struct AiterTensor | ||
| { | ||
| void* ptr; // data_ptr, pointer to GPU memory | ||
| size_t numel_; // total number of elements | ||
| int ndim; // number of dimensions | ||
| int64_t shape[8]; // size of each dimension, up to 8 dims (PyTorch limit) |
Comment on lines
3
to
+55
| @@ -18,3 +19,55 @@ enum class QuantType : int | |||
| per_1x128, | |||
| per_128x128, | |||
| }; | |||
| typedef enum { | |||
| AITER_DTYPE_fp8, | |||
| AITER_DTYPE_fp8_e8m0, | |||
| AITER_DTYPE_fp16, | |||
| AITER_DTYPE_bf16, | |||
| AITER_DTYPE_fp32, | |||
| AITER_DTYPE_i4x2, | |||
| AITER_DTYPE_fp4x2, | |||
| AITER_DTYPE_u32, | |||
| AITER_DTYPE_i32, | |||
| AITER_DTYPE_i16, | |||
| AITER_DTYPE_i8, | |||
| AITER_DTYPE_u8, | |||
| } AiterDtype; | |||
|
|
|||
| static inline size_t AiterDtype_element_size(AiterDtype dtype) | |||
| { | |||
| switch (dtype) { | |||
| case AITER_DTYPE_fp8: | |||
| case AITER_DTYPE_fp8_e8m0: | |||
| case AITER_DTYPE_i4x2: | |||
| case AITER_DTYPE_fp4x2: | |||
| case AITER_DTYPE_i8: | |||
| case AITER_DTYPE_u8: return 1; | |||
| case AITER_DTYPE_fp16: | |||
| case AITER_DTYPE_bf16: | |||
| case AITER_DTYPE_i16: return 2; | |||
| case AITER_DTYPE_fp32: | |||
| case AITER_DTYPE_u32: | |||
| case AITER_DTYPE_i32: return 4; | |||
| default: return 0; | |||
| } | |||
| } | |||
|
|
|||
Comment on lines
22
to
+56
| i4x2 = getattr(torch, "int4", _8bit_fallback) | ||
| fp4x2 = getattr(torch, "float4_e2m1fn_x2", _8bit_fallback) | ||
| fp8 = get_dtype_fp8() | ||
| fp8_e8m0 = getattr(torch, "float8_e8m0fnu", _8bit_fallback) | ||
| fp16 = torch.float16 | ||
| bf16 = torch.bfloat16 | ||
| fp32 = torch.float32 | ||
| u32 = torch.uint32 | ||
| i32 = torch.int32 | ||
| i16 = torch.int16 | ||
| i8 = torch.int8 | ||
|
|
||
| d_dtypes = { | ||
| "fp8": fp8, | ||
| "fp8_e8m0": fp8_e8m0, | ||
| "fp16": fp16, | ||
| "bf16": bf16, | ||
| "fp32": fp32, | ||
| "i4x2": i4x2, | ||
| "fp4x2": fp4x2, | ||
| "u32": u32, | ||
| "i32": i32, | ||
| "i16": i16, | ||
| "i8": i8, | ||
| } | ||
| u8 = torch.uint8 | ||
|
|
||
| d_dtypes = {name: globals()[name] for name in aiter_dtypes} | ||
|
|
||
| globals().update({f"AITER_DTYPE_{name}": idx for name, idx in aiter_dtypes.items()}) | ||
| _torch_to_aiter_dtype = {globals()[name]: idx for name, idx in aiter_dtypes.items()} | ||
|
|
||
|
|
||
| def torch_to_aiter(tensor: torch.Tensor) -> AiterTensor: | ||
| """torch.Tensor -> AiterTensor, zero-copy, points to the same GPU memory.""" | ||
| assert tensor.is_cuda, "AiterTensor only supports CUDA tensors" | ||
| assert tensor.ndim <= 8, f"AiterTensor supports at most 8 dims, got {tensor.ndim}" | ||
| assert tensor.dtype in _torch_to_aiter_dtype, f"Unsupported dtype: {tensor.dtype}" | ||
|
|
||
| at = AiterTensor() | ||
| at.ptr = tensor.data_ptr() | ||
| at.numel_ = tensor.numel() | ||
| at.ndim = tensor.ndim | ||
| for i in range(tensor.ndim): | ||
| at.shape[i] = tensor.shape[i] | ||
| at.strides[i] = tensor.stride(i) | ||
| at.dtype_ = _torch_to_aiter_dtype[tensor.dtype] | ||
| at.device_id = tensor.device.index or 0 | ||
| return at |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist