From 891fb9e0bb2b15606552adc6fd00bcb1ce8b1dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Nov 2025 19:17:43 +0100 Subject: [PATCH] test: Extract timing utilities into `utils.py` --- tests/speed/autogram/grad_vs_jac_vs_gram.py | 50 +++++---------------- tests/speed/utils.py | 29 ++++++++++++ 2 files changed, 39 insertions(+), 40 deletions(-) create mode 100644 tests/speed/utils.py diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 7a6556f2..64d33911 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -1,5 +1,4 @@ import gc -import time import torch from device import DEVICE @@ -24,6 +23,7 @@ ) from utils.tensors import make_inputs_and_targets +from tests.speed.utils import print_times, time_call from torchjd.aggregation import Mean from torchjd.autogram import Engine @@ -96,50 +96,20 @@ def post_fn(): optionally_cuda_sync() n_runs = 10 - autograd_times = torch.tensor(time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs)) - print(f"autograd times (avg = {autograd_times.mean():.5f}, std = {autograd_times.std():.5f}") - print(autograd_times) - print() + autograd_times = time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs) + print_times("autograd", autograd_times) - autograd_gramian_times = torch.tensor( - time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs) + autograd_gramian_times = time_call( + fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs ) - print( - f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = " - f"{autograd_gramian_times.std():.5f}" - ) - print(autograd_gramian_times) - print() + print_times("autograd gramian", autograd_gramian_times) - autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)) - print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}") - print(autojac_times) - print() + autojac_times = time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs) + print_times("autojac", autojac_times) engine = Engine(model, batch_dim=0) - autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) - print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") - print(autogram_times) - print() - - -def noop(): - pass - - -def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> list[float]: - init_fn() - - times = [] - for _ in range(n_runs): - pre_fn() - start = time.perf_counter() - fn() - post_fn() - elapsed_time = time.perf_counter() - start - times.append(elapsed_time) - - return times + autogram_times = time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs) + print_times("autogram", autogram_times) def main(): diff --git a/tests/speed/utils.py b/tests/speed/utils.py new file mode 100644 index 00000000..11e87334 --- /dev/null +++ b/tests/speed/utils.py @@ -0,0 +1,29 @@ +import time + +import torch +from torch import Tensor + + +def noop(): + pass + + +def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor: + init_fn() + + times = [] + for _ in range(n_runs): + pre_fn() + start = time.perf_counter() + fn() + post_fn() + elapsed_time = time.perf_counter() - start + times.append(elapsed_time) + + return torch.tensor(times) + + +def print_times(name: str, times: Tensor) -> None: + print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}") + print(times) + print()