Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import torch
from megatron.core.models.gpt import GPTModel
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
)
from megatron.core.utils import deprecate_inference_params, get_pg_size
from torch.distributed.tensor import DTensor, distribute_tensor

from nemo_rl.algorithms.logits_sampling_utils import (
Expand All @@ -29,6 +23,11 @@
need_top_k_or_top_p_filtering,
)

if TYPE_CHECKING:
# megatron-core (optional "mcore" extra) is imported lazily below so this
# module imports without mcore installed.
from megatron.core.models.gpt import GPTModel


@torch.no_grad()
def _compute_distributed_log_softmax(
Expand Down Expand Up @@ -2044,6 +2043,8 @@ def backward(


def patch_gpt_model_forward_for_linear_ce_fusion(*, chunk_size: int) -> None:
from megatron.core.models.gpt import GPTModel

if getattr(GPTModel, "_linear_ce_fusion_forward_patched", False):
GPTModel._linear_ce_fusion_chunk_size = chunk_size
return
Expand All @@ -2054,7 +2055,7 @@ def patch_gpt_model_forward_for_linear_ce_fusion(*, chunk_size: int) -> None:


def _gpt_forward_with_linear_ce_fusion(
self: GPTModel,
self: "GPTModel",
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
Expand All @@ -2070,6 +2071,12 @@ def _gpt_forward_with_linear_ce_fusion(
padding_mask: Optional[torch.Tensor] = None,
return_logprobs_for_linear_ce_fusion: bool = False,
) -> torch.Tensor:
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
)
from megatron.core.utils import deprecate_inference_params, get_pg_size

if not return_logprobs_for_linear_ce_fusion:
return self._original_forward_for_linear_ce_fusion(
input_ids=input_ids,
Expand Down
Loading