diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index c4717a7bfc..a1ae4e2d70 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -12,15 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch -from megatron.core.models.gpt import GPTModel -from megatron.core.parallel_state import ( - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, -) -from megatron.core.utils import deprecate_inference_params, get_pg_size from torch.distributed.tensor import DTensor, distribute_tensor from nemo_rl.algorithms.logits_sampling_utils import ( @@ -29,6 +23,11 @@ need_top_k_or_top_p_filtering, ) +if TYPE_CHECKING: + # megatron-core (optional "mcore" extra) is imported lazily below so this + # module imports without mcore installed. + from megatron.core.models.gpt import GPTModel + @torch.no_grad() def _compute_distributed_log_softmax( @@ -2044,6 +2043,8 @@ def backward( def patch_gpt_model_forward_for_linear_ce_fusion(*, chunk_size: int) -> None: + from megatron.core.models.gpt import GPTModel + if getattr(GPTModel, "_linear_ce_fusion_forward_patched", False): GPTModel._linear_ce_fusion_chunk_size = chunk_size return @@ -2054,7 +2055,7 @@ def patch_gpt_model_forward_for_linear_ce_fusion(*, chunk_size: int) -> None: def _gpt_forward_with_linear_ce_fusion( - self: GPTModel, + self: "GPTModel", input_ids: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, @@ -2070,6 +2071,12 @@ def _gpt_forward_with_linear_ce_fusion( padding_mask: Optional[torch.Tensor] = None, return_logprobs_for_linear_ce_fusion: bool = False, ) -> torch.Tensor: + from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + ) + from megatron.core.utils import deprecate_inference_params, get_pg_size + if not return_logprobs_for_linear_ce_fusion: return self._original_forward_for_linear_ce_fusion( input_ids=input_ids,