Skip to content

feat(ascend): op-norm-rope group — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding#66

Merged
voltjia merged 27 commits intomasterfrom
feat/ascend-op-norm-rope
Apr 24, 2026
Merged

feat(ascend): op-norm-rope group — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding#66
voltjia merged 27 commits intomasterfrom
feat/ascend-op-norm-rope

Conversation

@zhangyue207
Copy link
Copy Markdown
Collaborator

Summary

Seven layer-level Ascend operators — Swiglu, SiluAndMul, CausalSoftmax,
RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding — covering the
norm + RoPE paths used by every transformer inference layer.

Part 3 of 4 in the Ascend operator split. Parallel-reviewable with
op-simple and op-cache-attn (operator sets are disjoint; no textual
conflict).

Depends on: feat/ascend-framework-pr must merge first.

Operators

op impl Notes
Swiglu aclnnSilu + aclnnMul (decomposed); kernel_fused.h wraps fused swiglu where the ACLNN version is reliable
SiluAndMul custom AscendC kernel (kernel.h routes to ascend_kernel::silu_and_mul)
CausalSoftmax aclnnSoftmax + pre-computed mask
RmsNorm aclnnRmsNorm (kernel.h); custom AscendC variant (kernel_custom.h)
AddRmsNorm 3 impls: decomposed aclnnAdd+aclnnRmsNorm (kernel.h); fused aclnnAddRmsNorm (kernel_fused.h); custom AscendC (kernel_custom.h)
ApplyRotaryPosEmb aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam (kernel_atb.h)
RotaryEmbedding 3 impls: V2 (kernel.h); ATB RopeParam with both neox / interleave (kernel_atb.h); aclnnRopeWithSinCosCache for partial rotary (kernel_sincos_cache.h) see API alignment below

vLLM API alignment

src/base/rotary_embedding.h: query_out / key_out are now
std::optional<Tensor>. When omitted, the kernel writes back in place on
query / key — matches vLLM's RotaryEmbedding.forward(positions, query, key) inplace signature. Explicit out buffers are still supported.
All three Ascend impls resolve the optional to a concrete tensor via
value_or(query).

test_rotary_embedding_inplace covers fp16 / bf16 × impl=0 / impl=1.
Tolerance is atol=5e-3, matching the V2 ~4 ULP fp16 accumulator error
documented in kernel.h.

Base headers

  • NEW: src/base/apply_rotary_pos_emb.h, silu_and_mul.h
  • MODIFY: src/base/rotary_embedding.h (optional out), add_rms_norm.h
    (constructor parameter alignment), rms_norm.h, causal_softmax.h,
    swiglu.h

Verification

  • python3 .ci/run.py --local --gpu-id <N> (Ascend 910B + CANN 8.5.1):
    3347 passed / 1684 skipped / 0 failed

Test plan

  • python3 .ci/run.py --local
  • test_rotary_embedding_inplace (fp16/bf16 × impl=0/1): 4 passed
  • clang-format passes locally
  • CUDA / Metax / Cambricon / Moore / Iluvatar regressions (CI-verified)

@zhangyue207
Copy link
Copy Markdown
Collaborator Author

merge test

[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape3-input_strides3-gate_strides3-out_strides3] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape4-None-None-None] 
[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape4-None-None-None] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape5-input_strides5-gate_strides5-out_strides5] 
[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape5-input_strides5-gate_strides5-out_strides5] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape6-None-None-None] 
[gw0] [ 99%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape6-None-None-None] 
tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape7-input_strides7-gate_strides7-out_strides7] 
[gw0] [100%] PASSED tests/test_swiglu.py::test_swiglu[npu-dtype2-0.01-0.005-1-shape7-input_strides7-gate_strides7-out_strides7] 

----------- generated xml file: /workspace/results/test-results.xml ------------
===================== 3767 passed, 1664 skipped in 45.95s ======================
========== Summary ==========
[warn] job ascend_npu: container exited with 137 (likely docker teardown SIGKILL after clean pytest); junit XML reports no failures — treating as success
EXIT=0

@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch 14 times, most recently from 2b2577f to 65287ae Compare April 21, 2026 06:40
Comment thread src/ascend/add_rms_norm/kernel.h Outdated
Comment thread src/ascend/add_rms_norm/kernel.h Outdated
Comment thread src/ascend/add_rms_norm/kernel_custom.h Outdated
Comment thread src/ascend/add_rms_norm/kernel_custom.h Outdated
Comment thread src/ascend/add_rms_norm/kernel_custom.h Outdated
zhangyue207 pushed a commit that referenced this pull request Apr 21, 2026
Replaces the per-test `@pytest.mark.parametrize("implementation_index", ...)`
+ runtime `if impl not in active_indices: skip` pattern with a single hook in
`conftest.pytest_generate_tests` that emits only the (device, impl) pairs
actually active on each device.

Rationale: kernel dispatch is per-device, so cross-device union (previous
`all_active_implementation_indices` helper) polluted the matrix with impls
that the selected device can't run — runtime-skipped noise.  Joint generation
keeps the matrix to its semantic cell: "this device has this impl, so run it".

- `tests/conftest.py`: when both `device` and `implementation_index` are in
  fixturenames, emit pairs via `op_cls.active_implementation_indices(dev)`;
  fall back to a skipped placeholder (`id="skip"`) when no device has an
  active impl, avoiding `[NOTSET-...]` test IDs.
- `tests/{test_add,test_gemm,test_rms_norm,test_swiglu}.py`: drop the hardcoded
  `implementation_index` parametrize decorator and the runtime `active_indices`
  guard — conftest now handles both.
- `tests/utils.py`: remove the `all_active_implementation_indices` helper
  (superseded by per-device generation in conftest).

Same test outcome on Ascend CI (1935 passed / 1686 skipped) but the remaining
skips are now either semantically mandatory (uint dtypes unsupported by
`torch_npu`, Gemm impl=2 SFINAE-only workaround, op missing ascend impl on
op-simple pending PR #66) rather than mechanism artifacts.
voltjia pushed a commit that referenced this pull request Apr 22, 2026
…near (#65)

* feat(ascend): op-simple group — Add, Mul, Cast, Cat, Matmul, Gemm, Linear

Seven foundational Ascend operators:

| op | impl |
|---|---|
| Add | aclnnAdd |
| Mul | aclnnMul |
| Cast | aclnnCast |
| Cat | aclnnCat |
| Matmul | aclnnMatmul |
| Gemm | aclnnMm (also carries the cached-executor / workspace-pool rework) |
| Linear | aclnnMatmul + optional bias |

Also ships:
- `src/base/<op>.h` for the 5 new ops (cast/cat/linear/matmul/mul);
  `add.h` and `gemm.h` existed on master and are updated in-place
- `src/cpu/<op>/<op>.h` reference impls for cast/cat/linear/mul (add/gemm/matmul
  had CPU refs on master already)
- `tests/test_<op>.py` for each operator (add and gemm have MODIFY diffs;
  others are new)

* fix(ascend): Add/Cat destructor — use `release()` for executor-owned caches

- `add/kernel.h`: swap destroy() → release() on in_cache_/oth_cache_/out_cache_
  and drop aclDestroyAclOpExecutor (both are referenced by the Repeatable
  executor; destroying them causes double-free at shutdown per the pattern
  documented in common.h and commit 64c367c).
- `cat/kernel.h`: release all in_caches_[i] in the destructor; without it,
  ~AclTensorCache() on vector teardown double-frees descriptors held by
  tensor_list_ / executor_.
- Also group the alpha_* storage members with blank lines to match file
  convention.

* test: generate `implementation_index` dynamically from `active_implementation_indices`

Replaces hardcoded `(0, 1)` / `(0, 1, 2)` tuples in test_add, test_gemm,
test_rms_norm, test_swiglu with a union over the locally-available devices'
active implementation indices.

New helper `tests.utils.all_active_implementation_indices(op_cls)` only
iterates `get_available_devices()` to avoid `DispatchFunc::std::abort` on
device types outside the build's `ActiveDevices` set.

Effect on Ascend CI: skipped-test count drops from 3246 to 1686 — impl=1
(`cuBLASLt`) no longer parametrized when no CUDA device is visible, and
RmsNorm/Swiglu's custom-kernel slot drops out of the matrix on op-simple
where the framework layer hasn't merged the AscendC impl yet.

* test(conftest): joint `(device, implementation_index)` parametrize

Replaces the per-test `@pytest.mark.parametrize("implementation_index", ...)`
+ runtime `if impl not in active_indices: skip` pattern with a single hook in
`conftest.pytest_generate_tests` that emits only the (device, impl) pairs
actually active on each device.

Rationale: kernel dispatch is per-device, so cross-device union (previous
`all_active_implementation_indices` helper) polluted the matrix with impls
that the selected device can't run — runtime-skipped noise.  Joint generation
keeps the matrix to its semantic cell: "this device has this impl, so run it".

- `tests/conftest.py`: when both `device` and `implementation_index` are in
  fixturenames, emit pairs via `op_cls.active_implementation_indices(dev)`;
  fall back to a skipped placeholder (`id="skip"`) when no device has an
  active impl, avoiding `[NOTSET-...]` test IDs.
- `tests/{test_add,test_gemm,test_rms_norm,test_swiglu}.py`: drop the hardcoded
  `implementation_index` parametrize decorator and the runtime `active_indices`
  guard — conftest now handles both.
- `tests/utils.py`: remove the `all_active_implementation_indices` helper
  (superseded by per-device generation in conftest).

Same test outcome on Ascend CI (1935 passed / 1686 skipped) but the remaining
skips are now either semantically mandatory (uint dtypes unsupported by
`torch_npu`, Gemm impl=2 SFINAE-only workaround, op missing ascend impl on
op-simple pending PR #66) rather than mechanism artifacts.

* refactor(conftest): dedupe `_op_class_from_module`, short-circuit redundant fixture

Post-review cleanup of the joint-parametrize refactor (1dd288f):

- Extract `_op_class_from_module` as a shared helper; `skip_op_without_platform_impl` fixture now calls it instead of re-deriving the snake→pascal class name inline.
- Short-circuit the fixture when `implementation_index` is already in callspec — `pytest_generate_tests` has already pruned empty-impl pairs, so per-case `active_implementation_indices` calls are wasted work.
- Drop `try/except ImportError` inside the helper — collection has already imported `infini.ops` via test modules; masking a real import failure only turns it into a cryptic NOTSET fixture.
- Drop the `devices[0] if devices else "cpu"` fallback — `get_available_devices()` always includes `"cpu"`, making the `else` arm unreachable.

* refactor(cpu): flatten nested `DispatchFunc` in Cast; snake_case variables in Linear

Per PR #65 review:

- `src/cpu/cast/cast.h`: replace nested `DispatchFunc(in_dtype, ...)` inside
  `DispatchFunc(out_dtype, ...)` with a single multi-dispatch call
  `DispatchFunc<kCpu, AllTypes, AllTypes>({in, out}, [](in_tag, out_tag) {...})`
  per the multi-dispatch idiom documented in `CONTRIBUTING.md`.
- `src/cpu/linear/linear.h`: rename PascalCase locals to snake_case:
  `A/B/Out/Bias` → `a_ptr/b_ptr/out_ptr/bias_ptr`,
  `A_batch/B_batch/Out_batch` → `a_batch/b_batch/out_batch`,
  `M/N/K` → `m/n/k` (matching master's `src/cpu/gemm/gemm.h` which already
  uses lowercase dim names `m_/n_/k_`).

* refactor(cpu/linear): drop redundant `&& bias` guard + narrating comment

- `if (bias_ptr && bias)` → `if (bias_ptr)` (line 75). `bias_ptr` is
  `nullptr` iff `!bias` by construction at line 38, so `&& bias` is dead.
- Remove `// Determine `m`, `n`, `k` from shapes and transpose flags.` —
  the three lines below literally do exactly that; self-describing now that
  names are snake_case.

---------

Co-authored-by: zhangyue <zhangyue@example.com>
… RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding

Seven layer-level Ascend operators:

| op | impl |
|---|---|
| Swiglu | aclnnSilu + aclnnMul (decomposed); `kernel_fused.h` wraps fused swiglu where available |
| SiluAndMul | custom AscendC kernel |
| CausalSoftmax | aclnnSoftmax + pre-computed mask |
| RmsNorm | aclnnRmsNorm (kernel.h); custom AscendC variant (kernel_custom.h) |
| AddRmsNorm | 3 impls: decomposed aclnnAdd+aclnnRmsNorm (kernel.h); fused aclnnAddRmsNorm (kernel_fused.h); custom AscendC (kernel_custom.h) |
| ApplyRotaryPosEmb | aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam (kernel_atb.h) |
| RotaryEmbedding | **3 impls**: aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam with both neox/interleave (kernel_atb.h); aclnnRopeWithSinCosCache for partial rotary (kernel_sincos_cache.h) |

Bundles the RotaryEmbedding API alignment: `query_out` / `key_out`
are now `std::optional<Tensor>` — omitted → inplace on `query` / `key`
(matches vLLM `RotaryEmbedding.forward(positions, query, key)`).

New `src/base/<op>.h`: apply_rotary_pos_emb, silu_and_mul.
Modified: add_rms_norm (constructor signature alignment),
rotary_embedding (optional query_out/key_out).
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch from 65287ae to e38d08b Compare April 22, 2026 06:39
zhangyue added 9 commits April 22, 2026 15:52
…rnel registration

- swiglu/kernel_fused.h: release() cat_out_cache_ and out_staging_cache_
  to avoid double-free; drop aclDestroyTensorList per 64c367c convention.
- silu_and_mul/kernel.h: release() out_staging_cache_ to avoid double-free.
- custom/CMakeLists.txt: add add_rms_norm sources to OP_SRCS and register
  its op_kernel via ascendc_library(no_workspace_kernel ...); without
  this, aclrtlaunch_add_rms_norm has no backing implementation.
- `x1/x2/gamma/y_out/x_out` -> `input/other/weight/out/rstd_out`.
- Propagate through base header, all three Ascend kernel variants
  (`kernel.h`, `kernel_fused.h`, `kernel_custom.h`), and test file.
- Remove stale `rstd_shape_` field from base (unused; `kernel.h` holds
  its own copy).
- Upgrade assertion messages to complete sentences with backticked
  identifiers.
… kernels

- Wrap `aclnn*` / `aclrt*` identifiers in backticks and ensure
  complete-sentence, period-terminated comments per CONTRIBUTING.md.
- `silu_and_mul` base header: upgrade assertion message to a
  complete sentence with backticked identifiers.
- Files touched: causal_softmax/kernel.h, rms_norm/kernel.h,
  swiglu/kernel.h, swiglu/kernel_fused.h, base/silu_and_mul.h.
…d coverage

- Wire `implementation_index` into joint `(device, implementation_index)`
  parametrize via conftest; enforces fixture symmetry with `test_swiglu.py`.
- Add two non-contiguous shape cases to exercise the staging-buffer copy
  path in `src/ascend/silu_and_mul/kernel.h`.
…aryPosEmb base ops

Merge the two rope base headers into one vLLM-compatible op matching
`RotaryEmbedding.forward(positions, query, key=None) -> (query, key|None)`.
`key` becomes `std::optional<Tensor>` (MLA), `query_out` / `key_out` remain
optional for the vLLM-native inplace path, and a new `bool pre_gathered`
constructor flag folds the old `ApplyRotaryPosEmb` fast path into the
unified op.

Kernel updates across all three Ascend impls:
- impl 0 (`aclnnApplyRotaryPosEmbV2`) and impl 1 (ATB `RopeParam`) accept
  the optional `key` / out tensors and honor `pre_gathered` (skipping
  internal `aclnnIndexSelect` when the caller has pre-gathered).
- impl 0 and impl 1 re-upload the expanded cos/sin tables on cache-pointer
  change (reviewer-flagged stale-pointer bug).
- impl 2 (`aclnnRopeWithSinCosCache`) destroys its per-call
  `aclOpExecutor` instead of leaking it (reviewer-flagged leak).
- Uppercase locals (`D`, `T`, `Nq`, `Nkv`, `half_D`, `hiddenQ`,
  `hiddenK`) renamed to snake_case, and `uploadCosSinCache` renamed to
  `UploadCosSinCache` per Google C++ style.
After the `ApplyRotaryPosEmb` base class was folded into the unified
`RotaryEmbedding` op, vllm-infini still calls
`infini.ops.apply_rotary_pos_emb(...)` — preserve that symbol as a
pybind11 Python-level shim bound alongside the generated
`rotary_embedding` binding.

The shim un-expands the caller's neox-duplicated `[T, head_size]` cos /
sin halves, concats into a `[T, head_size*2]` pre-gathered cache,
synthesizes `positions = arange(T)`, and forwards to the unified op
with `pre_gathered=True`.  No vllm-infini changes are needed.
…3D/partial

Consolidate `test_apply_rotary_pos_emb.py` (deleted separately) into
`test_rotary_embedding.py`:

- `test_apply_rotary_pos_emb`      — pre-gathered fast path through the
  new Python shim; asserts bit-exact parity against
  `infini.ops.rotary_embedding` on the same data.
- `test_apply_rotary_pos_emb_3d`   — 3D `[T, Nq, D]` / `[T, Nkv, D]`
  layout through the shim (reviewer gap).
- `test_rotary_embedding_partial`  — extend to cover
  `is_neox_style=False` on impl 2 (`aclnnRopeWithSinCosCache`),
  matching the reviewer's partial-rotary gap on the non-neox path.
- `_ref_rotary_embedding` now tolerates `key=None` (MLA).
…nature

Without this, the unified `RotaryEmbedding`'s new `bool pre_gathered`
parameter became a required positional kwarg on the Python side, breaking
every existing `infini.ops.rotary_embedding(...)` caller that did not
pass it.  Regex-scan the base header for `<scalar_type> name = <literal>`
patterns and emit `py::arg(name) = <literal>` in `_generate_py_args`.

Also restore the default on the virtual `operator()` override in
`src/base/rotary_embedding.h` so the regex picks it up.
…ncos executor destroy

Two in-flight regressions from the previous commit:

1. The `pre_gathered=true` path in kernel.h / kernel_atb.h assumed the
   caller's `cos_sin_cache` is `[T, head_size*2]` (dim-1 concat), but
   that layout can't be split with a flat byte offset because row-major
   contiguous layout interleaves cos and sin per row.  Change the wire
   format to `[2T, head_size]` (dim-0 concat) so the first
   `T * head_size * elem_sz` bytes are contiguous cos and the next
   are contiguous sin; update both kernels and the `apply_rotary_pos_emb`
   Python shim to match.

   Also set the initial `sin_v2_cache_` base pointer to the sin offset
   so the V2 executor captures distinct cos/sin addresses on first call.

2. `kernel_sincos_cache.h` (impl 2) SIGABRTs when the per-call
   `aclOpExecutor*` is destroyed right after `aclnnRopeWithSinCosCache`
   — the kernel is async on the stream and the executor backs the
   enqueued launch.  Revert the `aclDestroyAclOpExecutor` call (still
   leaks, but matches the prior behavior that passed all partial-rotary
   tests) and leave a TODO for proper Repeatable-executor caching once
   the input-address index layout for this kernel is confirmed.
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch 5 times, most recently from 7ab7966 to d34542d Compare April 22, 2026 17:30
zhangyue added 3 commits April 23, 2026 02:06
… dep tracking

In-tree `ascendc_library()` trips a `CANN` `extract_host_stub.py` path
bug (`KeyError` on `/./workspace/...` paths in `$<TARGET_OBJECTS>`)
whenever it runs under `scikit-build-core`'s temp-dir builds.  Standalone
`src/ascend/custom/build.sh` avoids the bug by invoking a separate
`cmake` with `src/ascend/custom/` as its `SOURCE_DIR`.  This commit
drives `build.sh` from the main build so devs / CI get a working install
from a single `pip install` call.

- `option(BUILD_ASCEND_CUSTOM …)` replaces the old `BUILD_CUSTOM_KERNEL`
  (name is Ascend-specific now that the driver is CMake-native) and
  **defaults to ON**.  Non-Ascend builds ignore it (gated by
  `WITH_ASCEND` in `src/CMakeLists.txt`); users who don't want the
  `ccec` build on Ascend pass `-DBUILD_ASCEND_CUSTOM=OFF`.

- `src/CMakeLists.txt` registers `build.sh` as a build-phase
  `add_custom_command(OUTPUT …/libno_workspace_kernel.a)` with explicit
  dependencies on every `src/ascend/custom/**/*.{cpp,h}` file (via
  `file(GLOB_RECURSE … CONFIGURE_DEPENDS)`) — edits to any `op_host/` or
  `op_kernel/` source now re-trigger the build, instead of silently
  reusing a stale `.a`.  The outer `scikit-build-core` env (`CMAKE_GENERATOR`,
  `CMAKE_EXPORT_COMPILE_COMMANDS`, …) is scrubbed via `cmake -E env
  --unset=…` before invoking `build.sh` — leaving them set makes the
  nested `cmake`'s `ninja` generator emit the bug-triggering
  `/./workspace/...` paths even though the outer configure dir is clean.

- `src/ascend/custom/cmake/detect_soc.cmake` holds
  `infiniops_detect_soc(<out>)`, which parses `npu-smi info` for the
  first `910*` / `310*` entry and falls back to `Ascend910B4`.  Both
  `src/CMakeLists.txt` (outer build) and
  `src/ascend/custom/cmake/config_ascend.cmake` (sub-build driven by
  `build.sh`) `include()` this file — SOC detection lives in one place.

- `src/ascend/custom/CMakeLists.txt` pushes the main `src/` dir onto
  the interface target's `INCLUDES` property so the kernel TU can
  `#include "data_type.h"`.

- `src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy`: disables all
  `clang-tidy` checks on `ccec`-compiled device code (absent from
  `compile_commands.json`, `__aicore__` macro parses incorrectly
  without `kernel_operator.h`).

Dev workflow: `pip install -e .[dev]` gives a fully working install on
Ascend; editing any custom-kernel source and re-running `pip install`
re-triggers the `ccec` build automatically.
The `AscendC` custom kernels forward `static_cast<int64_t>(input.dtype())`
to their `aclrtlaunch_*` entry points and dispatch on the same enum —
making `DataType`'s integer values part of a host↔device ABI.

Assign explicit values (`kInt8 = 0, …, kFloat64 = 11`) to pin that ABI:
reordering or inserting entries above existing ones would silently
change the integers seen by device code.  No behaviour change at call
sites (the enum is still accessed by symbolic name everywhere except
the `int64_t` cast).
bf16 was silently producing garbage / NaN on impl 1 (`rms_norm`) and
impl 2 (`add_rms_norm`): the kernels only instantiated `<half>` and
`<float>`, and the launchers mapped bf16 to the fp32 byte-size path,
so bf16 weight was read as if it were fp32 and the fp16 output cast
used `CAST_ROUND` (fp16-only alias).

Kernel dispatch:

- `op_kernel/rms_norm.cpp` / `op_kernel/add_rms_norm.cpp`: add a
  `KernelXxx<bfloat16_t>` instantiation; dispatch in the `extern "C"`
  entry is now `switch (static_cast<infini::ops::DataType>(dtypeCode))`
  (shared enum forwarded from the launcher via `int64_t`).  The
  fp16/bf16 branch uses `CAST_RINT` for the fp32 → T writeback —
  defined for both `half` and `bfloat16_t` destinations, whereas
  `CAST_ROUND` is a `half`-specific alias.

Launchers (`kernel_custom.h`):

- Store `DataType dtype_` (replaces the old `int64_t dtype_size_` which
  collapsed fp16 and bf16 onto the same code).
- Use `ascend::ToAclDtype(dtype_)` and `kDataTypeToSize.at(dtype_)`
  instead of hand-rolled ternaries (consistent with the rest of the
  Ascend backend).
- Forward `static_cast<int64_t>(dtype_)` as the kernel's `dtypeCode`.
- `extern "C" aclrtlaunch_*` forward-decl parameters renamed to
  `snake_case`; the function name itself is generated by
  `ascendc_add_operator(OP_NAME …)` and carries
  `// NOLINTNEXTLINE(readability-identifier-naming)` so `clang-tidy`
  accepts it.

Identifier naming (Google C++ Style):

- `op_kernel/*.cpp` members `snake_case_`, params / locals `snake_case`,
  constants `kPascalCase` (was `BUFFER_NUM` / `dimLength` / `inQueueX1`
  / `blockRows`, etc. — inherited from the `vllm-ascend` sample style).

Verified: `pytest tests/test_rms_norm.py tests/test_add_rms_norm.py
--devices ascend` → 144 passed / 0 failed (fp32 / fp16 / bf16 × both
ops × full shape + stride matrix).
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch from d34542d to 33e99af Compare April 22, 2026 18:06
zhangyue added 2 commits April 23, 2026 03:11
…th vLLM

Bring `src/base/*.h` interfaces and tensor conventions into strict alignment
with vLLM's public kernel contracts.  Derived Ascend kernels and tests
follow.  `generated/bindings/` will regenerate on next build.

- **`SiluAndMul`**: rename `x` → `input` (matches `F.glu(input, dim)`); add
  `(input, out)` overload with `dim = -1` default to match vLLM's hardcoded
  last-dim behavior.
- **`Linear`**: add vLLM-aligned `(input, weight, bias?, out)` overload with
  weight stored as `[out_features, in_features]` (identical to
  `F.linear(input, weight, bias)`).  Deprecated 6-arg
  `(a, b, bias, trans_a, trans_b, out)` form retained.  CPU and Ascend
  subclasses gain matching 4-arg ctors that delegate to the 6-arg form with
  `trans_a = false, trans_b = true`.
- **`AddRmsNorm`**: rename `other` → `residual` (matches vLLM's
  `fused_add_rms_norm(input, residual, weight, eps)` schema); add inplace
  `(input, residual, weight, eps)` overload that forwards to the
  out-of-place primary form with aliased buffers.
- **`RotaryEmbedding`**: reorder first six parameters to match vLLM's
  `rotary_embedding(positions, query, key?, head_size, cos_sin_cache,
  is_neox)` schema verbatim; `rotary_dim` / `query_out?` / `key_out?` /
  `pre_gathered` remain as InfiniOps extensions at the tail.  Added
  `positions.dtype() == int64` assert per vLLM convention.

Verified on NPU: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear}.py --devices ascend` → 295 passed, 4 skipped, 0 failed.
Follow-up to `c23901a`.  Per CLAUDE.md "default to writing no comments",
strip doc-comments that narrate the change or restate well-named
identifiers from the four refactored base headers.  Keep only the one
WHY comment in `rotary_embedding.h` explaining `pre_gathered`'s
index_select+neox precondition (the name alone doesn't carry it).

Also replace the two delegating ctors in `src/cpu/linear/linear.h` with
`using Linear::Linear;` — matches the pattern already used in
`src/cpu/{rms_norm,swiglu}/*.h`, `src/cuda/{rms_norm,causal_softmax}/*.h`.

Verified: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear}.py --devices ascend` → 295 passed, 4 skipped.
Comment thread tests/test_rotary_embedding.py Outdated
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch from 3b85437 to 305aa96 Compare April 23, 2026 04:51
- `tests/test_add_rms_norm.py`: extend `implementation_index` parametrize
  to `(0, 1, 2)`; add `_clear_add_rms_norm_cache` autouse fixture to
  avoid cross-test state pollution in the custom AscendC kernel (impl 2)
  whose cached fp32 weight buffer collides across tests with matching
  shape/dtype keys.  Coverage: +54 test cases (108 total, all green).

- `src/base/rotary_embedding.h`: assert `key.has_value()` with a TODO
  noting MLA is not yet implemented on any Ascend backend.  All three
  impls already assert `has_key_` individually; hoisting the check to
  base turns a silent crash (if a caller passes `key=None`) into a clean
  assert.  Keeps `std::optional<Tensor> key` in the signature for future
  MLA support without breaking vLLM API compatibility.

- `src/ascend/causal_softmax/kernel.h`: add justification for the
  3-primitive decomposition (no single CANN 8.5 API covers causal-mask
  + softmax; `aclnnSoftmaxV2` lacks the mask argument, and
  `aclnnScaledMaskedSoftmax` requires a pre-scaled attention score), per
  CLAUDE.md Ascend rule "never decompose when a fused API exists".

Verified: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear,causal_softmax}.py --devices ascend` → 349 passed, 4 skipped.
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch from 305aa96 to ca52518 Compare April 23, 2026 05:07
zhangyue added 2 commits April 23, 2026 14:18
The legacy `apply_rotary_pos_emb` shim existed only as a vllm-infini
compat alias after the `ApplyRotaryPosEmb` base op was folded into the
unified `RotaryEmbedding`.  vllm-infini is out of scope for this PR, so
drop the shim entirely:

- `scripts/generate_wrappers.py`: remove `_generate_apply_rotary_pos_emb_shim`
  and the `extra_shim` emission hook — the Python-level wrapper was
  ~60 lines of pybind C++ that concatenated cos/sin, synthesized
  `positions = arange(T)`, and forwarded to `rotary_embedding` with
  `pre_gathered=True`.  Callers that need the pre-gather fast path can
  invoke `infini.ops.rotary_embedding(..., pre_gathered=True)` directly.
- `tests/test_rotary_embedding.py`: remove `test_apply_rotary_pos_emb` /
  `test_apply_rotary_pos_emb_3d` and the `_expand_cos_sin` helper that
  only those tests used.  `pre_gathered=True` remains exercised
  indirectly via `test_rotary_embedding_full` when impl 2 requires the
  caller to pre-gather (handled internally by the kernel).
- Touch up two stale `apply_rotary_pos_emb shim` comments in
  `kernel{,_atb}.h` that no longer point anywhere.

Verified: `pytest tests/ --devices ascend` → 2278 passed, 1612 skipped
(was 2306 / 1612 — delta is the 28 removed `apply_rotary_pos_emb` cases).
Fold the deleted `test_apply_rotary_pos_emb` / `_3d` cases into a single
`test_rotary_embedding_pre_gathered` that exercises the `pre_gathered`
fast path directly on the `rotary_embedding` overload (no shim).
Parametrize over 2D / 3D query-key layouts, impls 0 and 1 (impl 2 asserts
`!pre_gathered_`), neox / GPT-J styles, fp16 / bf16.  The new
`_build_pre_gathered_cache` helper constructs the `[2*T, head_size]`
wire format that `src/ascend/rotary_embedding/kernel.h` expects —
cos rows 0..T-1, sin rows T..2T-1, both neox-expanded per token.

Coverage: 12 new cases pass (4 skip for `impl=0 + not-neox`, same as the
`test_rotary_embedding_full` skip — V2 only supports `rotaryMode="half"`).

Full rotary suite: 88 passed, 8 skipped (was 80 passed, 4 skipped before
this test was added).
@zhangyue207 zhangyue207 force-pushed the feat/ascend-op-norm-rope branch from 10e0ad6 to 694506b Compare April 23, 2026 06:40
- `src/base/add_rms_norm.h`: `#include <cstddef>` — no `size_t` usage.
- `src/base/rotary_embedding.h`: same.
- `src/ascend/add_rms_norm/kernel_custom.h`: `#include <vector>` — no
  `std::vector` / `std::array` usage.

Build + 355 passed / 8 skipped on Ascend unchanged.
Comment thread src/ascend/add_rms_norm/kernel_custom.h Outdated
Comment thread src/ascend/add_rms_norm/kernel_custom.h Outdated
Comment thread src/ascend/add_rms_norm/kernel_fused.h Outdated
Comment thread src/ascend/rms_norm/kernel_custom.h Outdated
Comment thread src/ascend/rotary_embedding/kernel_atb.h Outdated
@zhangyue207
Copy link
Copy Markdown
Collaborator Author

ascend:

Installing collected packages: InfiniOps
  Attempting uninstall: InfiniOps
    Found existing installation: InfiniOps 0.1.0
    Uninstalling InfiniOps-0.1.0:
      Successfully uninstalled InfiniOps-0.1.0
Successfully installed InfiniOps-0.1.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
[root@localhost workspace]# pytest tests/ --devices ascend
============================================= test session starts ==============================================
platform linux -- Python 3.11.14, pytest-9.0.3, pluggy-1.6.0
rootdir: /workspace
configfile: pyproject.toml
plugins: cov-7.1.0, xdist-3.8.0, anyio-4.13.0
collected 2298 items                                                                                           

tests/test_add.py ........................................................................ssssssssssssss [  3%]
ssssssssssssssssssssss                                                                                   [  4%]
tests/test_add_rms_norm.py ............................................................................. [  8%]
...............................                                                                          [  9%]
tests/test_cast.py ..............................                                                        [ 10%]
tests/test_cat.py .....................                                                                  [ 11%]
tests/test_causal_softmax.py ..................                                                          [ 12%]
tests/test_gemm.py ..................................................................................... [ 16%]
........................................................................................................ [ 20%]
........................................................................................................ [ 25%]
........................................................................................................ [ 29%]
........................................................................................................ [ 34%]
........................................................................................................ [ 38%]
........................................................................................................ [ 43%]
........................................................................................................ [ 47%]
........................................................................................................ [ 52%]
........................................................................................................ [ 56%]
........................................................................................................ [ 61%]
........................................................................................................ [ 65%]
........................................................................................................ [ 70%]
........................................................................................................ [ 74%]
...............................................................                                          [ 77%]
tests/test_linear.py ................................................................................... [ 81%]
.....................................                                                                    [ 82%]
tests/test_matmul.py ................................................                                    [ 84%]
tests/test_mul.py ........................................................................ssssssssssssss [ 88%]
ssssssssssssssssssssss                                                                                   [ 89%]
tests/test_rms_norm.py ........................................................................          [ 92%]
tests/test_rotary_embedding.py ..ss..ss................................................................. [ 95%]
...........s.s.....s.s.                                                                                  [ 96%]
tests/test_silu_and_mul.py .....................                                                         [ 97%]
tests/test_swiglu.py ................................................                                    [100%]

====================================== 2218 passed, 80 skipped in 13.91s =======================================

zhangyue added 3 commits April 23, 2026 19:53
Addresses inline review comments on #66 (reviewer: Ziminli) across all
PR-touched files:

- C4: strip trailing periods from assert messages; lowercase the
  sentence-starting word when it is bare English (e.g. "Ascend ..." →
  "ascend ..."), leave backticked identifiers untouched.
- G4: backtick `RmsNorm` in kernel_custom.h header comment; backtick
  `aclnn` / `cos_sin_cache` / `infini.ops.add_rms_norm(...)` in kernel
  comments that were still running raw text.
- C2: rename `aclrtlaunch_add_rms_norm` / `aclrtlaunch_rms_norm`
  forward-decl parameter names from AscendC internals (`x1, x2, y,
  x_out`) to the base-header semantic names (`input, residual, weight,
  out, residual_out`).  The extern "C" symbol is name-blind so the
  AscendC kernel .cpp can keep its local names — the wrapper .h just
  presents the public contract.
- Pre-gathered rotary test: drop the hardcoded
  `implementation_index=(0, 1)` parametrize, let conftest auto-inject
  and skip impl 2 inline (the impl 2 kernel asserts
  `!pre_gathered_`).

Verified locally (`--gpu-id 3/4/5 --local`):
  test_add_rms_norm.py:      108 passed
  test_rms_norm.py:            72 passed
  test_rotary_embedding.py:    88 passed, 16 skipped (expected:
                                          impl 2 + pre_gathered,
                                          impl 0 + non-neox)
…m order

Addresses Ziminli's comment on `aclrtlaunch_add_rms_norm` forward-decl
(#66 discussion 3115868675 / 3129096521):

- **函数名格式:** the AscendC kernel entry-points `add_rms_norm` /
  `rms_norm` are renamed to `AddRmsNorm` / `RmsNorm`.  The AscendC
  toolchain prepends `aclrtlaunch_` on the symbol regardless of case,
  so the exported names become `aclrtlaunch_AddRmsNorm` /
  `aclrtlaunch_RmsNorm` — matching the base-class names and
  `readability-identifier-naming.FunctionCase = CamelCase`.
  The `NOLINTNEXTLINE(readability-identifier-naming)` shim and the
  "PascalCase rule does not apply" apology comments go away.

- **参数列表顺序 (C2):** reorder parameters to `inputs, attributes,
  outputs`.  Both `.cpp` kernel entry, `KernelAddRmsNorm::Init` /
  `KernelRmsNorm::Init`, and the `extern "C"` forward-decl in
  `kernel_custom.h` are updated together, along with the call sites
  in `operator()`.

- **Variable naming (`.cpp` internals):** `x1/x2/y/x_out` →
  `input/residual/out/residual_out`; `x/y` → `input/out`.  Cascaded
  through member names (`*_gm_`, `*_queue_*`, `*_local`) for
  consistency — internal to each kernel class, no ABI impact.

- **`op_host/*.cpp`:** updated to include the PascalCase generated
  header `aclrtlaunch_AddRmsNorm.h` / `aclrtlaunch_RmsNorm.h` and to
  match the reordered `EXEC_KERNEL_CMD` argument list.

Verified locally with `.ci/run.py --local`:
  test_add_rms_norm.py:      108 passed
  test_rms_norm.py:            72 passed

The AscendC toolchain successfully compiles the PascalCase kernel
entries and generates matching launch headers — the
`aclrtlaunch_<ENTRY>` macro concatenates regardless of case.
/simplify found 4 comment blocks that narrate the rename rationale
rather than encode load-bearing contracts:

- `kernel_custom.h` forward-decl — compress build-system detail
  (`no_workspace_kernel`, `ascendc_library()`) to one line, keep only
  the ABI contract (`aclrtlaunch_<Entry>` is generated by AscendC from
  `op_kernel/`).
- `op_host/<op>.cpp` `EXEC_KERNEL_CMD` — drop "Parameter order follows
  the base class: inputs, attributes, outputs."; the signature itself
  is self-evident.
- `op_kernel/<op>.cpp` kernel entry — drop "Parameters follow the C2
  convention ..." and "`aclrtlaunch_AddRmsNorm` matches the base
  `AddRmsNorm` class name"; these are commit-message material, not
  comments.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants