-
Notifications
You must be signed in to change notification settings - Fork 243
Layerwise KD mode #802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Layerwise KD mode #802
Conversation
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
📝 WalkthroughWalkthroughIntroduces bypass-enabled knowledge distillation via a new BypassDistillationModel class that extends DistillationModel with hook management for capturing and injecting intermediate layer activations. Adds configuration support, mode descriptor, conversion functions, and comprehensive test coverage. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Student as Student<br/>Model
participant TeacherRef as Teacher Ref<br/>(on Student)
participant Teacher
participant Loss as Loss<br/>Computation
User->>Student: forward(input)
activate Student
Note over Student: Pre-hook: bypass teacher input<br/>into student computation
TeacherRef->>Teacher: forward(teacher_input)
activate Teacher
Teacher->>Teacher: Capture intermediate<br/>output
Teacher-->>TeacherRef: Return output
deactivate Teacher
Student->>Student: Process with<br/>injected input
Student-->>User: Return output
deactivate Student
User->>Loss: compute_kd_loss()
activate Loss
Loss->>Loss: Compute losses for<br/>each (student, teacher) pair
Loss-->>User: Return aggregated loss
deactivate Loss
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/distill/mode.py (1)
177-184: Registration caching may use wrong model class when converting the same student class with different modes.If the same student model class is converted first with
kd_lossmode (registersDistillationModel) and later withbypass_kdmode (which passesBypassDistillationModel), the second conversion will skip registration (line 179 condition fails since the class is already registered) and reuse the cachedDistillationModelinstead of the intendedBypassDistillationModel.The registration key is only the student class type (
original_cls), not a tuple of(original_cls, model_cls). Theregister()method stores the dynamic module class at_registry[nn_cls_] = dm_classon first registration and never updates it, even if the samenn_clsis encountered again with a differentmodel_cls.Consider:
- Keying the registration on
(original_cls, model_cls)tuple- Checking if the registered class matches the intended
model_clsand re-registering if needed- Documenting this as an unsupported use case
🧹 Nitpick comments (4)
modelopt/torch/distill/distillation_model.py (1)
122-127: Consider cleaning up intermediate output attributes during export.The export method removes hook handles but leaves
_intermediate_outputattributes on the student and teacher layers. While these are typicallyNoneaftercompute_kd_loss, they could contain stale tensors if export is called mid-training without loss computation.♻️ Optional cleanup of intermediate attributes
def export(self): """Export the distillation model.""" for handle in self._hook_handles: handle.remove() self._hook_handles.clear() + for student_layer, teacher_layer in self._layers_to_loss: + if hasattr(student_layer, "_intermediate_output"): + delattr(student_layer, "_intermediate_output") + if hasattr(teacher_layer, "_intermediate_output"): + delattr(teacher_layer, "_intermediate_output") return super().export()tests/unit/torch/distill/test_bypass.py (2)
26-32: Code duplication with test_distill.py.The
get_input_tensor()andtiny_mobilenet()helper functions are duplicated fromtests/unit/torch/distill/test_distill.py. Consider extracting these to a shared conftest.py or test utilities module.
63-90: Teststest_bypass_forward_passandtest_bypass_input_injectionhave overlapping assertions.Both tests verify that
teacher_layer._intermediate_input is Noneafter forward. Consider consolidating or differentiating the test purposes more clearly.modelopt/torch/distill/mode.py (1)
146-159: Document the newmodel_clsparameter.The docstring's
Argssection doesn't document the newmodel_clsparameter, and theReturnssection mentions onlyDistillationModelbut the function can now return subclasses likeBypassDistillationModel.📝 Suggested docstring update
"""Function for converting a model to a distillation meta-model. This is the only utility needed to use the ``modelopt.torch.distill`` API directly. Args: model: The base model to be used as the student. config: A KDLossConfig instance defining the configuration options. + model_cls: The distillation model class to use. Defaults to DistillationModel. Returns: - A ``DistillationModel`` encapsulating the teacher and students, to be used in place + A distillation model (of type ``model_cls``) encapsulating the teacher and students, to be used in place of the original model. Metadata dictionary containing the config arguments needed to recreate the model from a checkpoint. """
|
Is there a need to update changelog / distillation readme? Or is this more of an internal feature which we will only use via Puzzletron? |
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
For now, via Puzzletron was my idea |
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com>
What does this PR do?
Type of change: new feature
Overview: Add a subclass of
DistillationModelwhich implements slightly different hooks to inject teacher tensors into corresponding student layers for module replacement purposes, as opposed to logits distillation.Usage
Testing
New units
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
API Changes
✏️ Tip: You can customize this high-level summary in your review settings.