From b2f41ea886115cc7e87cb4d6f0e3afc11a25472e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 17:23:14 +0200 Subject: [PATCH 1/3] Stop reshaping Jacobians in Jac --- src/torchjd/autojac/_transform/_jac.py | 48 +++++++++----------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index fbed6ea0..358ef90c 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -1,10 +1,9 @@ import math from collections.abc import Callable, Sequence from functools import partial -from itertools import accumulate import torch -from torch import Size, Tensor +from torch import Tensor from ._differentiate import Differentiate from ._materialize import materialize @@ -69,7 +68,7 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: ] ) - def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor: + def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]: optional_grads = torch.autograd.grad( self.outputs, self.inputs, @@ -79,15 +78,16 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor: allow_unused=True, ) grads = materialize(optional_grads, inputs=self.inputs) - return torch.concatenate([grad.reshape([-1]) for grad in grads]) + return grads # If the jac_outputs are correct, this value should be the same for all jac_outputs. m = jac_outputs[0].shape[0] max_chunk_size = self.chunk_size if self.chunk_size is not None else m n_chunks = math.ceil(m / max_chunk_size) - # List of tensors of shape [k_i, n] where the k_i's sum to m - jac_matrix_chunks = [] + # One tuple per chunk (i), with one value per input (j), of shape [k_i] + shape[j], + # where k_i is the number of rows in the chunk (the k_i's sum to m) + jacs_chunks: list[tuple[Tensor, ...]] = [] # First differentiations: always retain graph get_vjp_retain = partial(_get_vjp, retain_graph=True) @@ -95,27 +95,22 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor: start = i * max_chunk_size end = (i + 1) * max_chunk_size jac_outputs_chunk = [jac_output[start:end] for jac_output in jac_outputs] - jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_retain)) + jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_retain)) # Last differentiation: retain the graph only if self.retain_graph==True get_vjp_last = partial(_get_vjp, retain_graph=self.retain_graph) start = (n_chunks - 1) * max_chunk_size jac_outputs_chunk = [jac_output[start:] for jac_output in jac_outputs] - jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_last)) + jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last)) - jac_matrix = torch.vstack(jac_matrix_chunks) - lengths = [input.numel() for input in self.inputs] - jac_matrices = _extract_sub_matrices(jac_matrix, lengths) + n_inputs = len(self.inputs) + jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs)) + return jacs - shapes = [input.shape for input in self.inputs] - jacs = _reshape_matrices(jac_matrices, shapes) - return tuple(jacs) - - -def _get_jac_matrix_chunk( - jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], Tensor] -) -> Tensor: +def _get_jacs_chunk( + jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]] +) -> tuple[Tensor, ...]: """ Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by calling get_vjp directly or by wrapping it into a call to ``torch.vmap``, depending on the shape @@ -126,18 +121,7 @@ def _get_jac_matrix_chunk( chunk_size = jac_outputs_chunk[0].shape[0] if chunk_size == 1: grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk] - gradient_vector = get_vjp(grad_outputs) - return gradient_vector.unsqueeze(0) + gradients = get_vjp(grad_outputs) + return tuple(gradient.unsqueeze(0) for gradient in gradients) else: return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk) - - -def _extract_sub_matrices(matrix: Tensor, lengths: Sequence[int]) -> list[Tensor]: - cumulative_lengths = [*accumulate(lengths)] - start_indices = [0] + cumulative_lengths[:-1] - end_indices = cumulative_lengths - return [matrix[:, start:end] for start, end in zip(start_indices, end_indices)] - - -def _reshape_matrices(matrices: Sequence[Tensor], shapes: Sequence[Size]) -> Sequence[Tensor]: - return [matrix.view((matrix.shape[0],) + shape) for matrix, shape in zip(matrices, shapes)] From 2322a86c4da752584d0e173a03a29bf9f5418c63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 17:27:34 +0200 Subject: [PATCH 2/3] Factorize _get_vjp between Grad and Jac --- .../autojac/_transform/_differentiate.py | 14 ++++++++++++++ src/torchjd/autojac/_transform/_grad.py | 11 +---------- src/torchjd/autojac/_transform/_jac.py | 17 ++--------------- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index e095e051..ddd1f064 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +import torch from torch import Tensor from ._base import RequirementError, TensorDict, Transform +from ._materialize import materialize from ._ordered_set import OrderedSet @@ -61,3 +63,15 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: f"outputs {outputs}." ) return set(self.inputs) + + def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]: + optional_grads = torch.autograd.grad( + self.outputs, + self.inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=self.create_graph, + allow_unused=True, + ) + grads = materialize(optional_grads, inputs=self.inputs) + return grads diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/autojac/_transform/_grad.py index 88a846e4..dae694b6 100644 --- a/src/torchjd/autojac/_transform/_grad.py +++ b/src/torchjd/autojac/_transform/_grad.py @@ -4,7 +4,6 @@ from torch import Tensor from ._differentiate import Differentiate -from ._materialize import materialize from ._ordered_set import OrderedSet @@ -54,13 +53,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: if len(self.outputs) == 0: return tuple([torch.zeros_like(input) for input in self.inputs]) - optional_grads = torch.autograd.grad( - self.outputs, - self.inputs, - grad_outputs=grad_outputs, - retain_graph=self.retain_graph, - create_graph=self.create_graph, - allow_unused=True, - ) - grads = materialize(optional_grads, self.inputs) + grads = self._get_vjp(grad_outputs, self.retain_graph) return grads diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 358ef90c..133bf72d 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -6,7 +6,6 @@ from torch import Tensor from ._differentiate import Differentiate -from ._materialize import materialize from ._ordered_set import OrderedSet @@ -68,18 +67,6 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: ] ) - def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]: - optional_grads = torch.autograd.grad( - self.outputs, - self.inputs, - grad_outputs=grad_outputs, - retain_graph=retain_graph, - create_graph=self.create_graph, - allow_unused=True, - ) - grads = materialize(optional_grads, inputs=self.inputs) - return grads - # If the jac_outputs are correct, this value should be the same for all jac_outputs. m = jac_outputs[0].shape[0] max_chunk_size = self.chunk_size if self.chunk_size is not None else m @@ -90,7 +77,7 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor jacs_chunks: list[tuple[Tensor, ...]] = [] # First differentiations: always retain graph - get_vjp_retain = partial(_get_vjp, retain_graph=True) + get_vjp_retain = partial(self._get_vjp, retain_graph=True) for i in range(n_chunks - 1): start = i * max_chunk_size end = (i + 1) * max_chunk_size @@ -98,7 +85,7 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_retain)) # Last differentiation: retain the graph only if self.retain_graph==True - get_vjp_last = partial(_get_vjp, retain_graph=self.retain_graph) + get_vjp_last = partial(self._get_vjp, retain_graph=self.retain_graph) start = (n_chunks - 1) * max_chunk_size jac_outputs_chunk = [jac_output[start:] for jac_output in jac_outputs] jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last)) From 81ff735632392ea779819551c73e987d5e984382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 17:36:35 +0200 Subject: [PATCH 3/3] Add changelog entry --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c6d6bd37..de73934b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ changes that do not affect the user. ### Changed +- Removed an unnecessary internal reshape when computing Jacobians. This should have no effect but a + slight performance improvement in `autojac`. - Revamped documentation. - Made `backward` and `mtl_backward` importable from `torchjd.autojac` (like it was prior to 0.7.0). - Deprecated importing `backward` and `mtl_backward` from `torchjd` directly.