From ce4ed8f90926b1956050d5086e5264031d752642 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Tue, 20 Jan 2026 02:27:28 +0000 Subject: [PATCH 1/7] feat(tinker): add support for built-in loss functions and checkpoint control Add two new features to TinkerBackend: 1. Built-in loss functions (tinker_loss_fn, tinker_loss_fn_config) - Supports Tinker's optimized losses: importance_sampling, ppo, cispo, dro - Uses forward_backward_async instead of forward_backward_custom_async - ~1.5x fewer FLOPs, up to 3x faster (per Tinker docs) - Default behavior unchanged (uses ART's custom loss) 2. Checkpoint control (save_checkpoint parameter) - When False, only saves sampler weights (fast, for inference) - When True (default), saves full state + optimizer (for resumption) - Enables faster training when full checkpoints only needed at intervals Both features are backwards-compatible - existing code works unchanged. --- src/art/dev/train.py | 11 +++ src/art/tinker/backend.py | 167 +++++++++++++++++++++++++++++++++++++- src/art/tinker/service.py | 134 +++++++++++++++++++++++------- 3 files changed, 280 insertions(+), 32 deletions(-) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index bd415074..2842f856 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -27,3 +27,14 @@ class TrainConfig(TypedDict, total=False): scale_learning_rate_by_reward_std_dev: bool scale_rewards: bool truncated_importance_sampling: float | None + + # Tinker built-in loss configuration (only used by TinkerBackend) + # When set, uses Tinker's optimized built-in loss instead of ART's custom loss + tinker_loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro"] | None + tinker_loss_fn_config: dict[str, float] | None # e.g., {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0} + + # Tinker checkpoint control (only used by TinkerBackend) + # When False, skips saving full checkpoint (state + optimizer) after training. + # Sampler weights are still saved for inference. Use this for faster training + # when you only need full checkpoints at specific intervals. + tinker_save_checkpoint: bool diff --git a/src/art/tinker/backend.py b/src/art/tinker/backend.py index e6a79aca..b48dcf80 100644 --- a/src/art/tinker/backend.py +++ b/src/art/tinker/backend.py @@ -1,11 +1,15 @@ import os +from typing import Iterable, Literal from mp_actors import move_to_child_process -from ..local.backend import LocalBackend +from .. import dev +from ..local.backend import LocalBackend, LocalTrainResult from ..local.service import ModelService from ..model import TrainableModel -from ..utils.output_dirs import get_model_dir +from ..trajectories import TrajectoryGroup +from ..types import TrainConfig +from ..utils.output_dirs import get_model_dir, get_step_checkpoint_dir class TinkerBackend(LocalBackend): @@ -24,6 +28,165 @@ def __init__( os.environ["TINKER_API_KEY"] = tinker_api_key super().__init__(in_process=in_process, path=path) + async def train( # type: ignore[override] + self, + model: TrainableModel, + trajectory_groups: Iterable[TrajectoryGroup], + *, + # Core training parameters + learning_rate: float = 5e-6, + beta: float = 0.0, + # RL algorithm settings (used by ART's custom loss when tinker_loss_fn is None) + ppo: bool = False, + epsilon: float | None = None, + epsilon_high: float | None = None, + # Advantage computation + advantage_balance: float = 0.0, + scale_rewards: bool = True, + # Importance sampling + importance_sampling_level: Literal[ + "token", "sequence", "average", "geometric_average" + ] = "token", + max_negative_advantage_importance_sampling_weight: float | None = None, + mask_prob_ratio: bool = False, + # Experimental parameters + kimi_k2_tau: float | None = None, + precalculate_logprobs: bool = False, + # LocalBackend-specific parameters + allow_training_without_logprobs: bool = False, + plot_tensors: bool = False, + truncated_importance_sampling: float | None = None, + scale_learning_rate_by_reward_std_dev: bool = False, + logprob_calculation_chunk_size: int = 1024, + num_trajectories_learning_rate_multiplier_power: float = 0.0, + # Checkpoint behavior + save_checkpoint: bool = True, + # Verbosity + verbose: bool = False, + # Tinker-specific: built-in loss function + tinker_loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro"] + | None = None, + tinker_loss_fn_config: dict[str, float] | None = None, + ) -> LocalTrainResult: + """Train the model on trajectory groups, with optional Tinker built-in loss. + + When tinker_loss_fn is specified, uses Tinker's optimized built-in loss + function (e.g., "cispo", "ppo"). This is faster than ART's custom loss + (1.5x fewer FLOPs, up to 3x faster wall time). + + When tinker_loss_fn is None (default), uses ART's custom loss implementation, + which is compatible with other backends like LocalBackend. + + Args: + model: The trainable model to train. + trajectory_groups: Batches of trajectories to train on. + learning_rate: Learning rate for training. Defaults to 5e-6. + beta: KL penalty coefficient. Defaults to 0.0. + tinker_loss_fn: Tinker built-in loss function. Options: + - "importance_sampling": REINFORCE with importance sampling + - "ppo": Proximal Policy Optimization with clipping + - "cispo": Clipped Importance Sampling Policy Optimization + - "dro": Direct Reward Optimization + If None, uses ART's custom loss (controlled by ppo, epsilon, etc.) + tinker_loss_fn_config: Config dict for built-in loss, e.g.: + {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0} + **other_args: See LocalBackend.train() for other parameters. + + Returns: + LocalTrainResult with step number, training metrics, and checkpoint path. + + Example: + # Use Tinker's built-in CISPO (recommended for speed) + result = await backend.train( + model, + trajectory_groups, + learning_rate=5e-6, + tinker_loss_fn="cispo", + tinker_loss_fn_config={"clip_low_threshold": 0.0, "clip_high_threshold": 6.0}, + ) + + # Use ART's custom loss (default, for compatibility) + result = await backend.train( + model, + trajectory_groups, + learning_rate=5e-6, + ppo=False, + epsilon=1.0, + ) + """ + groups_list = list(trajectory_groups) + + # Build config objects from explicit kwargs + config = TrainConfig(learning_rate=learning_rate, beta=beta) + dev_config: dev.TrainConfig = { + "advantage_balance": advantage_balance, + "allow_training_without_logprobs": allow_training_without_logprobs, + "importance_sampling_level": importance_sampling_level, + "mask_prob_ratio": mask_prob_ratio, + "plot_tensors": plot_tensors, + "ppo": ppo, + "precalculate_logprobs": precalculate_logprobs, + "scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev, + "scale_rewards": scale_rewards, + "logprob_calculation_chunk_size": logprob_calculation_chunk_size, + "num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power, + } + # Only include optional fields if they're set + if epsilon is not None: + dev_config["epsilon"] = epsilon + if epsilon_high is not None: + dev_config["epsilon_high"] = epsilon_high + if max_negative_advantage_importance_sampling_weight is not None: + dev_config["max_negative_advantage_importance_sampling_weight"] = ( + max_negative_advantage_importance_sampling_weight + ) + if kimi_k2_tau is not None: + dev_config["kimi_k2_tau"] = kimi_k2_tau + if truncated_importance_sampling is not None: + dev_config["truncated_importance_sampling"] = truncated_importance_sampling + + # Tinker-specific: built-in loss function + if tinker_loss_fn is not None: + dev_config["tinker_loss_fn"] = tinker_loss_fn + if tinker_loss_fn_config is not None: + dev_config["tinker_loss_fn_config"] = tinker_loss_fn_config + + # Tinker-specific: checkpoint control + dev_config["tinker_save_checkpoint"] = save_checkpoint + + # Collect metrics from training + training_metrics: list[dict[str, float]] = [] + async for metrics in self._train_model( + model, groups_list, config, dev_config, verbose + ): + training_metrics.append(metrics) + + # Aggregate metrics + avg_metrics: dict[str, float] = {} + if training_metrics: + avg_metrics = { + k: sum(d.get(k, 0) for d in training_metrics) + / sum(1 for d in training_metrics if k in d) + for k in {k for d in training_metrics for k in d} + if k != "num_gradient_steps" + } + + # Get step and checkpoint path + step = await self._get_step(model) + checkpoint_path: str | None = None + if save_checkpoint: + checkpoint_path = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=self._path), step + ) + if not os.path.exists(checkpoint_path): + checkpoint_path = None + + return LocalTrainResult( + step=step, + metrics=avg_metrics, + checkpoint_path=checkpoint_path, + ) + async def _get_service(self, model: TrainableModel) -> ModelService: from ..dev.get_model_config import get_model_config from ..dev.model import TinkerArgs diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 1db653c0..70437edd 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -104,6 +104,10 @@ async def train( packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) state = await self._state_task + # Check if using Tinker's built-in loss function + tinker_loss_fn = _config.get("tinker_loss_fn") + tinker_loss_fn_config = _config.get("tinker_loss_fn_config") + def custom_loss_fn( _: list[tinker.Datum], logprobs_list: list[torch.Tensor], @@ -134,35 +138,69 @@ def custom_loss_fn( ][0] ] ] - forward_backward_output_future = ( - await state.training_client.forward_backward_custom_async( - data=[ - tinker.Datum( - loss_fn_inputs={ - "target_tokens": tinker.TensorData.from_torch( - shifted_tokens[i][mask] - ), - "weights": tinker.TensorData.from_torch( - torch.ones_like( - shifted_tokens[i][mask], dtype=torch.float32 - ) - ), - }, - model_input=tinker.ModelInput.from_ints( - packed_tensors["tokens"][i][mask].tolist() + + if tinker_loss_fn: + # Use Tinker's optimized built-in loss function + # Build datums with logprobs and advantages for the built-in loss + datums = [ + tinker.Datum( + loss_fn_inputs={ + "target_tokens": tinker.TensorData.from_torch( + shifted_tokens[i][mask] + ), + "logprobs": tinker.TensorData.from_torch( + shift_tensor(packed_tensors["logprobs"][i], 0.0)[mask] ), - ) - for mask in masks - ], - loss_fn=partial( - custom_loss_fn, - masks=masks, - inputs=create_train_inputs( - packed_tensors, i, config, _config, False + "advantages": tinker.TensorData.from_torch( + shift_tensor(packed_tensors["advantages"][i], 0.0)[mask] + ), + }, + model_input=tinker.ModelInput.from_ints( + packed_tensors["tokens"][i][mask].tolist() ), - ), + ) + for mask in masks + ] + forward_backward_output_future = ( + await state.training_client.forward_backward_async( + data=datums, + loss_fn=tinker_loss_fn, + loss_fn_config=tinker_loss_fn_config, + ) ) - ) + else: + # Use ART's custom loss function (default behavior) + datums = [ + tinker.Datum( + loss_fn_inputs={ + "target_tokens": tinker.TensorData.from_torch( + shifted_tokens[i][mask] + ), + "weights": tinker.TensorData.from_torch( + torch.ones_like( + shifted_tokens[i][mask], dtype=torch.float32 + ) + ), + }, + model_input=tinker.ModelInput.from_ints( + packed_tensors["tokens"][i][mask].tolist() + ), + ) + for mask in masks + ] + forward_backward_output_future = ( + await state.training_client.forward_backward_custom_async( + data=datums, + loss_fn=partial( + custom_loss_fn, + masks=masks, + inputs=create_train_inputs( + packed_tensors, i, config, _config, False + ), + ), + ) + ) + optim_step_future = await state.training_client.optim_step_async( adam_params=tinker.AdamParams(learning_rate=config.learning_rate), ) @@ -173,13 +211,26 @@ def custom_loss_fn( **forward_backward_output.metrics, **(optim_step_response.metrics or {}), } + + # Save checkpoint or just sampler weights based on config last_checkpoint_dir = self._get_last_checkpoint_dir() assert last_checkpoint_dir is not None, "No checkpoint found" next_step = int(last_checkpoint_dir.name) + 1 - new_sampler_client = await self._save_checkpoint( - last_checkpoint_dir.with_name(f"{next_step:04d}"), - state.training_client, - ) + + save_checkpoint = _config.get("tinker_save_checkpoint", True) + if save_checkpoint: + # Full checkpoint: saves training state + optimizer + sampler weights + new_sampler_client = await self._save_checkpoint( + last_checkpoint_dir.with_name(f"{next_step:04d}"), + state.training_client, + ) + else: + # Fast path: only save sampler weights for inference + new_sampler_client = await self._save_sampler_weights_only( + next_step, + state.training_client, + ) + # Add new sampler client to the dict and update latest step state.sampler_clients[next_step] = new_sampler_client state.latest_step = next_step @@ -274,6 +325,10 @@ def _get_last_checkpoint_dir(self) -> Path | None: async def _save_checkpoint( self, checkpoint_dir: Path, training_client: tinker.TrainingClient ) -> tinker.SamplingClient: + """Save full checkpoint (training state + optimizer + sampler weights). + + This is slower but enables full resumption of training. + """ with log_timing("Saving Tinker checkpoint"): state_response, sampler_response = await asyncio.gather( *await asyncio.gather( @@ -296,6 +351,25 @@ async def _save_checkpoint( ) return sampling_client + async def _save_sampler_weights_only( + self, step: int, training_client: tinker.TrainingClient + ) -> tinker.SamplingClient: + """Save only sampler weights (fast, for inference only). + + This is faster but does NOT save optimizer state, so training cannot + be resumed from this step. Use this when you only need the model for + inference and will save full checkpoints at specific intervals. + """ + with log_timing("Saving sampler weights"): + sampler_response = await ( + await training_client.save_weights_for_sampler_async(f"{step:04d}") + ) + with log_timing("Creating Tinker sampling client"): + sampling_client = await training_client.create_sampling_client_async( + model_path=sampler_response.path + ) + return sampling_client + async def _run_openai_server( self, config: dev.OpenAIServerConfig | None, state: "TinkerState" ) -> None: From 8dc64b485b8f073b4737fe22dc743389abafaaf1 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Tue, 20 Jan 2026 02:52:37 +0000 Subject: [PATCH 2/7] style: fix ruff formatting in dev/train.py --- src/art/dev/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index 2842f856..7fcf3109 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -31,7 +31,9 @@ class TrainConfig(TypedDict, total=False): # Tinker built-in loss configuration (only used by TinkerBackend) # When set, uses Tinker's optimized built-in loss instead of ART's custom loss tinker_loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro"] | None - tinker_loss_fn_config: dict[str, float] | None # e.g., {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0} + tinker_loss_fn_config: ( + dict[str, float] | None + ) # e.g., {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0} # Tinker checkpoint control (only used by TinkerBackend) # When False, skips saving full checkpoint (state + optimizer) after training. From 415e7ac92587b98a80f5de9abda6aa6dec38e5dd Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Tue, 20 Jan 2026 02:58:21 +0000 Subject: [PATCH 3/7] feat: add adam_beta1, adam_beta2, adam_eps params to TinkerBackend.train() - Add optional adam_beta1, adam_beta2, adam_eps parameters to train() - Pass through to TinkerService via dev_config - Use params when calling optim_step_async with tinker.AdamParams This allows customization of Adam optimizer hyperparameters, which is needed when using non-default values (e.g., beta2=0.95 instead of 0.999). --- src/art/tinker/backend.py | 20 +++++++++++++++++++- src/art/tinker/service.py | 10 +++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/art/tinker/backend.py b/src/art/tinker/backend.py index b48dcf80..4621db39 100644 --- a/src/art/tinker/backend.py +++ b/src/art/tinker/backend.py @@ -67,6 +67,10 @@ async def train( # type: ignore[override] tinker_loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro"] | None = None, tinker_loss_fn_config: dict[str, float] | None = None, + # Adam optimizer parameters + adam_beta1: float | None = None, + adam_beta2: float | None = None, + adam_eps: float | None = None, ) -> LocalTrainResult: """Train the model on trajectory groups, with optional Tinker built-in loss. @@ -90,19 +94,25 @@ async def train( # type: ignore[override] If None, uses ART's custom loss (controlled by ppo, epsilon, etc.) tinker_loss_fn_config: Config dict for built-in loss, e.g.: {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0} + adam_beta1: Adam optimizer beta1 parameter. Defaults to Tinker default (0.9). + adam_beta2: Adam optimizer beta2 parameter. Defaults to Tinker default (0.999). + adam_eps: Adam optimizer epsilon parameter. Defaults to Tinker default (1e-8). **other_args: See LocalBackend.train() for other parameters. Returns: LocalTrainResult with step number, training metrics, and checkpoint path. Example: - # Use Tinker's built-in CISPO (recommended for speed) + # Use Tinker's built-in CISPO with custom Adam params result = await backend.train( model, trajectory_groups, learning_rate=5e-6, tinker_loss_fn="cispo", tinker_loss_fn_config={"clip_low_threshold": 0.0, "clip_high_threshold": 6.0}, + adam_beta1=0.9, + adam_beta2=0.95, # Custom beta2 + adam_eps=1e-8, ) # Use ART's custom loss (default, for compatibility) @@ -154,6 +164,14 @@ async def train( # type: ignore[override] # Tinker-specific: checkpoint control dev_config["tinker_save_checkpoint"] = save_checkpoint + # Tinker-specific: Adam optimizer parameters + if adam_beta1 is not None: + dev_config["adam_beta1"] = adam_beta1 + if adam_beta2 is not None: + dev_config["adam_beta2"] = adam_beta2 + if adam_eps is not None: + dev_config["adam_eps"] = adam_eps + # Collect metrics from training training_metrics: list[dict[str, float]] = [] async for metrics in self._train_model( diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 70437edd..58605d9c 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -201,8 +201,16 @@ def custom_loss_fn( ) ) + # Build Adam params with optional beta1/beta2/eps from config + adam_params = tinker.AdamParams(learning_rate=config.learning_rate) + if "adam_beta1" in _config: + adam_params.beta1 = _config["adam_beta1"] + if "adam_beta2" in _config: + adam_params.beta2 = _config["adam_beta2"] + if "adam_eps" in _config: + adam_params.eps = _config["adam_eps"] optim_step_future = await state.training_client.optim_step_async( - adam_params=tinker.AdamParams(learning_rate=config.learning_rate), + adam_params=adam_params, ) forward_backward_output, optim_step_response = await asyncio.gather( forward_backward_output_future, optim_step_future From bd00982d163a392c331395401d61591bbc3205e5 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Wed, 21 Jan 2026 19:23:11 +0000 Subject: [PATCH 4/7] Add Adam optimizer params to TrainConfig TypedDict Add adam_beta1, adam_beta2, and adam_eps to fix Pyright type errors when assigning these keys to the dev_config dict. --- src/art/dev/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index 7fcf3109..ead24ab2 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -40,3 +40,8 @@ class TrainConfig(TypedDict, total=False): # Sampler weights are still saved for inference. Use this for faster training # when you only need full checkpoints at specific intervals. tinker_save_checkpoint: bool + + # Adam optimizer parameters (only used by TinkerBackend) + adam_beta1: float + adam_beta2: float + adam_eps: float From 888e5991bd00d4b4edc617e7ad7c1345360730e6 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Thu, 22 Jan 2026 00:33:38 +0000 Subject: [PATCH 5/7] fix: handle 1D tensors in shift_tensor and NaN logprobs in Tinker API - Update shift_tensor to support both 1D and 2D tensors - Replace NaN values in logprobs before JSON serialization to Tinker API - Guard Qwen3InstructRenderer patch for older tinker_cookbook versions --- src/art/loss.py | 11 ++++++++++- src/art/tinker/service.py | 32 ++++++++++++++++++++------------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/art/loss.py b/src/art/loss.py index dda53958..570272cc 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -143,4 +143,13 @@ def loss_fn( def shift_tensor(tensor: torch.Tensor, pad: int | float | bool) -> torch.Tensor: - return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad) + """Shift tensor left by 1 position, padding the right with `pad`. + + Handles both 1D tensors (sequence) and 2D tensors (batch x sequence). + """ + if tensor.ndim == 1: + # 1D tensor: just shift and pad + return torch.nn.functional.pad(tensor[1:], (0, 1), value=pad) + else: + # 2D tensor: shift along sequence dimension + return torch.nn.functional.pad(tensor[:, 1:], (0, 1), value=pad) diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 58605d9c..dec79a22 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -36,17 +36,19 @@ packed_tensors_from_dir, ) -# Patch Tinker's Qwen3InstructRenderer which mistakenly expects "args" instead of "arguments" in tool calls. -_parse_tool_call = renderers.Qwen3InstructRenderer._parse_tool_call - - -def _patched_parse_tool_call( - self, tool_call_str: str -) -> list[renderers.ToolCall] | None: - return _parse_tool_call(self, tool_call_str.replace('"arguments": ', '"args": ')) - +# Patch Tinker's Qwen3InstructRenderer which mistakenly expects "args" instead of +# "arguments" in tool calls. Guard against older renderers that lack this class. +if hasattr(renderers, "Qwen3InstructRenderer"): + _parse_tool_call = renderers.Qwen3InstructRenderer._parse_tool_call + + def _patched_parse_tool_call( + self, tool_call_str: str + ) -> list[renderers.ToolCall] | None: + return _parse_tool_call( + self, tool_call_str.replace('"arguments": ', '"args": ') + ) -renderers.Qwen3InstructRenderer._parse_tool_call = _patched_parse_tool_call + renderers.Qwen3InstructRenderer._parse_tool_call = _patched_parse_tool_call @contextmanager @@ -142,6 +144,12 @@ def custom_loss_fn( if tinker_loss_fn: # Use Tinker's optimized built-in loss function # Build datums with logprobs and advantages for the built-in loss + # Note: logprobs may contain NaN for padded positions; replace with 0.0 + # for JSON serialization (Tinker's API requires valid floats) + shifted_logprobs = torch.nan_to_num( + shift_tensor(packed_tensors["logprobs"][i], 0.0), nan=0.0 + ) + shifted_advantages = shift_tensor(packed_tensors["advantages"][i], 0.0) datums = [ tinker.Datum( loss_fn_inputs={ @@ -149,10 +157,10 @@ def custom_loss_fn( shifted_tokens[i][mask] ), "logprobs": tinker.TensorData.from_torch( - shift_tensor(packed_tensors["logprobs"][i], 0.0)[mask] + shifted_logprobs[mask] ), "advantages": tinker.TensorData.from_torch( - shift_tensor(packed_tensors["advantages"][i], 0.0)[mask] + shifted_advantages[mask] ), }, model_input=tinker.ModelInput.from_ints( From 044f7210bd2692ad7c45aa29c572eab4682e2bee Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Thu, 22 Jan 2026 01:24:43 +0000 Subject: [PATCH 6/7] fix: ensure TinkerBackend uses same port for server and client Previously, if port 8000 was already in use, the server would bind to a different port via get_free_port() but the client would still try to connect to port 8000, causing connection failures. Now the port is determined once upfront and passed to both the server and client. --- src/art/tinker/service.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index dec79a22..8f0d7d27 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -69,12 +69,15 @@ class TinkerService: _openai_server_task: asyncio.Task[None] | None = None async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> None: + # Determine the port upfront so both server and client use the same port + _config = config or {} + host = _config.get("host", "0.0.0.0") + port = _config.get("port") or get_free_port() + self._openai_server_task = asyncio.create_task( - self._run_openai_server(config, await self._state_task) - ) - client = AsyncOpenAI( - base_url=f"http://{(config or {}).get('host', '0.0.0.0')}:{(config or {}).get('port', 8000)}/v1" + self._run_openai_server(host, port, await self._state_task) ) + client = AsyncOpenAI(base_url=f"http://{host}:{port}/v1") with log_timing("Waiting for server"): start = time.time() while True: @@ -387,9 +390,8 @@ async def _save_sampler_weights_only( return sampling_client async def _run_openai_server( - self, config: dev.OpenAIServerConfig | None, state: "TinkerState" + self, host: str, port: int, state: "TinkerState" ) -> None: - config = config or {} app = FastAPI() @app.get("/metrics") @@ -495,8 +497,8 @@ async def chat_completions( server_config = uvicorn.Config( app, - host=config.get("host", "0.0.0.0"), - port=config.get("port", get_free_port()), + host=host, + port=port, log_level="error", ) server = uvicorn.Server(server_config) From 86984a3b23fec9de55284cfb792e94ebbe417f2d Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Thu, 22 Jan 2026 20:06:28 +0000 Subject: [PATCH 7/7] partial changes, unclear if helpful --- src/art/tinker/service.py | 78 ++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 8f0d7d27..de01a8d8 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -14,10 +14,6 @@ from openai import AsyncOpenAI from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs from openai.types.chat.chat_completion_message import ChatCompletionMessage -from openai.types.chat.chat_completion_message_function_tool_call import ( - ChatCompletionMessageFunctionToolCall, - Function, -) from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob from openai.types.chat.completion_create_params import CompletionCreateParams from openai.types.completion_usage import CompletionUsage @@ -424,20 +420,30 @@ async def chat_completions( add_generation_prompt=True, ) ) - sample_response = await sampler_client.sample_async( - prompt=prompt, - num_samples=body.get("n") or 1, - sampling_params=tinker.SamplingParams( - max_tokens=body.get("max_completion_tokens") - or body.get("max_tokens"), - seed=body.get("seed"), - temperature=t - if (t := body.get("temperature")) is not None - else 1.0, - top_k=body.get("top_k") or -1, - top_p=body.get("top_p") or 1.0, - ), - ) + try: + sample_response = await sampler_client.sample_async( + prompt=prompt, + num_samples=body.get("n") or 1, + sampling_params=tinker.SamplingParams( + max_tokens=body.get("max_completion_tokens") + or body.get("max_tokens"), + seed=body.get("seed"), + temperature=t + if (t := body.get("temperature")) is not None + else 1.0, + top_k=body.get("top_k") or -1, + top_p=body.get("top_p") or 1.0, + ), + ) + except Exception as e: + max_tokens = body.get("max_completion_tokens") or body.get("max_tokens") + print( + "[tinker-service] sample_async error " + f"model={model_name} step={step} " + f"prompt_tokens={prompt.length} n={body.get('n') or 1} " + f"max_tokens={max_tokens} error={e}" + ) + raise choices: list[Choice] = [] for i, sequence in enumerate(sample_response.sequences): assert sequence.logprobs is not None, "Logprobs are required" @@ -445,26 +451,30 @@ async def chat_completions( "Tokens and logprobs must have the same length" ) message, _ = state.renderer.parse_response(sequence.tokens) + # Convert to OpenAI format - handles list content, tool_calls, reasoning_content + openai_message = state.renderer.to_openai_message(message) + + # Ensure tool_calls are valid for OpenAI schema (id must be a string) + tool_calls = openai_message.get("tool_calls") + if tool_calls: + sanitized_tool_calls = [] + for idx, tool_call in enumerate(tool_calls): + if not isinstance(tool_call, dict): + continue + tool_call = dict(tool_call) + if not tool_call.get("id"): + tool_call["id"] = f"call_{idx}" + sanitized_tool_calls.append(tool_call) + if sanitized_tool_calls: + openai_message["tool_calls"] = sanitized_tool_calls + else: + openai_message.pop("tool_calls", None) + choices.append( Choice( finish_reason=sequence.stop_reason, index=i, - message=ChatCompletionMessage( - content=message["content"], - role="assistant", - tool_calls=[ - ChatCompletionMessageFunctionToolCall( - type="function", - id=tool_call.id or "", - function=Function( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - ), - ) - for tool_call in message.get("tool_calls", []) - ] - or None, - ), + message=ChatCompletionMessage(**openai_message), logprobs=ChoiceLogprobs( content=[ ChatCompletionTokenLogprob(