Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces two major features: multi-GPU training support using DistributedDataParallel (DDP) and experiment tracking with Weights & Biases. The DDP implementation is well-structured, handling process group initialization, data sampling, and metric synchronization correctly. The introduction of a logging protocol with a W&B implementation is a great addition for experiment tracking. I've identified a critical issue regarding DDP support for multiple models, which will cause a crash. I've also made a couple of medium-severity suggestions to improve code clarity and documentation.
| self.finetuned_estimator_.model_.recompute_layer = True # type: ignore | ||
|
|
||
| # --- DDP model wrapping --- | ||
| model_for_optimization = self.finetuned_estimator_.model_ |
There was a problem hiding this comment.
This line assumes that self.finetuned_estimator_ has a single model, as it accesses the .model_ property. However, TabPFNClassifier can be initialized with multiple models, in which case len(self.finetuned_estimator_.models_) > 1 and accessing .model_ will raise a ValueError. Finetuning, especially with DDP, seems designed for a single model. It would be safer to explicitly check for this and raise a more informative error if multiple models are provided for finetuning.
| @@ -0,0 +1 @@ | |||
| Add multi-GPU DDP support for finetuning via torchrun (auto-detected, no code changes needed) | |||
There was a problem hiding this comment.
The changelog entry only mentions the multi-GPU DDP support. This pull request also introduces W&B logging support, which is a significant feature. It would be beneficial to also include this in the changelog to accurately reflect all the changes.
| Add multi-GPU DDP support for finetuning via torchrun (auto-detected, no code changes needed) | |
| Add multi-GPU DDP support for finetuning via torchrun and W&B logging support. |
|
|
||
| # Store the original training size for checkpoint naming | ||
| train_size = X.shape[0] | ||
| start_time = time.monotonic() |
Issue
Closes #810
Motivation and Context
Tracking training runs is important, implemented a logging class that can be expanded for new logger support. First support is for W&B.
This PR builds on top of #812 and should only merged afterwards.
Public API Changes
How Has This Been Tested?
Checklist
changelog/README.md), or "no changelog needed" label requested.