Skip to content

[draft][plugin] sgl radix attn backend#296

Open
ZhiweiYan-96 wants to merge 3 commits intoROCm:mainfrom
zejunchen-zejun:guanbao/rebase_sgl_attn_backend
Open

[draft][plugin] sgl radix attn backend#296
ZhiweiYan-96 wants to merge 3 commits intoROCm:mainfrom
zejunchen-zejun:guanbao/rebase_sgl_attn_backend

Conversation

@ZhiweiYan-96
Copy link

No description provided.

self.k_scale = torch.tensor([1.0], dtype=torch.float32)
self.v_scale = torch.tensor([1.0], dtype=torch.float32)

def forward_sgl_plugin_mode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ZhiweiYan-96 @zejunchen-zejun

Here should be refined because the execution details should be contained inside the radix attention instead of the model forward

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ATOM-provided SGLang attention backend (using aiter + Triton) and wires it into the SGLang plugin path, including an env-flagged fused RoPE/QK-norm path and updated plugin/config handling.

Changes:

  • Register an ATOM SGLang attention backend under the existing "aiter" backend name.
  • Introduce ATOM_ROPE_FUSED_QKNORM toggle and integrate fused RoPE+QKNorm+KV-cache behavior in the Qwen3-MoE SGL plugin path.
  • Adjust SGLang plugin config parsing and plugin-mode hf-config initialization.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
atom/utils/envs.py Adds ATOM_ROPE_FUSED_QKNORM env flag.
atom/plugin/register.py Switches SGLang backend registration to ATOM’s backend implementation.
atom/plugin/config.py Updates SGLang argv parsing before calling prepare_server_args.
atom/plugin/attention_backend/sgl_attn_backend.py New ATOM SGLang attention backend implementation (aiter + Triton).
atom/plugin/attention_backend/init.py Package init for attention_backend.
atom/models/qwen3_moe.py Adds fused RoPE+QKNorm+KV-cache path for SGL plugin mode and adjusts rope scaling handling.
atom/model_ops/radix_attention.py Routes SGLang attention call with save_kv_cache controlled by env flag.
atom/config.py Plugin-mode hf-config sourcing and rope-parameter compatibility changes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +131 to +137
# Format1: sglang serve --model-path ...
# Format2: python3 -m sglang.launch_server --model-path ...
args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:]
# sglang has no global config variable like vllm,
# so here construct the server args from sys.argv passed by users
# this is the only way to get full arguments
server_args: ServerArgs = prepare_server_args(sys.argv[1:])
server_args: ServerArgs = prepare_server_args(args_list)
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] can raise IndexError when _generate_atom_config_from_sglang_config() is called in environments where sys.argv has fewer than 2 elements (e.g., programmatic invocation/tests). Guard len(sys.argv) > 1 before indexing, and fall back to sys.argv[1:] when no subcommand is present.

Copilot uses AI. Check for mistakes.
Comment on lines 642 to 646
# Compatible with both transformers < 5
rope_params = getattr(self.hf_config, "rope_scaling", {})
if rope_params is None:
rope_params = {}
rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None)
rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default")
rope_params = getattr(self.hf_config, "rope_scaling", {}) or {}
rope_params["rope_theta"] = self.hf_config.rope_theta
rope_params["rope_type"] = getattr(rope_params, "rope_type", "default")
self.hf_config.rope_parameters = rope_params
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") is incorrect because rope_params is a dict; getattr(...) will always return the default and ignore an existing "rope_type" key. Use rope_params.get("rope_type", "default") (or read from self.hf_config) so non-default rope settings are preserved.

Copilot uses AI. Check for mistakes.
Comment on lines +87 to +88
# forward_batch contains the filed attn_backend, which will find the backend registered in ATOM
return self.attn(
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: filed attn_backend should be field attn_backend.

Copilot uses AI. Check for mistakes.
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1"
)
== "1",
"ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1",
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Environment variable key ATOM_ROPE_FUSED_QKNORM is reading AITER_ROPE_FUSED_QKNORM, which breaks the usual ATOM_* env naming pattern in this module and makes ATOM_ROPE_FUSED_QKNORM=1 ineffective. Consider reading ATOM_ROPE_FUSED_QKNORM (or renaming the key to AITER_ROPE_FUSED_QKNORM) to keep the variable name consistent for users.

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +34
try:
from aiter import (
flash_attn_varlen_func,
dtypes,
get_pa_metadata_info_v1,
get_pa_metadata_v1,
pa_fwd_asm,
pa_persistent_fwd,
)
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)

Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except ImportError: branch prints a message and continues, but the module later uses flash_attn_varlen_func, dtypes, etc., which will raise at runtime if aiter is missing. Prefer raising an ImportError with a clear message (or using logger.error then re-raising) so failures are immediate and actionable.

Copilot uses AI. Check for mistakes.
Comment on lines +62 to +65
slot_id = tl.load(slot_mapping_ptr + tid)
if slot_id < 0:
return
block_id = slot_id // block_size
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Triton kernel uses if slot_id < 0: return, but slot_id is a tl.tensor from tl.load, so Python control flow on it will not compile. Replace this with masked loads/stores (e.g., compute a mask = slot_id >= 0 and pass it to tl.load/tl.store, and/or set slot_id = tl.maximum(slot_id, 0) under a mask).

Copilot uses AI. Check for mistakes.
Comment on lines +821 to +830
reshape_and_cache_shuffle_triton(
k,
v,
k_buffer,
v_buffer,
cache_loc,
"auto",
k_scale,
v_scale,
)
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshape_and_cache_shuffle_triton() only enables quant handling when kv_cache_dtype.startswith("fp8"), but callers always pass "auto" here, so the QUANT path is never taken and k_scale/v_scale are ignored. Pass the real kv-cache dtype (e.g., from config/model_runner) or infer it from key_cache.dtype so FP8 kv-cache works correctly.

Copilot uses AI. Check for mistakes.
Comment on lines +233 to +235
self.k_scale = torch.tensor([1.0], dtype=torch.float32)
self.v_scale = torch.tensor([1.0], dtype=torch.float32)

Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.k_scale/self.v_scale are created as CPU tensors in __init__, but are later passed alongside GPU kv-cache buffers in plugin mode. This will cause device mismatch errors when the fused path is enabled. Consider registering them as buffers and initializing/moving them to the same device as the model (or creating them on-demand on qkv.device).

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +37
try:
from aiter import (
flash_attn_varlen_func,
dtypes,
get_pa_metadata_info_v1,
get_pa_metadata_v1,
pa_fwd_asm,
pa_persistent_fwd,
)
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)

import triton
import triton.language as tl

Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ruff E402: import triton / import triton.language as tl come after a try/except block (and a print(...)), so these imports are not at the top of the module. To avoid CI failures, move the Triton imports above the try/except (or include them in the same initial import section) and avoid executable statements before the import block.

Copilot uses AI. Check for mistakes.
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1"
)
== "1",
"ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we remove this env flag because IIRC, there is duplicated

"""

logger.info("Generate atom config for plugin mode from passed config")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need change oh


assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = get_hf_config(self.model)
if is_plugin_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we follow the atom config post init?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants