feat(autogram): Add support for non-Tensor args#442
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
|
The current fail highlights something slightly non-trivial for the batched case (and in particular in Side note: If we implement a mapping from Module to Gramian computer, then we could also specify the in_dims there thus making these two discussions dependent. |
|
I completely agree. I fixed the issue with 49e5d5f The way the in_dims are computed for now is just to give 0 to every tensor (note: grad_outputs are all tensors so no need to provide any None for them). In another PR we can customize the value that we give (and give something that depends on the module, args, and outputs, rather than always 0 for tensors). |
|
With WithModuleWithHybridPyTreeArg, I'm very confident that this works for any kind of args that contains tensors batched on dim 0. |
* With this extra linear module, we now also check that the gradients wrt the args that require grad are correctly backpropagated.
No description provided.