Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions examples/qwen3/qwen3_32b_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,19 @@ def qwen3_decode_layer(
scores = pl.mul(pl.matmul(q_rot_bf16, k_tile, b_trans=True), ATTN_SCALE)
# TODO(valid_shape): once the compiler propagates valid_shape
# from k_tile, scores will auto-get vs=[1, valid_len] and the
# manual scores_valid view + exp_pad can be removed.
scores_valid = pl.slice(scores, [1, valid_len], [0, 0])
cur_mi = pl.cast(pl.row_max(scores_valid), target_type=pl.FP32)
exp_scores = pl.exp(pl.row_expand_sub(scores_valid, cur_mi))
# manual scores_valid view can be removed.
scores_valid = pl.slice(
scores,
[1, SEQ_TILE],
[0, 0],
valid_shape=[1, valid_len],
)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
cur_mi = pl.cast(pl.row_max(scores_padded), target_type=pl.FP32)
exp_scores = pl.exp(pl.row_expand_sub(scores_padded, cur_mi))
cur_li = pl.cast(pl.row_sum(exp_scores), target_type=pl.FP32)
exp_pad = pl.create_tensor([1, SEQ_TILE], dtype=pl.FP32)
exp_pad = pl.mul(exp_pad, 0.0)
exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0])
oi_tmp = pl.matmul(
pl.cast(exp_pad, target_type=pl.BF16),
pl.cast(exp_scores, target_type=pl.BF16),
v_tile,
out_dtype=pl.FP32,
)
Expand Down
Loading