diff --git a/examples/qwen3/qwen3_32b_decode.py b/examples/qwen3/qwen3_32b_decode.py index 02727f5..296d682 100644 --- a/examples/qwen3/qwen3_32b_decode.py +++ b/examples/qwen3/qwen3_32b_decode.py @@ -243,16 +243,19 @@ def qwen3_decode_layer( scores = pl.mul(pl.matmul(q_rot_bf16, k_tile, b_trans=True), ATTN_SCALE) # TODO(valid_shape): once the compiler propagates valid_shape # from k_tile, scores will auto-get vs=[1, valid_len] and the - # manual scores_valid view + exp_pad can be removed. - scores_valid = pl.slice(scores, [1, valid_len], [0, 0]) - cur_mi = pl.cast(pl.row_max(scores_valid), target_type=pl.FP32) - exp_scores = pl.exp(pl.row_expand_sub(scores_valid, cur_mi)) + # manual scores_valid view can be removed. + scores_valid = pl.slice( + scores, + [1, SEQ_TILE], + [0, 0], + valid_shape=[1, valid_len], + ) + scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) + cur_mi = pl.cast(pl.row_max(scores_padded), target_type=pl.FP32) + exp_scores = pl.exp(pl.row_expand_sub(scores_padded, cur_mi)) cur_li = pl.cast(pl.row_sum(exp_scores), target_type=pl.FP32) - exp_pad = pl.create_tensor([1, SEQ_TILE], dtype=pl.FP32) - exp_pad = pl.mul(exp_pad, 0.0) - exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0]) oi_tmp = pl.matmul( - pl.cast(exp_pad, target_type=pl.BF16), + pl.cast(exp_scores, target_type=pl.BF16), v_tile, out_dtype=pl.FP32, )