Skip to content

[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461

Open
VeeraRajasekhar wants to merge 11 commits intodevfrom
veergopu/fused-varlen-ck-smallseq-integration
Open

[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
VeeraRajasekhar wants to merge 11 commits intodevfrom
veergopu/fused-varlen-ck-smallseq-integration

Conversation

@VeeraRajasekhar
Copy link
Contributor

Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..

  • Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

  • Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.

  • In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).

  • In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.

  • Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);

  • JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.

  • Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..

- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
  fused_attn_rocm/: declarations and implementation adapted from
  varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
  grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
  max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
  transformer_engine/common/CMakeLists.txt.

- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
  max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
  == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
  max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
  h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
  call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
  count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
  output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
  matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
  kernel expects sequence-level batch).

- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
  (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
  max_seqlen_kv; on real run call get_runtime_max_seqlen then
  fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
  get_runtime_max_seqlen, workspace size, and small-seq bwd.

- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
  path (forward write, backward read);

- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
  kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
  q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
  buffer matches C++ attention-weights convention.

- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
  (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
  SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
  C++.
@wangye805
Copy link
Collaborator

wangye805 commented Feb 25, 2026

Let's make this PR work for jax extension first. Later we can support pytorch.

One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in

if backend == NVTE_Fused_Attn_Backend.NVTE_AOTriton:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
elif backend == NVTE_Fused_Attn_Backend.NVTE_CK:
if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
and
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
, without knowing actual runtime cu_seqlen_q/kv. Aux tensors are prepared in
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
, also without the knowledge of runtime cu_seqlen_q/kv

General guideline:
1). Pre-allocate large enough softmax_aux and workspace ahead of time. Do not modify the aux preparation function or the c++ level aux workspace calculation/preparation, since we know our softmax aux and workspace size will be large enough for both flow, and the special flow only need a valid start pointer address.
2). During actual kernel dispatch, we do a seqlen_q/kv check, if it satisfy the special cross-attn condition, we launch it here
3). Use an env to guard this new flow and disable it when CP is used

NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace.");

float sqr_dk_scale = attn_scale;
hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t

- tests/jax: CK small-seq tests use fixture to set/restore
  NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing
  cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq =
  max_seqlen_q for THD else 2.
- JAX attention.py: THD softmax shape/dtype uses small-seq path only when
  env=1, else original layout
- JAX attention.cpp: Added env guard
- fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for
  fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
@VeeraRajasekhar
Copy link
Contributor Author

const T* V_ptr = static_cast<const T*>(devPtrV);
T* O_ptr = static_cast<T*>(devPtrO);
T* attn_workspace = static_cast<T*>(attn_weights_buffer);
const int* cu_kv = static_cast<const int*>(devPtrCuSeqlensKV);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it run into issues if we don't pass in cu_seqlen_q/cu_seqlen_q_padded?

For example, if there are several empty segments for q/kv but for all non-empty ones, s_q always equal to 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on integrating new updates from ck team as they have added this support

)
ck_smallseq_softmax_aux_size = (
batch_size * attn_heads * q_max_seqlen
* min(kv_max_seqlen, 16) * 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the implementation, we only support kv_max_seqlen<=16 right? So should this be checked via an assertion instead of enforced via min?

Copy link
Contributor Author

@VeeraRajasekhar VeeraRajasekhar Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot keep that cause, we care about run_time_max_seq_len, here we don't know the run_time_max_seqlen, for example

pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"),
pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"),
this test cases s_kv is not 16 but the the num of segments and inputs are chosen in such a way that
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));
this returns <=16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this test would break for any case where the runtime_max_seqlen_kv is actually >16?

Comment on lines +382 to +387
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16))
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16))
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape += (1,)
else:
softmax_shape += (min(kv_max_seqlen, 16),)

@VeeraRajasekhar VeeraRajasekhar force-pushed the veergopu/fused-varlen-ck-smallseq-integration branch from b5c5fb7 to c6e0eae Compare March 3, 2026 21:17
)
ck_smallseq_softmax_aux_size = (
batch_size * attn_heads * q_max_seqlen
* min(kv_max_seqlen, 16) * 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this test would break for any case where the runtime_max_seqlen_kv is actually >16?

run 1 test_fused_attn.py
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow
XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our new rocm7.2 image, the xla cudagraph disabling need to use

XLA_FLAGS="--xla_gpu_enable_command_buffer="

generates random segments. For KV: always generates random segments.
"""
num_segments_per_seq = self.max_seqlen_q
if self.max_seqlen_q == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it run into problems if we call generate_random_segment_ids directly when self.max_seqlen_q==1?

self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put the small_seq into self.config to replace the checking with ENV?

def ck_smallseq_env(monkeypatch):
"""Enable CK small-seq path and disable XLA GPU graphs for these tests."""
if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""):
pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the new XLA_FLAG for cudagraph is due to the change of rocm or jax. If it's with the jax change, we can use a jax version check

void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add another filter s_q !=s_kv here

void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's add s_q!=s_kv here

if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
# THD only: check env; run small-seq logic only when enabled
if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also filter it with q_max_seqlen != kv_max_seqlen

) # 2 bytes for bf16/fp16
if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(q_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax_dtype for old ck flow is fp32, I recall?

auto work_shape = MakeShapeVector(query_workspace_tensor.shape());

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the q_max_seqlen != kv_max_seqlen filter here as well

wangye805 and others added 2 commits March 13, 2026 17:07
…dded_q_to_batch)

CK (fused_attn_ck.cpp):
- Add build_padded_q_to_batch_kernel: from cu_seqlens_q_padded writes
  padded_q_to_batch[slot] = batch_idx for the first Q slot of each batch.
- In smallseq fwd/bwd paths (max_seqlen_q==1, max_seqlen_kv 2..16):
  allocate workspace for padded_q_to_batch, run the kernel, pass
  devPtrCuSeqlensQ, devPtrSeqOffsetsQ, total_padded_q, devPtrPaddedQToBatch
  to smallseq, and use a dedicated smallseq_workspace pointer for the
  smallseq backend.

Smallseq (fused_attn_smallseq.cpp / .h):
- Forward/backward APIs now take Q sequence/offset and packed-Q mapping:
  devPtrCuSeqlensQ, devPtrCuSeqlensQPadded, total_padded_q,
  devPtrPaddedQToBatch (caller builds padded_q_to_batch on device).
- Kernels use packed Q layout: Q/scores indexed by q_storage_offset
  (cu_seqlens_q_padded) and skip batches with actual_seq_q == 0.
- Softmax/grad grids use total_padded_q * head_num * max_seq_kv (total_elt)
  with padded_q_to_batch for batch mapping; backward workspace size
  uses total_padded_q instead of batch count b.
- fused_attn_smallseq_bwd_workspace_size(b,...) -> (total_padded_q,...).

Tests (tests/jax/test_fused_attn.py):
- max_seqlen_q==1: use get_seqlens_and_offsets(segment_ids_q) for
  offsets_q (same convention as q>1), then override seqlens_q to ones
  (bincount length=1 quirk).
- Temporarily disable two seqpack tests that hang with updated kernels:
  seqpack-2048-2-4-16-16-128-128, seqpack-2-4096-8192-16-16-128-128.
@VeeraRajasekhar
Copy link
Contributor Author

@wangye805

I am currently working on

  • Need to investigate why
    # Following tests are hanging with updated kernels, investigating the issue.
    # pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"),
    # pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"),
    are not working
  • Need to add tests and check if zero seq len tests are working through TE invocation.(In progress)

namespace transformer_engine {
namespace fused_attn_rocm {

__global__ void build_padded_q_to_batch_kernel(const int* cu_seqlens_q_padded,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the name of this kernel, I presume you wanted to get a map from token id including padding to which batch this token belongs to. Probably you need a for-loop for all i in [start, end), assign the padded_q_to_batch[i] to b.

total_padded_q * sizeof(int));
constexpr int block = 256;
dim3 grid((b + block - 1) / block);
build_padded_q_to_batch_kernel<<<grid, block, 0, stream>>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, now we need extra workspace for devPtrPaddedQToBatch, you will need to edit the jax workspace size function to request it:

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
, just like what you did in bwd:
const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && q_max_seqlen!=kv_max_seqlen && nvte_smallseq && std::string(nvte_smallseq) == "1") {
size_t workspace_elems = product(work_shape);
size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype());
size_t workspace_bytes = workspace_elems * elt_size;
size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16)
if (workspace_bytes < unfused_small_seq_workspace) {
size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size;
work_shape = std::vector<size_t>{min_elems};
workspace_elems = min_elems;
workspace_bytes = workspace_elems * elt_size;
}
const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG");
if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") {
std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): ";
std::cout << "input_batch: " << input_batch << ", ";
std::cout << "is_ragged: " << is_ragged << ", ";
std::cout << "workspace_elems: " << workspace_elems << ", ";
std::cout << "workspace_bytes: " << workspace_bytes << ", ";
std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", ";
std::cout << "workspace_bytes >= unfused_small_seq_workspace: "
<< (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl;
}
}

constexpr int block = 256;
dim3 grid((b + block - 1) / block);
build_padded_q_to_batch_kernel<<<grid, block, 0, stream>>>(
static_cast<const int*>(devPtrSeqOffsetsQ), static_cast<int>(b), devPtrPaddedQToBatch);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, you will need to request the extra buffer inside

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && q_max_seqlen!=kv_max_seqlen && nvte_smallseq && std::string(nvte_smallseq) == "1") {
size_t workspace_elems = product(work_shape);
size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype());
size_t workspace_bytes = workspace_elems * elt_size;
size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16)
if (workspace_bytes < unfused_small_seq_workspace) {
size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size;
work_shape = std::vector<size_t>{min_elems};
workspace_elems = min_elems;
workspace_bytes = workspace_elems * elt_size;
}
const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG");
if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") {
std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): ";
std::cout << "input_batch: " << input_batch << ", ";
std::cout << "is_ragged: " << is_ragged << ", ";
std::cout << "workspace_elems: " << workspace_elems << ", ";
std::cout << "workspace_bytes: " << workspace_bytes << ", ";
std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", ";
std::cout << "workspace_bytes >= unfused_small_seq_workspace: "
<< (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl;
}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants