Skip to content

refactor(autogram): Make engine hook recursively#451

Merged
ValerianRey merged 2 commits intomainfrom
recursive-hooking
Oct 11, 2025
Merged

refactor(autogram): Make engine hook recursively#451
ValerianRey merged 2 commits intomainfrom
recursive-hooking

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Oct 11, 2025

  • Make the Engine hook modules recursively. If a direct parameter exists, hook the module and do not hook its children. If not, try to hook the child modules.
  • This change means that only the parentmost module with direct rg params gets hooked. Its used parameters are thus simply module.parameters(recurse=True) now. This is even the case in special cases where the parent uses the child parameters, so we don't need to have a special case for MHA anymore.
  • Remove _module_utils.py: it's now trivial to know with respect to which parameters to differentiate.
  • Update all usages to now create the Engine with Engine(model) instead of Engine(model.modules()). For partial JD, users have to be more careful, as they should sometimes specify several modules, but these modules should be "disjoint" (i.e. no specified module should be a child of another specified module)
  • This mostly makes a difference on FreeParam. Before, we had 2 hooks (one for the parent, parameterized with the parent's param - aka the free param, and one for the child module, parameterized with the child's params). Now we simply have 1 hook for the parent, parameterized with the all parameters (i.e. parent.parameters(recurse=True)). This is probably faster (because we don't have to do 2 extra forwards and 2 extra backwards for the child, but just 1 now), but maybe a bit more memory consuming (because we have to store the Jacobian wrt the child's params and wrt the parent's free param at the same time). This case is quite niche though, and I still see it as an improvement.
  • Change Engine to take *modules: nn.Module instead of Iterable[nn.Module] (more convenient for the new usage, because we only specify one model 99% of the time). Update the docstring accordingly.

Pros:

  • Cleaner interface
  • No more indirectly_used_params to handle (no more special case for MultiheadAttention)
  • Faster on FreeParam and similar architectures

Cons:

  • Consumes more memory on FreeParam and similar architectures

For reference, here is the definition of FreeParam:

class FreeParam(ShapedModule):
    """
    Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the
    beginning of the forward pass.
    """

    INPUT_SHAPES = (15,)
    OUTPUT_SHAPES = (80,)

    def __init__(self):
        super().__init__()
        self.matrix = nn.Parameter(torch.randn(15, 16))  # Free parameter
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(16, 50)
        self.linear2 = nn.Linear(50, 60)
        self.linear3 = nn.Linear(60, 70)
        self.linear4 = nn.Linear(70, 80)

    def forward(self, input: Tensor) -> Tensor:
        output = self.relu(input @ self.matrix)
        output = self.relu(self.linear1(output))
        output = self.relu(self.linear2(output))
        output = self.relu(self.linear3(output))
        output = self.linear4(output)
        return output

* Make the Engine hook modules recursively. If a direct parameter exists, hook the module and do not hook its children. If not, try to hook the child modules.
* This changes means that only the parentmost module with direct rg params gets hooked. Its used parameters are thus simply module.parameters(recurse=True) now. This is even the case in special cases where the parent uses the child parameters, so we don't need to have a special case for MHA anymore.
* Remove _module_utils: it's now trivial to know with respect to which parameters to differentiate.
* Update all usages to now create the Engine with Engine(model) instead of Engine(model.modules). For partial JD, users have to be more careful, as they should sometimes specify several modules, but these modules should be "disjoint" (i.e. no specified module should be a child of another specified module)
* This mostly makes a difference on WithFreeParam. Before, we had 2 hooks (one for the parent, parameterized with the parent's param - aka the free param, and one for the child module, parameterized with the child's params). Now we simply have 1 hook for the parent, parameterized with the all parameters (i.e. parent.parameters(recurse=True)). This is probably faster (because we don't have to do 2 extra forwards and 2 extra backwards for the child, but just 1 now), but maybe a bit more memory consuming (because we have to store the Jacobian wrt the child's params and wrt the parent's free param at the same time). This case is quite niche though, and I still see it as an improvement.
* Change Engine to take *modules: nn.Module instead of Iterable[nn.Module] (more convenient for the new usage, because we only specify one model 99% of the time). Update the docstring accordingly.
@ValerianRey ValerianRey added package: autojac cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements labels Oct 11, 2025
@ValerianRey ValerianRey self-assigned this Oct 11, 2025
@codecov
Copy link
Copy Markdown

codecov bot commented Oct 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/autogram/_engine.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_module_hook_manager.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_vjp.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it! I still find weird that we have params both in module hook manager and in vjp, but I am working on a solution to that

@ValerianRey ValerianRey merged commit 19d375d into main Oct 11, 2025
17 checks passed
@ValerianRey ValerianRey deleted the recursive-hooking branch October 11, 2025 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: autojac

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants