diff --git a/docker/Dockerfile_ascend_a3 b/docker/Dockerfile_ascend_a3 index d8fc152ed1..aa975d3b9c 100644 --- a/docker/Dockerfile_ascend_a3 +++ b/docker/Dockerfile_ascend_a3 @@ -4,7 +4,7 @@ ARG ASCEND_DEVICE_TYPE=ascend_a3 ARG ASCEND_HUB=swr.cn-south-1.myhuaweicloud.com/ascendhub -FROM ${ASCEND_HUB}/cann:8.3.rc1-a3-openeuler24.03-py3.11 AS ascend_a3_base +FROM ${ASCEND_HUB}/cann:8.5.0-a3-openeuler24.03-py3.11 AS ascend_a3_base FROM ${ASCEND_DEVICE_TYPE}_base AS builder ENV DEBIAN_FRONTEND=noninteractive @@ -22,6 +22,6 @@ ARG LMDEPLOY_TAG=main RUN --mount=type=cache,target=/root/.cache \ pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \ - pip install --no-cache-dir torch==2.8.0 torch-npu==2.8.0 torchvision==0.23.0 && \ + pip install --no-cache-dir torch==2.9.0 torch-npu==2.9.0 torchvision==0.24.0 && \ TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \ LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG} diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index e48ab79f18..484cbd1b72 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -1,17 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. import itertools +import math import os import re +from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Dict, Tuple import torch +import torch.distributed as dist from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.distributed import get_dist_manager from lmdeploy.utils import get_logger +from ..moe import DlinferMoECommType, DlinferMoeMetadata from ..op_backend import DlinferOpsBackend logger = get_logger('lmdeploy') @@ -40,6 +45,30 @@ def is_Ascend310P(cls) -> bool: def is_Ascend910(cls) -> bool: return cls.device_name().startswith(cls.Ascend910) + @classmethod + @lru_cache(maxsize=1) + def soc_version(cls) -> int: + return torch.npu.get_soc_version() + + @classmethod + def is_A2(cls) -> bool: + return 220 <= cls.soc_version() <= 225 + + @classmethod + def is_A3(cls) -> bool: + return 250 <= cls.soc_version() <= 255 + + +@dataclass +class DistMeta: + dp_size: int + tp_size: int + ep_size: int + tp_rank: int + ep_rank: int + tp_group: torch.distributed.ProcessGroup + ep_group: torch.distributed.ProcessGroup + class AscendKVQuantMeta: has_set_value: bool = False @@ -88,10 +117,10 @@ def set_value(cls, device: str, dtype: torch.dtype, record_file: str, total_laye class AscendOpsBackend(DlinferOpsBackend): """Ascend layer backend.""" - enable_graph = False - half_negative_inf = torch.finfo(torch.float16).min + enable_graph: bool = False total_slots = None max_batches = None + dist_meta: DistMeta = None @staticmethod def get_name() -> str: @@ -232,6 +261,103 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s return kv_start_indices, attention_mask + def get_dist_meta(): + if cls.dist_meta is not None: + return cls.dist_meta + dist_ctx = get_dist_manager().current_context() + dp_size, tp_size, ep_size = dist_ctx.dist_config.dp, dist_ctx.dist_config.tp, dist_ctx.dist_config.ep + tp_rank, ep_rank = dist_ctx.attn_tp_group.rank, dist_ctx.ep_rank + tp_group = dist_ctx.attn_tp_group.gpu_group + ep_group = dist_ctx.ep_gpu_group + cls.dist_meta = DistMeta(dp_size=dp_size, + tp_size=tp_size, + ep_size=ep_size, + tp_rank=tp_rank, + ep_rank=ep_rank, + tp_group=tp_group, + ep_group=ep_group) + return cls.dist_meta + + def get_tokens_info(dp_size, tp_size, ep_size, ep_group): + if ep_size <= 1: + return 0, 0, 0 + # get padded_tokens_current_rank + is_graph = cls.enable_graph and step_context.is_decoding + if is_graph: + from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size + actual_tokens_current_rank = step_context.q_seqlens.shape[0] + padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank), + cls.max_batches) + else: + actual_tokens_current_rank = step_context.q_seqlens.sum().item() + padded_tokens_current_rank = actual_tokens_current_rank + # get max_tokens_across_dp + if dp_size > 1: + runtime_tokens_tensor = torch.tensor([padded_tokens_current_rank], + dtype=step_context.q_seqlens.dtype, + device=torch.npu.current_device()) + world_size = dp_size * tp_size + runtime_tokens_buffer = torch.zeros([world_size], + dtype=step_context.q_seqlens.dtype, + device=torch.npu.current_device()) + dist.all_gather_into_tensor(runtime_tokens_buffer, runtime_tokens_tensor, ep_group) + max_tokens_across_dp = torch.max(runtime_tokens_buffer).item() + else: + max_tokens_across_dp = padded_tokens_current_rank + return actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp + + @lru_cache + def init_mc2_token_capacity(tp_size): + max_num_tokens = min(cls.max_batches, 512) + num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size + return num_tokens_per_tp_rank * tp_size + + def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size): + if ep_size <= 1: + return DlinferMoECommType.ALLGATHER + mc2_token_capacity = init_mc2_token_capacity(tp_size) + is_graph = cls.enable_graph and step_context.is_decoding + if is_graph: + max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size + if SocVersion.is_A2(): + if max_tokens_across_dp <= mc2_token_capacity and dp_size * tp_size >= 16: + return DlinferMoECommType.MC2 + else: + return DlinferMoECommType.ALLGATHER + elif SocVersion.is_A3(): + if max_tokens_across_dp <= mc2_token_capacity: + return DlinferMoECommType.MC2 + else: + return DlinferMoECommType.ALLTOALL + else: + raise ValueError(f'Unsupported soc_version: {SocVersion.soc_version()}') + + def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp, tp_size, + moe_comm_type): + x_active_mask = None + if moe_comm_type == DlinferMoECommType.MC2: + padded_size = math.ceil(max_tokens_across_dp / tp_size) * tp_size + pad_size = padded_size - padded_tokens_current_rank + x_active_mask = torch.ones(actual_tokens_current_rank, + dtype=torch.bool, + device=torch.npu.current_device()) + elif moe_comm_type == DlinferMoECommType.ALLTOALL: + pad_size = tp_size - padded_tokens_current_rank + elif moe_comm_type == DlinferMoECommType.ALLGATHER: + pad_size = max_tokens_across_dp - padded_tokens_current_rank + else: + pad_size = 0 + return pad_size, x_active_mask + + @lru_cache(maxsize=1) + def get_moe_group_name(group): + if group is None: + return None + local_rank = torch.distributed.get_rank(group=group) + backend = group._get_backend(torch.device('npu')) + group_name = backend.get_hccl_comm_name(local_rank) + return group_name + q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill) q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, @@ -274,8 +400,32 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, ) - step_context.attn_metadata = attn_metadata + + cls.dist_meta = get_dist_meta() + actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp = get_tokens_info( + cls.dist_meta.dp_size, cls.dist_meta.tp_size, cls.dist_meta.ep_size, cls.dist_meta.ep_group) + moe_comm_type = select_moe_comm_type(max_tokens_across_dp, cls.dist_meta.dp_size, cls.dist_meta.tp_size, + cls.dist_meta.ep_size) + pad_size, x_active_mask = get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, + max_tokens_across_dp, cls.dist_meta.tp_size, moe_comm_type) + moe_group_name = get_moe_group_name(cls.dist_meta.ep_group) + + moe_metadata = DlinferMoeMetadata( + max_tokens_across_dp=max_tokens_across_dp, + pad_size=pad_size, + dp_size=cls.dist_meta.dp_size, + tp_size=cls.dist_meta.tp_size, + ep_size=cls.dist_meta.ep_size, + tp_rank=cls.dist_meta.tp_rank, + ep_rank=cls.dist_meta.ep_rank, + tp_group=cls.dist_meta.tp_group, + ep_group=cls.dist_meta.ep_group, + moe_comm_type=moe_comm_type, + x_active_mask=x_active_mask, + moe_group_name=moe_group_name, + ) + step_context.moe_metadata = moe_metadata return step_context @staticmethod diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py index 70b134c786..f034a0bb07 100644 --- a/lmdeploy/pytorch/backends/dlinfer/moe.py +++ b/lmdeploy/pytorch/backends/dlinfer/moe.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. - +import os from typing import Callable, List import torch +from lmdeploy.pytorch.kernels.dlinfer import DlinferMoECommType # noqa: F401 +from lmdeploy.pytorch.kernels.dlinfer import DlinferMoeMetadata # noqa: F401 from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from ..moe import FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder, SoftmaxTopKImpl @@ -19,7 +22,9 @@ def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1): raise NotImplementedError('Group router not supported') def forward(self, x: torch.Tensor): - routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k) + step_context = get_step_ctx_manager().current_context() + moe_metadata = getattr(step_context, 'moe_metadata', None) + routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k, moe_metadata) return routing_weights, selected_experts @@ -35,17 +40,42 @@ def build(top_k: int, dim: int = -1, n_groups: int = -1): class DlinferFusedMoEImpl(FusedMoEImpl): """Dlinfer fused moe implementation.""" - def __init__(self, top_k: int, renormalize: bool = False): + def __init__(self, + top_k: int, + num_experts: int, + renormalize: bool = False, + ep_size: int = 1, + ep_group: torch.distributed.ProcessGroup = None): self.top_k = top_k + self.num_experts = num_experts self.renormalize = renormalize + self.ep_size = ep_size + self.ep_group = ep_group + self.expert_ids_per_ep_rank = None + if self.ep_size > 1: + self.expert_ids_per_ep_rank = torch.tensor( + [i % (self.num_experts // self.ep_size) for i in range(num_experts)], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor): """Update weights.""" device_type = gate_up_weights.device.type if device_type in ['npu']: + if os.getenv('DLINFER_RESET_MOE_UPDATE_WEIGHTS', '0') == '1': + return gate_up_weights, down_weights return gate_up_weights.transpose(-1, -2).contiguous(), down_weights.transpose(-1, -2).contiguous() return gate_up_weights, down_weights + def ep_expert_list(self, world_size: int, rank: int): + """Experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -59,8 +89,13 @@ def forward(self, """forward.""" assert gate_up_bias is None assert down_bias is None + + step_context = get_step_ctx_manager().current_context() + moe_metadata = getattr(step_context, 'moe_metadata', None) + if moe_metadata is not None: + moe_metadata.expert_ids_per_ep_rank = self.expert_ids_per_ep_rank return fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, self.top_k, - self.renormalize) + self.renormalize, moe_metadata) class DlinferFusedMoEBuilder(FusedMoEBuilder): @@ -76,4 +111,8 @@ def build(top_k: int, layer_idx: int = 0, out_dtype: torch.dtype = torch.bfloat16): """Build from mlp.""" - return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize) + return DlinferFusedMoEImpl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + ep_size=ep_size, + ep_group=ep_group) diff --git a/lmdeploy/pytorch/configurations/utils.py b/lmdeploy/pytorch/configurations/utils.py index dfdd50512e..2ea21364a7 100644 --- a/lmdeploy/pytorch/configurations/utils.py +++ b/lmdeploy/pytorch/configurations/utils.py @@ -11,9 +11,11 @@ def flash_mla_available(): # use flash_mla by default if it is installed use_flash_mla = False try: - # torch_npu device_properties doesn't have 'major' attribute + """In some torch_npu versions, device_properties doesn't have 'major' + attribute; In other torch_npu versions, the value of major is None.""" device_properties = torch.cuda.get_device_properties(0) - if hasattr(device_properties, 'major') and device_properties.major >= 9: + major = getattr(device_properties, 'major', None) + if isinstance(major, int) and major >= 9: import flash_mla # noqa use_flash_mla = True except ImportError: diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 7b226d7ff4..660368ba23 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -4,7 +4,7 @@ from .awq_kernels import awq_linear from .fill_kv_cache import fill_kv_cache from .flash_attention import flash_attention_fwd -from .fused_moe import fused_moe +from .fused_moe import DlinferMoECommType, DlinferMoeMetadata, fused_moe from .linear import linear from .moe_gating_topk_softmax import moe_gating_topk_softmax from .pagedattention import paged_attention_fwd @@ -15,6 +15,8 @@ 'apply_rotary_pos_emb', 'awq_linear', 'fill_kv_cache', + 'DlinferMoECommType', + 'DlinferMoeMetadata', 'fused_moe', 'paged_attention_fwd', 'flash_attention_fwd', diff --git a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py index a1b4c659d1..7f3037b247 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/flash_attention.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import dlinfer.ops as ext_ops -from dlinfer.utils.type_annotation import Tensor +from torch import Tensor def flash_attention_fwd( diff --git a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py index 4bcfade78d..4624e0c199 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py +++ b/lmdeploy/pytorch/kernels/dlinfer/fused_moe.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import dlinfer.ops as ext_ops +from dlinfer.utils.type_annotation import MoECommType as DlinferMoECommType # noqa: F401 +from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata from torch import Tensor @@ -11,6 +13,8 @@ def fused_moe( topk_ids: Tensor, topk: int, renormalize: bool, + moe_metadata: DlinferMoeMetadata, ): """Dlinfer fused moe.""" - return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize) + return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize, + moe_metadata) diff --git a/lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py index ad2fe66056..cc1a324bf4 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py +++ b/lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + import dlinfer.ops as ext_ops +from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata from torch import Tensor -def moe_gating_topk_softmax(router_logits: Tensor, topk: int): - routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk) +def moe_gating_topk_softmax(router_logits: Tensor, topk: int, + moe_metadata: DlinferMoeMetadata) -> Tuple[Tensor, Tensor]: + routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk, moe_metadata) return routing_weights, selected_experts diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 13f4e12a58..8996508aff 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + import dlinfer.ops as ext_ops -import torch -from dlinfer.utils.type_annotation import Optional, Sequence, Tensor +from torch import Tensor def prefill_attention( @@ -111,8 +112,8 @@ def paged_token_attention( def paged_attention_fwd( query_states: Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, + key_states: Tensor, + value_states: Tensor, attn_output: Tensor, key_cache: Tensor, value_cache: Tensor, diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index d94a38d0bf..22d1ca8418 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -22,9 +22,9 @@ safetensors sentencepiece shortuuid tiktoken -torch>=2.3.1,<2.9.0 -torch-npu>=2.3.1,<2.9.0 -torchvision>=0.18.1,<0.24.0 +torch>=2.3.1,<2.10.0 +torch-npu>=2.3.1,<2.10.0 +torchvision>=0.18.1,<0.25.0 transformers uvicorn xgrammar