diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 165ff0c..1903b7c 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -477,7 +477,7 @@ def _postprocess( if self.config.task_type in {'seq_cls', 'embedding' } and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1: - hidden_states = gather_from_sequence_parallel_region(hidden_states) + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) if self.config.task_type == 'embedding': logits = F.normalize(hidden_states, p=2, dim=-1) diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index 9b135cd..316a7f0 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -9,8 +9,8 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import get_tensor_model_parallel_rank from megatron.core.tensor_parallel import (all_gather_last_dim_from_tensor_parallel_region, - gather_from_sequence_parallel_region, - reduce_scatter_to_sequence_parallel_region) + gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.transformer_block import TransformerBlockSubmodules @@ -462,7 +462,7 @@ def __init__(self, config: ModelConfig, submodules: SelfAttentionSubmodules, lay def forward(self, hidden_states: torch.Tensor, **kwargs): config = self.config if config.sequence_parallel and config.tensor_model_parallel_size > 1: - hidden_states = gather_from_sequence_parallel_region(hidden_states) + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) seq_len = hidden_states.shape[0] packed_seq_params = kwargs.get('packed_seq_params') thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' @@ -482,15 +482,15 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): else: hidden_states = hidden_states.transpose(0, 1) attention_mask = resolve_gdn_attention_mask(kwargs) - res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) + with get_cuda_rng_tracker().fork('data-parallel-rng'): + res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])]) else: res = res.transpose(0, 1).contiguous() if config.sequence_parallel and config.tensor_model_parallel_size > 1: - # Quick fix for dropout issue, awaiting ms-swift 4.0 refactor - res = reduce_scatter_to_sequence_parallel_region(res) / config.tensor_model_parallel_size + res = scatter_to_sequence_parallel_region(res) return res, None diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index 515366f..9d669e6 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from megatron.core import InferenceParams from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region +from megatron.core.tensor_parallel import VocabParallelEmbedding, scatter_to_sequence_parallel_region from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec @@ -58,8 +58,7 @@ def forward(_self, input_): res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if reduce_scatter_embeddings: res = res.transpose(0, 1).contiguous() - res = reduce_scatter_to_sequence_parallel_region( - res, group=_self.tp_group) / self.config.tensor_model_parallel_size + res = scatter_to_sequence_parallel_region(res, group=_self.tp_group) return res VocabParallelEmbedding.forward = forward diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5.py b/src/mcore_bridge/model/mm_gpts/qwen3_5.py index 801c55c..8ba5378 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5.py @@ -2,8 +2,8 @@ import torch from megatron.core.extensions.transformer_engine import _get_extra_te_kwargs from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule -from megatron.core.tensor_parallel import (gather_from_sequence_parallel_region, - reduce_scatter_to_sequence_parallel_region) +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker from megatron.core.transformer.attention import SelfAttentionSubmodules from megatron.core.transformer.transformer_config import TransformerConfig @@ -33,7 +33,7 @@ def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodule def forward(self, hidden_states: torch.Tensor, **kwargs): config = self.config if config.sequence_parallel and config.tensor_model_parallel_size > 1: - hidden_states = gather_from_sequence_parallel_region(hidden_states) + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) seq_len = hidden_states.shape[0] packed_seq_params = kwargs.get('packed_seq_params') thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' @@ -53,15 +53,15 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): else: hidden_states = hidden_states.transpose(0, 1) attention_mask = resolve_gdn_attention_mask(kwargs) - res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) + with get_cuda_rng_tracker().fork('data-parallel-rng'): + res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) if thd_format: res = res[attention_mask][:, None] res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])]) else: res = res.transpose(0, 1).contiguous() if config.sequence_parallel and config.tensor_model_parallel_size > 1: - # Quick fix for dropout issue, awaiting ms-swift 4.0 refactor - res = reduce_scatter_to_sequence_parallel_region(res) / config.tensor_model_parallel_size + res = scatter_to_sequence_parallel_region(res) return res, None diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index c2d069b..fa89bf4 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -184,7 +184,7 @@ def _get_embeddings( else: enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 if enable_sp: - decoder_input = gather_from_sequence_parallel_region(decoder_input) + decoder_input = gather_from_sequence_parallel_region(decoder_input, tensor_parallel_output_grad=False) decoder_input, _ = roll_tensor( decoder_input.transpose(0, 2), shifts=-1,