Skip to content

test: Add profiler#519

Merged
ValerianRey merged 6 commits intomainfrom
add-memory-profiler
Jan 19, 2026
Merged

test: Add profiler#519
ValerianRey merged 6 commits intomainfrom
add-memory-profiler

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Jan 16, 2026

Not very clean yet (and I'm not sure I'll make it very clean), but this allows:

  • Profiling of autojac and autogram
  • Change of device (cpu or cuda) is supported
  • Various architectures / batch sizes are tested

Changes:

  • Add traces/ to .gitignore
  • Add test.profiling package
  • Move speed tests to test.profiling
  • Add run_profiler in test.profiling
  • Fix call to jac_to_grad in autojac_forward_backward to give it directly the list of model params (instead of a generator)

@ValerianRey ValerianRey added the cc: test Conventional commit type for changes to tests. label Jan 16, 2026
@ValerianRey ValerianRey self-assigned this Jan 16, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Jan 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@claude
Copy link
Copy Markdown

claude bot commented Jan 16, 2026


Code review

No issues found. Checked for bugs and CLAUDE.md compliance.


Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard (impossible?) for me to review, but this is clearly a strict improvement so I vote merge, we can improve later if needed and it is not user facing anyways. Can ask Claude what it thinks.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

@claude Please review this in details and suggest easy improvements (like things that could be factorized, etc)

@claude
Copy link
Copy Markdown

claude bot commented Jan 19, 2026

Claude finished @ValerianRey's task in 1m 34s —— View job


Detailed Code Review

I've reviewed the profiling additions and have several suggestions for easy improvements, particularly around factorization and code organization.

Major Improvements

1. Significant Code Duplication in run_profiler.py

The profile_autojac and profile_autogram functions share ~90% of their code. They can be factored into a single function:

Current Issues (lines 47-147):

  • Near-identical structure between the two functions
  • Duplicate cache clearing logic (lines 61-62, 75-76, 113, 127-128)
  • Duplicate activities setup (lines 69-71, 121-123)
  • Duplicate profiling configuration (lines 80-85, 132-137)
  • Duplicate trace saving logic (lines 88-95, 140-147)

Suggested Refactoring:

def profile_method(
    method_name: str,
    factory: ModuleFactory,
    batch_size: int,
    forward_backward_fn: Callable,
) -> None:
    """
    Profiles memory and computation time of a forward and backward pass.
    
    :param method_name: Name of the method being profiled (for output paths)
    :param factory: A ModuleFactory that creates the model to profile.
    :param batch_size: The batch size to use for profiling.
    :param forward_backward_fn: Function to execute the forward and backward pass.
    """
    print(f"{method_name}: {factory} with batch_size={batch_size} on {DEVICE}:")
    
    _clear_memory()
    
    model = factory()
    inputs, targets = make_inputs_and_targets(model, batch_size)
    loss_fn = make_mse_loss_fn(targets)
    
    activities = _get_profiler_activities()
    
    # Warmup run
    forward_backward_fn(model, inputs, loss_fn)
    _clear_memory()
    
    # Profiled run
    model.zero_grad()
    with profile(
        activities=activities,
        profile_memory=True,
        record_shapes=True,
        with_stack=True,
    ) as prof:
        forward_backward_fn(model, inputs, loss_fn)
    
    _save_and_print_trace(prof, method_name, factory, batch_size)

def _clear_memory() -> None:
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def _get_profiler_activities() -> list[ProfilerActivity]:
    activities = [ProfilerActivity.CPU]
    if DEVICE.type == "cuda":
        activities.append(ProfilerActivity.CUDA)
    return activities

def _save_and_print_trace(
    prof: profile, method_name: str, factory: ModuleFactory, batch_size: int
) -> None:
    filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json"
    torchjd_dir = Path(__file__).parent.parent.parent
    traces_dir = torchjd_dir / "traces" / method_name
    traces_dir.mkdir(parents=True, exist_ok=True)
    trace_path = traces_dir / filename
    
    prof.export_chrome_trace(str(trace_path))
    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20))

def profile_autojac(factory: ModuleFactory, batch_size: int) -> None:
    def forward_backward_fn(model, inputs, loss_fn):
        aggregator = Mean()
        autojac_forward_backward(model, inputs, loss_fn, aggregator)
    
    profile_method("autojac", factory, batch_size, forward_backward_fn)

def profile_autogram(factory: ModuleFactory, batch_size: int) -> None:
    def forward_backward_fn(model, inputs, loss_fn):
        engine = Engine(model, batch_dim=0)
        weighting = MeanWeighting()
        autogram_forward_backward(model, inputs, loss_fn, engine, weighting)
    
    profile_method("autogram", factory, batch_size, forward_backward_fn)

Fix this →

2. Magic Number for row_limit

In run_profiler.py lines 95 and 147, the value 20 is hardcoded:

print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20))

This should be extracted as a module-level constant:

DEFAULT_TABLE_ROW_LIMIT = 20

Fix this →

3. Code Duplication in speed_grad_vs_jac_vs_gram.py

Multiple similar function definitions with nearly identical structure (lines 62-92):

  • fn_autograd, init_fn_autograd
  • fn_autograd_gramian, init_fn_autograd_gramian
  • fn_autojac, init_fn_autojac
  • fn_autogram, init_fn_autogram

These could be factored using a helper function:

def make_timed_functions(fn: Callable) -> tuple[Callable, Callable]:
    """Create timed and init functions for a given callable."""
    def init_fn():
        torch.cuda.empty_cache()
        gc.collect()
        fn()
    
    return fn, init_fn

# Usage:
fn_autograd, init_fn_autograd = make_timed_functions(
    lambda: autograd_forward_backward(model, inputs, loss_fn)
)

Fix this →

Minor Improvements

4. Inconsistent Expression in run_profiler.py

Lines 62, 76, 113, 128 have an odd ternary expression:

torch.cuda.empty_cache() if torch.cuda.is_available() else None

This is more clearly written as:

if torch.cuda.is_available():
    torch.cuda.empty_cache()

Or better yet, inside the suggested _clear_memory() helper.

5. Magic Number for n_runs

In speed_grad_vs_jac_vs_gram.py line 105, the value 10 is hardcoded multiple times. Consider extracting:

DEFAULT_N_RUNS = 10

Fix this →

6. Missing Closing Parenthesis in Print Statement

In speed_grad_vs_jac_vs_gram.py line 142, the f-string is missing a closing parenthesis:

print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}")

Should be:

print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f})")

Fix this →

7. Inconsistent PARAMETRIZATIONS

The PARAMETRIZATIONS constant differs between the two files:

  • run_profiler.py: 7 parametrizations
  • speed_grad_vs_jac_vs_gram.py: 9 parametrizations (includes FreeParam and NoFreeParam)

Consider:

  • Either unifying them if they should be the same
  • Or extracting them to a shared configuration file if they intentionally differ
  • Or documenting why they differ

8. Type Hints

The forward_backward_fn parameter in my suggested refactoring needs proper typing. Consider using Protocol or Callable with full signature.

Positive Aspects

  • Good separation of concerns with the package rename from tests/speed to tests/profiling
  • Proper documentation strings for the profiling functions
  • The warmup run before profiling is excellent practice
  • Chrome trace export is a great debugging feature
  • The fix to jac_to_grad in forward_backwards.py:34 is correct (converting generator to list)

Summary

The main opportunity is reducing duplication in run_profiler.py - the two profiling functions are ~90% identical and can share most of their implementation. The second biggest opportunity is factoring the repeated function definition pattern in speed_grad_vs_jac_vs_gram.py. There's also a syntax error (missing parenthesis) that should be fixed.

These refactorings would significantly improve maintainability without changing functionality. Let me know if you'd like me to implement any of these suggestions!


@ValerianRey ValerianRey merged commit a258254 into main Jan 19, 2026
15 checks passed
@ValerianRey ValerianRey deleted the add-memory-profiler branch January 19, 2026 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: test Conventional commit type for changes to tests.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants