Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ run_test_config() {
export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow
XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our new rocm7.2 image, the xla cudagraph disabling need to use

XLA_FLAGS="--xla_gpu_enable_command_buffer="

NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # Using FAv2 for forward and backward pass
run_default_fa 1 test_helper.py
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention
run_default_fa 1 test_sanity_import.py
Expand Down
168 changes: 148 additions & 20 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from math import sqrt
from typing import Tuple, Optional, Dict
import os
import random

import jax
Expand Down Expand Up @@ -329,7 +330,11 @@ class FusedAttnRunner:
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
if (
90400 <= get_cudnn_version() < 90500
or ( is_hip_extension() and
os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1")
):
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
Expand Down Expand Up @@ -418,6 +423,55 @@ def _check_configs(self):
"the F16_arbitrary_seqlen backend."
)

def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids):
"""
Build THD segment descriptors for the CK small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ=1).

Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1,
uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise
generates random segments. For KV: always generates random segments.
"""
num_segments_per_seq = self.max_seqlen_q
if self.max_seqlen_q == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it run into problems if we call generate_random_segment_ids directly when self.max_seqlen_q==1?

# Q: deterministic - one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size]
# Use same path as q>1 and KV: get_seqlens_and_offsets(segment_ids_q) so offsets follow
# the same convention (segment start indices, -1 padding). For (batch,1) all-ones,
# get_seqlens_and_offsets returns offsets [0, -1] per row (correct) but seqlens is wrong
# because bincount(..., length=1) truncates segment id 1, so we fix seqlens_q only.
segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q)
seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) # bincount length=1 quirk
else:
segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42
)
seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q)

min_segment_len = None if self.window_size is None else seqlens_q
segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv)
return (
num_segments_per_seq,
segment_ids_q,
segment_pos_q,
pad_q,
seqlens_q,
offsets_q,
segment_ids_kv,
segment_pos_kv,
pad_kv,
seqlens_kv,
offsets_kv,
)

def _setup_inputs(self):
self._check_configs()

Expand Down Expand Up @@ -539,27 +593,42 @@ def generate_random_segment_ids(
return segment_ids, segment_pos, segment_pad

if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put the small_seq into self.config to replace the checking with ENV?

(
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
self.segment_ids_q,
self.segment_pos_q,
self.pad_q,
self.seqlens_q,
self.offsets_q,
self.segment_ids_kv,
self.segment_pos_kv,
self.pad_kv,
self.seqlens_kv,
self.offsets_kv,
) = self._setup_thd_segments_ck_smallseq(generate_random_segment_ids)
else:
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
Expand Down Expand Up @@ -1214,3 +1283,62 @@ def test_jax_new_rng():
)
runner = FusedAttnRunner(**kwargs)
runner.test_forward()


# ROCm CK small-seq varlen tests.
@pytest.fixture
def ck_smallseq_env(monkeypatch):
"""Enable CK small-seq path and disable XLA GPU graphs for these tests."""
if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""):
pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the new XLA_FLAG for cudagraph is due to the change of rocm or jax. If it's with the jax change, we can use a jax version check

monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1")
yield

@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"])
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v",
[
pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"),
pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"),
pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"),
pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"),
pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"),
pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"),
# Following tests are hanging with updated kernels, investigating the issue.
# pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"),
# pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"),
],
)
@pytest.mark.skipif(
not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware"
)
def test_ck_unfused_smallseq_backend(
ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype
):
"""
Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout.
Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and
restores it after the test.
"""
runner = FusedAttnRunner(
batch_size=b,
max_seqlen_q=s_q,
max_seqlen_kv=s_kv,
num_heads_q=h_q,
num_heads_kv=h_kv,
head_dim_qk=d_qk,
head_dim_v=d_v,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.PADDING_MASK,
dropout_prob=0.0,
use_old_rng=True,
dtype=dtype,
is_training=True,
qkv_layout=QKVLayout.THD_THD_THD,
bias_shape=None,
window_size=None,
seq_desc_format=SeqDescFormat.Seqlens,
)
runner._setup_inputs()
# runner.test_forward()
runner.test_backward()
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ else()
fused_attn_rocm/fused_attn.cpp
fused_attn_rocm/fused_attn_aotriton.cpp
fused_attn_rocm/fused_attn_ck.cpp
fused_attn_rocm/fused_attn_smallseq.cpp
fused_attn_rocm/utils.cpp
gemm/rocm_gemm.cu
amd_detail/system.cpp)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand Down Expand Up @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd(
int how_v3_bf16_cvt,
hipStream_t stream);

uint64_t get_runtime_max_seqlen(uint64_t b,
const void* cu_seqlen_ptr,
const void* cu_seqlen_padded_ptr,
void* workspace,
hipStream_t stream);

}//namespace ck_fused_attn
#endif // CK_FUSED_ATTN_H

110 changes: 107 additions & 3 deletions transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand All @@ -9,6 +9,7 @@
#include <numeric> // Required for std::accumulate
#ifdef USE_FUSED_ATTN_CK
#include <ck_fused_attn/ck_fused_attn.hpp>
#include "fused_attn_smallseq.h"
#endif // USE_FUSED_ATTN_CK
#include "../util/cuda_runtime.h"
#include "../util/system.h"
Expand All @@ -18,6 +19,17 @@
namespace transformer_engine {
namespace fused_attn_rocm {

__global__ void build_padded_q_to_batch_kernel(const int* cu_seqlens_q_padded,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the name of this kernel, I presume you wanted to get a map from token id including padding to which batch this token belongs to. Probably you need a for-loop for all i in [start, end), assign the padded_q_to_batch[i] to b.

int bs,
int* padded_q_to_batch) {
int b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= bs) return;
int start = cu_seqlens_q_padded[b];
int end = cu_seqlens_q_padded[b + 1];
if (end > start)
padded_q_to_batch[start] = b;
}

// check the fused attn config to see whether it's ck backend supported
// single filtering followed by joint filtering
bool is_ck_backend_supported(
Expand Down Expand Up @@ -614,6 +626,52 @@ void fused_attn_ck_fwd_impl(
// denote the next available section of workspace from upstream
void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && s_q!=s_kv && nvte_smallseq && std::string(nvte_smallseq) == "1") {
void* max_seqlen_workspace = workspace_next;
size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
static_cast<uint64_t>(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t));

if (nvte_log_ck_config) {
std::cout << std::endl << "attn_fwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", ";
std::cout << "flow: "
<< (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 &&
runtime_max_seqlen_kv <= 16
? "ck-smallseq"
: "regular ck/aiter")
<< std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
int total_padded_q = static_cast<int>(max_tokens_q);
int* devPtrPaddedQToBatch = static_cast<int*>(workspace_next);
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) +
total_padded_q * sizeof(int));
constexpr int block = 256;
dim3 grid((b + block - 1) / block);
build_padded_q_to_batch_kernel<<<grid, block, 0, stream>>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, now we need extra workspace for devPtrPaddedQToBatch, you will need to edit the jax workspace size function to request it:

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
, just like what you did in bwd:
const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && q_max_seqlen!=kv_max_seqlen && nvte_smallseq && std::string(nvte_smallseq) == "1") {
size_t workspace_elems = product(work_shape);
size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype());
size_t workspace_bytes = workspace_elems * elt_size;
size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16)
if (workspace_bytes < unfused_small_seq_workspace) {
size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size;
work_shape = std::vector<size_t>{min_elems};
workspace_elems = min_elems;
workspace_bytes = workspace_elems * elt_size;
}
const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG");
if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") {
std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): ";
std::cout << "input_batch: " << input_batch << ", ";
std::cout << "is_ragged: " << is_ragged << ", ";
std::cout << "workspace_elems: " << workspace_elems << ", ";
std::cout << "workspace_bytes: " << workspace_bytes << ", ";
std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", ";
std::cout << "workspace_bytes >= unfused_small_seq_workspace: "
<< (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl;
}
}

static_cast<const int*>(devPtrSeqOffsetsQ), static_cast<int>(b), devPtrPaddedQToBatch);
void* smallseq_workspace = workspace_next;

fused_attn_rocm::fused_attn_smallseq_fwd(
b, h, hg, runtime_max_seqlen_kv, d_qk, d_v,
is_training, scaling_factor, dropout_probability,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux,
devPtrCuSeqlensQ, devPtrSeqOffsetsQ,
total_padded_q, devPtrPaddedQToBatch,
devPtrCuSeqlensKV, devPtrSeqOffsetsKV,
devPtrDropoutSeed, devPtrDropoutOffset,
dtype, smallseq_workspace, workspace_size, stream);
return;
}
}

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
Expand Down Expand Up @@ -916,6 +974,53 @@ void fused_attn_ck_bwd_impl(
// denote the next available section of workspace from upstream
void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && s_q!=s_kv && nvte_smallseq && std::string(nvte_smallseq) == "1") {
void* max_seqlen_workspace = workspace_next;
size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t));

if (nvte_log_ck_config) {
std::cout << std::endl << "attn_bwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", ";
std::cout << "flow: "
<< (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 &&
runtime_max_seqlen_kv <= 16
? "ck-smallseq"
: "regular ck/aiter")
<< std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
int total_padded_q = static_cast<int>(max_tokens_q);
int* devPtrPaddedQToBatch = static_cast<int*>(workspace_next);
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) +
total_padded_q * sizeof(int));
void* smallseq_workspace = workspace_next;

constexpr int block = 256;
dim3 grid((b + block - 1) / block);
build_padded_q_to_batch_kernel<<<grid, block, 0, stream>>>(
static_cast<const int*>(devPtrSeqOffsetsQ), static_cast<int>(b), devPtrPaddedQToBatch);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, you will need to request the extra buffer inside

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && q_max_seqlen!=kv_max_seqlen && nvte_smallseq && std::string(nvte_smallseq) == "1") {
size_t workspace_elems = product(work_shape);
size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype());
size_t workspace_bytes = workspace_elems * elt_size;
size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16)
if (workspace_bytes < unfused_small_seq_workspace) {
size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size;
work_shape = std::vector<size_t>{min_elems};
workspace_elems = min_elems;
workspace_bytes = workspace_elems * elt_size;
}
const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG");
if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") {
std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): ";
std::cout << "input_batch: " << input_batch << ", ";
std::cout << "is_ragged: " << is_ragged << ", ";
std::cout << "workspace_elems: " << workspace_elems << ", ";
std::cout << "workspace_bytes: " << workspace_bytes << ", ";
std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", ";
std::cout << "workspace_bytes >= unfused_small_seq_workspace: "
<< (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl;
}
}


fused_attn_rocm::fused_attn_smallseq_bwd(
b, h, hg, runtime_max_seqlen_kv, d_qk, d_v,
scaling_factor, dropout_probability,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux,
devPtrdQ, devPtrdK, devPtrdV,
devPtrCuSeqlensQ, devPtrSeqOffsetsQ,
total_padded_q, devPtrPaddedQToBatch,
devPtrCuSeqlensKV, devPtrSeqOffsetsKV,
dtype, smallseq_workspace, workspace_size, stream);
return;
}
}

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
Expand Down Expand Up @@ -1828,7 +1933,7 @@ void fused_attn_ck_fwd(
size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_q/d_qk;
size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_kv/d_qk;

bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Expand Down Expand Up @@ -1883,7 +1988,6 @@ void fused_attn_ck_fwd(
bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);

fused_attn_ck_fwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h,
max_tokens_q, max_tokens_kv,
Expand Down
Loading