-
Notifications
You must be signed in to change notification settings - Fork 672
SFT (local backend) #530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SFT (local backend) #530
Conversation
Move batching and shuffling logic from SFTConfig into iterator functions. train_sft now accepts Iterable[List[Trajectory]] instead of individual trajectories, simplifying the API and making batch management more explicit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
src/art/types.py
Outdated
|
|
||
|
|
||
| class SFTConfig(pydantic.BaseModel): | ||
| learning_rate: float = 1e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove custom_lr_schedule
Make learning_rate: float | list[float]
| Used to identify where assistant turns begin (train on responses only). | ||
| """ | ||
|
|
||
| instruction_part: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably can keep this class as empty?
Unsure if instruction_part and response_part is a good fit for experimental feature
src/art/local/backend.py
Outdated
| batch_size = 2 # Default to 2 for SFT | ||
|
|
||
| # Determine learning rates | ||
| if config.custom_lr_schedule and len(config.custom_lr_schedule) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Refactor/Remove
custom_lr_schedule.learning_rateisfloat | list[float] - Add validation for
num_learning_rate==num_batches
src/art/unsloth/service.py
Outdated
|
|
||
| # Save checkpoint after training | ||
| # Name checkpoint by final training step: starting_step + num_batches | ||
| final_step = get_step_from_dir(self.output_dir) + len(sft_batches) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checkpoint step should be still incremented by 1.
Checkpoint step != Gradient step
| response_part="<|im_start|>assistant\n", | ||
| ), | ||
| # Qwen 3 models (with thinking tokens) | ||
| "Qwen/Qwen3-8B": ModelConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- How we decide to support all of this model?
- Prefer to keep it simple and start with model that's widely use in OpenPipe Platform and ART?
- Research Qwen chat template, iirc
<think></think>only show up at the last turn. We may need to remove<think></think>inresponse_partin Qwen.
| progress_bar.close() | ||
|
|
||
|
|
||
| def iterate_file( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have iterate_file take in epoch
See the following PR for reference
| yield _parse_jsonl_line(line) | ||
|
|
||
|
|
||
| async def train_sft_from_file( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify this so user can have the training continue running after closing their laptop.
- Iterate_file(file, epoch)
- Write to local disk
- Upload to wandb artifact
- Calculate lr
- Call train_sft(url, lr)
- Monitor training status
No description provided.