refactor(autogram): Make VJP take flat grad_outputs.#438
Conversation
…er `output_spec` from both `autograd.Function` in `ModuleHookManager`.
Codecov Report✅ All modified and coverable lines are covered by tests.
... and 1 file with indirect coverage changes 🚀 New features to boost your workflow:
|
|
I really like this PR. Before it, we did unflatten then flatten for AutogradVJP, and now we just don't do any of those. So it should be a performance improvement. |
|
We have a mypy error with this: But it seems that PyTorch is in the wrong here. In the documentation of grad, they say:
However, they type it as @PierreQuinton did I miss something? If not, I'll probably open an issue or a PR in torch. |
|
I just opened an issue in pytorch: pytorch/pytorch#164298 In the meantime, I think it's fine to replace |
Yeah the type hint is wrong, we will update that in another PR though. The thing is when we filter with |
…FunctionalVJP. `output_spec` now only appears in the hook.
…one in another PR.
* They're always flat now
|
There seems to be a problem with macOS runners, I think this is safe to merge anyways, so should we? |
Yes, no need to wait, especially since nothing here is macos-specific (no numerical errors can be made because of this PR). Here is the status btw: https://www.githubstatus.com/ |
This allows removing the parameter
output_specfrom bothautograd.FunctioninModuleHookManager.