Skip to content

test(autogram): Add extra test for batched equivalence#445

Closed
ValerianRey wants to merge 3 commits intomainfrom
add-batched-equivalence-test
Closed

test(autogram): Add extra test for batched equivalence#445
ValerianRey wants to merge 3 commits intomainfrom
add-batched-equivalence-test

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Oct 2, 2025

This test can be quite practical, to check that both engines do give the same results. For instance, when working on Transformer, autogram is not equivalent to autograd_gramian yet, but both engines (autogram with batch_dim=None and autogram with batch_dim=0) are equivalent.

It seems that this test does not pass for FreeParam. This is weird, considering that both should be equivalent to autograd_gramian. We may have a bug here. This could also be caused by having two engines, but then again it would be an unexpected behavior.

@ValerianRey ValerianRey added cc: test Conventional commit type for changes to tests. package: autogram labels Oct 2, 2025
@ValerianRey ValerianRey self-assigned this Oct 2, 2025
@ValerianRey ValerianRey changed the title test(autogram): Add extra test for batched equivalance test(autogram): Add extra test for batched equivalence Oct 2, 2025
@ValerianRey
Copy link
Copy Markdown
Contributor Author

I just investigated, and having a 2nd engine makes the result for engine_none be different if and only if engine_0.compute_gramian is called before engine_none.compute_gramian

@ValerianRey
Copy link
Copy Markdown
Contributor Author

ValerianRey commented Oct 2, 2025

The problem can be reproduced even without having two engines.

engine_none = Engine(model.modules(), batch_dim=None)

inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
loss_fn = make_mse_loss_fn(targets)

# Call the model's forward pass twice, to make the hook run twice instead of once
torch.random.manual_seed(0)  # Fix randomness for random models
output = model(inputs)
losses = reduce_to_vector(loss_fn(output))
torch.random.manual_seed(0)  # Fix randomness for random models
output = model(inputs)
losses = reduce_to_vector(loss_fn(output))

gramian_none = engine_none.compute_gramian(losses)

Having two engines (and the one that does an extra forward first) is just a convoluted way of calling twice the model's forward pass, which results in the same bug.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

A solution is to reset the state of the engine between the two forward passes. We don't want the first hook call to have any effect on the state of the engine.

Should we add a public reset_state method to the engine? Or at least a private one that we could use in our tests involving multiple engines?

@PierreQuinton

@codecov
Copy link
Copy Markdown

codecov bot commented Oct 2, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

I'm not sure this is such a good test tbh, I'll probably not merge this.

@PierreQuinton
Copy link
Copy Markdown
Contributor

I like this test. I'm not sure to understand exactly what the problem is with the state. Could we have the reset in the hook? Like at the beginning?

@ValerianRey
Copy link
Copy Markdown
Contributor Author

ValerianRey commented Oct 3, 2025

I like this test. I'm not sure to understand exactly what the problem is with the state. Could we have the reset in the hook? Like at the beginning?

We don't wanna reset at every hook: we wanna reset before the forward pass of the model. So we'd need a model hook for that (which would change the engine constructor) or we could also use a context manager. It would be something like:

engine = ...
with engine.activate_hooks():
    output = ...

losses = ...
gramian = engine.compute_gramian(losses)

The hooks would thus be manually activated, rather than be always activated unless we're in gramian computation phase. I'm not a big fan of that either but it's not that bad. IMO we shouldn't change anything for now, but still think about it a bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: test Conventional commit type for changes to tests. package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants