Skip to content

Conversation

@AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Jan 21, 2026

What does this PR do?

Type of change: new feature

Overview: Add a subclass of DistillationModel which implements slightly different hooks to inject teacher tensors into corresponding student layers for module replacement purposes, as opposed to logits distillation.

Usage

mtd.convert(model, mode=[("layerwise_kd", config)])

Testing

New units

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Introduced bypass-enabled knowledge distillation mode with layer-level loss mapping for fine-grained model optimization control.
    • Added model export functionality with automatic cleanup of intermediate activation capturing mechanisms.
  • API Changes

    • New bypass_kd mode configuration option available for advanced knowledge distillation workflows.
    • Updated model export interface for improved lifecycle management.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
@AAnoosheh AAnoosheh self-assigned this Jan 21, 2026
@AAnoosheh AAnoosheh requested a review from a team as a code owner January 21, 2026 14:59
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 21, 2026

📝 Walkthrough

Walkthrough

Introduces bypass-enabled knowledge distillation via a new BypassDistillationModel class that extends DistillationModel with hook management for capturing and injecting intermediate layer activations. Adds configuration support, mode descriptor, conversion functions, and comprehensive test coverage.

Changes

Cohort / File(s) Summary
Bypass Model Implementation
modelopt/torch/distill/bypass_distillation_model.py, modelopt/torch/distill/config.py
New BypassDistillationModel class with hook registration for student-teacher layer pairs; captures intermediate inputs/outputs and enables input bypass. New BypassKDConfig enforces criterion structure as dict mapping layer pairs to losses.
Base Model Extension
modelopt/torch/distill/distillation_model.py
Adds hook lifecycle management: _hook_handles attribute, _register_hooks() method, and export() method to clean up hooks. Updates teacher_model property return type from nn.ModuleList to nn.Module.
Mode Integration
modelopt/torch/distill/mode.py, modelopt/torch/distill/__init__.py
Introduces BypassKDModeDescriptor for "bypass_kd" mode with BypassKDConfig and _convert_for_bypass() entrypoint. Updates _convert_for_kd() signature to accept model_cls parameter. Removes MetadataDict and UpdateEntrypoint dependencies from KnowledgeDistillationModeDescriptor.
Test Coverage
tests/unit/torch/distill/test_bypass.py, tests/unit/torch/distill/test_distill.py
Comprehensive new test module for bypass distillation covering hook registration, forward pass behavior, input injection, loss computation, context managers, export cleanup, and gradient flow. Minor cleanup to existing test expectations.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Student as Student<br/>Model
    participant TeacherRef as Teacher Ref<br/>(on Student)
    participant Teacher
    participant Loss as Loss<br/>Computation

    User->>Student: forward(input)
    activate Student
    
    Note over Student: Pre-hook: bypass teacher input<br/>into student computation
    TeacherRef->>Teacher: forward(teacher_input)
    activate Teacher
    Teacher->>Teacher: Capture intermediate<br/>output
    Teacher-->>TeacherRef: Return output
    deactivate Teacher
    
    Student->>Student: Process with<br/>injected input
    Student-->>User: Return output
    deactivate Student
    
    User->>Loss: compute_kd_loss()
    activate Loss
    Loss->>Loss: Compute losses for<br/>each (student, teacher) pair
    Loss-->>User: Return aggregated loss
    deactivate Loss
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'Layerwise KD mode' is vague and does not clearly convey the main change. While the PR adds bypass-style knowledge distillation (which could be considered layerwise), the term 'Layerwise' is non-descriptive without additional context. Consider a more specific title like 'Add bypass-style knowledge distillation mode' or 'Implement BypassDistillationModel for layer-wise KD' that better describes the feature being added.
✅ Passed checks (2 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/distill/mode.py (1)

177-184: Registration caching may use wrong model class when converting the same student class with different modes.

If the same student model class is converted first with kd_loss mode (registers DistillationModel) and later with bypass_kd mode (which passes BypassDistillationModel), the second conversion will skip registration (line 179 condition fails since the class is already registered) and reuse the cached DistillationModel instead of the intended BypassDistillationModel.

The registration key is only the student class type (original_cls), not a tuple of (original_cls, model_cls). The register() method stores the dynamic module class at _registry[nn_cls_] = dm_class on first registration and never updates it, even if the same nn_cls is encountered again with a different model_cls.

Consider:

  • Keying the registration on (original_cls, model_cls) tuple
  • Checking if the registered class matches the intended model_cls and re-registering if needed
  • Documenting this as an unsupported use case
🧹 Nitpick comments (4)
modelopt/torch/distill/distillation_model.py (1)

122-127: Consider cleaning up intermediate output attributes during export.

The export method removes hook handles but leaves _intermediate_output attributes on the student and teacher layers. While these are typically None after compute_kd_loss, they could contain stale tensors if export is called mid-training without loss computation.

♻️ Optional cleanup of intermediate attributes
     def export(self):
         """Export the distillation model."""
         for handle in self._hook_handles:
             handle.remove()
         self._hook_handles.clear()
+        for student_layer, teacher_layer in self._layers_to_loss:
+            if hasattr(student_layer, "_intermediate_output"):
+                delattr(student_layer, "_intermediate_output")
+            if hasattr(teacher_layer, "_intermediate_output"):
+                delattr(teacher_layer, "_intermediate_output")
         return super().export()
tests/unit/torch/distill/test_bypass.py (2)

26-32: Code duplication with test_distill.py.

The get_input_tensor() and tiny_mobilenet() helper functions are duplicated from tests/unit/torch/distill/test_distill.py. Consider extracting these to a shared conftest.py or test utilities module.


63-90: Tests test_bypass_forward_pass and test_bypass_input_injection have overlapping assertions.

Both tests verify that teacher_layer._intermediate_input is None after forward. Consider consolidating or differentiating the test purposes more clearly.

modelopt/torch/distill/mode.py (1)

146-159: Document the new model_cls parameter.

The docstring's Args section doesn't document the new model_cls parameter, and the Returns section mentions only DistillationModel but the function can now return subclasses like BypassDistillationModel.

📝 Suggested docstring update
     """Function for converting a model to a distillation meta-model.

     This is the only utility needed to use the ``modelopt.torch.distill`` API directly.

     Args:
         model: The base model to be used as the student.
         config: A KDLossConfig instance defining the configuration options.
+        model_cls: The distillation model class to use. Defaults to DistillationModel.

     Returns:
-        A ``DistillationModel`` encapsulating the teacher and students, to be used in place
+        A distillation model (of type ``model_cls``) encapsulating the teacher and students, to be used in place
         of the original model.
         Metadata dictionary containing the config arguments needed to recreate the model
         from a checkpoint.
     """

@kevalmorabia97
Copy link
Collaborator

Is there a need to update changelog / distillation readme? Or is this more of an internal feature which we will only use via Puzzletron?

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
@AAnoosheh
Copy link
Contributor Author

Is there a need to update changelog / distillation readme? Or is this more of an internal feature which we will only use via Puzzletron?

For now, via Puzzletron was my idea

Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
@AAnoosheh AAnoosheh changed the title Bypass-style KD mode Layerwise KD mode Jan 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants