Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker/Dockerfile_ascend_a3
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
ARG ASCEND_DEVICE_TYPE=ascend_a3
ARG ASCEND_HUB=swr.cn-south-1.myhuaweicloud.com/ascendhub

FROM ${ASCEND_HUB}/cann:8.3.rc1-a3-openeuler24.03-py3.11 AS ascend_a3_base
FROM ${ASCEND_HUB}/cann:8.5.0-a3-openeuler24.03-py3.11 AS ascend_a3_base

FROM ${ASCEND_DEVICE_TYPE}_base AS builder
ENV DEBIAN_FRONTEND=noninteractive
Expand All @@ -22,6 +22,6 @@ ARG LMDEPLOY_TAG=main
RUN --mount=type=cache,target=/root/.cache \
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \
pip install --no-cache-dir torch==2.8.0 torch-npu==2.8.0 torchvision==0.23.0 && \
pip install --no-cache-dir torch==2.9.0 torch-npu==2.9.0 torchvision==0.24.0 && \
TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \
LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG}
156 changes: 153 additions & 3 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import math
import os
import re
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Dict, Tuple

import torch
import torch.distributed as dist

from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.utils import get_logger

from ..moe import DlinferMoECommType, DlinferMoeMetadata
from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -40,6 +45,30 @@ def is_Ascend310P(cls) -> bool:
def is_Ascend910(cls) -> bool:
return cls.device_name().startswith(cls.Ascend910)

@classmethod
@lru_cache(maxsize=1)
def soc_version(cls) -> int:
return torch.npu.get_soc_version()

@classmethod
def is_A2(cls) -> bool:
return 220 <= cls.soc_version() <= 225

@classmethod
def is_A3(cls) -> bool:
return 250 <= cls.soc_version() <= 255


@dataclass
class DistMeta:
dp_size: int
tp_size: int
ep_size: int
tp_rank: int
ep_rank: int
tp_group: torch.distributed.ProcessGroup
ep_group: torch.distributed.ProcessGroup


class AscendKVQuantMeta:
has_set_value: bool = False
Expand Down Expand Up @@ -88,10 +117,10 @@ def set_value(cls, device: str, dtype: torch.dtype, record_file: str, total_laye

class AscendOpsBackend(DlinferOpsBackend):
"""Ascend layer backend."""
enable_graph = False
half_negative_inf = torch.finfo(torch.float16).min
enable_graph: bool = False
total_slots = None
max_batches = None
dist_meta: DistMeta = None

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -232,6 +261,103 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s

return kv_start_indices, attention_mask

def get_dist_meta():
if cls.dist_meta is not None:
return cls.dist_meta
dist_ctx = get_dist_manager().current_context()
dp_size, tp_size, ep_size = dist_ctx.dist_config.dp, dist_ctx.dist_config.tp, dist_ctx.dist_config.ep
tp_rank, ep_rank = dist_ctx.attn_tp_group.rank, dist_ctx.ep_rank
tp_group = dist_ctx.attn_tp_group.gpu_group
ep_group = dist_ctx.ep_gpu_group
cls.dist_meta = DistMeta(dp_size=dp_size,
tp_size=tp_size,
ep_size=ep_size,
tp_rank=tp_rank,
ep_rank=ep_rank,
tp_group=tp_group,
ep_group=ep_group)
return cls.dist_meta

def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
if ep_size <= 1:
return 0, 0, 0
# get padded_tokens_current_rank
is_graph = cls.enable_graph and step_context.is_decoding
if is_graph:
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size
actual_tokens_current_rank = step_context.q_seqlens.shape[0]
padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank),
cls.max_batches)
else:
actual_tokens_current_rank = step_context.q_seqlens.sum().item()
padded_tokens_current_rank = actual_tokens_current_rank
# get max_tokens_across_dp
if dp_size > 1:
runtime_tokens_tensor = torch.tensor([padded_tokens_current_rank],
dtype=step_context.q_seqlens.dtype,
device=torch.npu.current_device())
world_size = dp_size * tp_size
runtime_tokens_buffer = torch.zeros([world_size],
dtype=step_context.q_seqlens.dtype,
device=torch.npu.current_device())
dist.all_gather_into_tensor(runtime_tokens_buffer, runtime_tokens_tensor, ep_group)
max_tokens_across_dp = torch.max(runtime_tokens_buffer).item()
else:
max_tokens_across_dp = padded_tokens_current_rank
return actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp

@lru_cache
def init_mc2_token_capacity(tp_size):
max_num_tokens = min(cls.max_batches, 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
return num_tokens_per_tp_rank * tp_size

def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
if ep_size <= 1:
return DlinferMoECommType.ALLGATHER
mc2_token_capacity = init_mc2_token_capacity(tp_size)
is_graph = cls.enable_graph and step_context.is_decoding
if is_graph:
max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size
if SocVersion.is_A2():
if max_tokens_across_dp <= mc2_token_capacity and dp_size * tp_size >= 16:
return DlinferMoECommType.MC2
else:
return DlinferMoECommType.ALLGATHER
elif SocVersion.is_A3():
if max_tokens_across_dp <= mc2_token_capacity:
return DlinferMoECommType.MC2
else:
return DlinferMoECommType.ALLTOALL
else:
raise ValueError(f'Unsupported soc_version: {SocVersion.soc_version()}')

def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp, tp_size,
moe_comm_type):
x_active_mask = None
if moe_comm_type == DlinferMoECommType.MC2:
padded_size = math.ceil(max_tokens_across_dp / tp_size) * tp_size
pad_size = padded_size - padded_tokens_current_rank
x_active_mask = torch.ones(actual_tokens_current_rank,
dtype=torch.bool,
device=torch.npu.current_device())
elif moe_comm_type == DlinferMoECommType.ALLTOALL:
pad_size = tp_size - padded_tokens_current_rank
elif moe_comm_type == DlinferMoECommType.ALLGATHER:
pad_size = max_tokens_across_dp - padded_tokens_current_rank
else:
pad_size = 0
return pad_size, x_active_mask

@lru_cache(maxsize=1)
def get_moe_group_name(group):
if group is None:
return None
local_rank = torch.distributed.get_rank(group=group)
backend = group._get_backend(torch.device('npu'))
group_name = backend.get_hccl_comm_name(local_rank)
return group_name

q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,
is_unpaged_prefill)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,
Expand Down Expand Up @@ -274,8 +400,32 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
)

step_context.attn_metadata = attn_metadata

cls.dist_meta = get_dist_meta()
actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp = get_tokens_info(
cls.dist_meta.dp_size, cls.dist_meta.tp_size, cls.dist_meta.ep_size, cls.dist_meta.ep_group)
moe_comm_type = select_moe_comm_type(max_tokens_across_dp, cls.dist_meta.dp_size, cls.dist_meta.tp_size,
cls.dist_meta.ep_size)
pad_size, x_active_mask = get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank,
max_tokens_across_dp, cls.dist_meta.tp_size, moe_comm_type)
moe_group_name = get_moe_group_name(cls.dist_meta.ep_group)

moe_metadata = DlinferMoeMetadata(
max_tokens_across_dp=max_tokens_across_dp,
pad_size=pad_size,
dp_size=cls.dist_meta.dp_size,
tp_size=cls.dist_meta.tp_size,
ep_size=cls.dist_meta.ep_size,
tp_rank=cls.dist_meta.tp_rank,
ep_rank=cls.dist_meta.ep_rank,
tp_group=cls.dist_meta.tp_group,
ep_group=cls.dist_meta.ep_group,
moe_comm_type=moe_comm_type,
x_active_mask=x_active_mask,
moe_group_name=moe_group_name,
)
step_context.moe_metadata = moe_metadata
return step_context

@staticmethod
Expand Down
49 changes: 44 additions & 5 deletions lmdeploy/pytorch/backends/dlinfer/moe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.

import os
from typing import Callable, List

import torch

from lmdeploy.pytorch.kernels.dlinfer import DlinferMoECommType # noqa: F401
from lmdeploy.pytorch.kernels.dlinfer import DlinferMoeMetadata # noqa: F401
from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager

from ..moe import FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder, SoftmaxTopKImpl

Expand All @@ -19,7 +22,9 @@ def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):
raise NotImplementedError('Group router not supported')

def forward(self, x: torch.Tensor):
routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k)
step_context = get_step_ctx_manager().current_context()
moe_metadata = getattr(step_context, 'moe_metadata', None)
routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k, moe_metadata)
return routing_weights, selected_experts


Expand All @@ -35,17 +40,42 @@ def build(top_k: int, dim: int = -1, n_groups: int = -1):
class DlinferFusedMoEImpl(FusedMoEImpl):
"""Dlinfer fused moe implementation."""

def __init__(self, top_k: int, renormalize: bool = False):
def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False,
ep_size: int = 1,
ep_group: torch.distributed.ProcessGroup = None):
self.top_k = top_k
self.num_experts = num_experts
self.renormalize = renormalize
self.ep_size = ep_size
self.ep_group = ep_group
self.expert_ids_per_ep_rank = None
if self.ep_size > 1:
self.expert_ids_per_ep_rank = torch.tensor(
[i % (self.num_experts // self.ep_size) for i in range(num_experts)],
dtype=torch.int32,
device=torch.cuda.current_device(),
)

def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
"""Update weights."""
device_type = gate_up_weights.device.type
if device_type in ['npu']:
if os.getenv('DLINFER_RESET_MOE_UPDATE_WEIGHTS', '0') == '1':
return gate_up_weights, down_weights
return gate_up_weights.transpose(-1, -2).contiguous(), down_weights.transpose(-1, -2).contiguous()
return gate_up_weights, down_weights

def ep_expert_list(self, world_size: int, rank: int):
"""Experts list of current rank."""
num_experts = self.num_experts
expert_per_rank = (num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, num_experts)
return list(range(first_expert, last_expert))

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
Expand All @@ -59,8 +89,13 @@ def forward(self,
"""forward."""
assert gate_up_bias is None
assert down_bias is None

step_context = get_step_ctx_manager().current_context()
moe_metadata = getattr(step_context, 'moe_metadata', None)
if moe_metadata is not None:
moe_metadata.expert_ids_per_ep_rank = self.expert_ids_per_ep_rank
return fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, self.top_k,
self.renormalize)
self.renormalize, moe_metadata)


class DlinferFusedMoEBuilder(FusedMoEBuilder):
Expand All @@ -76,4 +111,8 @@ def build(top_k: int,
layer_idx: int = 0,
out_dtype: torch.dtype = torch.bfloat16):
"""Build from mlp."""
return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize)
return DlinferFusedMoEImpl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize,
ep_size=ep_size,
ep_group=ep_group)
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/configurations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ def flash_mla_available():
# use flash_mla by default if it is installed
use_flash_mla = False
try:
# torch_npu device_properties doesn't have 'major' attribute
"""In some torch_npu versions, device_properties doesn't have 'major'
attribute; In other torch_npu versions, the value of major is None."""
device_properties = torch.cuda.get_device_properties(0)
if hasattr(device_properties, 'major') and device_properties.major >= 9:
major = getattr(device_properties, 'major', None)
if isinstance(major, int) and major >= 9:
import flash_mla # noqa
use_flash_mla = True
except ImportError:
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .flash_attention import flash_attention_fwd
from .fused_moe import fused_moe
from .fused_moe import DlinferMoECommType, DlinferMoeMetadata, fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .pagedattention import paged_attention_fwd
Expand All @@ -15,6 +15,8 @@
'apply_rotary_pos_emb',
'awq_linear',
'fill_kv_cache',
'DlinferMoECommType',
'DlinferMoeMetadata',
'fused_moe',
'paged_attention_fwd',
'flash_attention_fwd',
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import Tensor
from torch import Tensor


def flash_attention_fwd(
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import MoECommType as DlinferMoECommType # noqa: F401
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata
from torch import Tensor


Expand All @@ -11,6 +13,8 @@ def fused_moe(
topk_ids: Tensor,
topk: int,
renormalize: bool,
moe_metadata: DlinferMoeMetadata,
):
"""Dlinfer fused moe."""
return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize)
return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize,
moe_metadata)
8 changes: 6 additions & 2 deletions lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata
from torch import Tensor


def moe_gating_topk_softmax(router_logits: Tensor, topk: int):
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk)
def moe_gating_topk_softmax(router_logits: Tensor, topk: int,
moe_metadata: DlinferMoeMetadata) -> Tuple[Tensor, Tensor]:
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk, moe_metadata)
return routing_weights, selected_experts
Loading
Loading