Skip to content

Question about inference_mode=False in pl.Trainer and enabling gradients in validation and test step #11

@fedeotto

Description

@fedeotto

Hello,

I’ve been following your work with great interest and have spent some time exploring and re-implementing your model to better understand and control its behavior.

While I believe my re-implementation is faithful to your original design, I’ve been having difficulty reproducing the same performance results at a comparable scale. I noticed that in your PyTorch Lightning setup, inference_mode=False is set in the trainer, and gradient computation is enabled during both the validation_step() and test_step() methods.

I haven’t yet applied this approach in my own implementation, so I wanted to ask about the rationale behind allowing gradient computation during inference. This isn’t something I’ve typically encountered, and I’m curious whether it addresses a known issue or could potentially explain some of the performance discrepancies I’m seeing.

Many thanks for your interesting work!

Best regards,
Federico

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions