From c83d55352b3da431f4f7650993b304754ab358f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 16 Jan 2026 19:04:34 +0100 Subject: [PATCH 1/5] test: Add profiler --- .gitignore | 3 + tests/{speed => profiling}/__init__.py | 0 tests/profiling/run_profiler.py | 153 ++++++++++++++++++ .../speed_grad_vs_jac_vs_gram.py} | 38 ++++- tests/speed/autogram/__init__.py | 0 tests/speed/utils.py | 29 ---- tests/utils/forward_backwards.py | 2 +- 7 files changed, 189 insertions(+), 36 deletions(-) rename tests/{speed => profiling}/__init__.py (100%) create mode 100644 tests/profiling/run_profiler.py rename tests/{speed/autogram/grad_vs_jac_vs_gram.py => profiling/speed_grad_vs_jac_vs_gram.py} (84%) delete mode 100644 tests/speed/autogram/__init__.py delete mode 100644 tests/speed/utils.py diff --git a/.gitignore b/.gitignore index 902e607c..38a5d34e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Profiling results +traces/ + # uv uv.lock diff --git a/tests/speed/__init__.py b/tests/profiling/__init__.py similarity index 100% rename from tests/speed/__init__.py rename to tests/profiling/__init__.py diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py new file mode 100644 index 00000000..b5083f03 --- /dev/null +++ b/tests/profiling/run_profiler.py @@ -0,0 +1,153 @@ +import gc +from pathlib import Path + +import torch +from settings import DEVICE +from torch.profiler import ProfilerActivity, profile +from utils.architectures import ( + AlexNet, + Cifar10Model, + GroupNormMobileNetV3Small, + InstanceNormMobileNetV2, + InstanceNormResNet18, + ModuleFactory, + SqueezeNet, + WithTransformerLarge, +) +from utils.forward_backwards import ( + autogram_forward_backward, + autojac_forward_backward, + make_mse_loss_fn, +) +from utils.tensors import make_inputs_and_targets + +from torchjd.aggregation import Mean +from torchjd.aggregation._mean import MeanWeighting +from torchjd.autogram import Engine + +PARAMETRIZATIONS = [ + (ModuleFactory(WithTransformerLarge), 4), + (ModuleFactory(Cifar10Model), 64), + (ModuleFactory(AlexNet), 4), + (ModuleFactory(InstanceNormResNet18), 4), + (ModuleFactory(GroupNormMobileNetV3Small), 8), + (ModuleFactory(SqueezeNet), 4), + (ModuleFactory(InstanceNormMobileNetV2), 2), +] + + +def main(): + for factory, batch_size in PARAMETRIZATIONS: + profile_autojac(factory, batch_size) + print("\n" + "=" * 80 + "\n") + profile_autogram(factory, batch_size) + print("\n" + "=" * 80 + "\n") + + +def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: + """ + Profiles memory and computation time of autojac forward and backward pass for a given + architecture. + + Prints the result and saves it in the traces folder. The saved traces be viewed using chrome at + chrome://tracing. + + :param factory: A ModuleFactory that creates the model to profile. + :param batch_size: The batch size to use for profiling. + """ + + print(f"autojac: {factory} with batch_size={batch_size} on {DEVICE}:") + + gc.collect() + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + model = factory() + inputs, targets = make_inputs_and_targets(model, batch_size) + loss_fn = make_mse_loss_fn(targets) + aggregator = Mean() + + activities = [ProfilerActivity.CPU] + if DEVICE.type == "cuda": + activities.append(ProfilerActivity.CUDA) + + # Warmup run + autojac_forward_backward(model, inputs, loss_fn, aggregator) + gc.collect() + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Profiled run + model.zero_grad() + with profile( + activities=activities, + profile_memory=True, + record_shapes=True, + with_stack=True, + ) as prof: + autojac_forward_backward(model, inputs, loss_fn, aggregator) + + filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json" + torchjd_dir = Path(__file__).parent.parent.parent + traces_dir = torchjd_dir / "traces" / "autojac" + traces_dir.mkdir(parents=True, exist_ok=True) + trace_path = traces_dir / filename + + prof.export_chrome_trace(str(trace_path)) + print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20)) + + +def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: + """ + Profiles memory and computation time of autogram forward and backward pass for a given + architecture. + + Prints the result and saves it in the traces folder. The saved traces be viewed using chrome at + chrome://tracing. + + :param factory: A ModuleFactory that creates the model to profile. + :param batch_size: The batch size to use for profiling. + """ + + print(f"autogram: {factory} with batch_size={batch_size} on {DEVICE}:") + + gc.collect() + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + model = factory() + inputs, targets = make_inputs_and_targets(model, batch_size) + loss_fn = make_mse_loss_fn(targets) + engine = Engine(model, batch_dim=0) + weighting = MeanWeighting() + + activities = [ProfilerActivity.CPU] + if DEVICE.type == "cuda": + activities.append(ProfilerActivity.CUDA) + + # Warmup run + autogram_forward_backward(model, inputs, loss_fn, engine, weighting) + gc.collect() + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Profiled run + model.zero_grad() + with profile( + activities=activities, + profile_memory=True, + record_shapes=True, + with_stack=True, + ) as prof: + autogram_forward_backward(model, inputs, loss_fn, engine, weighting) + + filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json" + torchjd_dir = Path(__file__).parent.parent.parent + traces_dir = torchjd_dir / "traces" / "autogram" + traces_dir.mkdir(parents=True, exist_ok=True) + trace_path = traces_dir / filename + + prof.export_chrome_trace(str(trace_path)) + print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20)) + + +if __name__ == "__main__": + # To test this on cuda, add the following environment variables when running this: + # CUBLAS_WORKSPACE_CONFIG=:4096:8;PYTEST_TORCH_DEVICE=cuda:0 + main() diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py similarity index 84% rename from tests/speed/autogram/grad_vs_jac_vs_gram.py rename to tests/profiling/speed_grad_vs_jac_vs_gram.py index e6d9233a..8bf490cc 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -1,7 +1,9 @@ import gc +import time import torch from settings import DEVICE +from torch import Tensor from utils.architectures import ( AlexNet, Cifar10Model, @@ -23,7 +25,6 @@ ) from utils.tensors import make_inputs_and_targets -from tests.speed.utils import print_times, time_call from torchjd.aggregation import Mean from torchjd.autogram import Engine @@ -40,6 +41,12 @@ ] +def main(): + for factory, batch_size in PARAMETRIZATIONS: + compare_autograd_autojac_and_autogram_speed(factory, batch_size) + print("\n") + + def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int): model = factory() inputs, targets = make_inputs_and_targets(model, batch_size) @@ -85,7 +92,7 @@ def init_fn_autogram(): fn_autogram() def optionally_cuda_sync(): - if str(DEVICE).startswith("cuda"): + if DEVICE.type == "cuda": torch.cuda.synchronize() def pre_fn(): @@ -112,10 +119,29 @@ def post_fn(): print_times("autogram", autogram_times) -def main(): - for factory, batch_size in PARAMETRIZATIONS: - compare_autograd_autojac_and_autogram_speed(factory, batch_size) - print("\n") +def noop(): + pass + + +def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor: + init_fn() + + times = [] + for _ in range(n_runs): + pre_fn() + start = time.perf_counter() + fn() + post_fn() + elapsed_time = time.perf_counter() - start + times.append(elapsed_time) + + return torch.tensor(times) + + +def print_times(name: str, times: Tensor) -> None: + print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}") + print(times) + print() if __name__ == "__main__": diff --git a/tests/speed/autogram/__init__.py b/tests/speed/autogram/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/speed/utils.py b/tests/speed/utils.py deleted file mode 100644 index 11e87334..00000000 --- a/tests/speed/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -import time - -import torch -from torch import Tensor - - -def noop(): - pass - - -def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor: - init_fn() - - times = [] - for _ in range(n_runs): - pre_fn() - start = time.perf_counter() - fn() - post_fn() - elapsed_time = time.perf_counter() - start - times.append(elapsed_time) - - return torch.tensor(times) - - -def print_times(name: str, times: Tensor) -> None: - print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}") - print(times) - print() diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index b0451e16..970d70bc 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -31,7 +31,7 @@ def autojac_forward_backward( ) -> None: losses = forward_pass(model, inputs, loss_fn, reduce_to_vector) backward(losses) - jac_to_grad(model.parameters(), aggregator) + jac_to_grad(list(model.parameters()), aggregator) def autograd_gramian_forward_backward( From 5c32c1fb4f34755d78306455361faf03bab879e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 19 Jan 2026 22:36:12 +0100 Subject: [PATCH 2/5] Factorize code --- tests/profiling/run_profiler.py | 123 ++++++++++++++------------------ 1 file changed, 52 insertions(+), 71 deletions(-) diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index b5083f03..400bb36b 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -1,5 +1,6 @@ import gc from pathlib import Path +from typing import Callable import torch from settings import DEVICE @@ -36,110 +37,65 @@ ] -def main(): - for factory, batch_size in PARAMETRIZATIONS: - profile_autojac(factory, batch_size) - print("\n" + "=" * 80 + "\n") - profile_autogram(factory, batch_size) - print("\n" + "=" * 80 + "\n") - - -def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: +def profile_method( + method_name: str, + forward_backward_fn: Callable, + factory: ModuleFactory, + batch_size: int, +) -> None: """ - Profiles memory and computation time of autojac forward and backward pass for a given - architecture. - - Prints the result and saves it in the traces folder. The saved traces be viewed using chrome at - chrome://tracing. + Profiles memory and computation time of a forward and backward pass. + :param method_name: Name of the method being profiled (for output paths) + :param forward_backward_fn: Function to execute the forward and backward pass. :param factory: A ModuleFactory that creates the model to profile. :param batch_size: The batch size to use for profiling. """ + print(f"{method_name}: {factory} with batch_size={batch_size} on {DEVICE}:") - print(f"autojac: {factory} with batch_size={batch_size} on {DEVICE}:") - - gc.collect() - torch.cuda.empty_cache() if torch.cuda.is_available() else None - + _clear_unused_memory() model = factory() inputs, targets = make_inputs_and_targets(model, batch_size) loss_fn = make_mse_loss_fn(targets) - aggregator = Mean() - activities = [ProfilerActivity.CPU] - if DEVICE.type == "cuda": - activities.append(ProfilerActivity.CUDA) + activities = _get_profiler_activities() # Warmup run - autojac_forward_backward(model, inputs, loss_fn, aggregator) - gc.collect() - torch.cuda.empty_cache() if torch.cuda.is_available() else None + forward_backward_fn(model, inputs, loss_fn) + model.zero_grad() + _clear_unused_memory() # Profiled run - model.zero_grad() with profile( activities=activities, profile_memory=True, record_shapes=True, with_stack=True, ) as prof: - autojac_forward_backward(model, inputs, loss_fn, aggregator) + forward_backward_fn(model, inputs, loss_fn) - filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json" - torchjd_dir = Path(__file__).parent.parent.parent - traces_dir = torchjd_dir / "traces" / "autojac" - traces_dir.mkdir(parents=True, exist_ok=True) - trace_path = traces_dir / filename - - prof.export_chrome_trace(str(trace_path)) - print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20)) + _save_and_print_trace(prof, method_name, factory, batch_size) -def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: - """ - Profiles memory and computation time of autogram forward and backward pass for a given - architecture. - - Prints the result and saves it in the traces folder. The saved traces be viewed using chrome at - chrome://tracing. - - :param factory: A ModuleFactory that creates the model to profile. - :param batch_size: The batch size to use for profiling. - """ - - print(f"autogram: {factory} with batch_size={batch_size} on {DEVICE}:") - +def _clear_unused_memory() -> None: gc.collect() - torch.cuda.empty_cache() if torch.cuda.is_available() else None + if torch.cuda.is_available(): + torch.cuda.empty_cache() - model = factory() - inputs, targets = make_inputs_and_targets(model, batch_size) - loss_fn = make_mse_loss_fn(targets) - engine = Engine(model, batch_dim=0) - weighting = MeanWeighting() +def _get_profiler_activities() -> list[ProfilerActivity]: activities = [ProfilerActivity.CPU] if DEVICE.type == "cuda": activities.append(ProfilerActivity.CUDA) + return activities - # Warmup run - autogram_forward_backward(model, inputs, loss_fn, engine, weighting) - gc.collect() - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Profiled run - model.zero_grad() - with profile( - activities=activities, - profile_memory=True, - record_shapes=True, - with_stack=True, - ) as prof: - autogram_forward_backward(model, inputs, loss_fn, engine, weighting) +def _save_and_print_trace( + prof: profile, method_name: str, factory: ModuleFactory, batch_size: int +) -> None: filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json" torchjd_dir = Path(__file__).parent.parent.parent - traces_dir = torchjd_dir / "traces" / "autogram" + traces_dir = torchjd_dir / "traces" / method_name traces_dir.mkdir(parents=True, exist_ok=True) trace_path = traces_dir / filename @@ -147,6 +103,31 @@ def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20)) +def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: + def forward_backward_fn(model, inputs, loss_fn): + aggregator = Mean() + autojac_forward_backward(model, inputs, loss_fn, aggregator) + + profile_method("autojac", forward_backward_fn, factory, batch_size) + + +def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: + def forward_backward_fn(model, inputs, loss_fn): + engine = Engine(model, batch_dim=0) + weighting = MeanWeighting() + autogram_forward_backward(model, inputs, loss_fn, engine, weighting) + + profile_method("autogram", forward_backward_fn, factory, batch_size) + + +def main(): + for factory, batch_size in PARAMETRIZATIONS: + profile_autojac(factory, batch_size) + print("\n" + "=" * 80 + "\n") + profile_autogram(factory, batch_size) + print("\n" + "=" * 80 + "\n") + + if __name__ == "__main__": # To test this on cuda, add the following environment variables when running this: # CUBLAS_WORKSPACE_CONFIG=:4096:8;PYTEST_TORCH_DEVICE=cuda:0 From f8c67bf0b7b9391cc7326144e0a092cb72d04489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 19 Jan 2026 22:50:51 +0100 Subject: [PATCH 3/5] Stop recording shapes --- tests/profiling/run_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index 400bb36b..097c6e15 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -69,7 +69,7 @@ def profile_method( with profile( activities=activities, profile_memory=True, - record_shapes=True, + record_shapes=False, # Otherwise some tensors may be referenced longer than normal with_stack=True, ) as prof: forward_backward_fn(model, inputs, loss_fn) From 38ce00e5da0267f55f81282398fd0a91479dd49e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 19 Jan 2026 22:53:40 +0100 Subject: [PATCH 4/5] Use UPGrad for more realism --- tests/profiling/run_profiler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index 097c6e15..e041f994 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -22,8 +22,7 @@ ) from utils.tensors import make_inputs_and_targets -from torchjd.aggregation import Mean -from torchjd.aggregation._mean import MeanWeighting +from torchjd.aggregation import UPGrad, UPGradWeighting from torchjd.autogram import Engine PARAMETRIZATIONS = [ @@ -105,7 +104,7 @@ def _save_and_print_trace( def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: def forward_backward_fn(model, inputs, loss_fn): - aggregator = Mean() + aggregator = UPGrad() autojac_forward_backward(model, inputs, loss_fn, aggregator) profile_method("autojac", forward_backward_fn, factory, batch_size) @@ -114,7 +113,7 @@ def forward_backward_fn(model, inputs, loss_fn): def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: def forward_backward_fn(model, inputs, loss_fn): engine = Engine(model, batch_dim=0) - weighting = MeanWeighting() + weighting = UPGradWeighting() autogram_forward_backward(model, inputs, loss_fn, engine, weighting) profile_method("autogram", forward_backward_fn, factory, batch_size) From 4910f73055cb41b798b3c6c88d45605d5875df3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 19 Jan 2026 22:54:44 +0100 Subject: [PATCH 5/5] Add missing parenthesis --- tests/profiling/speed_grad_vs_jac_vs_gram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index 8bf490cc..16be875e 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -139,7 +139,7 @@ def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> def print_times(name: str, times: Tensor) -> None: - print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}") + print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f})") print(times) print()