[draft][plugin] sgl radix attn backend#296
Conversation
| self.k_scale = torch.tensor([1.0], dtype=torch.float32) | ||
| self.v_scale = torch.tensor([1.0], dtype=torch.float32) | ||
|
|
||
| def forward_sgl_plugin_mode( |
There was a problem hiding this comment.
Hi, @ZhiweiYan-96 @zejunchen-zejun
Here should be refined because the execution details should be contained inside the radix attention instead of the model forward
There was a problem hiding this comment.
Pull request overview
Adds an ATOM-provided SGLang attention backend (using aiter + Triton) and wires it into the SGLang plugin path, including an env-flagged fused RoPE/QK-norm path and updated plugin/config handling.
Changes:
- Register an ATOM SGLang attention backend under the existing
"aiter"backend name. - Introduce
ATOM_ROPE_FUSED_QKNORMtoggle and integrate fused RoPE+QKNorm+KV-cache behavior in the Qwen3-MoE SGL plugin path. - Adjust SGLang plugin config parsing and plugin-mode hf-config initialization.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/utils/envs.py | Adds ATOM_ROPE_FUSED_QKNORM env flag. |
| atom/plugin/register.py | Switches SGLang backend registration to ATOM’s backend implementation. |
| atom/plugin/config.py | Updates SGLang argv parsing before calling prepare_server_args. |
| atom/plugin/attention_backend/sgl_attn_backend.py | New ATOM SGLang attention backend implementation (aiter + Triton). |
| atom/plugin/attention_backend/init.py | Package init for attention_backend. |
| atom/models/qwen3_moe.py | Adds fused RoPE+QKNorm+KV-cache path for SGL plugin mode and adjusts rope scaling handling. |
| atom/model_ops/radix_attention.py | Routes SGLang attention call with save_kv_cache controlled by env flag. |
| atom/config.py | Plugin-mode hf-config sourcing and rope-parameter compatibility changes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Format1: sglang serve --model-path ... | ||
| # Format2: python3 -m sglang.launch_server --model-path ... | ||
| args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] | ||
| # sglang has no global config variable like vllm, | ||
| # so here construct the server args from sys.argv passed by users | ||
| # this is the only way to get full arguments | ||
| server_args: ServerArgs = prepare_server_args(sys.argv[1:]) | ||
| server_args: ServerArgs = prepare_server_args(args_list) |
There was a problem hiding this comment.
args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] can raise IndexError when _generate_atom_config_from_sglang_config() is called in environments where sys.argv has fewer than 2 elements (e.g., programmatic invocation/tests). Guard len(sys.argv) > 1 before indexing, and fall back to sys.argv[1:] when no subcommand is present.
| # Compatible with both transformers < 5 | ||
| rope_params = getattr(self.hf_config, "rope_scaling", {}) | ||
| if rope_params is None: | ||
| rope_params = {} | ||
| rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) | ||
| rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") | ||
| rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} | ||
| rope_params["rope_theta"] = self.hf_config.rope_theta | ||
| rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") | ||
| self.hf_config.rope_parameters = rope_params |
There was a problem hiding this comment.
rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") is incorrect because rope_params is a dict; getattr(...) will always return the default and ignore an existing "rope_type" key. Use rope_params.get("rope_type", "default") (or read from self.hf_config) so non-default rope settings are preserved.
| # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM | ||
| return self.attn( |
There was a problem hiding this comment.
Typo in comment: filed attn_backend should be field attn_backend.
| "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" | ||
| ) | ||
| == "1", | ||
| "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", |
There was a problem hiding this comment.
Environment variable key ATOM_ROPE_FUSED_QKNORM is reading AITER_ROPE_FUSED_QKNORM, which breaks the usual ATOM_* env naming pattern in this module and makes ATOM_ROPE_FUSED_QKNORM=1 ineffective. Consider reading ATOM_ROPE_FUSED_QKNORM (or renaming the key to AITER_ROPE_FUSED_QKNORM) to keep the variable name consistent for users.
| try: | ||
| from aiter import ( | ||
| flash_attn_varlen_func, | ||
| dtypes, | ||
| get_pa_metadata_info_v1, | ||
| get_pa_metadata_v1, | ||
| pa_fwd_asm, | ||
| pa_persistent_fwd, | ||
| ) | ||
| except ImportError: | ||
| print( | ||
| "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." | ||
| ) | ||
|
|
There was a problem hiding this comment.
The except ImportError: branch prints a message and continues, but the module later uses flash_attn_varlen_func, dtypes, etc., which will raise at runtime if aiter is missing. Prefer raising an ImportError with a clear message (or using logger.error then re-raising) so failures are immediate and actionable.
| slot_id = tl.load(slot_mapping_ptr + tid) | ||
| if slot_id < 0: | ||
| return | ||
| block_id = slot_id // block_size |
There was a problem hiding this comment.
This Triton kernel uses if slot_id < 0: return, but slot_id is a tl.tensor from tl.load, so Python control flow on it will not compile. Replace this with masked loads/stores (e.g., compute a mask = slot_id >= 0 and pass it to tl.load/tl.store, and/or set slot_id = tl.maximum(slot_id, 0) under a mask).
| reshape_and_cache_shuffle_triton( | ||
| k, | ||
| v, | ||
| k_buffer, | ||
| v_buffer, | ||
| cache_loc, | ||
| "auto", | ||
| k_scale, | ||
| v_scale, | ||
| ) |
There was a problem hiding this comment.
reshape_and_cache_shuffle_triton() only enables quant handling when kv_cache_dtype.startswith("fp8"), but callers always pass "auto" here, so the QUANT path is never taken and k_scale/v_scale are ignored. Pass the real kv-cache dtype (e.g., from config/model_runner) or infer it from key_cache.dtype so FP8 kv-cache works correctly.
| self.k_scale = torch.tensor([1.0], dtype=torch.float32) | ||
| self.v_scale = torch.tensor([1.0], dtype=torch.float32) | ||
|
|
There was a problem hiding this comment.
self.k_scale/self.v_scale are created as CPU tensors in __init__, but are later passed alongside GPU kv-cache buffers in plugin mode. This will cause device mismatch errors when the fused path is enabled. Consider registering them as buffers and initializing/moving them to the same device as the model (or creating them on-demand on qkv.device).
| try: | ||
| from aiter import ( | ||
| flash_attn_varlen_func, | ||
| dtypes, | ||
| get_pa_metadata_info_v1, | ||
| get_pa_metadata_v1, | ||
| pa_fwd_asm, | ||
| pa_persistent_fwd, | ||
| ) | ||
| except ImportError: | ||
| print( | ||
| "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." | ||
| ) | ||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
There was a problem hiding this comment.
Ruff E402: import triton / import triton.language as tl come after a try/except block (and a print(...)), so these imports are not at the top of the module. To avoid CI failures, move the Triton imports above the try/except (or include them in the same initial import section) and avoid executable statements before the import block.
| "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" | ||
| ) | ||
| == "1", | ||
| "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", |
There was a problem hiding this comment.
Is it possible that we remove this env flag because IIRC, there is duplicated
| """ | ||
|
|
||
| logger.info("Generate atom config for plugin mode from passed config") | ||
|
|
|
|
||
| assert 1 <= self.tensor_parallel_size <= 8 | ||
| self.hf_config = get_hf_config(self.model) | ||
| if is_plugin_mode(): |
There was a problem hiding this comment.
Here we follow the atom config post init?
No description provided.