Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Profiling results
traces/

# uv
uv.lock

Expand Down
File renamed without changes.
133 changes: 133 additions & 0 deletions tests/profiling/run_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import gc
from pathlib import Path
from typing import Callable

import torch
from settings import DEVICE
from torch.profiler import ProfilerActivity, profile
from utils.architectures import (
AlexNet,
Cifar10Model,
GroupNormMobileNetV3Small,
InstanceNormMobileNetV2,
InstanceNormResNet18,
ModuleFactory,
SqueezeNet,
WithTransformerLarge,
)
from utils.forward_backwards import (
autogram_forward_backward,
autojac_forward_backward,
make_mse_loss_fn,
)
from utils.tensors import make_inputs_and_targets

from torchjd.aggregation import UPGrad, UPGradWeighting
from torchjd.autogram import Engine

PARAMETRIZATIONS = [
(ModuleFactory(WithTransformerLarge), 4),
(ModuleFactory(Cifar10Model), 64),
(ModuleFactory(AlexNet), 4),
(ModuleFactory(InstanceNormResNet18), 4),
(ModuleFactory(GroupNormMobileNetV3Small), 8),
(ModuleFactory(SqueezeNet), 4),
(ModuleFactory(InstanceNormMobileNetV2), 2),
]


def profile_method(
method_name: str,
forward_backward_fn: Callable,
factory: ModuleFactory,
batch_size: int,
) -> None:
"""
Profiles memory and computation time of a forward and backward pass.

:param method_name: Name of the method being profiled (for output paths)
:param forward_backward_fn: Function to execute the forward and backward pass.
:param factory: A ModuleFactory that creates the model to profile.
:param batch_size: The batch size to use for profiling.
"""
print(f"{method_name}: {factory} with batch_size={batch_size} on {DEVICE}:")

_clear_unused_memory()
model = factory()
inputs, targets = make_inputs_and_targets(model, batch_size)
loss_fn = make_mse_loss_fn(targets)

activities = _get_profiler_activities()

# Warmup run
forward_backward_fn(model, inputs, loss_fn)
model.zero_grad()
_clear_unused_memory()

# Profiled run
with profile(
activities=activities,
profile_memory=True,
record_shapes=False, # Otherwise some tensors may be referenced longer than normal
with_stack=True,
) as prof:
forward_backward_fn(model, inputs, loss_fn)

_save_and_print_trace(prof, method_name, factory, batch_size)


def _clear_unused_memory() -> None:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()


def _get_profiler_activities() -> list[ProfilerActivity]:
activities = [ProfilerActivity.CPU]
if DEVICE.type == "cuda":
activities.append(ProfilerActivity.CUDA)
return activities


def _save_and_print_trace(
prof: profile, method_name: str, factory: ModuleFactory, batch_size: int
) -> None:
filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json"
torchjd_dir = Path(__file__).parent.parent.parent
traces_dir = torchjd_dir / "traces" / method_name
traces_dir.mkdir(parents=True, exist_ok=True)
trace_path = traces_dir / filename

prof.export_chrome_trace(str(trace_path))
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20))


def profile_autojac(factory: ModuleFactory, batch_size: int) -> None:
def forward_backward_fn(model, inputs, loss_fn):
aggregator = UPGrad()
autojac_forward_backward(model, inputs, loss_fn, aggregator)

profile_method("autojac", forward_backward_fn, factory, batch_size)


def profile_autogram(factory: ModuleFactory, batch_size: int) -> None:
def forward_backward_fn(model, inputs, loss_fn):
engine = Engine(model, batch_dim=0)
weighting = UPGradWeighting()
autogram_forward_backward(model, inputs, loss_fn, engine, weighting)

profile_method("autogram", forward_backward_fn, factory, batch_size)


def main():
for factory, batch_size in PARAMETRIZATIONS:
profile_autojac(factory, batch_size)
print("\n" + "=" * 80 + "\n")
profile_autogram(factory, batch_size)
print("\n" + "=" * 80 + "\n")


if __name__ == "__main__":
# To test this on cuda, add the following environment variables when running this:
# CUBLAS_WORKSPACE_CONFIG=:4096:8;PYTEST_TORCH_DEVICE=cuda:0
main()
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import gc
import time

import torch
from settings import DEVICE
from torch import Tensor
from utils.architectures import (
AlexNet,
Cifar10Model,
Expand All @@ -23,7 +25,6 @@
)
from utils.tensors import make_inputs_and_targets

from tests.speed.utils import print_times, time_call
from torchjd.aggregation import Mean
from torchjd.autogram import Engine

Expand All @@ -40,6 +41,12 @@
]


def main():
for factory, batch_size in PARAMETRIZATIONS:
compare_autograd_autojac_and_autogram_speed(factory, batch_size)
print("\n")


def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
model = factory()
inputs, targets = make_inputs_and_targets(model, batch_size)
Expand Down Expand Up @@ -85,7 +92,7 @@ def init_fn_autogram():
fn_autogram()

def optionally_cuda_sync():
if str(DEVICE).startswith("cuda"):
if DEVICE.type == "cuda":
torch.cuda.synchronize()

def pre_fn():
Expand All @@ -112,10 +119,29 @@ def post_fn():
print_times("autogram", autogram_times)


def main():
for factory, batch_size in PARAMETRIZATIONS:
compare_autograd_autojac_and_autogram_speed(factory, batch_size)
print("\n")
def noop():
pass


def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor:
init_fn()

times = []
for _ in range(n_runs):
pre_fn()
start = time.perf_counter()
fn()
post_fn()
elapsed_time = time.perf_counter() - start
times.append(elapsed_time)

return torch.tensor(times)


def print_times(name: str, times: Tensor) -> None:
print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f})")
print(times)
print()


if __name__ == "__main__":
Expand Down
Empty file removed tests/speed/autogram/__init__.py
Empty file.
29 changes: 0 additions & 29 deletions tests/speed/utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/utils/forward_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def autojac_forward_backward(
) -> None:
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
backward(losses)
jac_to_grad(model.parameters(), aggregator)
jac_to_grad(list(model.parameters()), aggregator)


def autograd_gramian_forward_backward(
Expand Down
Loading