Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 10 additions & 40 deletions tests/speed/autogram/grad_vs_jac_vs_gram.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import time

import torch
from device import DEVICE
Expand All @@ -24,6 +23,7 @@
)
from utils.tensors import make_inputs_and_targets

from tests.speed.utils import print_times, time_call
from torchjd.aggregation import Mean
from torchjd.autogram import Engine

Expand Down Expand Up @@ -96,50 +96,20 @@ def post_fn():
optionally_cuda_sync()

n_runs = 10
autograd_times = torch.tensor(time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs))
print(f"autograd times (avg = {autograd_times.mean():.5f}, std = {autograd_times.std():.5f}")
print(autograd_times)
print()
autograd_times = time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs)
print_times("autograd", autograd_times)

autograd_gramian_times = torch.tensor(
time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs)
autograd_gramian_times = time_call(
fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs
)
print(
f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = "
f"{autograd_gramian_times.std():.5f}"
)
print(autograd_gramian_times)
print()
print_times("autograd gramian", autograd_gramian_times)

autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs))
print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}")
print(autojac_times)
print()
autojac_times = time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)
print_times("autojac", autojac_times)

engine = Engine(model, batch_dim=0)
autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs))
print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}")
print(autogram_times)
print()


def noop():
pass


def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> list[float]:
init_fn()

times = []
for _ in range(n_runs):
pre_fn()
start = time.perf_counter()
fn()
post_fn()
elapsed_time = time.perf_counter() - start
times.append(elapsed_time)

return times
autogram_times = time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)
print_times("autogram", autogram_times)


def main():
Expand Down
29 changes: 29 additions & 0 deletions tests/speed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import time

import torch
from torch import Tensor


def noop():
pass


def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor:
init_fn()

times = []
for _ in range(n_runs):
pre_fn()
start = time.perf_counter()
fn()
post_fn()
elapsed_time = time.perf_counter() - start
times.append(elapsed_time)

return torch.tensor(times)


def print_times(name: str, times: Tensor) -> None:
print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}")
print(times)
print()
Loading