Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,21 @@ class TrainConfig(TypedDict, total=False):
scale_learning_rate_by_reward_std_dev: bool
scale_rewards: bool
truncated_importance_sampling: float | None

# Tinker built-in loss configuration (only used by TinkerBackend)
# When set, uses Tinker's optimized built-in loss instead of ART's custom loss
tinker_loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro"] | None
tinker_loss_fn_config: (
dict[str, float] | None
) # e.g., {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0}

# Tinker checkpoint control (only used by TinkerBackend)
# When False, skips saving full checkpoint (state + optimizer) after training.
# Sampler weights are still saved for inference. Use this for faster training
# when you only need full checkpoints at specific intervals.
tinker_save_checkpoint: bool

# Adam optimizer parameters (only used by TinkerBackend)
adam_beta1: float
adam_beta2: float
adam_eps: float
11 changes: 10 additions & 1 deletion src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,13 @@ def loss_fn(


def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor:
return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad)
"""Shift tensor left by 1 position, padding the right with `pad`.

Handles both 1D tensors (sequence) and 2D tensors (batch x sequence).
"""
if tensor.ndim == 1:
# 1D tensor: just shift and pad
return torch.nn.functional.pad(tensor[1:], (0, 1), value=pad)
else:
# 2D tensor: shift along sequence dimension
return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad)
185 changes: 183 additions & 2 deletions src/art/tinker/backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from typing import Iterable, Literal

from mp_actors import move_to_child_process

from ..local.backend import LocalBackend
from .. import dev
from ..local.backend import LocalBackend, LocalTrainResult
from ..local.service import ModelService
from ..model import TrainableModel
from ..utils.output_dirs import get_model_dir
from ..trajectories import TrajectoryGroup
from ..types import TrainConfig
from ..utils.output_dirs import get_model_dir, get_step_checkpoint_dir


class TinkerBackend(LocalBackend):
Expand All @@ -24,6 +28,183 @@ def __init__(
os.environ["TINKER_API_KEY"] = tinker_api_key
super().__init__(in_process=in_process, path=path)

async def train( # type: ignore[override]
self,
model: TrainableModel,
trajectory_groups: Iterable[TrajectoryGroup],
*,
# Core training parameters
learning_rate: float = 5e-6,
beta: float = 0.0,
# RL algorithm settings (used by ART's custom loss when tinker_loss_fn is None)
ppo: bool = False,
epsilon: float | None = None,
epsilon_high: float | None = None,
# Advantage computation
advantage_balance: float = 0.0,
scale_rewards: bool = True,
# Importance sampling
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
] = "token",
max_negative_advantage_importance_sampling_weight: float | None = None,
mask_prob_ratio: bool = False,
# Experimental parameters
kimi_k2_tau: float | None = None,
precalculate_logprobs: bool = False,
# LocalBackend-specific parameters
allow_training_without_logprobs: bool = False,
plot_tensors: bool = False,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool = False,
logprob_calculation_chunk_size: int = 1024,
num_trajectories_learning_rate_multiplier_power: float = 0.0,
# Checkpoint behavior
save_checkpoint: bool = True,
# Verbosity
verbose: bool = False,
# Tinker-specific: built-in loss function
tinker_loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro"]
| None = None,
tinker_loss_fn_config: dict[str, float] | None = None,
# Adam optimizer parameters
adam_beta1: float | None = None,
adam_beta2: float | None = None,
adam_eps: float | None = None,
) -> LocalTrainResult:
"""Train the model on trajectory groups, with optional Tinker built-in loss.

When tinker_loss_fn is specified, uses Tinker's optimized built-in loss
function (e.g., "cispo", "ppo"). This is faster than ART's custom loss
(1.5x fewer FLOPs, up to 3x faster wall time).

When tinker_loss_fn is None (default), uses ART's custom loss implementation,
which is compatible with other backends like LocalBackend.

Args:
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
beta: KL penalty coefficient. Defaults to 0.0.
tinker_loss_fn: Tinker built-in loss function. Options:
- "importance_sampling": REINFORCE with importance sampling
- "ppo": Proximal Policy Optimization with clipping
- "cispo": Clipped Importance Sampling Policy Optimization
- "dro": Direct Reward Optimization
If None, uses ART's custom loss (controlled by ppo, epsilon, etc.)
tinker_loss_fn_config: Config dict for built-in loss, e.g.:
{"clip_low_threshold": 0.0, "clip_high_threshold": 6.0}
adam_beta1: Adam optimizer beta1 parameter. Defaults to Tinker default (0.9).
adam_beta2: Adam optimizer beta2 parameter. Defaults to Tinker default (0.999).
adam_eps: Adam optimizer epsilon parameter. Defaults to Tinker default (1e-8).
**other_args: See LocalBackend.train() for other parameters.

Returns:
LocalTrainResult with step number, training metrics, and checkpoint path.

Example:
# Use Tinker's built-in CISPO with custom Adam params
result = await backend.train(
model,
trajectory_groups,
learning_rate=5e-6,
tinker_loss_fn="cispo",
tinker_loss_fn_config={"clip_low_threshold": 0.0, "clip_high_threshold": 6.0},
adam_beta1=0.9,
adam_beta2=0.95, # Custom beta2
adam_eps=1e-8,
)

# Use ART's custom loss (default, for compatibility)
result = await backend.train(
model,
trajectory_groups,
learning_rate=5e-6,
ppo=False,
epsilon=1.0,
)
"""
groups_list = list(trajectory_groups)

# Build config objects from explicit kwargs
config = TrainConfig(learning_rate=learning_rate, beta=beta)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": ppo,
"precalculate_logprobs": precalculate_logprobs,
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
"scale_rewards": scale_rewards,
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
}
# Only include optional fields if they're set
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling

# Tinker-specific: built-in loss function
if tinker_loss_fn is not None:
dev_config["tinker_loss_fn"] = tinker_loss_fn
if tinker_loss_fn_config is not None:
dev_config["tinker_loss_fn_config"] = tinker_loss_fn_config

# Tinker-specific: checkpoint control
dev_config["tinker_save_checkpoint"] = save_checkpoint

# Tinker-specific: Adam optimizer parameters
if adam_beta1 is not None:
dev_config["adam_beta1"] = adam_beta1
if adam_beta2 is not None:
dev_config["adam_beta2"] = adam_beta2
if adam_eps is not None:
dev_config["adam_eps"] = adam_eps

# Collect metrics from training
training_metrics: list[dict[str, float]] = []
async for metrics in self._train_model(
model, groups_list, config, dev_config, verbose
):
training_metrics.append(metrics)

# Aggregate metrics
avg_metrics: dict[str, float] = {}
if training_metrics:
avg_metrics = {
k: sum(d.get(k, 0) for d in training_metrics)
/ sum(1 for d in training_metrics if k in d)
for k in {k for d in training_metrics for k in d}
if k != "num_gradient_steps"
}

# Get step and checkpoint path
step = await self._get_step(model)
checkpoint_path: str | None = None
if save_checkpoint:
checkpoint_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path), step
)
if not os.path.exists(checkpoint_path):
checkpoint_path = None

return LocalTrainResult(
step=step,
metrics=avg_metrics,
checkpoint_path=checkpoint_path,
)

async def _get_service(self, model: TrainableModel) -> ModelService:
from ..dev.get_model_config import get_model_config
from ..dev.model import TinkerArgs
Expand Down
Loading
Loading