Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ changes that do not affect the user.

### Changed

- Removed an unnecessary internal reshape when computing Jacobians. This should have no effect but a
slight performance improvement in `autojac`.
- Revamped documentation.
- Made `backward` and `mtl_backward` importable from `torchjd.autojac` (like it was prior to 0.7.0).
- Deprecated importing `backward` and `mtl_backward` from `torchjd` directly.
Expand Down
14 changes: 14 additions & 0 deletions src/torchjd/autojac/_transform/_differentiate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence

import torch
from torch import Tensor

from ._base import RequirementError, TensorDict, Transform
from ._materialize import materialize
from ._ordered_set import OrderedSet


Expand Down Expand Up @@ -61,3 +63,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
f"outputs {outputs}."
)
return set(self.inputs)

def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]:
optional_grads = torch.autograd.grad(
self.outputs,
self.inputs,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=self.create_graph,
allow_unused=True,
)
grads = materialize(optional_grads, inputs=self.inputs)
return grads
11 changes: 1 addition & 10 deletions src/torchjd/autojac/_transform/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch import Tensor

from ._differentiate import Differentiate
from ._materialize import materialize
from ._ordered_set import OrderedSet


Expand Down Expand Up @@ -54,13 +53,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
if len(self.outputs) == 0:
return tuple([torch.zeros_like(input) for input in self.inputs])

optional_grads = torch.autograd.grad(
self.outputs,
self.inputs,
grad_outputs=grad_outputs,
retain_graph=self.retain_graph,
create_graph=self.create_graph,
allow_unused=True,
)
grads = materialize(optional_grads, self.inputs)
grads = self._get_vjp(grad_outputs, self.retain_graph)
return grads
61 changes: 16 additions & 45 deletions src/torchjd/autojac/_transform/_jac.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import math
from collections.abc import Callable, Sequence
from functools import partial
from itertools import accumulate

import torch
from torch import Size, Tensor
from torch import Tensor

from ._differentiate import Differentiate
from ._materialize import materialize
from ._ordered_set import OrderedSet


Expand Down Expand Up @@ -69,53 +67,37 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
]
)

def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
optional_grads = torch.autograd.grad(
self.outputs,
self.inputs,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=self.create_graph,
allow_unused=True,
)
grads = materialize(optional_grads, inputs=self.inputs)
return torch.concatenate([grad.reshape([-1]) for grad in grads])

# If the jac_outputs are correct, this value should be the same for all jac_outputs.
m = jac_outputs[0].shape[0]
max_chunk_size = self.chunk_size if self.chunk_size is not None else m
n_chunks = math.ceil(m / max_chunk_size)

# List of tensors of shape [k_i, n] where the k_i's sum to m
jac_matrix_chunks = []
# One tuple per chunk (i), with one value per input (j), of shape [k_i] + shape[j],
# where k_i is the number of rows in the chunk (the k_i's sum to m)
jacs_chunks: list[tuple[Tensor, ...]] = []

# First differentiations: always retain graph
get_vjp_retain = partial(_get_vjp, retain_graph=True)
get_vjp_retain = partial(self._get_vjp, retain_graph=True)
for i in range(n_chunks - 1):
start = i * max_chunk_size
end = (i + 1) * max_chunk_size
jac_outputs_chunk = [jac_output[start:end] for jac_output in jac_outputs]
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_retain))
jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_retain))

# Last differentiation: retain the graph only if self.retain_graph==True
get_vjp_last = partial(_get_vjp, retain_graph=self.retain_graph)
get_vjp_last = partial(self._get_vjp, retain_graph=self.retain_graph)
start = (n_chunks - 1) * max_chunk_size
jac_outputs_chunk = [jac_output[start:] for jac_output in jac_outputs]
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_last))

jac_matrix = torch.vstack(jac_matrix_chunks)
lengths = [input.numel() for input in self.inputs]
jac_matrices = _extract_sub_matrices(jac_matrix, lengths)

shapes = [input.shape for input in self.inputs]
jacs = _reshape_matrices(jac_matrices, shapes)
jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last))

return tuple(jacs)
n_inputs = len(self.inputs)
jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs))
return jacs


def _get_jac_matrix_chunk(
jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], Tensor]
) -> Tensor:
def _get_jacs_chunk(
jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]]
) -> tuple[Tensor, ...]:
"""
Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by
calling get_vjp directly or by wrapping it into a call to ``torch.vmap``, depending on the shape
Expand All @@ -126,18 +108,7 @@ def _get_jac_matrix_chunk(
chunk_size = jac_outputs_chunk[0].shape[0]
if chunk_size == 1:
grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk]
gradient_vector = get_vjp(grad_outputs)
return gradient_vector.unsqueeze(0)
gradients = get_vjp(grad_outputs)
return tuple(gradient.unsqueeze(0) for gradient in gradients)
else:
return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk)


def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]:
cumulative_lengths = [*accumulate(lengths)]
start_indices = [0] + cumulative_lengths[:-1]
end_indices = cumulative_lengths
return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)]


def _reshape_matrices(matrices: Sequence[Tensor], shapes: Sequence[Size]) -> Sequence[Tensor]:
return [matrix.view((matrix.shape[0],) + shape) for matrix, shape in zip(matrices, shapes)]
Loading