Skip to content

Add HierarchicalWeighting#414

Closed
ValerianRey wants to merge 46 commits intomainfrom
add-hierarchical-weighting
Closed

Add HierarchicalWeighting#414
ValerianRey wants to merge 46 commits intomainfrom
add-hierarchical-weighting

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Sep 10, 2025

This PR contains 0a25555 which adds the HierarchicalWeighting.

I think we should wait to test this on proper experiments / speed benchmarks, before merging.

PierreQuinton and others added 30 commits August 31, 2025 15:25
Add reshape of jacobian for the scalar output case.

Fix reshape of the Gramian, we for the last half of the dimensions, we need to reshape in the same order as the first, then we move the dimensions. We could in principle create a `reshape_gramian` function that does this, as well as a `move_dim_gramian`

Add a test of values for all four cases of having a batched/non-batched dimension. Tests or reshape/move-dim should work should go in another test.

Remove some tests that do not test anything more than `test_gramian_is_correct`.

Add `_gramian_utils.py` which contains helper to `reshape` and `movedim` on a Gramian.

Add `generate_vmap_rule = True` for `JacobianAccumulator`. This allows vmaping the forward phase. This enables having several Engines defined on the same module.

Add `test_reshape_equivariance`

Add tests to verify that gramian utils yields the correct quadratic forms.

Add tests to verify that gramian utils yields the correct quadratic forms.

Add `test_movedim_equivariance`

Fix warning.

Fix warning.

Remove handles from `ModuleHookManager`

Change `batched_dims` to a single optional `batched_dim`. Fix movedim in `compute_gramian` and add `test_movedim_equivariance`

Remove `grad_output`, can be added later, but should be `jac_output` instead.

Make modules with incompatible batched operations are compatible with non-batched autogram.

Fix doc tests

Provide the autograd vjp for when no dimension is batched. This enables having a single forward in that case which should be faster.

Make VJPs into Callable classes.
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Not the final name I think, but at least it's consistent with the method name
Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
ValerianRey and others added 16 commits September 5, 2025 19:16
At this point, 3 architectures fail: SomeFrozenParam, SomeUnusedParam
and MultiOutputWithFrozenBranch
… variables

This fixes non-batched engien on SomeFrozenParams architecture
…grad in AutogradVJP

This fixes non-batched engine on SomeUnusedParam
…adVJP

This fixes non-batched engine on MultiOutputWithFrozenBranch
Maybe not a definitive name, but I think it's more clear
* Small improvement of clarity
…e can also contain (at most) one element set to -1, the size of that dimension is deduced from the total number of elements
@ValerianRey
Copy link
Copy Markdown
Contributor Author

Closing this PR in favor of an archive branch (archive/add-hierarchical-weighting). We can start a new PR when we want to experiment on this.

@ValerianRey ValerianRey deleted the add-hierarchical-weighting branch September 25, 2025 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants