From 76490424c6f8ce062ba37aae4937c9f4d3dec197 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Sat, 18 Apr 2026 05:11:22 +0800 Subject: [PATCH 1/7] =?UTF-8?q?feat(ascend):=20op-simple=20group=20?= =?UTF-8?q?=E2=80=94=20Add,=20Mul,=20Cast,=20Cat,=20Matmul,=20Gemm,=20Line?= =?UTF-8?q?ar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Seven foundational Ascend operators: | op | impl | |---|---| | Add | aclnnAdd | | Mul | aclnnMul | | Cast | aclnnCast | | Cat | aclnnCat | | Matmul | aclnnMatmul | | Gemm | aclnnMm (also carries the cached-executor / workspace-pool rework) | | Linear | aclnnMatmul + optional bias | Also ships: - `src/base/.h` for the 5 new ops (cast/cat/linear/matmul/mul); `add.h` and `gemm.h` existed on master and are updated in-place - `src/cpu//.h` reference impls for cast/cat/linear/mul (add/gemm/matmul had CPU refs on master already) - `tests/test_.py` for each operator (add and gemm have MODIFY diffs; others are new) --- src/ascend/add/kernel.h | 92 +++++++++++++++++++++++++++ src/ascend/cast/kernel.h | 64 +++++++++++++++++++ src/ascend/cat/kernel.h | 98 +++++++++++++++++++++++++++++ src/ascend/gemm/kernel.h | 75 ++++++++++++++-------- src/ascend/linear/kernel.h | 125 +++++++++++++++++++++++++++++++++++++ src/ascend/matmul/kernel.h | 68 ++++++++++++++++++++ src/ascend/mul/kernel.h | 68 ++++++++++++++++++++ src/base/cast.h | 52 +++++++++++++++ src/base/cat.h | 35 +++++++++++ src/base/linear.h | 65 +++++++++++++++++++ src/base/mat_mul.h | 31 --------- src/base/matmul.h | 41 ++++++++++++ src/base/mul.h | 67 ++++++++++++++++++++ src/cpu/cast/cast.h | 57 +++++++++++++++++ src/cpu/cat/cat.h | 71 +++++++++++++++++++++ src/cpu/linear/linear.h | 108 ++++++++++++++++++++++++++++++++ src/cpu/mul/mul.h | 63 +++++++++++++++++++ tests/test_cast.py | 62 ++++++++++++++++++ tests/test_cat.py | 69 ++++++++++++++++++++ tests/test_linear.py | 90 ++++++++++++++++++++++++++ tests/test_matmul.py | 76 ++++++++++++++++++++++ tests/test_mul.py | 87 ++++++++++++++++++++++++++ 22 files changed, 1508 insertions(+), 56 deletions(-) create mode 100644 src/ascend/add/kernel.h create mode 100644 src/ascend/cast/kernel.h create mode 100644 src/ascend/cat/kernel.h create mode 100644 src/ascend/linear/kernel.h create mode 100644 src/ascend/matmul/kernel.h create mode 100644 src/ascend/mul/kernel.h create mode 100644 src/base/cast.h create mode 100644 src/base/cat.h create mode 100644 src/base/linear.h delete mode 100644 src/base/mat_mul.h create mode 100644 src/base/matmul.h create mode 100644 src/base/mul.h create mode 100644 src/cpu/cast/cast.h create mode 100644 src/cpu/cat/cat.h create mode 100644 src/cpu/linear/linear.h create mode 100644 src/cpu/mul/mul.h create mode 100644 tests/test_cast.py create mode 100644 tests/test_cat.py create mode 100644 tests/test_linear.py create mode 100644 tests/test_matmul.py create mode 100644 tests/test_mul.py diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h new file mode 100644 index 00000000..73b3005b --- /dev/null +++ b/src/ascend/add/kernel.h @@ -0,0 +1,92 @@ +#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) { + // `aclCreateScalar` stores the pointer rather than copying the value, so + // `alpha_storage_*` must remain alive for the lifetime of `alpha_`. + // The alpha scalar type must match the tensor dtype: use int64 for integer + // dtypes and float for floating-point dtypes. + if (ascend::IsIntegerDtype(input.dtype())) { + alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64); + } else { + alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Destroy cached tensors and the executor, then the scalar. + // Historical note: this active-destroy pattern works for `Add` at + // process exit but crashed for most other operators — see `64c367c` + // and the rest of `src/ascend/*/kernel.h` which use `release()` only. + in_cache_.destroy(); + oth_cache_.destroy(); + out_cache_.destroy(); + + if (executor_) aclDestroyAclOpExecutor(executor_); + if (alpha_) aclDestroyScalar(alpha_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnAdd(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + float alpha_float_storage_ = + 1.0f; // Stable address for `aclCreateScalar` (float). + int64_t alpha_int_storage_ = + 1; // Stable address for `aclCreateScalar` (int). + aclScalar* alpha_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cast/kernel.h b/src/ascend/cast/kernel.h new file mode 100644 index 00000000..d918aa84 --- /dev/null +++ b/src/ascend/cast/kernel.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAST_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cast.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) + : Cast(input, out), + in_cache_(input), + out_cache_(out), + acl_out_dtype_(ascend::ToAclDtype(out.dtype())) {} + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor input, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnCast(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + aclDataType acl_out_dtype_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h new file mode 100644 index 00000000..bb821073 --- /dev/null +++ b/src/ascend/cat/kernel.h @@ -0,0 +1,98 @@ +#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAT_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cat.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cat.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat(first_input, rest_inputs, dim, out), out_cache_(out) { + // Build `AclTensorCache` for each input tensor. + in_caches_.reserve(input_count_); + in_caches_.emplace_back(first_input); + for (const auto& t : rest_inputs) { + in_caches_.emplace_back(t); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + out_cache_.release(); + + if (tensor_list_) aclDestroyTensorList(tensor_list_); + } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + auto stream = static_cast(stream_); + + // Collect all input tensors in order. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + // First call: create descriptors, tensor list, and executor. + std::vector acl_tensors(input_count_); + for (size_t i = 0; i < input_count_; ++i) { + acl_tensors[i] = + in_caches_[i].get(const_cast(inputs[i]->data())); + } + + tensor_list_ = + aclCreateTensorList(const_cast(acl_tensors.data()), + static_cast(input_count_)); + + aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + // Subsequent calls: update data pointers on cached descriptors via + // `aclSetRawTensorAddr`. The executor holds references to the same + // `aclTensor*` objects inside `tensor_list_`, so updating their data + // pointers is sufficient — no `aclSetInputTensorAddr` needed. + for (size_t i = 0; i < input_count_; ++i) { + in_caches_[i].get(const_cast(inputs[i]->data())); + } + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnCat(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable std::vector in_caches_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclTensorList* tensor_list_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h index 16f8c50f..1795baf2 100644 --- a/src/ascend/gemm/kernel.h +++ b/src/ascend/gemm/kernel.h @@ -21,14 +21,26 @@ class Operator : public Gemm { : Gemm(a, b, alpha, beta, trans_a, trans_b, c), batched_{batch_count_ > 1}, alpha_val_{alpha.value_or(1.0f)}, - beta_val_{beta.value_or(1.0f)} { + beta_val_{beta.value_or(1.0f)}, + self_cache_(c), + a_cache_(a, trans_a_), + b_cache_(b, trans_b_), + out_cache_(c) { alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); } ~Operator() { - aclDestroyScalar(alpha_scalar_); - aclDestroyScalar(beta_scalar_); + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + self_cache_.release(); + a_cache_.release(); + b_cache_.release(); + out_cache_.release(); + + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); } void operator()(const Tensor a, const Tensor b, std::optional alpha, @@ -36,35 +48,36 @@ class Operator : public Gemm { std::optional trans_b, Tensor c) const override { auto stream = static_cast(stream_); - auto t_self = ascend::BuildAclTensor(c); - auto t_a = ascend::BuildAclTensor(a, trans_a_); - auto t_b = ascend::BuildAclTensor(b, trans_b_); - auto t_out = ascend::BuildAclTensor(c); - - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - - if (batched_) { - aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, - alpha_scalar_, t_out, 0, &ws_needed, - &executor); + auto t_self = self_cache_.get(c.data()); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); } else { - aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_, - t_out, 0, &ws_needed, &executor); + aclSetInputTensorAddr(executor_, 0, t_self, c.data()); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); } - auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed); + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); if (batched_) { - aclnnBaddbmm(arena.buf, ws_needed, executor, stream); + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); } else { - aclnnAddmm(arena.buf, ws_needed, executor, stream); + aclnnAddmm(arena.buf, ws_size_, executor_, stream); } - - aclDestroyTensor(t_self); - aclDestroyTensor(t_a); - aclDestroyTensor(t_b); - aclDestroyTensor(t_out); } private: @@ -77,6 +90,18 @@ class Operator : public Gemm { aclScalar* alpha_scalar_ = nullptr; aclScalar* beta_scalar_ = nullptr; + + mutable ascend::AclTensorCache self_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/linear/kernel.h b/src/ascend/linear/kernel.h new file mode 100644 index 00000000..497dd806 --- /dev/null +++ b/src/ascend/linear/kernel.h @@ -0,0 +1,125 @@ +#ifndef INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ +#define INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/linear.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Linear { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear(a, b, bias, trans_a, trans_b, out), + batched_{out.ndim() > 2}, + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(out) { + if (has_bias_) { + bias_cache_ = ascend::AclTensorCache(*bias); + alpha_scalar_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + bias_cache_.release(); + a_cache_.release(); + b_cache_.release(); + out_cache_.release(); + + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(out.data()); + + if (has_bias_) { + auto t_bias = bias_cache_.get(const_cast(bias->data())); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_bias, + const_cast(bias->data())); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); + } else { + aclnnAddmm(arena.buf, ws_size_, executor_, stream); + } + } else { + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + } + + private: + bool batched_; + + mutable ascend::AclTensorCache bias_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + float alpha_storage_ = 1.0f; + + float beta_storage_ = 1.0f; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h new file mode 100644 index 00000000..df05677f --- /dev/null +++ b/src/ascend/matmul/kernel.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/matmul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : Matmul(a, b, c, trans_a, trans_b), + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(c) {} + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + a_cache_.release(); + b_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/mul/kernel.h b/src/ascend/mul/kernel.h new file mode 100644 index 00000000..f1cfed67 --- /dev/null +++ b/src/ascend/mul/kernel.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_ASCEND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/mul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Mul { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) {} + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + oth_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnMulGetWorkspaceSize(t_in, t_oth, t_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnMul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cast.h b/src/base/cast.h new file mode 100644 index 00000000..29f1f40c --- /dev/null +++ b/src/base/cast.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_BASE_CAST_H_ +#define INFINI_OPS_BASE_CAST_H_ + +#include "operator.h" + +namespace infini::ops { + +class Cast : public Operator { + public: + Cast(const Tensor input, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_dtype_{input.dtype()}, + out_dtype_{out.dtype()}, + input_shape_{input.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input.numel() == out.numel() && + "the input and output of `Cast` must have the same number of " + "elements"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_dtype_; + + const DataType out_dtype_; + + Tensor::Shape input_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cat.h b/src/base/cat.h new file mode 100644 index 00000000..dcb0ba58 --- /dev/null +++ b/src/base/cat.h @@ -0,0 +1,35 @@ +#ifndef INFINI_OPS_BASE_CAT_H_ +#define INFINI_OPS_BASE_CAT_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Cat : public Operator { + public: + Cat(const Tensor first_input, std::vector rest_inputs, int64_t dim, + Tensor out) + : input_count_{1 + rest_inputs.size()} { + assert(input_count_ >= 2 && "`Cat` requires at least 2 input tensors"); + + auto ndim = static_cast(out.ndim()); + // Normalize negative dim (e.g. -1 means last dimension). + dim_ = dim < 0 ? dim + ndim : dim; + assert(dim_ >= 0 && dim_ < ndim && "`Cat` dim out of range"); + } + + virtual void operator()(const Tensor first_input, + std::vector rest_inputs, int64_t dim, + Tensor out) const = 0; + + protected: + int64_t dim_; + + size_t input_count_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/linear.h b/src/base/linear.h new file mode 100644 index 00000000..a5276e61 --- /dev/null +++ b/src/base/linear.h @@ -0,0 +1,65 @@ +#ifndef INFINI_OPS_BASE_LINEAR_H_ +#define INFINI_OPS_BASE_LINEAR_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +// Fused linear projection: out = a @ b (+ bias). +// +// When bias is present, computes out = a @ b + bias in a single dispatch. +// When bias is absent, computes out = a @ b (equivalent to Matmul). +// `trans_a` / `trans_b`: If true, transpose the last two dims before +// multiplying. +class Linear : public Operator { + public: + Linear(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + out_shape_{out.shape()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + out_strides_{out.strides()}, + trans_a_{trans_a}, + trans_b_{trans_b}, + has_bias_{bias.has_value()} { + assert(a.dtype() == b.dtype() && + "operator `Linear` requires a and b to have the same dtype"); + assert(a.dtype() == out.dtype() && + "operator `Linear` requires a and out to have the same dtype"); + if (has_bias_) { + assert(bias->dtype() == out.dtype() && + "operator `Linear` requires bias and out to have the same dtype"); + } + } + + virtual void operator()(const Tensor a, const Tensor b, + std::optional bias, bool trans_a, + bool trans_b, Tensor out) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides out_strides_; + + bool trans_a_{false}; + + bool trans_b_{false}; + + bool has_bias_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mat_mul.h b/src/base/mat_mul.h deleted file mode 100644 index 6180c8bf..00000000 --- a/src/base/mat_mul.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef INFINI_OPS_BASE_MAT_MUL_H_ -#define INFINI_OPS_BASE_MAT_MUL_H_ - -#include "operator.h" -#include "tensor.h" - -namespace infini::ops { - -class MatMul : public Operator { - public: - MatMul(const Tensor input, const Tensor other, Tensor out) - : input_shape_{input.shape()}, - other_shape_{other.shape()}, - out_shape_{out.shape()} { - assert(input.dtype() == other.dtype()); - } - - virtual void operator()(const Tensor input, const Tensor other, - Tensor out) const = 0; - - protected: - Tensor::Shape input_shape_; - - Tensor::Shape other_shape_; - - Tensor::Shape out_shape_; -}; - -} // namespace infini::ops - -#endif diff --git a/src/base/matmul.h b/src/base/matmul.h new file mode 100644 index 00000000..071feaea --- /dev/null +++ b/src/base/matmul.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_BASE_MATMUL_H_ +#define INFINI_OPS_BASE_MATMUL_H_ + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class Matmul : public Operator { + public: + // `trans_a` / `trans_b`: If true, transpose the last two dims of `a` / `b` + // before multiplying. These are constructor parameters so the `CacheKey` + // encodes the transposition and distinct descriptors are cached for each + // combination. + Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + c_shape_{c.shape()}, + trans_a_{trans_a}, + trans_b_{trans_b} { + assert(a.dtype() == b.dtype()); + } + + virtual void operator()(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape c_shape_; + + bool trans_a_{false}; + + bool trans_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mul.h b/src/base/mul.h new file mode 100644 index 00000000..9e7be223 --- /dev/null +++ b/src/base/mul.h @@ -0,0 +1,67 @@ +#ifndef INFINI_OPS_BASE_MUL_H_ +#define INFINI_OPS_BASE_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +class Mul : public Operator { + public: + Mul(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "the output of `Mul` should NOT have broadcasted dim!"); + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "operator `Mul` requires all input and output tensors to have the " + "same dtype"); + } + + virtual void operator()(const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cast/cast.h b/src/cpu/cast/cast.h new file mode 100644 index 00000000..67c8367c --- /dev/null +++ b/src/cpu/cast/cast.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_CPU_CAST_CAST_H_ +#define INFINI_OPS_CPU_CAST_CAST_H_ + +#include "base/cast.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) : Cast{input, out} {} + + void operator()(const Tensor input, Tensor out) const override { + DispatchFunc( + input_dtype_, + [&](auto in_tag) { + using InT = typename decltype(in_tag)::type; + DispatchFunc( + out_dtype_, + [&](auto out_tag) { + using OutT = typename decltype(out_tag)::type; + Compute(input, out); + }, + "`Operator::operator()` (out)"); + }, + "`Operator::operator()` (in)"); + } + + private: + template + void Compute(const Tensor input, Tensor out) const { + const auto* in_ptr = static_cast(input.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto in_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = + Caster::template Cast(in_ptr[in_idx]); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cat/cat.h b/src/cpu/cat/cat.h new file mode 100644 index 00000000..18b45247 --- /dev/null +++ b/src/cpu/cat/cat.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_CPU_CAT_CAT_H_ +#define INFINI_OPS_CPU_CAT_CAT_H_ + +#include +#include + +#include "base/cat.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, rest_inputs, dim, out} {} + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + // Collect all input tensors. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + // Use normalized `dim_` from base class (handles negative dim). + auto dim = dim_; + auto elem_size = kDataTypeToSize.at(out.dtype()); + auto ndim = out.ndim(); + auto out_shape = out.shape(); + + // Compute outer and inner sizes relative to the cat dimension. + Tensor::Size outer = 1; + for (int64_t i = 0; i < dim; ++i) { + outer *= out_shape[i]; + } + + Tensor::Size inner = 1; + for (size_t i = static_cast(dim) + 1; i < ndim; ++i) { + inner *= out_shape[i]; + } + + auto* out_ptr = static_cast(out.data()); + Tensor::Size out_dim_size = out_shape[dim]; + + // For each outer index, copy slices from each input along the cat dim. + for (Tensor::Size o = 0; o < outer; ++o) { + Tensor::Size offset_in_dim = 0; + + for (size_t t = 0; t < input_count_; ++t) { + auto in_dim = inputs[t]->shape()[dim]; + auto in_ptr = static_cast(inputs[t]->data()); + + auto src_offset = (o * in_dim) * inner * elem_size; + auto dst_offset = + (o * out_dim_size + offset_in_dim) * inner * elem_size; + auto copy_size = in_dim * inner * elem_size; + + std::memcpy(out_ptr + dst_offset, in_ptr + src_offset, copy_size); + offset_in_dim += in_dim; + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h new file mode 100644 index 00000000..f5323c2f --- /dev/null +++ b/src/cpu/linear/linear.h @@ -0,0 +1,108 @@ +#ifndef INFINI_OPS_CPU_LINEAR_LINEAR_H_ +#define INFINI_OPS_CPU_LINEAR_LINEAR_H_ + +#include + +#include "base/linear.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Linear, + Caster { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear{a, b, bias, trans_a, trans_b, out} {} + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, bias, trans_a, trans_b, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const { + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* Out = static_cast(out.data()); + const T* Bias = bias ? static_cast(bias->data()) : nullptr; + + // Determine M, K, N from shapes and transpose flags. + auto ndim_a = a_shape_.size(); + auto ndim_b = b_shape_.size(); + auto ndim_out = out_shape_.size(); + + Tensor::Size M = out_shape_[ndim_out - 2]; + Tensor::Size N = out_shape_[ndim_out - 1]; + Tensor::Size K = trans_a ? a_shape_[ndim_a - 2] : a_shape_[ndim_a - 1]; + + // Compute strides for the inner matrix dimensions after transpose. + Tensor::Stride stride_a_m = + trans_a ? a_strides_[ndim_a - 1] : a_strides_[ndim_a - 2]; + Tensor::Stride stride_a_k = + trans_a ? a_strides_[ndim_a - 2] : a_strides_[ndim_a - 1]; + Tensor::Stride stride_b_k = + trans_b ? b_strides_[ndim_b - 1] : b_strides_[ndim_b - 2]; + Tensor::Stride stride_b_n = + trans_b ? b_strides_[ndim_b - 2] : b_strides_[ndim_b - 1]; + Tensor::Stride stride_out_m = out_strides_[ndim_out - 2]; + Tensor::Stride stride_out_n = out_strides_[ndim_out - 1]; + + // Batch dimensions. + Tensor::Size batch_count = 1; + for (size_t i = 0; i + 2 < ndim_out; ++i) { + batch_count *= out_shape_[i]; + } + + Tensor::Stride batch_stride_a = ndim_a > 2 ? a_strides_[ndim_a - 3] : 0; + Tensor::Stride batch_stride_b = ndim_b > 2 ? b_strides_[ndim_b - 3] : 0; + Tensor::Stride batch_stride_out = + ndim_out > 2 ? out_strides_[ndim_out - 3] : 0; + + // Bias stride: for 1D bias [N], stride is 1. For batched bias, use last + // stride. + Tensor::Stride bias_stride = 0; + if (Bias && bias) { + auto ndim_bias = bias->shape().size(); + bias_stride = bias->strides()[ndim_bias - 1]; + } + + for (Tensor::Size batch = 0; batch < batch_count; ++batch) { + const auto* A_batch = A + batch * batch_stride_a; + const auto* B_batch = B + batch * batch_stride_b; + auto* Out_batch = Out + batch * batch_stride_out; + + for (Tensor::Size i = 0; i < M; ++i) { + for (Tensor::Size j = 0; j < N; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < K; ++l) { + float a_val = Cast(A_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast(B_batch[l * stride_b_k + j * stride_b_n]); + sum += a_val * b_val; + } + + if (Bias) { + sum += Cast(Bias[j * bias_stride]); + } + + Out_batch[i * stride_out_m + j * stride_out_n] = Cast(sum); + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/mul/mul.h b/src/cpu/mul/mul.h new file mode 100644 index 00000000..0bdefb96 --- /dev/null +++ b/src/cpu/mul/mul.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_CPU_MUL_MUL_H_ +#define INFINI_OPS_CPU_MUL_MUL_H_ + +#include + +#include "base/mul.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Mul, + Caster { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul{input, other, out} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, other, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor input, const Tensor other, Tensor out) const { + using ComputeType = std::conditional_t || + IsFP16, + float, T>; + + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) * + Cast(other_ptr[other_idx])); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_cast.py b/tests/test_cast.py new file mode 100644 index 00000000..bd19d934 --- /dev/null +++ b/tests/test_cast.py @@ -0,0 +1,62 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4, 4), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("input_dtype", "out_dtype", "rtol", "atol"), + ( + (torch.float16, torch.float32, 1e-3, 1e-3), + (torch.float32, torch.float16, 1e-3, 1e-3), + (torch.bfloat16, torch.float32, 1e-2, 5e-3), + (torch.float32, torch.bfloat16, 1e-2, 5e-3), + (torch.float16, torch.bfloat16, 1e-2, 5e-3), + (torch.bfloat16, torch.float16, 1e-2, 5e-3), + ), +) +def test_cast( + shape, + input_strides, + out_strides, + input_dtype, + out_dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=input_dtype, device=device) + out = empty_strided(shape, out_strides, dtype=out_dtype, device=device) + + return Payload( + _cast, + _torch_cast, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cast(input, out): + infini.ops.cast(input, out, stream=get_stream(input.device)) + + return out + + +def _torch_cast(input, out): + out.copy_(input.to(out.dtype)) + + return out diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..85428b53 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,69 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shapes, dim, out_shape", + ( + # 2 inputs, dim=0 + (((4, 64), (4, 64)), 0, (8, 64)), + # 2 inputs, dim=1 + (((4, 32), (4, 64)), 1, (4, 96)), + # 2 inputs, dim=-1 (negative dim) + (((4, 32), (4, 64)), -1, (4, 96)), + # 3 inputs, dim=1 + (((4, 16), (4, 32), (4, 16)), 1, (4, 64)), + # 2 inputs, dim=0, 3D + (((2, 4, 64), (2, 4, 64)), 0, (4, 4, 64)), + # 2 inputs, dim=2, 3D + (((2, 4, 32), (2, 4, 64)), 2, (2, 4, 96)), + # 4 inputs, dim=1 + (((1, 1024), (1, 1024), (1, 1024), (1, 1024)), 1, (1, 4096)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_cat(shapes, dim, out_shape, dtype, device, rtol, atol): + inputs = [randn_strided(s, None, dtype=dtype, device=device) for s in shapes] + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _cat(*args, dim=dim), + lambda *args: _torch_cat(*args, dim=dim), + (*inputs, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + first = inputs[0] + rest = inputs[1:] + + infini.ops.cat(first, rest, dim, out, stream=get_stream(first.device)) + + return out + + +def _torch_cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + result = torch.cat(inputs, dim=dim) + out.copy_(result) + + return out diff --git a/tests/test_linear.py b/tests/test_linear.py new file mode 100644 index 00000000..364ba5fc --- /dev/null +++ b/tests/test_linear.py @@ -0,0 +1,90 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, out_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((1, 4096), (4096, 4096), (1, 4096)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize("has_bias", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 5e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_linear( + a_shape, + b_shape, + out_shape, + trans_a, + trans_b, + has_bias, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + # Bias shape is [N], the last dim of the output. + bias = None + + if has_bias: + N = out_shape[-1] + bias = randn_strided((N,), None, dtype=dtype, device=device) + + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _linear(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_linear(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, bias, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _linear(a, b, bias, out, trans_a=False, trans_b=False): + infini.ops.linear(a, b, bias, trans_a, trans_b, out, stream=get_stream(a.device)) + + return out + + +def _torch_linear(a, b, bias, out, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()) + + if bias is not None: + result = result + bias.float() + + out.copy_(result.to(out.dtype)) + + return out diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 00000000..fea3822a --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,76 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, c_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 1e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_matmul( + a_shape, + b_shape, + c_shape, + trans_a, + trans_b, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + c = empty_strided(c_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _matmul(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_matmul(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, c), + {}, + rtol=rtol, + atol=atol, + ) + + +def _matmul(a, b, c, trans_a=False, trans_b=False): + infini.ops.matmul(a, b, c, trans_a, trans_b, stream=get_stream(a.device)) + + return c + + +def _torch_matmul(a, b, c, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()).to(c.dtype) + c.copy_(result) + + return c diff --git a/tests/test_mul.py b/tests/test_mul.py new file mode 100644 index 00000000..e368f96d --- /dev/null +++ b/tests/test_mul.py @@ -0,0 +1,87 @@ +import infini.ops +import pytest +import torch + +from tests.utils import ( + Payload, + empty_strided, + get_stream, + randint_strided, + randn_strided, +) + +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), +) +def test_mul( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + if device == "musa" and dtype in _UINT_DTYPES: + pytest.skip( + "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." + ) + + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: + input = randint_strided( + 0, 100, shape, input_strides, dtype=dtype, device=device + ) + other = randint_strided( + 0, 100, shape, other_strides, dtype=dtype, device=device + ) + else: + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload(_mul, _torch_mul, (input, other, out), {}, rtol=rtol, atol=atol) + + +def _mul(input, other, out): + infini.ops.mul(input, other, out, stream=get_stream(input.device)) + + return out + + +def _torch_mul(input, other, out): + if input.dtype in _UINT_DTYPES: + input = input.to(torch.int64) + + if other.dtype in _UINT_DTYPES: + other = other.to(torch.int64) + + res = torch.mul(input, other) + out.copy_(res.to(out.dtype)) + + return out From bdbc695a069c304ad11b4c871f6194514f76f948 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 21 Apr 2026 16:15:35 +0800 Subject: [PATCH 2/7] =?UTF-8?q?fix(ascend):=20Add/Cat=20destructor=20?= =?UTF-8?q?=E2=80=94=20use=20`release()`=20for=20executor-owned=20caches?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `add/kernel.h`: swap destroy() → release() on in_cache_/oth_cache_/out_cache_ and drop aclDestroyAclOpExecutor (both are referenced by the Repeatable executor; destroying them causes double-free at shutdown per the pattern documented in common.h and commit 64c367c). - `cat/kernel.h`: release all in_caches_[i] in the destructor; without it, ~AclTensorCache() on vector teardown double-frees descriptors held by tensor_list_ / executor_. - Also group the alpha_* storage members with blank lines to match file convention. --- src/ascend/add/kernel.h | 27 ++++++++++++++------------- src/ascend/cat/kernel.h | 9 ++++++++- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h index 73b3005b..251c3136 100644 --- a/src/ascend/add/kernel.h +++ b/src/ascend/add/kernel.h @@ -34,15 +34,14 @@ class Operator : public Add { ~Operator() { if (!ascend::IsAclRuntimeAlive()) return; - // Destroy cached tensors and the executor, then the scalar. - // Historical note: this active-destroy pattern works for `Add` at - // process exit but crashed for most other operators — see `64c367c` - // and the rest of `src/ascend/*/kernel.h` which use `release()` only. - in_cache_.destroy(); - oth_cache_.destroy(); - out_cache_.destroy(); - - if (executor_) aclDestroyAclOpExecutor(executor_); + // Null cached descriptors — see `AclTensorCache::release()`. The + // descriptors are still referenced by the Repeatable `executor_`, so + // skipping `aclDestroyTensor` (and leaking the executor at shutdown) + // avoids a double-free; see `64c367c`. + in_cache_.release(); + oth_cache_.release(); + out_cache_.release(); + if (alpha_) aclDestroyScalar(alpha_); } @@ -80,10 +79,12 @@ class Operator : public Add { mutable uint64_t ws_size_ = 0; - float alpha_float_storage_ = - 1.0f; // Stable address for `aclCreateScalar` (float). - int64_t alpha_int_storage_ = - 1; // Stable address for `aclCreateScalar` (int). + // Stable address for `aclCreateScalar` (float). + float alpha_float_storage_ = 1.0f; + + // Stable address for `aclCreateScalar` (int). + int64_t alpha_int_storage_ = 1; + aclScalar* alpha_ = nullptr; }; diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h index bb821073..018f966a 100644 --- a/src/ascend/cat/kernel.h +++ b/src/ascend/cat/kernel.h @@ -31,7 +31,14 @@ class Operator : public Cat { ~Operator() { if (!ascend::IsAclRuntimeAlive()) return; - // Null cached descriptors — see `AclTensorCache::release()`. + // Null cached descriptors — see `AclTensorCache::release()`. The input + // descriptors are referenced by the Repeatable `executor_` via + // `tensor_list_`, so every `in_caches_[i]` must be released alongside + // `out_cache_`; otherwise `~AclTensorCache()` double-frees them when the + // vector destructs. + for (auto& c : in_caches_) { + c.release(); + } out_cache_.release(); if (tensor_list_) aclDestroyTensorList(tensor_list_); From 9d7cb0e74c5674d12a1637ac564c50c1a3383b39 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 21 Apr 2026 16:55:27 +0800 Subject: [PATCH 3/7] test: generate `implementation_index` dynamically from `active_implementation_indices` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces hardcoded `(0, 1)` / `(0, 1, 2)` tuples in test_add, test_gemm, test_rms_norm, test_swiglu with a union over the locally-available devices' active implementation indices. New helper `tests.utils.all_active_implementation_indices(op_cls)` only iterates `get_available_devices()` to avoid `DispatchFunc::std::abort` on device types outside the build's `ActiveDevices` set. Effect on Ascend CI: skipped-test count drops from 3246 to 1686 — impl=1 (`cuBLASLt`) no longer parametrized when no CUDA device is visible, and RmsNorm/Swiglu's custom-kernel slot drops out of the matrix on op-simple where the framework layer hasn't merged the AscendC impl yet. --- tests/test_add.py | 7 ++++--- tests/test_gemm.py | 13 +++++++++---- tests/test_rms_norm.py | 12 ++++++++++-- tests/test_swiglu.py | 12 ++++++++++-- tests/utils.py | 22 ++++++++++++++++++++++ 5 files changed, 55 insertions(+), 11 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index 12c4b9b5..1a501825 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -4,6 +4,7 @@ from tests.utils import ( Payload, + all_active_implementation_indices, empty_strided, get_stream, randint_strided, @@ -35,9 +36,9 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -# TODO: Generate implementation indices dynamically from -# `Add.active_implementation_indices` instead of hardcoding. -@pytest.mark.parametrize("implementation_index", (0, 1)) +@pytest.mark.parametrize( + "implementation_index", all_active_implementation_indices(infini.ops.Add) +) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 97d06069..4e9ccc61 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,12 @@ import pytest import torch -from tests.utils import Payload, get_stream, randn_strided +from tests.utils import ( + Payload, + all_active_implementation_indices, + get_stream, + randn_strided, +) @pytest.mark.auto_act_and_assert @@ -20,9 +25,9 @@ @pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) @pytest.mark.parametrize("trans_a", (False, True)) @pytest.mark.parametrize("trans_b", (False, True)) -# TODO: Generate implementation indices dynamically from -# `Gemm.active_implementation_indices` instead of hardcoding. -@pytest.mark.parametrize("implementation_index", (0, 1, 2)) +@pytest.mark.parametrize( + "implementation_index", all_active_implementation_indices(infini.ops.Gemm) +) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index 52fd1ae4..be8e3203 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, get_stream, randn_strided +from tests.utils import ( + Payload, + all_active_implementation_indices, + empty_strided, + get_stream, + randn_strided, +) @pytest.mark.auto_act_and_assert @@ -18,7 +24,9 @@ ), ) @pytest.mark.parametrize("eps", (1e-6, 1e-5)) -@pytest.mark.parametrize("implementation_index", (0, 1)) +@pytest.mark.parametrize( + "implementation_index", all_active_implementation_indices(infini.ops.RmsNorm) +) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 23c29943..a8a3ac15 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, get_stream, rand_strided +from tests.utils import ( + Payload, + all_active_implementation_indices, + empty_strided, + get_stream, + rand_strided, +) @pytest.mark.auto_act_and_assert @@ -19,7 +25,9 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -@pytest.mark.parametrize("implementation_index", (0, 1)) +@pytest.mark.parametrize( + "implementation_index", all_active_implementation_indices(infini.ops.Swiglu) +) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( diff --git a/tests/utils.py b/tests/utils.py index 982d05ae..d524214c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -122,6 +122,28 @@ def get_stream(device): return getattr(stream, attr, 0) +def all_active_implementation_indices(op_cls): + """Union of `op_cls.active_implementation_indices(device)` across every + locally-available torch device type. + + Use as the `@pytest.mark.parametrize("implementation_index", ...)` value so + the test matrix grows automatically when a new backend implementation is + added. Per-device filtering (skipping indices not active on the currently + selected device) stays the test body's responsibility — see the `skip` + pattern in `test_gemm.py`. + + Limited to `get_available_devices()` to avoid `DispatchFunc::std::abort` + for device types outside the build's `ActiveDevices` set (e.g., querying + `"cuda"` on an Ascend-only build). + """ + indices = set() + + for device in get_available_devices(): + indices.update(op_cls.active_implementation_indices(device)) + + return tuple(sorted(indices)) + + def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device From 1dd288fed447e8ae83ab7d51b8cfc8f6e839f415 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 21 Apr 2026 17:47:00 +0800 Subject: [PATCH 4/7] test(conftest): joint `(device, implementation_index)` parametrize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the per-test `@pytest.mark.parametrize("implementation_index", ...)` + runtime `if impl not in active_indices: skip` pattern with a single hook in `conftest.pytest_generate_tests` that emits only the (device, impl) pairs actually active on each device. Rationale: kernel dispatch is per-device, so cross-device union (previous `all_active_implementation_indices` helper) polluted the matrix with impls that the selected device can't run — runtime-skipped noise. Joint generation keeps the matrix to its semantic cell: "this device has this impl, so run it". - `tests/conftest.py`: when both `device` and `implementation_index` are in fixturenames, emit pairs via `op_cls.active_implementation_indices(dev)`; fall back to a skipped placeholder (`id="skip"`) when no device has an active impl, avoiding `[NOTSET-...]` test IDs. - `tests/{test_add,test_gemm,test_rms_norm,test_swiglu}.py`: drop the hardcoded `implementation_index` parametrize decorator and the runtime `active_indices` guard — conftest now handles both. - `tests/utils.py`: remove the `all_active_implementation_indices` helper (superseded by per-device generation in conftest). Same test outcome on Ascend CI (1935 passed / 1686 skipped) but the remaining skips are now either semantically mandatory (uint dtypes unsupported by `torch_npu`, Gemm impl=2 SFINAE-only workaround, op missing ascend impl on op-simple pending PR #66) rather than mechanism artifacts. --- tests/conftest.py | 60 +++++++++++++++++++++++++++++++++++++++++- tests/test_add.py | 9 ------- tests/test_gemm.py | 15 +---------- tests/test_rms_norm.py | 16 +---------- tests/test_swiglu.py | 16 +---------- tests/utils.py | 22 ---------------- 6 files changed, 62 insertions(+), 76 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7b39007f..77104e7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -191,7 +191,65 @@ def pytest_generate_tests(metafunc): else: devices = () - metafunc.parametrize("device", devices or available) + devices = devices or available + + # Joint `(device, implementation_index)` parametrize: generate only + # pairs where the op has an active implementation on that device. + # Avoids cross-device pollution — an impl active on `cpu` but not on + # `npu` no longer appears as a runtime skip in the npu column. + if ( + "implementation_index" in metafunc.fixturenames + and "implementation_index" not in already_parametrized + ): + op_cls = _op_class_from_module(metafunc.module) + + if op_cls is not None and hasattr(op_cls, "active_implementation_indices"): + pairs = [ + (dev, idx) + for dev in devices + for idx in op_cls.active_implementation_indices(dev) + ] + + if not pairs: + # Emit one skipped placeholder so test IDs read + # `[skip-dtype0-...]` instead of `[NOTSET-...]`. + pairs = [ + pytest.param( + devices[0] if devices else "cpu", + 0, + marks=pytest.mark.skip( + reason=( + f"{op_cls.__name__} has no active " + "implementation on any available device" + ) + ), + id="skip", + ) + ] + + metafunc.parametrize("device, implementation_index", pairs) + + return + + metafunc.parametrize("device", devices) + + +def _op_class_from_module(module): + """Derive the `infini.ops.` class from a `tests/test_.py` module.""" + module_name = module.__name__.rsplit(".", 1)[-1] + + if not module_name.startswith("test_"): + return None + + op_snake = module_name[len("test_") :] + op_pascal = "".join(part.capitalize() for part in op_snake.split("_")) + + try: + import infini.ops as _ops + except ImportError: + return None + + return getattr(_ops, op_pascal, None) @pytest.hookimpl(tryfirst=True) diff --git a/tests/test_add.py b/tests/test_add.py index 1a501825..e2266c30 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -4,7 +4,6 @@ from tests.utils import ( Payload, - all_active_implementation_indices, empty_strided, get_stream, randint_strided, @@ -36,9 +35,6 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -@pytest.mark.parametrize( - "implementation_index", all_active_implementation_indices(infini.ops.Add) -) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -64,11 +60,6 @@ def test_add( "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." ) - active_indices = infini.ops.Add.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - if implementation_index == 1 and dtype in _UINT_DTYPES: pytest.skip("ATen `add` does not support unsigned integer types") diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 4e9ccc61..71e0e8fd 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,12 +2,7 @@ import pytest import torch -from tests.utils import ( - Payload, - all_active_implementation_indices, - get_stream, - randn_strided, -) +from tests.utils import Payload, get_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -25,9 +20,6 @@ @pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) @pytest.mark.parametrize("trans_a", (False, True)) @pytest.mark.parametrize("trans_b", (False, True)) -@pytest.mark.parametrize( - "implementation_index", all_active_implementation_indices(infini.ops.Gemm) -) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -61,11 +53,6 @@ def test_gemm( if device == "mlu" and dtype == torch.bfloat16: pytest.skip("`bfloat16` is not supported by `cnnlBatchMatMulEx`") - active_indices = infini.ops.Gemm.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - if implementation_index == 1 and dtype in (torch.float16, torch.bfloat16): pytest.skip("cuBLASLt half-precision exceeds current tolerances") diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index be8e3203..45f9199b 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,13 +2,7 @@ import pytest import torch -from tests.utils import ( - Payload, - all_active_implementation_indices, - empty_strided, - get_stream, - randn_strided, -) +from tests.utils import Payload, empty_strided, get_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -24,9 +18,6 @@ ), ) @pytest.mark.parametrize("eps", (1e-6, 1e-5)) -@pytest.mark.parametrize( - "implementation_index", all_active_implementation_indices(infini.ops.RmsNorm) -) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -48,11 +39,6 @@ def test_rms_norm( rtol, atol, ): - active_indices = infini.ops.RmsNorm.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index a8a3ac15..f159742c 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,13 +2,7 @@ import pytest import torch -from tests.utils import ( - Payload, - all_active_implementation_indices, - empty_strided, - get_stream, - rand_strided, -) +from tests.utils import Payload, empty_strided, get_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -25,9 +19,6 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -@pytest.mark.parametrize( - "implementation_index", all_active_implementation_indices(infini.ops.Swiglu) -) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -47,11 +38,6 @@ def test_swiglu( rtol, atol, ): - active_indices = infini.ops.Swiglu.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - input = rand_strided(shape, input_strides, dtype=dtype, device=device) gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) diff --git a/tests/utils.py b/tests/utils.py index d524214c..982d05ae 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -122,28 +122,6 @@ def get_stream(device): return getattr(stream, attr, 0) -def all_active_implementation_indices(op_cls): - """Union of `op_cls.active_implementation_indices(device)` across every - locally-available torch device type. - - Use as the `@pytest.mark.parametrize("implementation_index", ...)` value so - the test matrix grows automatically when a new backend implementation is - added. Per-device filtering (skipping indices not active on the currently - selected device) stays the test body's responsibility — see the `skip` - pattern in `test_gemm.py`. - - Limited to `get_available_devices()` to avoid `DispatchFunc::std::abort` - for device types outside the build's `ActiveDevices` set (e.g., querying - `"cuda"` on an Ascend-only build). - """ - indices = set() - - for device in get_available_devices(): - indices.update(op_cls.active_implementation_indices(device)) - - return tuple(sorted(indices)) - - def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device From 3abf50beb1bb683dfa0929684537616c1158fd99 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 21 Apr 2026 21:28:33 +0800 Subject: [PATCH 5/7] refactor(conftest): dedupe `_op_class_from_module`, short-circuit redundant fixture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Post-review cleanup of the joint-parametrize refactor (1dd288f): - Extract `_op_class_from_module` as a shared helper; `skip_op_without_platform_impl` fixture now calls it instead of re-deriving the snake→pascal class name inline. - Short-circuit the fixture when `implementation_index` is already in callspec — `pytest_generate_tests` has already pruned empty-impl pairs, so per-case `active_implementation_indices` calls are wasted work. - Drop `try/except ImportError` inside the helper — collection has already imported `infini.ops` via test modules; masking a real import failure only turns it into a cryptic NOTSET fixture. - Drop the `devices[0] if devices else "cpu"` fallback — `get_available_devices()` always includes `"cpu"`, making the `else` arm unreachable. --- tests/conftest.py | 49 ++++++++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 77104e7e..d995459f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -106,44 +106,33 @@ def skip_op_without_platform_impl(request): """Skip `device=` parametrizations when the op has no implementation on any of the corresponding platforms. - Derives the InfiniOps class name from the test module filename - (`tests/test_.py` → ``) and checks - `infini.ops..active_implementation_indices()` for every - platform that maps to the test's torch device type. Skips only when - every mapped platform reports no active implementation — avoids - `Fatal Python error: Aborted` from dispatching through a base class - that has no backend specialization on the current branch. + Only runs for tests that parametrize `device` but not + `implementation_index` — joint `(device, impl_idx)` parametrize in + `pytest_generate_tests` already prunes empty-impl pairs at collection + time, making this check redundant (and wasteful) on those tests. """ if not hasattr(request.node, "callspec"): return - device = request.node.callspec.params.get("device") - platforms = _TORCH_DEVICE_TO_PLATFORMS.get(device) - - if not platforms: - return - - module_name = request.node.module.__name__.rsplit(".", 1)[-1] + params = request.node.callspec.params - if not module_name.startswith("test_"): + if "implementation_index" in params: return - op_snake = module_name[len("test_") :] - op_pascal = "".join(part.capitalize() for part in op_snake.split("_")) + platforms = _TORCH_DEVICE_TO_PLATFORMS.get(params.get("device")) - try: - import infini.ops as _ops - except ImportError: + if not platforms: return - op_cls = getattr(_ops, op_pascal, None) + op_cls = _op_class_from_module(request.node.module) if op_cls is None or not hasattr(op_cls, "active_implementation_indices"): return if not any(op_cls.active_implementation_indices(p) for p in platforms): pytest.skip( - f"{op_pascal} has no implementation on any `{device}`-mapped platform" + f"{op_cls.__name__} has no implementation on any " + f"`{params.get('device')}`-mapped platform" ) @@ -213,9 +202,11 @@ def pytest_generate_tests(metafunc): if not pairs: # Emit one skipped placeholder so test IDs read # `[skip-dtype0-...]` instead of `[NOTSET-...]`. + # `get_available_devices()` always includes `"cpu"`, so + # `devices[0]` is safe. pairs = [ pytest.param( - devices[0] if devices else "cpu", + devices[0], 0, marks=pytest.mark.skip( reason=( @@ -235,7 +226,12 @@ def pytest_generate_tests(metafunc): def _op_class_from_module(module): - """Derive the `infini.ops.` class from a `tests/test_.py` module.""" + """Derive the `infini.ops.` class from a `tests/test_.py` module. + + Test modules have already imported `infini.ops` by the time this runs, so + no `try/except ImportError` is needed — a real import failure would have + aborted collection long before. + """ module_name = module.__name__.rsplit(".", 1)[-1] if not module_name.startswith("test_"): @@ -244,10 +240,7 @@ def _op_class_from_module(module): op_snake = module_name[len("test_") :] op_pascal = "".join(part.capitalize() for part in op_snake.split("_")) - try: - import infini.ops as _ops - except ImportError: - return None + import infini.ops as _ops return getattr(_ops, op_pascal, None) From 2a28bb5833c2927917a307597b4a08b965c46bfd Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 22 Apr 2026 13:54:44 +0800 Subject: [PATCH 6/7] refactor(cpu): flatten nested `DispatchFunc` in Cast; snake_case variables in Linear MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per PR #65 review: - `src/cpu/cast/cast.h`: replace nested `DispatchFunc(in_dtype, ...)` inside `DispatchFunc(out_dtype, ...)` with a single multi-dispatch call `DispatchFunc({in, out}, [](in_tag, out_tag) {...})` per the multi-dispatch idiom documented in `CONTRIBUTING.md`. - `src/cpu/linear/linear.h`: rename PascalCase locals to snake_case: `A/B/Out/Bias` → `a_ptr/b_ptr/out_ptr/bias_ptr`, `A_batch/B_batch/Out_batch` → `a_batch/b_batch/out_batch`, `M/N/K` → `m/n/k` (matching master's `src/cpu/gemm/gemm.h` which already uses lowercase dim names `m_/n_/k_`). --- src/cpu/cast/cast.h | 17 ++++++----------- src/cpu/linear/linear.h | 42 ++++++++++++++++++++--------------------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/cpu/cast/cast.h b/src/cpu/cast/cast.h index 67c8367c..ef89b8ac 100644 --- a/src/cpu/cast/cast.h +++ b/src/cpu/cast/cast.h @@ -13,19 +13,14 @@ class Operator : public Cast { Operator(const Tensor input, Tensor out) : Cast{input, out} {} void operator()(const Tensor input, Tensor out) const override { - DispatchFunc( - input_dtype_, - [&](auto in_tag) { + DispatchFunc( + {input_dtype_, out_dtype_}, + [&](auto in_tag, auto out_tag) { using InT = typename decltype(in_tag)::type; - DispatchFunc( - out_dtype_, - [&](auto out_tag) { - using OutT = typename decltype(out_tag)::type; - Compute(input, out); - }, - "`Operator::operator()` (out)"); + using OutT = typename decltype(out_tag)::type; + Compute(input, out); }, - "`Operator::operator()` (in)"); + "`Operator::operator()`"); } private: diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h index f5323c2f..50bc8aeb 100644 --- a/src/cpu/linear/linear.h +++ b/src/cpu/linear/linear.h @@ -32,19 +32,19 @@ class Operator : public Linear, template void Compute(const Tensor a, const Tensor b, std::optional bias, bool trans_a, bool trans_b, Tensor out) const { - const auto* A = static_cast(a.data()); - const auto* B = static_cast(b.data()); - auto* Out = static_cast(out.data()); - const T* Bias = bias ? static_cast(bias->data()) : nullptr; + const auto* a_ptr = static_cast(a.data()); + const auto* b_ptr = static_cast(b.data()); + auto* out_ptr = static_cast(out.data()); + const T* bias_ptr = bias ? static_cast(bias->data()) : nullptr; - // Determine M, K, N from shapes and transpose flags. + // Determine `m`, `n`, `k` from shapes and transpose flags. auto ndim_a = a_shape_.size(); auto ndim_b = b_shape_.size(); auto ndim_out = out_shape_.size(); - Tensor::Size M = out_shape_[ndim_out - 2]; - Tensor::Size N = out_shape_[ndim_out - 1]; - Tensor::Size K = trans_a ? a_shape_[ndim_a - 2] : a_shape_[ndim_a - 1]; + Tensor::Size m = out_shape_[ndim_out - 2]; + Tensor::Size n = out_shape_[ndim_out - 1]; + Tensor::Size k = trans_a ? a_shape_[ndim_a - 2] : a_shape_[ndim_a - 1]; // Compute strides for the inner matrix dimensions after transpose. Tensor::Stride stride_a_m = @@ -69,34 +69,34 @@ class Operator : public Linear, Tensor::Stride batch_stride_out = ndim_out > 2 ? out_strides_[ndim_out - 3] : 0; - // Bias stride: for 1D bias [N], stride is 1. For batched bias, use last + // Bias stride: for 1D bias `[n]`, stride is 1. For batched bias, use last // stride. Tensor::Stride bias_stride = 0; - if (Bias && bias) { + if (bias_ptr && bias) { auto ndim_bias = bias->shape().size(); bias_stride = bias->strides()[ndim_bias - 1]; } for (Tensor::Size batch = 0; batch < batch_count; ++batch) { - const auto* A_batch = A + batch * batch_stride_a; - const auto* B_batch = B + batch * batch_stride_b; - auto* Out_batch = Out + batch * batch_stride_out; + const auto* a_batch = a_ptr + batch * batch_stride_a; + const auto* b_batch = b_ptr + batch * batch_stride_b; + auto* out_batch = out_ptr + batch * batch_stride_out; - for (Tensor::Size i = 0; i < M; ++i) { - for (Tensor::Size j = 0; j < N; ++j) { + for (Tensor::Size i = 0; i < m; ++i) { + for (Tensor::Size j = 0; j < n; ++j) { float sum = 0.0f; - for (Tensor::Size l = 0; l < K; ++l) { - float a_val = Cast(A_batch[i * stride_a_m + l * stride_a_k]); - float b_val = Cast(B_batch[l * stride_b_k + j * stride_b_n]); + for (Tensor::Size l = 0; l < k; ++l) { + float a_val = Cast(a_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast(b_batch[l * stride_b_k + j * stride_b_n]); sum += a_val * b_val; } - if (Bias) { - sum += Cast(Bias[j * bias_stride]); + if (bias_ptr) { + sum += Cast(bias_ptr[j * bias_stride]); } - Out_batch[i * stride_out_m + j * stride_out_n] = Cast(sum); + out_batch[i * stride_out_m + j * stride_out_n] = Cast(sum); } } } From b8f874b726f3e88bc90dd749789a9e3fcb0f92c0 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 22 Apr 2026 13:59:13 +0800 Subject: [PATCH 7/7] refactor(cpu/linear): drop redundant `&& bias` guard + narrating comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `if (bias_ptr && bias)` → `if (bias_ptr)` (line 75). `bias_ptr` is `nullptr` iff `!bias` by construction at line 38, so `&& bias` is dead. - Remove `// Determine `m`, `n`, `k` from shapes and transpose flags.` — the three lines below literally do exactly that; self-describing now that names are snake_case. --- src/cpu/linear/linear.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h index 50bc8aeb..21e1bb26 100644 --- a/src/cpu/linear/linear.h +++ b/src/cpu/linear/linear.h @@ -37,7 +37,6 @@ class Operator : public Linear, auto* out_ptr = static_cast(out.data()); const T* bias_ptr = bias ? static_cast(bias->data()) : nullptr; - // Determine `m`, `n`, `k` from shapes and transpose flags. auto ndim_a = a_shape_.size(); auto ndim_b = b_shape_.size(); auto ndim_out = out_shape_.size(); @@ -72,7 +71,7 @@ class Operator : public Linear, // Bias stride: for 1D bias `[n]`, stride is 1. For batched bias, use last // stride. Tensor::Stride bias_stride = 0; - if (bias_ptr && bias) { + if (bias_ptr) { auto ndim_bias = bias->shape().size(); bias_stride = bias->strides()[ndim_bias - 1]; }