Skip to content

Refactor topk softmax asm bind#2327

Open
yzhou103 wants to merge 4 commits intomainfrom
refactor_topk_softmax_asm_bind
Open

Refactor topk softmax asm bind#2327
yzhou103 wants to merge 4 commits intomainfrom
refactor_topk_softmax_asm_bind

Conversation

@yzhou103
Copy link
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@yzhou103 yzhou103 requested review from a team and Copilot March 18, 2026 08:50
@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2327 --add-label <label>

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors several ASM-backed ops to use a torch-free C ABI (extern "C") plus a Python ctypes call path (via a new AiterTensor struct), reducing reliance on PyTorch C++/pybind for these kernels.

Changes:

  • Introduces AiterTensor/AiterDtype plumbing and a compile_ops(..., ffi_type="ctypes") dispatch path that loads and calls .so symbols via ctypes.
  • Refactors ASM implementations (topk-softmax, layernorm, GEMM a16w16) to accept AiterTensor* and an explicit hipStream_t.
  • Adjusts build config and Python wrappers to use the new ctypes path; adds contiguity gating for the ASM topk-softmax callsites/tests.

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/test_moeTopkSoftmax.py Makes ASM path contiguous; minor cleanup in allclose helper
csrc/py_itfs_cu/asm_topksoftmax.cu Switches to C ABI + AiterTensor* args and explicit stream/device handling
csrc/py_itfs_cu/asm_layernorm.cu Switches to C ABI + AiterTensor* args and explicit stream/device handling
csrc/py_itfs_cu/asm_gemm_a16w16.cu Switches to C ABI + AiterTensor* args; removes torch types; explicit device handling
csrc/include/rocm_ops.hpp Removes some pybind bindings; updates macros (notably quant bindings)
csrc/include/norm.h Formatting + removes ASM layernorm declarations from this torch header
csrc/include/moe_op.h Removes torch-signature declaration for topk_softmax_asm
csrc/include/asm_gemm_a16w16.h Deletes old torch-signature header for GEMM ASM
csrc/include/aiter_tensor.h Adds new C struct used for ctypes FFI tensor metadata
csrc/include/aiter_hip_common.h Adds AITER_CHECK and includes for new enum/tensor headers
csrc/include/aiter_enum.h Adds AiterDtype enum and helpers
aiter/utility/aiter_types.py Adds ctypes definitions + header parsing for dtype IDs
aiter/utility/dtypes.py Adds torch_to_aiter() conversion and dtype maps for ctypes path
aiter/ops/norm.py Points ASM layernorm wrappers to a new module + ctypes FFI
aiter/ops/moe_op.py Switches topk_softmax_asm wrapper to ctypes FFI
aiter/ops/gemm_op_a16w16.py Switches GEMM ASM wrapper to ctypes FFI and returns out explicitly
aiter/jit/optCompilerConfig.json Updates build modules (drops some pybind sources; adds module_asm_layernorm)
aiter/jit/core.py Adds _ctypes_call() and ffi_type option to compile_ops
aiter/fused_moe.py Adds gating_output.is_contiguous() requirement for ASM topk-softmax fast path
Comments suppressed due to low confidence (2)

aiter/ops/moe_op.py:32

  • topk_softmax_asm is switched to ffi_type="ctypes", but it still targets module_moe_asm, whose build config includes torch/pybind sources (e.g., pybind/moe_op_pybind.cu). _ctypes_call() forces torch_exclude=True when it needs to build the module, so a first-time call to topk_softmax_asm (without the pybind module already built) is likely to fail to link/compile. Suggested fix: create a dedicated torch-free module (e.g., module_asm_topksoftmax) that only compiles py_itfs_cu/asm_topksoftmax.cu (+ any needed deps) and point this decorator at it; alternatively, avoid setting torch_exclude=True for modules that still depend on torch/pybind.
@compile_ops("module_moe_asm", fc_name="topk_softmax_asm", ffi_type="ctypes")
def topk_softmax_asm(
    topk_weights: Tensor,
    topk_indices: Tensor,
    token_expert_indices: Tensor,
    gating_output: Tensor,
    need_renorm: bool,
) -> None: ...

csrc/include/rocm_ops.hpp:1459

  • QUANT_PYBIND no longer binds moe_smooth_per_token_scaled_quant_v1/v2, but Python still calls these via @compile_ops("module_quant") (see aiter/ops/quant.py). This will cause runtime AttributeError when the compiled module_quant is imported and getattr(module, ...) is attempted. Either restore these m.def(...) bindings in QUANT_PYBIND or update the Python API to stop referencing these functions.
#define QUANT_PYBIND                                                     \
    m.def("static_per_tensor_quant", &aiter::static_per_tensor_quant);   \
    m.def("dynamic_per_tensor_quant", &aiter::dynamic_per_tensor_quant); \
    m.def("dynamic_per_token_scaled_quant",                              \
          &aiter::dynamic_per_token_scaled_quant,                        \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("scale_ub")        = std::nullopt,                     \
          py::arg("shuffle_scale")   = false,                            \
          py::arg("num_rows")        = std::nullopt,                     \
          py::arg("num_rows_factor") = 1);                               \
    m.def("dynamic_per_group_scaled_quant_fp4",                          \
          &aiter::dynamic_per_group_scaled_quant_fp4,                    \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("group_size")      = 32,                               \
          py::arg("shuffle_scale")   = true,                             \
          py::arg("num_rows")        = std::nullopt,                     \
          py::arg("num_rows_factor") = 1);                               \
    m.def("smooth_per_token_scaled_quant",                               \
          &aiter::smooth_per_token_scaled_quant,                         \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("smooth_scale"),                                       \
          py::arg("smooth_scale_map")      = std::nullopt,               \
          py::arg("shuffle_scale")         = false,                      \
          py::arg("num_rows")              = std::nullopt,               \
          py::arg("num_rows_factor")       = 1,                          \
          py::arg("smooth_scale_map_hash") = std::nullopt,               \
          py::arg("enable_ps")             = true);                                  \
    m.def("partial_transpose",                                           \
          &aiter::partial_transpose,                                     \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("num_rows"));

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +220 to 222
torch.isclose(sr, st, rtol=rtol, atol=atol)
logger.info(
f"{msg} [{i_row}x{i_col}], r:{r_idx[i_row, i_col]}->{sr}, t:{t_idx[i_row, i_col]}->{st}"
Comment on lines +4 to +12
#include "aiter_enum.h"
#include <cstdint>

struct AiterTensor
{
void* ptr; // data_ptr, pointer to GPU memory
size_t numel_; // total number of elements
int ndim; // number of dimensions
int64_t shape[8]; // size of each dimension, up to 8 dims (PyTorch limit)
Comment on lines 3 to +55
@@ -18,3 +19,55 @@ enum class QuantType : int
per_1x128,
per_128x128,
};
typedef enum {
AITER_DTYPE_fp8,
AITER_DTYPE_fp8_e8m0,
AITER_DTYPE_fp16,
AITER_DTYPE_bf16,
AITER_DTYPE_fp32,
AITER_DTYPE_i4x2,
AITER_DTYPE_fp4x2,
AITER_DTYPE_u32,
AITER_DTYPE_i32,
AITER_DTYPE_i16,
AITER_DTYPE_i8,
AITER_DTYPE_u8,
} AiterDtype;

static inline size_t AiterDtype_element_size(AiterDtype dtype)
{
switch (dtype) {
case AITER_DTYPE_fp8:
case AITER_DTYPE_fp8_e8m0:
case AITER_DTYPE_i4x2:
case AITER_DTYPE_fp4x2:
case AITER_DTYPE_i8:
case AITER_DTYPE_u8: return 1;
case AITER_DTYPE_fp16:
case AITER_DTYPE_bf16:
case AITER_DTYPE_i16: return 2;
case AITER_DTYPE_fp32:
case AITER_DTYPE_u32:
case AITER_DTYPE_i32: return 4;
default: return 0;
}
}

Comment on lines 22 to +56
i4x2 = getattr(torch, "int4", _8bit_fallback)
fp4x2 = getattr(torch, "float4_e2m1fn_x2", _8bit_fallback)
fp8 = get_dtype_fp8()
fp8_e8m0 = getattr(torch, "float8_e8m0fnu", _8bit_fallback)
fp16 = torch.float16
bf16 = torch.bfloat16
fp32 = torch.float32
u32 = torch.uint32
i32 = torch.int32
i16 = torch.int16
i8 = torch.int8

d_dtypes = {
"fp8": fp8,
"fp8_e8m0": fp8_e8m0,
"fp16": fp16,
"bf16": bf16,
"fp32": fp32,
"i4x2": i4x2,
"fp4x2": fp4x2,
"u32": u32,
"i32": i32,
"i16": i16,
"i8": i8,
}
u8 = torch.uint8

d_dtypes = {name: globals()[name] for name in aiter_dtypes}

globals().update({f"AITER_DTYPE_{name}": idx for name, idx in aiter_dtypes.items()})
_torch_to_aiter_dtype = {globals()[name]: idx for name, idx in aiter_dtypes.items()}


def torch_to_aiter(tensor: torch.Tensor) -> AiterTensor:
"""torch.Tensor -> AiterTensor, zero-copy, points to the same GPU memory."""
assert tensor.is_cuda, "AiterTensor only supports CUDA tensors"
assert tensor.ndim <= 8, f"AiterTensor supports at most 8 dims, got {tensor.ndim}"
assert tensor.dtype in _torch_to_aiter_dtype, f"Unsupported dtype: {tensor.dtype}"

at = AiterTensor()
at.ptr = tensor.data_ptr()
at.numel_ = tensor.numel()
at.ndim = tensor.ndim
for i in range(tensor.ndim):
at.shape[i] = tensor.shape[i]
at.strides[i] = tensor.stride(i)
at.dtype_ = _torch_to_aiter_dtype[tensor.dtype]
at.device_id = tensor.device.index or 0
return at
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants