Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _postprocess(

if self.config.task_type in {'seq_cls', 'embedding'
} and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
hidden_states = gather_from_sequence_parallel_region(hidden_states)
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)

if self.config.task_type == 'embedding':
logits = F.normalize(hidden_states, p=2, dim=-1)
Expand Down
12 changes: 6 additions & 6 deletions src/mcore_bridge/model/gpts/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import get_tensor_model_parallel_rank
from megatron.core.tensor_parallel import (all_gather_last_dim_from_tensor_parallel_region,
gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region)
gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region)
from megatron.core.tensor_parallel.random import get_cuda_rng_tracker
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
Expand Down Expand Up @@ -462,7 +462,7 @@ def __init__(self, config: ModelConfig, submodules: SelfAttentionSubmodules, lay
def forward(self, hidden_states: torch.Tensor, **kwargs):
config = self.config
if config.sequence_parallel and config.tensor_model_parallel_size > 1:
hidden_states = gather_from_sequence_parallel_region(hidden_states)
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
seq_len = hidden_states.shape[0]
packed_seq_params = kwargs.get('packed_seq_params')
thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
Expand All @@ -482,15 +482,15 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
else:
hidden_states = hidden_states.transpose(0, 1)
attention_mask = resolve_gdn_attention_mask(kwargs)
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
with get_cuda_rng_tracker().fork('data-parallel-rng'):
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
if thd_format:
res = res[attention_mask][:, None]
res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])])
else:
res = res.transpose(0, 1).contiguous()
if config.sequence_parallel and config.tensor_model_parallel_size > 1:
# Quick fix for dropout issue, awaiting ms-swift 4.0 refactor
res = reduce_scatter_to_sequence_parallel_region(res) / config.tensor_model_parallel_size
res = scatter_to_sequence_parallel_region(res)
return res, None


Expand Down
5 changes: 2 additions & 3 deletions src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from megatron.core import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region
from megatron.core.tensor_parallel import VocabParallelEmbedding, scatter_to_sequence_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec

Expand Down Expand Up @@ -58,8 +58,7 @@ def forward(_self, input_):
res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
if reduce_scatter_embeddings:
res = res.transpose(0, 1).contiguous()
res = reduce_scatter_to_sequence_parallel_region(
res, group=_self.tp_group) / self.config.tensor_model_parallel_size
res = scatter_to_sequence_parallel_region(res, group=_self.tp_group)
return res

VocabParallelEmbedding.forward = forward
Expand Down
12 changes: 6 additions & 6 deletions src/mcore_bridge/model/mm_gpts/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
from megatron.core.extensions.transformer_engine import _get_extra_te_kwargs
from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule
from megatron.core.tensor_parallel import (gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region)
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region
from megatron.core.tensor_parallel.random import get_cuda_rng_tracker
from megatron.core.transformer.attention import SelfAttentionSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig

Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodule
def forward(self, hidden_states: torch.Tensor, **kwargs):
config = self.config
if config.sequence_parallel and config.tensor_model_parallel_size > 1:
hidden_states = gather_from_sequence_parallel_region(hidden_states)
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)
seq_len = hidden_states.shape[0]
packed_seq_params = kwargs.get('packed_seq_params')
thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
Expand All @@ -53,15 +53,15 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
else:
hidden_states = hidden_states.transpose(0, 1)
attention_mask = resolve_gdn_attention_mask(kwargs)
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
with get_cuda_rng_tracker().fork('data-parallel-rng'):
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
if thd_format:
res = res[attention_mask][:, None]
res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])])
else:
res = res.transpose(0, 1).contiguous()
if config.sequence_parallel and config.tensor_model_parallel_size > 1:
# Quick fix for dropout issue, awaiting ms-swift 4.0 refactor
res = reduce_scatter_to_sequence_parallel_region(res) / config.tensor_model_parallel_size
res = scatter_to_sequence_parallel_region(res)
return res, None


Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _get_embeddings(
else:
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
if enable_sp:
decoder_input = gather_from_sequence_parallel_region(decoder_input)
decoder_input = gather_from_sequence_parallel_region(decoder_input, tensor_parallel_output_grad=False)
decoder_input, _ = roll_tensor(
decoder_input.transpose(0, 2),
shifts=-1,
Expand Down
Loading