-
Notifications
You must be signed in to change notification settings - Fork 3
feat: add cambricon causal softmax op #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| #ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H | ||
| #define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H | ||
|
|
||
| #include "base/causal_softmax.h" | ||
| #include "cambricon/common.h" | ||
| #include "cambricon/data_type_.h" | ||
|
|
||
| namespace infini::ops { | ||
|
|
||
| // TODO: Remove forward declaration. | ||
| template <typename T> | ||
| void CausalSoftmaxUnion(void *workspace, int core_per_cluster, | ||
| int cluster_count, cnrtQueue_t queue, void *y, | ||
| const void *x, size_t batch_size_, size_t seq_len_, | ||
| size_t total_seq_len_, ptrdiff_t y_stride_b, | ||
| ptrdiff_t y_stride_i, ptrdiff_t y_stride_j, | ||
| ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, | ||
| ptrdiff_t x_stride_j); | ||
|
|
||
| template <> | ||
| class Operator<CausalSoftmax, Device::Type::kCambricon> : public CausalSoftmax { | ||
| public: | ||
| Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} { | ||
| cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, | ||
| &cluster_count); | ||
| } | ||
| void operator()(const Tensor input, Tensor out) const override { | ||
| auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0); | ||
| auto workspace{workspace_ ? workspace_ : default_workspace_}; | ||
| ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1; | ||
| ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0]; | ||
| ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1]; | ||
| ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1; | ||
| ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0]; | ||
| ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1]; | ||
|
|
||
| DispatchFunc< | ||
| List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>( | ||
| {static_cast<int64_t>(input.dtype())}, | ||
| [&](auto input_tag) { | ||
| using InputT = infini::ops::TypeMapType<Device::Type::kCambricon, ListGet<0>(input_tag)>; | ||
| CausalSoftmaxUnion<InputT>( | ||
| workspace, core_per_cluster, cluster_count, queue, out.data(), | ||
| input.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b, | ||
| y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j); | ||
| }, | ||
| "CambriconCausalSoftmax::operator() - output dispatch"); | ||
| } | ||
|
|
||
| std::size_t workspace_size_in_bytes() const override { return 0; } | ||
|
|
||
| ~Operator() {} | ||
|
|
||
| void *default_workspace_{nullptr}; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里需要满足 |
||
| int core_per_cluster = 0; | ||
| int cluster_count = 0; | ||
| }; | ||
|
|
||
| } // namespace infini::ops | ||
|
|
||
| #endif | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 检查参数顺序是否符合 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| #include "causal_softmax.h" | ||
|
|
||
| __nram__ char nram_buffer[NRAM_MAX_SIZE]; | ||
| const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4; | ||
|
|
||
| namespace infini::ops { | ||
|
|
||
| template <typename T> | ||
| __mlu_func__ void ProcessSoftmaxStep(const T *input, T *output, float scalar, | ||
| int num_elements, int stride, | ||
| bool is_exp_phase) { | ||
| constexpr bool is_half = std::is_same_v<T, __half>; | ||
| constexpr bool is_bfloat16 = std::is_same_v<T, __bang_bfloat16>; | ||
| constexpr bool is_float = !is_half && !is_bfloat16; | ||
|
|
||
| const int chunk_size = | ||
| SRC_MAX_SIZE / | ||
| ((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float)); | ||
| float *float_buffer = (float *)nram_buffer; | ||
| T *temp_buffer = | ||
| is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float)); | ||
|
|
||
| // Common stride configurations. | ||
| const int src_stride = stride * sizeof(T); | ||
| const int dst_stride = stride * sizeof(T); | ||
|
|
||
| int processed = 0; | ||
| while (processed < num_elements) { | ||
| int curr_batch = std::min(chunk_size, num_elements - processed); | ||
|
|
||
| if constexpr (is_float) { | ||
| __memcpy( | ||
| float_buffer, (is_exp_phase ? input : output) + processed * stride, | ||
| sizeof(float), GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1); | ||
| } else { | ||
| __memcpy(temp_buffer, | ||
| (is_exp_phase ? input : output) + processed * stride, sizeof(T), | ||
| GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1); | ||
|
|
||
| if constexpr (is_half) { | ||
| __bang_half2float(float_buffer, reinterpret_cast<half *>(temp_buffer), | ||
| curr_batch); | ||
| } else if constexpr (is_bfloat16) { | ||
| __bang_bfloat162float(float_buffer, temp_buffer, curr_batch); | ||
| } | ||
| } | ||
|
|
||
| // Common processing for all types. | ||
| if (is_exp_phase) { | ||
| __bang_sub_scalar(float_buffer, float_buffer, scalar, | ||
| curr_batch); // scalar is max_val | ||
| __bang_active_exphp(float_buffer, float_buffer, curr_batch); | ||
| } else { | ||
| __bang_mul_scalar(float_buffer, float_buffer, scalar, | ||
| curr_batch); // scalar is 1.0f/sum_val | ||
| } | ||
|
|
||
| if constexpr (is_float) { | ||
| __memcpy(output + processed * stride, float_buffer, sizeof(float), | ||
| NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1); | ||
| } else { | ||
| if constexpr (is_half) { | ||
| __bang_float2half(reinterpret_cast<half *>(temp_buffer), float_buffer, | ||
| curr_batch); | ||
| } else if constexpr (is_bfloat16) { | ||
| __bang_float2bfloat16(temp_buffer, float_buffer, curr_batch); | ||
| } | ||
|
|
||
| __memcpy(output + processed * stride, temp_buffer, sizeof(T), NRAM2GDRAM, | ||
| dst_stride, sizeof(T), curr_batch - 1); | ||
| } | ||
|
|
||
| processed += curr_batch; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| __mlu_global__ void CausalSoftmax(T *y, const T *x, size_t batch_size, | ||
| size_t seq_len, size_t total_seq_len, | ||
| ptrdiff_t y_stride_b, ptrdiff_t y_stride_i, | ||
| ptrdiff_t y_stride_j, ptrdiff_t x_stride_b, | ||
| ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) { | ||
| size_t task_id = taskId; | ||
| size_t task_num = taskDimX * taskDimY; | ||
|
|
||
| size_t total_tasks = batch_size * seq_len; | ||
| size_t tasks_per_core = (total_tasks + task_num - 1) / task_num; | ||
| size_t start = task_id * tasks_per_core; | ||
| size_t end = std::min(start + tasks_per_core, total_tasks); | ||
|
|
||
| const int max_batch = SRC_MAX_SIZE / sizeof(T); | ||
| T *src = (T *)nram_buffer; | ||
| float *dst = (float *)(nram_buffer + max_batch * sizeof(T)); | ||
|
|
||
| for (size_t index = start; index < end; index++) { | ||
| size_t batch = index / seq_len; | ||
| size_t i = (index % seq_len); | ||
| ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i; | ||
| ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i; | ||
| T *y_ = y + y_offset; | ||
| const T *x_ = x + x_offset; | ||
|
|
||
| // Calculate the valid sequence length for this position. | ||
| size_t valid_len = total_seq_len - seq_len + i + 1; | ||
|
|
||
| // Zero out future positions. | ||
| for (size_t j = valid_len; j < total_seq_len; j++) { | ||
| y_[j * y_stride_j] = (T)0.0f; | ||
| } | ||
|
|
||
| // Calculate max value using optimized reduction. | ||
| float max_val = | ||
| infini::ops::reduce::MaxBatched(x_, src, dst, valid_len, max_batch); | ||
|
|
||
| // Compute `exp(x - max)`. | ||
| ProcessSoftmaxStep(x_, y_, max_val, valid_len, x_stride_j, true); | ||
|
|
||
| // Calculate sum of exponentials. | ||
| float sum_val = | ||
| infini::ops::reduce::SumBatched(y_, src, dst, valid_len, max_batch); | ||
|
|
||
| // Normalize by sum. | ||
| ProcessSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| void CausalSoftmaxUnion(void *workspace, int core_per_cluster, | ||
| int cluster_count, cnrtQueue_t queue, void *y, | ||
| const void *x, size_t batch_size_, size_t seq_len_, | ||
| size_t total_seq_len_, ptrdiff_t y_stride_b, | ||
| ptrdiff_t y_stride_i, ptrdiff_t y_stride_j, | ||
| ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, | ||
| ptrdiff_t x_stride_j) { | ||
| cnrtDim3_t kernel_dim; | ||
| cnrtFunctionType_t kernel_type; | ||
|
|
||
| kernel_dim.x = core_per_cluster; | ||
| kernel_dim.y = cluster_count; | ||
| kernel_dim.z = 1; | ||
| kernel_type = cnrtFuncTypeUnion1; | ||
|
|
||
| CausalSoftmax<T><<<kernel_dim, kernel_type, queue>>>( | ||
| (T *)y, (const T *)x, batch_size_, seq_len_, total_seq_len_, y_stride_b, | ||
| y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j); | ||
|
|
||
| cnrtQueueSync(queue); | ||
| } | ||
|
|
||
| template void CausalSoftmaxUnion<__half>(void *, int, int, cnrtQueue_t, void *, | ||
| const void *, size_t, size_t, size_t, | ||
| ptrdiff_t, ptrdiff_t, ptrdiff_t, | ||
| ptrdiff_t, ptrdiff_t, ptrdiff_t); | ||
|
|
||
| template void CausalSoftmaxUnion<__bang_bfloat16>( | ||
| void *, int, int, cnrtQueue_t, void *, const void *, size_t, size_t, size_t, | ||
| ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t); | ||
|
|
||
| template void CausalSoftmaxUnion<float>(void *, int, int, cnrtQueue_t, void *, | ||
| const void *, size_t, size_t, size_t, | ||
| ptrdiff_t, ptrdiff_t, ptrdiff_t, | ||
| ptrdiff_t, ptrdiff_t, ptrdiff_t); | ||
|
|
||
| } // namespace infini::ops |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 检查函数参数顺序,可以参考 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也需要满足要求空行。