-
Notifications
You must be signed in to change notification settings - Fork 15
refactor(autogram): Use GramianComputers working on modules
#453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
586e17a
Make `GramianAccumulator` track paths to Modules rathre than parameters.
PierreQuinton 7e901e2
Improve docstrings
ValerianRey 41d3b0b
Make InterModuleParamReuse xfail
ValerianRey 72375bb
Merge branch 'main' into gramian-accumulator-handles-modules
ValerianRey 07be8d6
Fix GramianAccumulator tests.
PierreQuinton a2cec0b
Make `_make_path_jacobians` return a list rather than a dict.
PierreQuinton f8856ba
Make `accumulate_path_jacobian` take a `list[Tensor]` of jacobians
PierreQuinton a9d752e
Rename `VJP` into `JacobianComputer` as it does not compute vector ja…
PierreQuinton 89cdc86
functional_call can take a list of dict.
PierreQuinton 3cc2cdf
Fix typing
PierreQuinton 6138a72
Fix typing yet again
PierreQuinton 1a1426c
Makes `JacobianComputer` return a Jacobian Matrix. Adapt `ComputeModu…
PierreQuinton 6750ffa
Improve docstring.
PierreQuinton 22fd4e7
Apparently, if a in_dim corresponding to a PyTree is set to `None`, i…
PierreQuinton 13248cb
Fix precision of tests (for cuda testing)
ValerianRey b2b10d0
Remove ModuleJacobianComputer
ValerianRey 5641382
Make JacobianComputers take rg_outputs at call
ValerianRey ea12581
Move JacobianComputer construction to Engine
ValerianRey ab28096
Add GramianComputer and subclasses
ValerianRey 47551fe
Use GramianComputer instead of JacobianComputer
ValerianRey 9e102e7
Simplify JacobianAccumulator
ValerianRey be6e41d
Move vmap handling to JacobianComputer
ValerianRey 8967ce1
Improve docstrings
ValerianRey b791485
Use _to_gramian in JacobianBasedGramianComputerWithCrossTerms
ValerianRey dd81b94
Remove JacobianBasedGramianComputerWithoutCrossTerms
ValerianRey f3c3961
Simplify how in_dims are computed in FunctionalJacobianComputer
ValerianRey 944d6ac
Merge branch 'main' into gramian-accumulator-handles-modules
ValerianRey cd2a8d8
Fix mypy error
PierreQuinton 0fcd870
Use ___ for variable name
ValerianRey ae356e4
Make _make_gramian_computer protected
ValerianRey 84a6dcf
Add comment
ValerianRey 22a463c
Improve type consistency
ValerianRey 3e6641b
Simplify JacobianComputer docstrings
ValerianRey 08baaf1
Fix comment
ValerianRey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Optional | ||
|
|
||
| from torch import Tensor | ||
| from torch.utils._pytree import PyTree | ||
|
|
||
| from torchjd.autogram._jacobian_computer import JacobianComputer | ||
|
|
||
|
|
||
| class GramianComputer(ABC): | ||
| @abstractmethod | ||
| def __call__( | ||
| self, | ||
| rg_outputs: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> Optional[Tensor]: | ||
| """Compute what we can for a module and optionally return the gramian if it's ready.""" | ||
|
|
||
| def track_forward_call(self) -> None: | ||
| """Track that the module's forward was called. Necessary in some implementations.""" | ||
|
|
||
| def reset(self): | ||
| """Reset state if any. Necessary in some implementations.""" | ||
|
|
||
|
|
||
| class JacobianBasedGramianComputer(GramianComputer, ABC): | ||
| def __init__(self, jacobian_computer): | ||
| self.jacobian_computer = jacobian_computer | ||
|
|
||
| @staticmethod | ||
| def _to_gramian(jacobian: Tensor) -> Tensor: | ||
| return jacobian @ jacobian.T | ||
|
|
||
|
|
||
| class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): | ||
| """ | ||
| Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning | ||
| the gramian. | ||
| """ | ||
|
|
||
| def __init__(self, jacobian_computer: JacobianComputer): | ||
| super().__init__(jacobian_computer) | ||
| self.remaining_counter = 0 | ||
| self.summed_jacobian: Optional[Tensor] = None | ||
|
|
||
| def reset(self) -> None: | ||
| self.remaining_counter = 0 | ||
| self.summed_jacobian = None | ||
|
|
||
| def track_forward_call(self) -> None: | ||
| self.remaining_counter += 1 | ||
|
|
||
| def __call__( | ||
| self, | ||
| rg_outputs: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> Optional[Tensor]: | ||
| """Compute what we can for a module and optionally return the gramian if it's ready.""" | ||
|
|
||
| jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) | ||
|
|
||
| if self.summed_jacobian is None: | ||
| self.summed_jacobian = jacobian_matrix | ||
| else: | ||
| self.summed_jacobian += jacobian_matrix | ||
|
|
||
| self.remaining_counter -= 1 | ||
|
|
||
| if self.remaining_counter == 0: | ||
| gramian = self._to_gramian(self.summed_jacobian) | ||
| del self.summed_jacobian | ||
| return gramian | ||
| else: | ||
| return None | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable | ||
| from typing import cast | ||
|
|
||
| import torch | ||
| from torch import Tensor, nn | ||
| from torch.nn import Parameter | ||
| from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only | ||
|
|
||
| # Note about import from protected _pytree module: | ||
| # PyTorch maintainers plan to make pytree public (see | ||
| # https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400). | ||
| # It should also come with better speed, because the current implementation is slow, according to | ||
| # https://github.com/pytorch/pytorch/issues/65761#issue-1010116111. | ||
| # When pytree becomes public, this import will have to be changed with a conditional import (to | ||
| # still support older versions of PyTorch where pytree is protected). | ||
|
|
||
|
|
||
| class JacobianComputer(ABC): | ||
| """ | ||
| Abstract class to computes Jacobians for a module's forward pass with respect to its parameters. | ||
|
|
||
| :params module: The module to differentiate. | ||
| """ | ||
|
|
||
| def __init__(self, module: nn.Module): | ||
| self.module = module | ||
|
|
||
| self.rg_params = dict[str, Parameter]() | ||
| self.frozen_params = dict[str, Parameter]() | ||
|
|
||
| for name, param in module.named_parameters(recurse=True): | ||
| if param.requires_grad: | ||
| self.rg_params[name] = param | ||
| else: | ||
| self.frozen_params[name] = param | ||
|
|
||
| def __call__( | ||
| self, | ||
| rg_outputs: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> Tensor: | ||
| # This makes __call__ vmappable. | ||
| return ComputeModuleJacobians.apply( | ||
| self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs | ||
| ) | ||
|
|
||
| @abstractmethod | ||
| def _compute_jacobian( | ||
| self, | ||
| rg_outputs: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> Tensor: | ||
| """ | ||
| Computes and returns the Jacobian. The output must be a matrix (2D Tensor). | ||
| """ | ||
|
|
||
|
|
||
| class FunctionalJacobianComputer(JacobianComputer): | ||
| """ | ||
| JacobianComputer using the functional differentiation API. This requires to use vmap, so it's | ||
| not compatible with every module, and it requires to have an extra forward pass to create the | ||
| vjp function. | ||
| """ | ||
|
|
||
| def _compute_jacobian( | ||
| self, | ||
| _: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> Tensor: | ||
| grad_outputs_in_dims = (0,) * len(grad_outputs) | ||
| args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) | ||
| kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) | ||
| in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims) | ||
| vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) | ||
|
|
||
| return vmapped_vjp(grad_outputs, args, kwargs) | ||
|
|
||
| def _call_on_one_instance( | ||
| self, | ||
| grad_outputs_j: tuple[Tensor, ...], | ||
| args_j: tuple[PyTree, ...], | ||
| kwargs_j: dict[str, PyTree], | ||
| ) -> Tensor: | ||
| # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a | ||
| # "batch" of 1 activation (or grad_output). This is because some layers (e.g. | ||
| # nn.Flatten) do not work equivalently if they're provided with a batch or with | ||
| # an element of a batch. We thus always provide them with batches, just of a | ||
| # different size. | ||
| args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) | ||
| kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) | ||
| grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j) | ||
|
|
||
| def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]: | ||
| all_state = [ | ||
| cast(dict[str, Tensor], rg_params), | ||
| dict(self.module.named_buffers()), | ||
| cast(dict[str, Tensor], self.frozen_params), | ||
| ] | ||
| output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) | ||
| flat_outputs = tree_flatten(output)[0] | ||
| rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) | ||
| return rg_outputs | ||
|
|
||
| vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] | ||
|
|
||
| # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the | ||
| # functional has a single primal which is dict(module.named_parameters()). We therefore take | ||
| # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. | ||
| gradients = vjp_func(grad_outputs_j_)[0] | ||
| gradient = torch.cat([t.reshape(-1) for t in gradients.values()]) | ||
| return gradient | ||
|
|
||
|
|
||
| class AutogradJacobianComputer(JacobianComputer): | ||
| """ | ||
| JacobianComputer using the autograd engine. The main advantage of using this method is that it | ||
| doesn't require making an extra forward pass. | ||
| """ | ||
|
|
||
| def _compute_jacobian( | ||
| self, | ||
| rg_outputs: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| _: tuple[PyTree, ...], | ||
| __: dict[str, PyTree], | ||
| ) -> Tensor: | ||
| flat_rg_params, ___ = tree_flatten(self.rg_params) | ||
| grads = torch.autograd.grad( | ||
| rg_outputs, | ||
| flat_rg_params, | ||
| grad_outputs, | ||
| retain_graph=True, | ||
| allow_unused=True, | ||
| materialize_grads=True, | ||
| ) | ||
| flattened_grads = torch.cat([g.reshape(-1) for g in grads]) | ||
| jacobian = flattened_grads.unsqueeze(0) | ||
| return jacobian | ||
|
|
||
|
|
||
| class ComputeModuleJacobians(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| compute_jacobian_fn: Callable[ | ||
| [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor | ||
| ], | ||
| rg_outputs: tuple[Tensor, ...], | ||
| grad_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> Tensor: | ||
| # There is no non-batched dimension | ||
| jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs) | ||
| return jacobian | ||
|
|
||
| @staticmethod | ||
| def vmap( | ||
| _, | ||
| in_dims: tuple[None, None, tuple[int, ...], None, None], | ||
| compute_jacobian_fn: Callable, | ||
| rg_outputs: tuple[Tensor, ...], | ||
| jac_outputs: tuple[Tensor, ...], | ||
| args: tuple[PyTree, ...], | ||
| kwargs: dict[str, PyTree], | ||
| ) -> tuple[Tensor, None]: | ||
| # There is a non-batched dimension | ||
| # We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension | ||
| generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])( | ||
| rg_outputs, | ||
| jac_outputs, | ||
| args, | ||
| kwargs, | ||
| ) | ||
| shape = generalized_jacobian.shape | ||
| jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) | ||
| return jacobian, None | ||
|
|
||
| @staticmethod | ||
| def setup_context(*_) -> None: | ||
| pass |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.