diff --git a/src/mcore_bridge/model/gpts/glm4.py b/src/mcore_bridge/model/gpts/glm4.py index 5c7a8ca..861f94d 100644 --- a/src/mcore_bridge/model/gpts/glm4.py +++ b/src/mcore_bridge/model/gpts/glm4.py @@ -91,11 +91,11 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo class Glm4Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = Glm4SelfAttention - layer_spec.submodules.mlp.module = Glm4MLP - transformer_layer.MLP = Glm4MLP # patch - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Glm4SelfAttention + layer_spec.submodules.mlp.module = Glm4MLP + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index 81b11b6..c03f803 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -95,9 +95,10 @@ def _set_moe_state( class MinimaxM2Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index 7c36a19..515366f 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import megatron.core import torch from contextlib import contextmanager from megatron.core import InferenceParams @@ -15,6 +14,7 @@ class MultimodalGPTModel(MegatronModule): + language_model_cls = GPTModel def __init__(self, config: ModelConfig, @@ -26,7 +26,8 @@ def __init__(self, super().__init__(config) self.pre_process = pre_process self.post_process = post_process - self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs) + self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args, + **kwargs) self.vp_stage = self.language_model.vp_stage self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights self.model_meta = config.model_meta diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 6fd1ac7..eff1bd6 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -2,3 +2,4 @@ from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer +from .transformer_layer import CustomTransformerLayer diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py new file mode 100644 index 0000000..024135e --- /dev/null +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -0,0 +1,271 @@ +import enum +import inspect +import torch +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region) +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, + get_transformer_layer_offset) +from megatron.core.utils import get_pg_rank +from typing import Optional + +from mcore_bridge.utils import get_logger + +try: + from megatron.core.transformer.enums import CudaGraphScope +except ImportError: + + class CudaGraphScope(enum.Enum): + """Cuda Graph Scope - defines which parts of the model to capture.""" + + full_iteration = 1 # Captures the entire training/inference iteration + attn = 2 # Captures attention layers + mlp = 3 # Captures MLP layers (dense layers only) + moe = 4 # Captures MoE layers (drop-and-pad MoE layers only) + moe_router = 5 # Captures MoE router part + moe_preprocess = 6 # Captures MoE preprocessing part (requires moe_router) + mamba = 7 # Captures Mamba layers + + +logger = get_logger() + + +class CustomTransformerLayer(TransformerLayer): + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, + ): + self.submodules_config = submodules + super(TransformerLayer, self).__init__(config=config, vp_stage=vp_stage) + + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + self.pg_collection = pg_collection + self.tp_group = pg_collection.tp + + # MTP inner layers use their own layer numbering (starting from 1 within each MTP depth), + # so they should NOT add the decoder layer offset. The router.py handles MTP layer + # numbering separately by adding config.num_layers to distinguish MTP layers from decoder + # layers in the aux loss tracker. + # + # When add_layer_offset is False, the caller has already included the correct offset + # in layer_number (e.g. when using --hybrid-layer-pattern with fVPP). + if is_mtp_layer or not add_layer_offset: + self.layer_number = layer_number + else: + self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage, + get_pg_rank(pg_collection.pp)) + self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + self.is_mtp_layer = is_mtp_layer + + # [Module 1: Input Layernorm] Optional Layernorm on the input data + # TODO: add pytorch only layernorm + self.input_layernorm = submodules.input_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + attention_optional_kwargs = {} + if config.context_parallel_size > 1 and config.cp_comm_type is not None: + if isinstance(config.cp_comm_type, list): + # layer_number is 1-indexed, so we need to subtract 1 to get the correct index + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type[self.layer_number - 1] + else: + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type + + attention_optional_kwargs['pg_collection'] = pg_collection + if pp_layer_offset is not None: + attention_optional_kwargs['pp_layer_offset'] = pp_layer_offset + + # [Module 2: SelfAttention] + self.self_attention = build_module( + submodules.self_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # [Module 5: CrossAttention] + self.cross_attention = build_module( + submodules.cross_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 6: BiasDropoutFusion] + self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) + + # [Module 7: Pre MLP] Optional Layernorm before MLP + self.pre_mlp_layernorm = submodules.pre_mlp_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 8: MLP block] + additional_mlp_kwargs = {} + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TEFusedMLP + from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP + from megatron.core.transformer.moe.moe_layer import MoELayer + + from mcore_bridge.model.gpts.glm4 import Glm4MLP + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(submodules.mlp, ModuleSpec): + if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif submodules.mlp.module in (MLP, Glm4MLP): + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f'Unknown MLP type: {submodules.mlp.module}. Using default kwargs.') + self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + if hasattr(self.mlp, 'set_layer_number'): + self.mlp.set_layer_number(self.layer_number) + + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False + if self.config.recompute_granularity == 'selective': + assert self.config.recompute_modules is not None + if 'layernorm' in self.config.recompute_modules: + if not isinstance(self.input_layernorm, IdentityOp): + self.recompute_input_layernorm = True + if self.config.fp8 or self.config.fp4: + self.self_attention.set_for_recompute_input_layernorm() + + def can_recompute_pre_mlp_layernorm_for_cudagraph(): + if (not self.is_moe_layer or CudaGraphScope.moe_router not in self.config.cuda_graph_scope + or self.config.cuda_graph_impl == 'local'): + # Not a MoE layer, or not capturing the router part. + return True + if (self.config.moe_shared_expert_intermediate_size is not None + and self.config.moe_shared_expert_overlap): + # If shared expert overlap is used, we cannot make the pre-mlp layernorm + # recomputation, because the shared expert takes the layernorm output as + # input, and it is outside of the CUDA graph scope. + logger.warning( + 'pre_mlp_layernorm recompute is not supported with moe router ' + 'cudagraph + shared expert overlap. Disabling pre_mlp_layernorm ' + 'recompute.', ) + return False + if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and ( + self.config.moe_token_dispatcher_type == 'alltoall' or self.config.moe_latent_size): + # Only when capturing the preprocess part and using alltoall token + # dispatcher or latent MoE can we make the pre-mlp layernorm recomputation. + # Because in other cases the layernorm output returns directly as one of the + # outputs of the cudagraph, which will be allocated a static buffer, thus + # not able to be released. + return True + logger.warning( + 'pre_mlp_layernorm recompute is only supported with moe router + ' + 'preprocess cudagraph will alltoall token dispatcher or latent MoE. ' + 'Disabling pre_mlp_layernorm recompute.', ) + return False + + if (not isinstance(self.pre_mlp_layernorm, IdentityOp) + and can_recompute_pre_mlp_layernorm_for_cudagraph()): + self.recompute_pre_mlp_layernorm = True + if self.config.fp8 or self.config.fp4: + if isinstance(self.mlp, MoELayer): + self.mlp.set_for_recompute_pre_mlp_layernorm() + else: + from megatron.core.extensions.transformer_engine import set_save_original_input + + set_save_original_input(self.mlp.linear_fc1) + if 'mlp' in self.config.recompute_modules: + if not self.is_moe_layer: + self.recompute_mlp = True + if hasattr(self.config, 'fine_grained_activation_offloading'): + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + + # @jcasper how should we handle nvfuser? + # Set bias+dropout+add fusion grad_enable execution handler. + # TORCH_MAJOR = int(torch.__version__.split('.')[0]) + # TORCH_MINOR = int(torch.__version__.split('.')[1]) + # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad + self.bias_dropout_add_exec_handler = torch.enable_grad + + def forward(self, *args, **kwargs): + """ + Perform a forward pass through the transformer layer. + + This method calls the core computation of a transformer layer, including + self-attention, cross-attention (if applicable), and feed-forward operations. + """ + hidden_states, context = self._forward_attention(*args, **kwargs) + # If padding_free is set, attention_mask does not exist. + mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs + mask = None + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + pad_size = 0 + if mlp_padding_free and hidden_states.shape[1] > 1: + if enable_sp: + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) + mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() + hidden_states = hidden_states[mask][:, None] + if enable_sp: + tp_size = self.config.tensor_model_parallel_size + num_tokens = hidden_states.shape[0] + remainder = num_tokens % tp_size + if remainder != 0: + pad_size = tp_size - remainder + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, 0, 0, pad_size)) + hidden_states = scatter_to_sequence_parallel_region(hidden_states) + output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) + if mask is not None: + if enable_sp: + output = gather_from_sequence_parallel_region(output, tensor_parallel_output_grad=False) + if pad_size > 0: + output = output[:-pad_size] + new_output = output.new_zeros((*mask.shape, output.shape[-1])) + new_output[mask] = output.squeeze(1) + output = new_output + if enable_sp: + output = scatter_to_sequence_parallel_region(output) + return output, context diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 0e67e90..15b37fe 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -4,9 +4,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear -from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec from packaging import version from torch import nn from typing import TYPE_CHECKING, List, Optional, Type, Union @@ -15,7 +13,7 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import MultiTokenPredictionLayer +from .modules import CustomTransformerLayer, MultiTokenPredictionLayer if TYPE_CHECKING: from .gpt_model import GPTModel @@ -90,41 +88,20 @@ def _replace_spec_dsa(self, layer_spec): layer_spec.submodules.self_attention = dsa_spec def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - if self.config.num_moe_experts: - transformer_layer_spec = get_gpt_decoder_block_spec( - self.config, - use_transformer_engine=True, - normalization=self.config.normalization, - qk_l2_norm=self.config.qk_l2_norm, - vp_stage=vp_stage) - if self.config.experimental_attention_variant == 'dsa': - for layer_spec in transformer_layer_spec.layer_specs: - self._replace_spec_dsa(layer_spec) - else: - transformer_layer_spec = self._get_transformer_layer_spec() - return transformer_layer_spec - - def _get_transformer_layer_spec(self): - config = self.config - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - config.num_moe_experts, - config.moe_grouped_gemm, - config.qk_layernorm, - config.multi_latent_attention, - qk_l2_norm=config.qk_l2_norm, - ) + transformer_layer_spec = get_gpt_decoder_block_spec( + self.config, + use_transformer_engine=True, + normalization=self.config.normalization, + qk_l2_norm=self.config.qk_l2_norm, + vp_stage=vp_stage) + if self.config.experimental_attention_variant == 'dsa': + for layer_spec in transformer_layer_spec.layer_specs: + self._replace_spec_dsa(layer_spec) return transformer_layer_spec def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = None): - if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - # TODO: remove - transformer_layer_spec_for_mtp = self._get_transformer_layer_spec() - else: - transformer_layer_spec_for_mtp = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec( - self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, vp_stage=vp_stage) + self.config, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage) if mtp_block_spec is not None: for layer_spec in mtp_block_spec.layer_specs: layer_spec.module = MultiTokenPredictionLayer @@ -138,6 +115,10 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + def _set_custom_layer(self, transformer_layer_spec): + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = CustomTransformerLayer + def build_model( self, pre_process=True, @@ -146,6 +127,7 @@ def build_model( ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) + self._set_custom_layer(transformer_layer_spec) mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index 5cabe42..e7db3c3 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -106,12 +106,12 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(config, seq_len=None): +def get_rope_inv_freq(config, seq_len=None, **kwargs): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) dummy_config = _get_dummy_config(config) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len) + inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs) if attention_scaling is None: attention_scaling = 1. return inv_freq, attention_scaling diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index fa859e0..79527ee 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -12,7 +12,6 @@ from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region) -from megatron.core.transformer import TransformerLayer from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock, get_mtp_layer_offset from megatron.core.utils import deprecate_inference_params @@ -407,31 +406,6 @@ def sharded_state_dict( peft_module.OriginModulesToSaveWrapper = OriginModulesToSaveWrapper -def _patch_TransformerLayer(): - - def forward(self, *_args, **kwargs): - """ - Perform a forward pass through the transformer layer. - - This method calls the core computation of a transformer layer, including - self-attention, cross-attention (if applicable), and feed-forward operations. - """ - hidden_states, context = self._forward_attention(*_args, **kwargs) - mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs - mask = None - if mlp_padding_free and hidden_states.shape[1] > 1: - mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() - hidden_states = hidden_states[mask][:, None] - output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) - if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) - new_output[mask] = output.squeeze(1) - output = new_output - return output, context - - TransformerLayer.forward = forward - - def _patch_TELinear(): def __repr__(self): @@ -759,7 +733,6 @@ def apply_patch(): # patch module _patch_mla_attention() _patch_TEGroupedLinear() - _patch_TransformerLayer() _patch_TELinear() _patch_mrope() _patch_mtp() diff --git a/src/mcore_bridge/tuners/patcher.py b/src/mcore_bridge/tuners/patcher.py index e715c35..f9cae8c 100644 --- a/src/mcore_bridge/tuners/patcher.py +++ b/src/mcore_bridge/tuners/patcher.py @@ -1,6 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import copy -from contextlib import contextmanager from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.router import TopKRouter @@ -11,6 +9,8 @@ from torch import nn from typing import Optional +from mcore_bridge.utils import patch_deepcopy + from .lora import LoraParallelLinear @@ -37,39 +37,6 @@ def dispatch_megatron( model.dispatch_megatron = dispatch_megatron -@contextmanager -def _patch_deepcopy(): - _origin_deepcopy = copy.deepcopy - copy_keys = ('tp_group', '_tp_group', 'config') - - def new_deepcopy(x, *args, **kwargs): - if not isinstance(x, nn.Module): - return _origin_deepcopy(x, *args, **kwargs) - - saved_values = {} - for key in copy_keys: - val = getattr(x, key, None) - if val is not None: - saved_values[key] = val - setattr(x, key, None) - - try: - res = _origin_deepcopy(x, *args, **kwargs) - finally: - for key, value in saved_values.items(): - setattr(x, key, value) - - for key, value in saved_values.items(): - setattr(res, key, value) - return res - - copy.deepcopy = new_deepcopy - try: - yield - finally: - copy.deepcopy = _origin_deepcopy - - def _patch_lora_model(): if hasattr(LoraModel, '_mcore_patched'): return @@ -77,7 +44,7 @@ def _patch_lora_model(): __origin_init__ = LoraModel.__init__ def __new_init__(self, *args, **kwargs): - with _patch_deepcopy(): + with patch_deepcopy(): __origin_init__(self, *args, **kwargs) if not isinstance(self.model, MegatronModule): return diff --git a/src/mcore_bridge/utils/__init__.py b/src/mcore_bridge/utils/__init__.py index 34cdbd1..d4285be 100644 --- a/src/mcore_bridge/utils/__init__.py +++ b/src/mcore_bridge/utils/__init__.py @@ -6,4 +6,4 @@ from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device -from .utils import deep_getattr, get_env_args, json_parse_to_dict +from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy diff --git a/src/mcore_bridge/utils/utils.py b/src/mcore_bridge/utils/utils.py index 6905c41..a7e525b 100644 --- a/src/mcore_bridge/utils/utils.py +++ b/src/mcore_bridge/utils/utils.py @@ -1,5 +1,8 @@ +import copy import json import os +from contextlib import contextmanager +from torch import nn from transformers.utils import strtobool from typing import Callable, Dict, Optional, TypeVar, Union @@ -58,3 +61,36 @@ def deep_getattr(obj, attr: str, default=None): else: obj = getattr(obj, a, default) return obj + + +@contextmanager +def patch_deepcopy(): + _origin_deepcopy = copy.deepcopy + copy_keys = ('tp_group', '_tp_group', 'config') + + def new_deepcopy(x, *args, **kwargs): + if not isinstance(x, nn.Module): + return _origin_deepcopy(x, *args, **kwargs) + + saved_values = {} + for key in copy_keys: + val = getattr(x, key, None) + if val is not None: + saved_values[key] = val + setattr(x, key, None) + + try: + res = _origin_deepcopy(x, *args, **kwargs) + finally: + for key, value in saved_values.items(): + setattr(x, key, value) + + for key, value in saved_values.items(): + setattr(res, key, value) + return res + + copy.deepcopy = new_deepcopy + try: + yield + finally: + copy.deepcopy = _origin_deepcopy