Skip to content

[Model Support] Qwen3.5 Support#333

Open
ganyi1996ppo wants to merge 15 commits intomainfrom
ganyi/qwen3.5
Open

[Model Support] Qwen3.5 Support#333
ganyi1996ppo wants to merge 15 commits intomainfrom
ganyi/qwen3.5

Conversation

@ganyi1996ppo
Copy link
Contributor

@ganyi1996ppo ganyi1996ppo commented Mar 15, 2026

Summary

This PR adds Qwen3.5 model support to ATOM, including both dense and MoE variants. The implementation leverages ATOM's vLLM plugin mode to provide seamless integration with vLLM's ecosystem while utilizing ATOM's optimized kernels for linear attention (GatedDeltaNet) layers. Only works on vllm plugin mode.

Key Features

1. Qwen3.5 Model Support

  • Qwen3.5 Dense: 35B-A3B model with hybrid attention architecture
  • Qwen3.5 MoE: MoE variants with sparse expert routing
  • Architecture: Combination of linear attention (GatedDeltaNet), full attention, and MoE layers
  • Implementation: vLLM plugin mode only (not standalone ATOM mode)

2. Multi-Modal Input Support

  • Full vision-language capability via Qwen3VL integration
  • Supports image + text inputs through vLLM's multimodal pipeline
  • Compatible with OpenAI API format for multimodal requests
  • Image encoding via base64 or URLs

Technical Details

Files Added

Model Configuration:

  • atom/model_config/qwen3_5.py - Qwen3.5 dense config
  • atom/model_config/qwen3_5_moe.py - Qwen3.5 MoE config

Model Implementation:

  • atom/models/qwen3_5.py - Main Qwen3.5 model implementation
    • Qwen3_5GatedDeltaNet - Linear attention layer (optimized with ATOM kernels)
    • Qwen3_5Attention - Full attention layer (uses vLLM native implementation)
    • Qwen3_5ForCausalLM - Dense model wrapper
    • Qwen3_5MoeForCausalLM - MoE model wrapper
    • Qwen3_5ForConditionalGeneration - Multimodal model with vision encoder

Plugin Integration:

  • atom/plugin/vllm/attention_backend/attention_gdn.py - GatedDeltaNet backend for vLLM
  • atom/plugin/vllm/model_wrapper.py - Enhanced wrapper for Qwen3.5 models
  • Weight mapping for Qwen3.5 checkpoint loading

Files Modified

Core Components:

  • atom/config.py - Added Qwen3.5 config support
  • atom/model_loader/loader.py - Enhanced weight loading for Qwen3.5
  • atom/model_engine/model_runner.py - vLLM plugin mode integration
  • atom/models/qwen3_next.py - Refactored for Qwen3.5 compatibility
  • atom/model_ops/base_attention.py - Enhanced attention abstraction
  • atom/model_ops/layernorm.py - Added support for Qwen3.5 norm layers
  • atom/model_ops/linear.py - Weight loading improvements
  • atom/model_ops/moe.py - MoE layer enhancements

Architecture

Qwen3.5 uses a hybrid architecture with:

  • Linear Attention (GatedDeltaNet): Layers 0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 14, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 32, 33, 34, 36, 37, 38

    • Uses ATOM's optimized torch.ops.aiter.linear_attention_with_output_base
    • Flash Linear Attention (FLA) kernels
    • Chunk-based gated delta rule
  • Full Attention: Layers 3, 7, 11, 15, 19, 23, 27, 31, 35, 39

    • Uses vLLM's native torch.ops.vllm.unified_attention_with_output
    • Flash Attention via ROCM_AITER_FA backend
    • RoPE positional encoding with Q/K normalization
  • MoE Layers: Integrated into most layers

    • Top-K expert routing
    • Shared expert with gating
    • Fused MoE kernels from AITER

Usage

Starting Server

# Eager Mode
vllm serve /path/to/Qwen3.5-35B-A3B-FP8 \
  --tensor-parallel-size 2 \
  --gpu_memory_utilization 0.7 \
  --attention-backend ROCM_AITER_FA \
  --enforce-eager \
  --port 8000

# With CUDA graphs (production)
vllm serve /path/to/Qwen3.5-35B-A3B-FP8 \
  --tensor-parallel-size 2 \
  --gpu_memory_utilization 0.7 \
  --attention-backend ROCM_AITER_FA \
  --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
  --port 8000

Text-Only Inference

curl -X POST "http://localhost:8000/v1/completions" \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "What is 2+2?",
    "max_tokens": 256,
    "temperature": 0
  }'

Multimodal (Image + Text) Inference

# Encode image to base64
IMAGE_BASE64=$(base64 -w 0 /path/to/image.jpg)

curl -X POST "http://localhost:8000/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -d '{
    "model": "/path/to/Qwen3.5-35B-A3B-FP8",
    "messages": [{
      "role": "user",
      "content": [
        {"type": "text", "text": "Describe this image"},
        {"type": "image_url", "image_url": {
          "url": "data:image/jpeg;base64,'"$IMAGE_BASE64"'"
        }}
      ]
    }],
    "max_tokens": 256,
    "temperature": 0
  }'

Environment Variables

Required:

export VLLM_USE_V1=1
export VLLM_ROCM_USE_AITER=1
export VLLM_ROCM_USE_AITER_MHA=1
# Current Attention Implementation is problematic on hybrid attention model (mixed full attention and linear attenion), so for hybrid model attention fallback is required yet
export ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1  

Testing

Validated on:

  • AMD Instinct MI300X GPUs
  • ROCm 7.0.2+
  • vLLM v1 architecture
  • Models: Qwen3.5-35B-A3B-FP8

Text Only

gsm8k res on Qwen/Qwen3.5-35B-A3B-FP8

# ATOM plugin with vLLM
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7521|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.7362|±  |0.0121|

# vLLM
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7582|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.7445|±  |0.0120|

Image

Request

dog

PORT=${1:-8000}

# Convert image to base64
IMAGE_BASE64=$(base64 -w 0 /home/ygan/dog.png)

curl -X POST "http://localhost:${PORT}/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -d '{
    "model": "/mnt/data/pretrained_model/Qwen/Qwen3.5-35B-A3B-FP8",
    "messages": [
      {
        "role": "user",
        "content": [
          {
            "type": "text",
            "text": "What is in this image? Describe it in detail."
          },
          {
            "type": "image_url",
            "image_url": {
              "url": "data:image/png;base64,'"$IMAGE_BASE64"'"
            }
          }
        ]
      }
    ],
    "max_tokens": 256,
    "temperature": 0,
    "top_p": 1,
    "seed": 123,
    "stream": false
  }' | python3 -m json.tool

Result:

{
    "id": "chatcmpl-8191fab14490fc74",
    "object": "chat.completion",
    "created": 1773631986,
    "model": "/mnt/data/pretrained_model/Qwen/Qwen3.5-35B-A3B-FP8",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "The user wants a detailed description of the image provided.\n\n1.  **Subject Identification:** The main subject is a dog. It looks like a Golden Retriever or a similar mix (maybe a Nova Scotia Duck Tolling Retriever or a mix with a spaniel). It has golden-brown fur.\n2.  **Composition:** It's a close-up portrait, focusing on the head and upper chest. The background is blurred (bokeh), suggesting a shallow depth of field.\n3.  **Physical Features - Head:**\n    *   **Ears:** Floppy, medium-sized, covered in slightly longer, feathery fur. They are set high on the head.\n    *   **Eyes:** Large, dark brown, expressive. They are looking slightly upward and to the left (viewer's left). There are catchlights (reflections) in the eyes, indicating a light source.\n    *   **Forehead:** Smooth, with a slight stop (indentation) between the eyes. The fur is short and sleek here.\n    *   **Nose:** Black, wet-looking, prominent. The nostrils are clearly visible.\n    *   **Muzzle:** Tapered but sturdy. There",
                "refusal": null,
                "annotations": null,
                "audio": null,
                "function_call": null,
                "tool_calls": [],
                "reasoning": null
            },
            "logprobs": null,
            "finish_reason": "length",
            "stop_reason": null,
            "token_ids": null
        }
    ],
    "service_tier": null,
    "system_fingerprint": null,
    "usage": {
        "prompt_tokens": 1048,
        "total_tokens": 1304,
        "completion_tokens": 256,
        "prompt_tokens_details": null
    },
    "prompt_logprobs": null,
    "prompt_token_ids": null,
    "kv_transfer_params": null
}

Implementation Notes

  1. vLLM Plugin Mode Only: This implementation requires vLLM plugin mode. Standalone ATOM mode is not supported for Qwen3.5.

  2. Hybrid Attention: GatedDeltaNet layers use ATOM's optimized kernels, while full attention uses vLLM's native implementation for maximum compatibility.

  3. Multimodal: Inherits from Qwen3VLForConditionalGeneration for vision-language support.

Breaking Changes

None - this is a new model addition.

Related Issues

  • Adds support for Qwen3.5 model family
  • Complements existing Qwen3-Next support
  • Part of ongoing effort to support latest Qwen models

Submission Checklist

@ganyi1996ppo ganyi1996ppo marked this pull request as ready for review March 16, 2026 03:06
Copilot AI review requested due to automatic review settings March 16, 2026 03:06
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 model support (dense and MoE variants) to the ATOM framework, targeting vLLM plugin mode. It introduces model configurations, model implementations with hybrid attention (GatedDeltaNet linear attention + full attention), multimodal support via Qwen3VL integration, and a new GDN attention backend for vLLM.

Changes:

  • Added Qwen3.5 dense and MoE model configs, model implementations, and conditional generation wrappers with multimodal (vision-language) support
  • Added a new GatedDeltaNet attention backend for vLLM plugin mode and refactored existing attention/loader code for Qwen3.5 compatibility
  • Extended weight loading, quantization config, and MoE layers to handle Qwen3.5's separate projection weights and fused expert patterns

Reviewed changes

Copilot reviewed 22 out of 23 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
atom/model_config/qwen3_5.py New: Qwen3.5 dense model configuration
atom/model_config/qwen3_5_moe.py New: Qwen3.5 MoE model configuration
atom/models/qwen3_5.py New: Main Qwen3.5 model implementation (GDN, attention, CausalLM, ConditionalGeneration)
atom/models/interfaces.py New: Multimodal protocol interfaces for model support
atom/models/utils.py Added utility classes (StageMissingLayer, collect_children, no_init_weights, common_prefix)
atom/models/qwen3_next.py Refactored for Qwen3.5 compatibility: changed config access to .text_config, added vLLM attention path
atom/plugin/vllm/attention_backend/attention_gdn.py New: GatedDeltaNet attention backend for vLLM
atom/plugin/vllm/attention_backend/gdn_attn.py New: GDN attention backend wrapper
atom/plugin/vllm/model_wrapper.py Added ATOMForConditionalGeneration multimodal wrapper
atom/plugin/vllm/register.py Registered Qwen3.5 model architectures
atom/plugin/vllm/platform.py Updated get_attn_backend_cls signature for vLLM compatibility
atom/plugin/config.py Added vllm_config field to PluginConfig
atom/plugin/attention.py Changed supported kernel block sizes from [16, 32] to [16]
atom/config.py Enhanced quant config for packed modules; fallback config loading via vLLM; missing return bug
atom/model_loader/loader.py Added WeightsMapper for weight name remapping; plugin mode weight loading changes
atom/model_engine/model_runner.py Added Qwen3.5 to model registry; changed hf_config to text_config
atom/model_ops/base_attention.py Updated linear attention forward to pass layer_name
atom/model_ops/attention_gdn.py Added vLLM forward context support (with is_vllm bug)
atom/model_ops/layernorm.py Switched RMSNormGated to vLLM's rmsnorm_fn (breaks non-vllm)
atom/model_ops/linear.py Extended weight_loader for tuple shard_ids; new QKVZBAParallelLinear shard types
atom/model_ops/moe.py Added w13 shard support; parameterized gate_up_proj name
atom/utils/selector.py Added vLLM-specific GDN attention backend selection

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +376 to +390
# if torch.compiler.is_compiling():
# return self.forward_native(x, z)
# return self.forward_native(x, z)

from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)
Copilot AI review requested due to automatic review settings March 16, 2026 09:04
Signed-off-by: ganyi <ygan@amd.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 model support (dense and MoE variants) to the ATOM framework, leveraging vLLM's plugin mode for multimodal inference. It also refactors shared components (layernorm fusion, weight loading, attention backends) and improves the CI benchmark infrastructure.

Changes:

  • Adds Qwen3.5 dense and MoE model implementations with hybrid attention (GatedDeltaNet linear + full attention), multimodal support via Qwen3VL integration, and vLLM plugin registration
  • Refactors RMSNorm quantization fusion into reusable DualRMSNorm and fuse_rmsnorm_group_quant utilities in layernorm.py, removing duplicate code from deepseek_v2.py
  • Restructures CI benchmark workflow to use external models.json config and adds a dedicated profiler analysis job with a new regression_rerun.py script

Reviewed changes

Copilot reviewed 24 out of 25 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
atom/models/qwen3_5.py New Qwen3.5 model (dense, MoE, multimodal) implementation
atom/models/qwen3_next.py Refactored to support Qwen3.5 via text_config, method extraction
atom/model_config/qwen3_5.py Qwen3.5 dense model configuration
atom/model_config/qwen3_5_moe.py Qwen3.5 MoE model configuration
atom/model_ops/layernorm.py DualRMSNorm, fuse_rmsnorm_group_quant, RMSNorm fused quant refactor
atom/model_ops/linear.py Tuple shard loading, new QKVZBAParallelLinear shard IDs, import change
atom/model_ops/moe.py w13 shard support, make_expert_params_mapping extension
atom/models/deepseek_v2.py Migrated fusion code to layernorm.py, uses DualRMSNorm
atom/config.py QuantizationConfig updates, vLLM config fallback, env var generalization
atom/model_loader/loader.py WeightsMapper class, text_config-aware loading
atom/model_engine/model_runner.py Qwen3.5 model registration, text_config loading
atom/plugin/vllm/model_wrapper.py ATOMForConditionalGeneration multimodal wrapper
atom/plugin/vllm/register.py Qwen3.5 model registration
atom/plugin/vllm/attention_backend/attention_gdn.py GatedDeltaNet vLLM attention backend
atom/plugin/vllm/attention_backend/gdn_attn.py GDN backend class for vLLM plugin
atom/plugin/config.py Added vllm_config field
atom/plugin/attention.py Block size change (16 only)
atom/utils/selector.py vLLM-aware GDN backend selection
atom/utils/envs.py New master RMSNORM_QUANT_FUSION switch, custom all-gather
atom/models/utils.py StageMissingLayer, collect_children, no_init_weights utilities
atom/models/interfaces.py SupportsMultiModal protocol for ATOM models
atom/model_ops/base_attention.py layer_name passthrough, prefix storage
atom/model_ops/attention_gdn.py GatedDeltaNet rename fix, layer_name param
atom/model_ops/attentions/gdn_attn.py GatedDeltaNet rename fix
atom/model_ops/embed_head.py Configurable custom all-gather
docs/model_ops_guide.md DualRMSNorm documentation
docs/environment_variables.md Updated env var docs
.github/workflows/atom-benchmark.yaml Refactored CI: models.json, profiler analysis job
.github/benchmark/models.json Externalized model configs
.github/scripts/regression_rerun.py New regression re-run config generator
.github/scripts/summarize.py model_id in regression report
.github/scripts/atom_test.sh Cache clearing, stop command, result filename
Comments suppressed due to low confidence (1)

atom/model_ops/layernorm.py:708

  • Bug: forward_cuda now unconditionally imports from vllm.model_executor.layers.fla.ops.layernorm_guard, which will fail with ImportError in standalone ATOM mode (non-vLLM). The previous code correctly fell back to self.forward_native(x, z). Since RMSNormGated is a general utility in model_ops/layernorm.py (not specific to vLLM plugin mode), this import should be guarded, e.g. with a try/except that falls back to the native implementation, or a check for vLLM availability.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +436 to +437
self.config.n_shared_experts = 1
self.config.n_routed_experts = self.config.num_experts
@staticmethod
def get_supported_kernel_block_sizes():
return [16, 32]
return [16]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gluon pa does not support block-size 32, so I removed it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some comments here why only support block size 16

self.model = model_class(config)
torch.set_default_device(None)
load_model(self.model, config.model, config.hf_config, config.load_dummy)
if hasattr(config.hc_config, "text_config"):
@ganyi1996ppo
Copy link
Contributor Author

This PR requires ROCm/aiter#2292 merged into aiter

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
# return self.forward_native(x, z)
# return self.forward_native(x, z)

from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert this.



@triton.jit
def shard_qkvzba_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to model ops or aiter..

########################################################

# ConditionalGeneration model scope should only works on plugin mode
if is_vllm():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many required multimodal modules missing in atom, so qwen3.5 only supported on vllm plugin mode yet. And we will extend it to atom native support in the future

Copy link
Collaborator

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review focusing on critical and potential bug issues in the newly added code.

self.model = model_class(config)
torch.set_default_device(None)
load_model(self.model, config.model, config.hf_config, config.load_dummy)
if hasattr(config.hc_config, "text_config"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Critical] Typo: config.hc_config should be config.hf_config

config (which is Config) has no attribute hc_config. This will raise AttributeError at runtime for any model with a text_config sub-config (including Qwen3.5). Should be:

if hasattr(config.hf_config, "text_config"):


self.config = config
self.config.n_shared_experts = 1
self.config.n_routed_experts = self.config.num_experts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Critical] self.config.num_experts will crash for dense (non-MoE) Qwen3.5 models

Qwen3_5Model is shared by both Qwen3_5ForCausalLM (dense) and Qwen3_5MoeForCausalLM (MoE). For dense models, the config is Qwen3_5TextConfig which does not define num_experts, so this line will raise AttributeError.

Suggested fix — guard with a conditional or use getattr:

self.config.n_shared_experts = getattr(self.config, "n_shared_experts", 1)
self.config.n_routed_experts = getattr(self.config, "num_experts", 0)

Or split Qwen3_5Model to avoid unconditionally setting MoE-specific attributes on a dense config.

gemm_a8w8_blockscale_bpreshuffle_triton = None

# For Triton FP8 Blockscale GEMM is mostly slower then AITER GEMM, we turn off Triton FP8 GEMM
from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Potential Bug] Unprotected import lost its fallback guard

Previously gemm_a8w8_blockscale_bpreshuffle_triton was set to None as a safe fallback. Now this import is outside the try-except block — if aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale is not available (e.g., older aiter version), it will crash module import entirely.

Consider wrapping this in its own try-except with a None fallback:

try:
    from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import (
        gemm_a8w8_blockscale_preshuffle as gemm_a8w8_blockscale_bpreshuffle_triton,
    )
except ImportError:
    gemm_a8w8_blockscale_bpreshuffle_triton = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings March 17, 2026 02:03
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 model support (dense and MoE variants) to ATOM, leveraging vLLM's plugin mode for hybrid attention architectures (GatedDeltaNet linear attention + full attention). It includes multimodal (vision-language) support via Qwen3VL integration.

Changes:

  • Adds Qwen3.5 dense and MoE model implementations, configs, and vLLM plugin integration including a GatedDeltaNet attention backend
  • Refactors shared components (Qwen3Next, weight loading, attention selection) to support the new model architecture and vLLM plugin mode
  • Adds utility classes (WeightsMapper, StageMissingLayer, interfaces) and environment variable controls for the new model support

Reviewed changes

Copilot reviewed 23 out of 24 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
atom/models/qwen3_5.py New Qwen3.5 model implementation (dense, MoE, multimodal)
atom/model_config/qwen3_5.py Qwen3.5 dense model configuration
atom/model_config/qwen3_5_moe.py Qwen3.5 MoE model configuration
atom/plugin/vllm/attention_backend/attention_gdn.py GatedDeltaNet attention backend for vLLM plugin mode
atom/plugin/vllm/attention_backend/gdn_attn.py GDN attention backend wrapper for vLLM
atom/plugin/vllm/model_wrapper.py ATOMForConditionalGeneration wrapper for multimodal
atom/plugin/vllm/register.py Register Qwen3.5 models in vLLM plugin registry
atom/plugin/config.py Add vllm_config to PluginConfig
atom/plugin/attention.py Reduce supported block sizes to [16]
atom/models/qwen3_next.py Refactor for Qwen3.5 compatibility and vLLM support
atom/models/interfaces.py New multimodal protocol interfaces (currently unused)
atom/models/utils.py Add utility classes (StageMissingLayer, collect_children, etc.)
atom/model_ops/linear.py Support tuple shard IDs and new QKV loading modes
atom/model_ops/layernorm.py Switch RMSNormGated to use vLLM's rmsnorm_fn
atom/model_ops/base_attention.py Pass layer_name to GDN forward, store prefix
atom/model_ops/attention_gdn.py Rename GatedDetlaNet → GatedDeltaNet, add layer_name param
atom/model_ops/attentions/gdn_attn.py Update import for renamed GatedDeltaNet
atom/model_ops/embed_head.py Configurable custom all-gather via env var
atom/model_loader/loader.py Add WeightsMapper, support weights remapping and text_config
atom/model_engine/model_runner.py Add Qwen3.5 to model arch dict, text_config handling
atom/config.py Enhanced quantization config, vLLM fallback for config loading
atom/utils/selector.py Route GDN attention to vLLM-specific backend when in plugin mode
atom/utils/envs.py Add ATOM_USE_CUSTOM_ALL_GATHER env var

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.model = model_class(config)
torch.set_default_device(None)
load_model(self.model, config.model, config.hf_config, config.load_dummy)
if hasattr(config.hc_config, "text_config"):
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
print(f"register layer {prefix} to static forward context for Mamba")
Comment on lines +52 to +54
from aiter.ops.triton.gemm.basic.gemm_a8w8_blockscale import (
gemm_a8w8_blockscale_preshuffle as gemm_a8w8_blockscale_bpreshuffle_triton,
)
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
print("hidden states shape: ", hidden_states.shape)
Comment on lines +381 to +395
# if torch.compiler.is_compiling():
# return self.forward_native(x, z)
# return self.forward_native(x, z)

from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
from .utils import _merge_multimodal_embeddings
Comment on lines +436 to +437
self.config.n_shared_experts = 1
self.config.n_routed_experts = self.config.num_experts

@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3_5MoeProcessingInfo,
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings March 17, 2026 07:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 model support (both dense and MoE variants) to the ATOM framework, operating exclusively in vLLM plugin mode. The implementation provides a hybrid attention architecture combining GatedDeltaNet linear attention layers (using ATOM's optimized kernels) with full attention layers (using vLLM's native implementation), plus multimodal (vision-language) support via Qwen3VL integration.

Changes:

  • Added Qwen3.5 model configs, model classes, and multimodal wrappers with dedicated weight loading and mapping logic
  • Refactored Qwen3NextGatedDeltaNet, Qwen3NextAttention, and Qwen3NextDecoderLayer to support both standalone and vLLM plugin modes, with new GatedDeltaNet attention backend for vLLM
  • Enhanced weight loading infrastructure with WeightsMapper, fused expert weight support, and tuple-shard loading for MergedColumnParallelLinear

Reviewed changes

Copilot reviewed 22 out of 23 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
atom/model_config/qwen3_5.py New Qwen3.5 dense model configuration (text + vision)
atom/model_config/qwen3_5_moe.py New Qwen3.5 MoE model configuration
atom/models/qwen3_5.py Main Qwen3.5 model implementation with dense, MoE, and multimodal variants
atom/models/qwen3_next.py Refactored base classes for Qwen3.5 compatibility (config access, method extraction)
atom/models/interfaces.py New multimodal interface protocols (appears unused)
atom/models/utils.py Added utility classes/functions for model registration and weight init
atom/plugin/vllm/attention_backend/attention_gdn.py New GatedDeltaNet attention backend for vLLM plugin mode
atom/plugin/vllm/attention_backend/gdn_attn.py Backend wrapper for GDN attention
atom/plugin/vllm/model_wrapper.py Added ATOMForConditionalGeneration wrapper for multimodal models
atom/plugin/vllm/register.py Registered Qwen3.5 model classes
atom/plugin/config.py Added vllm_config to plugin config
atom/plugin/attention.py Reduced supported block sizes to [16]
atom/config.py Enhanced QuantizationConfig with vllm integration, fallback config loading
atom/model_loader/loader.py Added WeightsMapper, fused expert loading, weight name remapping
atom/model_ops/linear.py Added tuple shard_id support and new QKVZBA shard ids
atom/model_ops/layernorm.py Switched RMSNormGated to use vllm's rmsnorm kernel
atom/model_ops/base_attention.py Added layer_name parameter and prefix tracking
atom/model_ops/attention_gdn.py Fixed typo GatedDetlaNet → GatedDeltaNet, added layer_name param
atom/model_ops/attentions/gdn_attn.py Updated to use corrected GatedDeltaNet class name
atom/model_ops/embed_head.py Made custom all-gather configurable via env var
atom/utils/selector.py Added vLLM-specific GDN attention backend selection
atom/utils/envs.py Added ATOM_USE_CUSTOM_ALL_GATHER env var

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +385 to +395
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.atom_config = atom_config
config = atom_config.hf_config.text_config
Signed-off-by: ganyi <ygan@amd.com>
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
print(f"register layer {prefix} to static forward context for Mamba")
Copy link
Contributor

@zejunchen-zejun zejunchen-zejun Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use logger here



@runtime_checkable
class SupportsMultiModal(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ganyi1996ppo
Is this class entirely ported from vLLM? If yes, we can directly use the vLLM SupportsMultiModal to reduce the effort of maintain.

Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings March 17, 2026 08:55
@contextmanager
def _mark_language_model(
self,
atom_config: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ganyi1996ppo
Here the argument is atom_config, while the datatype comment is nn.Module, it seems misleading.
Meanwhile the passed argument is vllm_config, code is here:
with self._mark_language_model(vllm_config):

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 model support to ATOM (both dense and MoE variants), operating exclusively in vLLM plugin mode. The implementation leverages a hybrid attention architecture with GatedDeltaNet (linear attention) and full attention layers, multimodal (vision-language) capabilities via Qwen3VL integration, and a dual-class pattern for vLLM compatibility.

Changes:

  • Added Qwen3.5 dense and MoE model implementations with hybrid attention, multimodal support, and custom weight loading for both FP8 and BF16 checkpoint formats
  • Refactored Qwen3NextGatedDeltaNet and related classes to support Qwen3.5's separate QKVZ/BA projections, and enhanced the weight loader to handle fused expert weights and tuple shard IDs
  • Added vLLM plugin infrastructure: GDN attention backend, conditional generation wrapper (ATOMForConditionalGeneration), model registration, and config enhancements

Reviewed changes

Copilot reviewed 20 out of 21 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
atom/models/qwen3_5.py New Qwen3.5 model implementations (dense, MoE, multimodal) with triton kernels and dual-class vLLM pattern
atom/models/qwen3_next.py Refactored GatedDeltaNet, decoder layer, and MoE block for Qwen3.5 compatibility
atom/model_config/qwen3_5.py Qwen3.5 dense configuration (text, vision, composite)
atom/model_config/qwen3_5_moe.py Qwen3.5 MoE configuration
atom/plugin/vllm/attention_backend/attention_gdn.py GatedDeltaNet attention backend for vLLM plugin mode
atom/plugin/vllm/attention_backend/gdn_attn.py Backend wrapper class for GDN attention
atom/plugin/vllm/model_wrapper.py New ATOMForConditionalGeneration wrapper with multimodal/MRoPE support
atom/plugin/vllm/register.py Model registration entries for Qwen3.5
atom/model_loader/loader.py WeightsMapper, fused expert loading, and enhanced weight loading pipeline
atom/model_ops/linear.py Tuple shard_id support and new shard types (qkv, z, b, a) in weight loaders
atom/config.py QuantizationConfig enhancements, vLLM fallback for config loading, packed_modules_mapping support
atom/model_ops/base_attention.py Added layer_name param and prefix attribute to linear attention
atom/model_ops/attention_gdn.py Fixed GatedDetlaNetGatedDeltaNet typo, added layer_name param
atom/model_ops/attentions/gdn_attn.py Updated import for renamed GatedDeltaNet class
atom/model_ops/embed_head.py Configurable custom all-gather via env variable
atom/models/utils.py Added utility classes/functions for model construction
atom/plugin/config.py Added vllm_config field to PluginConfig
atom/plugin/attention.py Reduced supported block sizes to [16] only
atom/utils/selector.py vLLM-aware attention backend selection
atom/utils/envs.py New ATOM_USE_CUSTOM_ALL_GATHER env variable

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings March 18, 2026 01:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Qwen3.5 model family support to ATOM (dense + MoE, including multimodal ConditionalGeneration) targeting vLLM plugin mode, and extends ATOM’s plugin/loader infrastructure to support Qwen3.5-specific attention backends and checkpoint weight mappings.

Changes:

  • Add new Qwen3.5 / Qwen3.5-MoE HF config definitions and a new atom/models/qwen3_5.py implementation (hybrid linear attention + full attention + multimodal wrappers).
  • Add vLLM-plugin GatedDeltaNet (GDN) attention backend and wire backend selection to choose plugin vs non-plugin implementations.
  • Enhance weight loading utilities (name mapping + packed-module handling + fused expert handling) and expose a new env toggle for TP all-gather.

Reviewed changes

Copilot reviewed 20 out of 21 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
atom/utils/selector.py Passes plugin-mode flag into attention backend selection; routes GDN backend to vLLM plugin implementation.
atom/utils/envs.py Adds ATOM_USE_CUSTOM_ALL_GATHER env var definition.
atom/plugin/vllm/register.py Registers Qwen3.5 ConditionalGeneration architectures for vLLM plugin mode.
atom/plugin/vllm/model_wrapper.py Extends vLLM wrapper interfaces to support multimodal + MRoPE forwarding.
atom/plugin/vllm/attention_backend/gdn_attn.py Adds a vLLM plugin attention-backend entrypoint for GDN.
atom/plugin/vllm/attention_backend/attention_gdn.py Implements GatedDeltaNet attention for vLLM v1 attention metadata/caching.
atom/plugin/config.py Stores the raw vllm_config on the plugin config object.
atom/plugin/attention.py Changes supported kernel KV block sizes exposed to vLLM plugin attention backend.
atom/models/utils.py Adds utilities for “no init weights” / collecting children and a common_prefix helper.
atom/models/qwen3_next.py Refactors Qwen3-Next to improve compatibility with Qwen3.5 + vLLM attention integration.
atom/models/qwen3_5.py New Qwen3.5 implementation including multimodal ConditionalGeneration in vLLM plugin mode.
atom/model_ops/linear.py Extends packed-shard weight loading for merged column-parallel linears; adds additional shard ids.
atom/model_ops/embed_head.py Adds env-controlled toggle for custom TP all-gather in LM head.
atom/model_ops/base_attention.py Threads layer_name into linear-attention op path and stores prefix.
atom/model_ops/attentions/gdn_attn.py Fixes GatedDeltaNet naming typo and type annotations.
atom/model_ops/attention_gdn.py Renames GatedDetlaNetGatedDeltaNet and updates forward signature to accept layer_name.
atom/model_loader/loader.py Adds WeightsMapper, supports mapping during iteration, adds fused-expert handling hooks, extends packed-module loading behavior.
atom/model_config/qwen3_5.py Adds Qwen3.5 dense config definitions.
atom/model_config/qwen3_5_moe.py Adds Qwen3.5 MoE config definitions.
atom/config.py Adds vLLM config fallback for unsupported HF configs; extends quant config with packed-module mapping awareness.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +78 to +81
"ATOM_USE_CUSTOM_ALL_GATHER": lambda: os.getenv(
"ATOM_USE_CUSTOM_ALL_GATHER", "0"
).lower()
== "1",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addded

Comment on lines +333 to +340
for shard_idx, target_name in enumerate(packed_value):
param_name = name.replace(k, target_name)
if "output_scale" not in param_name:
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
futures.append(
executor.submit(
weight_loader, param, weight_tensor, shard_idx
Comment on lines +495 to +499
for ori_param, (
model_param,
shard_id,
) in self.packed_modules_mapping.items():
if proj_name in model_param:
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings March 18, 2026 08:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Qwen3.5 (dense + MoE, incl. multimodal ConditionalGeneration) support to ATOM, targeting vLLM plugin mode integration and hybrid attention (GatedDeltaNet + full attention) interoperability, along with weight-loading/mapping enhancements needed for Qwen3.5 checkpoints.

Changes:

  • Introduces Qwen3.5 model/config implementations (dense + MoE) and a vLLM-side GatedDeltaNet attention backend.
  • Extends plugin-mode integration (model registry/wrapper, config plumbing) and improves weight loading/mapping to support packed/fused patterns.
  • Adds a fused Triton split/chunk kernel used by the Qwen3.5 linear-attention path and adds an env toggle for custom all-gather.

Reviewed changes

Copilot reviewed 22 out of 23 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tests/test_envs.py Tracks new env var in test cleanup list (needs explicit default/override assertions).
atom/utils/selector.py Routes GDN attention backend selection differently when running under vLLM plugin mode.
atom/utils/envs.py Adds ATOM_USE_CUSTOM_ALL_GATHER env var (default enabled).
atom/plugin/vllm/register.py Registers Qwen3.5 ConditionalGeneration architectures for vLLM model registry overrides.
atom/plugin/vllm/model_wrapper.py Extends wrapper interfaces for multimodal/MRoPE and maps Qwen3.5 arch → ATOM implementations.
atom/plugin/vllm/attention_backend/gdn_attn.py Adds vLLM plugin attention backend wrapper for GDN.
atom/plugin/vllm/attention_backend/attention_gdn.py Implements vLLM-plugin-mode GatedDeltaNet attention path.
atom/plugin/vllm/attention_backend/init.py Package marker for vLLM attention backend modules.
atom/plugin/config.py Stores full vllm_config in plugin config for downstream access.
atom/plugin/attention.py Updates supported kernel block sizes list for plugin attention backend.
atom/models/utils.py Adds utilities for module-child collection and meta-device init avoidance.
atom/models/qwen3_next.py Refactors Qwen3-Next components for improved vLLM compatibility and config variations.
atom/models/qwen3_5.py Adds Qwen3.5 dense/MoE model implementations + vLLM multimodal wrappers and weight mapping hooks.
atom/model_ops/split_chunk.py New fused Triton kernel to split/chunk Qwen3.5 projection outputs efficiently.
atom/model_ops/linear.py Enhances packed shard weight loading and expands QKVZBA shard IDs; adjusts Triton GEMM imports.
atom/model_ops/embed_head.py Makes TP all-gather implementation selectable via env var.
atom/model_ops/base_attention.py Passes layer_name through linear-attention custom op to backend impl.
atom/model_ops/attentions/gdn_attn.py Fixes GatedDeltaNet naming and type annotations.
atom/model_ops/attention_gdn.py Renames class to GatedDeltaNet and updates forward signature to accept layer_name.
atom/model_loader/loader.py Adds WeightsMapper, plugin-mode config selection, packed-module list handling, and fused-expert loading hooks.
atom/model_config/qwen3_5.py Adds Qwen3.5 dense HF config definitions (text + vision).
atom/model_config/qwen3_5_moe.py Adds Qwen3.5 MoE HF config definitions (text + vision).
atom/config.py Improves HF config loading fallback in plugin mode and passes vLLM quant info into QuantizationConfig.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +331 to +343
if isinstance(packed_value, list):
# Checkpoint has fused weight, split into separate params
for shard_idx, target_name in enumerate(packed_value):
param_name = name.replace(k, target_name)
if "output_scale" not in param_name:
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
futures.append(
executor.submit(
weight_loader, param, weight_tensor, shard_idx
)
)
loaded_weights_record.add(prefix + param_name)
tl.store(z_ptr + z_out_base + dim_idx, z_vals, mask=mask)

# Store zeros to core_attn_out: coalesced write
zeros = tl.zeros([BLOCK_SIZE], dtype=tl.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo fix this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Comment on lines +149 to +150
# Must be >= head_v_dim (128) and >= 2*num_v_heads_tp (32)
BLOCK_SIZE = 128
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ba proj data will only need a program to handle, the others are all handled by 128

Comment on lines 24 to 28
"ATOM_DISABLE_MMAP",
"ATOM_DISABLE_VLLM_PLUGIN",
"ATOM_DISABLE_VLLM_PLUGIN_ATTENTION",
"ATOM_USE_CUSTOM_ALL_GATHER",
]
Signed-off-by: ganyi <ygan@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants