Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/ascend/add/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_
#define INFINI_OPS_ASCEND_ADD_KERNEL_H_

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_add.h"
#include "ascend/common.h"
#include "ascend/workspace_pool_.h"
#include "base/add.h"
#include "data_type.h"
#include "operator.h"

namespace infini::ops {

template <>
class Operator<Add, Device::Type::kAscend> : public Add {
public:
Operator(const Tensor input, const Tensor other, Tensor out)
: Add(input, other, out),
in_cache_(input),
oth_cache_(other),
out_cache_(out) {
// `aclCreateScalar` stores the pointer rather than copying the value, so
// `alpha_storage_*` must remain alive for the lifetime of `alpha_`.
// The alpha scalar type must match the tensor dtype: use int64 for integer
// dtypes and float for floating-point dtypes.
if (ascend::IsIntegerDtype(input.dtype())) {
alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64);
} else {
alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT);
}
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`. The
// descriptors are still referenced by the Repeatable `executor_`, so
// skipping `aclDestroyTensor` (and leaking the executor at shutdown)
// avoids a double-free; see `64c367c`.
in_cache_.release();
oth_cache_.release();
out_cache_.release();

if (alpha_) aclDestroyScalar(alpha_);
}

void operator()(const Tensor input, const Tensor other,
Tensor out) const override {
auto stream = static_cast<aclrtStream>(stream_);
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
auto t_oth = oth_cache_.get(const_cast<void*>(other.data()));
auto t_out = out_cache_.get(out.data());

if (!executor_) {
aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_,
&executor_);
aclSetAclOpExecutorRepeatable(executor_);
} else {
aclSetInputTensorAddr(executor_, 0, t_in,
const_cast<void*>(input.data()));
aclSetInputTensorAddr(executor_, 1, t_oth,
const_cast<void*>(other.data()));
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
aclnnAdd(arena.buf, ws_size_, executor_, stream);
}

private:
mutable ascend::AclTensorCache in_cache_;

mutable ascend::AclTensorCache oth_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable aclOpExecutor* executor_ = nullptr;

mutable uint64_t ws_size_ = 0;

// Stable address for `aclCreateScalar` (float).
float alpha_float_storage_ = 1.0f;

// Stable address for `aclCreateScalar` (int).
int64_t alpha_int_storage_ = 1;

aclScalar* alpha_ = nullptr;
};

} // namespace infini::ops

#endif
64 changes: 64 additions & 0 deletions src/ascend/cast/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_
#define INFINI_OPS_ASCEND_CAST_KERNEL_H_

#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnnop/aclnn_cast.h"
#include "ascend/common.h"
#include "ascend/workspace_pool_.h"
#include "base/cast.h"
#include "operator.h"

namespace infini::ops {

template <>
class Operator<Cast, Device::Type::kAscend> : public Cast {
public:
Operator(const Tensor input, Tensor out)
: Cast(input, out),
in_cache_(input),
out_cache_(out),
acl_out_dtype_(ascend::ToAclDtype(out.dtype())) {}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`.
in_cache_.release();
out_cache_.release();
}

void operator()(const Tensor input, Tensor out) const override {
auto stream = static_cast<aclrtStream>(stream_);
auto t_in = in_cache_.get(const_cast<void*>(input.data()));
auto t_out = out_cache_.get(out.data());

if (!executor_) {
aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_,
&executor_);
aclSetAclOpExecutorRepeatable(executor_);
} else {
aclSetInputTensorAddr(executor_, 0, t_in,
const_cast<void*>(input.data()));
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
aclnnCast(arena.buf, ws_size_, executor_, stream);
}

private:
mutable ascend::AclTensorCache in_cache_;

mutable ascend::AclTensorCache out_cache_;

aclDataType acl_out_dtype_;

mutable aclOpExecutor* executor_ = nullptr;

mutable uint64_t ws_size_ = 0;
};

} // namespace infini::ops

#endif
105 changes: 105 additions & 0 deletions src/ascend/cat/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_
#define INFINI_OPS_ASCEND_CAT_KERNEL_H_

#include <vector>

#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "aclnn/aclnn_base.h"
#include "aclnnop/aclnn_cat.h"
#include "ascend/common.h"
#include "ascend/workspace_pool_.h"
#include "base/cat.h"
#include "operator.h"

namespace infini::ops {

template <>
class Operator<Cat, Device::Type::kAscend> : public Cat {
public:
Operator(const Tensor first_input, std::vector<Tensor> rest_inputs,
int64_t dim, Tensor out)
: Cat(first_input, rest_inputs, dim, out), out_cache_(out) {
// Build `AclTensorCache` for each input tensor.
in_caches_.reserve(input_count_);
in_caches_.emplace_back(first_input);
for (const auto& t : rest_inputs) {
in_caches_.emplace_back(t);
}
}

~Operator() {
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`. The input
// descriptors are referenced by the Repeatable `executor_` via
// `tensor_list_`, so every `in_caches_[i]` must be released alongside
// `out_cache_`; otherwise `~AclTensorCache()` double-frees them when the
// vector destructs.
for (auto& c : in_caches_) {
c.release();
}
out_cache_.release();

if (tensor_list_) aclDestroyTensorList(tensor_list_);
}

void operator()(const Tensor first_input, std::vector<Tensor> rest_inputs,
int64_t /*dim*/, Tensor out) const override {
auto stream = static_cast<aclrtStream>(stream_);

// Collect all input tensors in order.
std::vector<const Tensor*> inputs;
inputs.reserve(input_count_);
inputs.push_back(&first_input);
for (const auto& t : rest_inputs) {
inputs.push_back(&t);
}

auto t_out = out_cache_.get(out.data());

if (!executor_) {
// First call: create descriptors, tensor list, and executor.
std::vector<aclTensor*> acl_tensors(input_count_);
for (size_t i = 0; i < input_count_; ++i) {
acl_tensors[i] =
in_caches_[i].get(const_cast<void*>(inputs[i]->data()));
}

tensor_list_ =
aclCreateTensorList(const_cast<const aclTensor**>(acl_tensors.data()),
static_cast<uint64_t>(input_count_));

aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_,
&executor_);
aclSetAclOpExecutorRepeatable(executor_);
} else {
// Subsequent calls: update data pointers on cached descriptors via
// `aclSetRawTensorAddr`. The executor holds references to the same
// `aclTensor*` objects inside `tensor_list_`, so updating their data
// pointers is sufficient — no `aclSetInputTensorAddr` needed.
for (size_t i = 0; i < input_count_; ++i) {
in_caches_[i].get(const_cast<void*>(inputs[i]->data()));
}
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
aclnnCat(arena.buf, ws_size_, executor_, stream);
}

private:
mutable std::vector<ascend::AclTensorCache> in_caches_;

mutable ascend::AclTensorCache out_cache_;

mutable aclTensorList* tensor_list_ = nullptr;

mutable aclOpExecutor* executor_ = nullptr;

mutable uint64_t ws_size_ = 0;
};

} // namespace infini::ops

#endif
75 changes: 50 additions & 25 deletions src/ascend/gemm/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,63 @@ class Operator<Gemm, Device::Type::kAscend> : public Gemm {
: Gemm(a, b, alpha, beta, trans_a, trans_b, c),
batched_{batch_count_ > 1},
alpha_val_{alpha.value_or(1.0f)},
beta_val_{beta.value_or(1.0f)} {
beta_val_{beta.value_or(1.0f)},
self_cache_(c),
a_cache_(a, trans_a_),
b_cache_(b, trans_b_),
out_cache_(c) {
alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT);
beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT);
}

~Operator() {
aclDestroyScalar(alpha_scalar_);
aclDestroyScalar(beta_scalar_);
if (!ascend::IsAclRuntimeAlive()) return;

// Null cached descriptors — see `AclTensorCache::release()`.
self_cache_.release();
a_cache_.release();
b_cache_.release();
out_cache_.release();

if (alpha_scalar_) aclDestroyScalar(alpha_scalar_);
if (beta_scalar_) aclDestroyScalar(beta_scalar_);
}

void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c) const override {
auto stream = static_cast<aclrtStream>(stream_);

auto t_self = ascend::BuildAclTensor(c);
auto t_a = ascend::BuildAclTensor(a, trans_a_);
auto t_b = ascend::BuildAclTensor(b, trans_b_);
auto t_out = ascend::BuildAclTensor(c);

uint64_t ws_needed = 0;
aclOpExecutor* executor = nullptr;

if (batched_) {
aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
alpha_scalar_, t_out, 0, &ws_needed,
&executor);
auto t_self = self_cache_.get(c.data());
auto t_a = a_cache_.get(const_cast<void*>(a.data()));
auto t_b = b_cache_.get(const_cast<void*>(b.data()));
auto t_out = out_cache_.get(c.data());

if (!executor_) {
if (batched_) {
aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
alpha_scalar_, t_out, 0, &ws_size_,
&executor_);
} else {
aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
alpha_scalar_, t_out, 0, &ws_size_,
&executor_);
}
aclSetAclOpExecutorRepeatable(executor_);
} else {
aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_,
t_out, 0, &ws_needed, &executor);
aclSetInputTensorAddr(executor_, 0, t_self, c.data());
aclSetInputTensorAddr(executor_, 1, t_a, const_cast<void*>(a.data()));
aclSetInputTensorAddr(executor_, 2, t_b, const_cast<void*>(b.data()));
aclSetOutputTensorAddr(executor_, 0, t_out, c.data());
}

auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed);
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);

if (batched_) {
aclnnBaddbmm(arena.buf, ws_needed, executor, stream);
aclnnBaddbmm(arena.buf, ws_size_, executor_, stream);
} else {
aclnnAddmm(arena.buf, ws_needed, executor, stream);
aclnnAddmm(arena.buf, ws_size_, executor_, stream);
}

aclDestroyTensor(t_self);
aclDestroyTensor(t_a);
aclDestroyTensor(t_b);
aclDestroyTensor(t_out);
}

private:
Expand All @@ -77,6 +90,18 @@ class Operator<Gemm, Device::Type::kAscend> : public Gemm {
aclScalar* alpha_scalar_ = nullptr;

aclScalar* beta_scalar_ = nullptr;

mutable ascend::AclTensorCache self_cache_;

mutable ascend::AclTensorCache a_cache_;

mutable ascend::AclTensorCache b_cache_;

mutable ascend::AclTensorCache out_cache_;

mutable aclOpExecutor* executor_ = nullptr;

mutable uint64_t ws_size_ = 0;
};

} // namespace infini::ops
Expand Down
Loading
Loading