Skip to content

Conversation

@kevalmorabia97
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 commented Jan 21, 2026

What does this PR do?

Type of change: new example

Megatron-Bridge pruning example scripts (HF input, HF / Megatron output). Also defined some utility functions we can reuse for adding examples for quantization or other optimizations:

  • modelopt.torch.utils.plugins.mbridge.load_mbridge_model_from_hf: Load HF to MBridge with ModelOpt spec in desired TP/PP/etc configuration
  • modelopt.torch.utils.plugins.mbridge.get_hf_mbridge_calibration_loop: Create forward_loop for calibration on a HF dataset
    • Supports all datasets available in modelopt.torch.utils.dataset_utils (cnn_dailymail, nemotron-post-training-dataset-v2, etc)
    • Supports Micro Batch Size >= 1

Usage

From nvcr.io/nvidian/nemo:26.02.rc1 container (mount latest code to /opt/Megatron-Bridge and /opt/Model-Optimizer)

torchrun --nproc_per_node 2 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
    --hf_model_name_or_path Qwen/Qwen3-8B \
    --prune_target_params 6e9 \
    --hparams_to_skip num_attention_heads \
    --output_hf_path /tmp/Qwen3-8B-Pruned-6B

Testing

  • Manually ran pruning script in nemo:25.11 container (plus modelopt and mbridge mounted to latest) for Qwen3-8B and Nemotron-Nano-9B-v2 with PP=8 and PP=4
  • Added per-PR CI/CD test for example script

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: ‼️ TODO
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

Release Notes

New Features

  • Added new Megatron-Bridge pruning example demonstrating Minitron-based model optimization with advanced pruning configurations.

Documentation

  • Updated core project documentation to highlight Megatron-Bridge as a supported optimization framework.
  • Added comprehensive example documentation for Megatron-Bridge workflows including pruning, distillation, and quantization.
  • Updated pruning guides with Megatron-Bridge integration examples and best practices.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 21, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 21, 2026

📝 Walkthrough

Walkthrough

This PR introduces Megatron-Bridge support to ModelOpt with a new pruning example using Minitron on Qwen3-8B. Changes include new utility functions for loading HF models via Megatron-Bridge, constructing calibration loops, updating documentation, and refactoring the plugins system to prevent auto-registration of the bridge plugin.

Changes

Cohort / File(s) Summary
Configuration & Metadata
.github/CODEOWNERS, CHANGELOG.rst, README.md
Added CODEOWNERS entry for megatron_bridge examples, updated CHANGELOG with new pruning example, and replaced NVIDIA NeMo with NVIDIA Megatron-Bridge in integration target documentation.
New Megatron-Bridge Example
examples/megatron_bridge/README.md, examples/megatron_bridge/prune_minitron.py
Introduced new example directory with pruning documentation and CLI script orchestrating NAS-based pruning of Megatron-Bridge models using Minitron algorithm. Supports Qwen3-8B to 6B reduction with calibration dataset configuration, pruning modes (target params or export config), and dual-format model saving (Megatron or HF checkpoint).
Documentation Updates
examples/pruning/README.md
Updated terminology from Megatron-LM/NeMo to Megatron-LM (M-LM) and Megatron-Bridge (M-Bridge), replaced NeMo container reference (25.11 → 26.02), rewrote code examples to use Megatron-Bridge model loading utilities, and adjusted model applicability notes for pipeline parallelism.
Core Megatron-Bridge Utilities
modelopt/torch/utils/plugins/mbridge.py
New module providing three public functions: load_mbridge_model_from_hf() for instantiating Megatron-Bridge models from HF checkpoints with provider customization, get_hf_mbridge_calibration_loop() for constructing ModelOpt calibration loops, and internal _get_dataset_cfg() for dataset preparation.
Utility Modifications
modelopt/torch/utils/dataset_utils.py
Exposed get_dataset_samples() as public function (previously private _get_dataset_samples) and added to __all__ export list.
Distributed Utility Updates
modelopt/torch/utils/distributed.py
Modified local_rank() to fall back to global rank with warning instead of raising error when LOCAL_RANK environment variable is missing.
Plugin System Refactoring
modelopt/torch/utils/plugins/__init__.py
Commented out auto-import of megatron bridge plugin at module initialization while retaining other plugin imports (megatron_generate, megatron_mmlu, megatron_preprocess_data).

Sequence Diagram(s)

sequenceDiagram
    participant CLI as prune_minitron.py
    participant Bridge as Megatron-Bridge
    participant Dataset as Dataset/Calibration
    participant Pruning as ModelOpt Pruning
    participant Output as Output Format

    CLI->>CLI: Parse arguments & validate config
    CLI->>Bridge: load_mbridge_model_from_hf()
    Bridge-->>CLI: Return model, provider, unwrapped_model
    CLI->>Dataset: get_hf_mbridge_calibration_loop()
    Dataset-->>CLI: Return calibration loop closure
    CLI->>Pruning: Build NAS search space & config
    CLI->>Pruning: Execute pruning with forward_loop
    Pruning-->>CLI: Return pruned model
    CLI->>Output: Save to Megatron or HF format
    Output-->>CLI: Model checkpoint written
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add Megatron-Bridge pruning example scripts' accurately captures the main change: new example scripts for Megatron-Bridge pruning with supporting utility functions, directly reflecting the PR's primary objective.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from ae6a842 to dc1cadc Compare January 21, 2026 09:58
@kevalmorabia97 kevalmorabia97 changed the base branch from main to kmorabia/minitron-auto January 21, 2026 09:59
@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 40.00000% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.18%. Comparing base (945ee02) to head (44920ad).

Files with missing lines Patch % Lines
modelopt/torch/utils/distributed.py 33.33% 2 Missing ⚠️
modelopt/torch/utils/dataset_utils.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #800      +/-   ##
==========================================
- Coverage   74.18%   74.18%   -0.01%     
==========================================
  Files         192      192              
  Lines       19236    19239       +3     
==========================================
+ Hits        14271    14272       +1     
- Misses       4965     4967       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch 2 times, most recently from 2281a23 to 9c79afd Compare January 21, 2026 12:37
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from 9c79afd to a65050a Compare January 21, 2026 20:36
Base automatically changed from kmorabia/minitron-auto to main January 21, 2026 22:34
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/mbridge-pruning branch from a65050a to 44920ad Compare January 22, 2026 10:40
@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review January 22, 2026 10:43
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners January 22, 2026 10:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@examples/megatron_bridge/prune_minitron.py`:
- Around line 205-210: The existence checks for output files should be guarded
by whether the corresponding output path args are set: before calling
os.path.exists with args.output_megatron_path or args.output_hf_path, check that
args.output_megatron_path and args.output_hf_path are truthy respectively;
update the block around the os.path.exists checks (the conditions using
args.output_megatron_path and args.output_hf_path) so you only call
os.path.exists when the arg is not None/empty, and keep the existing
warn_rank_0(...) and return behavior unchanged.
- Around line 167-175: The defaulting logic for
args.prune_intermediate_checkpoint can point to a file under
args.output_megatron_path or args.output_hf_path which may not exist; before
calling mtp.prune (or any operation that writes that checkpoint) create the
parent directory for args.prune_intermediate_checkpoint using os.path.dirname
and os.makedirs(..., exist_ok=True). Ensure the directory creation happens right
after the block that sets args.prune_intermediate_checkpoint and before any
prune/save calls so mtp.prune won't fail due to a missing directory.

In `@examples/megatron_bridge/README.md`:
- Around line 59-62: Remove the duplicate "## Resources" heading: locate the
repeated heading string "## Resources" in the README (the two identical headings
shown in the diff) and delete the redundant one so there is only a single "##
Resources" section header remaining.

In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 162-163: The code sets global_batch_size = micro_batch_size and
does integer division num_iters = num_samples // global_batch_size which yields
zero when num_samples < micro_batch_size; fix by using ceiling division and
guard zero samples: if num_samples <= 0 raise/return early for invalid input,
otherwise compute num_iters = max(1, (num_samples + global_batch_size - 1) //
global_batch_size) (or math.ceil(num_samples / global_batch_size)) so at least
one calibration iteration runs when num_samples > 0; update any callers
expecting num_iters and add a short unit test or assertion near these variables
(global_batch_size, micro_batch_size, num_samples, num_iters).
- Around line 118-135: _get_dataset_cfg currently passes the raw list returned
by get_dataset_samples into DatasetDict, causing a type mismatch because
DatasetDict expects datasets.Dataset instances; convert the list[str] to a
HuggingFace Dataset (e.g., datasets.Dataset.from_dict or from_list) before
constructing DatasetDict and update the process_example_fn to reference the
field name used (for example use Dataset.from_dict({"text": dataset}) and change
process_example_fn to use example["text"]), keeping references to
HFDatasetConfig, DatasetDict, get_dataset_samples, and process_example_fn in the
fix.
🧹 Nitpick comments (4)
modelopt/torch/utils/distributed.py (1)

76-81: Consider using warnings.warn with a filter to avoid repeated warnings.

The warning will fire on every call to local_rank() when LOCAL_RANK is not set. For workflows that call this function repeatedly, this could produce excessive log noise.

♻️ Suggested fix using `warnings.warn` with stacklevel and category
+import functools
+
+@functools.lru_cache(maxsize=1)
+def _warn_local_rank_fallback():
+    warn("LOCAL_RANK environment variable not found. Using global rank instead.", stacklevel=3)
+
 def local_rank() -> int:
     """Returns the local rank of the current process."""
     if "LOCAL_RANK" in os.environ:
         return int(os.environ["LOCAL_RANK"])
-    warn("LOCAL_RANK environment variable not found. Using global rank instead.")
+    _warn_local_rank_fallback()
     return rank()
modelopt/torch/utils/plugins/__init__.py (1)

28-32: LGTM with minor grammar nit.

The comment explaining why the Megatron-Bridge plugin is not pre-imported is helpful for maintainability.

✏️ Optional: Fix apostrophes in comment
-# NOTE: Dont pre-import megatron bridge plugin here to avoid circular dependency issues.
-#   We dont register anything so this isnt a problem.
+# NOTE: Don't pre-import megatron bridge plugin here to avoid circular dependency issues.
+#   We don't register anything so this isn't a problem.
modelopt/torch/utils/plugins/mbridge.py (2)

97-109: Consider handling NemotronHModelProvider explicitly and improving error messages.

  1. NemotronHModelProvider is imported (line 32) and used in get_hf_mbridge_calibration_loop (line 166), but it falls through to the else branch here. If this is intentional, consider documenting it or adding it to the type hints.

  2. The assert statements provide no context when they fail.

♻️ Suggested improvements
     if isinstance(provider, MambaModelProvider):
         provider.mamba_stack_spec = modelopt_mamba_stack_spec
     else:
+        # GPTModelProvider and NemotronHModelProvider use transformer_layer_spec
         provider.transformer_layer_spec = modelopt_transformer_layer_spec
 
     provider.finalize()
     if init_model_parallel:
         provider.initialize_model_parallel(seed=0)
 
     model = provider.provide_distributed_model(wrap_with_ddp=False)
-    assert len(model) == 1
+    assert len(model) == 1, f"Expected single model, got {len(model)} models"
     unwrapped_model = unwrap_model(model[0])
-    assert isinstance(unwrapped_model, (GPTModel, MambaModel))
+    assert isinstance(unwrapped_model, (GPTModel, MambaModel)), (
+        f"Expected GPTModel or MambaModel, got {type(unwrapped_model)}"
+    )

Also consider updating the return type annotation on line 65 to include NemotronHModelProvider if it's a supported provider type:

GPTModelProvider | MambaModelProvider | NemotronHModelProvider

211-221: Unused parameter m in forward_loop - consider prefixing with underscore.

The forward_loop function accepts parameter m but uses the outer scope model variable instead. This is likely intentional for ModelOpt API compatibility, but the unused parameter could confuse readers.

♻️ Suggested fix: prefix unused parameter
-    def forward_loop(m):
+    def forward_loop(_model):
+        # NOTE: _model parameter is unused; the Megatron model list from closure is used instead
         evaluate_and_print_results(
             state,
             prefix="iteration 1",
             forward_step_func=forward_step,
             data_iterator=train_data_iterator,
             model=model,
             config=cfg,
             verbose=True,
             write_to_tensorboard=False,
         )

Comment on lines +167 to +175
if args.prune_intermediate_checkpoint is None:
if args.output_megatron_path:
args.prune_intermediate_checkpoint = (
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
)
elif args.output_hf_path:
args.prune_intermediate_checkpoint = (
f"{args.output_hf_path}/modelopt_pruning_scores.pth"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Create the pruning-score checkpoint directory before pruning.
Line 167 defaults the checkpoint under the output path, but that directory may not exist yet. If mtp.prune writes the scores before you save outputs, it can fail with “No such file or directory”. Create the parent directory up front.

Suggested fix
     if args.prune_intermediate_checkpoint is None:
         if args.output_megatron_path:
             args.prune_intermediate_checkpoint = (
                 f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
             )
         elif args.output_hf_path:
             args.prune_intermediate_checkpoint = (
                 f"{args.output_hf_path}/modelopt_pruning_scores.pth"
             )
         print_rank_0(
             "No checkpoint provided to cache intermediate pruning scores. "
             f"Setting to: {args.prune_intermediate_checkpoint}"
         )
+    checkpoint_dir = os.path.dirname(args.prune_intermediate_checkpoint)
+    if checkpoint_dir:
+        os.makedirs(checkpoint_dir, exist_ok=True)
🤖 Prompt for AI Agents
In `@examples/megatron_bridge/prune_minitron.py` around lines 167 - 175, The
defaulting logic for args.prune_intermediate_checkpoint can point to a file
under args.output_megatron_path or args.output_hf_path which may not exist;
before calling mtp.prune (or any operation that writes that checkpoint) create
the parent directory for args.prune_intermediate_checkpoint using
os.path.dirname and os.makedirs(..., exist_ok=True). Ensure the directory
creation happens right after the block that sets
args.prune_intermediate_checkpoint and before any prune/save calls so mtp.prune
won't fail due to a missing directory.

Comment on lines +205 to +210
if os.path.exists(f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"):
warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...")
return
elif os.path.exists(f"{args.output_hf_path}/config.json"):
warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard existence checks when output path is unset.
If --output_hf_path is used, args.output_megatron_path is None, and Line 205 checks None/latest_checkpointed_iteration.txt. If a local None/ dir exists, the script may exit incorrectly. Guard with if args.output_megatron_path.

Suggested fix
-    if os.path.exists(f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"):
+    if args.output_megatron_path and os.path.exists(
+        f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"
+    ):
         warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...")
         return
-    elif os.path.exists(f"{args.output_hf_path}/config.json"):
+    elif args.output_hf_path and os.path.exists(f"{args.output_hf_path}/config.json"):
         warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...")
         return
🤖 Prompt for AI Agents
In `@examples/megatron_bridge/prune_minitron.py` around lines 205 - 210, The
existence checks for output files should be guarded by whether the corresponding
output path args are set: before calling os.path.exists with
args.output_megatron_path or args.output_hf_path, check that
args.output_megatron_path and args.output_hf_path are truthy respectively;
update the block around the os.path.exists checks (the conditions using
args.output_megatron_path and args.output_hf_path) so you only call
os.path.exists when the arg is not None/empty, and keep the existing
warn_rank_0(...) and return behavior unchanged.

Comment on lines +59 to +62
## Resources

## Resources

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove duplicate “Resources” heading.
The “## Resources” header appears twice, which is redundant.

Suggested fix
-## Resources
-
 ## Resources
🤖 Prompt for AI Agents
In `@examples/megatron_bridge/README.md` around lines 59 - 62, Remove the
duplicate "## Resources" heading: locate the repeated heading string "##
Resources" in the README (the two identical headings shown in the diff) and
delete the redundant one so there is only a single "## Resources" section header
remaining.

Comment on lines +118 to +135
def _get_dataset_cfg(dataset_name: str, num_samples: int, seq_length: int) -> HFDatasetConfig:
"""Get a dataset config for the dataset."""
dataset = get_dataset_samples(dataset_name, num_samples)
dataset_cfg = HFDatasetConfig(
dataset_name=f"{dataset_name}_{num_samples}",
dataset_dict=DatasetDict({"train": dataset}),
process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
seq_length=seq_length,
dataloader_type="batch",
num_workers=1,
do_validation=False,
do_test=False,
val_proportion=None,
split_val_from_train=False,
rewrite=False,
)

return dataset_cfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's find and examine the get_dataset_samples function to see what it actually returns
rg -n "def get_dataset_samples" --type py -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 891


🏁 Script executed:

#!/bin/bash
# Check how HFDatasetConfig and DatasetDict are used in the codebase
rg -n "HFDatasetConfig\(" --type py -B 2 -A 15 | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 1455


🏁 Script executed:

#!/bin/bash
# Search for DatasetDict usage patterns to see what types are typically passed
rg -n "DatasetDict\(" --type py -A 3 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 447


🌐 Web query:

HuggingFace datasets DatasetDict constructor accepts list or Dataset

💡 Result:

Short answer: No — DatasetDict expects a mapping of split name → Dataset (i.e., a dict of str: Dataset). Construct it with a dict (or use class methods like from_dict/from_parquet). References: Hugging Face docs for DatasetDict (shows it's a dict of splits to Dataset). [1][2]

Sources:
[1] Hugging Face docs — DatasetDict description and examples. [2] Hugging Face package reference — DatasetDict/from_dict examples.


Type mismatch: DatasetDict expects Dataset objects but receives list[str]

get_dataset_samples returns list[str] (confirmed at modelopt/torch/utils/dataset_utils.py:106), but DatasetDict({"train": dataset}) expects a mapping of split names to Dataset objects. The HuggingFace datasets library does not accept raw Python lists for this parameter. Convert the list to a Dataset object before passing it to DatasetDict.

🤖 Prompt for AI Agents
In `@modelopt/torch/utils/plugins/mbridge.py` around lines 118 - 135,
_get_dataset_cfg currently passes the raw list returned by get_dataset_samples
into DatasetDict, causing a type mismatch because DatasetDict expects
datasets.Dataset instances; convert the list[str] to a HuggingFace Dataset
(e.g., datasets.Dataset.from_dict or from_list) before constructing DatasetDict
and update the process_example_fn to reference the field name used (for example
use Dataset.from_dict({"text": dataset}) and change process_example_fn to use
example["text"]), keeping references to HFDatasetConfig, DatasetDict,
get_dataset_samples, and process_example_fn in the fix.

Comment on lines +162 to +163
global_batch_size = micro_batch_size
num_iters = num_samples // global_batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential zero iterations when num_samples < micro_batch_size.

If num_samples is less than micro_batch_size, integer division will result in num_iters = 0, causing no calibration to occur silently.

🐛 Suggested fix: add validation or ceiling division
     global_batch_size = micro_batch_size
-    num_iters = num_samples // global_batch_size
+    num_iters = num_samples // global_batch_size
+    if num_iters == 0:
+        raise ValueError(
+            f"num_samples ({num_samples}) must be >= micro_batch_size ({micro_batch_size}) "
+            "to have at least one calibration iteration."
+        )
🤖 Prompt for AI Agents
In `@modelopt/torch/utils/plugins/mbridge.py` around lines 162 - 163, The code
sets global_batch_size = micro_batch_size and does integer division num_iters =
num_samples // global_batch_size which yields zero when num_samples <
micro_batch_size; fix by using ceiling division and guard zero samples: if
num_samples <= 0 raise/return early for invalid input, otherwise compute
num_iters = max(1, (num_samples + global_batch_size - 1) // global_batch_size)
(or math.ceil(num_samples / global_batch_size)) so at least one calibration
iteration runs when num_samples > 0; update any callers expecting num_iters and
add a short unit test or assertion near these variables (global_batch_size,
micro_batch_size, num_samples, num_iters).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants