feat(tinker): Add support for built-in loss functions and checkpoint control #523
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds two new features to
TinkerBackendthat are fully backwards-compatible with existing code.1. Built-in Loss Functions
Adds support for Tinker's optimized built-in loss functions via new parameters on
TinkerBackend.train():tinker_loss_fn: Select from"importance_sampling","ppo","cispo","dro"tinker_loss_fn_config: Pass loss-specific config (e.g.,{"clip_low_threshold": 0.0, "clip_high_threshold": 6.0})Benefits:
forward_backward_asyncinstead offorward_backward_custom_asyncDefault behavior unchanged - when
tinker_loss_fn=None(default), continues to use ART's custom loss implementation.2. Checkpoint Control
The existing
save_checkpointparameter now controls checkpoint behavior in TinkerBackend:save_checkpoint=True(default): Saves full state + optimizer (enables training resumption)save_checkpoint=False: Only saves sampler weights (fast, for inference only)This enables faster training when full checkpoints are only needed at specific intervals (e.g., at eval steps).
Usage
Files Changed
src/art/dev/train.py: Addedtinker_loss_fn,tinker_loss_fn_config,tinker_save_checkpointto TrainConfigsrc/art/tinker/backend.py: Overrodetrain()with new parameterssrc/art/tinker/service.py: Added dispatch logic for built-in vs custom loss, added_save_sampler_weights_only()methodBackwards Compatibility
All existing code continues to work unchanged. The new parameters are optional with sensible defaults that preserve current behavior.