Skip to content
Open

Sft #458

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
498d3df
SFT data iterator
angkywilliam Nov 13, 2025
3bd818f
Add SFT LR utils
angkywilliam Nov 14, 2025
66ec620
train_sft skeleton
angkywilliam Nov 14, 2025
4aeda2f
SFT Shape 0.1
angkywilliam Nov 14, 2025
4ff152b
Add shuffle to SFTConfig
angkywilliam Nov 14, 2025
b6f0380
change SFT args order
angkywilliam Nov 14, 2025
e32db37
Refactor SFT to accept batched trajectories
angkywilliam Nov 18, 2025
9138b07
Tokenize SFT Batch
angkywilliam Nov 19, 2025
18a7897
Add num_trainable_tokens to SFTBatch
angkywilliam Nov 19, 2025
90bf94b
draft train_sft
angkywilliam Nov 19, 2025
12e2142
Flatten trajectory for train_sft
angkywilliam Nov 21, 2025
4ea6c5e
Tokenize SFT Batches support flat list and add padding
angkywilliam Nov 21, 2025
f7bb203
Fix max_length duplicate name issue
angkywilliam Nov 21, 2025
d59e524
Remove unused file
angkywilliam Nov 21, 2025
7f6309a
remove unused typing
angkywilliam Nov 21, 2025
5ec5575
sft iterator
angkywilliam Nov 22, 2025
d6688cf
SFT Iterator
angkywilliam Nov 22, 2025
6c63af5
Use Unsloth for train on response
angkywilliam Nov 25, 2025
d2b39d5
Merge branch 'main' of github.com:OpenPipe/ART into sft
Kovbo Jan 14, 2026
ca5177b
refactoring
Kovbo Jan 14, 2026
c3a06b4
implement local backend SFT training
Kovbo Jan 15, 2026
9cf747d
Add SFT to Local Backend
Kovbo Jan 15, 2026
28205cb
avg loss
Kovbo Jan 15, 2026
64454b1
refactor, sft works good
Kovbo Jan 17, 2026
739eb45
Merge branch 'sft' of github.com:OpenPipe/ART into sft
Kovbo Jan 17, 2026
9918f65
Merge remote-tracking branch 'origin/main' into sft
Kovbo Jan 20, 2026
fb706f9
remove logging
Kovbo Jan 20, 2026
08d87d1
move tokenizer, update backend
Kovbo Jan 20, 2026
0573bc8
update lr schedule and tests
Kovbo Jan 20, 2026
904c3ff
refactor sft training from file
Kovbo Jan 20, 2026
2078d5e
change batch sft
Kovbo Jan 21, 2026
381ac7d
refactor step count based on checkpoints
Kovbo Jan 21, 2026
4bc79ed
update sft warmup script
Kovbo Jan 21, 2026
db6833c
fix model registration
Kovbo Jan 21, 2026
9544df9
make local random
Kovbo Jan 22, 2026
c6b2874
refactor backend
Kovbo Jan 22, 2026
834b37e
refactor
Kovbo Jan 22, 2026
736f259
Merge branch 'main' of github.com:OpenPipe/ART into sft
Kovbo Jan 22, 2026
84e6ceb
update example
Kovbo Jan 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions dev/sft/dataset.jsonl

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions dev/sft/distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Distillation example: Train a small model using completions from a large model."""

import asyncio
import os

from dotenv import load_dotenv
from openai import AsyncOpenAI

import art
from art.local import LocalBackend

load_dotenv()

if not os.environ.get("OPENROUTER_API_KEY"):
raise ValueError("OPENROUTER_API_KEY environment variable is required")

TEACHER_MODEL = "qwen/qwen3-235b-a22b-2507"
STUDENT_BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
PROMPT = "Explain the concept of recursion in programming with a simple example."


async def main():
# Get completion from teacher model
teacher_client = AsyncOpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)

print(f"Getting completion from teacher model ({TEACHER_MODEL})...")
completion = await teacher_client.chat.completions.create(
model=TEACHER_MODEL,
messages=[{"role": "user", "content": PROMPT}],
)
teacher_response = completion.choices[0].message.content
print(
f"Teacher response ({len(teacher_response)} chars):\n{teacher_response[:500]}..."
)

# Create trajectory from teacher completion
trajectory = art.Trajectory(
messages_and_choices=[
{"role": "user", "content": PROMPT},
{"role": "assistant", "content": teacher_response},
],
reward=0.0,
)

# Train student model
backend = LocalBackend()
student = art.TrainableModel(
name="distillation-demo-11",
project="sft-distillation",
base_model=STUDENT_BASE_MODEL,
)
await student.register(backend)

print(f"Training student model ({STUDENT_BASE_MODEL})...")
await student.train_sft(
[trajectory, trajectory, trajectory],
config=art.SFTConfig(learning_rate=2e-4),
verbose=True,
)
print("Training complete!")


if __name__ == "__main__":
asyncio.run(main())
31 changes: 31 additions & 0 deletions dev/sft/sft-from-file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Simple SFT training script using train_sft_from_file helper."""

import asyncio

import art
from art.local import LocalBackend
from art.utils.sft import train_sft_from_file


async def main():
backend = LocalBackend()
model = art.TrainableModel(
name="run-5",
project="sft-from-file",
base_model="Qwen/Qwen2.5-7B-Instruct",
)
await model.register(backend)

await train_sft_from_file(
model=model,
file_path="dev/sft/dataset.jsonl",
epochs=1,
chunk_size=10,
peak_lr=2e-4,
)

print("Training complete!")


if __name__ == "__main__":
asyncio.run(main())
123 changes: 123 additions & 0 deletions dev/sft/sft-warmup-before-rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Minimal example demonstrating SFT -> RL -> SFT switching."""

import asyncio
import os

from dotenv import load_dotenv

import art
from art.local import LocalBackend

# Simple SFT trajectories - teach model to respond "maybe"
SFT_TRAJECTORIES = [
art.Trajectory(
messages_and_choices=[
{"role": "user", "content": "respond with yes or no"},
{"role": "assistant", "content": "maybe"},
],
reward=0.0, # reward unused for SFT
),
] * 10


async def rl_rollout(client, model_name: str, prompt: str) -> art.Trajectory:
"""Single RL rollout with reward based on response."""
messages: art.Messages = [{"role": "user", "content": prompt}]
completion = await client.chat.completions.create(
messages=messages, model=model_name, max_tokens=10, timeout=30
)
choice = completion.choices[0]
content = choice.message.content or ""

# Reward: "maybe" > "no" > "yes" > other
reward = {"maybe": 1.0, "no": 0.75, "yes": 0.5}.get(content.strip().lower(), 0.0)
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)


async def main():
load_dotenv()

backend = LocalBackend()
model = art.TrainableModel(
name="sft-rl-switch-test-13",
project="sft-rl-demo",
base_model="Qwen/Qwen2.5-7B-Instruct",
)
await model.register(backend)

# ========================================================================
# Phase 1: SFT
# ========================================================================
print("\n[Phase 1] SFT training...")
await model.train_sft(
SFT_TRAJECTORIES,
config=art.SFTConfig(learning_rate=2e-6),
)
print("SFT phase 1 complete.")

# ========================================================================
# Phase 2: RL (GRPO)
# ========================================================================
print("\n[Phase 2] RL training...")
client = model.openai_client()
prompt = "respond with yes, no, or maybe"

for i in range(5):
print(f" RL step {i + 1}")
train_groups = await art.gather_trajectory_groups(
[
art.TrajectoryGroup(
rl_rollout(client, model.name, prompt) for _ in range(6)
)
for _ in range(12)
]
)
await model.train(train_groups, config=art.TrainConfig(learning_rate=1e-5))
print("RL phase 2 complete.")

# ========================================================================
# Phase 3: SFT again
# ========================================================================
print("\n[Phase 3] SFT training again...")
await model.train_sft(
SFT_TRAJECTORIES,
config=art.SFTConfig(batch_size=1, learning_rate=2e-6),
)
print("SFT phase 3 complete.")

# ========================================================================
# Phase 4: RL (GRPO) again
# ========================================================================
print("\n[Phase 4] RL training...")
client = model.openai_client()
prompt = "respond with yes, no, or maybe"

for i in range(5):
print(f" RL step {i + 1}")
train_groups = await art.gather_trajectory_groups(
[
art.TrajectoryGroup(
rl_rollout(client, model.name, prompt) for _ in range(6)
)
for _ in range(12)
]
)
await model.train(train_groups, config=art.TrainConfig(learning_rate=1e-5))
print("RL phase 4 complete.")

# ========================================================================
# Test: Check model output
# ========================================================================
print("\n[Test] Model output after training:")
completion = await client.chat.completions.create(
messages=[{"role": "user", "content": "respond with yes or no"}],
model=model.name,
max_tokens=10,
)
print(f"Response: {completion.choices[0].message.content}")

print("\nAll phases complete!")


if __name__ == "__main__":
asyncio.run(main())
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies = [
"openai>=2.14.0",
"typer>=0.15.2",
"litellm>=1.74.1",
"weave>=0.52.23",
"weave>=0.52.24",
"tinker>=0.7.0",
"tinker-cookbook>=0.1.0",
"polars>=1.26.0",
Expand All @@ -32,7 +32,7 @@ backend = [
"accelerate==1.7.0",
"awscli>=1.38.1",
"setuptools>=78.1.0",
"wandb==0.23.1",
"wandb==0.24.0",
"transformers>=4.55.2,<=4.57.3",
"duckdb>=1.0.0",
"pyarrow>=15.0.0",
Expand Down
2 changes: 2 additions & 0 deletions src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, **kwargs):
Messages,
MessagesAndChoices,
ServerlessTrainResult,
SFTConfig,
Tools,
TrainConfig,
TrainResult,
Expand All @@ -89,6 +90,7 @@ def __init__(self, **kwargs):
"Model",
"TrainableModel",
"retry",
"SFTConfig",
"TrainConfig",
"TrainResult",
"TinkerBackend",
Expand Down
19 changes: 17 additions & 2 deletions src/art/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)

from . import dev
from .trajectories import TrajectoryGroup
from .types import TrainConfig, TrainResult
from .trajectories import Trajectory, TrajectoryGroup
from .types import SFTConfig, TrainConfig, TrainResult

if TYPE_CHECKING:
from .model import Model, TrainableModel
Expand Down Expand Up @@ -149,6 +149,21 @@ async def _train_model(
if pbar is not None:
pbar.close()

async def _train_sft(
self,
model: "TrainableModel",
trajectories: Iterable[Trajectory],
config: SFTConfig,
dev_config: dev.SFTConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
raise NotImplementedError(
"SFT training is not yet implemented. "
"This method will be available in a future release."
)
# This yield is unreachable but makes this an async generator
yield # type: ignore

# ------------------------------------------------------------------
# Experimental support for S3
# ------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion src/art/dev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
TrainerArgs,
)
from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config
from .train import TrainConfig
from .train import SFTConfig, TrainConfig

__all__ = [
"EngineArgs",
Expand All @@ -21,5 +21,6 @@
"get_openai_server_config",
"OpenAIServerConfig",
"ServerArgs",
"SFTConfig",
"TrainConfig",
]
17 changes: 9 additions & 8 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,27 @@ def get_openai_server_config(
config = OpenAIServerConfig()
log_file = config.get("log_file", log_file)

# Extract step from lora_path for multi-checkpoint support
# lora_path format is: {output_dir}/checkpoints/{step:04d}
lora_name = model_name
# Build LoRA modules list for multi-checkpoint support
# Register under both model_name (for "current" model) and model_name@step (for specific checkpoint)
lora_modules: list[str] | None = None
if lora_path:
step = int(os.path.basename(lora_path))
lora_name = f"{model_name}@{step}"
lora_modules = [
f'{{"name": "{model_name}", "path": "{lora_path}"}}',
f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}',
]

server_args = ServerArgs(
api_key="default",
lora_modules=(
[f'{{"name": "{lora_name}", "path": "{lora_path}"}}'] if lora_path else None
),
lora_modules=lora_modules,
return_tokens_as_token_ids=True,
enable_auto_tool_choice=True,
tool_call_parser="hermes",
)
server_args.update(config.get("server_args", {}))
engine_args = EngineArgs(
model=base_model,
served_model_name=base_model if lora_path else model_name,
served_model_name=model_name,
generation_config="vllm",
)
engine_args.update(config.get("engine_args", {}))
Expand Down
14 changes: 14 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,17 @@ class TrainConfig(TypedDict, total=False):
scale_learning_rate_by_reward_std_dev: bool
scale_rewards: bool
truncated_importance_sampling: float | None


class SFTConfig(TypedDict, total=False):
"""Experimental SFT configuration options. Use at your own risk.

Undocumented options (may change):
instruction_part: Override auto-detected instruction marker for tokenization.
Used to identify where user turns begin in the chat template.
response_part: Override auto-detected response marker for tokenization.
Used to identify where assistant turns begin (train on responses only).
"""

instruction_part: str
response_part: str
Loading
Loading