-
Notifications
You must be signed in to change notification settings - Fork 243
Add Megatron-Bridge pruning example scripts #800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis PR introduces Megatron-Bridge support to ModelOpt with a new pruning example using Minitron on Qwen3-8B. Changes include new utility functions for loading HF models via Megatron-Bridge, constructing calibration loops, updating documentation, and refactoring the plugins system to prevent auto-registration of the bridge plugin. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as prune_minitron.py
participant Bridge as Megatron-Bridge
participant Dataset as Dataset/Calibration
participant Pruning as ModelOpt Pruning
participant Output as Output Format
CLI->>CLI: Parse arguments & validate config
CLI->>Bridge: load_mbridge_model_from_hf()
Bridge-->>CLI: Return model, provider, unwrapped_model
CLI->>Dataset: get_hf_mbridge_calibration_loop()
Dataset-->>CLI: Return calibration loop closure
CLI->>Pruning: Build NAS search space & config
CLI->>Pruning: Execute pruning with forward_loop
Pruning-->>CLI: Return pruned model
CLI->>Output: Save to Megatron or HF format
Output-->>CLI: Model checkpoint written
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
ae6a842 to
dc1cadc
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #800 +/- ##
==========================================
- Coverage 74.18% 74.18% -0.01%
==========================================
Files 192 192
Lines 19236 19239 +3
==========================================
+ Hits 14271 14272 +1
- Misses 4965 4967 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
2281a23 to
9c79afd
Compare
9c79afd to
a65050a
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
a65050a to
44920ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@examples/megatron_bridge/prune_minitron.py`:
- Around line 205-210: The existence checks for output files should be guarded
by whether the corresponding output path args are set: before calling
os.path.exists with args.output_megatron_path or args.output_hf_path, check that
args.output_megatron_path and args.output_hf_path are truthy respectively;
update the block around the os.path.exists checks (the conditions using
args.output_megatron_path and args.output_hf_path) so you only call
os.path.exists when the arg is not None/empty, and keep the existing
warn_rank_0(...) and return behavior unchanged.
- Around line 167-175: The defaulting logic for
args.prune_intermediate_checkpoint can point to a file under
args.output_megatron_path or args.output_hf_path which may not exist; before
calling mtp.prune (or any operation that writes that checkpoint) create the
parent directory for args.prune_intermediate_checkpoint using os.path.dirname
and os.makedirs(..., exist_ok=True). Ensure the directory creation happens right
after the block that sets args.prune_intermediate_checkpoint and before any
prune/save calls so mtp.prune won't fail due to a missing directory.
In `@examples/megatron_bridge/README.md`:
- Around line 59-62: Remove the duplicate "## Resources" heading: locate the
repeated heading string "## Resources" in the README (the two identical headings
shown in the diff) and delete the redundant one so there is only a single "##
Resources" section header remaining.
In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 162-163: The code sets global_batch_size = micro_batch_size and
does integer division num_iters = num_samples // global_batch_size which yields
zero when num_samples < micro_batch_size; fix by using ceiling division and
guard zero samples: if num_samples <= 0 raise/return early for invalid input,
otherwise compute num_iters = max(1, (num_samples + global_batch_size - 1) //
global_batch_size) (or math.ceil(num_samples / global_batch_size)) so at least
one calibration iteration runs when num_samples > 0; update any callers
expecting num_iters and add a short unit test or assertion near these variables
(global_batch_size, micro_batch_size, num_samples, num_iters).
- Around line 118-135: _get_dataset_cfg currently passes the raw list returned
by get_dataset_samples into DatasetDict, causing a type mismatch because
DatasetDict expects datasets.Dataset instances; convert the list[str] to a
HuggingFace Dataset (e.g., datasets.Dataset.from_dict or from_list) before
constructing DatasetDict and update the process_example_fn to reference the
field name used (for example use Dataset.from_dict({"text": dataset}) and change
process_example_fn to use example["text"]), keeping references to
HFDatasetConfig, DatasetDict, get_dataset_samples, and process_example_fn in the
fix.
🧹 Nitpick comments (4)
modelopt/torch/utils/distributed.py (1)
76-81: Consider usingwarnings.warnwith a filter to avoid repeated warnings.The warning will fire on every call to
local_rank()whenLOCAL_RANKis not set. For workflows that call this function repeatedly, this could produce excessive log noise.♻️ Suggested fix using `warnings.warn` with stacklevel and category
+import functools + +@functools.lru_cache(maxsize=1) +def _warn_local_rank_fallback(): + warn("LOCAL_RANK environment variable not found. Using global rank instead.", stacklevel=3) + def local_rank() -> int: """Returns the local rank of the current process.""" if "LOCAL_RANK" in os.environ: return int(os.environ["LOCAL_RANK"]) - warn("LOCAL_RANK environment variable not found. Using global rank instead.") + _warn_local_rank_fallback() return rank()modelopt/torch/utils/plugins/__init__.py (1)
28-32: LGTM with minor grammar nit.The comment explaining why the Megatron-Bridge plugin is not pre-imported is helpful for maintainability.
✏️ Optional: Fix apostrophes in comment
-# NOTE: Dont pre-import megatron bridge plugin here to avoid circular dependency issues. -# We dont register anything so this isnt a problem. +# NOTE: Don't pre-import megatron bridge plugin here to avoid circular dependency issues. +# We don't register anything so this isn't a problem.modelopt/torch/utils/plugins/mbridge.py (2)
97-109: Consider handlingNemotronHModelProviderexplicitly and improving error messages.
NemotronHModelProvideris imported (line 32) and used inget_hf_mbridge_calibration_loop(line 166), but it falls through to theelsebranch here. If this is intentional, consider documenting it or adding it to the type hints.The
assertstatements provide no context when they fail.♻️ Suggested improvements
if isinstance(provider, MambaModelProvider): provider.mamba_stack_spec = modelopt_mamba_stack_spec else: + # GPTModelProvider and NemotronHModelProvider use transformer_layer_spec provider.transformer_layer_spec = modelopt_transformer_layer_spec provider.finalize() if init_model_parallel: provider.initialize_model_parallel(seed=0) model = provider.provide_distributed_model(wrap_with_ddp=False) - assert len(model) == 1 + assert len(model) == 1, f"Expected single model, got {len(model)} models" unwrapped_model = unwrap_model(model[0]) - assert isinstance(unwrapped_model, (GPTModel, MambaModel)) + assert isinstance(unwrapped_model, (GPTModel, MambaModel)), ( + f"Expected GPTModel or MambaModel, got {type(unwrapped_model)}" + )Also consider updating the return type annotation on line 65 to include
NemotronHModelProviderif it's a supported provider type:GPTModelProvider | MambaModelProvider | NemotronHModelProvider
211-221: Unused parameterminforward_loop- consider prefixing with underscore.The
forward_loopfunction accepts parametermbut uses the outer scopemodelvariable instead. This is likely intentional for ModelOpt API compatibility, but the unused parameter could confuse readers.♻️ Suggested fix: prefix unused parameter
- def forward_loop(m): + def forward_loop(_model): + # NOTE: _model parameter is unused; the Megatron model list from closure is used instead evaluate_and_print_results( state, prefix="iteration 1", forward_step_func=forward_step, data_iterator=train_data_iterator, model=model, config=cfg, verbose=True, write_to_tensorboard=False, )
| if args.prune_intermediate_checkpoint is None: | ||
| if args.output_megatron_path: | ||
| args.prune_intermediate_checkpoint = ( | ||
| f"{args.output_megatron_path}/modelopt_pruning_scores.pth" | ||
| ) | ||
| elif args.output_hf_path: | ||
| args.prune_intermediate_checkpoint = ( | ||
| f"{args.output_hf_path}/modelopt_pruning_scores.pth" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create the pruning-score checkpoint directory before pruning.
Line 167 defaults the checkpoint under the output path, but that directory may not exist yet. If mtp.prune writes the scores before you save outputs, it can fail with “No such file or directory”. Create the parent directory up front.
Suggested fix
if args.prune_intermediate_checkpoint is None:
if args.output_megatron_path:
args.prune_intermediate_checkpoint = (
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
)
elif args.output_hf_path:
args.prune_intermediate_checkpoint = (
f"{args.output_hf_path}/modelopt_pruning_scores.pth"
)
print_rank_0(
"No checkpoint provided to cache intermediate pruning scores. "
f"Setting to: {args.prune_intermediate_checkpoint}"
)
+ checkpoint_dir = os.path.dirname(args.prune_intermediate_checkpoint)
+ if checkpoint_dir:
+ os.makedirs(checkpoint_dir, exist_ok=True)🤖 Prompt for AI Agents
In `@examples/megatron_bridge/prune_minitron.py` around lines 167 - 175, The
defaulting logic for args.prune_intermediate_checkpoint can point to a file
under args.output_megatron_path or args.output_hf_path which may not exist;
before calling mtp.prune (or any operation that writes that checkpoint) create
the parent directory for args.prune_intermediate_checkpoint using
os.path.dirname and os.makedirs(..., exist_ok=True). Ensure the directory
creation happens right after the block that sets
args.prune_intermediate_checkpoint and before any prune/save calls so mtp.prune
won't fail due to a missing directory.
| if os.path.exists(f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"): | ||
| warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...") | ||
| return | ||
| elif os.path.exists(f"{args.output_hf_path}/config.json"): | ||
| warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...") | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard existence checks when output path is unset.
If --output_hf_path is used, args.output_megatron_path is None, and Line 205 checks None/latest_checkpointed_iteration.txt. If a local None/ dir exists, the script may exit incorrectly. Guard with if args.output_megatron_path.
Suggested fix
- if os.path.exists(f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"):
+ if args.output_megatron_path and os.path.exists(
+ f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"
+ ):
warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...")
return
- elif os.path.exists(f"{args.output_hf_path}/config.json"):
+ elif args.output_hf_path and os.path.exists(f"{args.output_hf_path}/config.json"):
warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...")
return🤖 Prompt for AI Agents
In `@examples/megatron_bridge/prune_minitron.py` around lines 205 - 210, The
existence checks for output files should be guarded by whether the corresponding
output path args are set: before calling os.path.exists with
args.output_megatron_path or args.output_hf_path, check that
args.output_megatron_path and args.output_hf_path are truthy respectively;
update the block around the os.path.exists checks (the conditions using
args.output_megatron_path and args.output_hf_path) so you only call
os.path.exists when the arg is not None/empty, and keep the existing
warn_rank_0(...) and return behavior unchanged.
| ## Resources | ||
|
|
||
| ## Resources | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove duplicate “Resources” heading.
The “## Resources” header appears twice, which is redundant.
Suggested fix
-## Resources
-
## Resources🤖 Prompt for AI Agents
In `@examples/megatron_bridge/README.md` around lines 59 - 62, Remove the
duplicate "## Resources" heading: locate the repeated heading string "##
Resources" in the README (the two identical headings shown in the diff) and
delete the redundant one so there is only a single "## Resources" section header
remaining.
| def _get_dataset_cfg(dataset_name: str, num_samples: int, seq_length: int) -> HFDatasetConfig: | ||
| """Get a dataset config for the dataset.""" | ||
| dataset = get_dataset_samples(dataset_name, num_samples) | ||
| dataset_cfg = HFDatasetConfig( | ||
| dataset_name=f"{dataset_name}_{num_samples}", | ||
| dataset_dict=DatasetDict({"train": dataset}), | ||
| process_example_fn=lambda example, tokenizer: {"input": example, "output": ""}, | ||
| seq_length=seq_length, | ||
| dataloader_type="batch", | ||
| num_workers=1, | ||
| do_validation=False, | ||
| do_test=False, | ||
| val_proportion=None, | ||
| split_val_from_train=False, | ||
| rewrite=False, | ||
| ) | ||
|
|
||
| return dataset_cfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's find and examine the get_dataset_samples function to see what it actually returns
rg -n "def get_dataset_samples" --type py -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 891
🏁 Script executed:
#!/bin/bash
# Check how HFDatasetConfig and DatasetDict are used in the codebase
rg -n "HFDatasetConfig\(" --type py -B 2 -A 15 | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 1455
🏁 Script executed:
#!/bin/bash
# Search for DatasetDict usage patterns to see what types are typically passed
rg -n "DatasetDict\(" --type py -A 3 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 447
🌐 Web query:
HuggingFace datasets DatasetDict constructor accepts list or Dataset
💡 Result:
Short answer: No — DatasetDict expects a mapping of split name → Dataset (i.e., a dict of str: Dataset). Construct it with a dict (or use class methods like from_dict/from_parquet). References: Hugging Face docs for DatasetDict (shows it's a dict of splits to Dataset). [1][2]
Sources:
[1] Hugging Face docs — DatasetDict description and examples. [2] Hugging Face package reference — DatasetDict/from_dict examples.
Type mismatch: DatasetDict expects Dataset objects but receives list[str]
get_dataset_samples returns list[str] (confirmed at modelopt/torch/utils/dataset_utils.py:106), but DatasetDict({"train": dataset}) expects a mapping of split names to Dataset objects. The HuggingFace datasets library does not accept raw Python lists for this parameter. Convert the list to a Dataset object before passing it to DatasetDict.
🤖 Prompt for AI Agents
In `@modelopt/torch/utils/plugins/mbridge.py` around lines 118 - 135,
_get_dataset_cfg currently passes the raw list returned by get_dataset_samples
into DatasetDict, causing a type mismatch because DatasetDict expects
datasets.Dataset instances; convert the list[str] to a HuggingFace Dataset
(e.g., datasets.Dataset.from_dict or from_list) before constructing DatasetDict
and update the process_example_fn to reference the field name used (for example
use Dataset.from_dict({"text": dataset}) and change process_example_fn to use
example["text"]), keeping references to HFDatasetConfig, DatasetDict,
get_dataset_samples, and process_example_fn in the fix.
| global_batch_size = micro_batch_size | ||
| num_iters = num_samples // global_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential zero iterations when num_samples < micro_batch_size.
If num_samples is less than micro_batch_size, integer division will result in num_iters = 0, causing no calibration to occur silently.
🐛 Suggested fix: add validation or ceiling division
global_batch_size = micro_batch_size
- num_iters = num_samples // global_batch_size
+ num_iters = num_samples // global_batch_size
+ if num_iters == 0:
+ raise ValueError(
+ f"num_samples ({num_samples}) must be >= micro_batch_size ({micro_batch_size}) "
+ "to have at least one calibration iteration."
+ )🤖 Prompt for AI Agents
In `@modelopt/torch/utils/plugins/mbridge.py` around lines 162 - 163, The code
sets global_batch_size = micro_batch_size and does integer division num_iters =
num_samples // global_batch_size which yields zero when num_samples <
micro_batch_size; fix by using ceiling division and guard zero samples: if
num_samples <= 0 raise/return early for invalid input, otherwise compute
num_iters = max(1, (num_samples + global_batch_size - 1) // global_batch_size)
(or math.ceil(num_samples / global_batch_size)) so at least one calibration
iteration runs when num_samples > 0; update any callers expecting num_iters and
add a short unit test or assertion near these variables (global_batch_size,
micro_batch_size, num_samples, num_iters).
What does this PR do?
Type of change: new example
Megatron-Bridge pruning example scripts (HF input, HF / Megatron output). Also defined some utility functions we can reuse for adding examples for quantization or other optimizations:
modelopt.torch.utils.plugins.mbridge.load_mbridge_model_from_hf: Load HF to MBridge with ModelOpt spec in desired TP/PP/etc configurationmodelopt.torch.utils.plugins.mbridge.get_hf_mbridge_calibration_loop: Createforward_loopfor calibration on a HF datasetmodelopt.torch.utils.dataset_utils(cnn_dailymail,nemotron-post-training-dataset-v2, etc)Usage
From
nvcr.io/nvidian/nemo:26.02.rc1container (mount latest code to/opt/Megatron-Bridgeand/opt/Model-Optimizer)Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.