diff --git a/README.md b/README.md index 5c9d77b..8c43082 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@
-
+
diff --git a/README_zh.md b/README_zh.md
index 3be691a..00a0e2d 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -20,7 +20,7 @@
-
+
diff --git a/requirements.txt b/requirements.txt
index e41cb66..a867d28 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-megatron-core>=0.12
+megatron-core>=0.15
modelscope
peft>=0.11,<0.20
safetensors
diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py
index 96b4da6..b171f20 100644
--- a/src/mcore_bridge/bridge/gpt_bridge.py
+++ b/src/mcore_bridge/bridge/gpt_bridge.py
@@ -1,13 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import math
-import megatron.core
import re
import torch
import torch.distributed as dist
import torch.nn.functional as F
from contextlib import contextmanager
from megatron.core import mpu
-from packaging import version
from peft import PeftModel
from peft.utils import ModulesToSaveWrapper
from tqdm import tqdm
@@ -22,8 +20,6 @@
logger = get_logger()
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
EP_PP_SIZE = None
EP_PP_GROUP = None
EP_PP_RANK = None
@@ -60,7 +56,6 @@ def __init__(self, config: ModelConfig):
self.model_type = config.hf_model_type
self.llm_model_type = config.llm_model_type
self.is_multimodal = config.is_multimodal
- self.mcore_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0')
self.module_mapping = config.model_meta.visual_cls.module_mapping if self.is_multimodal else {}
self.tp_size = self.config.tensor_model_parallel_size
self.pp_size = self.config.pipeline_model_parallel_size
@@ -130,9 +125,6 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
}
if self.config.task_type in {'causal_lm', 'generative_reranker'}:
dim0_keys.add('output_layer')
- if not self.mcore_014:
- # https://github.com/NVIDIA/Megatron-LM/commit/720c8b40d8e7e2de1dd303d792f29093101c5e72
- dim0_keys.update({'linear_q_down_proj', 'linear_kv_down_proj'})
# RowLinear
dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'}
if 'lora_A' not in mg_key and 'lora_B' not in mg_key:
@@ -1679,12 +1671,8 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd
hf_state_dict = {}
mg_models = iter(mg_models)
mg_model = next(mg_models)
- if mcore_013:
- is_pp_first_stage = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage)
- is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage)
- else:
- is_pp_first_stage = mpu.is_pipeline_first_stage()
- is_pp_last_stage = mpu.is_pipeline_last_stage()
+ is_pp_first_stage = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage)
+ is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage)
if not to_mcore or is_pp_first_stage:
hf_state_dict.update(self._convert_pre_process(mg_model, hf_state_dict, '', to_mcore))
if to_mcore:
diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py
index 5fb714c..165ff0c 100644
--- a/src/mcore_bridge/model/gpt_model.py
+++ b/src/mcore_bridge/model/gpt_model.py
@@ -1,7 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
import math
-import megatron.core
import os
import torch
import torch.nn.functional as F
@@ -20,7 +19,6 @@
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import WrappedTensor, deprecate_inference_params
-from packaging import version
from typing import Optional, Tuple
from mcore_bridge.config import ModelConfig
@@ -30,8 +28,6 @@
logger = get_logger()
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
class OutputLayerLinear(TELinear):
@@ -79,12 +75,6 @@ def __init__(
config.mscale_all_dim = hf_rope_scaling['mscale_all_dim']
config.rotary_scaling_factor = hf_rope_scaling['factor']
self.hf_rope_scaling = hf_rope_scaling
- if mcore_013:
- kwargs = {'vp_stage': vp_stage}
- else:
- self.vp_stage = vp_stage
- assert vp_stage is None, 'megatron-core==0.12 does not support vp_stage'
- kwargs = {}
super().__init__(
config,
transformer_layer_spec,
@@ -96,7 +86,7 @@ def __init__(
position_embedding_type=config.position_embedding_type,
rotary_base=config.rotary_base,
mtp_block_spec=mtp_block_spec,
- **kwargs,
+ vp_stage=vp_stage,
)
if config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
diff --git a/src/mcore_bridge/model/gpts/glm4.py b/src/mcore_bridge/model/gpts/glm4.py
index 30c5eaf..5c7a8ca 100644
--- a/src/mcore_bridge/model/gpts/glm4.py
+++ b/src/mcore_bridge/model/gpts/glm4.py
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import megatron.core
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import TENorm
from megatron.core.transformer import transformer_layer
@@ -7,7 +6,6 @@
from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.utils import sharded_state_dict_default
-from packaging import version
from typing import Optional
from mcore_bridge.bridge import GPTBridge
@@ -16,8 +14,6 @@
from ..constant import ModelType
from ..register import ModelLoader, ModelMeta, register_model
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
class Glm4SelfAttention(SelfAttention):
diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py
index ebdfae3..81b11b6 100644
--- a/src/mcore_bridge/model/gpts/minimax_m2.py
+++ b/src/mcore_bridge/model/gpts/minimax_m2.py
@@ -1,12 +1,10 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import megatron.core
from megatron.core import mpu
from megatron.core.tensor_parallel.mappings import (gather_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region)
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import build_module
-from packaging import version
from typing import Optional
from mcore_bridge.bridge import GPTBridge
@@ -15,8 +13,6 @@
from ..constant import ModelType
from ..register import ModelLoader, ModelMeta, register_model
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
class MinimaxM2SelfAttention(SelfAttention):
diff --git a/src/mcore_bridge/model/gpts/olmoe.py b/src/mcore_bridge/model/gpts/olmoe.py
index f436517..63c1957 100644
--- a/src/mcore_bridge/model/gpts/olmoe.py
+++ b/src/mcore_bridge/model/gpts/olmoe.py
@@ -1,4 +1,3 @@
-import megatron.core
import torch
import torch.distributed as dist
from copy import deepcopy
@@ -9,7 +8,6 @@
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
-from packaging import version
from typing import Optional
from mcore_bridge.bridge import GPTBridge
@@ -19,8 +17,6 @@
from ..constant import ModelType
from ..register import ModelLoader, ModelMeta, register_model
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
class OLMoESelfAttention(SelfAttentionBase):
@@ -78,13 +74,12 @@ def get_olmoe_decoder_block_spec(
) -> TransformerBlockSubmodules:
"""GPT block spec."""
layer_norm_impl = TENorm
- kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {}
moe_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=True,
multi_latent_attention=False,
- **kwargs,
+ use_kitchen=config.use_kitchen,
)
layer_specs = []
for _ in range(config.num_layers):
diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py
index 08cce14..65a90df 100644
--- a/src/mcore_bridge/model/gpts/qwen3_next.py
+++ b/src/mcore_bridge/model/gpts/qwen3_next.py
@@ -1,5 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import megatron.core
import torch
from copy import deepcopy
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, _get_extra_te_kwargs
@@ -16,7 +15,6 @@
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.utils import deprecate_inference_params, is_fa_min_version
-from packaging import version
from transformers.utils import is_torch_npu_available
from typing import Optional, Tuple, Union
@@ -27,8 +25,6 @@
from ..constant import ModelType
from ..register import ModelLoader, ModelMeta, register_model
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0')
try:
from flashattn_hopper.flash_attn_interface import _flash_attn_forward
from flashattn_hopper.flash_attn_interface import flash_attn_with_kvcache as flash_attn3_with_kvcache
@@ -102,10 +98,6 @@ class Qwen3NextSelfAttention(SelfAttention):
def __init__(self, config: ModelConfig, submodules: SelfAttentionSubmodules, *args, **kwargs):
super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs)
kwargs = {}
- if mcore_015:
- kwargs['tp_group'] = self.pg_collection.tp
- elif mcore_013:
- kwargs['tp_group'] = self.model_comm_pgs.tp
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
@@ -117,6 +109,7 @@ def __init__(self, config: ModelConfig, submodules: SelfAttentionSubmodules, *ar
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
+ tp_group=self.pg_collection.tp,
**kwargs,
)
@@ -253,7 +246,7 @@ def nvtx_range_push(*args, **kwargs):
if (in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching()):
raise ValueError('CUDA graphs must use flash decode with static batching!')
- result = self._adjust_key_value_for_inference(
+ query, key, value, rotary_pos_emb, attn_mask_type, block_table = self._adjust_key_value_for_inference(
inference_context,
query,
key,
@@ -263,10 +256,6 @@ def nvtx_range_push(*args, **kwargs):
rotary_pos_sin,
sequence_len_offset,
)
- if mcore_013:
- query, key, value, rotary_pos_emb, attn_mask_type, block_table = result
- else:
- query, key, value, rotary_pos_emb, attn_mask_type = result
if packed_seq_params is not None:
query = query.squeeze(1)
@@ -277,11 +266,6 @@ def nvtx_range_push(*args, **kwargs):
# ================================================
# relative positional embedding (rotary embedding)
# ================================================
- kwargs = {}
- if mcore_015:
- kwargs['cp_group'] = self.pg_collection.cp
- elif mcore_013:
- kwargs['cp_group'] = self.model_comm_pgs.cp
nvtx_range_push(suffix='rotary_pos_emb')
if rotary_pos_emb is not None and not self.config.flash_decode:
q_pos_emb, k_pos_emb = rotary_pos_emb
@@ -306,11 +290,11 @@ def nvtx_range_push(*args, **kwargs):
q_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
- **kwargs,
+ cp_group=self.pg_collection.cp,
)
else:
- query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q,
- **kwargs)
+ query = inference_context.apply_rotary_emb_query(
+ query, q_pos_emb, self.config, cu_seqlens_q, cp_group=self.pg_collection.cp)
if k_pos_emb is not None:
key = apply_rotary_pos_emb(
key,
@@ -561,13 +545,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
# Use Zero-Centered RMSNorm to match HuggingFace exactly (no +1/-1 conversion needed)
layer_norm_impl = Qwen3NextRMSNorm
- kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {}
moe_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
- **kwargs,
+ use_kitchen=config.use_kitchen,
)
layer_specs = []
for is_linear_attention in self.config.linear_attention_freq:
diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py
index b68e82b..7c36a19 100644
--- a/src/mcore_bridge/model/mm_gpt_model.py
+++ b/src/mcore_bridge/model/mm_gpt_model.py
@@ -7,15 +7,12 @@
from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
-from packaging import version
from mcore_bridge.config import ModelConfig
from mcore_bridge.utils import split_cp_inputs
from .gpt_model import GPTModel
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
class MultimodalGPTModel(MegatronModule):
@@ -60,9 +57,8 @@ def forward(_self, input_):
res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
if reduce_scatter_embeddings:
res = res.transpose(0, 1).contiguous()
- group_kwargs = {'group': _self.tp_group} if mcore_013 else {}
- res = reduce_scatter_to_sequence_parallel_region(res, **
- group_kwargs) / self.config.tensor_model_parallel_size
+ res = reduce_scatter_to_sequence_parallel_region(
+ res, group=_self.tp_group) / self.config.tensor_model_parallel_size
return res
VocabParallelEmbedding.forward = forward
diff --git a/src/mcore_bridge/model/mm_gpts/llama4.py b/src/mcore_bridge/model/mm_gpts/llama4.py
index b48cd7d..7c508e2 100644
--- a/src/mcore_bridge/model/mm_gpts/llama4.py
+++ b/src/mcore_bridge/model/mm_gpts/llama4.py
@@ -1,10 +1,8 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import megatron.core
import torch
from copy import deepcopy
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
-from packaging import version
from transformers import PretrainedConfig
from typing import Optional
@@ -14,8 +12,6 @@
from ..register import ModelLoader, ModelMeta, register_model
from .utils import HuggingFaceVit
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
class Llama4Vit(HuggingFaceVit):
module_mapping = {'multi_modal_projector': 'multi_modal_projector', 'vision_model': 'vision_model'}
diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py
index f5113c6..0e67e90 100644
--- a/src/mcore_bridge/model/register.py
+++ b/src/mcore_bridge/model/register.py
@@ -71,7 +71,6 @@ class ModelLoader:
def __init__(self, config: ModelConfig):
from mcore_bridge.model import GPTModel, MultimodalGPTModel
- self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
self.config = config
if self.model_cls is None:
self.model_cls = MultimodalGPTModel if config.is_multimodal else GPTModel
@@ -92,9 +91,12 @@ def _replace_spec_dsa(self, layer_spec):
def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
if self.config.num_moe_experts:
- kwargs = {'qk_l2_norm': self.config.qk_l2_norm, 'vp_stage': vp_stage} if self.mcore_013 else {}
transformer_layer_spec = get_gpt_decoder_block_spec(
- self.config, use_transformer_engine=True, normalization=self.config.normalization, **kwargs)
+ self.config,
+ use_transformer_engine=True,
+ normalization=self.config.normalization,
+ qk_l2_norm=self.config.qk_l2_norm,
+ vp_stage=vp_stage)
if self.config.experimental_attention_variant == 'dsa':
for layer_spec in transformer_layer_spec.layer_specs:
self._replace_spec_dsa(layer_spec)
@@ -104,13 +106,12 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
def _get_transformer_layer_spec(self):
config = self.config
- kwargs = {'qk_l2_norm': config.qk_l2_norm} if self.mcore_013 else {}
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
- **kwargs,
+ qk_l2_norm=config.qk_l2_norm,
)
return transformer_layer_spec
@@ -122,9 +123,8 @@ def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = N
transformer_layer_spec_for_mtp = self._get_transformer_layer_spec()
else:
transformer_layer_spec_for_mtp = transformer_layer_spec
- kwargs = {'vp_stage': vp_stage} if self.mcore_013 else {}
mtp_block_spec = get_gpt_mtp_block_spec(
- self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, **kwargs)
+ self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, vp_stage=vp_stage)
if mtp_block_spec is not None:
for layer_spec in mtp_block_spec.layer_specs:
layer_spec.module = MultiTokenPredictionLayer
diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py
index 15280b5..fa859e0 100644
--- a/src/mcore_bridge/patcher.py
+++ b/src/mcore_bridge/patcher.py
@@ -1,4 +1,3 @@
-import megatron.core
import peft
import sys
import torch
@@ -25,7 +24,6 @@
from mcore_bridge.utils import get_logger, is_flash_attn_3_available
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
logger = get_logger()
@@ -103,12 +101,8 @@ def forward(
# Adjust key, value for inference
# ===================================================
# rotary_pos_emb = None
- if mcore_013:
- query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference(
- inference_context, query, key, value, rotary_pos_emb=None)
- else:
- query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference(
- inference_context, query, key, value, rotary_pos_emb=None)
+ query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference(
+ inference_context, query, key, value, rotary_pos_emb=None)
# TODO: Currently, TE can only accept contiguous tensors for MLA
query = query.contiguous()
@@ -414,7 +408,6 @@ def sharded_state_dict(
def _patch_TransformerLayer():
- _origin_forward = TransformerLayer.forward
def forward(self, *_args, **kwargs):
"""
@@ -423,8 +416,6 @@ def forward(self, *_args, **kwargs):
This method calls the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
"""
- if not mcore_013:
- return _origin_forward(self, *_args, **kwargs)
hidden_states, context = self._forward_attention(*_args, **kwargs)
mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs
mask = None
@@ -552,8 +543,6 @@ def _apply_rotary_pos_emb_thd(
use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item()
if not use_batched_rope:
logger.warning_once('Using non-batched RoPE, which may affect performance.')
- if mcore_013:
- kwargs['cp_group'] = cp_group
return _origin_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
@@ -561,6 +550,7 @@ def _apply_rotary_pos_emb_thd(
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
+ cp_group=cp_group,
**kwargs,
)
diff --git a/src/mcore_bridge/tuners/lora.py b/src/mcore_bridge/tuners/lora.py
index 2cc0f16..eab70e9 100644
--- a/src/mcore_bridge/tuners/lora.py
+++ b/src/mcore_bridge/tuners/lora.py
@@ -29,7 +29,6 @@
from .utils import tuners_sharded_state_dict
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0')
MINDSPEED_015 = version.parse('0.15.0')
@@ -163,7 +162,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
'config': self.config,
'is_expert': self.is_expert,
}
- if mcore_013 and not (mcore_016 and self.is_grouped):
+ if not (mcore_016 and self.is_grouped):
tp_group = _get_tensor_parallel_group_for_lora(self.base_layer)
if tp_group is not None:
kwargs['tp_group'] = tp_group
diff --git a/src/mcore_bridge/utils/megatron_utils.py b/src/mcore_bridge/utils/megatron_utils.py
index 031b4e5..9fb20a0 100644
--- a/src/mcore_bridge/utils/megatron_utils.py
+++ b/src/mcore_bridge/utils/megatron_utils.py
@@ -1,13 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
# code borrowed from modelscope/ms-swift
-import megatron.core
import torch
from megatron.core import mpu, tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.transformer_block import get_num_layers_to_build
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
-from packaging import version
from transformers import set_seed
from typing import Optional
@@ -15,8 +13,6 @@
logger = get_logger()
-mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
-
def unwrap_model(models, module_instances=None):
"""Unwrap_model to return the final model instance"""
@@ -67,17 +63,16 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di
def get_local_layer_specs(config, layer_specs, vp_stage=None):
- kwargs = {'vp_stage': vp_stage} if mcore_013 else {}
- num_layers_to_build = get_num_layers_to_build(config, **kwargs)
+ num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage)
if getattr(config, 'pipeline_model_parallel_layout', None) is not None:
from megatron.core.transformer.enums import LayerType
local_layer_specs = [
layer_specs[layer_id] for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list(
- layer_type=LayerType.decoder, **kwargs)
+ layer_type=LayerType.decoder, vp_stage=vp_stage)
]
else:
- offset = get_transformer_layer_offset(config, **kwargs)
+ offset = get_transformer_layer_offset(config, vp_stage=vp_stage)
local_layer_specs = layer_specs[offset:offset + num_layers_to_build]
return local_layer_specs