Skip to content

[feat] Make ATOM work with SGLang out-of-tree#355

Draft
zhuyuhua-v wants to merge 7 commits intoROCm:mainfrom
zejunchen-zejun:plugin_for_sgl_oot
Draft

[feat] Make ATOM work with SGLang out-of-tree#355
zhuyuhua-v wants to merge 7 commits intoROCm:mainfrom
zejunchen-zejun:plugin_for_sgl_oot

Conversation

@zhuyuhua-v
Copy link
Contributor

@zhuyuhua-v zhuyuhua-v commented Mar 19, 2026

Motivation

ATOM currently integrates with SGLang by maintaining a forked version of SGLang that adds --model-impl atom support. This approach requires invasive modifications to SGLang's codebase (6+ places), and imposes a continuous maintenance burden to keep the fork synchronized with upstream SGLang updates.

SGLang PR #13429 introduced the SGLANG_EXTERNAL_MODEL_PACKAGE mechanism, which allows external packages to register custom model implementations without modifying upstream SGLang. This provides a clean, officially supported extension point that perfectly fits ATOM's use case.

By adopting this out-of-tree (OOT) approach, we eliminate the need for a SGLang fork entirely. ATOM users can run optimized models on AMD GPUs using unmodified upstream SGLang with a single environment variable:

export SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.oot
python3 -m sglang.launch_server --model-path <model_path> ...

Comparison with the fork-based approach

Dimension --model-impl atom (fork) SGLANG_EXTERNAL_MODEL_PACKAGE (OOT)
SGLang source Modified fork, requires continuous sync Upstream SGLang, zero modifications
Model registration Built-in ATOMForCausalLM + ModelImpl.ATOM External package with per-architecture EntryClass
Supported models Single wrapper for all architectures One submodule per model family, precise per-arch registration
Maintenance cost Must track and merge upstream changes Only maintain ATOM-side code
Adding new models Modify the fork Add a .py file under atom/plugin/sglang/oot/

Design

SGLang's External Model Package Mechanism

SGLang's SGLANG_EXTERNAL_MODEL_PACKAGE (PR #13429) works through an automatic module discovery mechanism in sglang/srt/models/registry.py:

ModelRegistry = _ModelRegistry()
ModelRegistry.register("sglang.srt.models")                     # 1. Register built-in models

if external_pkg := envs.SGLANG_EXTERNAL_MODEL_PACKAGE.get():
    ModelRegistry.register(external_pkg, overwrite=True)         # 2. Register external models (overwrite)

When SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.oot is set, the registry:

  1. Imports the package (executes __init__.py)
  2. Uses pkgutil.iter_modules to discover all .py submodules
  3. Imports each submodule and collects classes from EntryClass attributes
  4. Registers them by cls.__name__, with overwrite=True to replace built-in implementations

ATOM OOT Package Structure

atom/plugin/sglang/oot/
├── __init__.py       # Package entry
├── deepseek.py       # DeepseekV2ForCausalLM, DeepseekV3ForCausalLM wrappers
└── qwen3_moe.py      # Qwen3MoeForCausalLM, Qwen3ForCausalLM wrappers

Each submodule defines thin wrapper classes that:

  • Conform to SGLang's model interface (__init__, forward, load_weights signatures)
  • Delegate to ATOM's optimized implementation via atom.prepare_model(config, engine="sglang")
  • Bridge the output format by converting ATOM's hidden_states output to SGLang's LogitsProcessorOutput

Architecture Diagram

image

Execution Flow

image

Model Wrapper Interface

Each OOT wrapper conforms to SGLang's model interface:

class Qwen3MoeForCausalLM(nn.Module):
    def __init__(self, config, quant_config=None, prefix=""):
        # Delegate model creation to ATOM
        self.model = atom.prepare_model(config=config, engine="sglang")
        self.logits_processor = LogitsProcessor(config)

    def forward(self, input_ids, positions, forward_batch, ...) -> LogitsProcessorOutput:
        # ATOM model returns hidden_states; convert to SGLang's expected output
        hidden_states = self.model(input_ids=input_ids, positions=positions,
                                   forward_batch=forward_batch, ...)
        return self.logits_processor(input_ids, hidden_states,
                                     self.model.lm_head, forward_batch)

    def load_weights(self, weights):
        # Delegate to ATOM's weight loading (handles sharding, quantization, etc.)
        self.model.load_weights(weights)

EntryClass = [Qwen3MoeForCausalLM, Qwen3ForCausalLM]

The wrapper class name must match the HF config's architectures field (e.g. "Qwen3MoeForCausalLM"), because SGLang uses cls.__name__ as the registry key.

Attention

SGLang provides the register_attention_backend decorator for registering custom attention backends. ATOM leverages this mechanism to inject its optimized attention implementation:

@register_attention_backend("aiter")
def create_atom_backend(runner):
    from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl
    return ATOMAttnBackendForSgl(runner)

The ATOMAttnBackendForSgl backend provides:

  • Prefill: flash_attn_varlen_func from aiter for variable-length attention
  • Decode: pa_persistent_fwd / pa_fwd_asm for paged attention with optimized AMD GPU kernels
  • KV Cache: Shuffle-based layout via reshape_and_cache_shuffle_triton
  • CUDA Graph: Support for capture and replay in decode phase

This registration uses the name "aiter" to align with SGLang's existing attention backend selection, ensuring the ATOM backend is transparently selected when running on AMD GPUs.

Supported Models

Model Architecture Name OOT Submodule
DeepSeek-R1 / V3 DeepseekV3ForCausalLM deepseek.py
Qwen3 MoE (235B-A22B) Qwen3MoeForCausalLM qwen3_moe.py

How to Add a New Model

Adding a new model to ATOM's OOT package requires no changes to upstream SGLang and minimal boilerplate:

  1. Add the model implementation to ATOM core (under atom/models/)
  2. Register it in atom/plugin/register.py's _ATOM_SUPPORTED_MODELS
  3. Create a new .py file under atom/plugin/sglang/oot/ with the wrapper class and EntryClass

Taking Qwen3 MoE as an example:

# atom/plugin/sglang/oot/qwen3_moe.py

class Qwen3MoeForCausalLM(nn.Module):
    def __init__(self, config, quant_config=None, prefix=""):
        super().__init__()
        import atom
        self.model = atom.prepare_model(config=config, engine="sglang")
        self.logits_processor = LogitsProcessor(config)

    def forward(self, input_ids, positions, forward_batch, ...) -> LogitsProcessorOutput:
        hidden_states = self.model(input_ids=input_ids, positions=positions,
                                   forward_batch=forward_batch, ...)
        return self.logits_processor(input_ids, hidden_states,
                                     self.model.lm_head, forward_batch)

    def load_weights(self, weights):
        self.model.load_weights(weights)

class Qwen3ForCausalLM(Qwen3MoeForCausalLM):
    pass

EntryClass = [Qwen3MoeForCausalLM, Qwen3ForCausalLM]

No changes to __init__.py, environment variables, or SGLang code are needed. SGLang's pkgutil.iter_modules automatically discovers the new file.

Usage

# Use upstream SGLang + ATOM via external model package
export PYTHONPATH=/path/to/upstream-sglang/python:/path/to/ATOM
export SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.oot

python3 -m sglang.launch_server \
    --model-path /data/models/Qwen3-235B-A22B-Instruct-2507-FP8 \
    --tensor-parallel-size 8 \
    --expert-parallel-size 8 \
    --kv-cache-dtype fp8_e4m3 \
    --mem-fraction-static 0.8 \
    --page-size 1024 \
    --cuda-graph-max-bs 16

Switching models only requires changing --model-path. SGLang automatically selects the correct ATOM wrapper based on the model's HF config.json architectures field.

Limitations

  • Models not yet supported by ATOM's OOT package will fall back to SGLang's built-in implementations. This fallback works correctly since the OOT package only overrides specific architecture names.

PRs

Signed-off-by: zhuyuhua-v <yuhzhu@amd.com>
Signed-off-by: zhuyuhua-v <yuhzhu@amd.com>
Signed-off-by: zhuyuhua-v <yuhzhu@amd.com>
Signed-off-by: zhuyuhua-v <yuhzhu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants